├── LICENSE ├── Makefile ├── README.md ├── configs ├── default_config.py ├── overfit_kitti_mf_gt.yaml ├── train_demon_mf_gt.yaml ├── train_kitti_mf_gt.yaml ├── train_kitti_mf_selfsup.yaml ├── train_nyu_mf_gt.yaml ├── train_rgbd_mf_gt.yaml ├── train_scannet_mf_gt_view2.yaml ├── train_scannet_mf_gt_view3.yaml ├── train_scannet_mf_gt_view5.yaml ├── train_scannet_mf_selfsup_view3.yaml ├── train_scannet_mf_selfsup_view5.yaml ├── train_scene11_mf_gt.yaml ├── train_sun3d_mf_gt.yaml └── train_video_mf_selfsup_out_random.yaml ├── docker └── Dockerfile ├── dro_sfm ├── __init__.py ├── datasets │ ├── __init__.py │ ├── augmentations.py │ ├── demon_dataset.py │ ├── demon_mf_dataset.py │ ├── dgp_dataset.py │ ├── image_dataset.py │ ├── kitti_dataset.py │ ├── kitti_dataset_utils.py │ ├── nyu_dataset_processed.py │ ├── nyu_dataset_test_processed.py │ ├── scannet_banet_dataset.py │ ├── scannet_dataset.py │ ├── scannet_test_dataset.py │ ├── transforms.py │ ├── video_dataset.py │ └── video_random_dataset.py ├── geometry │ ├── __init__.py │ ├── camera.py │ ├── camera_utils.py │ ├── pose.py │ ├── pose_trans.py │ └── pose_utils.py ├── loggers │ ├── __init__.py │ └── wandb_logger.py ├── losses │ ├── __init__.py │ ├── loss_base.py │ ├── multiview_photometric_loss_mf.py │ └── supervised_loss.py ├── models │ ├── SelfSupModelMF.py │ ├── SemiSupModelMF.py │ ├── SfmModel.py │ ├── SfmModelMF.py │ ├── SupModelMF.py │ ├── __init__.py │ ├── model_checkpoint.py │ ├── model_utils.py │ └── model_wrapper.py ├── networks │ ├── __init__.py │ ├── depth_pose │ │ └── DepthPoseNet.py │ ├── layers │ │ ├── PercepNet.py │ │ └── resnet │ │ │ ├── depth_decoder.py │ │ │ ├── layers.py │ │ │ ├── pose_decoder.py │ │ │ ├── pose_res_decoder.py │ │ │ └── resnet_encoder.py │ └── optim │ │ ├── __init__.py │ │ ├── extractor.py │ │ └── update.py ├── trainers │ ├── __init__.py │ ├── base_trainer.py │ └── horovod_trainer.py └── utils │ ├── __init__.py │ ├── config.py │ ├── depth.py │ ├── horovod.py │ ├── image.py │ ├── image_gt.py │ ├── load.py │ ├── logging.py │ ├── misc.py │ ├── reduce.py │ ├── save.py │ └── types.py ├── media └── figs │ ├── demo_kitti.gif │ └── demo_scannet.gif ├── run.sh └── scripts ├── eval.py ├── evaluate_depth_maps.py ├── infer.py ├── infer_pose.py ├── infer_pose.sh ├── infer_video.py ├── train.py └── vis.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Alibaba Cloud 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Handy commands: 2 | # - `make docker-build`: builds DOCKERIMAGE (default: `dro-sfm:latest`) 3 | PROJECT ?= dro-sfm 4 | WORKSPACE ?= /workspace/$(PROJECT) 5 | DOCKER_IMAGE ?= ${PROJECT}:latest 6 | 7 | SHMSIZE ?= 32G 8 | WANDB_MODE ?= run 9 | DOCKER_OPTS := \ 10 | --name ${PROJECT} \ 11 | -it \ 12 | --shm-size=${SHMSIZE} \ 13 | -e AWS_DEFAULT_REGION \ 14 | -e AWS_ACCESS_KEY_ID \ 15 | -e AWS_SECRET_ACCESS_KEY \ 16 | -e WANDB_API_KEY \ 17 | -e WANDB_ENTITY \ 18 | -e WANDB_MODE \ 19 | -e HOST_HOSTNAME= \ 20 | -e OMP_NUM_THREADS=1 -e KMP_AFFINITY="granularity=fine,compact,1,0" \ 21 | -e OMPI_ALLOW_RUN_AS_ROOT=1 \ 22 | -e OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \ 23 | -e NCCL_DEBUG=VERSION \ 24 | -e DISPLAY=${DISPLAY} \ 25 | -e XAUTHORITY \ 26 | -e NVIDIA_DRIVER_CAPABILITIES=all \ 27 | -v ~/.aws:/root/.aws \ 28 | -v /root/.ssh:/root/.ssh \ 29 | -v ~/.cache:/root/.cache \ 30 | -v /data:/data \ 31 | -v /data0:/data0 \ 32 | -v /mnt:/mnt \ 33 | -v /mnt/fsx/:/mnt/fsx \ 34 | -v /dev/null:/dev/raw1394 \ 35 | -v /tmp:/tmp \ 36 | -v /tmp/.X11-unix/X0:/tmp/.X11-unix/X0 \ 37 | -v /var/run/docker.sock:/var/run/docker.sock \ 38 | -v ${PWD}:${WORKSPACE} \ 39 | -w ${WORKSPACE} \ 40 | --privileged \ 41 | --ipc=host \ 42 | --network=host 43 | 44 | NGPUS=$(shell nvidia-smi -L | wc -l) 45 | MPI_CMD=mpirun \ 46 | -allow-run-as-root \ 47 | -np ${NGPUS} \ 48 | -H localhost:${NGPUS} \ 49 | -x MASTER_ADDR=127.0.0.1 \ 50 | -x MASTER_PORT=23457 \ 51 | -x HOROVOD_TIMELINE \ 52 | -x OMP_NUM_THREADS=1 \ 53 | -x KMP_AFFINITY='granularity=fine,compact,1,0' \ 54 | -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x NCCL_MIN_NRINGS=4 \ 55 | --report-bindings 56 | 57 | 58 | .PHONY: all clean docker-build docker-overfit-pose 59 | 60 | all: clean 61 | 62 | clean: 63 | find . -name "*.pyc" | xargs rm -f && \ 64 | find . -name "__pycache__" | xargs rm -rf 65 | 66 | docker-build: 67 | docker build \ 68 | -f docker/Dockerfile \ 69 | -t ${DOCKER_IMAGE} . 70 | 71 | docker-start-interactive: 72 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} bash 73 | 74 | docker-start-jupyter: 75 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 76 | bash -c "jupyter notebook --port=8888 -ip=0.0.0.0 --allow-root --no-browser" 77 | 78 | docker-run: 79 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 80 | bash -c "${COMMAND}" 81 | 82 | docker-run-mpi: 83 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} \ 84 | bash -c "${MPI_CMD} ${COMMAND}" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DRO: Deep Recurrent Optimizer for Structure-from-Motion 2 | 3 | This is the official PyTorch implementation code for DRO-sfm. For technical details, please refer to: 4 | 5 | **DRO: Deep Recurrent Optimizer for Structure-from-Motion**
6 | Xiaodong Gu\*, Weihao Yuan\*, Zuozhuo Dai, Chengzhou Tang, Siyu Zhu, Ping Tan
7 | **[[Paper](https://arxiv.org/abs/2103.13201)]**
8 | 9 |

10 | 11 | 12 |

13 | 14 | ## Bibtex 15 | If you find this code useful in your research, please cite: 16 | 17 | ``` 18 | @article{gu2021dro, 19 | title={DRO: Deep Recurrent Optimizer for Structure-from-Motion}, 20 | author={Gu, Xiaodong and Yuan, Weihao and Dai, Zuozhuo and Tang, Chengzhou and Zhu, Siyu and Tan, Ping}, 21 | journal={arXiv preprint arXiv:2103.13201}, 22 | year={2021} 23 | } 24 | ``` 25 | 26 | ## Contents 27 | 1. [Install](#install) 28 | 2. [Datasets](#datasets) 29 | 3. [Training](#training) 30 | 4. [Evaluation](#evaluation) 31 | 5. [Models](#models) 32 | 33 | 34 | ## Install 35 | + We recommend using [nvidia-docker2](https://github.com/NVIDIA/nvidia-docker) to have a reproducible environment. 36 | 37 | ```bash 38 | git clone https://github.com/aliyun/dro-sfm.git 39 | cd dro-sfm 40 | sudo make docker-build 41 | sudo make docker-start-interactive 42 | ``` 43 | You can also download the built docker directly from [dro-sfm-image.tar](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/dro-sfm-image.tar) 44 | 45 | ```bash 46 | docker load < dro-sfm-image.tar 47 | ``` 48 | 49 | + If you do not use docker, you could create an environment following the steps in the Dockerfile. 50 | 51 | ```bash 52 | # Environment variables 53 | export PYTORCH_VERSION=1.4.0 54 | export TORCHVISION_VERSION=0.5.0 55 | export NCCL_VERSION=2.4.8-1+cuda10.1 56 | export HOROVOD_VERSION=65de4c961d1e5ad2828f2f6c4329072834f27661 57 | # Install NCCL 58 | sudo apt-get install libnccl2=${NCCL_VERSION} libnccl-dev=${NCCL_VERSION} 59 | 60 | # Install Open MPI 61 | mkdir /tmp/openmpi && \ 62 | cd /tmp/openmpi && \ 63 | wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ 64 | tar zxf openmpi-4.0.0.tar.gz && \ 65 | cd openmpi-4.0.0 && \ 66 | ./configure --enable-orterun-prefix-by-default && \ 67 | make -j $(nproc) all && \ 68 | make install && \ 69 | ldconfig && \ 70 | rm -rf /tmp/openmpi 71 | 72 | # Install PyTorch 73 | pip install torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} && ldconfig 74 | 75 | # Install horovod (for distributed training) 76 | sudo ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL HOROVOD_WITH_PYTORCH=1 pip install --no-cache-dir git+https://github.com/horovod/horovod.git@${HOROVOD_VERSION} && sudo ldconfig 77 | ``` 78 | 79 | To verify that the environment is setup correctly, you can run a simple overfitting test: 80 | 81 | ```bash 82 | # download a tiny subset of KITTI 83 | cd dro-sfm 84 | curl -s https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/KITTI_tiny.tar | tar xv -C /data/datasets/kitti/ 85 | # in docker 86 | ./run.sh "python scripts/train.py configs/overfit_kitti_mf_gt.yaml" log.txt 87 | ``` 88 | 89 | ## Datasets 90 | Datasets are assumed to be downloaded in `/data/datasets/` (can be a symbolic link). 91 | 92 | ### KITTI 93 | The KITTI (raw) dataset used in our experiments can be downloaded from the [KITTI website](http://www.cvlibs.net/datasets/kitti/raw_data.php). 94 | For convenience, you can download data from [packnet](https://tri-ml-public.s3.amazonaws.com/github/packnet-sfm/datasets/KITTI_raw.tar.gz) or [here](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/KITTI_raw.tar) 95 | 96 | ### Tiny KITTI 97 | For simple tests, you can download a "tiny" version of [KITTI](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/KITTI_tiny.tar): 98 | 99 | 100 | ### Scannet 101 | The Scannet (raw) dataset used in our experiments can be downloaded from the [Scannet website](http://www.scan-net.org). 102 | For convenience, you can download data from [here](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/scannet.tar) 103 | 104 | 105 | ### DeMoN 106 | Download [DeMoN](https://github.com/lmb-freiburg/demon/tree/master/datasets). 107 | 108 | ```bash 109 | bash download_traindata.sh 110 | python ./dataset/preparation/preparedata_train.py 111 | bash download_testdata.sh 112 | python ./dataset/preparation/preparedata_test.py 113 | ``` 114 | 115 | ## Training 116 | Any training, including fine-tuning, can be done by passing either a `.yaml` config file or a `.ckpt` model checkpoint to [scripts/train.py](./scripts/train.py): 117 | 118 | ```bash 119 | # kitti, checkpoints will saved in ./results/mdoel/ 120 | ./run.sh 'python scripts/train.py configs/train_kitti_mf_gt.yaml' logs/kitti_sup.txt 121 | ./run.sh 'python scripts/train.py configs/train_kitti_mf_selfsup.yaml' logs/kitti_selfsup.txt 122 | 123 | # scannet 124 | ./run.sh 'python scripts/train.py configs/train_scannet_mf_gt_view3.yaml' logs/scannet_sup.txt 125 | ./run.sh 'python scripts/train.py configs/train_scannet_mf_selfsup_view3.yaml' logs/scannet_selfsup.txt 126 | ./run.sh 'python scripts/train.py configs/train_scannet_mf_gt_view5.yaml' logs/scannet_sup_view5.txt 127 | 128 | # demon 129 | ./run.sh 'python scripts/train.py configs/train_demon_mf_gt.yaml' logs/demon_sup.txt 130 | ``` 131 | 132 | 133 | ## Evaluation 134 | 135 | ```bash 136 | python scripts/eval.py --checkpoint [--config ] 137 | # example:kitti, results will be saved in results/depth/ 138 | python scripts/eval.py --checkpoint ckpt/outdoor_kitti.ckpt --config configs/train_kitti_mf_gt.yaml 139 | 140 | ``` 141 | 142 | You can also directly run inference on a single image or video: 143 | 144 | ```bash 145 | # video or folder 146 | # indoor-scannet 147 | python scripts/infer_video.py --checkpoint ckpt/indoor_sacnnet.ckpt --input /path/to/video or folder --output /path/to/save_folder --sample_rate 1 --data_type scannet --ply_mode 148 | # indoor-general 149 | python scripts/infer_video.py --checkpoint ckpt/indoor_sacnnet.ckpt --input /path/to/video or folder --output /path/to/save_folder --sample_rate 1 --data_type general --ply_mode 150 | 151 | # outdoor 152 | python scripts/infer_video.py --checkpoint ckpt/outdoor_kitti.ckpt --input /path/to/video or folder --output /path/to/save_folder --sample_rate 1 --data_type kitti --ply_mode 153 | 154 | # image 155 | python scripts/infer.py --checkpoint --input --output 156 | ``` 157 | 158 | 159 | ## Models 160 | 161 | | Model | Abs.Rel. | Sqr.Rel | RMSE | RMSElog | a1 | a2 | a3| SILog| L1_inv| rot_ang| t_ang| t_cm| 162 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 163 | |[Kitti_sup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/outdoor_kitti.ckpt) | 0.045 | 0.193 | 2.570 | 0.080 | 0.971 | 0.994 | 0.998| 0.079 | 0.003 | -| -| -| 164 | |[Kitti_selfsup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/outdoor_kitti_selfsup.ckpt) | 0.053 |0.346 | 3.037 | 0.102 | 0.962 | 0.990| 0.996|0.101 | 0.004 | -| -| -| 165 | |[scannet_sup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/indoor_scannet.ckpt) | 0.053 | 0.017 | 0.165 | 0.080 | 0.967 | 0.994 | 0.998| 0.078 | 0.033| 0.472| 9.297| 1.160| 166 | |[scannet_sup(view5)](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/indoor_scannet_view5.ckpt) |0.047 |0.014 | 0.151 | 0.072 | 0.976 | 0.996 | 0.999| 0.071 | 0.030 | 0.456| 8.502| 1.163| 167 | |[scannet_selfsup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/indoor_scannet_selfsup.ckpt) | 0.143 | 0.345 | 0.656 | 0.274 | 0.896 | 0.954 | 0.969|0.272 | 0.106 | 0.609| 10.779| 1.393 | 168 | 169 | 170 | 171 | ## Acknowledgements 172 | Thanks to Toyota Research Institute for opening source of excellent work [packnet-sfm](https://github.com/TRI-ML/packnet-sfm). Thanks to Zachary Teed for opening source of his excellent work [RAFT](https://github.com/princeton-vl/RAFT). 173 | -------------------------------------------------------------------------------- /configs/overfit_kitti_mf_gt.yaml: -------------------------------------------------------------------------------- 1 | name: 'overfit_kitti_gt' 2 | arch: 3 | max_epochs: 50 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | name: 'SupModelMF' 10 | optimizer: 11 | name: 'Adam' 12 | depth: 13 | lr: 0.0002 14 | pose: 15 | lr: 0.0002 16 | scheduler: 17 | name: 'StepLR' 18 | step_size: 30 19 | gamma: 0.5 20 | depth_net: 21 | name: 'DepthPoseNet' 22 | version: 'it12-h-out' 23 | loss: 24 | automask_loss: True 25 | photometric_reduce_op: 'min' 26 | params: 27 | crop: 'garg' 28 | min_depth: 0.2 29 | max_depth: 80.0 30 | datasets: 31 | augmentation: 32 | image_shape: (192, 640) 33 | train: 34 | batch_size: 4 35 | dataset: ['KITTI'] 36 | path: ['/data/datasets/kitti/KITTI_tiny'] 37 | split: ['kitti_tiny.txt'] 38 | depth_type: ['velodyne'] 39 | repeat: [100] 40 | forward_context: 1 41 | back_context: 1 42 | validation: 43 | dataset: ['KITTI'] 44 | path: ['/data/datasets/kitti/KITTI_tiny'] 45 | split: ['kitti_tiny.txt'] 46 | depth_type: ['velodyne'] 47 | forward_context: 1 48 | back_context: 1 49 | test: 50 | dataset: ['KITTI'] 51 | path: ['/data/datasets/kitti/KITTI_tiny'] 52 | split: ['kitti_tiny.txt'] 53 | depth_type: ['velodyne'] 54 | forward_context: 1 55 | back_context: 1 56 | -------------------------------------------------------------------------------- /configs/train_demon_mf_gt.yaml: -------------------------------------------------------------------------------- 1 | name: 'demon_sun3d_rgbd_gt' 2 | arch: 3 | max_epochs: 100 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | # pretrain from xxx.ckpt 10 | # checkpoint_path: ./logs/models/xxc.ckpt 11 | name: 'SupModelMF' 12 | optimizer: 13 | name: 'Adam' 14 | depth: 15 | lr: 0.0002 16 | pose: 17 | lr: 0.0002 18 | scheduler: 19 | name: 'StepLR' 20 | step_size: 30 21 | gamma: 0.5 22 | depth_net: 23 | name: 'DepthPoseNet' 24 | version: 'it12-h-out' 25 | loss: 26 | automask_loss: True 27 | photometric_reduce_op: 'min' 28 | params: 29 | crop: '' 30 | min_depth: 0.2 31 | max_depth: 10.0 32 | datasets: 33 | augmentation: 34 | image_shape: (240, 320) 35 | train: 36 | batch_size: 12 37 | dataset: ['DemonMF'] 38 | path: ['/data/datasets/demon/train'] 39 | split: ['sun3d_train.txt', 'rgbd_train.txt'] 40 | depth_type: ['groundtruth'] 41 | repeat: [1, 5] 42 | forward_context: 1 43 | back_context: 0 44 | validation: 45 | dataset: ['Demon'] 46 | path: ['/data/datasets/demon/test'] 47 | split: ['sun3d_test.txt', 'rgbd_test.txt'] 48 | depth_type: ['groundtruth'] 49 | forward_context: 1 50 | back_context: 1 51 | test: 52 | dataset: ['Demon'] 53 | path: ['/data/datasets/demon/test'] 54 | split: ['sun3d_test.txt', 'rgbd_test.txt'] 55 | depth_type: ['groundtruth'] 56 | forward_context: 1 57 | back_context: 1 58 | -------------------------------------------------------------------------------- /configs/train_kitti_mf_gt.yaml: -------------------------------------------------------------------------------- 1 | name: 'kitti_gt' 2 | save: 3 | folder: './results' 4 | arch: 5 | max_epochs: 50 6 | checkpoint: 7 | save_top_k: 10 8 | monitor: 'abs_rel_pp_gt' 9 | monitor_index: 0 10 | model: 11 | name: 'SupModelMF' 12 | optimizer: 13 | name: 'Adam' 14 | depth: 15 | lr: 0.0002 16 | pose: 17 | lr: 0.0002 18 | scheduler: 19 | name: 'StepLR' 20 | step_size: 30 21 | gamma: 0.5 22 | depth_net: 23 | name: 'DepthPoseNet' 24 | version: 'it12-h' 25 | loss: 26 | automask_loss: True 27 | photometric_reduce_op: 'min' 28 | params: 29 | crop: 'garg' 30 | min_depth: 0.2 31 | max_depth: 80.0 32 | datasets: 33 | augmentation: 34 | image_shape: (320, 960) 35 | train: 36 | batch_size: 2 37 | dataset: ['KITTI'] 38 | path: ['/data/datasets/kitti/KITTI_raw'] 39 | split: ['data_splits/eigen_zhou_files.txt'] 40 | depth_type: ['groundtruth'] 41 | repeat: [2] 42 | forward_context: 1 43 | back_context: 1 44 | validation: 45 | dataset: ['KITTI'] 46 | path: ['/data/datasets/kitti/KITTI_raw'] 47 | split: ['data_splits/eigen_val_files.txt', 48 | 'data_splits/eigen_test_files.txt'] 49 | depth_type: ['groundtruth'] 50 | forward_context: 1 51 | back_context: 0 52 | test: 53 | dataset: ['KITTI'] 54 | path: ['/data/datasets/kitti/KITTI_raw'] 55 | split: ['data_splits/eigen_test_files.txt'] 56 | depth_type: ['groundtruth'] 57 | forward_context: 1 58 | back_context: 0 -------------------------------------------------------------------------------- /configs/train_kitti_mf_selfsup.yaml: -------------------------------------------------------------------------------- 1 | name: 'kitt_selfsup_view3' 2 | arch: 3 | max_epochs: 50 4 | model: 5 | name: 'SelfSupModelMF' 6 | optimizer: 7 | name: 'Adam' 8 | depth: 9 | lr: 0.0002 10 | pose: 11 | lr: 0.0002 12 | scheduler: 13 | name: 'StepLR' 14 | step_size: 30 15 | gamma: 0.5 16 | depth_net: 17 | name: 'DepthPoseNet' 18 | version: 'it8-seq4-inter-out' 19 | loss: 20 | automask_loss: True 21 | photometric_reduce_op: 'min' 22 | params: 23 | crop: 'garg' 24 | min_depth: 0.5 25 | max_depth: 80.0 26 | datasets: 27 | augmentation: 28 | image_shape: (320, 960) 29 | train: 30 | batch_size: 2 31 | dataset: ['KITTI'] 32 | path: ['/data/datasets/kitti/KITTI_raw'] 33 | split: ['data_splits/eigen_zhou_files.txt'] 34 | depth_type: ['groundtruth'] 35 | repeat: [2] 36 | forward_context: 1 37 | back_context: 1 38 | validation: 39 | dataset: ['KITTI'] 40 | path: ['/data/datasets/kitti/KITTI_raw'] 41 | split: ['data_splits/eigen_val_files.txt', 42 | 'data_splits/eigen_test_files.txt'] 43 | depth_type: ['groundtruth'] 44 | forward_context: 1 45 | back_context: 0 46 | test: 47 | dataset: ['KITTI'] 48 | path: ['/data/datasets/kitti/KITTI_raw'] 49 | split: ['data_splits/eigen_test_files.txt'] 50 | depth_type: ['groundtruth'] 51 | forward_context: 1 52 | back_context: 0 53 | -------------------------------------------------------------------------------- /configs/train_nyu_mf_gt.yaml: -------------------------------------------------------------------------------- 1 | name: 'nyu_semi' 2 | arch: 3 | max_epochs: 100 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | name: 'SemiSupModelMF' 10 | optimizer: 11 | name: 'Adam' 12 | depth: 13 | lr: 0.0002 14 | pose: 15 | lr: 0.0002 16 | scheduler: 17 | name: 'StepLR' 18 | step_size: 30 19 | gamma: 0.5 20 | depth_net: 21 | name: 'DepthPoseNet' 22 | version: 'it12-h-out' 23 | loss: 24 | automask_loss: True 25 | photometric_reduce_op: 'min' 26 | params: 27 | crop: 'eigen_nyu' 28 | min_depth: 0.2 29 | max_depth: 10.0 30 | datasets: 31 | augmentation: 32 | image_shape: (240, 320) 33 | train: 34 | batch_size: 8 35 | dataset: ['NYU'] 36 | path: ['/data/datasets/nyu/train_processed'] 37 | split: [''] 38 | depth_type: ['groundtruth'] 39 | repeat: [1] 40 | forward_context: 1 41 | back_context: 1 42 | validation: 43 | dataset: ['NYUtest'] 44 | path: ['/data/datasets/nyu/test_processed'] 45 | split: ['test_depth.txt'] 46 | depth_type: ['groundtruth'] 47 | forward_context: 1 48 | back_context: 1 49 | test: 50 | dataset: ['NYUtest'] 51 | path: ['/data/datasets/nyu/test_processed'] 52 | split: ['test_depth.txt'] 53 | depth_type: ['groundtruth'] 54 | forward_context: 1 55 | back_context: 1 56 | -------------------------------------------------------------------------------- /configs/train_rgbd_mf_gt.yaml: -------------------------------------------------------------------------------- 1 | name: 'rgbd_gt' 2 | arch: 3 | max_epochs: 100 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | name: 'SupModelMF' 10 | optimizer: 11 | name: 'Adam' 12 | depth: 13 | lr: 0.0002 14 | pose: 15 | lr: 0.0002 16 | scheduler: 17 | name: 'StepLR' 18 | step_size: 30 19 | gamma: 0.5 20 | depth_net: 21 | name: 'DepthPoseNet' 22 | version: 'it12-h-out' 23 | loss: 24 | automask_loss: True 25 | photometric_reduce_op: 'min' 26 | params: 27 | crop: '' 28 | min_depth: 0.2 29 | max_depth: 10.0 30 | datasets: 31 | augmentation: 32 | image_shape: (240, 320) 33 | train: 34 | batch_size: 12 35 | dataset: ['Demon'] 36 | path: ['/data/datasets/demon/train'] 37 | split: ['rgbd_train.txt'] 38 | depth_type: ['groundtruth'] 39 | repeat: [1] 40 | forward_context: 1 41 | back_context: 1 42 | validation: 43 | dataset: ['Demon'] 44 | path: ['/data/datasets/demon/test'] 45 | split: ['rgbd_test.txt'] 46 | depth_type: ['groundtruth'] 47 | forward_context: 1 48 | back_context: 1 49 | test: 50 | dataset: ['Demon'] 51 | path: ['/data/datasets/demon/test'] 52 | split: ['rgbd_test.txt'] 53 | depth_type: ['groundtruth'] 54 | forward_context: 1 55 | back_context: 1 56 | -------------------------------------------------------------------------------- /configs/train_scannet_mf_gt_view2.yaml: -------------------------------------------------------------------------------- 1 | name: 'scannet_gt_view2' 2 | save: 3 | folder: './results' 4 | arch: 5 | max_epochs: 100 6 | checkpoint: 7 | save_top_k: 10 8 | monitor: 'abs_rel_pp_gt' 9 | monitor_index: 0 10 | model: 11 | name: 'SupModelMF' 12 | optimizer: 13 | name: 'Adam' 14 | depth: 15 | lr: 0.0002 16 | pose: 17 | lr: 0.0002 18 | scheduler: 19 | name: 'StepLR' 20 | step_size: 30 21 | gamma: 0.5 22 | depth_net: 23 | name: 'DepthPoseNet' 24 | version: 'it12-h-out' 25 | loss: 26 | automask_loss: True 27 | photometric_reduce_op: 'min' 28 | params: 29 | crop: '' 30 | min_depth: 0.2 31 | max_depth: 10.0 32 | datasets: 33 | augmentation: 34 | image_shape: (240, 320) 35 | train: 36 | batch_size: 12 37 | dataset: ['ScannetBA'] 38 | path: ['/data/datasets/scannet/train'] 39 | split: ['splits/train_all_list.txt'] 40 | depth_type: ['groundtruth'] 41 | repeat: [1] 42 | forward_context: 1 43 | back_context: 0 44 | num_workers: 32 45 | validation: 46 | dataset: ['ScannetTest'] 47 | path: ['/data/datasets/scannet/test'] 48 | split: ['splits/test_split.txt'] 49 | depth_type: ['groundtruth'] 50 | forward_context: 1 51 | back_context: 0 52 | test: 53 | dataset: ['ScannetTest'] 54 | path: ['/data/datasets/scannet/test'] 55 | split: ['splits/test_split.txt'] 56 | depth_type: ['groundtruth'] 57 | forward_context: 1 58 | back_context: 0 59 | -------------------------------------------------------------------------------- /configs/train_scannet_mf_gt_view3.yaml: -------------------------------------------------------------------------------- 1 | name: 'scannet_gt_view3' 2 | save: 3 | folder: './results' 4 | arch: 5 | max_epochs: 100 6 | checkpoint: 7 | save_top_k: 10 8 | monitor: 'abs_rel_pp_gt' 9 | monitor_index: 0 10 | model: 11 | name: 'SupModelMF' 12 | optimizer: 13 | name: 'Adam' 14 | depth: 15 | lr: 0.0002 16 | pose: 17 | lr: 0.0002 18 | scheduler: 19 | name: 'StepLR' 20 | step_size: 30 21 | gamma: 0.5 22 | depth_net: 23 | name: 'DepthPoseNet' 24 | version: 'it12-h-out' 25 | loss: 26 | automask_loss: True 27 | photometric_reduce_op: 'min' 28 | params: 29 | crop: '' 30 | min_depth: 0.2 31 | max_depth: 10.0 32 | datasets: 33 | augmentation: 34 | image_shape: (240, 320) 35 | train: 36 | batch_size: 8 37 | dataset: ['ScannetBA'] 38 | path: ['/data/datasets/scannet/train'] 39 | split: ['splits/train_all_list.txt'] 40 | depth_type: ['groundtruth'] 41 | repeat: [1] 42 | forward_context: 1 43 | back_context: 1 44 | validation: 45 | dataset: ['ScannetTest'] 46 | path: ['/data/datasets/scannet/test'] 47 | split: ['splits/test_split.txt'] 48 | depth_type: ['groundtruth'] 49 | forward_context: 1 50 | back_context: 0 51 | test: 52 | dataset: ['ScannetTest'] 53 | path: ['/data/datasets/scannet/test'] 54 | split: ['splits/test_split.txt'] 55 | depth_type: ['groundtruth'] 56 | forward_context: 1 57 | back_context: 0 58 | -------------------------------------------------------------------------------- /configs/train_scannet_mf_gt_view5.yaml: -------------------------------------------------------------------------------- 1 | name: 'scannet_gt_view5' 2 | save: 3 | folder: './results' 4 | arch: 5 | max_epochs: 100 6 | checkpoint: 7 | save_top_k: 10 8 | monitor: 'abs_rel_pp_gt' 9 | monitor_index: 0 10 | model: 11 | name: 'SupModelMF' 12 | optimizer: 13 | name: 'Adam' 14 | depth: 15 | lr: 0.0002 16 | pose: 17 | lr: 0.0002 18 | scheduler: 19 | name: 'StepLR' 20 | step_size: 30 21 | gamma: 0.5 22 | depth_net: 23 | name: 'DepthPoseNet' 24 | version: 'it12-h-out' 25 | loss: 26 | automask_loss: True 27 | photometric_reduce_op: 'min' 28 | params: 29 | crop: '' 30 | min_depth: 0.2 31 | max_depth: 10.0 32 | datasets: 33 | augmentation: 34 | image_shape: (384, 512) 35 | train: 36 | batch_size: 2 37 | dataset: ['ScannetBA'] 38 | path: ['/data/datasets/scannet/train'] 39 | split: ['splits/train_all_list.txt'] 40 | depth_type: ['groundtruth'] 41 | repeat: [1] 42 | forward_context: 2 43 | back_context: 2 44 | num_workers: 16 45 | validation: 46 | dataset: ['ScannetTest'] 47 | path: ['/data/datasets/scannet/test'] 48 | split: ['splits/test_split.txt'] 49 | depth_type: ['groundtruth'] 50 | forward_context: 1 51 | back_context: 0 52 | test: 53 | dataset: ['ScannetTest'] 54 | path: ['/data/datasets/scannet/test'] 55 | split: ['splits/test_split.txt'] 56 | depth_type: ['groundtruth'] 57 | forward_context: 2 58 | back_context: 2 59 | -------------------------------------------------------------------------------- /configs/train_scannet_mf_selfsup_view3.yaml: -------------------------------------------------------------------------------- 1 | name: 'scannet_selfsup_view3' 2 | arch: 3 | max_epochs: 100 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | name: 'SelfSupModelMF' 10 | optimizer: 11 | name: 'Adam' 12 | depth: 13 | lr: 0.0002 14 | pose: 15 | lr: 0.0002 16 | scheduler: 17 | name: 'StepLR' 18 | step_size: 30 19 | gamma: 0.5 20 | depth_net: 21 | name: 'DepthPoseNet' 22 | version: 'it12-h-out' 23 | loss: 24 | automask_loss: True 25 | photometric_reduce_op: 'min' 26 | params: 27 | crop: '' 28 | min_depth: 0.2 29 | max_depth: 10.0 30 | datasets: 31 | augmentation: 32 | image_shape: (320, 512) 33 | train: 34 | batch_size: 3 35 | dataset: ['ScannetBA'] 36 | path: ['/data/datasets/scannet/train'] 37 | split: ['splits/train_all_list.txt'] 38 | depth_type: ['groundtruth'] 39 | repeat: [1] 40 | forward_context: 1 41 | back_context: 1 42 | validation: 43 | dataset: ['ScannetTest'] 44 | path: ['/data/datasets/scannet/test'] 45 | split: ['splits/test_split.txt'] 46 | depth_type: ['groundtruth'] 47 | forward_context: 1 48 | back_context: 1 49 | test: 50 | dataset: ['ScannetTest'] 51 | path: ['/data/datasets/scannet/test'] 52 | split: ['splits/test_split.txt'] 53 | depth_type: ['groundtruth'] 54 | forward_context: 1 55 | back_context: 1 56 | -------------------------------------------------------------------------------- /configs/train_scannet_mf_selfsup_view5.yaml: -------------------------------------------------------------------------------- 1 | name: 'scannet_selfsup_view5' 2 | arch: 3 | max_epochs: 100 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | name: 'SelfSupModelMF' 10 | optimizer: 11 | name: 'Adam' 12 | depth: 13 | lr: 0.0002 14 | pose: 15 | lr: 0.0002 16 | scheduler: 17 | name: 'StepLR' 18 | step_size: 30 19 | gamma: 0.5 20 | depth_net: 21 | name: 'DepthPoseNet' 22 | version: 'it12-h-out' 23 | loss: 24 | automask_loss: True 25 | photometric_reduce_op: 'min' 26 | params: 27 | crop: '' 28 | min_depth: 0.2 29 | max_depth: 10.0 30 | datasets: 31 | augmentation: 32 | image_shape: (240, 320) 33 | train: 34 | batch_size: 4 35 | dataset: ['ScannetBA'] 36 | path: ['/data/datasets/scannet/train'] 37 | split: ['splits/train_all_list.txt'] 38 | depth_type: ['groundtruth'] 39 | repeat: [1] 40 | forward_context: 2 41 | back_context: 2 42 | num_workers: 16 43 | validation: 44 | dataset: ['ScannetTest'] 45 | path: ['/data/datasets/scannet/test'] 46 | split: ['splits/test_split.txt'] 47 | depth_type: ['groundtruth'] 48 | forward_context: 1 49 | back_context: 0 50 | test: 51 | dataset: ['ScannetTest'] 52 | path: ['/data/datasets/scannet/test'] 53 | split: ['splits/test_split.txt'] 54 | depth_type: ['groundtruth'] 55 | forward_context: 1 56 | back_context: 0 57 | -------------------------------------------------------------------------------- /configs/train_scene11_mf_gt.yaml: -------------------------------------------------------------------------------- 1 | name: 'scene11_gt' 2 | arch: 3 | max_epochs: 100 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | name: 'SupModelMF' 10 | optimizer: 11 | name: 'Adam' 12 | depth: 13 | lr: 0.0002 14 | pose: 15 | lr: 0.0002 16 | scheduler: 17 | name: 'StepLR' 18 | step_size: 30 19 | gamma: 0.5 20 | depth_net: 21 | name: 'DepthPoseNet' 22 | version: 'it12-h-out' 23 | loss: 24 | automask_loss: True 25 | photometric_reduce_op: 'min' 26 | params: 27 | crop: '' 28 | min_depth: 0.2 29 | max_depth: 10.0 30 | datasets: 31 | augmentation: 32 | image_shape: (240, 320) 33 | train: 34 | batch_size: 12 35 | dataset: ['Demon'] 36 | path: ['/data/datasets/demon/train'] 37 | split: ['scene11_train.txt'] 38 | depth_type: ['groundtruth'] 39 | repeat: [1] 40 | forward_context: 1 41 | back_context: 1 42 | validation: 43 | dataset: ['Demon'] 44 | path: ['/data/datasets/demon/test'] 45 | split: ['scene11_test.txt'] 46 | depth_type: ['groundtruth'] 47 | forward_context: 1 48 | back_context: 1 49 | test: 50 | dataset: ['Demon'] 51 | path: ['/data/datasets/demon/test'] 52 | split: ['scene11_test.txt'] 53 | depth_type: ['groundtruth'] 54 | forward_context: 1 55 | back_context: 1 56 | -------------------------------------------------------------------------------- /configs/train_sun3d_mf_gt.yaml: -------------------------------------------------------------------------------- 1 | name: 'sun3d_gt' 2 | arch: 3 | max_epochs: 100 4 | checkpoint: 5 | save_top_k: 10 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | name: 'SupModelMF' 10 | optimizer: 11 | name: 'Adam' 12 | depth: 13 | lr: 0.0002 14 | pose: 15 | lr: 0.0002 16 | scheduler: 17 | name: 'StepLR' 18 | step_size: 30 19 | gamma: 0.5 20 | depth_net: 21 | name: 'DepthPoseNet' 22 | version: 'it12-h-out' 23 | loss: 24 | automask_loss: True 25 | photometric_reduce_op: 'min' 26 | params: 27 | crop: '' 28 | min_depth: 0.2 29 | max_depth: 10.0 30 | datasets: 31 | augmentation: 32 | image_shape: (240, 320) 33 | train: 34 | batch_size: 12 35 | dataset: ['DemonMF'] 36 | path: ['/data/datasets/demon/train'] 37 | split: ['sun3d_train.txt'] 38 | depth_type: ['groundtruth'] 39 | repeat: [1] 40 | forward_context: 1 41 | back_context: 0 42 | validation: 43 | dataset: ['Demon'] 44 | path: ['/data/datasets/demon/test'] 45 | split: ['sun3d_test.txt'] 46 | depth_type: ['groundtruth'] 47 | forward_context: 1 48 | back_context: 1 49 | test: 50 | dataset: ['Demon'] 51 | path: ['/data/datasets/demon/test'] 52 | split: ['sun3d_test.txt'] 53 | depth_type: ['groundtruth'] 54 | forward_context: 1 55 | back_context: 1 56 | -------------------------------------------------------------------------------- /configs/train_video_mf_selfsup_out_random.yaml: -------------------------------------------------------------------------------- 1 | name: 'video_selfsup_out_random' 2 | arch: 3 | max_epochs: 10 4 | checkpoint: 5 | save_top_k: -1 6 | monitor: 'abs_rel_pp_gt' 7 | monitor_index: 0 8 | model: 9 | checkpoint_path: "results/indoor_scannet.ckpt" #pretrain model 10 | name: 'SelfSupModelMF' 11 | optimizer: 12 | name: 'Adam' 13 | depth: 14 | lr: 0.0002 15 | pose: 16 | lr: 0.0002 17 | scheduler: 18 | name: 'StepLR' 19 | step_size: 30 20 | gamma: 0.5 21 | depth_net: 22 | name: 'DepthPoseNet' 23 | version: 'it12-h-out' 24 | loss: 25 | automask_loss: True 26 | photometric_reduce_op: 'min' 27 | params: 28 | crop: '' 29 | min_depth: 0.2 30 | max_depth: 10.0 31 | datasets: 32 | augmentation: 33 | image_shape: (240, 320) 34 | train: 35 | batch_size: 6 36 | dataset: ['Video_Random'] 37 | path: ['/data/datasets/video/indoor'] 38 | split: [''] 39 | depth_type: [''] 40 | repeat: [30] 41 | forward_context: 1 42 | back_context: 1 43 | strides: (2, ) 44 | validation: 45 | dataset: ['Video'] 46 | path: ['/data/datasets/video/indoor'] 47 | split: [''] 48 | depth_type: [''] 49 | forward_context: 1 50 | back_context: 1 51 | strides: (1, ) 52 | test: 53 | dataset: ['Video'] 54 | path: ['/data/datasets/video/indoor'] 55 | split: [''] 56 | depth_type: [''] 57 | forward_context: 1 58 | back_context: 1 59 | strides: (2, ) 60 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | FROM nvidia/cuda:10.1-devel-ubuntu18.04 3 | 4 | ENV PROJECT=dro-sfm 5 | ENV PYTORCH_VERSION=1.4.0 6 | ENV TORCHVISION_VERSION=0.5.0 7 | ENV CUDNN_VERSION=7.6.5.32-1+cuda10.1 8 | ENV NCCL_VERSION=2.4.8-1+cuda10.1 9 | ENV HOROVOD_VERSION=65de4c961d1e5ad2828f2f6c4329072834f27661 10 | ENV LC_ALL=C.UTF-8 11 | ENV LANG=C.UTF-8 12 | 13 | ARG python=3.6 14 | ENV PYTHON_VERSION=${python} 15 | ENV DEBIAN_FRONTEND=noninteractive 16 | 17 | # Set default shell to /bin/bash 18 | SHELL ["/bin/bash", "-cu"] 19 | 20 | RUN apt-get clean && cd /var/lib/apt && mv lists lists.old && mkdir -p lists/partial && apt-get clean && apt-get update 21 | 22 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \ 23 | build-essential \ 24 | cmake \ 25 | g++-4.8 \ 26 | git \ 27 | curl \ 28 | docker.io \ 29 | vim \ 30 | wget \ 31 | ca-certificates \ 32 | libcudnn7=${CUDNN_VERSION} \ 33 | libnccl2=${NCCL_VERSION} \ 34 | libnccl-dev=${NCCL_VERSION} \ 35 | libjpeg-dev \ 36 | libpng-dev \ 37 | python${PYTHON_VERSION} \ 38 | python${PYTHON_VERSION}-dev \ 39 | python3-tk \ 40 | librdmacm1 \ 41 | libibverbs1 \ 42 | ibverbs-providers \ 43 | libgtk2.0-dev \ 44 | unzip \ 45 | bzip2 \ 46 | htop \ 47 | gnuplot \ 48 | ffmpeg 49 | 50 | # Install Open MPI 51 | RUN mkdir /tmp/openmpi && \ 52 | cd /tmp/openmpi && \ 53 | wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \ 54 | tar zxf openmpi-4.0.0.tar.gz && \ 55 | cd openmpi-4.0.0 && \ 56 | ./configure --enable-orterun-prefix-by-default && \ 57 | make -j $(nproc) all && \ 58 | make install && \ 59 | ldconfig && \ 60 | rm -rf /tmp/openmpi 61 | 62 | # Install OpenSSH for MPI to communicate between containers 63 | RUN apt-get install -y --no-install-recommends openssh-client openssh-server && \ 64 | mkdir -p /var/run/sshd 65 | 66 | # Allow OpenSSH to talk to containers without asking for confirmation 67 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \ 68 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \ 69 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config 70 | 71 | # Instal Python and pip 72 | RUN if [[ "${PYTHON_VERSION}" == "3.6" ]]; then \ 73 | apt-get install -y python${PYTHON_VERSION}-distutils; \ 74 | fi 75 | 76 | RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python 77 | 78 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ 79 | python get-pip.py && \ 80 | rm get-pip.py 81 | 82 | # Install Pydata and other deps 83 | RUN pip install future typing numpy pandas matplotlib jupyter h5py \ 84 | awscli boto3 tqdm termcolor path.py pillow-simd opencv-python-headless \ 85 | mpi4py onnx onnxruntime pycuda yacs cython==0.29.10 86 | 87 | # Install PyTorch 88 | RUN pip install torch==${PYTORCH_VERSION} \ 89 | torchvision==${TORCHVISION_VERSION} && ldconfig 90 | 91 | 92 | # install horovod (for distributed training) 93 | RUN ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && \ 94 | HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL HOROVOD_WITH_PYTORCH=1 \ 95 | pip install --no-cache-dir git+https://github.com/horovod/horovod.git@${HOROVOD_VERSION} && \ 96 | ldconfig 97 | 98 | 99 | # Override DGP wandb with required version 100 | RUN pip install wandb==0.8.21 101 | 102 | # Expose Port for jupyter (8888) 103 | EXPOSE 8888 104 | 105 | # create project workspace dir 106 | RUN mkdir -p /workspace/experiments 107 | RUN mkdir -p /workspace/${PROJECT} 108 | WORKDIR /workspace/${PROJECT} 109 | 110 | # Copy project source last (to avoid cache busting) 111 | WORKDIR /workspace/${PROJECT} 112 | COPY . /workspace/${PROJECT} 113 | ENV PYTHONPATH="/workspace/${PROJECT}:$PYTHONPATH" -------------------------------------------------------------------------------- /dro_sfm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/__init__.py -------------------------------------------------------------------------------- /dro_sfm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dro_sfm/datasets/augmentations.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import numpy as np 4 | import random 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | 8 | from dro_sfm.utils.misc import filter_dict 9 | 10 | ######################################################################################################################## 11 | 12 | def resize_image(image, shape, interpolation=Image.ANTIALIAS): 13 | """ 14 | Resizes input image. 15 | 16 | Parameters 17 | ---------- 18 | image : Image.PIL 19 | Input image 20 | shape : tuple [H,W] 21 | Output shape 22 | interpolation : int 23 | Interpolation mode 24 | 25 | Returns 26 | ------- 27 | image : Image.PIL 28 | Resized image 29 | """ 30 | transform = transforms.Resize(shape, interpolation=interpolation) 31 | return transform(image) 32 | 33 | def resize_depth(depth, shape): 34 | """ 35 | Resizes depth map. 36 | 37 | Parameters 38 | ---------- 39 | depth : np.array [h,w] 40 | Depth map 41 | shape : tuple (H,W) 42 | Output shape 43 | 44 | Returns 45 | ------- 46 | depth : np.array [H,W] 47 | Resized depth map 48 | """ 49 | depth = cv2.resize(depth, dsize=shape[::-1], 50 | interpolation=cv2.INTER_NEAREST) 51 | return np.expand_dims(depth, axis=2) 52 | 53 | def resize_sample_image_and_intrinsics(sample, shape, 54 | image_interpolation=Image.ANTIALIAS): 55 | """ 56 | Resizes the image and intrinsics of a sample 57 | 58 | Parameters 59 | ---------- 60 | sample : dict 61 | Dictionary with sample values 62 | shape : tuple (H,W) 63 | Output shape 64 | image_interpolation : int 65 | Interpolation mode 66 | 67 | Returns 68 | ------- 69 | sample : dict 70 | Resized sample 71 | """ 72 | # Resize image and corresponding intrinsics 73 | image_transform = transforms.Resize(shape, interpolation=image_interpolation) 74 | (orig_w, orig_h) = sample['rgb'].size 75 | (out_h, out_w) = shape 76 | # Scale intrinsics 77 | for key in filter_dict(sample, [ 78 | 'intrinsics' 79 | ]): 80 | intrinsics = np.copy(sample[key]) 81 | intrinsics[0] *= out_w / orig_w 82 | intrinsics[1] *= out_h / orig_h 83 | sample[key] = intrinsics 84 | # Scale images 85 | for key in filter_dict(sample, [ 86 | 'rgb', 'rgb_original', 87 | ]): 88 | sample[key] = image_transform(sample[key]) 89 | # Scale context images 90 | for key in filter_dict(sample, [ 91 | 'rgb_context', 'rgb_context_original', 92 | ]): 93 | sample[key] = [image_transform(k) for k in sample[key]] 94 | # Return resized sample 95 | return sample 96 | 97 | def resize_sample(sample, shape, image_interpolation=Image.ANTIALIAS): 98 | """ 99 | Resizes a sample, including image, intrinsics and depth maps. 100 | 101 | Parameters 102 | ---------- 103 | sample : dict 104 | Dictionary with sample values 105 | shape : tuple (H,W) 106 | Output shape 107 | image_interpolation : int 108 | Interpolation mode 109 | 110 | Returns 111 | ------- 112 | sample : dict 113 | Resized sample 114 | """ 115 | # Resize image and intrinsics 116 | sample = resize_sample_image_and_intrinsics(sample, shape, image_interpolation) 117 | # Resize depth maps 118 | for key in filter_dict(sample, [ 119 | 'depth', 120 | ]): 121 | sample[key] = resize_depth(sample[key], shape) 122 | # Resize depth contexts 123 | for key in filter_dict(sample, [ 124 | 'depth_context', 125 | ]): 126 | sample[key] = [resize_depth(k, shape) for k in sample[key]] 127 | # Return resized sample 128 | return sample 129 | 130 | ######################################################################################################################## 131 | 132 | def to_tensor(image, tensor_type='torch.FloatTensor'): 133 | """Casts an image to a torch.Tensor""" 134 | transform = transforms.ToTensor() 135 | return transform(image).type(tensor_type) 136 | 137 | def to_tensor_sample(sample, tensor_type='torch.FloatTensor'): 138 | """ 139 | Casts the keys of sample to tensors. 140 | 141 | Parameters 142 | ---------- 143 | sample : dict 144 | Input sample 145 | tensor_type : str 146 | Type of tensor we are casting to 147 | 148 | Returns 149 | ------- 150 | sample : dict 151 | Sample with keys cast as tensors 152 | """ 153 | transform = transforms.ToTensor() 154 | # Convert single items 155 | for key in filter_dict(sample, [ 156 | 'rgb', 'rgb_original', 'depth', 157 | ]): 158 | sample[key] = transform(sample[key]).type(tensor_type) 159 | # Convert lists 160 | for key in filter_dict(sample, [ 161 | 'rgb_context', 'rgb_context_original', 'depth_context' 162 | ]): 163 | sample[key] = [transform(k).type(tensor_type) for k in sample[key]] 164 | # Return converted sample 165 | return sample 166 | 167 | ######################################################################################################################## 168 | 169 | def duplicate_sample(sample): 170 | """ 171 | Duplicates sample images and contexts to preserve their unaugmented versions. 172 | 173 | Parameters 174 | ---------- 175 | sample : dict 176 | Input sample 177 | 178 | Returns 179 | ------- 180 | sample : dict 181 | Sample including [+"_original"] keys with copies of images and contexts. 182 | """ 183 | # Duplicate single items 184 | for key in filter_dict(sample, [ 185 | 'rgb' 186 | ]): 187 | sample['{}_original'.format(key)] = sample[key].copy() 188 | # Duplicate lists 189 | for key in filter_dict(sample, [ 190 | 'rgb_context' 191 | ]): 192 | sample['{}_original'.format(key)] = [k.copy() for k in sample[key]] 193 | # Return duplicated sample 194 | return sample 195 | 196 | def colorjitter_sample(sample, parameters, prob=1.0): 197 | """ 198 | Jitters input images as data augmentation. 199 | 200 | Parameters 201 | ---------- 202 | sample : dict 203 | Input sample 204 | parameters : tuple (brightness, contrast, saturation, hue) 205 | Color jittering parameters 206 | prob : float 207 | Jittering probability 208 | 209 | Returns 210 | ------- 211 | sample : dict 212 | Jittered sample 213 | """ 214 | if random.random() < prob: 215 | # Prepare transformation 216 | color_augmentation = transforms.ColorJitter() 217 | brightness, contrast, saturation, hue = parameters 218 | augment_image = color_augmentation.get_params( 219 | brightness=[max(0, 1 - brightness), 1 + brightness], 220 | contrast=[max(0, 1 - contrast), 1 + contrast], 221 | saturation=[max(0, 1 - saturation), 1 + saturation], 222 | hue=[-hue, hue]) 223 | # Jitter single items 224 | for key in filter_dict(sample, [ 225 | 'rgb' 226 | ]): 227 | sample[key] = augment_image(sample[key]) 228 | # Jitter lists 229 | for key in filter_dict(sample, [ 230 | 'rgb_context' 231 | ]): 232 | sample[key] = [augment_image(k) for k in sample[key]] 233 | # Return jittered (?) sample 234 | return sample 235 | 236 | ######################################################################################################################## 237 | 238 | 239 | -------------------------------------------------------------------------------- /dro_sfm/datasets/demon_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from collections import defaultdict 4 | import os 5 | 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | from dro_sfm.utils.image import load_image 9 | import cv2, IPython 10 | from PIL import Image 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | import h5py 14 | ######################################################################################################################## 15 | #### FUNCTIONS 16 | ######################################################################################################################## 17 | 18 | def dummy_calibration(image): 19 | return np.array([[5.703422047415297129e+02 , 0. , 3.200000000000000000e+02], 20 | [0. , 5.703422047415297129e+02 , 2.400000000000000000e+02], 21 | [0. , 0. , 1. ]]) 22 | 23 | 24 | ######################################################################################################################## 25 | #### DATASET 26 | ######################################################################################################################## 27 | 28 | class DemonDataset(Dataset): 29 | def __init__(self, root_dir, split, data_transform=None, 30 | forward_context=0, back_context=0, strides=(1,), 31 | depth_type=None, **kwargs): 32 | super().__init__() 33 | # Asserts 34 | # assert depth_type is None or depth_type == '', \ 35 | # 'NYUDataset currently does not support depth types' 36 | assert len(strides) == 1 and strides[0] == 1, \ 37 | 'NYUDataset currently only supports stride of 1.' 38 | 39 | self.depth_type = depth_type 40 | self.with_depth = depth_type is not '' and depth_type is not None 41 | self.root_dir = root_dir 42 | self.split = split 43 | with open(os.path.join(self.root_dir, split), "r") as f: 44 | data = f.readlines() 45 | 46 | self.paths = [] 47 | # Get file list from data 48 | for i, fname in enumerate(data): 49 | path = os.path.join(root_dir, fname.split()[0]) 50 | # if os.path.exists(path): 51 | self.paths.append(path) 52 | 53 | 54 | self.backward_context = back_context 55 | self.forward_context = forward_context 56 | self.has_context = self.backward_context + self.forward_context > 0 57 | self.strides = strides[0] 58 | 59 | self.data_transform = data_transform 60 | 61 | def __len__(self): 62 | return len(self.paths) 63 | 64 | def _change_idx(self, idx, filename): 65 | _, ext = os.path.splitext(os.path.basename(filename)) 66 | return str(idx) + ext 67 | 68 | def __getitem__(self, idx): 69 | filepath = self.paths[idx] 70 | 71 | image = load_image(os.path.join(filepath, '0000.jpg')) 72 | depth = np.load(os.path.join(filepath, '0000.npy')) 73 | 74 | rgb_contexts = [load_image(os.path.join(filepath, '0001.jpg'))] 75 | 76 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)] 77 | pos0 = np.zeros((4, 4)) 78 | pos1 = np.zeros((4, 4)) 79 | pos0[:3, :] = poses[0] 80 | pos0[3, 3] = 1. 81 | pos1[:3, :] = poses[1] 82 | pos1[3, 3] = 1. 83 | pos = np.matmul(pos1, np.linalg.inv(pos0)) 84 | # pos = np.matmul(np.linalg.inv(pos1), pos0) 85 | pose_context = [pos.astype(np.float32)] 86 | 87 | intr = np.genfromtxt(os.path.join(filepath, 'cam.txt')) 88 | 89 | sample = { 90 | 'idx': idx, 91 | 'filename': '%s' % (filepath.split('/')[-1]), 92 | 'rgb': image, 93 | 'depth': depth, 94 | 'pose_context': pose_context, 95 | 'intrinsics': intr 96 | } 97 | 98 | if self.has_context: 99 | sample['rgb_context'] = rgb_contexts 100 | 101 | if self.data_transform: 102 | sample = self.data_transform(sample) 103 | 104 | return sample 105 | 106 | ######################################################################################################################## 107 | -------------------------------------------------------------------------------- /dro_sfm/datasets/demon_mf_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from collections import defaultdict 4 | import os 5 | 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | from dro_sfm.utils.image import load_image 9 | import cv2, IPython 10 | from PIL import Image 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | import h5py 14 | ######################################################################################################################## 15 | #### FUNCTIONS 16 | ######################################################################################################################## 17 | 18 | def dummy_calibration(image): 19 | return np.array([[5.703422047415297129e+02 , 0. , 3.200000000000000000e+02], 20 | [0. , 5.703422047415297129e+02 , 2.400000000000000000e+02], 21 | [0. , 0. , 1. ]]) 22 | 23 | 24 | ######################################################################################################################## 25 | #### DATASET 26 | ######################################################################################################################## 27 | 28 | class DemonDataset(Dataset): 29 | def __init__(self, root_dir, split, data_transform=None, 30 | forward_context=0, back_context=0, strides=(1,), 31 | depth_type=None, **kwargs): 32 | super().__init__() 33 | # Asserts 34 | # assert depth_type is None or depth_type == '', \ 35 | # 'NYUDataset currently does not support depth types' 36 | assert len(strides) == 1 and strides[0] == 1, \ 37 | 'NYUDataset currently only supports stride of 1.' 38 | 39 | self.depth_type = depth_type 40 | self.with_depth = depth_type is not '' and depth_type is not None 41 | self.root_dir = root_dir 42 | self.split = split 43 | with open(os.path.join(self.root_dir, split), "r") as f: 44 | data = f.readlines() 45 | 46 | self.paths = [] 47 | # Get file list from data 48 | for i, fname in enumerate(data): 49 | path = os.path.join(root_dir, fname.split()[0]) 50 | view3_rgb = os.path.join(path, "0002.jpg") 51 | view3_depth = os.path.join(path, "0002.npy") 52 | if not (forward_context == 1 and back_context == 1): 53 | if os.path.exists(view3_rgb) and os.path.exists(view3_depth): 54 | self.paths.append((path, True)) 55 | else: 56 | self.paths.append((path, False)) 57 | else: 58 | if os.path.exists(view3_rgb) and os.path.exists(view3_depth): 59 | self.paths.append((path, True)) 60 | 61 | self.backward_context = back_context 62 | self.forward_context = forward_context 63 | self.has_context = (self.backward_context > 0 ) or (self.forward_context > 0) 64 | self.strides = strides[0] 65 | 66 | self.data_transform = data_transform 67 | 68 | def __len__(self): 69 | return len(self.paths) 70 | 71 | def _change_idx(self, idx, filename): 72 | _, ext = os.path.splitext(os.path.basename(filename)) 73 | return str(idx) + ext 74 | 75 | def _get_view2(self, filepath): 76 | image = load_image(os.path.join(filepath, '0000.jpg')) 77 | depth = np.load(os.path.join(filepath, '0000.npy')) 78 | rgb_contexts = [load_image(os.path.join(filepath, '0001.jpg'))] 79 | 80 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)] 81 | pos0 = np.zeros((4, 4)) 82 | pos1 = np.zeros((4, 4)) 83 | pos0[:3, :] = poses[0] 84 | pos0[3, 3] = 1. 85 | pos1[:3, :] = poses[1] 86 | pos1[3, 3] = 1. 87 | 88 | pos01 = np.matmul(pos1, np.linalg.inv(pos0)) 89 | pose_context = [pos01.astype(np.float32)] 90 | 91 | return image, depth, rgb_contexts, pose_context 92 | 93 | def _get_view3_dummy(self, filepath): 94 | image = load_image(os.path.join(filepath, '0000.jpg')) 95 | depth = np.load(os.path.join(filepath, '0000.npy')) 96 | rgb_contexts = [load_image(os.path.join(filepath, '0001.jpg')), load_image(os.path.join(filepath, '0001.jpg'))] 97 | 98 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)] 99 | pos0 = np.zeros((4, 4)) 100 | pos1 = np.zeros((4, 4)) 101 | pos0[:3, :] = poses[0] 102 | pos0[3, 3] = 1. 103 | pos1[:3, :] = poses[1] 104 | pos1[3, 3] = 1. 105 | 106 | pos01 = np.matmul(pos1, np.linalg.inv(pos0)) 107 | pose_context = [pos01.astype(np.float32), pos01.astype(np.float32)] 108 | 109 | return image, depth, rgb_contexts, pose_context 110 | 111 | def _get_view3(self, filepath): 112 | image = load_image(os.path.join(filepath, '0001.jpg')) 113 | depth = np.load(os.path.join(filepath, '0001.npy')) 114 | rgb_contexts = [load_image(os.path.join(filepath, '0000.jpg')), load_image(os.path.join(filepath, '0002.jpg'))] 115 | 116 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)] 117 | 118 | pos0 = np.eye(4) 119 | pos1 = np.eye(4) 120 | pos2 = np.eye(4) 121 | pos0[:3, :] = poses[0] 122 | pos1[:3, :] = poses[1] 123 | pos2[:3, :] = poses[2] 124 | 125 | pos10 = np.matmul(pos0, np.linalg.inv(pos1)) 126 | pos12 = np.matmul(pos2, np.linalg.inv(pos1)) 127 | pose_context = [pos10.astype(np.float32), pos12.astype(np.float32)] 128 | 129 | return image, depth, rgb_contexts, pose_context 130 | 131 | 132 | def __getitem__(self, idx): 133 | filepath, mf = self.paths[idx] 134 | 135 | if self.forward_context == 1 and self.backward_context == 0: 136 | image, depth, rgb_contexts, pose_context = self._get_view2(filepath) 137 | 138 | elif self.forward_context == 1 and self.backward_context == 1: 139 | image, depth, rgb_contexts, pose_context = self._get_view3(filepath) 140 | 141 | elif self.forward_context == 1 and self.backward_context == -1: 142 | if np.random.random() > 0.5 and mf: 143 | image, depth, rgb_contexts, pose_context = self._get_view3(filepath) 144 | else: 145 | image, depth, rgb_contexts, pose_context = self._get_view3_dummy(filepath) 146 | else: 147 | raise NotImplementedError 148 | 149 | intr = np.genfromtxt(os.path.join(filepath, 'cam.txt')) 150 | 151 | sample = { 152 | 'idx': idx, 153 | 'filename': '%s' % (filepath.split('/')[-1]), 154 | 'rgb': image, 155 | 'depth': depth, 156 | 'pose_context': pose_context, 157 | 'intrinsics': intr 158 | } 159 | 160 | if self.has_context: 161 | sample['rgb_context'] = rgb_contexts 162 | 163 | if self.data_transform: 164 | sample = self.data_transform(sample) 165 | 166 | return sample 167 | 168 | ######################################################################################################################## 169 | -------------------------------------------------------------------------------- /dro_sfm/datasets/image_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from collections import defaultdict 4 | import os 5 | 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | from dro_sfm.utils.image import load_image 9 | 10 | ######################################################################################################################## 11 | #### FUNCTIONS 12 | ######################################################################################################################## 13 | 14 | def dummy_calibration(image): 15 | w, h = [float(d) for d in image.size] 16 | return np.array([[1000. , 0. , w / 2. - 0.5], 17 | [0. , 1000. , h / 2. - 0.5], 18 | [0. , 0. , 1. ]]) 19 | 20 | def get_idx(filename): 21 | return int(re.search(r'\d+', filename).group()) 22 | 23 | def read_files(directory, ext=('.png', '.jpg', '.jpeg'), skip_empty=True): 24 | files = defaultdict(list) 25 | for entry in os.scandir(directory): 26 | relpath = os.path.relpath(entry.path, directory) 27 | if entry.is_dir(): 28 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty) 29 | if skip_empty and not len(d_files): 30 | continue 31 | files[relpath] = d_files[entry.path] 32 | elif entry.is_file(): 33 | if ext is None or entry.path.lower().endswith(tuple(ext)): 34 | files[directory].append(relpath) 35 | return files 36 | 37 | ######################################################################################################################## 38 | #### DATASET 39 | ######################################################################################################################## 40 | 41 | class ImageDataset(Dataset): 42 | def __init__(self, root_dir, split, data_transform=None, 43 | forward_context=0, back_context=0, strides=(1,), 44 | depth_type=None, **kwargs): 45 | super().__init__() 46 | # Asserts 47 | assert depth_type is None or depth_type == '', \ 48 | 'ImageDataset currently does not support depth types' 49 | assert len(strides) == 1 and strides[0] == 1, \ 50 | 'ImageDataset currently only supports stride of 1.' 51 | 52 | self.root_dir = root_dir 53 | self.split = split 54 | 55 | self.backward_context = back_context 56 | self.forward_context = forward_context 57 | self.has_context = self.backward_context + self.forward_context > 0 58 | self.strides = 1 59 | 60 | self.files = [] 61 | file_tree = read_files(root_dir) 62 | for k, v in file_tree.items(): 63 | file_set = set(file_tree[k]) 64 | files = [fname for fname in sorted(v) if self._has_context(fname, file_set)] 65 | self.files.extend([[k, fname] for fname in files]) 66 | 67 | self.data_transform = data_transform 68 | 69 | def __len__(self): 70 | return len(self.files) 71 | 72 | def _change_idx(self, idx, filename): 73 | _, ext = os.path.splitext(os.path.basename(filename)) 74 | return self.split.format(idx) + ext 75 | 76 | def _has_context(self, filename, file_set): 77 | context_paths = self._get_context_file_paths(filename) 78 | return all([f in file_set for f in context_paths]) 79 | 80 | def _get_context_file_paths(self, filename): 81 | fidx = get_idx(filename) 82 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \ 83 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides) 84 | return [self._change_idx(fidx + i, filename) for i in idxs] 85 | 86 | def _read_rgb_context_files(self, session, filename): 87 | context_paths = self._get_context_file_paths(filename) 88 | return [load_image(os.path.join(self.root_dir, session, filename)) 89 | for filename in context_paths] 90 | 91 | def _read_rgb_file(self, session, filename): 92 | return load_image(os.path.join(self.root_dir, session, filename)) 93 | 94 | def __getitem__(self, idx): 95 | session, filename = self.files[idx] 96 | image = self._read_rgb_file(session, filename) 97 | 98 | sample = { 99 | 'idx': idx, 100 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]), 101 | # 102 | 'rgb': image, 103 | 'intrinsics': dummy_calibration(image) 104 | } 105 | 106 | if self.has_context: 107 | sample['rgb_context'] = \ 108 | self._read_rgb_context_files(session, filename) 109 | 110 | if self.data_transform: 111 | sample = self.data_transform(sample) 112 | 113 | return sample 114 | 115 | ######################################################################################################################## 116 | -------------------------------------------------------------------------------- /dro_sfm/datasets/kitti_dataset_utils.py: -------------------------------------------------------------------------------- 1 | """Provides helper methods for loading and parsing KITTI data.""" 2 | 3 | from collections import namedtuple 4 | 5 | import numpy as np 6 | 7 | __author__ = "Lee Clement" 8 | __email__ = "lee.clement@robotics.utias.utoronto.ca" 9 | 10 | # Per dataformat.txt 11 | OxtsPacket = namedtuple('OxtsPacket', 12 | 'lat, lon, alt, ' + 13 | 'roll, pitch, yaw, ' + 14 | 'vn, ve, vf, vl, vu, ' + 15 | 'ax, ay, az, af, al, au, ' + 16 | 'wx, wy, wz, wf, wl, wu, ' + 17 | 'pos_accuracy, vel_accuracy, ' + 18 | 'navstat, numsats, ' + 19 | 'posmode, velmode, orimode') 20 | 21 | # Bundle into an easy-to-access structure 22 | OxtsData = namedtuple('OxtsData', 'packet, T_w_imu') 23 | 24 | 25 | def rotx(t): 26 | """ 27 | Rotation about the x-axis 28 | 29 | Parameters 30 | ---------- 31 | t : float 32 | Theta angle 33 | 34 | Returns 35 | ------- 36 | matrix : np.array [3,3] 37 | Rotation matrix 38 | """ 39 | c = np.cos(t) 40 | s = np.sin(t) 41 | return np.array([[1, 0, 0], 42 | [0, c, -s], 43 | [0, s, c]]) 44 | 45 | 46 | def roty(t): 47 | """ 48 | Rotation about the y-axis 49 | 50 | Parameters 51 | ---------- 52 | t : float 53 | Theta angle 54 | 55 | Returns 56 | ------- 57 | matrix : np.array [3,3] 58 | Rotation matrix 59 | """ 60 | c = np.cos(t) 61 | s = np.sin(t) 62 | return np.array([[c, 0, s], 63 | [0, 1, 0], 64 | [-s, 0, c]]) 65 | 66 | 67 | def rotz(t): 68 | """ 69 | Rotation about the z-axis 70 | 71 | Parameters 72 | ---------- 73 | t : float 74 | Theta angle 75 | 76 | Returns 77 | ------- 78 | matrix : np.array [3,3] 79 | Rotation matrix 80 | """ 81 | c = np.cos(t) 82 | s = np.sin(t) 83 | return np.array([[c, -s, 0], 84 | [s, c, 0], 85 | [0, 0, 1]]) 86 | 87 | 88 | def transform_from_rot_trans(R, t): 89 | """ 90 | Transformation matrix from rotation matrix and translation vector. 91 | 92 | Parameters 93 | ---------- 94 | R : np.array [3,3] 95 | Rotation matrix 96 | t : np.array [3] 97 | translation vector 98 | 99 | Returns 100 | ------- 101 | matrix : np.array [4,4] 102 | Transformation matrix 103 | """ 104 | R = R.reshape(3, 3) 105 | t = t.reshape(3, 1) 106 | return np.vstack((np.hstack([R, t]), [0, 0, 0, 1])) 107 | 108 | 109 | def read_calib_file(filepath): 110 | """ 111 | Read in a calibration file and parse into a dictionary 112 | 113 | Parameters 114 | ---------- 115 | filepath : str 116 | File path to read from 117 | 118 | Returns 119 | ------- 120 | calib : dict 121 | Dictionary with calibration values 122 | """ 123 | data = {} 124 | 125 | with open(filepath, 'r') as f: 126 | for line in f.readlines(): 127 | key, value = line.split(':', 1) 128 | # The only non-float values in these files are dates, which 129 | # we don't care about anyway 130 | try: 131 | data[key] = np.array([float(x) for x in value.split()]) 132 | except ValueError: 133 | pass 134 | 135 | return data 136 | 137 | 138 | def pose_from_oxts_packet(raw_data, scale): 139 | """ 140 | Helper method to compute a SE(3) pose matrix from an OXTS packet 141 | 142 | Parameters 143 | ---------- 144 | raw_data : dict 145 | Oxts data to read from 146 | scale : float 147 | Oxts scale 148 | 149 | Returns 150 | ------- 151 | R : np.array [3,3] 152 | Rotation matrix 153 | t : np.array [3] 154 | Translation vector 155 | """ 156 | packet = OxtsPacket(*raw_data) 157 | er = 6378137. # earth radius (approx.) in meters 158 | 159 | # Use a Mercator projection to get the translation vector 160 | tx = scale * packet.lon * np.pi * er / 180. 161 | ty = scale * er * \ 162 | np.log(np.tan((90. + packet.lat) * np.pi / 360.)) 163 | tz = packet.alt 164 | t = np.array([tx, ty, tz]) 165 | 166 | # Use the Euler angles to get the rotation matrix 167 | Rx = rotx(packet.roll) 168 | Ry = roty(packet.pitch) 169 | Rz = rotz(packet.yaw) 170 | R = Rz.dot(Ry.dot(Rx)) 171 | 172 | # Combine the translation and rotation into a homogeneous transform 173 | return R, t 174 | 175 | 176 | def load_oxts_packets_and_poses(oxts_files): 177 | """ 178 | Generator to read OXTS ground truth data. 179 | Poses are given in an East-North-Up coordinate system 180 | whose origin is the first GPS position. 181 | 182 | Parameters 183 | ---------- 184 | oxts_files : list of str 185 | List of oxts files to read from 186 | 187 | Returns 188 | ------- 189 | oxts : list of dict 190 | List of oxts ground-truth data 191 | """ 192 | # Scale for Mercator projection (from first lat value) 193 | scale = None 194 | # Origin of the global coordinate system (first GPS position) 195 | origin = None 196 | 197 | oxts = [] 198 | 199 | for filename in oxts_files: 200 | with open(filename, 'r') as f: 201 | for line in f.readlines(): 202 | line = line.split() 203 | # Last five entries are flags and counts 204 | line[:-5] = [float(x) for x in line[:-5]] 205 | line[-5:] = [int(float(x)) for x in line[-5:]] 206 | 207 | packet = OxtsPacket(*line) 208 | 209 | if scale is None: 210 | scale = np.cos(packet.lat * np.pi / 180.) 211 | 212 | R, t = pose_from_oxts_packet(packet, scale) 213 | 214 | if origin is None: 215 | origin = t 216 | 217 | T_w_imu = transform_from_rot_trans(R, t - origin) 218 | 219 | oxts.append(OxtsData(packet, T_w_imu)) 220 | 221 | return oxts 222 | 223 | 224 | -------------------------------------------------------------------------------- /dro_sfm/datasets/nyu_dataset_processed.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from collections import defaultdict 4 | import os 5 | 6 | from torch.utils.data import Dataset 7 | import numpy as np 8 | from dro_sfm.utils.image import load_image 9 | import cv2, IPython 10 | from PIL import Image 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | import h5py 14 | ######################################################################################################################## 15 | #### FUNCTIONS 16 | ######################################################################################################################## 17 | 18 | def dummy_calibration(image): 19 | return np.array([[518.85790117450188 , 0. , 325.58244941119034], 20 | [0. , 519.46961112127485 , 253.73616633400465], 21 | [0. , 0. , 1. ]]) 22 | 23 | def get_idx(filename): 24 | return int(re.search(r'\d+', filename).group()) 25 | 26 | def read_files(directory, ext=('.depth', '.h5'), skip_empty=True): 27 | files = defaultdict(list) 28 | for entry in os.scandir(directory): 29 | relpath = os.path.relpath(entry.path, directory) 30 | if entry.is_dir(): 31 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty) 32 | if skip_empty and not len(d_files): 33 | continue 34 | files[relpath] = d_files[entry.path] 35 | elif entry.is_file(): 36 | if (ext is None or entry.path.lower().endswith(tuple(ext))): 37 | files[directory].append(relpath) 38 | return files 39 | 40 | def read_npz_depth(file, depth_type): 41 | """Reads a .npz depth map given a certain depth_type.""" 42 | depth = np.load(file)[depth_type + '_depth'].astype(np.float32) 43 | return np.expand_dims(depth, axis=2) 44 | 45 | def read_png_depth(file): 46 | """Reads a .png depth map.""" 47 | depth_png = np.array(load_image(file), dtype=int) 48 | 49 | # assert (np.max(depth_png) > 255), 'Wrong .png depth file' 50 | if (np.max(depth_png) > 255): 51 | depth = depth_png.astype(np.float) / 256. 52 | else: 53 | depth = depth_png.astype(np.float) 54 | depth[depth_png == 0] = -1. 55 | return np.expand_dims(depth, axis=2) 56 | 57 | ######################################################################################################################## 58 | #### DATASET 59 | ######################################################################################################################## 60 | 61 | class NYUDataset(Dataset): 62 | def __init__(self, root_dir, split, data_transform=None, 63 | forward_context=0, back_context=0, strides=(1,), 64 | depth_type=None, **kwargs): 65 | super().__init__() 66 | # Asserts 67 | # assert depth_type is None or depth_type == '', \ 68 | # 'NYUDataset currently does not support depth types' 69 | assert len(strides) == 1 and strides[0] == 1, \ 70 | 'NYUDataset currently only supports stride of 1.' 71 | 72 | self.depth_type = depth_type 73 | self.with_depth = depth_type is not '' and depth_type is not None 74 | self.root_dir = root_dir 75 | self.split = split 76 | 77 | self.backward_context = back_context 78 | self.forward_context = forward_context 79 | self.has_context = self.backward_context + self.forward_context > 0 80 | self.strides = strides[0] 81 | 82 | self.files = [] 83 | self.file_tree = read_files(root_dir) 84 | for k, v in self.file_tree.items(): 85 | file_list = sorted(v) 86 | files = [fname for fname in file_list if self._has_context(k, fname, file_list)] 87 | self.files.extend([[k, fname] for fname in files]) 88 | 89 | self.data_transform = data_transform 90 | 91 | def __len__(self): 92 | return len(self.files) 93 | 94 | def _change_idx(self, idx, filename): 95 | _, ext = os.path.splitext(os.path.basename(filename)) 96 | return str(idx) + ext 97 | 98 | def _has_context(self, session, filename, file_list): 99 | context_paths = self._get_context_file_paths(filename, file_list) 100 | return all([f in file_list for f in context_paths]) 101 | 102 | def _get_context_file_paths(self, filename, filelist): 103 | # fidx = get_idx(filename) 104 | fidx = filelist.index(filename) 105 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \ 106 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides) 107 | return [filelist[fidx+i] if 0 <= fidx+i < len(filelist) else 'none' for i in idxs] 108 | 109 | def _read_rgb_context_files(self, session, filename): 110 | context_paths = self._get_context_file_paths(filename, sorted(self.file_tree[session])) 111 | 112 | return [self._read_rgb_file(session, filename) 113 | for filename in context_paths] 114 | 115 | def _read_rgb_file(self, session, filename): 116 | file_path = os.path.join(self.root_dir, session, filename) 117 | h5f = h5py.File(file_path, "r") 118 | rgb = np.array(h5f['rgb']) 119 | image = np.transpose(rgb, (1, 2, 0)) 120 | image_pil = Image.fromarray(image) 121 | return image_pil 122 | 123 | def __getitem__(self, idx): 124 | session, filename = self.files[idx] 125 | if session == self.root_dir: 126 | session = '' 127 | 128 | file_path = os.path.join(self.root_dir, session, filename) 129 | h5f = h5py.File(file_path, "r") 130 | rgb = np.array(h5f['rgb']) 131 | image = np.transpose(rgb, (1, 2, 0)) 132 | # image = rgb[...,::-1] 133 | image_pil = Image.fromarray(image) 134 | depth = np.array(h5f['depth']) 135 | 136 | sample = { 137 | 'idx': idx, 138 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]), 139 | 'rgb': image_pil, 140 | 'depth': depth, 141 | 'intrinsics': dummy_calibration(image) 142 | } 143 | 144 | if self.has_context: 145 | sample['rgb_context'] = \ 146 | self._read_rgb_context_files(session, filename) 147 | 148 | if self.data_transform: 149 | sample = self.data_transform(sample) 150 | 151 | return sample 152 | 153 | ######################################################################################################################## 154 | -------------------------------------------------------------------------------- /dro_sfm/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import partial 3 | from dro_sfm.datasets.augmentations import resize_image, resize_sample, \ 4 | duplicate_sample, colorjitter_sample, to_tensor_sample, resize_sample_image_and_intrinsics 5 | 6 | ######################################################################################################################## 7 | 8 | def train_transforms(sample, image_shape, jittering): 9 | """ 10 | Training data augmentation transformations 11 | 12 | Parameters 13 | ---------- 14 | sample : dict 15 | Sample to be augmented 16 | image_shape : tuple (height, width) 17 | Image dimension to reshape 18 | jittering : tuple (brightness, contrast, saturation, hue) 19 | Color jittering parameters 20 | 21 | Returns 22 | ------- 23 | sample : dict 24 | Augmented sample 25 | """ 26 | if len(image_shape) > 0: 27 | sample = resize_sample(sample, image_shape) 28 | sample = duplicate_sample(sample) 29 | if len(jittering) > 0: 30 | sample = colorjitter_sample(sample, jittering) 31 | sample = to_tensor_sample(sample) 32 | return sample 33 | 34 | def validation_transforms(sample, image_shape): 35 | """ 36 | Validation data augmentation transformations 37 | 38 | Parameters 39 | ---------- 40 | sample : dict 41 | Sample to be augmented 42 | image_shape : tuple (height, width) 43 | Image dimension to reshape 44 | 45 | Returns 46 | ------- 47 | sample : dict 48 | Augmented sample 49 | """ 50 | # if len(image_shape) > 0: 51 | # sample['rgb'] = resize_image(sample['rgb'], image_shape) 52 | # if 'rgb_context' in sample: 53 | # sample['rgb_context'] = [resize_image(img, image_shape) for img in sample['rgb_context']] 54 | 55 | if len(image_shape) > 0: 56 | sample = resize_sample_image_and_intrinsics(sample, image_shape) 57 | 58 | sample = to_tensor_sample(sample) 59 | return sample 60 | 61 | def test_transforms(sample, image_shape): 62 | """ 63 | Test data augmentation transformations 64 | 65 | Parameters 66 | ---------- 67 | sample : dict 68 | Sample to be augmented 69 | image_shape : tuple (height, width) 70 | Image dimension to reshape 71 | 72 | Returns 73 | ------- 74 | sample : dict 75 | Augmented sample 76 | """ 77 | # if len(image_shape) > 0: 78 | # sample['rgb'] = resize_image(sample['rgb'], image_shape) 79 | # if 'rgb_context' in sample: 80 | # sample['rgb_context'] = [resize_image(img, image_shape) for img in sample['rgb_context']] 81 | 82 | if len(image_shape) > 0: 83 | sample = resize_sample_image_and_intrinsics(sample, image_shape) 84 | 85 | sample = to_tensor_sample(sample) 86 | return sample 87 | 88 | def get_transforms(mode, image_shape, jittering, **kwargs): 89 | """ 90 | Get data augmentation transformations for each split 91 | 92 | Parameters 93 | ---------- 94 | mode : str {'train', 'validation', 'test'} 95 | Mode from which we want the data augmentation transformations 96 | image_shape : tuple (height, width) 97 | Image dimension to reshape 98 | jittering : tuple (brightness, contrast, saturation, hue) 99 | Color jittering parameters 100 | 101 | Returns 102 | ------- 103 | XXX_transform: Partial function 104 | Data augmentation transformation for that mode 105 | """ 106 | if mode == 'train': 107 | return partial(train_transforms, 108 | image_shape=image_shape, 109 | jittering=jittering) 110 | elif mode == 'validation': 111 | return partial(validation_transforms, 112 | image_shape=image_shape) 113 | elif mode == 'test': 114 | return partial(test_transforms, 115 | image_shape=image_shape) 116 | else: 117 | raise ValueError('Unknown mode {}'.format(mode)) 118 | 119 | ######################################################################################################################## 120 | 121 | -------------------------------------------------------------------------------- /dro_sfm/datasets/video_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from collections import defaultdict 4 | import os 5 | from numpy.core.defchararray import join 6 | 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | from dro_sfm.utils.image import load_image 10 | import cv2, IPython 11 | from PIL import Image 12 | from PIL import ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | import h5py 15 | ######################################################################################################################## 16 | #### FUNCTIONS 17 | ######################################################################################################################## 18 | 19 | def dummy_calibration(w, h): 20 | fx = w * 1.2 21 | fy = w * 1.2 22 | cx = w / 2.0 23 | cy = h / 2.0 24 | return np.array([[fx, 0., cx], 25 | [0., fy, cy], 26 | [0., 0., 1.]]) 27 | 28 | 29 | def get_idx(filename): 30 | return int(re.search(r'\d+', filename).group()) 31 | 32 | def read_files(directory, ext=('.depth', '.h5'), skip_empty=True): 33 | files = defaultdict(list) 34 | for entry in os.scandir(directory): 35 | relpath = os.path.relpath(entry.path, directory) 36 | if entry.is_dir(): 37 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty) 38 | if skip_empty and not len(d_files): 39 | continue 40 | files[relpath] = d_files[entry.path] 41 | elif entry.is_file(): 42 | if (ext is None or entry.path.lower().endswith(tuple(ext))): 43 | files[directory].append(relpath) 44 | return files 45 | 46 | def read_npz_depth(file, depth_type): 47 | """Reads a .npz depth map given a certain depth_type.""" 48 | depth = np.load(file)[depth_type + '_depth'].astype(np.float32) 49 | return np.expand_dims(depth, axis=2) 50 | 51 | def read_png_depth(file): 52 | """Reads a .png depth map.""" 53 | depth_png = np.array(load_image(file), dtype=int) 54 | 55 | # assert (np.max(depth_png) > 255), 'Wrong .png depth file' 56 | if (np.max(depth_png) > 255): 57 | depth = depth_png.astype(np.float) / 256. 58 | else: 59 | depth = depth_png.astype(np.float) 60 | depth[depth_png == 0] = -1. 61 | return np.expand_dims(depth, axis=2) 62 | 63 | ######################################################################################################################## 64 | #### DATASET 65 | ######################################################################################################################## 66 | 67 | class VideoDataset(Dataset): 68 | def __init__(self, root_dir, split, data_transform=None, 69 | forward_context=0, back_context=0, strides=(1,), 70 | depth_type=None, **kwargs): 71 | super().__init__() 72 | # Asserts 73 | assert depth_type is None or depth_type == '', \ 74 | 'VideoDataset currently does not support depth types' 75 | self.depth_type = depth_type 76 | self.with_depth = depth_type is not '' and depth_type is not None 77 | self.root_dir = root_dir 78 | self.split = split 79 | 80 | self.backward_context = back_context 81 | self.forward_context = forward_context 82 | self.has_context = self.backward_context + self.forward_context > 0 83 | self.strides = strides[0] 84 | 85 | self.files = [] 86 | self.file_tree = read_files(root_dir, ext=(".jpg", ".png", ".bmp", ".jpeg")) 87 | for k, v in self.file_tree.items(): 88 | file_list = sorted(v) 89 | files = [fname for fname in file_list if self._has_context(k, fname, file_list)] 90 | self.files.extend([[k, fname] for fname in files]) 91 | 92 | self.data_transform = data_transform 93 | 94 | def __len__(self): 95 | return len(self.files) 96 | 97 | def _has_context(self, session, filename, file_list): 98 | context_paths = self._get_context_file_paths(filename, file_list) 99 | return all([os.path.exists(os.path.join(self.root_dir, session, f)) for f in context_paths]) 100 | 101 | def _get_context_file_paths(self, filename, filelist): 102 | # fidx = get_idx(filename) 103 | fidx = filelist.index(filename) 104 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \ 105 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides) 106 | 107 | return [filelist[fidx+i] if (0 <= (fidx+i)) and ((fidx+i) < len(filelist)) else 'none' for i in idxs] 108 | 109 | def _read_rgb_context_files(self, session, filename): 110 | context_paths = self._get_context_file_paths(filename, sorted(self.file_tree[session])) 111 | return [self._read_rgb_file(session, filename) 112 | for filename in context_paths] 113 | 114 | def _read_rgb_file(self, session, filename): 115 | file_path = os.path.join(self.root_dir, session, filename) 116 | rgb = load_image(file_path) 117 | return rgb 118 | 119 | def __getitem__(self, idx): 120 | session, filename = self.files[idx] 121 | if session == self.root_dir: 122 | session = '' 123 | 124 | rgb = self._read_rgb_file(session, filename) 125 | 126 | sample = { 127 | 'idx': idx, 128 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]), 129 | 'rgb': rgb, 130 | 'intrinsics': dummy_calibration(w=rgb.size[0], h=rgb.size[1]) 131 | } 132 | 133 | # Add depth information if requested 134 | if self.with_depth: 135 | sample.update({ 136 | 'depth': None, 137 | }) 138 | 139 | if self.has_context: 140 | sample['rgb_context'] = \ 141 | self._read_rgb_context_files(session, filename) 142 | 143 | if self.data_transform: 144 | sample = self.data_transform(sample) 145 | 146 | return sample 147 | 148 | ######################################################################################################################## 149 | -------------------------------------------------------------------------------- /dro_sfm/datasets/video_random_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | from collections import defaultdict 4 | import os 5 | from numpy.core.defchararray import join 6 | 7 | from torch.utils.data import Dataset 8 | import numpy as np 9 | from dro_sfm.utils.image import load_image 10 | import cv2, IPython 11 | from PIL import Image 12 | from PIL import ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | import h5py 15 | ######################################################################################################################## 16 | #### FUNCTIONS 17 | ######################################################################################################################## 18 | 19 | def dummy_calibration(w, h): 20 | fx = w * 1.2 21 | fy = w * 1.2 22 | cx = w / 2.0 23 | cy = h / 2.0 24 | return np.array([[fx, 0., cx], 25 | [0., fy, cy], 26 | [0., 0., 1.]]) 27 | 28 | 29 | def get_idx(filename): 30 | return int(re.search(r'\d+', filename).group()) 31 | 32 | def read_files(directory, ext=('.depth', '.h5'), skip_empty=True): 33 | files = defaultdict(list) 34 | for entry in os.scandir(directory): 35 | relpath = os.path.relpath(entry.path, directory) 36 | if entry.is_dir(): 37 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty) 38 | if skip_empty and not len(d_files): 39 | continue 40 | files[relpath] = d_files[entry.path] 41 | elif entry.is_file(): 42 | if (ext is None or entry.path.lower().endswith(tuple(ext))): 43 | files[directory].append(relpath) 44 | return files 45 | 46 | def read_npz_depth(file, depth_type): 47 | """Reads a .npz depth map given a certain depth_type.""" 48 | depth = np.load(file)[depth_type + '_depth'].astype(np.float32) 49 | return np.expand_dims(depth, axis=2) 50 | 51 | def read_png_depth(file): 52 | """Reads a .png depth map.""" 53 | depth_png = np.array(load_image(file), dtype=int) 54 | 55 | # assert (np.max(depth_png) > 255), 'Wrong .png depth file' 56 | if (np.max(depth_png) > 255): 57 | depth = depth_png.astype(np.float) / 256. 58 | else: 59 | depth = depth_png.astype(np.float) 60 | depth[depth_png == 0] = -1. 61 | return np.expand_dims(depth, axis=2) 62 | 63 | ######################################################################################################################## 64 | #### DATASET 65 | ######################################################################################################################## 66 | 67 | class VideoRandomDataset(Dataset): 68 | def __init__(self, root_dir, split, data_transform=None, 69 | forward_context=0, back_context=0, strides=(1,), 70 | depth_type=None, **kwargs): 71 | super().__init__() 72 | # Asserts 73 | assert depth_type is None or depth_type == '', \ 74 | 'VideoDataset currently does not support depth types' 75 | self.depth_type = depth_type 76 | self.with_depth = depth_type is not '' and depth_type is not None 77 | self.root_dir = root_dir 78 | self.split = split 79 | 80 | self.backward_context = back_context 81 | self.forward_context = forward_context 82 | 83 | self.max_len = 11 84 | 85 | self.has_context = self.backward_context + self.forward_context > 0 86 | self.strides = strides[0] 87 | 88 | self.files = [] 89 | self.file_tree = read_files(root_dir, ext=(".jpg", ".png", ".bmp", ".jpeg")) 90 | for k, v in self.file_tree.items(): 91 | file_list = sorted(v) 92 | files = [fname for fname in file_list if self._has_context(k, fname, file_list)] 93 | self.files.extend([[k, fname] for fname in files]) 94 | 95 | self.data_transform = data_transform 96 | 97 | def __len__(self): 98 | return len(self.files) 99 | 100 | def _has_context(self, session, filename, file_list): 101 | context_paths = self._get_context_file_paths(filename, file_list) 102 | return all([os.path.exists(os.path.join(self.root_dir, session, f)) for f in context_paths]) 103 | 104 | def _get_context_file_paths(self, filename, filelist): 105 | # fidx = get_idx(filename) 106 | fidx = filelist.index(filename) 107 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \ 108 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides) 109 | 110 | return [filelist[fidx+i] if (0 <= (fidx+i)) and ((fidx+i) < len(filelist)) else 'none' for i in idxs] 111 | 112 | 113 | def _get_context_file_paths_random(self, filename, filelist): 114 | # fidx = get_idx(filename) 115 | fidx = filelist.index(filename) 116 | idxs_back = list(np.arange(-self.backward_context * self.strides, 0, 1)) 117 | 118 | idx_forw= list(np.arange(0, self.forward_context * self.strides, 1) + 1) 119 | 120 | idxs_back = np.random.choice(idxs_back, self.backward_context).tolist() 121 | 122 | idx_forw = np.random.choice(idx_forw, self.forward_context).tolist() 123 | 124 | idxs = idxs_back + idx_forw 125 | 126 | return [filelist[i + fidx] for i in idxs] 127 | 128 | 129 | def _read_rgb_context_files(self, session, filename): 130 | context_paths = self._get_context_file_paths_random(filename, sorted(self.file_tree[session])) 131 | # print(filename, context_paths) 132 | return [self._read_rgb_file(session, filename) 133 | for filename in context_paths] 134 | 135 | 136 | def _read_rgb_file(self, session, filename): 137 | file_path = os.path.join(self.root_dir, session, filename) 138 | rgb = load_image(file_path) 139 | return rgb 140 | 141 | def __getitem__(self, idx): 142 | session, filename = self.files[idx] 143 | if session == self.root_dir: 144 | session = '' 145 | 146 | rgb = self._read_rgb_file(session, filename) 147 | 148 | sample = { 149 | 'idx': idx, 150 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]), 151 | 'rgb': rgb, 152 | 'intrinsics': dummy_calibration(w=rgb.size[0], h=rgb.size[1]) 153 | } 154 | 155 | # Add depth information if requested 156 | if self.with_depth: 157 | sample.update({ 158 | 'depth': None, 159 | }) 160 | 161 | if self.has_context: 162 | sample['rgb_context'] = \ 163 | self._read_rgb_context_files(session, filename) 164 | 165 | if self.data_transform: 166 | sample = self.data_transform(sample) 167 | 168 | return sample 169 | 170 | ######################################################################################################################## 171 | -------------------------------------------------------------------------------- /dro_sfm/geometry/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/geometry/__init__.py -------------------------------------------------------------------------------- /dro_sfm/geometry/camera.py: -------------------------------------------------------------------------------- 1 | 2 | from functools import lru_cache 3 | import torch 4 | import torch.nn as nn 5 | 6 | from dro_sfm.geometry.pose import Pose 7 | from dro_sfm.geometry.camera_utils import scale_intrinsics 8 | from dro_sfm.utils.image import image_grid 9 | 10 | ######################################################################################################################## 11 | 12 | class Camera(nn.Module): 13 | """ 14 | Differentiable camera class implementing reconstruction and projection 15 | functions for a pinhole model. 16 | """ 17 | def __init__(self, K, Tcw=None): 18 | """ 19 | Initializes the Camera class 20 | 21 | Parameters 22 | ---------- 23 | K : torch.Tensor [3,3] 24 | Camera intrinsics 25 | Tcw : Pose 26 | Camera -> World pose transformation 27 | """ 28 | super().__init__() 29 | self.K = K 30 | self.Tcw = Pose.identity(len(K)) if Tcw is None else Tcw 31 | 32 | def __len__(self): 33 | """Batch size of the camera intrinsics""" 34 | return len(self.K) 35 | 36 | def to(self, *args, **kwargs): 37 | """Moves object to a specific device""" 38 | self.K = self.K.to(*args, **kwargs) 39 | self.Tcw = self.Tcw.to(*args, **kwargs) 40 | return self 41 | 42 | ######################################################################################################################## 43 | 44 | @property 45 | def fx(self): 46 | """Focal length in x""" 47 | return self.K[:, 0, 0] 48 | 49 | @property 50 | def fy(self): 51 | """Focal length in y""" 52 | return self.K[:, 1, 1] 53 | 54 | @property 55 | def cx(self): 56 | """Principal point in x""" 57 | return self.K[:, 0, 2] 58 | 59 | @property 60 | def cy(self): 61 | """Principal point in y""" 62 | return self.K[:, 1, 2] 63 | 64 | @property 65 | @lru_cache() 66 | def Twc(self): 67 | """World -> Camera pose transformation (inverse of Tcw)""" 68 | return self.Tcw.inverse() 69 | 70 | @property 71 | @lru_cache() 72 | def Kinv(self): 73 | """Inverse intrinsics (for lifting)""" 74 | Kinv = self.K.clone() 75 | Kinv[:, 0, 0] = 1. / self.fx 76 | Kinv[:, 1, 1] = 1. / self.fy 77 | Kinv[:, 0, 2] = -1. * self.cx / self.fx 78 | Kinv[:, 1, 2] = -1. * self.cy / self.fy 79 | return Kinv 80 | 81 | ######################################################################################################################## 82 | 83 | def scaled(self, x_scale, y_scale=None): 84 | """ 85 | Returns a scaled version of the camera (changing intrinsics) 86 | 87 | Parameters 88 | ---------- 89 | x_scale : float 90 | Resize scale in x 91 | y_scale : float 92 | Resize scale in y. If None, use the same as x_scale 93 | 94 | Returns 95 | ------- 96 | camera : Camera 97 | Scaled version of the current cmaera 98 | """ 99 | # If single value is provided, use for both dimensions 100 | if y_scale is None: 101 | y_scale = x_scale 102 | # If no scaling is necessary, return same camera 103 | if x_scale == 1. and y_scale == 1.: 104 | return self 105 | # Scale intrinsics and return new camera with same Pose 106 | K = scale_intrinsics(self.K.clone(), x_scale, y_scale) 107 | return Camera(K, Tcw=self.Tcw) 108 | 109 | ######################################################################################################################## 110 | 111 | def reconstruct(self, depth, frame='w'): 112 | """ 113 | Reconstructs pixel-wise 3D points from a depth map. 114 | 115 | Parameters 116 | ---------- 117 | depth : torch.Tensor [B,1,H,W] 118 | Depth map for the camera 119 | frame : 'w' 120 | Reference frame: 'c' for camera and 'w' for world 121 | 122 | Returns 123 | ------- 124 | points : torch.tensor [B,3,H,W] 125 | Pixel-wise 3D points 126 | """ 127 | B, C, H, W = depth.shape 128 | assert C == 1 129 | 130 | # Create flat index grid 131 | grid = image_grid(B, H, W, depth.dtype, depth.device, normalized=False) # [B,3,H,W] 132 | flat_grid = grid.view(B, 3, -1) # [B,3,HW] 133 | 134 | # Estimate the outward rays in the camera frame 135 | xnorm = (self.Kinv.bmm(flat_grid)).view(B, 3, H, W) 136 | # Scale rays to metric depth 137 | Xc = xnorm * depth 138 | 139 | # If in camera frame of reference 140 | if frame == 'c': 141 | return Xc 142 | # If in world frame of reference 143 | elif frame == 'w': 144 | return self.Twc @ Xc 145 | # If none of the above 146 | else: 147 | raise ValueError('Unknown reference frame {}'.format(frame)) 148 | 149 | def project(self, X, frame='w', normalize=True): 150 | """ 151 | Projects 3D points onto the image plane 152 | 153 | Parameters 154 | ---------- 155 | X : torch.Tensor [B,3,H,W] 156 | 3D points to be projected 157 | frame : 'w' 158 | Reference frame: 'c' for camera and 'w' for world 159 | 160 | Returns 161 | ------- 162 | points : torch.Tensor [B,H,W,2] 163 | 2D projected points that are within the image boundaries 164 | """ 165 | B, C, H, W = X.shape 166 | assert C == 3 167 | 168 | # Project 3D points onto the camera image plane 169 | if frame == 'c': 170 | Xc = self.K.bmm(X.view(B, 3, -1)) 171 | elif frame == 'w': 172 | Xc = self.K.bmm((self.Tcw @ X).view(B, 3, -1)) 173 | else: 174 | raise ValueError('Unknown reference frame {}'.format(frame)) 175 | 176 | # Normalize points 177 | X = Xc[:, 0] 178 | Y = Xc[:, 1] 179 | Z = Xc[:, 2].clamp(min=1e-5) 180 | if normalize: 181 | Xnorm = 2 * (X / Z) / (W - 1) - 1. #(-1, 1) 182 | Ynorm = 2 * (Y / Z) / (H - 1) - 1. 183 | else: 184 | Xnorm = X / Z 185 | Ynorm = Y / Z 186 | 187 | # Clamp out-of-bounds pixels 188 | # Xmask = ((Xnorm > 1) + (Xnorm < -1)).detach() 189 | # Xnorm[Xmask] = 2. 190 | # Ymask = ((Ynorm > 1) + (Ynorm < -1)).detach() 191 | # Ynorm[Ymask] = 2. 192 | 193 | # Return pixel coordinates 194 | return torch.stack([Xnorm, Ynorm], dim=-1).view(B, H, W, 2) 195 | -------------------------------------------------------------------------------- /dro_sfm/geometry/camera_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as funct 4 | 5 | ######################################################################################################################## 6 | 7 | def construct_K(fx, fy, cx, cy, dtype=torch.float, device=None): 8 | """Construct a [3,3] camera intrinsics from pinhole parameters""" 9 | return torch.tensor([[fx, 0, cx], 10 | [ 0, fy, cy], 11 | [ 0, 0, 1]], dtype=dtype, device=device) 12 | 13 | def scale_intrinsics(K, x_scale, y_scale): 14 | """Scale intrinsics given x_scale and y_scale factors""" 15 | K[..., 0, 0] *= x_scale 16 | K[..., 1, 1] *= y_scale 17 | K[..., 0, 2] = (K[..., 0, 2] + 0.5) * x_scale - 0.5 18 | K[..., 1, 2] = (K[..., 1, 2] + 0.5) * y_scale - 0.5 19 | return K 20 | 21 | ######################################################################################################################## 22 | 23 | def view_synthesis(ref_image, depth, ref_cam, cam, 24 | mode='bilinear', padding_mode='zeros'): 25 | """ 26 | Synthesize an image from another plus a depth map. 27 | 28 | Parameters 29 | ---------- 30 | ref_image : torch.Tensor [B,3,H,W] 31 | Reference image to be warped 32 | depth : torch.Tensor [B,1,H,W] 33 | Depth map from the original image 34 | ref_cam : Camera 35 | Camera class for the reference image 36 | cam : Camera 37 | Camera class for the original image 38 | mode : str 39 | Interpolation mode 40 | padding_mode : str 41 | Padding mode for interpolation 42 | 43 | Returns 44 | ------- 45 | ref_warped : torch.Tensor [B,3,H,W] 46 | Warped reference image in the original frame of reference 47 | """ 48 | assert depth.size(1) == 1 49 | # Reconstruct world points from target_camera 50 | world_points = cam.reconstruct(depth, frame='w') 51 | # Project world points onto reference camera 52 | ref_coords = ref_cam.project(world_points, frame='w') 53 | 54 | # View-synthesis given the projected reference points 55 | return funct.grid_sample(ref_image, ref_coords, mode=mode, 56 | padding_mode=padding_mode, align_corners=True) 57 | 58 | ######################################################################################################################## 59 | 60 | -------------------------------------------------------------------------------- /dro_sfm/geometry/pose.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from dro_sfm.geometry.pose_utils import invert_pose, pose_vec2mat 4 | 5 | ######################################################################################################################## 6 | 7 | class Pose: 8 | """ 9 | Pose class, that encapsulates a [4,4] transformation matrix 10 | for a specific reference frame 11 | """ 12 | def __init__(self, mat): 13 | """ 14 | Initializes a Pose object. 15 | 16 | Parameters 17 | ---------- 18 | mat : torch.Tensor [B,4,4] 19 | Transformation matrix 20 | """ 21 | assert tuple(mat.shape[-2:]) == (4, 4) 22 | if mat.dim() == 2: 23 | mat = mat.unsqueeze(0) 24 | assert mat.dim() == 3 25 | self.mat = mat 26 | 27 | def __len__(self): 28 | """Batch size of the transformation matrix""" 29 | return len(self.mat) 30 | 31 | ######################################################################################################################## 32 | 33 | @classmethod 34 | def identity(cls, N=1, device=None, dtype=torch.float): 35 | """Initializes as a [4,4] identity matrix""" 36 | return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1])) 37 | 38 | @classmethod 39 | def from_vec(cls, vec, mode): 40 | """Initializes from a [B,6] batch vector""" 41 | mat = pose_vec2mat(vec, mode) # [B,3,4] 42 | pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1]) 43 | pose[:, :3, :3] = mat[:, :3, :3] 44 | pose[:, :3, -1] = mat[:, :3, -1] 45 | return cls(pose) 46 | 47 | ######################################################################################################################## 48 | 49 | @property 50 | def shape(self): 51 | """Returns the transformation matrix shape""" 52 | return self.mat.shape 53 | 54 | def item(self): 55 | """Returns the transformation matrix""" 56 | return self.mat 57 | 58 | def repeat(self, *args, **kwargs): 59 | """Repeats the transformation matrix multiple times""" 60 | self.mat = self.mat.repeat(*args, **kwargs) 61 | return self 62 | 63 | def inverse(self): 64 | """Returns a new Pose that is the inverse of this one""" 65 | return Pose(invert_pose(self.mat)) 66 | 67 | def to(self, *args, **kwargs): 68 | """Moves object to a specific device""" 69 | self.mat = self.mat.to(*args, **kwargs) 70 | return self 71 | 72 | ######################################################################################################################## 73 | 74 | def transform_pose(self, pose): 75 | """Creates a new pose object that compounds this and another one (self * pose)""" 76 | assert tuple(pose.shape[-2:]) == (4, 4) 77 | return Pose(self.mat.bmm(pose.item())) 78 | 79 | def transform_points(self, points): 80 | """Transforms 3D points using this object""" 81 | assert points.shape[1] == 3 82 | B, _, H, W = points.shape 83 | out = self.mat[:,:3,:3].bmm(points.view(B, 3, -1)) + \ 84 | self.mat[:,:3,-1].unsqueeze(-1) 85 | return out.view(B, 3, H, W) 86 | 87 | def __matmul__(self, other): 88 | """Transforms the input (Pose or 3D points) using this object""" 89 | if isinstance(other, Pose): 90 | return self.transform_pose(other) 91 | elif isinstance(other, torch.Tensor): 92 | if other.shape[1] == 3 and other.dim() > 2: 93 | assert other.dim() == 3 or other.dim() == 4 94 | return self.transform_points(other) 95 | else: 96 | raise ValueError('Unknown tensor dimensions {}'.format(other.shape)) 97 | else: 98 | raise NotImplementedError() 99 | 100 | ######################################################################################################################## 101 | -------------------------------------------------------------------------------- /dro_sfm/geometry/pose_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from dro_sfm.geometry.pose_trans import axis_angle_to_matrix 3 | import torch 4 | import numpy as np 5 | from torch._C import dtype 6 | 7 | def mat2euler(mat): 8 | euler = torch.ones(mat.shape[0], 3, dtype=mat.dtype, device=mat.device) 9 | cy_thresh = 1e-6 10 | # try: 11 | # cy_thresh = np.finfo(mat.dtype).eps * 4 12 | # except ValueError: 13 | # cy_thresh = np.finfo(np.float).eps * 4.0 14 | # print("cy_thresh", cy_thresh) 15 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = mat[:, 0, 0], mat[:, 0, 1], mat[:, 0, 2], \ 16 | mat[:, 1, 0], mat[:, 1, 1], mat[:, 1, 2], \ 17 | mat[:, 2, 0], mat[:, 2, 1], mat[:, 2, 2] 18 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) 19 | cy = torch.sqrt(r33 * r33 + r23 * r23) 20 | 21 | mask = cy > cy_thresh 22 | 23 | if torch.sum(mask) > 1: 24 | euler[mask, 0] = torch.atan2(-r23, r33)[mask] 25 | euler[mask, 1] = torch.atan2(r13, cy)[mask] 26 | euler[mask, 2] = torch.atan2(-r12, r11)[mask] 27 | 28 | mask = cy <= cy_thresh 29 | if torch.sum(mask) > 1: 30 | print("mat2euler!!!!!!") 31 | euler[mask, 0] = 0.0 32 | euler[mask, 1] = torch.atan2(r13, cy) # atan2(sin(y), cy) 33 | euler[mask, 2] = torch.atan2(r21, r22) 34 | 35 | return euler 36 | 37 | 38 | ######################################################################################################################## 39 | 40 | def euler2mat(angle): 41 | """Convert euler angles to rotation matrix""" 42 | B = angle.size(0) 43 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] 44 | 45 | cosz = torch.cos(z) 46 | sinz = torch.sin(z) 47 | 48 | zeros = z.detach() * 0 49 | ones = zeros.detach() + 1 50 | zmat = torch.stack([cosz, -sinz, zeros, 51 | sinz, cosz, zeros, 52 | zeros, zeros, ones], dim=1).view(B, 3, 3) 53 | 54 | cosy = torch.cos(y) 55 | siny = torch.sin(y) 56 | 57 | ymat = torch.stack([cosy, zeros, siny, 58 | zeros, ones, zeros, 59 | -siny, zeros, cosy], dim=1).view(B, 3, 3) 60 | 61 | cosx = torch.cos(x) 62 | sinx = torch.sin(x) 63 | 64 | xmat = torch.stack([ones, zeros, zeros, 65 | zeros, cosx, -sinx, 66 | zeros, sinx, cosx], dim=1).view(B, 3, 3) 67 | 68 | rot_mat = xmat.bmm(ymat).bmm(zmat) 69 | return rot_mat 70 | 71 | ######################################################################################################################## 72 | 73 | def pose_vec2mat(vec, mode='euler'): 74 | """Convert Euler parameters to transformation matrix.""" 75 | if mode is None: 76 | return vec 77 | trans, rot = vec[:, :3].unsqueeze(-1), vec[:, 3:] 78 | if mode == 'euler': 79 | rot_mat = euler2mat(rot) 80 | elif mode == 'axis_angle': 81 | rot_mat = axis_angle_to_matrix(rot) 82 | else: 83 | raise ValueError('Rotation mode not supported {}'.format(mode)) 84 | mat = torch.cat([rot_mat, trans], dim=2) # [B,3,4] 85 | return mat 86 | 87 | ######################################################################################################################## 88 | 89 | def invert_pose(T): 90 | """Inverts a [B,4,4] torch.tensor pose""" 91 | Tinv = torch.eye(4, device=T.device, dtype=T.dtype).repeat([len(T), 1, 1]) 92 | Tinv[:, :3, :3] = torch.transpose(T[:, :3, :3], -2, -1) 93 | Tinv[:, :3, -1] = torch.bmm(-1. * Tinv[:, :3, :3], T[:, :3, -1].unsqueeze(-1)).squeeze(-1) 94 | return Tinv 95 | 96 | ######################################################################################################################## 97 | 98 | def invert_pose_numpy(T): 99 | """Inverts a [4,4] np.array pose""" 100 | Tinv = np.copy(T) 101 | R, t = Tinv[:3, :3], Tinv[:3, 3] 102 | Tinv[:3, :3], Tinv[:3, 3] = R.T, - np.matmul(R.T, t) 103 | return Tinv 104 | 105 | ######################################################################################################################## 106 | -------------------------------------------------------------------------------- /dro_sfm/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | from dro_sfm.loggers.wandb_logger import WandbLogger 2 | 3 | __all__ = ["WandbLogger"] -------------------------------------------------------------------------------- /dro_sfm/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/losses/__init__.py -------------------------------------------------------------------------------- /dro_sfm/losses/loss_base.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch.nn as nn 4 | from dro_sfm.utils.types import is_list 5 | 6 | ######################################################################################################################## 7 | 8 | class ProgressiveScaling: 9 | """ 10 | Helper class to manage progressive scaling. 11 | After a certain training progress percentage, decrease the number of scales by 1. 12 | 13 | Parameters 14 | ---------- 15 | progressive_scaling : float 16 | Training progress percentage where the number of scales is decreased 17 | num_scales : int 18 | Initial number of scales 19 | """ 20 | def __init__(self, progressive_scaling, num_scales=4): 21 | self.num_scales = num_scales 22 | # Use it only if bigger than zero (make a list) 23 | if progressive_scaling > 0.0: 24 | self.progressive_scaling = np.float32( 25 | [progressive_scaling * (i + 1) for i in range(num_scales - 1)] + [1.0]) 26 | # Otherwise, disable it 27 | else: 28 | self.progressive_scaling = progressive_scaling 29 | def __call__(self, progress): 30 | """ 31 | Call for an update in the number of scales 32 | 33 | Parameters 34 | ---------- 35 | progress : float 36 | Training progress percentage 37 | 38 | Returns 39 | ------- 40 | num_scales : int 41 | New number of scales 42 | """ 43 | if is_list(self.progressive_scaling): 44 | return int(self.num_scales - 45 | np.searchsorted(self.progressive_scaling, progress)) 46 | else: 47 | return self.num_scales 48 | 49 | ######################################################################################################################## 50 | 51 | class LossBase(nn.Module): 52 | """Base class for losses.""" 53 | def __init__(self): 54 | """Initializes logs and metrics dictionaries""" 55 | super().__init__() 56 | self._logs = {} 57 | self._metrics = {} 58 | 59 | ######################################################################################################################## 60 | 61 | @property 62 | def logs(self): 63 | """Return logs.""" 64 | return self._logs 65 | 66 | @property 67 | def metrics(self): 68 | """Return metrics.""" 69 | return self._metrics 70 | 71 | def add_metric(self, key, val): 72 | """Add a new metric to the dictionary and detach it.""" 73 | self._metrics[key] = val.detach() 74 | 75 | ######################################################################################################################## 76 | -------------------------------------------------------------------------------- /dro_sfm/models/SelfSupModelMF.py: -------------------------------------------------------------------------------- 1 | from dro_sfm.models.SfmModelMF import SfmModelMF 2 | from dro_sfm.losses.multiview_photometric_loss_mf import MultiViewPhotometricDecayLoss 3 | from dro_sfm.models.model_utils import merge_outputs 4 | 5 | 6 | class SelfSupModelMF(SfmModelMF): 7 | """ 8 | Model that inherits a depth and pose network from SfmModel and 9 | includes the photometric loss for self-supervised training. 10 | 11 | Parameters 12 | ---------- 13 | kwargs : dict 14 | Extra parameters 15 | """ 16 | def __init__(self, **kwargs): 17 | # Initializes SfmModel 18 | super().__init__(**kwargs) 19 | # Initializes the photometric loss 20 | self._photometric_loss = MultiViewPhotometricDecayLoss(**kwargs) 21 | 22 | @property 23 | def logs(self): 24 | """Return logs.""" 25 | return { 26 | **super().logs, 27 | **self._photometric_loss.logs 28 | } 29 | 30 | def self_supervised_loss(self, image, ref_images, inv_depths, poses, 31 | intrinsics, return_logs=False, progress=0.0): 32 | """ 33 | Calculates the self-supervised photometric loss. 34 | 35 | Parameters 36 | ---------- 37 | image : torch.Tensor [B,3,H,W] 38 | Original image 39 | ref_images : list of torch.Tensor [B,3,H,W] 40 | Reference images from context 41 | inv_depths : torch.Tensor [B,1,H,W] 42 | Predicted inverse depth maps from the original image 43 | poses : list of Pose 44 | List containing predicted poses between original and context images 45 | intrinsics : torch.Tensor [B,3,3] 46 | Camera intrinsics 47 | return_logs : bool 48 | True if logs are stored 49 | progress : 50 | Training progress percentage 51 | 52 | Returns 53 | ------- 54 | output : dict 55 | Dictionary containing a "loss" scalar a "metrics" dictionary 56 | """ 57 | return self._photometric_loss( 58 | image, ref_images, inv_depths, intrinsics, intrinsics, poses, 59 | return_logs=return_logs, progress=progress) 60 | 61 | def forward(self, batch, return_logs=False, progress=0.0): 62 | """ 63 | Processes a batch. 64 | 65 | Parameters 66 | ---------- 67 | batch : dict 68 | Input batch 69 | return_logs : bool 70 | True if logs are stored 71 | progress : 72 | Training progress percentage 73 | 74 | Returns 75 | ------- 76 | output : dict 77 | Dictionary containing a "loss" scalar and different metrics and predictions 78 | for logging and downstream usage. 79 | """ 80 | # Calculate predicted depth and pose output 81 | output = super().forward(batch, return_logs=return_logs) 82 | if not self.training: 83 | # If not training, no need for self-supervised loss 84 | return output 85 | else: 86 | if output["poses"] is None: 87 | return None 88 | # Otherwise, calculate self-supervised loss 89 | self_sup_output = self.self_supervised_loss( 90 | batch['rgb_original'], batch['rgb_context_original'], 91 | output['inv_depths'], output['poses'], batch['intrinsics'], 92 | return_logs=return_logs, progress=progress) 93 | # Return loss and metrics 94 | return { 95 | 'loss': self_sup_output['loss'], 96 | **merge_outputs(output, self_sup_output), 97 | } 98 | -------------------------------------------------------------------------------- /dro_sfm/models/SemiSupModelMF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dro_sfm.models.SelfSupModelMF import SelfSupModelMF, SfmModelMF 3 | from dro_sfm.losses.supervised_loss import SupervisedDepthPoseLoss as SupervisedLoss 4 | from dro_sfm.models.model_utils import merge_outputs 5 | from dro_sfm.utils.depth import depth2inv 6 | 7 | 8 | class SemiSupModelMFPose(SelfSupModelMF): 9 | """ 10 | Model that inherits a depth and pose networks, plus the self-supervised loss from 11 | SelfSupModel and includes a supervised loss for semi-supervision. 12 | 13 | Parameters 14 | ---------- 15 | supervised_loss_weight : float 16 | Weight for the supervised loss 17 | kwargs : dict 18 | Extra parameters 19 | """ 20 | def __init__(self, supervised_loss_weight=0.9, **kwargs): 21 | # Initializes SelfSupModel 22 | super().__init__(**kwargs) 23 | # If supervision weight is 0.0, use SelfSupModel directly 24 | assert 0. < supervised_loss_weight <= 1., "Model requires (0, 1] supervision" 25 | # Store weight and initializes supervised loss 26 | self.supervised_loss_weight = supervised_loss_weight 27 | self._supervised_loss = SupervisedLoss(**kwargs) 28 | 29 | print(f"=================supervised_loss_weight:{supervised_loss_weight}====================") 30 | # # Pose network is only required if there is self-supervision 31 | # self._network_requirements['pose_net'] = self.supervised_loss_weight < 1 32 | # # GT depth is only required if there is supervision 33 | self._train_requirements['gt_depth'] = self.supervised_loss_weight > 0 34 | self._train_requirements['gt_pose'] = True 35 | 36 | @property 37 | def logs(self): 38 | """Return logs.""" 39 | return { 40 | **super().logs, 41 | **self._supervised_loss.logs 42 | } 43 | 44 | def supervised_loss(self, image, ref_images, inv_depths, gt_depth, gt_poses, poses, 45 | intrinsics, return_logs=False, progress=0.0): 46 | """ 47 | Calculates the self-supervised photometric loss. 48 | 49 | Parameters 50 | ---------- 51 | image : torch.Tensor [B,3,H,W] 52 | Original image 53 | ref_images : list of torch.Tensor [B,3,H,W] 54 | Reference images from context 55 | inv_depths : torch.Tensor [B,1,H,W] 56 | Predicted inverse depth maps from the original image 57 | poses : list of Pose 58 | List containing predicted poses between original and context images 59 | intrinsics : torch.Tensor [B,3,3] 60 | Camera intrinsics 61 | return_logs : bool 62 | True if logs are stored 63 | progress : 64 | Training progress percentage 65 | 66 | Returns 67 | ------- 68 | output : dict 69 | Dictionary containing a "loss" scalar a "metrics" dictionary 70 | """ 71 | return self._supervised_loss( 72 | image, ref_images, inv_depths, depth2inv(gt_depth), gt_poses, intrinsics, intrinsics, poses, 73 | return_logs=return_logs, progress=progress) 74 | 75 | def forward(self, batch, return_logs=False, progress=0.0): 76 | """ 77 | Processes a batch. 78 | 79 | Parameters 80 | ---------- 81 | batch : dict 82 | Input batch 83 | return_logs : bool 84 | True if logs are stored 85 | progress : 86 | Training progress percentage 87 | 88 | Returns 89 | ------- 90 | output : dict 91 | Dictionary containing a "loss" scalar and different metrics and predictions 92 | for logging and downstream usage. 93 | """ 94 | if not self.training: 95 | # If not training, no need for self-supervised loss 96 | return SfmModelMF.forward(self, batch) 97 | else: 98 | if self.supervised_loss_weight == 1.: 99 | # If no self-supervision, no need to calculate loss 100 | self_sup_output = SfmModelMF.forward(self, batch) 101 | loss = torch.tensor([0.]).type_as(batch['rgb']) 102 | else: 103 | # Otherwise, calculate and weight self-supervised loss 104 | self_sup_output = SelfSupModelMF.forward(self, batch) 105 | loss = (1.0 - self.supervised_loss_weight) * self_sup_output['loss'] 106 | # Calculate and weight supervised loss 107 | sup_output = self.supervised_loss( 108 | batch['rgb_original'], batch['rgb_context_original'], 109 | self_sup_output['inv_depths'], batch['depth'], batch['pose_context'], self_sup_output['poses'], batch['intrinsics'], 110 | return_logs=return_logs, progress=progress) 111 | loss += self.supervised_loss_weight * sup_output['loss'] 112 | # Merge and return outputs 113 | return { 114 | 'loss': loss, 115 | **merge_outputs(self_sup_output, sup_output), 116 | } -------------------------------------------------------------------------------- /dro_sfm/models/SfmModel.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | from kornia.geometry.epipolar.projection import depth 4 | import torch.nn as nn 5 | from dro_sfm.utils.image import flip_model, interpolate_scales 6 | from dro_sfm.geometry.pose import Pose 7 | from dro_sfm.utils.misc import make_list 8 | from dro_sfm.utils.depth import inv2depth 9 | 10 | 11 | class SfmModel(nn.Module): 12 | """ 13 | Model class encapsulating a pose and depth networks. 14 | 15 | Parameters 16 | ---------- 17 | depth_net : nn.Module 18 | Depth network to be used 19 | pose_net : nn.Module 20 | Pose network to be used 21 | rotation_mode : str 22 | Rotation mode for the pose network 23 | flip_lr_prob : float 24 | Probability of flipping when using the depth network 25 | upsample_depth_maps : bool 26 | True if depth map scales are upsampled to highest resolution 27 | kwargs : dict 28 | Extra parameters 29 | """ 30 | def __init__(self, depth_net=None, pose_net=None, 31 | rotation_mode='euler', flip_lr_prob=0.0, 32 | upsample_depth_maps=False, **kwargs): 33 | super().__init__() 34 | self.depth_net = depth_net 35 | self.pose_net = pose_net 36 | self.rotation_mode = rotation_mode 37 | self.flip_lr_prob = flip_lr_prob 38 | self.upsample_depth_maps = upsample_depth_maps 39 | self._logs = {} 40 | self._losses = {} 41 | 42 | self._network_requirements = { 43 | 'depth_net': True, # Depth network required 44 | 'pose_net': True, # Pose network required 45 | 'percep_net': False, # Pose network required 46 | } 47 | self._train_requirements = { 48 | 'gt_depth': False, # No ground-truth depth required 49 | 'gt_pose': False, # No ground-truth pose required 50 | } 51 | 52 | @property 53 | def logs(self): 54 | """Return logs.""" 55 | return self._logs 56 | 57 | @property 58 | def losses(self): 59 | """Return metrics.""" 60 | return self._losses 61 | 62 | def add_loss(self, key, val): 63 | """Add a new loss to the dictionary and detaches it.""" 64 | self._losses[key] = val.detach() 65 | 66 | @property 67 | def network_requirements(self): 68 | """ 69 | Networks required to run the model 70 | 71 | Returns 72 | ------- 73 | requirements : dict 74 | depth_net : bool 75 | Whether a depth network is required by the model 76 | pose_net : bool 77 | Whether a depth network is required by the model 78 | """ 79 | return self._network_requirements 80 | 81 | @property 82 | def train_requirements(self): 83 | """ 84 | Information required by the model at training stage 85 | 86 | Returns 87 | ------- 88 | requirements : dict 89 | gt_depth : bool 90 | Whether ground truth depth is required by the model at training time 91 | gt_pose : bool 92 | Whether ground truth pose is required by the model at training time 93 | """ 94 | return self._train_requirements 95 | 96 | def add_depth_net(self, depth_net): 97 | """Add a depth network to the model""" 98 | self.depth_net = depth_net 99 | 100 | def add_pose_net(self, pose_net): 101 | """Add a pose network to the model""" 102 | self.pose_net = pose_net 103 | 104 | def compute_inv_depths(self, image): 105 | """Computes inverse depth maps from single images""" 106 | # Randomly flip and estimate inverse depth maps 107 | flip_lr = random.random() < self.flip_lr_prob if self.training else False 108 | inv_depths = make_list(flip_model(self.depth_net, image, flip_lr)) 109 | # If upsampling depth maps 110 | if self.upsample_depth_maps: 111 | inv_depths = interpolate_scales( 112 | inv_depths, mode='nearest', align_corners=None) 113 | # Return inverse depth maps 114 | return inv_depths 115 | 116 | def compute_poses(self, image, contexts, intrinsics, depth): 117 | """Compute poses from image and a sequence of context images""" 118 | pose_vec = self.pose_net(image, contexts, intrinsics, depth) 119 | if pose_vec is None: 120 | return None 121 | if pose_vec.shape[2] == 6: 122 | return [Pose.from_vec(pose_vec[:, i], self.rotation_mode) 123 | for i in range(pose_vec.shape[1])] 124 | else: 125 | return [Pose(pose_vec[:, i]) for i in range(pose_vec.shape[1])] 126 | 127 | def forward(self, batch, return_logs=False): 128 | """ 129 | Processes a batch. 130 | 131 | Parameters 132 | ---------- 133 | batch : dict 134 | Input batch 135 | return_logs : bool 136 | True if logs are stored 137 | 138 | Returns 139 | ------- 140 | output : dict 141 | Dictionary containing predicted inverse depth maps and poses 142 | """ 143 | # Generate inverse depth predictions 144 | inv_depths = self.compute_inv_depths(batch['rgb']) 145 | # Generate pose predictions if available 146 | pose = None 147 | if 'rgb_context' in batch and self.pose_net is not None: 148 | pose = self.compute_poses(batch['rgb'], 149 | batch['rgb_context'], batch["intrinsics"], inv2depth(inv_depths[0])) 150 | # Return output dictionary 151 | return { 152 | 'inv_depths': inv_depths, 153 | 'poses': pose, 154 | } 155 | -------------------------------------------------------------------------------- /dro_sfm/models/SfmModelMF.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import torch.nn as nn 4 | from dro_sfm.utils.image import flip_lr_intr, flip_mf_model, interpolate_scales 5 | from dro_sfm.utils.image import flip_lr as flip_lr_img 6 | from dro_sfm.geometry.pose import Pose 7 | from dro_sfm.utils.misc import make_list 8 | 9 | 10 | class SfmModelMF(nn.Module): 11 | """ 12 | Model class encapsulating a pose and depth networks. 13 | 14 | Parameters 15 | ---------- 16 | depth_net : nn.Module 17 | Depth network to be used 18 | pose_net : nn.Module 19 | Pose network to be used 20 | rotation_mode : str 21 | Rotation mode for the pose network 22 | flip_lr_prob : float 23 | Probability of flipping when using the depth network 24 | upsample_depth_maps : bool 25 | True if depth map scales are upsampled to highest resolution 26 | kwargs : dict 27 | Extra parameters 28 | """ 29 | def __init__(self, depth_net=None, pose_net=None, 30 | rotation_mode='euler', flip_lr_prob=0.0, 31 | upsample_depth_maps=False, min_depth=0.1, max_depth=100, **kwargs): 32 | super().__init__() 33 | self.depth_net = depth_net 34 | self.pose_net = pose_net 35 | self.rotation_mode = rotation_mode 36 | self.flip_lr_prob = flip_lr_prob 37 | self.upsample_depth_maps = upsample_depth_maps 38 | self.min_depth = min_depth 39 | self.max_depth = max_depth 40 | self._logs = {} 41 | self._losses = {} 42 | self._network_requirements = { 43 | 'depth_net': True, # Depth network required 44 | 'pose_net': False, # Pose network required 45 | 'percep_net': False, # Pose network required 46 | } 47 | self._train_requirements = { 48 | 'gt_depth': False, # No ground-truth depth required 49 | 'gt_pose': False, # No ground-truth pose required 50 | } 51 | 52 | @property 53 | def logs(self): 54 | """Return logs.""" 55 | return self._logs 56 | 57 | @property 58 | def losses(self): 59 | """Return metrics.""" 60 | return self._losses 61 | 62 | def add_loss(self, key, val): 63 | """Add a new loss to the dictionary and detaches it.""" 64 | self._losses[key] = val.detach() 65 | 66 | @property 67 | def network_requirements(self): 68 | """ 69 | Networks required to run the model 70 | 71 | Returns 72 | ------- 73 | requirements : dict 74 | depth_net : bool 75 | Whether a depth network is required by the model 76 | pose_net : bool 77 | Whether a depth network is required by the model 78 | """ 79 | return self._network_requirements 80 | 81 | @property 82 | def train_requirements(self): 83 | """ 84 | Information required by the model at training stage 85 | 86 | Returns 87 | ------- 88 | requirements : dict 89 | gt_depth : bool 90 | Whether ground truth depth is required by the model at training time 91 | gt_pose : bool 92 | Whether ground truth pose is required by the model at training time 93 | """ 94 | return self._train_requirements 95 | 96 | def add_depth_net(self, depth_net): 97 | """Add a depth network to the model""" 98 | self.depth_net = depth_net 99 | 100 | def add_pose_net(self, pose_net): 101 | """Add a pose network to the model""" 102 | self.pose_net = pose_net 103 | 104 | def compute_inv_depths(self, image, ref_imgs, intrinsics): 105 | """Computes inverse depth maps from single images""" 106 | # Randomly flip and estimate inverse depth maps 107 | flip_lr = random.random() < self.flip_lr_prob if self.training else False 108 | if flip_lr: 109 | intrinsics = flip_lr_intr(intrinsics, width=image.shape[3]) 110 | inv_depths_with_poses = flip_mf_model(self.depth_net, image, ref_imgs, intrinsics, flip_lr) 111 | inv_depths, poses = inv_depths_with_poses 112 | inv_depths = make_list(inv_depths) 113 | if flip_lr: 114 | inv_depths = [flip_lr_img(inv_d) for inv_d in inv_depths] 115 | # If upsampling depth maps 116 | if self.upsample_depth_maps: 117 | inv_depths = interpolate_scales( 118 | inv_depths, mode='nearest', align_corners=None) 119 | # Return inverse depth maps 120 | return inv_depths, poses 121 | 122 | def compute_poses(self, image, contexts, intrinsics, depth): 123 | """Compute poses from image and a sequence of context images""" 124 | pose_vec = self.pose_net(image, contexts, intrinsics, depth) 125 | if pose_vec is None: 126 | return None 127 | if pose_vec.shape[2] == 6: 128 | return [Pose.from_vec(pose_vec[:, i], self.rotation_mode) 129 | for i in range(pose_vec.shape[1])] 130 | else: 131 | return [Pose(pose_vec[:, i]) for i in range(pose_vec.shape[1])] 132 | 133 | def forward(self, batch, return_logs=False): 134 | """ 135 | Processes a batch. 136 | 137 | Parameters 138 | ---------- 139 | batch : dict 140 | Input batch 141 | return_logs : bool 142 | True if logs are stored 143 | 144 | Returns 145 | ------- 146 | output : dict 147 | Dictionary containing predicted inverse depth maps and poses 148 | """ 149 | # Generate inverse depth predictions 150 | inv_depths, pose_vec = self.compute_inv_depths(batch['rgb'], batch['rgb_context'], batch["intrinsics"]) 151 | # # Generate pose predictions if available 152 | # pose = None 153 | # if 'rgb_context' in batch and self.pose_net is not None: 154 | # pose = self.compute_poses(batch['rgb'], 155 | # batch['rgb_context'], batch["intrinsics"], inv2depth(inv_depths[0])) 156 | # Return output dictionary 157 | if pose_vec.shape[2] == 6: 158 | poses = [Pose.from_vec(pose_vec[:, i], self.rotation_mode) 159 | for i in range(pose_vec.shape[1])] 160 | elif (pose_vec.shape[2]) == 4 and (pose_vec.shape[3] == 4): 161 | poses = [Pose(pose_vec[:, i]) for i in range(pose_vec.shape[1])] 162 | else: 163 | #pose_vec shape: (b, n_view, n_iter, 6) 164 | poses = [] 165 | for i in range(pose_vec.shape[1]): 166 | poses_view = [] 167 | for j in range(pose_vec.shape[2]): 168 | poses_view.append(Pose.from_vec(pose_vec[:, i, j], self.rotation_mode)) 169 | poses.append(poses_view) #([pose_view1, pose_view2, ....]) each view has n_iter pose 170 | 171 | # print(poses[0][-1].shape, len(poses), len(poses[0]), len(inv_depths), inv_depths[0].shape) 172 | # print(poses[0][-1].mat[0], inv2depth(inv_depths)[-1][0, 0, 12, 40]) 173 | # print("gt", batch["pose_context"][0][0]) 174 | return { 175 | 'inv_depths': inv_depths, 176 | 'poses': poses, 177 | } -------------------------------------------------------------------------------- /dro_sfm/models/SupModelMF.py: -------------------------------------------------------------------------------- 1 | 2 | from dro_sfm.utils.depth import depth2inv 3 | from dro_sfm.losses.supervised_loss import SupervisedDepthPoseLoss 4 | from dro_sfm.models.SfmModelMF import SfmModelMF 5 | from dro_sfm.models.model_utils import merge_outputs 6 | 7 | 8 | class SupModelMF(SfmModelMF): 9 | """ 10 | Model that inherits a depth and pose network from SfmModel and 11 | includes the photometric loss for self-supervised training. 12 | 13 | Parameters 14 | ---------- 15 | kwargs : dict 16 | Extra parameters 17 | """ 18 | def __init__(self, **kwargs): 19 | # Initializes SfmModel 20 | super().__init__(**kwargs) 21 | # Initializes the photometric loss 22 | 23 | self._network_requirements = { 24 | 'depth_net': True, # Depth network required 25 | 'pose_net': False, # Pose network required 26 | 'percep_net': False, # Pose network required 27 | } 28 | 29 | self._train_requirements = { 30 | 'gt_depth': True, # No ground-truth depth required 31 | 'gt_pose': True, # No ground-truth pose required 32 | } 33 | 34 | # self._photometric_loss = MultiViewPhotometricLoss(**kwargs) 35 | self._loss = SupervisedDepthPoseLoss(**kwargs) 36 | 37 | @property 38 | def logs(self): 39 | """Return logs.""" 40 | return { 41 | **super().logs, 42 | **self._photometric_loss.logs 43 | } 44 | 45 | def supervised_loss(self, image, ref_images, inv_depths, gt_depth, gt_poses, poses, 46 | intrinsics, return_logs=False, progress=0.0): 47 | """ 48 | Calculates the self-supervised photometric loss. 49 | 50 | Parameters 51 | ---------- 52 | image : torch.Tensor [B,3,H,W] 53 | Original image 54 | ref_images : list of torch.Tensor [B,3,H,W] 55 | Reference images from context 56 | inv_depths : torch.Tensor [B,1,H,W] 57 | Predicted inverse depth maps from the original image 58 | poses : list of Pose 59 | List containing predicted poses between original and context images 60 | intrinsics : torch.Tensor [B,3,3] 61 | Camera intrinsics 62 | return_logs : bool 63 | True if logs are stored 64 | progress : 65 | Training progress percentage 66 | 67 | Returns 68 | ------- 69 | output : dict 70 | Dictionary containing a "loss" scalar a "metrics" dictionary 71 | """ 72 | return self._loss( 73 | image, ref_images, inv_depths, depth2inv(gt_depth), gt_poses, intrinsics, intrinsics, poses, 74 | return_logs=return_logs, progress=progress) 75 | 76 | def forward(self, batch, return_logs=False, progress=0.0): 77 | """ 78 | Processes a batch. 79 | 80 | Parameters 81 | ---------- 82 | batch : dict 83 | Input batch 84 | return_logs : bool 85 | True if logs are stored 86 | progress : 87 | Training progress percentage 88 | 89 | Returns 90 | ------- 91 | output : dict 92 | Dictionary containing a "loss" scalar and different metrics and predictions 93 | for logging and downstream usage. 94 | """ 95 | # Calculate predicted depth and pose output 96 | output = super().forward(batch, return_logs=return_logs) 97 | if not self.training: 98 | # If not training, no need for self-supervised loss 99 | return output 100 | else: 101 | if output["poses"] is None: 102 | return None 103 | # Otherwise, calculate self-supervised loss 104 | self_sup_output = self.supervised_loss( 105 | batch['rgb_original'], batch['rgb_context_original'], 106 | output['inv_depths'], batch['depth'], batch['pose_context'], output['poses'], batch['intrinsics'], 107 | return_logs=return_logs, progress=progress) 108 | # Return loss and metrics 109 | return { 110 | 'loss': self_sup_output['loss'], 111 | **merge_outputs(output, self_sup_output), 112 | } -------------------------------------------------------------------------------- /dro_sfm/models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Structure-from-Motion (SfM) Models and wrappers 3 | =============================================== 4 | 5 | - SfmModel is a torch.nn.Module wrapping both a Depth and a Pose network to enable training in a Structure-from-Motion setup (i.e. from videos) 6 | - SelfSupModel is an SfmModel specialized for self-supervised learning (using videos only) 7 | - SemiSupModel is an SfmModel specialized for semi-supervised learning (using videos and depth supervision) 8 | - ModelWrapper is a torch.nn.Module that wraps an SfmModel to enable easy training and eval with a trainer 9 | - ModelCheckpoint enables saving/restoring state of torch.nn.Module objects 10 | 11 | """ 12 | -------------------------------------------------------------------------------- /dro_sfm/models/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | 2 | # Adapted from Pytorch-Lightning 3 | # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/model_checkpoint.py 4 | 5 | import os, re 6 | import numpy as np 7 | import torch 8 | from dro_sfm.utils.logging import pcolor 9 | 10 | 11 | def sync_s3_data(local, model): 12 | """Sync saved models with the s3 bucket""" 13 | remote = os.path.join(model.config.checkpoint.s3_path, model.config.name) 14 | command = 'aws s3 sync {} {} --acl bucket-owner-full-control --quiet --delete'.format(local, remote) 15 | os.system(command) 16 | 17 | 18 | def save_code(filepath): 19 | """Save code in the models folder""" 20 | os.system('tar cfz {}/code.tar.gz *'.format(filepath)) 21 | 22 | 23 | class ModelCheckpoint: 24 | def __init__(self, filepath=None, monitor='val_loss', 25 | save_top_k=1, mode='auto', period=1, 26 | s3_path='', s3_frequency=5): 27 | super().__init__() 28 | # If save_top_k is zero, save all models 29 | if save_top_k == 0: 30 | save_top_k = 1e6 31 | # Create checkpoint folder 32 | self.dirpath, self.filename = os.path.split(filepath) 33 | print(self.dirpath, self.filename, filepath) 34 | os.makedirs(self.dirpath, exist_ok=True) 35 | # Store arguments 36 | self.monitor = monitor 37 | self.save_top_k = save_top_k 38 | self.period = period 39 | self.epoch_last_check = None 40 | self.best_k_models = {} 41 | self.kth_best_model = '' 42 | self.best = 0 43 | # Monitoring modes 44 | torch_inf = torch.tensor(np.Inf) 45 | mode_dict = { 46 | 'min': (torch_inf, 'min'), 47 | 'max': (-torch_inf, 'max'), 48 | 'auto': (-torch_inf, 'max') if \ 49 | 'acc' in self.monitor or \ 50 | 'a1' in self.monitor or \ 51 | self.monitor.startswith('fmeasure') 52 | else (torch_inf, 'min'), 53 | } 54 | self.kth_value, self.mode = mode_dict[mode] 55 | 56 | self.s3_path = s3_path 57 | self.s3_frequency = s3_frequency 58 | self.s3_enabled = s3_path is not '' and s3_frequency > 0 59 | self.save_code = True 60 | 61 | @staticmethod 62 | def _del_model(filepath): 63 | if os.path.isfile(filepath): 64 | os.remove(filepath) 65 | 66 | def _save_model(self, filepath, model): 67 | # Create folder, save model and sync to s3 68 | os.makedirs(os.path.dirname(filepath), exist_ok=True) 69 | torch.save({ 70 | 'config': model.config, 71 | 'epoch': model.current_epoch, 72 | 'state_dict': model.state_dict(), 73 | 'optimizer': model.optimizer.state_dict(), 74 | 'scheduler': model.scheduler.state_dict(), 75 | }, filepath) 76 | self._sync_s3(filepath, model) 77 | 78 | def _sync_s3(self, filepath, model): 79 | # If it's not time to sync, do nothing 80 | if self.s3_enabled and (model.current_epoch + 1) % self.s3_frequency == 0: 81 | filepath = os.path.dirname(filepath) 82 | # Print message and links 83 | print(pcolor('###### Syncing: {} -> {}'.format(filepath, 84 | model.config.checkpoint.s3_path), 'red', attrs=['bold'])) 85 | print(pcolor('###### URL: {}'.format( 86 | model.config.checkpoint.s3_url), 'red', attrs=['bold'])) 87 | # If it's time to save code 88 | if self.save_code: 89 | self.save_code = False 90 | save_code(filepath) 91 | # Sync model to s3 92 | sync_s3_data(filepath, model) 93 | 94 | def check_monitor_top_k(self, current): 95 | # If we don't have enough models 96 | if len(self.best_k_models) < self.save_top_k: 97 | return True 98 | # Convert to torch if necessary 99 | if not isinstance(current, torch.Tensor): 100 | current = torch.tensor(current) 101 | # Get monitoring operation 102 | monitor_op = { 103 | "min": torch.lt, 104 | "max": torch.gt, 105 | }[self.mode] 106 | # Compare and return 107 | return monitor_op(current, self.best_k_models[self.kth_best_model]) 108 | 109 | def format_checkpoint_name(self, epoch, metrics): 110 | metrics['epoch'] = epoch 111 | filename = self.filename 112 | for tmp in re.findall(r'(\{.*?)[:\}]', self.filename): 113 | name = tmp[1:] 114 | filename = filename.replace(tmp, name + '={' + name) 115 | if name not in metrics: 116 | metrics[name] = 0 117 | filename = filename.format(**metrics) 118 | return os.path.join(self.dirpath, '{}.ckpt'.format(filename)) 119 | 120 | def check_and_save(self, model, metrics): 121 | # Check saving interval 122 | epoch = model.current_epoch 123 | if self.epoch_last_check is not None and \ 124 | (epoch - self.epoch_last_check) < self.period: 125 | return 126 | self.epoch_last_check = epoch 127 | # Prepare filepath 128 | filepath = self.format_checkpoint_name(epoch, metrics) 129 | while os.path.isfile(filepath): 130 | filepath = self.format_checkpoint_name(epoch, metrics) 131 | # Check if saving or not 132 | if self.save_top_k != -1: 133 | current = metrics.get(self.monitor) 134 | assert current, 'Checkpoint metric is not available' 135 | if self.check_monitor_top_k(current): 136 | self._do_check_save(filepath, model, current) 137 | else: 138 | self._save_model(filepath, model) 139 | 140 | def _do_check_save(self, filepath, model, current): 141 | # List of models to delete 142 | del_list = [] 143 | if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0: 144 | delpath = self.kth_best_model 145 | self.best_k_models.pop(self.kth_best_model) 146 | del_list.append(delpath) 147 | # Monitor current models 148 | self.best_k_models[filepath] = current 149 | if len(self.best_k_models) == self.save_top_k: 150 | # Monitor dict has reached k elements 151 | _op = max if self.mode == 'min' else min 152 | self.kth_best_model = _op(self.best_k_models, 153 | key=self.best_k_models.get) 154 | self.kth_value = self.best_k_models[self.kth_best_model] 155 | # Determine best model 156 | _op = min if self.mode == 'min' else max 157 | self.best = _op(self.best_k_models.values()) 158 | # Delete old models 159 | for cur_path in del_list: 160 | if cur_path != filepath: 161 | self._del_model(cur_path) 162 | # Save model 163 | self._save_model(filepath, model) 164 | -------------------------------------------------------------------------------- /dro_sfm/models/model_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from dro_sfm.utils.types import is_tensor, is_list, is_numpy 3 | 4 | def merge_outputs(*outputs): 5 | """ 6 | Merges model outputs for logging 7 | 8 | Parameters 9 | ---------- 10 | outputs : tuple of dict 11 | Outputs to be merged 12 | 13 | Returns 14 | ------- 15 | output : dict 16 | Dictionary with a "metrics" key containing a dictionary with various metrics and 17 | all other keys that are not "loss" (it is handled differently). 18 | """ 19 | ignore = ['loss'] # Keys to ignore 20 | combine = ['metrics'] # Keys to combine 21 | merge = {key: {} for key in combine} 22 | for output in outputs: 23 | # Iterate over all keys 24 | for key, val in output.items(): 25 | # Combine these keys 26 | if key in combine: 27 | for sub_key, sub_val in output[key].items(): 28 | assert sub_key not in merge[key].keys(), \ 29 | 'Combining duplicated key {} to {}'.format(sub_key, key) 30 | merge[key][sub_key] = sub_val 31 | # Ignore these keys 32 | elif key not in ignore: 33 | assert key not in merge.keys(), \ 34 | 'Adding duplicated key {}'.format(key) 35 | merge[key] = val 36 | return merge 37 | 38 | 39 | def stack_batch(batch): 40 | """ 41 | Stack multi-camera batches (B,N,C,H,W becomes BN,C,H,W) 42 | 43 | Parameters 44 | ---------- 45 | batch : dict 46 | Batch 47 | 48 | Returns 49 | ------- 50 | batch : dict 51 | Stacked batch 52 | """ 53 | # If there is multi-camera information 54 | if len(batch['rgb'].shape) == 5: 55 | assert batch['rgb'].shape[0] == 1, 'Only batch size 1 is supported for multi-cameras' 56 | # Loop over all keys 57 | for key in batch.keys(): 58 | # If list, stack every item 59 | if is_list(batch[key]): 60 | if is_tensor(batch[key][0]) or is_numpy(batch[key][0]): 61 | batch[key] = [sample[0] for sample in batch[key]] 62 | # Else, stack single item 63 | else: 64 | batch[key] = batch[key][0] 65 | return batch 66 | -------------------------------------------------------------------------------- /dro_sfm/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/networks/__init__.py -------------------------------------------------------------------------------- /dro_sfm/networks/layers/PercepNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class PercepNet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, resize=True): 8 | super(PercepNet, self).__init__() 9 | mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) 10 | std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) 11 | self.register_buffer('mean_rgb', mean_rgb) 12 | self.register_buffer('std_rgb', std_rgb) 13 | self.resize = resize 14 | 15 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features 16 | self.slice1 = nn.Sequential() 17 | self.slice2 = nn.Sequential() 18 | self.slice3 = nn.Sequential() 19 | self.slice4 = nn.Sequential() 20 | for x in range(4): 21 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(4, 9): 23 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(9, 16): 25 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(16, 23): 27 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 28 | if not requires_grad: 29 | for param in self.parameters(): 30 | param.requires_grad = False 31 | 32 | def normalize(self, x): 33 | out = (x - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1) 34 | if self.resize: 35 | out = nn.functional.interpolate(out, mode='bilinear',size=(224, 224), align_corners=False) 36 | return out 37 | 38 | def forward(self, im1, im2): 39 | im = torch.cat([im1,im2], 0) 40 | im = self.normalize(im) # normalize input 41 | 42 | ## compute features 43 | feats = [] 44 | f = self.slice1(im) 45 | h, w = f.shape[-2:] 46 | feats += [torch.chunk(f, 2, dim=0)] 47 | f = self.slice2(f) 48 | feats += [torch.chunk(f, 2, dim=0)] 49 | f = self.slice3(f) 50 | feats += [torch.chunk(f, 2, dim=0)] 51 | f = self.slice4(f) 52 | feats += [torch.chunk(f, 2, dim=0)] 53 | 54 | losses = [] 55 | #weights = [0.3, 0.3, 0.4] 56 | #weights = [0.15, 0.25, 0.25, 0.25] 57 | weights = [0.15, 0.25, 0.6] 58 | #weights = [1.0, 1.0, 1.0] 59 | for i, (f1, f2) in enumerate(feats[0:3]): 60 | loss = weights[i] * torch.abs(f1-f2).mean(1, True) #(B, 1, H, W) 61 | loss = nn.functional.interpolate(loss, mode='bilinear',size=(h, w), align_corners=False) 62 | losses += [loss] 63 | 64 | return sum(losses) 65 | 66 | -------------------------------------------------------------------------------- /dro_sfm/networks/layers/resnet/layers.py: -------------------------------------------------------------------------------- 1 | 2 | # Adapted from monodepth2 3 | # https://github.com/nianticlabs/monodepth2/blob/master/layers.py 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def disp_to_depth(disp, min_depth, max_depth): 12 | """Convert network's sigmoid output into depth prediction 13 | The formula for this conversion is given in the 'additional considerations' 14 | section of the paper. 15 | """ 16 | min_disp = 1 / max_depth 17 | max_disp = 1 / min_depth 18 | scaled_disp = min_disp + (max_disp - min_disp) * disp 19 | depth = 1 / scaled_disp 20 | return scaled_disp, depth 21 | 22 | 23 | class ConvBlock(nn.Module): 24 | """Layer to perform a convolution followed by ELU 25 | """ 26 | def __init__(self, in_channels, out_channels): 27 | super(ConvBlock, self).__init__() 28 | 29 | self.conv = Conv3x3(in_channels, out_channels) 30 | self.nonlin = nn.ELU(inplace=True) 31 | 32 | def forward(self, x): 33 | out = self.conv(x) 34 | out = self.nonlin(out) 35 | return out 36 | 37 | 38 | class Conv3x3(nn.Module): 39 | """Layer to pad and convolve input 40 | """ 41 | def __init__(self, in_channels, out_channels, use_refl=True): 42 | super(Conv3x3, self).__init__() 43 | 44 | if use_refl: 45 | self.pad = nn.ReflectionPad2d(1) 46 | else: 47 | self.pad = nn.ZeroPad2d(1) 48 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 49 | 50 | def forward(self, x): 51 | out = self.pad(x) 52 | out = self.conv(out) 53 | return out 54 | 55 | 56 | def upsample(x): 57 | """Upsample input tensor by a factor of 2 58 | """ 59 | return F.interpolate(x, scale_factor=2, mode="nearest") 60 | 61 | 62 | -------------------------------------------------------------------------------- /dro_sfm/networks/layers/resnet/pose_decoder.py: -------------------------------------------------------------------------------- 1 | 2 | # Adapted from monodepth2 3 | # https://github.com/nianticlabs/monodepth2/blob/master/networks/pose_decoder.py 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | from collections import OrderedDict 10 | 11 | 12 | class PoseDecoder(nn.Module): 13 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): 14 | super(PoseDecoder, self).__init__() 15 | 16 | self.num_ch_enc = num_ch_enc 17 | self.num_input_features = num_input_features 18 | 19 | if num_frames_to_predict_for is None: 20 | num_frames_to_predict_for = num_input_features - 1 21 | self.num_frames_to_predict_for = num_frames_to_predict_for 22 | 23 | self.convs = OrderedDict() 24 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 25 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 26 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 27 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) 28 | 29 | self.relu = nn.ReLU() 30 | 31 | self.net = nn.ModuleList(list(self.convs.values())) 32 | 33 | def forward(self, input_features): 34 | last_features = [f[-1] for f in input_features] 35 | 36 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 37 | cat_features = torch.cat(cat_features, 1) 38 | 39 | out = cat_features 40 | for i in range(3): 41 | out = self.convs[("pose", i)](out) 42 | if i != 2: 43 | out = self.relu(out) 44 | out = out.mean(3).mean(2) 45 | 46 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6) 47 | 48 | axisangle = out[..., :3] 49 | translation = out[..., 3:] 50 | 51 | return axisangle, translation 52 | -------------------------------------------------------------------------------- /dro_sfm/networks/layers/resnet/pose_res_decoder.py: -------------------------------------------------------------------------------- 1 | 2 | # Adapted from monodepth2 3 | # https://github.com/nianticlabs/monodepth2/blob/master/networks/pose_decoder.py 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | from collections import OrderedDict 10 | 11 | 12 | class PoseResDecoder(nn.Module): 13 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): 14 | super(PoseResDecoder, self).__init__() 15 | 16 | self.num_ch_enc = num_ch_enc 17 | self.num_input_features = num_input_features 18 | 19 | if num_frames_to_predict_for is None: 20 | num_frames_to_predict_for = num_input_features - 1 21 | self.num_frames_to_predict_for = num_frames_to_predict_for 22 | 23 | self.convs = OrderedDict() 24 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 25 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 26 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 27 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) 28 | 29 | self.relu = nn.ReLU() 30 | 31 | self.net = nn.ModuleList(list(self.convs.values())) 32 | 33 | self.found_mat_net = nn.Sequential(nn.Linear(9, 18), nn.ReLU(inplace=True), 34 | nn.Linear(18, 18), nn.ReLU(inplace=True), 35 | nn.Linear(18, 6)) 36 | 37 | self.fusion_net = nn.Sequential(nn.Linear(12, 24), nn.ReLU(inplace=True), 38 | nn.Linear(24, 24), nn.ReLU(inplace=True), 39 | nn.Linear(24, 6)) 40 | 41 | def forward(self, input_features, foud_mat): 42 | last_features = [f[-1] for f in input_features] 43 | 44 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 45 | cat_features = torch.cat(cat_features, 1) 46 | 47 | out = cat_features 48 | for i in range(3): 49 | out = self.convs[("pose", i)](out) 50 | if i != 2: 51 | out = self.relu(out) 52 | out = out.mean(3).mean(2) 53 | 54 | # out = 0.01 * (out.view(-1, self.num_frames_to_predict_for, 1, 6) + self.found_mat_net(foud_mat.view(-1, 9)).view(-1, 1, 1, 6)) 55 | 56 | fund_mat_proj = self.found_mat_net(foud_mat.view(-1, 9)) 57 | out = 0.01 * (self.fusion_net(torch.cat([out.view(-1, 6), fund_mat_proj], dim=1)) + fund_mat_proj).view(-1, 1, 1, 6) 58 | 59 | print("out", out.view(-1, 6)) 60 | print("fund", fund_mat_proj.view(-1, 6)) 61 | 62 | axisangle = out[..., :3] 63 | translation = out[..., 3:] 64 | return axisangle, translation 65 | 66 | 67 | 68 | class PoseResAngleDecoder(nn.Module): 69 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): 70 | super(PoseResAngleDecoder, self).__init__() 71 | 72 | self.num_ch_enc = num_ch_enc 73 | self.num_input_features = num_input_features 74 | 75 | if num_frames_to_predict_for is None: 76 | num_frames_to_predict_for = num_input_features - 1 77 | self.num_frames_to_predict_for = num_frames_to_predict_for 78 | 79 | self.convs = OrderedDict() 80 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 81 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 82 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 83 | self.convs[("pose", 2)] = nn.Conv2d(256, 7 * num_frames_to_predict_for, 1) 84 | 85 | self.relu = nn.ReLU() 86 | 87 | self.net = nn.ModuleList(list(self.convs.values())) 88 | 89 | self.found_mat_net = nn.Sequential(nn.Linear(6, 18), nn.ReLU(inplace=True), 90 | nn.Linear(18, 18), nn.ReLU(inplace=True), 91 | nn.Linear(18, 6)) 92 | 93 | self.fusion_net = nn.Sequential(nn.Linear(12, 128), nn.ReLU(inplace=True), 94 | nn.Linear(128, 128), nn.ReLU(inplace=True), 95 | nn.Linear(128, 6)) 96 | 97 | def forward(self, input_features, pose_geo): 98 | last_features = [f[-1] for f in input_features] 99 | 100 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 101 | cat_features = torch.cat(cat_features, 1) 102 | 103 | out = cat_features 104 | for i in range(3): 105 | out = self.convs[("pose", i)](out) 106 | if i != 2: 107 | out = self.relu(out) 108 | out = out.mean(3).mean(2).view(-1, 7) 109 | 110 | # out = 0.01 * (out.view(-1, self.num_frames_to_predict_for, 1, 6) + self.found_mat_net(foud_mat.view(-1, 9)).view(-1, 1, 1, 6)) 111 | 112 | #trans_scale = 0.01 * pose_geo[:, 3:] * out[:, -1].unsqueeze(1) 113 | trans_scale = pose_geo[:, 3:] * out[:, -1].unsqueeze(1) 114 | 115 | pose_geo_new = torch.cat([pose_geo[:, :3], trans_scale], dim=1) 116 | 117 | #out = 0.01 * (self.fusion_net(torch.cat([out[:, :6], self.found_mat_net(pose_geo_new)], dim=1))) 118 | # out = 0.01 * (self.fusion_net(torch.cat([0.01 * out[:, :6], pose_geo_new], dim=1))) 119 | out = 0.01 * (self.fusion_net(torch.cat([out[:, :6], pose_geo_new], dim=1))) 120 | 121 | # out = 0.01 * out[:, :6] + pose_geo_new 122 | 123 | out = out.view(-1, 1, 1, 6) 124 | 125 | axisangle = out[..., :3] 126 | translation = out[..., 3:] 127 | return axisangle, translation 128 | 129 | -------------------------------------------------------------------------------- /dro_sfm/networks/layers/resnet/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | # Adapted from monodepth2 3 | # https://github.com/nianticlabs/monodepth2/blob/master/networks/resnet_encoder.py 4 | 5 | from __future__ import absolute_import, division, print_function 6 | 7 | import numpy as np 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torchvision.models as models 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | 15 | class ResNetMultiImageInput(models.ResNet): 16 | """Constructs a resnet model with varying number of input images. 17 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 18 | """ 19 | def __init__(self, block, layers, num_classes=1000, num_input_images=1): 20 | super(ResNetMultiImageInput, self).__init__(block, layers) 21 | self.inplanes = 64 22 | self.conv1 = nn.Conv2d( 23 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 24 | self.bn1 = nn.BatchNorm2d(64) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 27 | self.layer1 = self._make_layer(block, 64, layers[0]) 28 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 29 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 30 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 31 | 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 35 | elif isinstance(m, nn.BatchNorm2d): 36 | nn.init.constant_(m.weight, 1) 37 | nn.init.constant_(m.bias, 0) 38 | 39 | 40 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): 41 | """Constructs a ResNet model. 42 | Args: 43 | num_layers (int): Number of resnet layers. Must be 18 or 50 44 | pretrained (bool): If True, returns a model pre-trained on ImageNet 45 | num_input_images (int): Number of frames stacked as input 46 | """ 47 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 48 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 49 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 50 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) 51 | 52 | if pretrained: 53 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) 54 | loaded['conv1.weight'] = torch.cat( 55 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images 56 | model.load_state_dict(loaded) 57 | return model 58 | 59 | 60 | class ResnetEncoder(nn.Module): 61 | """Pytorch module for a resnet encoder 62 | """ 63 | def __init__(self, num_layers, pretrained, num_input_images=1): 64 | super(ResnetEncoder, self).__init__() 65 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 66 | 67 | resnets = {18: models.resnet18, 68 | 34: models.resnet34, 69 | 50: models.resnet50, 70 | 101: models.resnet101, 71 | 152: models.resnet152} 72 | 73 | if num_layers not in resnets: 74 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 75 | 76 | if num_input_images > 1: 77 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 78 | else: 79 | self.encoder = resnets[num_layers](pretrained) 80 | 81 | if num_layers > 34: 82 | self.num_ch_enc[1:] *= 4 83 | 84 | def forward(self, input_image): 85 | self.features = [] 86 | x = (input_image - 0.45) / 0.225 87 | x = self.encoder.conv1(x) 88 | x = self.encoder.bn1(x) 89 | self.features.append(self.encoder.relu(x)) 90 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 91 | self.features.append(self.encoder.layer2(self.features[-1])) 92 | self.features.append(self.encoder.layer3(self.features[-1])) 93 | self.features.append(self.encoder.layer4(self.features[-1])) 94 | 95 | return self.features 96 | -------------------------------------------------------------------------------- /dro_sfm/networks/optim/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dro_sfm/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainers 3 | ======== 4 | 5 | Trainer classes providing an easy way to train and evaluate SfM models 6 | when wrapped in a ModelWrapper. 7 | 8 | Inspired by pytorch-lightning. 9 | 10 | """ 11 | 12 | from dro_sfm.trainers.horovod_trainer import HorovodTrainer 13 | 14 | __all__ = ["HorovodTrainer"] -------------------------------------------------------------------------------- /dro_sfm/trainers/base_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from tqdm import tqdm 4 | from dro_sfm.utils.logging import prepare_dataset_prefix 5 | 6 | 7 | def sample_to_cuda(data, dtype=None): 8 | if isinstance(data, str): 9 | return data 10 | elif isinstance(data, dict): 11 | return {key: sample_to_cuda(data[key], dtype) for key in data.keys()} 12 | elif isinstance(data, list): 13 | return [sample_to_cuda(val, dtype) for val in data] 14 | else: 15 | # only convert floats (e.g., to half), otherwise preserve (e.g, ints) 16 | dtype = dtype if torch.is_floating_point(data) else None 17 | return data.to('cuda', dtype=dtype) 18 | 19 | 20 | class BaseTrainer: 21 | def __init__(self, min_epochs=0, max_epochs=50, 22 | checkpoint=None, **kwargs): 23 | 24 | self.min_epochs = min_epochs 25 | self.max_epochs = max_epochs 26 | 27 | self.checkpoint = checkpoint 28 | self.module = None 29 | 30 | @property 31 | def proc_rank(self): 32 | raise NotImplementedError('Not implemented for BaseTrainer') 33 | 34 | @property 35 | def world_size(self): 36 | raise NotImplementedError('Not implemented for BaseTrainer') 37 | 38 | @property 39 | def is_rank_0(self): 40 | return self.proc_rank == 0 41 | 42 | def check_and_save(self, module, output): 43 | if self.checkpoint: 44 | self.checkpoint.check_and_save(module, output) 45 | 46 | def train_progress_bar(self, dataloader, config, ncols=120): 47 | return tqdm(enumerate(dataloader, 0), 48 | unit=' images', unit_scale=self.world_size * config.batch_size, 49 | total=len(dataloader), smoothing=0, 50 | disable=not self.is_rank_0, ncols=ncols, 51 | ) 52 | 53 | def val_progress_bar(self, dataloader, config, n=0, ncols=120): 54 | return tqdm(enumerate(dataloader, 0), 55 | unit=' images', unit_scale=self.world_size * config.batch_size, 56 | total=len(dataloader), smoothing=0, 57 | disable=not self.is_rank_0, ncols=ncols, 58 | desc=prepare_dataset_prefix(config, n) 59 | ) 60 | 61 | def test_progress_bar(self, dataloader, config, n=0, ncols=120): 62 | return tqdm(enumerate(dataloader, 0), 63 | unit=' images', unit_scale=self.world_size * config.batch_size, 64 | total=len(dataloader), smoothing=0, 65 | disable=not self.is_rank_0, ncols=ncols, 66 | desc=prepare_dataset_prefix(config, n) 67 | ) 68 | -------------------------------------------------------------------------------- /dro_sfm/trainers/horovod_trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import horovod.torch as hvd 5 | from dro_sfm.trainers.base_trainer import BaseTrainer, sample_to_cuda 6 | from dro_sfm.utils.config import prep_logger_and_checkpoint 7 | from dro_sfm.utils.logging import print_config 8 | from dro_sfm.utils.logging import AvgMeter 9 | 10 | 11 | class HorovodTrainer(BaseTrainer): 12 | def __init__(self, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | hvd.init() 16 | torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", 1))) 17 | torch.cuda.set_device(hvd.local_rank()) 18 | torch.backends.cudnn.benchmark = True 19 | 20 | self.avg_loss = AvgMeter(50) 21 | self.avg_loss2 = AvgMeter(50) 22 | self.dtype = kwargs.get("dtype", None) # just for test for now 23 | 24 | @property 25 | def proc_rank(self): 26 | return hvd.rank() 27 | 28 | @property 29 | def world_size(self): 30 | return hvd.size() 31 | 32 | def fit(self, module): 33 | 34 | # Prepare module for training 35 | module.trainer = self 36 | # Update and print module configuration 37 | prep_logger_and_checkpoint(module) 38 | print_config(module.config) 39 | 40 | # Send module to GPU 41 | module = module.to('cuda') 42 | # Configure optimizer and scheduler 43 | module.configure_optimizers() 44 | 45 | # Create distributed optimizer 46 | compression = hvd.Compression.none 47 | optimizer = hvd.DistributedOptimizer(module.optimizer, 48 | named_parameters=module.named_parameters(), compression=compression) 49 | scheduler = module.scheduler 50 | 51 | # Get train and val dataloaders 52 | train_dataloader = module.train_dataloader() 53 | val_dataloaders = module.val_dataloader() 54 | 55 | # Epoch loop 56 | for epoch in range(module.current_epoch, self.max_epochs): 57 | # Train 58 | self.train(train_dataloader, module, optimizer) 59 | # Validation 60 | validation_output = self.validate(val_dataloaders, module) 61 | # Check and save model 62 | self.check_and_save(module, validation_output) 63 | # Update current epoch 64 | module.current_epoch += 1 65 | # Take a scheduler step 66 | scheduler.step() 67 | 68 | def train(self, dataloader, module, optimizer): 69 | # Set module to train 70 | module.train() 71 | # Shuffle dataloader sampler 72 | if hasattr(dataloader.sampler, "set_epoch"): 73 | dataloader.sampler.set_epoch(module.current_epoch) 74 | # Prepare progress bar 75 | progress_bar = self.train_progress_bar( 76 | dataloader, module.config.datasets.train) 77 | # Start training loop 78 | outputs = [] 79 | # For all batches 80 | for i, batch in progress_bar: 81 | # Reset optimizer 82 | optimizer.zero_grad() 83 | # Send samples to GPU and take a training step 84 | batch = sample_to_cuda(batch) 85 | output = module.training_step(batch, i) 86 | if output is None: 87 | print("skip this training step.....") 88 | continue 89 | # Backprop through loss and take an optimizer step 90 | output['loss'].backward() 91 | optimizer.step() 92 | # Append output to list of outputs 93 | output['loss'] = output['loss'].detach() 94 | outputs.append(output) 95 | # Update progress bar if in rank 0 96 | if self.is_rank_0: 97 | progress_bar.set_description( 98 | 'Epoch {} | Avg.Loss {:.4f} Loss2 {:.4f}'.format( 99 | module.current_epoch, self.avg_loss(output['loss'].item()), 100 | self.avg_loss2(output['metrics']['pose_loss'].item() if 'pose_loss' in output['metrics'] else 0.0))) 101 | # Return outputs for epoch end 102 | # return module.training_epoch_end(outputs) 103 | 104 | def validate(self, dataloaders, module): 105 | # Set module to eval 106 | module.eval() 107 | # Start validation loop 108 | all_outputs = [] 109 | # For all validation datasets 110 | for n, dataloader in enumerate(dataloaders): 111 | # Prepare progress bar for that dataset 112 | progress_bar = self.val_progress_bar( 113 | dataloader, module.config.datasets.validation, n) 114 | outputs = [] 115 | # For all batches 116 | for i, batch in progress_bar: 117 | # Send batch to GPU and take a validation step 118 | batch = sample_to_cuda(batch) 119 | output = module.validation_step(batch, i, n) 120 | # Append output to list of outputs 121 | outputs.append(output) 122 | # Append dataset outputs to list of all outputs 123 | all_outputs.append(outputs) 124 | # Return all outputs for epoch end 125 | return module.validation_epoch_end(all_outputs) 126 | 127 | def test(self, module): 128 | # Send module to GPU 129 | module = module.to('cuda', dtype=self.dtype) 130 | # Get test dataloaders 131 | test_dataloaders = module.test_dataloader() 132 | # Run evaluation 133 | self.evaluate(test_dataloaders, module) 134 | 135 | @torch.no_grad() 136 | def evaluate(self, dataloaders, module): 137 | # Set module to eval 138 | module.eval() 139 | # Start evaluation loop 140 | all_outputs = [] 141 | # For all test datasets 142 | for n, dataloader in enumerate(dataloaders): 143 | # Prepare progress bar for that dataset 144 | progress_bar = self.val_progress_bar( 145 | dataloader, module.config.datasets.test, n) 146 | outputs = [] 147 | # For all batches 148 | for i, batch in progress_bar: 149 | # Send batch to GPU and take a test step 150 | batch = sample_to_cuda(batch, self.dtype) 151 | output = module.test_step(batch, i, n) 152 | # Append output to list of outputs 153 | outputs.append(output) 154 | # Append dataset outputs to list of all outputs 155 | all_outputs.append(outputs) 156 | # Return all outputs for epoch end 157 | return module.test_epoch_end(all_outputs) 158 | -------------------------------------------------------------------------------- /dro_sfm/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/utils/__init__.py -------------------------------------------------------------------------------- /dro_sfm/utils/horovod.py: -------------------------------------------------------------------------------- 1 | 2 | try: 3 | import horovod.torch as hvd 4 | HAS_HOROVOD = True 5 | except ImportError: 6 | HAS_HOROVOD = False 7 | 8 | def hvd_disable(): 9 | global HAS_HOROVOD 10 | HAS_HOROVOD=False 11 | 12 | def hvd_init(): 13 | if HAS_HOROVOD: 14 | hvd.init() 15 | return HAS_HOROVOD 16 | 17 | def on_rank_0(func): 18 | def wrapper(*args, **kwargs): 19 | if rank() == 0: 20 | func(*args, **kwargs) 21 | return wrapper 22 | 23 | def rank(): 24 | return hvd.rank() if HAS_HOROVOD else 0 25 | 26 | def world_size(): 27 | return hvd.size() if HAS_HOROVOD else 1 28 | 29 | @on_rank_0 30 | def print0(string='\n'): 31 | print(string) 32 | 33 | def reduce_value(value, average, name): 34 | """ 35 | Reduce the mean value of a tensor from all GPUs 36 | 37 | Parameters 38 | ---------- 39 | value : torch.Tensor 40 | Value to be reduced 41 | average : bool 42 | Whether values will be averaged or not 43 | name : str 44 | Value name 45 | 46 | Returns 47 | ------- 48 | value : torch.Tensor 49 | reduced value 50 | """ 51 | return hvd.allreduce(value, average=average, name=name) 52 | -------------------------------------------------------------------------------- /dro_sfm/utils/load.py: -------------------------------------------------------------------------------- 1 | 2 | import importlib 3 | import logging 4 | import os 5 | import warnings 6 | import torch 7 | 8 | from inspect import signature 9 | from collections import OrderedDict 10 | 11 | from dro_sfm.utils.misc import make_list, same_shape 12 | from dro_sfm.utils.logging import pcolor 13 | from dro_sfm.utils.horovod import print0 14 | from dro_sfm.utils.types import is_str 15 | 16 | 17 | def set_debug(debug): 18 | """ 19 | Enable or disable debug terminal logging 20 | 21 | Parameters 22 | ---------- 23 | debug : bool 24 | Debugging flag (True to enable) 25 | """ 26 | # Disable logging if requested 27 | if not debug: 28 | os.environ['NCCL_DEBUG'] = '' 29 | os.environ['WANDB_SILENT'] = 'false' 30 | warnings.filterwarnings("ignore") 31 | logging.disable(logging.CRITICAL) 32 | 33 | 34 | def filter_args(func, keys): 35 | """ 36 | Filters a dictionary so it only contains keys that are arguments of a function 37 | 38 | Parameters 39 | ---------- 40 | func : Function 41 | Function for which we are filtering the dictionary 42 | keys : dict 43 | Dictionary with keys we are filtering 44 | 45 | Returns 46 | ------- 47 | filtered : dict 48 | Dictionary containing only keys that are arguments of func 49 | """ 50 | filtered = {} 51 | sign = list(signature(func).parameters.keys()) 52 | for k, v in {**keys}.items(): 53 | if k in sign: 54 | filtered[k] = v 55 | return filtered 56 | 57 | 58 | def filter_args_create(func, keys): 59 | """ 60 | Filters a dictionary so it only contains keys that are arguments of a function 61 | and creates a function with those arguments 62 | 63 | Parameters 64 | ---------- 65 | func : Function 66 | Function for which we are filtering the dictionary 67 | keys : dict 68 | Dictionary with keys we are filtering 69 | 70 | Returns 71 | ------- 72 | func : Function 73 | Function with filtered keys as arguments 74 | """ 75 | return func(**filter_args(func, keys)) 76 | 77 | 78 | def load_class(filename, paths, concat=True): 79 | """ 80 | Look for a file in different locations and return its method with the same name 81 | Optionally, you can use concat to search in path.filename instead 82 | 83 | Parameters 84 | ---------- 85 | filename : str 86 | Name of the file we are searching for 87 | paths : str or list of str 88 | Folders in which the file will be searched 89 | concat : bool 90 | Flag to concatenate filename to each path during the search 91 | 92 | Returns 93 | ------- 94 | method : Function 95 | Loaded method 96 | """ 97 | # for each path in paths 98 | for path in make_list(paths): 99 | # Create full path 100 | full_path = '{}.{}'.format(path, filename) if concat else path 101 | if importlib.util.find_spec(full_path): 102 | # Return method with same name as the file 103 | return getattr(importlib.import_module(full_path), filename) 104 | raise ValueError('Unknown class {}'.format(filename)) 105 | 106 | 107 | def load_class_args_create(filename, paths, args={}, concat=True): 108 | """Loads a class (filename) and returns an instance with filtered arguments (args)""" 109 | class_type = load_class(filename, paths, concat) 110 | return filter_args_create(class_type, args) 111 | 112 | 113 | def load_network(network, path, prefixes=''): 114 | """ 115 | Loads a pretrained network 116 | 117 | Parameters 118 | ---------- 119 | network : nn.Module 120 | Network that will receive the pretrained weights 121 | path : str 122 | File containing a 'state_dict' key with pretrained network weights 123 | prefixes : str or list of str 124 | Layer name prefixes to consider when loading the network 125 | 126 | Returns 127 | ------- 128 | network : nn.Module 129 | Updated network with pretrained weights 130 | """ 131 | prefixes = make_list(prefixes) 132 | # If path is a string 133 | if is_str(path): 134 | saved_state_dict = torch.load(path, map_location='cpu')['state_dict'] 135 | if path.endswith('.pth.tar'): 136 | saved_state_dict = backwards_state_dict(saved_state_dict) 137 | # If state dict is already provided 138 | else: 139 | saved_state_dict = path 140 | # Get network state dict 141 | network_state_dict = network.state_dict() 142 | 143 | updated_state_dict = OrderedDict() 144 | n, n_total = 0, len(network_state_dict.keys()) 145 | for key, val in saved_state_dict.items(): 146 | for prefix in prefixes: 147 | prefix = prefix + '.' 148 | if prefix in key: 149 | idx = key.find(prefix) + len(prefix) 150 | key = key[idx:] 151 | if key in network_state_dict.keys() and \ 152 | same_shape(val.shape, network_state_dict[key].shape): 153 | updated_state_dict[key] = val 154 | n += 1 155 | try: 156 | network.load_state_dict(updated_state_dict, strict=True) 157 | except Exception as e: 158 | print(e) 159 | network.load_state_dict(updated_state_dict, strict=False) 160 | base_color, attrs = 'cyan', ['bold', 'dark'] 161 | color = 'green' if n == n_total else 'yellow' if n > 0 else 'red' 162 | print0(pcolor('=====###### Pretrained {} loaded:'.format(prefixes[0]), base_color, attrs=attrs) + 163 | pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) + 164 | pcolor('tensors', base_color, attrs=attrs)) 165 | return network 166 | 167 | 168 | def backwards_state_dict(state_dict): 169 | """ 170 | Modify the state dict of older models for backwards compatibility 171 | 172 | Parameters 173 | ---------- 174 | state_dict : dict 175 | Model state dict with pretrained weights 176 | 177 | Returns 178 | ------- 179 | state_dict : dict 180 | Updated model state dict with modified layer names 181 | """ 182 | # List of layer names to change 183 | changes = (('model.model', 'model'), 184 | ('pose_network', 'pose_net'), 185 | ('disp_network', 'depth_net')) 186 | # Iterate over all keys and values 187 | updated_state_dict = OrderedDict() 188 | for key, val in state_dict.items(): 189 | # Ad hoc changes due to version changes 190 | key = '{}.{}'.format('model', key) 191 | if 'disp_network' in key: 192 | key = key.replace('conv3.0.weight', 'conv3.weight') 193 | key = key.replace('conv3.0.bias', 'conv3.bias') 194 | # Change layer names 195 | for change in changes: 196 | key = key.replace('{}.'.format(change[0]), 197 | '{}.'.format(change[1])) 198 | updated_state_dict[key] = val 199 | # Return updated state dict 200 | return updated_state_dict 201 | -------------------------------------------------------------------------------- /dro_sfm/utils/logging.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from termcolor import colored 4 | from functools import partial 5 | 6 | from dro_sfm.utils.horovod import on_rank_0 7 | 8 | 9 | def pcolor(string, color, on_color=None, attrs=None): 10 | """ 11 | Produces a colored string for printing 12 | 13 | Parameters 14 | ---------- 15 | string : str 16 | String that will be colored 17 | color : str 18 | Color to use 19 | on_color : str 20 | Background color to use 21 | attrs : list of str 22 | Different attributes for the string 23 | 24 | Returns 25 | ------- 26 | string: str 27 | Colored string 28 | """ 29 | return colored(string, color, on_color, attrs) 30 | 31 | 32 | def prepare_dataset_prefix(config, dataset_idx): 33 | """ 34 | Concatenates dataset path and split for metrics logging 35 | 36 | Parameters 37 | ---------- 38 | config : CfgNode 39 | Dataset configuration 40 | dataset_idx : int 41 | Dataset index for multiple datasets 42 | 43 | Returns 44 | ------- 45 | prefix : str 46 | Dataset prefix for metrics logging 47 | """ 48 | # Path is always available 49 | prefix = '{}'.format(os.path.splitext(config.path[dataset_idx].split('/')[-1])[0]) 50 | # If split is available and does not contain { character 51 | if config.split[dataset_idx] != '' and '{' not in config.split[dataset_idx]: 52 | prefix += '-{}'.format(os.path.splitext(os.path.basename(config.split[dataset_idx]))[0]) 53 | # If depth type is available 54 | if config.depth_type[dataset_idx] != '': 55 | prefix += '-{}'.format(config.depth_type[dataset_idx]) 56 | # If we are using specific cameras 57 | if len(config.cameras[dataset_idx]) == 1: # only allows single cameras 58 | prefix += '-{}'.format(config.cameras[dataset_idx][0]) 59 | # Return full prefix 60 | return prefix 61 | 62 | 63 | def s3_url(config): 64 | """ 65 | Generate the s3 url where the models will be saved 66 | 67 | Parameters 68 | ---------- 69 | config : CfgNode 70 | Model configuration 71 | 72 | Returns 73 | ------- 74 | url : str 75 | String containing the URL pointing to the s3 bucket 76 | """ 77 | return 'https://s3.console.aws.amazon.com/s3/buckets/{}/{}'.format( 78 | config.checkpoint.s3_path[5:], config.name) 79 | 80 | 81 | @on_rank_0 82 | def print_config(config, color=('blue', 'red', 'cyan'), attrs=('bold', 'dark')): 83 | """ 84 | Prints header for model configuration 85 | 86 | Parameters 87 | ---------- 88 | config : CfgNode 89 | Model configuration 90 | color : list of str 91 | Color pallete for the header 92 | attrs : 93 | Colored string attributes 94 | """ 95 | # Recursive print function 96 | def print_recursive(rec_args, n=2, l=0): 97 | if l == 0: 98 | print(pcolor('config:', color[1], attrs=attrs)) 99 | for key, val in rec_args.items(): 100 | if isinstance(val, dict): 101 | print(pcolor('{} {}:'.format('-' * n, key), color[1], attrs=attrs)) 102 | print_recursive(val, n + 2, l + 1) 103 | else: 104 | print('{}: {}'.format(pcolor('{} {}'.format('-' * n, key), color[2]), val)) 105 | 106 | # Color partial functions 107 | pcolor1 = partial(pcolor, color='blue', attrs=['bold', 'dark']) 108 | pcolor2 = partial(pcolor, color='blue', attrs=['bold']) 109 | # Config and name 110 | line = pcolor1('#' * 120) 111 | path = pcolor1('### Config: ') + \ 112 | pcolor2('{}'.format(config.default.replace('/', '.'))) + \ 113 | pcolor1(' -> ') + \ 114 | pcolor2('{}'.format(config.config.replace('/', '.'))) 115 | name = pcolor1('### Name: ') + \ 116 | pcolor2('{}'.format(config.name)) 117 | # Add wandb link if available 118 | if not config.wandb.dry_run: 119 | name += pcolor1(' -> ') + \ 120 | pcolor2('{}'.format(config.wandb.url)) 121 | # Add s3 link if available 122 | if config.checkpoint.s3_path is not '': 123 | name += pcolor1('\n### s3:') + \ 124 | pcolor2(' {}'.format(config.checkpoint.s3_url)) 125 | # Create header string 126 | header = '%s\n%s\n%s\n%s' % (line, path, name, line) 127 | 128 | # Print header, config and header again 129 | print() 130 | print(header) 131 | print_recursive(config) 132 | print(header) 133 | print() 134 | 135 | 136 | class AvgMeter: 137 | """Average meter for logging""" 138 | def __init__(self, n_max=100): 139 | """ 140 | Initializes a AvgMeter object. 141 | 142 | Parameters 143 | ---------- 144 | n_max : int 145 | Number of steps to average over 146 | """ 147 | self.n_max = n_max 148 | self.values = [] 149 | 150 | def __call__(self, value): 151 | """Appends new value and returns average""" 152 | self.values.append(value) 153 | if len(self.values) > self.n_max: 154 | self.values.pop(0) 155 | return self.get() 156 | 157 | def get(self): 158 | """Get current average""" 159 | return sum(self.values) / len(self.values) 160 | 161 | def reset(self): 162 | """Reset meter""" 163 | self.values.clear() 164 | 165 | def get_and_reset(self): 166 | """Gets current average and resets""" 167 | average = self.get() 168 | self.reset() 169 | return average 170 | -------------------------------------------------------------------------------- /dro_sfm/utils/misc.py: -------------------------------------------------------------------------------- 1 | 2 | from dro_sfm.utils.types import is_list 3 | 4 | ######################################################################################################################## 5 | 6 | def filter_dict(dictionary, keywords): 7 | """ 8 | Returns only the keywords that are part of a dictionary 9 | 10 | Parameters 11 | ---------- 12 | dictionary : dict 13 | Dictionary for filtering 14 | keywords : list of str 15 | Keywords that will be filtered 16 | 17 | Returns 18 | ------- 19 | keywords : list of str 20 | List containing the keywords that are keys in dictionary 21 | """ 22 | return [key for key in keywords if key in dictionary] 23 | 24 | ######################################################################################################################## 25 | 26 | def make_list(var, n=None): 27 | """ 28 | Wraps the input into a list, and optionally repeats it to be size n 29 | 30 | Parameters 31 | ---------- 32 | var : Any 33 | Variable to be wrapped in a list 34 | n : int 35 | How much the wrapped variable will be repeated 36 | 37 | Returns 38 | ------- 39 | var_list : list 40 | List generated from var 41 | """ 42 | var = var if is_list(var) else [var] 43 | if n is None: 44 | return var 45 | else: 46 | assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list' 47 | return var * n if len(var) == 1 else var 48 | 49 | ######################################################################################################################## 50 | 51 | def same_shape(shape1, shape2): 52 | """ 53 | Checks if two shapes are the same 54 | 55 | Parameters 56 | ---------- 57 | shape1 : tuple 58 | First shape 59 | shape2 : tuple 60 | Second shape 61 | 62 | Returns 63 | ------- 64 | flag : bool 65 | True if both shapes are the same (same length and dimensions) 66 | """ 67 | if len(shape1) != len(shape2): 68 | return False 69 | for i in range(len(shape1)): 70 | if shape1[i] != shape2[i]: 71 | return False 72 | return True 73 | 74 | ######################################################################################################################## 75 | -------------------------------------------------------------------------------- /dro_sfm/utils/reduce.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import numpy as np 4 | from collections import OrderedDict 5 | from dro_sfm.utils.horovod import reduce_value 6 | from dro_sfm.utils.logging import prepare_dataset_prefix 7 | 8 | 9 | def reduce_dict(data, to_item=False): 10 | """ 11 | Reduce the mean values of a dictionary from all GPUs 12 | 13 | Parameters 14 | ---------- 15 | data : dict 16 | Dictionary to be reduced 17 | to_item : bool 18 | True if the reduced values will be return as .item() 19 | 20 | Returns 21 | ------- 22 | dict : dict 23 | Reduced dictionary 24 | """ 25 | for key, val in data.items(): 26 | data[key] = reduce_value(data[key], average=True, name=key) 27 | if to_item: 28 | data[key] = data[key].item() 29 | return data 30 | 31 | def all_reduce_metrics(output_data_batch, datasets, name='depth'): 32 | """ 33 | Reduce metrics for all batches and all datasets using Horovod 34 | 35 | Parameters 36 | ---------- 37 | output_data_batch : list 38 | List of outputs for each batch 39 | datasets : list 40 | List of all considered datasets 41 | name : str 42 | Name of the task for the metric 43 | 44 | Returns 45 | ------- 46 | all_metrics_dict : list 47 | List of reduced metrics 48 | """ 49 | # If there is only one dataset, wrap in a list 50 | if isinstance(output_data_batch[0], dict): 51 | output_data_batch = [output_data_batch] 52 | # Get metrics keys and dimensions 53 | names = [key for key in list(output_data_batch[0][0].keys()) if key.startswith(name)] 54 | dims = [output_data_batch[0][0][name].shape[0] for name in names] 55 | # List storing metrics for all datasets 56 | all_metrics_dict = [] 57 | # Loop over all datasets and all batches 58 | for output_batch, dataset in zip(output_data_batch, datasets): 59 | metrics_dict = OrderedDict() 60 | length = len(dataset) 61 | # Count how many times each sample was seen 62 | seen = torch.zeros(length) 63 | for output in output_batch: 64 | for i, idx in enumerate(output['idx']): 65 | seen[idx] += 1 66 | seen = reduce_value(seen, average=False, name='idx') 67 | assert not np.any(seen.numpy() == 0), \ 68 | 'Not all samples were seen during evaluation' 69 | # Reduce all relevant metrics 70 | for name, dim in zip(names, dims): 71 | metrics = torch.zeros(length, dim) 72 | for output in output_batch: 73 | for i, idx in enumerate(output['idx']): 74 | metrics[idx] = output[name] 75 | metrics = reduce_value(metrics, average=False, name=name) 76 | metrics_dict[name] = (metrics / seen.view(-1, 1)).mean(0) 77 | # Append metrics dictionary to the list 78 | all_metrics_dict.append(metrics_dict) 79 | # Return list of metrics dictionary 80 | return all_metrics_dict 81 | 82 | ######################################################################################################################## 83 | 84 | def collate_metrics(output_data_batch, name='depth'): 85 | """ 86 | Collate epoch output to produce average metrics 87 | 88 | Parameters 89 | ---------- 90 | output_data_batch : list 91 | List of outputs for each batch 92 | name : str 93 | Name of the task for the metric 94 | 95 | Returns 96 | ------- 97 | metrics_data : list 98 | List of collated metrics 99 | """ 100 | # If there is only one dataset, wrap in a list 101 | if isinstance(output_data_batch[0], dict): 102 | output_data_batch = [output_data_batch] 103 | # Calculate the mean of all metrics 104 | metrics_data = [] 105 | # For all datasets 106 | for i, output_batch in enumerate(output_data_batch): 107 | metrics = OrderedDict() 108 | # For all keys (assume they are the same for all batches) 109 | for key, val in output_batch[0].items(): 110 | if key.startswith(name): 111 | metrics[key] = torch.stack([output[key] for output in output_batch], 0) 112 | metrics[key] = torch.mean(metrics[key], 0) 113 | metrics_data.append(metrics) 114 | # Return metrics data 115 | return metrics_data 116 | 117 | def create_dict(metrics_data, metrics_keys, metrics_modes, 118 | dataset, name='depth'): 119 | """ 120 | Creates a dictionary from collated metrics 121 | 122 | Parameters 123 | ---------- 124 | metrics_data : list 125 | List containing collated metrics 126 | metrics_keys : list 127 | List of keys for the metrics 128 | metrics_modes 129 | List of modes for the metrics 130 | dataset : CfgNode 131 | Dataset configuration file 132 | name : str 133 | Name of the task for the metric 134 | 135 | Returns 136 | ------- 137 | metrics_dict : dict 138 | Metrics dictionary 139 | """ 140 | # Create metrics dictionary 141 | metrics_dict = {} 142 | # For all datasets 143 | for n, metrics in enumerate(metrics_data): 144 | if metrics: # If there are calculated metrics 145 | prefix = prepare_dataset_prefix(dataset, n) 146 | # For all keys 147 | for i, key in enumerate(metrics_keys): 148 | for mode in metrics_modes: 149 | metrics_dict['{}-{}{}'.format(prefix, key, mode)] =\ 150 | metrics['{}{}'.format(name, mode)][i].item() 151 | # Return metrics dictionary 152 | return metrics_dict 153 | 154 | ######################################################################################################################## 155 | 156 | def average_key(batch_list, key): 157 | """ 158 | Average key in a list of batches 159 | 160 | Parameters 161 | ---------- 162 | batch_list : list of dict 163 | List containing dictionaries with the same keys 164 | key : str 165 | Key to be averaged 166 | 167 | Returns 168 | ------- 169 | average : float 170 | Average of the value contained in key for all batches 171 | """ 172 | values = [batch[key] for batch in batch_list] 173 | return sum(values) / len(values) 174 | 175 | def average_sub_key(batch_list, key, sub_key): 176 | """ 177 | Average subkey in a dictionary in a list of batches 178 | 179 | Parameters 180 | ---------- 181 | batch_list : list of dict 182 | List containing dictionaries with the same keys 183 | key : str 184 | Key to be averaged 185 | sub_key : 186 | Sub key to be averaged (belonging to key) 187 | 188 | Returns 189 | ------- 190 | average : float 191 | Average of the value contained in the sub_key of key for all batches 192 | """ 193 | values = [batch[key][sub_key] for batch in batch_list] 194 | return sum(values) / len(values) 195 | 196 | def average_loss_and_metrics(batch_list, prefix): 197 | """ 198 | Average loss and metrics values in a list of batches 199 | 200 | Parameters 201 | ---------- 202 | batch_list : list of dict 203 | List containing dictionaries with the same keys 204 | prefix : str 205 | Prefix string for metrics logging 206 | 207 | Returns 208 | ------- 209 | values : dict 210 | Dictionary containing a 'loss' float entry and a 'metrics' dict entry 211 | """ 212 | values = OrderedDict() 213 | key = 'loss' 214 | values['{}-{}'.format(prefix, key)] = \ 215 | average_key(batch_list, key) 216 | key = 'metrics' 217 | for sub_key in batch_list[0][key].keys(): 218 | values['{}-{}'.format(prefix, sub_key)] = \ 219 | average_sub_key(batch_list, key, sub_key) 220 | return values 221 | 222 | ######################################################################################################################## 223 | -------------------------------------------------------------------------------- /dro_sfm/utils/save.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | 5 | from dro_sfm.utils.image import write_image 6 | from dro_sfm.utils.depth import write_depth, inv2depth, viz_inv_depth 7 | from dro_sfm.utils.logging import prepare_dataset_prefix 8 | 9 | 10 | def save_depth(batch, output, args, dataset, save): 11 | """ 12 | Save depth predictions in various ways 13 | 14 | Parameters 15 | ---------- 16 | batch : dict 17 | Batch from dataloader 18 | output : dict 19 | Output from model 20 | args : tuple 21 | Step arguments 22 | dataset : CfgNode 23 | Dataset configuration 24 | save : CfgNode 25 | Save configuration 26 | """ 27 | # If there is no save folder, don't save 28 | if save.folder is '': 29 | return 30 | 31 | # If we want to save 32 | if save.depth.rgb or save.depth.viz or save.depth.npz or save.depth.png: 33 | # Retrieve useful tensors 34 | rgb = batch['rgb'] 35 | pred_inv_depth = output['inv_depth'] 36 | 37 | # Prepare path strings 38 | filename = batch['filename'] 39 | dataset_idx = 0 if len(args) == 1 else args[1] 40 | save_path = os.path.join(save.folder, 'depth', 41 | prepare_dataset_prefix(dataset, dataset_idx), 42 | os.path.basename(save.pretrained).split('.')[0]) 43 | # Create folder 44 | os.makedirs(save_path, exist_ok=True) 45 | 46 | # For each image in the batch 47 | length = rgb.shape[0] 48 | for i in range(length): 49 | # Save numpy depth maps 50 | if save.depth.npz: 51 | write_depth('{}/{}_depth.npz'.format(save_path, filename[i]), 52 | depth=inv2depth(pred_inv_depth[i]), 53 | intrinsics=batch['intrinsics'][i] if 'intrinsics' in batch else None) 54 | # Save png depth maps 55 | if save.depth.png: 56 | write_depth('{}/{}_depth.png'.format(save_path, filename[i]), 57 | depth=inv2depth(pred_inv_depth[i])) 58 | # Save rgb images 59 | if save.depth.rgb: 60 | rgb_i = rgb[i].permute(1, 2, 0).detach().cpu().numpy() * 255 61 | write_image('{}/{}_rgb.png'.format(save_path, filename[i]), rgb_i) 62 | # Save inverse depth visualizations 63 | if save.depth.viz: 64 | viz_i = viz_inv_depth(pred_inv_depth[i]) * 255 65 | write_image('{}/{}_viz.png'.format(save_path, filename[i]), viz_i) 66 | -------------------------------------------------------------------------------- /dro_sfm/utils/types.py: -------------------------------------------------------------------------------- 1 | 2 | import yacs 3 | import numpy as np 4 | import torch 5 | 6 | ######################################################################################################################## 7 | 8 | def is_numpy(data): 9 | """Checks if data is a numpy array.""" 10 | return isinstance(data, np.ndarray) 11 | 12 | def is_tensor(data): 13 | """Checks if data is a torch tensor.""" 14 | return type(data) == torch.Tensor 15 | 16 | def is_tuple(data): 17 | """Checks if data is a tuple.""" 18 | return isinstance(data, tuple) 19 | 20 | def is_list(data): 21 | """Checks if data is a list.""" 22 | return isinstance(data, list) 23 | 24 | def is_dict(data): 25 | """Checks if data is a dictionary.""" 26 | return isinstance(data, dict) 27 | 28 | def is_str(data): 29 | """Checks if data is a string.""" 30 | return isinstance(data, str) 31 | 32 | def is_int(data): 33 | """Checks if data is an integer.""" 34 | return isinstance(data, int) 35 | 36 | def is_seq(data): 37 | """Checks if data is a list or tuple.""" 38 | return is_tuple(data) or is_list(data) 39 | 40 | def is_cfg(data): 41 | """Checks if data is a configuration node""" 42 | return type(data) == yacs.config.CfgNode 43 | 44 | ######################################################################################################################## -------------------------------------------------------------------------------- /media/figs/demo_kitti.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/media/figs/demo_kitti.gif -------------------------------------------------------------------------------- /media/figs/demo_scannet.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/media/figs/demo_scannet.gif -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | mkdir -p $(dirname $2) 2 | 3 | NGPUS=$(nvidia-smi -L | wc -l) 4 | mpirun -allow-run-as-root -np ${NGPUS} -H localhost:${NGPUS} -x MASTER_ADDR=127.0.0.1 -x MASTER_PORT=23457 -x HOROVOD_TIMELINE -x OMP_NUM_THREADS=1 -x KMP_AFFINITY='granularity=fine,compact,1,0' -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x NCCL_MIN_NRINGS=4 --report-bindings $1 2>&1 | tee -a $2 5 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(lib_dir) 5 | import argparse 6 | import torch 7 | 8 | from dro_sfm.models.model_wrapper import ModelWrapper 9 | from dro_sfm.trainers.horovod_trainer import HorovodTrainer 10 | from dro_sfm.utils.config import parse_test_file 11 | from dro_sfm.utils.load import set_debug 12 | from dro_sfm.utils.horovod import hvd_init 13 | 14 | 15 | def parse_args(): 16 | """Parse arguments for training script""" 17 | parser = argparse.ArgumentParser(description='dro-sfm evaluation script') 18 | parser.add_argument('--checkpoint', type=str, help='Checkpoint (.ckpt)') 19 | parser.add_argument('--config', type=str, default=None, help='Configuration (.yaml)') 20 | parser.add_argument('--half', action="store_true", help='Use half precision (fp16)') 21 | args = parser.parse_args() 22 | assert args.checkpoint.endswith('.ckpt'), \ 23 | 'You need to provide a .ckpt file as checkpoint' 24 | assert args.config is None or args.config.endswith('.yaml'), \ 25 | 'You need to provide a .yaml file as configuration' 26 | return args 27 | 28 | 29 | def test(ckpt_file, cfg_file, half): 30 | """ 31 | Monocular depth estimation test script. 32 | 33 | Parameters 34 | ---------- 35 | ckpt_file : str 36 | Checkpoint path for a pretrained model 37 | cfg_file : str 38 | Configuration file 39 | half: bool 40 | use half precision (fp16) 41 | """ 42 | # Initialize horovod 43 | hvd_init() 44 | 45 | # Parse arguments 46 | config, state_dict = parse_test_file(ckpt_file, cfg_file) 47 | 48 | # Set debug if requested 49 | set_debug(config.debug) 50 | 51 | # Initialize monodepth model from checkpoint arguments 52 | model_wrapper = ModelWrapper(config) 53 | # Restore model state 54 | model_wrapper.load_state_dict(state_dict) 55 | 56 | # change to half precision for evaluation if requested 57 | config.arch["dtype"] = torch.float16 if half else None 58 | 59 | # Create trainer with args.arch parameters 60 | trainer = HorovodTrainer(**config.arch) 61 | 62 | # Test model 63 | trainer.test(model_wrapper) 64 | 65 | 66 | if __name__ == '__main__': 67 | args = parse_args() 68 | test(args.checkpoint, args.config, args.half) 69 | -------------------------------------------------------------------------------- /scripts/evaluate_depth_maps.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(lib_dir) 5 | import argparse 6 | import numpy as np 7 | import torch 8 | 9 | from glob import glob 10 | from argparse import Namespace 11 | from dro_sfm.utils.depth import load_depth 12 | from tqdm import tqdm 13 | 14 | from dro_sfm.utils.depth import load_depth, compute_depth_metrics 15 | 16 | 17 | def parse_args(): 18 | """Parse arguments for benchmark script""" 19 | parser = argparse.ArgumentParser(description='dro-sfm benchmark script') 20 | parser.add_argument('--pred_folder', type=str, 21 | help='Folder containing predicted depth maps (.npz with key "depth")') 22 | parser.add_argument('--gt_folder', type=str, 23 | help='Folder containing ground-truth depth maps (.npz with key "depth")') 24 | parser.add_argument('--use_gt_scale', action='store_true', 25 | help='Use ground-truth median scaling on predicted depth maps') 26 | parser.add_argument('--min_depth', type=float, default=0., 27 | help='Minimum distance to consider during evaluation') 28 | parser.add_argument('--max_depth', type=float, default=80., 29 | help='Maximum distance to consider during evaluation') 30 | parser.add_argument('--crop', type=str, default='', choices=['', 'garg'], 31 | help='Which crop to use during evaluation') 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def main(args): 37 | # Get and sort ground-truth and predicted files 38 | exts = ('npz', 'png') 39 | gt_files, pred_files = [], [] 40 | for ext in exts: 41 | gt_files.extend(glob(os.path.join(args.gt_folder, '*.{}'.format(ext)))) 42 | pred_files.extend(glob(os.path.join(args.pred_folder, '*.{}'.format(ext)))) 43 | # Sort ground-truth and prediction 44 | gt_files.sort() 45 | pred_files.sort() 46 | # Loop over all files 47 | metrics = [] 48 | progress_bar = tqdm(zip(gt_files, pred_files), total=len(gt_files)) 49 | for gt, pred in progress_bar: 50 | # Get and prepare ground-truth and predictions 51 | gt = torch.tensor(load_depth(gt)).unsqueeze(0).unsqueeze(0) 52 | pred = torch.tensor(load_depth(pred)).unsqueeze(0).unsqueeze(0) 53 | # Calculate metrics 54 | metrics.append(compute_depth_metrics( 55 | args, gt, pred, use_gt_scale=args.use_gt_scale)) 56 | # Get and print average value 57 | metrics = (sum(metrics) / len(metrics)).detach().cpu().numpy() 58 | names = ['abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'a1', 'a2', 'a3'] 59 | for name, metric in zip(names, metrics): 60 | print('{} = {}'.format(name, metric)) 61 | 62 | 63 | if __name__ == '__main__': 64 | args = parse_args() 65 | main(args) 66 | -------------------------------------------------------------------------------- /scripts/infer_pose.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 4 | sys.path.append(lib_dir) 5 | import argparse 6 | import numpy as np 7 | import torch 8 | from pytorch3d import transforms 9 | import json 10 | 11 | from glob import glob 12 | from cv2 import imwrite 13 | 14 | from dro_sfm.models.model_wrapper import ModelWrapper 15 | from dro_sfm.datasets.augmentations import resize_image, to_tensor 16 | from dro_sfm.utils.horovod import hvd_init, rank, world_size, print0 17 | from dro_sfm.utils.image import load_image 18 | from dro_sfm.utils.config import parse_test_file 19 | from dro_sfm.utils.load import set_debug 20 | from dro_sfm.utils.depth import write_depth, inv2depth, viz_inv_depth 21 | from dro_sfm.utils.logging import pcolor 22 | 23 | poses = dict() 24 | 25 | def is_image(file, ext=('.png', '.jpg',)): 26 | """Check if a file is an image with certain extensions""" 27 | return file.endswith(ext) 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='dro-sfm inference of depth maps from images') 31 | parser.add_argument('--checkpoint', type=str, help='Checkpoint (.ckpt)') 32 | parser.add_argument('--input', type=str, help='Input file or folder') 33 | parser.add_argument('--output', type=str, help='Output file or folder') 34 | parser.add_argument('--image_shape', type=int, nargs='+', default=None, 35 | help='Input and output image shape ' 36 | '(default: checkpoint\'s config.datasets.augmentation.image_shape)') 37 | parser.add_argument('--half', action="store_true", help='Use half precision (fp16)') 38 | parser.add_argument('--save', type=str, choices=['npz', 'png'], default=None, 39 | help='Save format (npz or png). Default is None (no depth map is saved).') 40 | parser.add_argument('--limit', type=int, default=None, 41 | help='Limit the amount of files to process in folder') 42 | parser.add_argument('--offset', type=int, default=None, 43 | help='Start at offset for files to process in folder') 44 | args = parser.parse_args() 45 | assert args.checkpoint.endswith('.ckpt'), \ 46 | 'You need to provide a .ckpt file as checkpoint' 47 | assert args.image_shape is None or len(args.image_shape) == 2, \ 48 | 'You need to provide a 2-dimensional tuple as shape (H,W)' 49 | assert (is_image(args.input) and is_image(args.output)) or \ 50 | (not is_image(args.input) and not is_image(args.input)), \ 51 | 'Input and output must both be images or folders' 52 | return args 53 | 54 | 55 | @torch.no_grad() 56 | def infer_and_save_pose(input_file_refs, input_file, model_wrapper, image_shape, half, save): 57 | """ 58 | Process a single input file to produce and save visualization 59 | 60 | Parameters 61 | ---------- 62 | input_file_refs : list(str) 63 | Reference image file paths 64 | input_file : str 65 | Image file for pose estimation 66 | model_wrapper : nn.Module 67 | Model wrapper used for inference 68 | image_shape : Image shape 69 | Input image shape 70 | half: bool 71 | use half precision (fp16) 72 | save: str 73 | Save format (npz or png) 74 | """ 75 | base_name = os.path.basename(input_file) 76 | 77 | # change to half precision for evaluation if requested 78 | dtype = torch.float16 if half else None 79 | 80 | # Load image 81 | def process_image(filename): 82 | image = load_image(filename) 83 | # Resize and to tensor 84 | image = resize_image(image, image_shape) 85 | image = to_tensor(image).unsqueeze(0) 86 | 87 | # Send image to GPU if available 88 | if torch.cuda.is_available(): 89 | image = image.to('cuda:{}'.format(rank()), dtype=dtype) 90 | return image 91 | image_ref = [process_image(input_file_ref) for input_file_ref in input_file_refs] 92 | image = process_image(input_file) 93 | 94 | # Depth inference (returns predicted inverse depth) 95 | pose_tensor = model_wrapper.pose(image, image_ref)[0][0] # take the pose from 1st to 2nd image 96 | 97 | rot_matrix = transforms.euler_angles_to_matrix(pose_tensor[3:], convention="ZYX") 98 | translation = pose_tensor[:3] 99 | 100 | poses[base_name] = (rot_matrix, translation) 101 | 102 | def main(args): 103 | 104 | # Initialize horovod 105 | hvd_init() 106 | 107 | # Parse arguments 108 | config, state_dict = parse_test_file(args.checkpoint) 109 | 110 | # If no image shape is provided, use the checkpoint one 111 | image_shape = args.image_shape 112 | if image_shape is None: 113 | image_shape = config.datasets.augmentation.image_shape 114 | 115 | print(image_shape) 116 | # Set debug if requested 117 | set_debug(config.debug) 118 | 119 | # Initialize model wrapper from checkpoint arguments 120 | model_wrapper = ModelWrapper(config, load_datasets=False) 121 | # Restore monodepth_model state 122 | model_wrapper.load_state_dict(state_dict) 123 | 124 | # change to half precision for evaluation if requested 125 | dtype = torch.float16 if args.half else None 126 | 127 | # Send model to GPU if available 128 | if torch.cuda.is_available(): 129 | model_wrapper = model_wrapper.to('cuda:{}'.format(rank()), dtype=dtype) 130 | 131 | # Set to eval mode 132 | model_wrapper.eval() 133 | 134 | if os.path.isdir(args.input): 135 | # If input file is a folder, search for image files 136 | files = [] 137 | for ext in ['png', 'jpg']: 138 | files.extend(glob((os.path.join(args.input, '*.{}'.format(ext))))) 139 | files.sort() 140 | print0('Found {} files'.format(len(files))) 141 | else: 142 | raise RuntimeError("Input needs directory, not file") 143 | 144 | if not os.path.isdir(args.output): 145 | root, file_name = os.path.split(args.output) 146 | os.makedirs(root, exist_ok=True) 147 | else: 148 | raise RuntimeError("Output needs to be a file") 149 | 150 | 151 | # Process each file 152 | list_of_files = list(zip(files[rank() :-2:world_size()], 153 | files[rank()+1:-1:world_size()], 154 | files[rank()+2: :world_size()])) 155 | if args.offset: 156 | list_of_files = list_of_files[args.offset:] 157 | if args.limit: 158 | list_of_files = list_of_files[:args.limit] 159 | for fn1, fn2, fn3 in list_of_files: 160 | infer_and_save_pose([fn1, fn3], fn2, model_wrapper, image_shape, args.half, args.save) 161 | 162 | position = np.zeros(3) 163 | orientation = np.eye(3) 164 | for key in sorted(poses.keys()): 165 | rot_matrix, translation = poses[key] 166 | orientation = orientation.dot(rot_matrix.tolist()) 167 | position += orientation.dot(translation.tolist()) 168 | poses[key] = {"rot": rot_matrix.tolist(), 169 | "trans": translation.tolist(), 170 | "pose": [*orientation[0], position[0], 171 | *orientation[1], position[1], 172 | *orientation[2], position[2], 173 | 0, 0, 0, 1]} 174 | 175 | json.dump(poses, open(args.output, "w"), sort_keys=True) 176 | print(f"Written pose of {len(list_of_files)} images to {args.output}") 177 | 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | args = parse_args() 183 | main(args) -------------------------------------------------------------------------------- /scripts/infer_pose.sh: -------------------------------------------------------------------------------- 1 | python scripts/infer_pose.py --checkpoint ckpt/PackNet01_MR_selfsup_D.ckpt --input /data0/datasets/kitti/KITTI_tiny/2011_09_26/2011_09_26_drive_0023_sync/image_02/data --output results/eval/pose --save png 2 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import os 4 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 5 | sys.path.append(lib_dir) 6 | import argparse 7 | from dro_sfm.models.model_wrapper import ModelWrapper 8 | from dro_sfm.models.model_checkpoint import ModelCheckpoint 9 | from dro_sfm.trainers.horovod_trainer import HorovodTrainer 10 | from dro_sfm.utils.config import parse_train_file 11 | from dro_sfm.utils.load import set_debug, filter_args_create 12 | from dro_sfm.utils.horovod import hvd_init, rank 13 | from dro_sfm.loggers import WandbLogger 14 | 15 | 16 | def parse_args(): 17 | """Parse arguments for training script""" 18 | parser = argparse.ArgumentParser(description='dro-sfm training script') 19 | parser.add_argument('file', type=str, help='Input file (.ckpt or .yaml)') 20 | parser.add_argument('--config', type=str, default=None, help='Input file (yaml)') 21 | args = parser.parse_args() 22 | assert args.file.endswith(('.ckpt', '.yaml')), \ 23 | 'You need to provide a .ckpt of .yaml file' 24 | return args 25 | 26 | 27 | def train(file, config): 28 | """ 29 | Monocular depth estimation training script. 30 | 31 | Parameters 32 | ---------- 33 | file : str 34 | Filepath, can be either a 35 | **.yaml** for a yacs configuration file or a 36 | **.ckpt** for a pre-trained checkpoint file. 37 | """ 38 | # Initialize horovod 39 | hvd_init() 40 | 41 | # Produce configuration and checkpoint from filename 42 | config, ckpt = parse_train_file(file, config) 43 | 44 | # Set debug if requested 45 | set_debug(config.debug) 46 | 47 | # Wandb Logger 48 | logger = None if config.wandb.dry_run or rank() > 0 \ 49 | else filter_args_create(WandbLogger, config.wandb) 50 | 51 | # model checkpoint 52 | checkpoint = None if config.checkpoint.filepath is '' or rank() > 0 else \ 53 | filter_args_create(ModelCheckpoint, config.checkpoint) 54 | 55 | # Initialize model wrapper 56 | model_wrapper = ModelWrapper(config, resume=ckpt, logger=logger) 57 | 58 | # Create trainer with args.arch parameters 59 | trainer = HorovodTrainer(**config.arch, checkpoint=checkpoint) 60 | 61 | # Train model 62 | trainer.fit(model_wrapper) 63 | 64 | 65 | if __name__ == '__main__': 66 | args = parse_args() 67 | train(args.file, args.config) 68 | --------------------------------------------------------------------------------