├── LICENSE
├── Makefile
├── README.md
├── configs
├── default_config.py
├── overfit_kitti_mf_gt.yaml
├── train_demon_mf_gt.yaml
├── train_kitti_mf_gt.yaml
├── train_kitti_mf_selfsup.yaml
├── train_nyu_mf_gt.yaml
├── train_rgbd_mf_gt.yaml
├── train_scannet_mf_gt_view2.yaml
├── train_scannet_mf_gt_view3.yaml
├── train_scannet_mf_gt_view5.yaml
├── train_scannet_mf_selfsup_view3.yaml
├── train_scannet_mf_selfsup_view5.yaml
├── train_scene11_mf_gt.yaml
├── train_sun3d_mf_gt.yaml
└── train_video_mf_selfsup_out_random.yaml
├── docker
└── Dockerfile
├── dro_sfm
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── augmentations.py
│ ├── demon_dataset.py
│ ├── demon_mf_dataset.py
│ ├── dgp_dataset.py
│ ├── image_dataset.py
│ ├── kitti_dataset.py
│ ├── kitti_dataset_utils.py
│ ├── nyu_dataset_processed.py
│ ├── nyu_dataset_test_processed.py
│ ├── scannet_banet_dataset.py
│ ├── scannet_dataset.py
│ ├── scannet_test_dataset.py
│ ├── transforms.py
│ ├── video_dataset.py
│ └── video_random_dataset.py
├── geometry
│ ├── __init__.py
│ ├── camera.py
│ ├── camera_utils.py
│ ├── pose.py
│ ├── pose_trans.py
│ └── pose_utils.py
├── loggers
│ ├── __init__.py
│ └── wandb_logger.py
├── losses
│ ├── __init__.py
│ ├── loss_base.py
│ ├── multiview_photometric_loss_mf.py
│ └── supervised_loss.py
├── models
│ ├── SelfSupModelMF.py
│ ├── SemiSupModelMF.py
│ ├── SfmModel.py
│ ├── SfmModelMF.py
│ ├── SupModelMF.py
│ ├── __init__.py
│ ├── model_checkpoint.py
│ ├── model_utils.py
│ └── model_wrapper.py
├── networks
│ ├── __init__.py
│ ├── depth_pose
│ │ └── DepthPoseNet.py
│ ├── layers
│ │ ├── PercepNet.py
│ │ └── resnet
│ │ │ ├── depth_decoder.py
│ │ │ ├── layers.py
│ │ │ ├── pose_decoder.py
│ │ │ ├── pose_res_decoder.py
│ │ │ └── resnet_encoder.py
│ └── optim
│ │ ├── __init__.py
│ │ ├── extractor.py
│ │ └── update.py
├── trainers
│ ├── __init__.py
│ ├── base_trainer.py
│ └── horovod_trainer.py
└── utils
│ ├── __init__.py
│ ├── config.py
│ ├── depth.py
│ ├── horovod.py
│ ├── image.py
│ ├── image_gt.py
│ ├── load.py
│ ├── logging.py
│ ├── misc.py
│ ├── reduce.py
│ ├── save.py
│ └── types.py
├── media
└── figs
│ ├── demo_kitti.gif
│ └── demo_scannet.gif
├── run.sh
└── scripts
├── eval.py
├── evaluate_depth_maps.py
├── infer.py
├── infer_pose.py
├── infer_pose.sh
├── infer_video.py
├── train.py
└── vis.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Alibaba Cloud
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Handy commands:
2 | # - `make docker-build`: builds DOCKERIMAGE (default: `dro-sfm:latest`)
3 | PROJECT ?= dro-sfm
4 | WORKSPACE ?= /workspace/$(PROJECT)
5 | DOCKER_IMAGE ?= ${PROJECT}:latest
6 |
7 | SHMSIZE ?= 32G
8 | WANDB_MODE ?= run
9 | DOCKER_OPTS := \
10 | --name ${PROJECT} \
11 | -it \
12 | --shm-size=${SHMSIZE} \
13 | -e AWS_DEFAULT_REGION \
14 | -e AWS_ACCESS_KEY_ID \
15 | -e AWS_SECRET_ACCESS_KEY \
16 | -e WANDB_API_KEY \
17 | -e WANDB_ENTITY \
18 | -e WANDB_MODE \
19 | -e HOST_HOSTNAME= \
20 | -e OMP_NUM_THREADS=1 -e KMP_AFFINITY="granularity=fine,compact,1,0" \
21 | -e OMPI_ALLOW_RUN_AS_ROOT=1 \
22 | -e OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 \
23 | -e NCCL_DEBUG=VERSION \
24 | -e DISPLAY=${DISPLAY} \
25 | -e XAUTHORITY \
26 | -e NVIDIA_DRIVER_CAPABILITIES=all \
27 | -v ~/.aws:/root/.aws \
28 | -v /root/.ssh:/root/.ssh \
29 | -v ~/.cache:/root/.cache \
30 | -v /data:/data \
31 | -v /data0:/data0 \
32 | -v /mnt:/mnt \
33 | -v /mnt/fsx/:/mnt/fsx \
34 | -v /dev/null:/dev/raw1394 \
35 | -v /tmp:/tmp \
36 | -v /tmp/.X11-unix/X0:/tmp/.X11-unix/X0 \
37 | -v /var/run/docker.sock:/var/run/docker.sock \
38 | -v ${PWD}:${WORKSPACE} \
39 | -w ${WORKSPACE} \
40 | --privileged \
41 | --ipc=host \
42 | --network=host
43 |
44 | NGPUS=$(shell nvidia-smi -L | wc -l)
45 | MPI_CMD=mpirun \
46 | -allow-run-as-root \
47 | -np ${NGPUS} \
48 | -H localhost:${NGPUS} \
49 | -x MASTER_ADDR=127.0.0.1 \
50 | -x MASTER_PORT=23457 \
51 | -x HOROVOD_TIMELINE \
52 | -x OMP_NUM_THREADS=1 \
53 | -x KMP_AFFINITY='granularity=fine,compact,1,0' \
54 | -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x NCCL_MIN_NRINGS=4 \
55 | --report-bindings
56 |
57 |
58 | .PHONY: all clean docker-build docker-overfit-pose
59 |
60 | all: clean
61 |
62 | clean:
63 | find . -name "*.pyc" | xargs rm -f && \
64 | find . -name "__pycache__" | xargs rm -rf
65 |
66 | docker-build:
67 | docker build \
68 | -f docker/Dockerfile \
69 | -t ${DOCKER_IMAGE} .
70 |
71 | docker-start-interactive:
72 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} bash
73 |
74 | docker-start-jupyter:
75 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} \
76 | bash -c "jupyter notebook --port=8888 -ip=0.0.0.0 --allow-root --no-browser"
77 |
78 | docker-run:
79 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} \
80 | bash -c "${COMMAND}"
81 |
82 | docker-run-mpi:
83 | docker run --gpus=all ${DOCKER_OPTS} ${DOCKER_IMAGE} \
84 | bash -c "${MPI_CMD} ${COMMAND}"
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## DRO: Deep Recurrent Optimizer for Structure-from-Motion
2 |
3 | This is the official PyTorch implementation code for DRO-sfm. For technical details, please refer to:
4 |
5 | **DRO: Deep Recurrent Optimizer for Structure-from-Motion**
6 | Xiaodong Gu\*, Weihao Yuan\*, Zuozhuo Dai, Chengzhou Tang, Siyu Zhu, Ping Tan
7 | **[[Paper](https://arxiv.org/abs/2103.13201)]**
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## Bibtex
15 | If you find this code useful in your research, please cite:
16 |
17 | ```
18 | @article{gu2021dro,
19 | title={DRO: Deep Recurrent Optimizer for Structure-from-Motion},
20 | author={Gu, Xiaodong and Yuan, Weihao and Dai, Zuozhuo and Tang, Chengzhou and Zhu, Siyu and Tan, Ping},
21 | journal={arXiv preprint arXiv:2103.13201},
22 | year={2021}
23 | }
24 | ```
25 |
26 | ## Contents
27 | 1. [Install](#install)
28 | 2. [Datasets](#datasets)
29 | 3. [Training](#training)
30 | 4. [Evaluation](#evaluation)
31 | 5. [Models](#models)
32 |
33 |
34 | ## Install
35 | + We recommend using [nvidia-docker2](https://github.com/NVIDIA/nvidia-docker) to have a reproducible environment.
36 |
37 | ```bash
38 | git clone https://github.com/aliyun/dro-sfm.git
39 | cd dro-sfm
40 | sudo make docker-build
41 | sudo make docker-start-interactive
42 | ```
43 | You can also download the built docker directly from [dro-sfm-image.tar](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/dro-sfm-image.tar)
44 |
45 | ```bash
46 | docker load < dro-sfm-image.tar
47 | ```
48 |
49 | + If you do not use docker, you could create an environment following the steps in the Dockerfile.
50 |
51 | ```bash
52 | # Environment variables
53 | export PYTORCH_VERSION=1.4.0
54 | export TORCHVISION_VERSION=0.5.0
55 | export NCCL_VERSION=2.4.8-1+cuda10.1
56 | export HOROVOD_VERSION=65de4c961d1e5ad2828f2f6c4329072834f27661
57 | # Install NCCL
58 | sudo apt-get install libnccl2=${NCCL_VERSION} libnccl-dev=${NCCL_VERSION}
59 |
60 | # Install Open MPI
61 | mkdir /tmp/openmpi && \
62 | cd /tmp/openmpi && \
63 | wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \
64 | tar zxf openmpi-4.0.0.tar.gz && \
65 | cd openmpi-4.0.0 && \
66 | ./configure --enable-orterun-prefix-by-default && \
67 | make -j $(nproc) all && \
68 | make install && \
69 | ldconfig && \
70 | rm -rf /tmp/openmpi
71 |
72 | # Install PyTorch
73 | pip install torch==${PYTORCH_VERSION} torchvision==${TORCHVISION_VERSION} && ldconfig
74 |
75 | # Install horovod (for distributed training)
76 | sudo ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL HOROVOD_WITH_PYTORCH=1 pip install --no-cache-dir git+https://github.com/horovod/horovod.git@${HOROVOD_VERSION} && sudo ldconfig
77 | ```
78 |
79 | To verify that the environment is setup correctly, you can run a simple overfitting test:
80 |
81 | ```bash
82 | # download a tiny subset of KITTI
83 | cd dro-sfm
84 | curl -s https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/KITTI_tiny.tar | tar xv -C /data/datasets/kitti/
85 | # in docker
86 | ./run.sh "python scripts/train.py configs/overfit_kitti_mf_gt.yaml" log.txt
87 | ```
88 |
89 | ## Datasets
90 | Datasets are assumed to be downloaded in `/data/datasets/` (can be a symbolic link).
91 |
92 | ### KITTI
93 | The KITTI (raw) dataset used in our experiments can be downloaded from the [KITTI website](http://www.cvlibs.net/datasets/kitti/raw_data.php).
94 | For convenience, you can download data from [packnet](https://tri-ml-public.s3.amazonaws.com/github/packnet-sfm/datasets/KITTI_raw.tar.gz) or [here](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/KITTI_raw.tar)
95 |
96 | ### Tiny KITTI
97 | For simple tests, you can download a "tiny" version of [KITTI](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/KITTI_tiny.tar):
98 |
99 |
100 | ### Scannet
101 | The Scannet (raw) dataset used in our experiments can be downloaded from the [Scannet website](http://www.scan-net.org).
102 | For convenience, you can download data from [here](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/datasets/scannet.tar)
103 |
104 |
105 | ### DeMoN
106 | Download [DeMoN](https://github.com/lmb-freiburg/demon/tree/master/datasets).
107 |
108 | ```bash
109 | bash download_traindata.sh
110 | python ./dataset/preparation/preparedata_train.py
111 | bash download_testdata.sh
112 | python ./dataset/preparation/preparedata_test.py
113 | ```
114 |
115 | ## Training
116 | Any training, including fine-tuning, can be done by passing either a `.yaml` config file or a `.ckpt` model checkpoint to [scripts/train.py](./scripts/train.py):
117 |
118 | ```bash
119 | # kitti, checkpoints will saved in ./results/mdoel/
120 | ./run.sh 'python scripts/train.py configs/train_kitti_mf_gt.yaml' logs/kitti_sup.txt
121 | ./run.sh 'python scripts/train.py configs/train_kitti_mf_selfsup.yaml' logs/kitti_selfsup.txt
122 |
123 | # scannet
124 | ./run.sh 'python scripts/train.py configs/train_scannet_mf_gt_view3.yaml' logs/scannet_sup.txt
125 | ./run.sh 'python scripts/train.py configs/train_scannet_mf_selfsup_view3.yaml' logs/scannet_selfsup.txt
126 | ./run.sh 'python scripts/train.py configs/train_scannet_mf_gt_view5.yaml' logs/scannet_sup_view5.txt
127 |
128 | # demon
129 | ./run.sh 'python scripts/train.py configs/train_demon_mf_gt.yaml' logs/demon_sup.txt
130 | ```
131 |
132 |
133 | ## Evaluation
134 |
135 | ```bash
136 | python scripts/eval.py --checkpoint [--config ]
137 | # example:kitti, results will be saved in results/depth/
138 | python scripts/eval.py --checkpoint ckpt/outdoor_kitti.ckpt --config configs/train_kitti_mf_gt.yaml
139 |
140 | ```
141 |
142 | You can also directly run inference on a single image or video:
143 |
144 | ```bash
145 | # video or folder
146 | # indoor-scannet
147 | python scripts/infer_video.py --checkpoint ckpt/indoor_sacnnet.ckpt --input /path/to/video or folder --output /path/to/save_folder --sample_rate 1 --data_type scannet --ply_mode
148 | # indoor-general
149 | python scripts/infer_video.py --checkpoint ckpt/indoor_sacnnet.ckpt --input /path/to/video or folder --output /path/to/save_folder --sample_rate 1 --data_type general --ply_mode
150 |
151 | # outdoor
152 | python scripts/infer_video.py --checkpoint ckpt/outdoor_kitti.ckpt --input /path/to/video or folder --output /path/to/save_folder --sample_rate 1 --data_type kitti --ply_mode
153 |
154 | # image
155 | python scripts/infer.py --checkpoint --input --output
156 | ```
157 |
158 |
159 | ## Models
160 |
161 | | Model | Abs.Rel. | Sqr.Rel | RMSE | RMSElog | a1 | a2 | a3| SILog| L1_inv| rot_ang| t_ang| t_cm|
162 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
163 | |[Kitti_sup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/outdoor_kitti.ckpt) | 0.045 | 0.193 | 2.570 | 0.080 | 0.971 | 0.994 | 0.998| 0.079 | 0.003 | -| -| -|
164 | |[Kitti_selfsup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/outdoor_kitti_selfsup.ckpt) | 0.053 |0.346 | 3.037 | 0.102 | 0.962 | 0.990| 0.996|0.101 | 0.004 | -| -| -|
165 | |[scannet_sup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/indoor_scannet.ckpt) | 0.053 | 0.017 | 0.165 | 0.080 | 0.967 | 0.994 | 0.998| 0.078 | 0.033| 0.472| 9.297| 1.160|
166 | |[scannet_sup(view5)](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/indoor_scannet_view5.ckpt) |0.047 |0.014 | 0.151 | 0.072 | 0.976 | 0.996 | 0.999| 0.071 | 0.030 | 0.456| 8.502| 1.163|
167 | |[scannet_selfsup](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/dro-sfm/models/indoor_scannet_selfsup.ckpt) | 0.143 | 0.345 | 0.656 | 0.274 | 0.896 | 0.954 | 0.969|0.272 | 0.106 | 0.609| 10.779| 1.393 |
168 |
169 |
170 |
171 | ## Acknowledgements
172 | Thanks to Toyota Research Institute for opening source of excellent work [packnet-sfm](https://github.com/TRI-ML/packnet-sfm). Thanks to Zachary Teed for opening source of his excellent work [RAFT](https://github.com/princeton-vl/RAFT).
173 |
--------------------------------------------------------------------------------
/configs/overfit_kitti_mf_gt.yaml:
--------------------------------------------------------------------------------
1 | name: 'overfit_kitti_gt'
2 | arch:
3 | max_epochs: 50
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | name: 'SupModelMF'
10 | optimizer:
11 | name: 'Adam'
12 | depth:
13 | lr: 0.0002
14 | pose:
15 | lr: 0.0002
16 | scheduler:
17 | name: 'StepLR'
18 | step_size: 30
19 | gamma: 0.5
20 | depth_net:
21 | name: 'DepthPoseNet'
22 | version: 'it12-h-out'
23 | loss:
24 | automask_loss: True
25 | photometric_reduce_op: 'min'
26 | params:
27 | crop: 'garg'
28 | min_depth: 0.2
29 | max_depth: 80.0
30 | datasets:
31 | augmentation:
32 | image_shape: (192, 640)
33 | train:
34 | batch_size: 4
35 | dataset: ['KITTI']
36 | path: ['/data/datasets/kitti/KITTI_tiny']
37 | split: ['kitti_tiny.txt']
38 | depth_type: ['velodyne']
39 | repeat: [100]
40 | forward_context: 1
41 | back_context: 1
42 | validation:
43 | dataset: ['KITTI']
44 | path: ['/data/datasets/kitti/KITTI_tiny']
45 | split: ['kitti_tiny.txt']
46 | depth_type: ['velodyne']
47 | forward_context: 1
48 | back_context: 1
49 | test:
50 | dataset: ['KITTI']
51 | path: ['/data/datasets/kitti/KITTI_tiny']
52 | split: ['kitti_tiny.txt']
53 | depth_type: ['velodyne']
54 | forward_context: 1
55 | back_context: 1
56 |
--------------------------------------------------------------------------------
/configs/train_demon_mf_gt.yaml:
--------------------------------------------------------------------------------
1 | name: 'demon_sun3d_rgbd_gt'
2 | arch:
3 | max_epochs: 100
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | # pretrain from xxx.ckpt
10 | # checkpoint_path: ./logs/models/xxc.ckpt
11 | name: 'SupModelMF'
12 | optimizer:
13 | name: 'Adam'
14 | depth:
15 | lr: 0.0002
16 | pose:
17 | lr: 0.0002
18 | scheduler:
19 | name: 'StepLR'
20 | step_size: 30
21 | gamma: 0.5
22 | depth_net:
23 | name: 'DepthPoseNet'
24 | version: 'it12-h-out'
25 | loss:
26 | automask_loss: True
27 | photometric_reduce_op: 'min'
28 | params:
29 | crop: ''
30 | min_depth: 0.2
31 | max_depth: 10.0
32 | datasets:
33 | augmentation:
34 | image_shape: (240, 320)
35 | train:
36 | batch_size: 12
37 | dataset: ['DemonMF']
38 | path: ['/data/datasets/demon/train']
39 | split: ['sun3d_train.txt', 'rgbd_train.txt']
40 | depth_type: ['groundtruth']
41 | repeat: [1, 5]
42 | forward_context: 1
43 | back_context: 0
44 | validation:
45 | dataset: ['Demon']
46 | path: ['/data/datasets/demon/test']
47 | split: ['sun3d_test.txt', 'rgbd_test.txt']
48 | depth_type: ['groundtruth']
49 | forward_context: 1
50 | back_context: 1
51 | test:
52 | dataset: ['Demon']
53 | path: ['/data/datasets/demon/test']
54 | split: ['sun3d_test.txt', 'rgbd_test.txt']
55 | depth_type: ['groundtruth']
56 | forward_context: 1
57 | back_context: 1
58 |
--------------------------------------------------------------------------------
/configs/train_kitti_mf_gt.yaml:
--------------------------------------------------------------------------------
1 | name: 'kitti_gt'
2 | save:
3 | folder: './results'
4 | arch:
5 | max_epochs: 50
6 | checkpoint:
7 | save_top_k: 10
8 | monitor: 'abs_rel_pp_gt'
9 | monitor_index: 0
10 | model:
11 | name: 'SupModelMF'
12 | optimizer:
13 | name: 'Adam'
14 | depth:
15 | lr: 0.0002
16 | pose:
17 | lr: 0.0002
18 | scheduler:
19 | name: 'StepLR'
20 | step_size: 30
21 | gamma: 0.5
22 | depth_net:
23 | name: 'DepthPoseNet'
24 | version: 'it12-h'
25 | loss:
26 | automask_loss: True
27 | photometric_reduce_op: 'min'
28 | params:
29 | crop: 'garg'
30 | min_depth: 0.2
31 | max_depth: 80.0
32 | datasets:
33 | augmentation:
34 | image_shape: (320, 960)
35 | train:
36 | batch_size: 2
37 | dataset: ['KITTI']
38 | path: ['/data/datasets/kitti/KITTI_raw']
39 | split: ['data_splits/eigen_zhou_files.txt']
40 | depth_type: ['groundtruth']
41 | repeat: [2]
42 | forward_context: 1
43 | back_context: 1
44 | validation:
45 | dataset: ['KITTI']
46 | path: ['/data/datasets/kitti/KITTI_raw']
47 | split: ['data_splits/eigen_val_files.txt',
48 | 'data_splits/eigen_test_files.txt']
49 | depth_type: ['groundtruth']
50 | forward_context: 1
51 | back_context: 0
52 | test:
53 | dataset: ['KITTI']
54 | path: ['/data/datasets/kitti/KITTI_raw']
55 | split: ['data_splits/eigen_test_files.txt']
56 | depth_type: ['groundtruth']
57 | forward_context: 1
58 | back_context: 0
--------------------------------------------------------------------------------
/configs/train_kitti_mf_selfsup.yaml:
--------------------------------------------------------------------------------
1 | name: 'kitt_selfsup_view3'
2 | arch:
3 | max_epochs: 50
4 | model:
5 | name: 'SelfSupModelMF'
6 | optimizer:
7 | name: 'Adam'
8 | depth:
9 | lr: 0.0002
10 | pose:
11 | lr: 0.0002
12 | scheduler:
13 | name: 'StepLR'
14 | step_size: 30
15 | gamma: 0.5
16 | depth_net:
17 | name: 'DepthPoseNet'
18 | version: 'it8-seq4-inter-out'
19 | loss:
20 | automask_loss: True
21 | photometric_reduce_op: 'min'
22 | params:
23 | crop: 'garg'
24 | min_depth: 0.5
25 | max_depth: 80.0
26 | datasets:
27 | augmentation:
28 | image_shape: (320, 960)
29 | train:
30 | batch_size: 2
31 | dataset: ['KITTI']
32 | path: ['/data/datasets/kitti/KITTI_raw']
33 | split: ['data_splits/eigen_zhou_files.txt']
34 | depth_type: ['groundtruth']
35 | repeat: [2]
36 | forward_context: 1
37 | back_context: 1
38 | validation:
39 | dataset: ['KITTI']
40 | path: ['/data/datasets/kitti/KITTI_raw']
41 | split: ['data_splits/eigen_val_files.txt',
42 | 'data_splits/eigen_test_files.txt']
43 | depth_type: ['groundtruth']
44 | forward_context: 1
45 | back_context: 0
46 | test:
47 | dataset: ['KITTI']
48 | path: ['/data/datasets/kitti/KITTI_raw']
49 | split: ['data_splits/eigen_test_files.txt']
50 | depth_type: ['groundtruth']
51 | forward_context: 1
52 | back_context: 0
53 |
--------------------------------------------------------------------------------
/configs/train_nyu_mf_gt.yaml:
--------------------------------------------------------------------------------
1 | name: 'nyu_semi'
2 | arch:
3 | max_epochs: 100
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | name: 'SemiSupModelMF'
10 | optimizer:
11 | name: 'Adam'
12 | depth:
13 | lr: 0.0002
14 | pose:
15 | lr: 0.0002
16 | scheduler:
17 | name: 'StepLR'
18 | step_size: 30
19 | gamma: 0.5
20 | depth_net:
21 | name: 'DepthPoseNet'
22 | version: 'it12-h-out'
23 | loss:
24 | automask_loss: True
25 | photometric_reduce_op: 'min'
26 | params:
27 | crop: 'eigen_nyu'
28 | min_depth: 0.2
29 | max_depth: 10.0
30 | datasets:
31 | augmentation:
32 | image_shape: (240, 320)
33 | train:
34 | batch_size: 8
35 | dataset: ['NYU']
36 | path: ['/data/datasets/nyu/train_processed']
37 | split: ['']
38 | depth_type: ['groundtruth']
39 | repeat: [1]
40 | forward_context: 1
41 | back_context: 1
42 | validation:
43 | dataset: ['NYUtest']
44 | path: ['/data/datasets/nyu/test_processed']
45 | split: ['test_depth.txt']
46 | depth_type: ['groundtruth']
47 | forward_context: 1
48 | back_context: 1
49 | test:
50 | dataset: ['NYUtest']
51 | path: ['/data/datasets/nyu/test_processed']
52 | split: ['test_depth.txt']
53 | depth_type: ['groundtruth']
54 | forward_context: 1
55 | back_context: 1
56 |
--------------------------------------------------------------------------------
/configs/train_rgbd_mf_gt.yaml:
--------------------------------------------------------------------------------
1 | name: 'rgbd_gt'
2 | arch:
3 | max_epochs: 100
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | name: 'SupModelMF'
10 | optimizer:
11 | name: 'Adam'
12 | depth:
13 | lr: 0.0002
14 | pose:
15 | lr: 0.0002
16 | scheduler:
17 | name: 'StepLR'
18 | step_size: 30
19 | gamma: 0.5
20 | depth_net:
21 | name: 'DepthPoseNet'
22 | version: 'it12-h-out'
23 | loss:
24 | automask_loss: True
25 | photometric_reduce_op: 'min'
26 | params:
27 | crop: ''
28 | min_depth: 0.2
29 | max_depth: 10.0
30 | datasets:
31 | augmentation:
32 | image_shape: (240, 320)
33 | train:
34 | batch_size: 12
35 | dataset: ['Demon']
36 | path: ['/data/datasets/demon/train']
37 | split: ['rgbd_train.txt']
38 | depth_type: ['groundtruth']
39 | repeat: [1]
40 | forward_context: 1
41 | back_context: 1
42 | validation:
43 | dataset: ['Demon']
44 | path: ['/data/datasets/demon/test']
45 | split: ['rgbd_test.txt']
46 | depth_type: ['groundtruth']
47 | forward_context: 1
48 | back_context: 1
49 | test:
50 | dataset: ['Demon']
51 | path: ['/data/datasets/demon/test']
52 | split: ['rgbd_test.txt']
53 | depth_type: ['groundtruth']
54 | forward_context: 1
55 | back_context: 1
56 |
--------------------------------------------------------------------------------
/configs/train_scannet_mf_gt_view2.yaml:
--------------------------------------------------------------------------------
1 | name: 'scannet_gt_view2'
2 | save:
3 | folder: './results'
4 | arch:
5 | max_epochs: 100
6 | checkpoint:
7 | save_top_k: 10
8 | monitor: 'abs_rel_pp_gt'
9 | monitor_index: 0
10 | model:
11 | name: 'SupModelMF'
12 | optimizer:
13 | name: 'Adam'
14 | depth:
15 | lr: 0.0002
16 | pose:
17 | lr: 0.0002
18 | scheduler:
19 | name: 'StepLR'
20 | step_size: 30
21 | gamma: 0.5
22 | depth_net:
23 | name: 'DepthPoseNet'
24 | version: 'it12-h-out'
25 | loss:
26 | automask_loss: True
27 | photometric_reduce_op: 'min'
28 | params:
29 | crop: ''
30 | min_depth: 0.2
31 | max_depth: 10.0
32 | datasets:
33 | augmentation:
34 | image_shape: (240, 320)
35 | train:
36 | batch_size: 12
37 | dataset: ['ScannetBA']
38 | path: ['/data/datasets/scannet/train']
39 | split: ['splits/train_all_list.txt']
40 | depth_type: ['groundtruth']
41 | repeat: [1]
42 | forward_context: 1
43 | back_context: 0
44 | num_workers: 32
45 | validation:
46 | dataset: ['ScannetTest']
47 | path: ['/data/datasets/scannet/test']
48 | split: ['splits/test_split.txt']
49 | depth_type: ['groundtruth']
50 | forward_context: 1
51 | back_context: 0
52 | test:
53 | dataset: ['ScannetTest']
54 | path: ['/data/datasets/scannet/test']
55 | split: ['splits/test_split.txt']
56 | depth_type: ['groundtruth']
57 | forward_context: 1
58 | back_context: 0
59 |
--------------------------------------------------------------------------------
/configs/train_scannet_mf_gt_view3.yaml:
--------------------------------------------------------------------------------
1 | name: 'scannet_gt_view3'
2 | save:
3 | folder: './results'
4 | arch:
5 | max_epochs: 100
6 | checkpoint:
7 | save_top_k: 10
8 | monitor: 'abs_rel_pp_gt'
9 | monitor_index: 0
10 | model:
11 | name: 'SupModelMF'
12 | optimizer:
13 | name: 'Adam'
14 | depth:
15 | lr: 0.0002
16 | pose:
17 | lr: 0.0002
18 | scheduler:
19 | name: 'StepLR'
20 | step_size: 30
21 | gamma: 0.5
22 | depth_net:
23 | name: 'DepthPoseNet'
24 | version: 'it12-h-out'
25 | loss:
26 | automask_loss: True
27 | photometric_reduce_op: 'min'
28 | params:
29 | crop: ''
30 | min_depth: 0.2
31 | max_depth: 10.0
32 | datasets:
33 | augmentation:
34 | image_shape: (240, 320)
35 | train:
36 | batch_size: 8
37 | dataset: ['ScannetBA']
38 | path: ['/data/datasets/scannet/train']
39 | split: ['splits/train_all_list.txt']
40 | depth_type: ['groundtruth']
41 | repeat: [1]
42 | forward_context: 1
43 | back_context: 1
44 | validation:
45 | dataset: ['ScannetTest']
46 | path: ['/data/datasets/scannet/test']
47 | split: ['splits/test_split.txt']
48 | depth_type: ['groundtruth']
49 | forward_context: 1
50 | back_context: 0
51 | test:
52 | dataset: ['ScannetTest']
53 | path: ['/data/datasets/scannet/test']
54 | split: ['splits/test_split.txt']
55 | depth_type: ['groundtruth']
56 | forward_context: 1
57 | back_context: 0
58 |
--------------------------------------------------------------------------------
/configs/train_scannet_mf_gt_view5.yaml:
--------------------------------------------------------------------------------
1 | name: 'scannet_gt_view5'
2 | save:
3 | folder: './results'
4 | arch:
5 | max_epochs: 100
6 | checkpoint:
7 | save_top_k: 10
8 | monitor: 'abs_rel_pp_gt'
9 | monitor_index: 0
10 | model:
11 | name: 'SupModelMF'
12 | optimizer:
13 | name: 'Adam'
14 | depth:
15 | lr: 0.0002
16 | pose:
17 | lr: 0.0002
18 | scheduler:
19 | name: 'StepLR'
20 | step_size: 30
21 | gamma: 0.5
22 | depth_net:
23 | name: 'DepthPoseNet'
24 | version: 'it12-h-out'
25 | loss:
26 | automask_loss: True
27 | photometric_reduce_op: 'min'
28 | params:
29 | crop: ''
30 | min_depth: 0.2
31 | max_depth: 10.0
32 | datasets:
33 | augmentation:
34 | image_shape: (384, 512)
35 | train:
36 | batch_size: 2
37 | dataset: ['ScannetBA']
38 | path: ['/data/datasets/scannet/train']
39 | split: ['splits/train_all_list.txt']
40 | depth_type: ['groundtruth']
41 | repeat: [1]
42 | forward_context: 2
43 | back_context: 2
44 | num_workers: 16
45 | validation:
46 | dataset: ['ScannetTest']
47 | path: ['/data/datasets/scannet/test']
48 | split: ['splits/test_split.txt']
49 | depth_type: ['groundtruth']
50 | forward_context: 1
51 | back_context: 0
52 | test:
53 | dataset: ['ScannetTest']
54 | path: ['/data/datasets/scannet/test']
55 | split: ['splits/test_split.txt']
56 | depth_type: ['groundtruth']
57 | forward_context: 2
58 | back_context: 2
59 |
--------------------------------------------------------------------------------
/configs/train_scannet_mf_selfsup_view3.yaml:
--------------------------------------------------------------------------------
1 | name: 'scannet_selfsup_view3'
2 | arch:
3 | max_epochs: 100
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | name: 'SelfSupModelMF'
10 | optimizer:
11 | name: 'Adam'
12 | depth:
13 | lr: 0.0002
14 | pose:
15 | lr: 0.0002
16 | scheduler:
17 | name: 'StepLR'
18 | step_size: 30
19 | gamma: 0.5
20 | depth_net:
21 | name: 'DepthPoseNet'
22 | version: 'it12-h-out'
23 | loss:
24 | automask_loss: True
25 | photometric_reduce_op: 'min'
26 | params:
27 | crop: ''
28 | min_depth: 0.2
29 | max_depth: 10.0
30 | datasets:
31 | augmentation:
32 | image_shape: (320, 512)
33 | train:
34 | batch_size: 3
35 | dataset: ['ScannetBA']
36 | path: ['/data/datasets/scannet/train']
37 | split: ['splits/train_all_list.txt']
38 | depth_type: ['groundtruth']
39 | repeat: [1]
40 | forward_context: 1
41 | back_context: 1
42 | validation:
43 | dataset: ['ScannetTest']
44 | path: ['/data/datasets/scannet/test']
45 | split: ['splits/test_split.txt']
46 | depth_type: ['groundtruth']
47 | forward_context: 1
48 | back_context: 1
49 | test:
50 | dataset: ['ScannetTest']
51 | path: ['/data/datasets/scannet/test']
52 | split: ['splits/test_split.txt']
53 | depth_type: ['groundtruth']
54 | forward_context: 1
55 | back_context: 1
56 |
--------------------------------------------------------------------------------
/configs/train_scannet_mf_selfsup_view5.yaml:
--------------------------------------------------------------------------------
1 | name: 'scannet_selfsup_view5'
2 | arch:
3 | max_epochs: 100
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | name: 'SelfSupModelMF'
10 | optimizer:
11 | name: 'Adam'
12 | depth:
13 | lr: 0.0002
14 | pose:
15 | lr: 0.0002
16 | scheduler:
17 | name: 'StepLR'
18 | step_size: 30
19 | gamma: 0.5
20 | depth_net:
21 | name: 'DepthPoseNet'
22 | version: 'it12-h-out'
23 | loss:
24 | automask_loss: True
25 | photometric_reduce_op: 'min'
26 | params:
27 | crop: ''
28 | min_depth: 0.2
29 | max_depth: 10.0
30 | datasets:
31 | augmentation:
32 | image_shape: (240, 320)
33 | train:
34 | batch_size: 4
35 | dataset: ['ScannetBA']
36 | path: ['/data/datasets/scannet/train']
37 | split: ['splits/train_all_list.txt']
38 | depth_type: ['groundtruth']
39 | repeat: [1]
40 | forward_context: 2
41 | back_context: 2
42 | num_workers: 16
43 | validation:
44 | dataset: ['ScannetTest']
45 | path: ['/data/datasets/scannet/test']
46 | split: ['splits/test_split.txt']
47 | depth_type: ['groundtruth']
48 | forward_context: 1
49 | back_context: 0
50 | test:
51 | dataset: ['ScannetTest']
52 | path: ['/data/datasets/scannet/test']
53 | split: ['splits/test_split.txt']
54 | depth_type: ['groundtruth']
55 | forward_context: 1
56 | back_context: 0
57 |
--------------------------------------------------------------------------------
/configs/train_scene11_mf_gt.yaml:
--------------------------------------------------------------------------------
1 | name: 'scene11_gt'
2 | arch:
3 | max_epochs: 100
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | name: 'SupModelMF'
10 | optimizer:
11 | name: 'Adam'
12 | depth:
13 | lr: 0.0002
14 | pose:
15 | lr: 0.0002
16 | scheduler:
17 | name: 'StepLR'
18 | step_size: 30
19 | gamma: 0.5
20 | depth_net:
21 | name: 'DepthPoseNet'
22 | version: 'it12-h-out'
23 | loss:
24 | automask_loss: True
25 | photometric_reduce_op: 'min'
26 | params:
27 | crop: ''
28 | min_depth: 0.2
29 | max_depth: 10.0
30 | datasets:
31 | augmentation:
32 | image_shape: (240, 320)
33 | train:
34 | batch_size: 12
35 | dataset: ['Demon']
36 | path: ['/data/datasets/demon/train']
37 | split: ['scene11_train.txt']
38 | depth_type: ['groundtruth']
39 | repeat: [1]
40 | forward_context: 1
41 | back_context: 1
42 | validation:
43 | dataset: ['Demon']
44 | path: ['/data/datasets/demon/test']
45 | split: ['scene11_test.txt']
46 | depth_type: ['groundtruth']
47 | forward_context: 1
48 | back_context: 1
49 | test:
50 | dataset: ['Demon']
51 | path: ['/data/datasets/demon/test']
52 | split: ['scene11_test.txt']
53 | depth_type: ['groundtruth']
54 | forward_context: 1
55 | back_context: 1
56 |
--------------------------------------------------------------------------------
/configs/train_sun3d_mf_gt.yaml:
--------------------------------------------------------------------------------
1 | name: 'sun3d_gt'
2 | arch:
3 | max_epochs: 100
4 | checkpoint:
5 | save_top_k: 10
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | name: 'SupModelMF'
10 | optimizer:
11 | name: 'Adam'
12 | depth:
13 | lr: 0.0002
14 | pose:
15 | lr: 0.0002
16 | scheduler:
17 | name: 'StepLR'
18 | step_size: 30
19 | gamma: 0.5
20 | depth_net:
21 | name: 'DepthPoseNet'
22 | version: 'it12-h-out'
23 | loss:
24 | automask_loss: True
25 | photometric_reduce_op: 'min'
26 | params:
27 | crop: ''
28 | min_depth: 0.2
29 | max_depth: 10.0
30 | datasets:
31 | augmentation:
32 | image_shape: (240, 320)
33 | train:
34 | batch_size: 12
35 | dataset: ['DemonMF']
36 | path: ['/data/datasets/demon/train']
37 | split: ['sun3d_train.txt']
38 | depth_type: ['groundtruth']
39 | repeat: [1]
40 | forward_context: 1
41 | back_context: 0
42 | validation:
43 | dataset: ['Demon']
44 | path: ['/data/datasets/demon/test']
45 | split: ['sun3d_test.txt']
46 | depth_type: ['groundtruth']
47 | forward_context: 1
48 | back_context: 1
49 | test:
50 | dataset: ['Demon']
51 | path: ['/data/datasets/demon/test']
52 | split: ['sun3d_test.txt']
53 | depth_type: ['groundtruth']
54 | forward_context: 1
55 | back_context: 1
56 |
--------------------------------------------------------------------------------
/configs/train_video_mf_selfsup_out_random.yaml:
--------------------------------------------------------------------------------
1 | name: 'video_selfsup_out_random'
2 | arch:
3 | max_epochs: 10
4 | checkpoint:
5 | save_top_k: -1
6 | monitor: 'abs_rel_pp_gt'
7 | monitor_index: 0
8 | model:
9 | checkpoint_path: "results/indoor_scannet.ckpt" #pretrain model
10 | name: 'SelfSupModelMF'
11 | optimizer:
12 | name: 'Adam'
13 | depth:
14 | lr: 0.0002
15 | pose:
16 | lr: 0.0002
17 | scheduler:
18 | name: 'StepLR'
19 | step_size: 30
20 | gamma: 0.5
21 | depth_net:
22 | name: 'DepthPoseNet'
23 | version: 'it12-h-out'
24 | loss:
25 | automask_loss: True
26 | photometric_reduce_op: 'min'
27 | params:
28 | crop: ''
29 | min_depth: 0.2
30 | max_depth: 10.0
31 | datasets:
32 | augmentation:
33 | image_shape: (240, 320)
34 | train:
35 | batch_size: 6
36 | dataset: ['Video_Random']
37 | path: ['/data/datasets/video/indoor']
38 | split: ['']
39 | depth_type: ['']
40 | repeat: [30]
41 | forward_context: 1
42 | back_context: 1
43 | strides: (2, )
44 | validation:
45 | dataset: ['Video']
46 | path: ['/data/datasets/video/indoor']
47 | split: ['']
48 | depth_type: ['']
49 | forward_context: 1
50 | back_context: 1
51 | strides: (1, )
52 | test:
53 | dataset: ['Video']
54 | path: ['/data/datasets/video/indoor']
55 | split: ['']
56 | depth_type: ['']
57 | forward_context: 1
58 | back_context: 1
59 | strides: (2, )
60 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 |
2 | FROM nvidia/cuda:10.1-devel-ubuntu18.04
3 |
4 | ENV PROJECT=dro-sfm
5 | ENV PYTORCH_VERSION=1.4.0
6 | ENV TORCHVISION_VERSION=0.5.0
7 | ENV CUDNN_VERSION=7.6.5.32-1+cuda10.1
8 | ENV NCCL_VERSION=2.4.8-1+cuda10.1
9 | ENV HOROVOD_VERSION=65de4c961d1e5ad2828f2f6c4329072834f27661
10 | ENV LC_ALL=C.UTF-8
11 | ENV LANG=C.UTF-8
12 |
13 | ARG python=3.6
14 | ENV PYTHON_VERSION=${python}
15 | ENV DEBIAN_FRONTEND=noninteractive
16 |
17 | # Set default shell to /bin/bash
18 | SHELL ["/bin/bash", "-cu"]
19 |
20 | RUN apt-get clean && cd /var/lib/apt && mv lists lists.old && mkdir -p lists/partial && apt-get clean && apt-get update
21 |
22 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
23 | build-essential \
24 | cmake \
25 | g++-4.8 \
26 | git \
27 | curl \
28 | docker.io \
29 | vim \
30 | wget \
31 | ca-certificates \
32 | libcudnn7=${CUDNN_VERSION} \
33 | libnccl2=${NCCL_VERSION} \
34 | libnccl-dev=${NCCL_VERSION} \
35 | libjpeg-dev \
36 | libpng-dev \
37 | python${PYTHON_VERSION} \
38 | python${PYTHON_VERSION}-dev \
39 | python3-tk \
40 | librdmacm1 \
41 | libibverbs1 \
42 | ibverbs-providers \
43 | libgtk2.0-dev \
44 | unzip \
45 | bzip2 \
46 | htop \
47 | gnuplot \
48 | ffmpeg
49 |
50 | # Install Open MPI
51 | RUN mkdir /tmp/openmpi && \
52 | cd /tmp/openmpi && \
53 | wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz && \
54 | tar zxf openmpi-4.0.0.tar.gz && \
55 | cd openmpi-4.0.0 && \
56 | ./configure --enable-orterun-prefix-by-default && \
57 | make -j $(nproc) all && \
58 | make install && \
59 | ldconfig && \
60 | rm -rf /tmp/openmpi
61 |
62 | # Install OpenSSH for MPI to communicate between containers
63 | RUN apt-get install -y --no-install-recommends openssh-client openssh-server && \
64 | mkdir -p /var/run/sshd
65 |
66 | # Allow OpenSSH to talk to containers without asking for confirmation
67 | RUN cat /etc/ssh/ssh_config | grep -v StrictHostKeyChecking > /etc/ssh/ssh_config.new && \
68 | echo " StrictHostKeyChecking no" >> /etc/ssh/ssh_config.new && \
69 | mv /etc/ssh/ssh_config.new /etc/ssh/ssh_config
70 |
71 | # Instal Python and pip
72 | RUN if [[ "${PYTHON_VERSION}" == "3.6" ]]; then \
73 | apt-get install -y python${PYTHON_VERSION}-distutils; \
74 | fi
75 |
76 | RUN ln -sf /usr/bin/python${PYTHON_VERSION} /usr/bin/python
77 |
78 | RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
79 | python get-pip.py && \
80 | rm get-pip.py
81 |
82 | # Install Pydata and other deps
83 | RUN pip install future typing numpy pandas matplotlib jupyter h5py \
84 | awscli boto3 tqdm termcolor path.py pillow-simd opencv-python-headless \
85 | mpi4py onnx onnxruntime pycuda yacs cython==0.29.10
86 |
87 | # Install PyTorch
88 | RUN pip install torch==${PYTORCH_VERSION} \
89 | torchvision==${TORCHVISION_VERSION} && ldconfig
90 |
91 |
92 | # install horovod (for distributed training)
93 | RUN ldconfig /usr/local/cuda/targets/x86_64-linux/lib/stubs && \
94 | HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL HOROVOD_WITH_PYTORCH=1 \
95 | pip install --no-cache-dir git+https://github.com/horovod/horovod.git@${HOROVOD_VERSION} && \
96 | ldconfig
97 |
98 |
99 | # Override DGP wandb with required version
100 | RUN pip install wandb==0.8.21
101 |
102 | # Expose Port for jupyter (8888)
103 | EXPOSE 8888
104 |
105 | # create project workspace dir
106 | RUN mkdir -p /workspace/experiments
107 | RUN mkdir -p /workspace/${PROJECT}
108 | WORKDIR /workspace/${PROJECT}
109 |
110 | # Copy project source last (to avoid cache busting)
111 | WORKDIR /workspace/${PROJECT}
112 | COPY . /workspace/${PROJECT}
113 | ENV PYTHONPATH="/workspace/${PROJECT}:$PYTHONPATH"
--------------------------------------------------------------------------------
/dro_sfm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/__init__.py
--------------------------------------------------------------------------------
/dro_sfm/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/augmentations.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | import numpy as np
4 | import random
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 |
8 | from dro_sfm.utils.misc import filter_dict
9 |
10 | ########################################################################################################################
11 |
12 | def resize_image(image, shape, interpolation=Image.ANTIALIAS):
13 | """
14 | Resizes input image.
15 |
16 | Parameters
17 | ----------
18 | image : Image.PIL
19 | Input image
20 | shape : tuple [H,W]
21 | Output shape
22 | interpolation : int
23 | Interpolation mode
24 |
25 | Returns
26 | -------
27 | image : Image.PIL
28 | Resized image
29 | """
30 | transform = transforms.Resize(shape, interpolation=interpolation)
31 | return transform(image)
32 |
33 | def resize_depth(depth, shape):
34 | """
35 | Resizes depth map.
36 |
37 | Parameters
38 | ----------
39 | depth : np.array [h,w]
40 | Depth map
41 | shape : tuple (H,W)
42 | Output shape
43 |
44 | Returns
45 | -------
46 | depth : np.array [H,W]
47 | Resized depth map
48 | """
49 | depth = cv2.resize(depth, dsize=shape[::-1],
50 | interpolation=cv2.INTER_NEAREST)
51 | return np.expand_dims(depth, axis=2)
52 |
53 | def resize_sample_image_and_intrinsics(sample, shape,
54 | image_interpolation=Image.ANTIALIAS):
55 | """
56 | Resizes the image and intrinsics of a sample
57 |
58 | Parameters
59 | ----------
60 | sample : dict
61 | Dictionary with sample values
62 | shape : tuple (H,W)
63 | Output shape
64 | image_interpolation : int
65 | Interpolation mode
66 |
67 | Returns
68 | -------
69 | sample : dict
70 | Resized sample
71 | """
72 | # Resize image and corresponding intrinsics
73 | image_transform = transforms.Resize(shape, interpolation=image_interpolation)
74 | (orig_w, orig_h) = sample['rgb'].size
75 | (out_h, out_w) = shape
76 | # Scale intrinsics
77 | for key in filter_dict(sample, [
78 | 'intrinsics'
79 | ]):
80 | intrinsics = np.copy(sample[key])
81 | intrinsics[0] *= out_w / orig_w
82 | intrinsics[1] *= out_h / orig_h
83 | sample[key] = intrinsics
84 | # Scale images
85 | for key in filter_dict(sample, [
86 | 'rgb', 'rgb_original',
87 | ]):
88 | sample[key] = image_transform(sample[key])
89 | # Scale context images
90 | for key in filter_dict(sample, [
91 | 'rgb_context', 'rgb_context_original',
92 | ]):
93 | sample[key] = [image_transform(k) for k in sample[key]]
94 | # Return resized sample
95 | return sample
96 |
97 | def resize_sample(sample, shape, image_interpolation=Image.ANTIALIAS):
98 | """
99 | Resizes a sample, including image, intrinsics and depth maps.
100 |
101 | Parameters
102 | ----------
103 | sample : dict
104 | Dictionary with sample values
105 | shape : tuple (H,W)
106 | Output shape
107 | image_interpolation : int
108 | Interpolation mode
109 |
110 | Returns
111 | -------
112 | sample : dict
113 | Resized sample
114 | """
115 | # Resize image and intrinsics
116 | sample = resize_sample_image_and_intrinsics(sample, shape, image_interpolation)
117 | # Resize depth maps
118 | for key in filter_dict(sample, [
119 | 'depth',
120 | ]):
121 | sample[key] = resize_depth(sample[key], shape)
122 | # Resize depth contexts
123 | for key in filter_dict(sample, [
124 | 'depth_context',
125 | ]):
126 | sample[key] = [resize_depth(k, shape) for k in sample[key]]
127 | # Return resized sample
128 | return sample
129 |
130 | ########################################################################################################################
131 |
132 | def to_tensor(image, tensor_type='torch.FloatTensor'):
133 | """Casts an image to a torch.Tensor"""
134 | transform = transforms.ToTensor()
135 | return transform(image).type(tensor_type)
136 |
137 | def to_tensor_sample(sample, tensor_type='torch.FloatTensor'):
138 | """
139 | Casts the keys of sample to tensors.
140 |
141 | Parameters
142 | ----------
143 | sample : dict
144 | Input sample
145 | tensor_type : str
146 | Type of tensor we are casting to
147 |
148 | Returns
149 | -------
150 | sample : dict
151 | Sample with keys cast as tensors
152 | """
153 | transform = transforms.ToTensor()
154 | # Convert single items
155 | for key in filter_dict(sample, [
156 | 'rgb', 'rgb_original', 'depth',
157 | ]):
158 | sample[key] = transform(sample[key]).type(tensor_type)
159 | # Convert lists
160 | for key in filter_dict(sample, [
161 | 'rgb_context', 'rgb_context_original', 'depth_context'
162 | ]):
163 | sample[key] = [transform(k).type(tensor_type) for k in sample[key]]
164 | # Return converted sample
165 | return sample
166 |
167 | ########################################################################################################################
168 |
169 | def duplicate_sample(sample):
170 | """
171 | Duplicates sample images and contexts to preserve their unaugmented versions.
172 |
173 | Parameters
174 | ----------
175 | sample : dict
176 | Input sample
177 |
178 | Returns
179 | -------
180 | sample : dict
181 | Sample including [+"_original"] keys with copies of images and contexts.
182 | """
183 | # Duplicate single items
184 | for key in filter_dict(sample, [
185 | 'rgb'
186 | ]):
187 | sample['{}_original'.format(key)] = sample[key].copy()
188 | # Duplicate lists
189 | for key in filter_dict(sample, [
190 | 'rgb_context'
191 | ]):
192 | sample['{}_original'.format(key)] = [k.copy() for k in sample[key]]
193 | # Return duplicated sample
194 | return sample
195 |
196 | def colorjitter_sample(sample, parameters, prob=1.0):
197 | """
198 | Jitters input images as data augmentation.
199 |
200 | Parameters
201 | ----------
202 | sample : dict
203 | Input sample
204 | parameters : tuple (brightness, contrast, saturation, hue)
205 | Color jittering parameters
206 | prob : float
207 | Jittering probability
208 |
209 | Returns
210 | -------
211 | sample : dict
212 | Jittered sample
213 | """
214 | if random.random() < prob:
215 | # Prepare transformation
216 | color_augmentation = transforms.ColorJitter()
217 | brightness, contrast, saturation, hue = parameters
218 | augment_image = color_augmentation.get_params(
219 | brightness=[max(0, 1 - brightness), 1 + brightness],
220 | contrast=[max(0, 1 - contrast), 1 + contrast],
221 | saturation=[max(0, 1 - saturation), 1 + saturation],
222 | hue=[-hue, hue])
223 | # Jitter single items
224 | for key in filter_dict(sample, [
225 | 'rgb'
226 | ]):
227 | sample[key] = augment_image(sample[key])
228 | # Jitter lists
229 | for key in filter_dict(sample, [
230 | 'rgb_context'
231 | ]):
232 | sample[key] = [augment_image(k) for k in sample[key]]
233 | # Return jittered (?) sample
234 | return sample
235 |
236 | ########################################################################################################################
237 |
238 |
239 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/demon_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | from collections import defaultdict
4 | import os
5 |
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | from dro_sfm.utils.image import load_image
9 | import cv2, IPython
10 | from PIL import Image
11 | from PIL import ImageFile
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 | import h5py
14 | ########################################################################################################################
15 | #### FUNCTIONS
16 | ########################################################################################################################
17 |
18 | def dummy_calibration(image):
19 | return np.array([[5.703422047415297129e+02 , 0. , 3.200000000000000000e+02],
20 | [0. , 5.703422047415297129e+02 , 2.400000000000000000e+02],
21 | [0. , 0. , 1. ]])
22 |
23 |
24 | ########################################################################################################################
25 | #### DATASET
26 | ########################################################################################################################
27 |
28 | class DemonDataset(Dataset):
29 | def __init__(self, root_dir, split, data_transform=None,
30 | forward_context=0, back_context=0, strides=(1,),
31 | depth_type=None, **kwargs):
32 | super().__init__()
33 | # Asserts
34 | # assert depth_type is None or depth_type == '', \
35 | # 'NYUDataset currently does not support depth types'
36 | assert len(strides) == 1 and strides[0] == 1, \
37 | 'NYUDataset currently only supports stride of 1.'
38 |
39 | self.depth_type = depth_type
40 | self.with_depth = depth_type is not '' and depth_type is not None
41 | self.root_dir = root_dir
42 | self.split = split
43 | with open(os.path.join(self.root_dir, split), "r") as f:
44 | data = f.readlines()
45 |
46 | self.paths = []
47 | # Get file list from data
48 | for i, fname in enumerate(data):
49 | path = os.path.join(root_dir, fname.split()[0])
50 | # if os.path.exists(path):
51 | self.paths.append(path)
52 |
53 |
54 | self.backward_context = back_context
55 | self.forward_context = forward_context
56 | self.has_context = self.backward_context + self.forward_context > 0
57 | self.strides = strides[0]
58 |
59 | self.data_transform = data_transform
60 |
61 | def __len__(self):
62 | return len(self.paths)
63 |
64 | def _change_idx(self, idx, filename):
65 | _, ext = os.path.splitext(os.path.basename(filename))
66 | return str(idx) + ext
67 |
68 | def __getitem__(self, idx):
69 | filepath = self.paths[idx]
70 |
71 | image = load_image(os.path.join(filepath, '0000.jpg'))
72 | depth = np.load(os.path.join(filepath, '0000.npy'))
73 |
74 | rgb_contexts = [load_image(os.path.join(filepath, '0001.jpg'))]
75 |
76 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)]
77 | pos0 = np.zeros((4, 4))
78 | pos1 = np.zeros((4, 4))
79 | pos0[:3, :] = poses[0]
80 | pos0[3, 3] = 1.
81 | pos1[:3, :] = poses[1]
82 | pos1[3, 3] = 1.
83 | pos = np.matmul(pos1, np.linalg.inv(pos0))
84 | # pos = np.matmul(np.linalg.inv(pos1), pos0)
85 | pose_context = [pos.astype(np.float32)]
86 |
87 | intr = np.genfromtxt(os.path.join(filepath, 'cam.txt'))
88 |
89 | sample = {
90 | 'idx': idx,
91 | 'filename': '%s' % (filepath.split('/')[-1]),
92 | 'rgb': image,
93 | 'depth': depth,
94 | 'pose_context': pose_context,
95 | 'intrinsics': intr
96 | }
97 |
98 | if self.has_context:
99 | sample['rgb_context'] = rgb_contexts
100 |
101 | if self.data_transform:
102 | sample = self.data_transform(sample)
103 |
104 | return sample
105 |
106 | ########################################################################################################################
107 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/demon_mf_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | from collections import defaultdict
4 | import os
5 |
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | from dro_sfm.utils.image import load_image
9 | import cv2, IPython
10 | from PIL import Image
11 | from PIL import ImageFile
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 | import h5py
14 | ########################################################################################################################
15 | #### FUNCTIONS
16 | ########################################################################################################################
17 |
18 | def dummy_calibration(image):
19 | return np.array([[5.703422047415297129e+02 , 0. , 3.200000000000000000e+02],
20 | [0. , 5.703422047415297129e+02 , 2.400000000000000000e+02],
21 | [0. , 0. , 1. ]])
22 |
23 |
24 | ########################################################################################################################
25 | #### DATASET
26 | ########################################################################################################################
27 |
28 | class DemonDataset(Dataset):
29 | def __init__(self, root_dir, split, data_transform=None,
30 | forward_context=0, back_context=0, strides=(1,),
31 | depth_type=None, **kwargs):
32 | super().__init__()
33 | # Asserts
34 | # assert depth_type is None or depth_type == '', \
35 | # 'NYUDataset currently does not support depth types'
36 | assert len(strides) == 1 and strides[0] == 1, \
37 | 'NYUDataset currently only supports stride of 1.'
38 |
39 | self.depth_type = depth_type
40 | self.with_depth = depth_type is not '' and depth_type is not None
41 | self.root_dir = root_dir
42 | self.split = split
43 | with open(os.path.join(self.root_dir, split), "r") as f:
44 | data = f.readlines()
45 |
46 | self.paths = []
47 | # Get file list from data
48 | for i, fname in enumerate(data):
49 | path = os.path.join(root_dir, fname.split()[0])
50 | view3_rgb = os.path.join(path, "0002.jpg")
51 | view3_depth = os.path.join(path, "0002.npy")
52 | if not (forward_context == 1 and back_context == 1):
53 | if os.path.exists(view3_rgb) and os.path.exists(view3_depth):
54 | self.paths.append((path, True))
55 | else:
56 | self.paths.append((path, False))
57 | else:
58 | if os.path.exists(view3_rgb) and os.path.exists(view3_depth):
59 | self.paths.append((path, True))
60 |
61 | self.backward_context = back_context
62 | self.forward_context = forward_context
63 | self.has_context = (self.backward_context > 0 ) or (self.forward_context > 0)
64 | self.strides = strides[0]
65 |
66 | self.data_transform = data_transform
67 |
68 | def __len__(self):
69 | return len(self.paths)
70 |
71 | def _change_idx(self, idx, filename):
72 | _, ext = os.path.splitext(os.path.basename(filename))
73 | return str(idx) + ext
74 |
75 | def _get_view2(self, filepath):
76 | image = load_image(os.path.join(filepath, '0000.jpg'))
77 | depth = np.load(os.path.join(filepath, '0000.npy'))
78 | rgb_contexts = [load_image(os.path.join(filepath, '0001.jpg'))]
79 |
80 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)]
81 | pos0 = np.zeros((4, 4))
82 | pos1 = np.zeros((4, 4))
83 | pos0[:3, :] = poses[0]
84 | pos0[3, 3] = 1.
85 | pos1[:3, :] = poses[1]
86 | pos1[3, 3] = 1.
87 |
88 | pos01 = np.matmul(pos1, np.linalg.inv(pos0))
89 | pose_context = [pos01.astype(np.float32)]
90 |
91 | return image, depth, rgb_contexts, pose_context
92 |
93 | def _get_view3_dummy(self, filepath):
94 | image = load_image(os.path.join(filepath, '0000.jpg'))
95 | depth = np.load(os.path.join(filepath, '0000.npy'))
96 | rgb_contexts = [load_image(os.path.join(filepath, '0001.jpg')), load_image(os.path.join(filepath, '0001.jpg'))]
97 |
98 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)]
99 | pos0 = np.zeros((4, 4))
100 | pos1 = np.zeros((4, 4))
101 | pos0[:3, :] = poses[0]
102 | pos0[3, 3] = 1.
103 | pos1[:3, :] = poses[1]
104 | pos1[3, 3] = 1.
105 |
106 | pos01 = np.matmul(pos1, np.linalg.inv(pos0))
107 | pose_context = [pos01.astype(np.float32), pos01.astype(np.float32)]
108 |
109 | return image, depth, rgb_contexts, pose_context
110 |
111 | def _get_view3(self, filepath):
112 | image = load_image(os.path.join(filepath, '0001.jpg'))
113 | depth = np.load(os.path.join(filepath, '0001.npy'))
114 | rgb_contexts = [load_image(os.path.join(filepath, '0000.jpg')), load_image(os.path.join(filepath, '0002.jpg'))]
115 |
116 | poses = [p.reshape((3, 4)) for p in np.genfromtxt(os.path.join(filepath, 'poses.txt')).astype(np.float64)]
117 |
118 | pos0 = np.eye(4)
119 | pos1 = np.eye(4)
120 | pos2 = np.eye(4)
121 | pos0[:3, :] = poses[0]
122 | pos1[:3, :] = poses[1]
123 | pos2[:3, :] = poses[2]
124 |
125 | pos10 = np.matmul(pos0, np.linalg.inv(pos1))
126 | pos12 = np.matmul(pos2, np.linalg.inv(pos1))
127 | pose_context = [pos10.astype(np.float32), pos12.astype(np.float32)]
128 |
129 | return image, depth, rgb_contexts, pose_context
130 |
131 |
132 | def __getitem__(self, idx):
133 | filepath, mf = self.paths[idx]
134 |
135 | if self.forward_context == 1 and self.backward_context == 0:
136 | image, depth, rgb_contexts, pose_context = self._get_view2(filepath)
137 |
138 | elif self.forward_context == 1 and self.backward_context == 1:
139 | image, depth, rgb_contexts, pose_context = self._get_view3(filepath)
140 |
141 | elif self.forward_context == 1 and self.backward_context == -1:
142 | if np.random.random() > 0.5 and mf:
143 | image, depth, rgb_contexts, pose_context = self._get_view3(filepath)
144 | else:
145 | image, depth, rgb_contexts, pose_context = self._get_view3_dummy(filepath)
146 | else:
147 | raise NotImplementedError
148 |
149 | intr = np.genfromtxt(os.path.join(filepath, 'cam.txt'))
150 |
151 | sample = {
152 | 'idx': idx,
153 | 'filename': '%s' % (filepath.split('/')[-1]),
154 | 'rgb': image,
155 | 'depth': depth,
156 | 'pose_context': pose_context,
157 | 'intrinsics': intr
158 | }
159 |
160 | if self.has_context:
161 | sample['rgb_context'] = rgb_contexts
162 |
163 | if self.data_transform:
164 | sample = self.data_transform(sample)
165 |
166 | return sample
167 |
168 | ########################################################################################################################
169 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/image_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | from collections import defaultdict
4 | import os
5 |
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | from dro_sfm.utils.image import load_image
9 |
10 | ########################################################################################################################
11 | #### FUNCTIONS
12 | ########################################################################################################################
13 |
14 | def dummy_calibration(image):
15 | w, h = [float(d) for d in image.size]
16 | return np.array([[1000. , 0. , w / 2. - 0.5],
17 | [0. , 1000. , h / 2. - 0.5],
18 | [0. , 0. , 1. ]])
19 |
20 | def get_idx(filename):
21 | return int(re.search(r'\d+', filename).group())
22 |
23 | def read_files(directory, ext=('.png', '.jpg', '.jpeg'), skip_empty=True):
24 | files = defaultdict(list)
25 | for entry in os.scandir(directory):
26 | relpath = os.path.relpath(entry.path, directory)
27 | if entry.is_dir():
28 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty)
29 | if skip_empty and not len(d_files):
30 | continue
31 | files[relpath] = d_files[entry.path]
32 | elif entry.is_file():
33 | if ext is None or entry.path.lower().endswith(tuple(ext)):
34 | files[directory].append(relpath)
35 | return files
36 |
37 | ########################################################################################################################
38 | #### DATASET
39 | ########################################################################################################################
40 |
41 | class ImageDataset(Dataset):
42 | def __init__(self, root_dir, split, data_transform=None,
43 | forward_context=0, back_context=0, strides=(1,),
44 | depth_type=None, **kwargs):
45 | super().__init__()
46 | # Asserts
47 | assert depth_type is None or depth_type == '', \
48 | 'ImageDataset currently does not support depth types'
49 | assert len(strides) == 1 and strides[0] == 1, \
50 | 'ImageDataset currently only supports stride of 1.'
51 |
52 | self.root_dir = root_dir
53 | self.split = split
54 |
55 | self.backward_context = back_context
56 | self.forward_context = forward_context
57 | self.has_context = self.backward_context + self.forward_context > 0
58 | self.strides = 1
59 |
60 | self.files = []
61 | file_tree = read_files(root_dir)
62 | for k, v in file_tree.items():
63 | file_set = set(file_tree[k])
64 | files = [fname for fname in sorted(v) if self._has_context(fname, file_set)]
65 | self.files.extend([[k, fname] for fname in files])
66 |
67 | self.data_transform = data_transform
68 |
69 | def __len__(self):
70 | return len(self.files)
71 |
72 | def _change_idx(self, idx, filename):
73 | _, ext = os.path.splitext(os.path.basename(filename))
74 | return self.split.format(idx) + ext
75 |
76 | def _has_context(self, filename, file_set):
77 | context_paths = self._get_context_file_paths(filename)
78 | return all([f in file_set for f in context_paths])
79 |
80 | def _get_context_file_paths(self, filename):
81 | fidx = get_idx(filename)
82 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \
83 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides)
84 | return [self._change_idx(fidx + i, filename) for i in idxs]
85 |
86 | def _read_rgb_context_files(self, session, filename):
87 | context_paths = self._get_context_file_paths(filename)
88 | return [load_image(os.path.join(self.root_dir, session, filename))
89 | for filename in context_paths]
90 |
91 | def _read_rgb_file(self, session, filename):
92 | return load_image(os.path.join(self.root_dir, session, filename))
93 |
94 | def __getitem__(self, idx):
95 | session, filename = self.files[idx]
96 | image = self._read_rgb_file(session, filename)
97 |
98 | sample = {
99 | 'idx': idx,
100 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]),
101 | #
102 | 'rgb': image,
103 | 'intrinsics': dummy_calibration(image)
104 | }
105 |
106 | if self.has_context:
107 | sample['rgb_context'] = \
108 | self._read_rgb_context_files(session, filename)
109 |
110 | if self.data_transform:
111 | sample = self.data_transform(sample)
112 |
113 | return sample
114 |
115 | ########################################################################################################################
116 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/kitti_dataset_utils.py:
--------------------------------------------------------------------------------
1 | """Provides helper methods for loading and parsing KITTI data."""
2 |
3 | from collections import namedtuple
4 |
5 | import numpy as np
6 |
7 | __author__ = "Lee Clement"
8 | __email__ = "lee.clement@robotics.utias.utoronto.ca"
9 |
10 | # Per dataformat.txt
11 | OxtsPacket = namedtuple('OxtsPacket',
12 | 'lat, lon, alt, ' +
13 | 'roll, pitch, yaw, ' +
14 | 'vn, ve, vf, vl, vu, ' +
15 | 'ax, ay, az, af, al, au, ' +
16 | 'wx, wy, wz, wf, wl, wu, ' +
17 | 'pos_accuracy, vel_accuracy, ' +
18 | 'navstat, numsats, ' +
19 | 'posmode, velmode, orimode')
20 |
21 | # Bundle into an easy-to-access structure
22 | OxtsData = namedtuple('OxtsData', 'packet, T_w_imu')
23 |
24 |
25 | def rotx(t):
26 | """
27 | Rotation about the x-axis
28 |
29 | Parameters
30 | ----------
31 | t : float
32 | Theta angle
33 |
34 | Returns
35 | -------
36 | matrix : np.array [3,3]
37 | Rotation matrix
38 | """
39 | c = np.cos(t)
40 | s = np.sin(t)
41 | return np.array([[1, 0, 0],
42 | [0, c, -s],
43 | [0, s, c]])
44 |
45 |
46 | def roty(t):
47 | """
48 | Rotation about the y-axis
49 |
50 | Parameters
51 | ----------
52 | t : float
53 | Theta angle
54 |
55 | Returns
56 | -------
57 | matrix : np.array [3,3]
58 | Rotation matrix
59 | """
60 | c = np.cos(t)
61 | s = np.sin(t)
62 | return np.array([[c, 0, s],
63 | [0, 1, 0],
64 | [-s, 0, c]])
65 |
66 |
67 | def rotz(t):
68 | """
69 | Rotation about the z-axis
70 |
71 | Parameters
72 | ----------
73 | t : float
74 | Theta angle
75 |
76 | Returns
77 | -------
78 | matrix : np.array [3,3]
79 | Rotation matrix
80 | """
81 | c = np.cos(t)
82 | s = np.sin(t)
83 | return np.array([[c, -s, 0],
84 | [s, c, 0],
85 | [0, 0, 1]])
86 |
87 |
88 | def transform_from_rot_trans(R, t):
89 | """
90 | Transformation matrix from rotation matrix and translation vector.
91 |
92 | Parameters
93 | ----------
94 | R : np.array [3,3]
95 | Rotation matrix
96 | t : np.array [3]
97 | translation vector
98 |
99 | Returns
100 | -------
101 | matrix : np.array [4,4]
102 | Transformation matrix
103 | """
104 | R = R.reshape(3, 3)
105 | t = t.reshape(3, 1)
106 | return np.vstack((np.hstack([R, t]), [0, 0, 0, 1]))
107 |
108 |
109 | def read_calib_file(filepath):
110 | """
111 | Read in a calibration file and parse into a dictionary
112 |
113 | Parameters
114 | ----------
115 | filepath : str
116 | File path to read from
117 |
118 | Returns
119 | -------
120 | calib : dict
121 | Dictionary with calibration values
122 | """
123 | data = {}
124 |
125 | with open(filepath, 'r') as f:
126 | for line in f.readlines():
127 | key, value = line.split(':', 1)
128 | # The only non-float values in these files are dates, which
129 | # we don't care about anyway
130 | try:
131 | data[key] = np.array([float(x) for x in value.split()])
132 | except ValueError:
133 | pass
134 |
135 | return data
136 |
137 |
138 | def pose_from_oxts_packet(raw_data, scale):
139 | """
140 | Helper method to compute a SE(3) pose matrix from an OXTS packet
141 |
142 | Parameters
143 | ----------
144 | raw_data : dict
145 | Oxts data to read from
146 | scale : float
147 | Oxts scale
148 |
149 | Returns
150 | -------
151 | R : np.array [3,3]
152 | Rotation matrix
153 | t : np.array [3]
154 | Translation vector
155 | """
156 | packet = OxtsPacket(*raw_data)
157 | er = 6378137. # earth radius (approx.) in meters
158 |
159 | # Use a Mercator projection to get the translation vector
160 | tx = scale * packet.lon * np.pi * er / 180.
161 | ty = scale * er * \
162 | np.log(np.tan((90. + packet.lat) * np.pi / 360.))
163 | tz = packet.alt
164 | t = np.array([tx, ty, tz])
165 |
166 | # Use the Euler angles to get the rotation matrix
167 | Rx = rotx(packet.roll)
168 | Ry = roty(packet.pitch)
169 | Rz = rotz(packet.yaw)
170 | R = Rz.dot(Ry.dot(Rx))
171 |
172 | # Combine the translation and rotation into a homogeneous transform
173 | return R, t
174 |
175 |
176 | def load_oxts_packets_and_poses(oxts_files):
177 | """
178 | Generator to read OXTS ground truth data.
179 | Poses are given in an East-North-Up coordinate system
180 | whose origin is the first GPS position.
181 |
182 | Parameters
183 | ----------
184 | oxts_files : list of str
185 | List of oxts files to read from
186 |
187 | Returns
188 | -------
189 | oxts : list of dict
190 | List of oxts ground-truth data
191 | """
192 | # Scale for Mercator projection (from first lat value)
193 | scale = None
194 | # Origin of the global coordinate system (first GPS position)
195 | origin = None
196 |
197 | oxts = []
198 |
199 | for filename in oxts_files:
200 | with open(filename, 'r') as f:
201 | for line in f.readlines():
202 | line = line.split()
203 | # Last five entries are flags and counts
204 | line[:-5] = [float(x) for x in line[:-5]]
205 | line[-5:] = [int(float(x)) for x in line[-5:]]
206 |
207 | packet = OxtsPacket(*line)
208 |
209 | if scale is None:
210 | scale = np.cos(packet.lat * np.pi / 180.)
211 |
212 | R, t = pose_from_oxts_packet(packet, scale)
213 |
214 | if origin is None:
215 | origin = t
216 |
217 | T_w_imu = transform_from_rot_trans(R, t - origin)
218 |
219 | oxts.append(OxtsData(packet, T_w_imu))
220 |
221 | return oxts
222 |
223 |
224 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/nyu_dataset_processed.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | from collections import defaultdict
4 | import os
5 |
6 | from torch.utils.data import Dataset
7 | import numpy as np
8 | from dro_sfm.utils.image import load_image
9 | import cv2, IPython
10 | from PIL import Image
11 | from PIL import ImageFile
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 | import h5py
14 | ########################################################################################################################
15 | #### FUNCTIONS
16 | ########################################################################################################################
17 |
18 | def dummy_calibration(image):
19 | return np.array([[518.85790117450188 , 0. , 325.58244941119034],
20 | [0. , 519.46961112127485 , 253.73616633400465],
21 | [0. , 0. , 1. ]])
22 |
23 | def get_idx(filename):
24 | return int(re.search(r'\d+', filename).group())
25 |
26 | def read_files(directory, ext=('.depth', '.h5'), skip_empty=True):
27 | files = defaultdict(list)
28 | for entry in os.scandir(directory):
29 | relpath = os.path.relpath(entry.path, directory)
30 | if entry.is_dir():
31 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty)
32 | if skip_empty and not len(d_files):
33 | continue
34 | files[relpath] = d_files[entry.path]
35 | elif entry.is_file():
36 | if (ext is None or entry.path.lower().endswith(tuple(ext))):
37 | files[directory].append(relpath)
38 | return files
39 |
40 | def read_npz_depth(file, depth_type):
41 | """Reads a .npz depth map given a certain depth_type."""
42 | depth = np.load(file)[depth_type + '_depth'].astype(np.float32)
43 | return np.expand_dims(depth, axis=2)
44 |
45 | def read_png_depth(file):
46 | """Reads a .png depth map."""
47 | depth_png = np.array(load_image(file), dtype=int)
48 |
49 | # assert (np.max(depth_png) > 255), 'Wrong .png depth file'
50 | if (np.max(depth_png) > 255):
51 | depth = depth_png.astype(np.float) / 256.
52 | else:
53 | depth = depth_png.astype(np.float)
54 | depth[depth_png == 0] = -1.
55 | return np.expand_dims(depth, axis=2)
56 |
57 | ########################################################################################################################
58 | #### DATASET
59 | ########################################################################################################################
60 |
61 | class NYUDataset(Dataset):
62 | def __init__(self, root_dir, split, data_transform=None,
63 | forward_context=0, back_context=0, strides=(1,),
64 | depth_type=None, **kwargs):
65 | super().__init__()
66 | # Asserts
67 | # assert depth_type is None or depth_type == '', \
68 | # 'NYUDataset currently does not support depth types'
69 | assert len(strides) == 1 and strides[0] == 1, \
70 | 'NYUDataset currently only supports stride of 1.'
71 |
72 | self.depth_type = depth_type
73 | self.with_depth = depth_type is not '' and depth_type is not None
74 | self.root_dir = root_dir
75 | self.split = split
76 |
77 | self.backward_context = back_context
78 | self.forward_context = forward_context
79 | self.has_context = self.backward_context + self.forward_context > 0
80 | self.strides = strides[0]
81 |
82 | self.files = []
83 | self.file_tree = read_files(root_dir)
84 | for k, v in self.file_tree.items():
85 | file_list = sorted(v)
86 | files = [fname for fname in file_list if self._has_context(k, fname, file_list)]
87 | self.files.extend([[k, fname] for fname in files])
88 |
89 | self.data_transform = data_transform
90 |
91 | def __len__(self):
92 | return len(self.files)
93 |
94 | def _change_idx(self, idx, filename):
95 | _, ext = os.path.splitext(os.path.basename(filename))
96 | return str(idx) + ext
97 |
98 | def _has_context(self, session, filename, file_list):
99 | context_paths = self._get_context_file_paths(filename, file_list)
100 | return all([f in file_list for f in context_paths])
101 |
102 | def _get_context_file_paths(self, filename, filelist):
103 | # fidx = get_idx(filename)
104 | fidx = filelist.index(filename)
105 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \
106 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides)
107 | return [filelist[fidx+i] if 0 <= fidx+i < len(filelist) else 'none' for i in idxs]
108 |
109 | def _read_rgb_context_files(self, session, filename):
110 | context_paths = self._get_context_file_paths(filename, sorted(self.file_tree[session]))
111 |
112 | return [self._read_rgb_file(session, filename)
113 | for filename in context_paths]
114 |
115 | def _read_rgb_file(self, session, filename):
116 | file_path = os.path.join(self.root_dir, session, filename)
117 | h5f = h5py.File(file_path, "r")
118 | rgb = np.array(h5f['rgb'])
119 | image = np.transpose(rgb, (1, 2, 0))
120 | image_pil = Image.fromarray(image)
121 | return image_pil
122 |
123 | def __getitem__(self, idx):
124 | session, filename = self.files[idx]
125 | if session == self.root_dir:
126 | session = ''
127 |
128 | file_path = os.path.join(self.root_dir, session, filename)
129 | h5f = h5py.File(file_path, "r")
130 | rgb = np.array(h5f['rgb'])
131 | image = np.transpose(rgb, (1, 2, 0))
132 | # image = rgb[...,::-1]
133 | image_pil = Image.fromarray(image)
134 | depth = np.array(h5f['depth'])
135 |
136 | sample = {
137 | 'idx': idx,
138 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]),
139 | 'rgb': image_pil,
140 | 'depth': depth,
141 | 'intrinsics': dummy_calibration(image)
142 | }
143 |
144 | if self.has_context:
145 | sample['rgb_context'] = \
146 | self._read_rgb_context_files(session, filename)
147 |
148 | if self.data_transform:
149 | sample = self.data_transform(sample)
150 |
151 | return sample
152 |
153 | ########################################################################################################################
154 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/transforms.py:
--------------------------------------------------------------------------------
1 |
2 | from functools import partial
3 | from dro_sfm.datasets.augmentations import resize_image, resize_sample, \
4 | duplicate_sample, colorjitter_sample, to_tensor_sample, resize_sample_image_and_intrinsics
5 |
6 | ########################################################################################################################
7 |
8 | def train_transforms(sample, image_shape, jittering):
9 | """
10 | Training data augmentation transformations
11 |
12 | Parameters
13 | ----------
14 | sample : dict
15 | Sample to be augmented
16 | image_shape : tuple (height, width)
17 | Image dimension to reshape
18 | jittering : tuple (brightness, contrast, saturation, hue)
19 | Color jittering parameters
20 |
21 | Returns
22 | -------
23 | sample : dict
24 | Augmented sample
25 | """
26 | if len(image_shape) > 0:
27 | sample = resize_sample(sample, image_shape)
28 | sample = duplicate_sample(sample)
29 | if len(jittering) > 0:
30 | sample = colorjitter_sample(sample, jittering)
31 | sample = to_tensor_sample(sample)
32 | return sample
33 |
34 | def validation_transforms(sample, image_shape):
35 | """
36 | Validation data augmentation transformations
37 |
38 | Parameters
39 | ----------
40 | sample : dict
41 | Sample to be augmented
42 | image_shape : tuple (height, width)
43 | Image dimension to reshape
44 |
45 | Returns
46 | -------
47 | sample : dict
48 | Augmented sample
49 | """
50 | # if len(image_shape) > 0:
51 | # sample['rgb'] = resize_image(sample['rgb'], image_shape)
52 | # if 'rgb_context' in sample:
53 | # sample['rgb_context'] = [resize_image(img, image_shape) for img in sample['rgb_context']]
54 |
55 | if len(image_shape) > 0:
56 | sample = resize_sample_image_and_intrinsics(sample, image_shape)
57 |
58 | sample = to_tensor_sample(sample)
59 | return sample
60 |
61 | def test_transforms(sample, image_shape):
62 | """
63 | Test data augmentation transformations
64 |
65 | Parameters
66 | ----------
67 | sample : dict
68 | Sample to be augmented
69 | image_shape : tuple (height, width)
70 | Image dimension to reshape
71 |
72 | Returns
73 | -------
74 | sample : dict
75 | Augmented sample
76 | """
77 | # if len(image_shape) > 0:
78 | # sample['rgb'] = resize_image(sample['rgb'], image_shape)
79 | # if 'rgb_context' in sample:
80 | # sample['rgb_context'] = [resize_image(img, image_shape) for img in sample['rgb_context']]
81 |
82 | if len(image_shape) > 0:
83 | sample = resize_sample_image_and_intrinsics(sample, image_shape)
84 |
85 | sample = to_tensor_sample(sample)
86 | return sample
87 |
88 | def get_transforms(mode, image_shape, jittering, **kwargs):
89 | """
90 | Get data augmentation transformations for each split
91 |
92 | Parameters
93 | ----------
94 | mode : str {'train', 'validation', 'test'}
95 | Mode from which we want the data augmentation transformations
96 | image_shape : tuple (height, width)
97 | Image dimension to reshape
98 | jittering : tuple (brightness, contrast, saturation, hue)
99 | Color jittering parameters
100 |
101 | Returns
102 | -------
103 | XXX_transform: Partial function
104 | Data augmentation transformation for that mode
105 | """
106 | if mode == 'train':
107 | return partial(train_transforms,
108 | image_shape=image_shape,
109 | jittering=jittering)
110 | elif mode == 'validation':
111 | return partial(validation_transforms,
112 | image_shape=image_shape)
113 | elif mode == 'test':
114 | return partial(test_transforms,
115 | image_shape=image_shape)
116 | else:
117 | raise ValueError('Unknown mode {}'.format(mode))
118 |
119 | ########################################################################################################################
120 |
121 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/video_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | from collections import defaultdict
4 | import os
5 | from numpy.core.defchararray import join
6 |
7 | from torch.utils.data import Dataset
8 | import numpy as np
9 | from dro_sfm.utils.image import load_image
10 | import cv2, IPython
11 | from PIL import Image
12 | from PIL import ImageFile
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | import h5py
15 | ########################################################################################################################
16 | #### FUNCTIONS
17 | ########################################################################################################################
18 |
19 | def dummy_calibration(w, h):
20 | fx = w * 1.2
21 | fy = w * 1.2
22 | cx = w / 2.0
23 | cy = h / 2.0
24 | return np.array([[fx, 0., cx],
25 | [0., fy, cy],
26 | [0., 0., 1.]])
27 |
28 |
29 | def get_idx(filename):
30 | return int(re.search(r'\d+', filename).group())
31 |
32 | def read_files(directory, ext=('.depth', '.h5'), skip_empty=True):
33 | files = defaultdict(list)
34 | for entry in os.scandir(directory):
35 | relpath = os.path.relpath(entry.path, directory)
36 | if entry.is_dir():
37 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty)
38 | if skip_empty and not len(d_files):
39 | continue
40 | files[relpath] = d_files[entry.path]
41 | elif entry.is_file():
42 | if (ext is None or entry.path.lower().endswith(tuple(ext))):
43 | files[directory].append(relpath)
44 | return files
45 |
46 | def read_npz_depth(file, depth_type):
47 | """Reads a .npz depth map given a certain depth_type."""
48 | depth = np.load(file)[depth_type + '_depth'].astype(np.float32)
49 | return np.expand_dims(depth, axis=2)
50 |
51 | def read_png_depth(file):
52 | """Reads a .png depth map."""
53 | depth_png = np.array(load_image(file), dtype=int)
54 |
55 | # assert (np.max(depth_png) > 255), 'Wrong .png depth file'
56 | if (np.max(depth_png) > 255):
57 | depth = depth_png.astype(np.float) / 256.
58 | else:
59 | depth = depth_png.astype(np.float)
60 | depth[depth_png == 0] = -1.
61 | return np.expand_dims(depth, axis=2)
62 |
63 | ########################################################################################################################
64 | #### DATASET
65 | ########################################################################################################################
66 |
67 | class VideoDataset(Dataset):
68 | def __init__(self, root_dir, split, data_transform=None,
69 | forward_context=0, back_context=0, strides=(1,),
70 | depth_type=None, **kwargs):
71 | super().__init__()
72 | # Asserts
73 | assert depth_type is None or depth_type == '', \
74 | 'VideoDataset currently does not support depth types'
75 | self.depth_type = depth_type
76 | self.with_depth = depth_type is not '' and depth_type is not None
77 | self.root_dir = root_dir
78 | self.split = split
79 |
80 | self.backward_context = back_context
81 | self.forward_context = forward_context
82 | self.has_context = self.backward_context + self.forward_context > 0
83 | self.strides = strides[0]
84 |
85 | self.files = []
86 | self.file_tree = read_files(root_dir, ext=(".jpg", ".png", ".bmp", ".jpeg"))
87 | for k, v in self.file_tree.items():
88 | file_list = sorted(v)
89 | files = [fname for fname in file_list if self._has_context(k, fname, file_list)]
90 | self.files.extend([[k, fname] for fname in files])
91 |
92 | self.data_transform = data_transform
93 |
94 | def __len__(self):
95 | return len(self.files)
96 |
97 | def _has_context(self, session, filename, file_list):
98 | context_paths = self._get_context_file_paths(filename, file_list)
99 | return all([os.path.exists(os.path.join(self.root_dir, session, f)) for f in context_paths])
100 |
101 | def _get_context_file_paths(self, filename, filelist):
102 | # fidx = get_idx(filename)
103 | fidx = filelist.index(filename)
104 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \
105 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides)
106 |
107 | return [filelist[fidx+i] if (0 <= (fidx+i)) and ((fidx+i) < len(filelist)) else 'none' for i in idxs]
108 |
109 | def _read_rgb_context_files(self, session, filename):
110 | context_paths = self._get_context_file_paths(filename, sorted(self.file_tree[session]))
111 | return [self._read_rgb_file(session, filename)
112 | for filename in context_paths]
113 |
114 | def _read_rgb_file(self, session, filename):
115 | file_path = os.path.join(self.root_dir, session, filename)
116 | rgb = load_image(file_path)
117 | return rgb
118 |
119 | def __getitem__(self, idx):
120 | session, filename = self.files[idx]
121 | if session == self.root_dir:
122 | session = ''
123 |
124 | rgb = self._read_rgb_file(session, filename)
125 |
126 | sample = {
127 | 'idx': idx,
128 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]),
129 | 'rgb': rgb,
130 | 'intrinsics': dummy_calibration(w=rgb.size[0], h=rgb.size[1])
131 | }
132 |
133 | # Add depth information if requested
134 | if self.with_depth:
135 | sample.update({
136 | 'depth': None,
137 | })
138 |
139 | if self.has_context:
140 | sample['rgb_context'] = \
141 | self._read_rgb_context_files(session, filename)
142 |
143 | if self.data_transform:
144 | sample = self.data_transform(sample)
145 |
146 | return sample
147 |
148 | ########################################################################################################################
149 |
--------------------------------------------------------------------------------
/dro_sfm/datasets/video_random_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | from collections import defaultdict
4 | import os
5 | from numpy.core.defchararray import join
6 |
7 | from torch.utils.data import Dataset
8 | import numpy as np
9 | from dro_sfm.utils.image import load_image
10 | import cv2, IPython
11 | from PIL import Image
12 | from PIL import ImageFile
13 | ImageFile.LOAD_TRUNCATED_IMAGES = True
14 | import h5py
15 | ########################################################################################################################
16 | #### FUNCTIONS
17 | ########################################################################################################################
18 |
19 | def dummy_calibration(w, h):
20 | fx = w * 1.2
21 | fy = w * 1.2
22 | cx = w / 2.0
23 | cy = h / 2.0
24 | return np.array([[fx, 0., cx],
25 | [0., fy, cy],
26 | [0., 0., 1.]])
27 |
28 |
29 | def get_idx(filename):
30 | return int(re.search(r'\d+', filename).group())
31 |
32 | def read_files(directory, ext=('.depth', '.h5'), skip_empty=True):
33 | files = defaultdict(list)
34 | for entry in os.scandir(directory):
35 | relpath = os.path.relpath(entry.path, directory)
36 | if entry.is_dir():
37 | d_files = read_files(entry.path, ext=ext, skip_empty=skip_empty)
38 | if skip_empty and not len(d_files):
39 | continue
40 | files[relpath] = d_files[entry.path]
41 | elif entry.is_file():
42 | if (ext is None or entry.path.lower().endswith(tuple(ext))):
43 | files[directory].append(relpath)
44 | return files
45 |
46 | def read_npz_depth(file, depth_type):
47 | """Reads a .npz depth map given a certain depth_type."""
48 | depth = np.load(file)[depth_type + '_depth'].astype(np.float32)
49 | return np.expand_dims(depth, axis=2)
50 |
51 | def read_png_depth(file):
52 | """Reads a .png depth map."""
53 | depth_png = np.array(load_image(file), dtype=int)
54 |
55 | # assert (np.max(depth_png) > 255), 'Wrong .png depth file'
56 | if (np.max(depth_png) > 255):
57 | depth = depth_png.astype(np.float) / 256.
58 | else:
59 | depth = depth_png.astype(np.float)
60 | depth[depth_png == 0] = -1.
61 | return np.expand_dims(depth, axis=2)
62 |
63 | ########################################################################################################################
64 | #### DATASET
65 | ########################################################################################################################
66 |
67 | class VideoRandomDataset(Dataset):
68 | def __init__(self, root_dir, split, data_transform=None,
69 | forward_context=0, back_context=0, strides=(1,),
70 | depth_type=None, **kwargs):
71 | super().__init__()
72 | # Asserts
73 | assert depth_type is None or depth_type == '', \
74 | 'VideoDataset currently does not support depth types'
75 | self.depth_type = depth_type
76 | self.with_depth = depth_type is not '' and depth_type is not None
77 | self.root_dir = root_dir
78 | self.split = split
79 |
80 | self.backward_context = back_context
81 | self.forward_context = forward_context
82 |
83 | self.max_len = 11
84 |
85 | self.has_context = self.backward_context + self.forward_context > 0
86 | self.strides = strides[0]
87 |
88 | self.files = []
89 | self.file_tree = read_files(root_dir, ext=(".jpg", ".png", ".bmp", ".jpeg"))
90 | for k, v in self.file_tree.items():
91 | file_list = sorted(v)
92 | files = [fname for fname in file_list if self._has_context(k, fname, file_list)]
93 | self.files.extend([[k, fname] for fname in files])
94 |
95 | self.data_transform = data_transform
96 |
97 | def __len__(self):
98 | return len(self.files)
99 |
100 | def _has_context(self, session, filename, file_list):
101 | context_paths = self._get_context_file_paths(filename, file_list)
102 | return all([os.path.exists(os.path.join(self.root_dir, session, f)) for f in context_paths])
103 |
104 | def _get_context_file_paths(self, filename, filelist):
105 | # fidx = get_idx(filename)
106 | fidx = filelist.index(filename)
107 | idxs = list(np.arange(-self.backward_context * self.strides, 0, self.strides)) + \
108 | list(np.arange(0, self.forward_context * self.strides, self.strides) + self.strides)
109 |
110 | return [filelist[fidx+i] if (0 <= (fidx+i)) and ((fidx+i) < len(filelist)) else 'none' for i in idxs]
111 |
112 |
113 | def _get_context_file_paths_random(self, filename, filelist):
114 | # fidx = get_idx(filename)
115 | fidx = filelist.index(filename)
116 | idxs_back = list(np.arange(-self.backward_context * self.strides, 0, 1))
117 |
118 | idx_forw= list(np.arange(0, self.forward_context * self.strides, 1) + 1)
119 |
120 | idxs_back = np.random.choice(idxs_back, self.backward_context).tolist()
121 |
122 | idx_forw = np.random.choice(idx_forw, self.forward_context).tolist()
123 |
124 | idxs = idxs_back + idx_forw
125 |
126 | return [filelist[i + fidx] for i in idxs]
127 |
128 |
129 | def _read_rgb_context_files(self, session, filename):
130 | context_paths = self._get_context_file_paths_random(filename, sorted(self.file_tree[session]))
131 | # print(filename, context_paths)
132 | return [self._read_rgb_file(session, filename)
133 | for filename in context_paths]
134 |
135 |
136 | def _read_rgb_file(self, session, filename):
137 | file_path = os.path.join(self.root_dir, session, filename)
138 | rgb = load_image(file_path)
139 | return rgb
140 |
141 | def __getitem__(self, idx):
142 | session, filename = self.files[idx]
143 | if session == self.root_dir:
144 | session = ''
145 |
146 | rgb = self._read_rgb_file(session, filename)
147 |
148 | sample = {
149 | 'idx': idx,
150 | 'filename': '%s_%s' % (session, os.path.splitext(filename)[0]),
151 | 'rgb': rgb,
152 | 'intrinsics': dummy_calibration(w=rgb.size[0], h=rgb.size[1])
153 | }
154 |
155 | # Add depth information if requested
156 | if self.with_depth:
157 | sample.update({
158 | 'depth': None,
159 | })
160 |
161 | if self.has_context:
162 | sample['rgb_context'] = \
163 | self._read_rgb_context_files(session, filename)
164 |
165 | if self.data_transform:
166 | sample = self.data_transform(sample)
167 |
168 | return sample
169 |
170 | ########################################################################################################################
171 |
--------------------------------------------------------------------------------
/dro_sfm/geometry/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/geometry/__init__.py
--------------------------------------------------------------------------------
/dro_sfm/geometry/camera.py:
--------------------------------------------------------------------------------
1 |
2 | from functools import lru_cache
3 | import torch
4 | import torch.nn as nn
5 |
6 | from dro_sfm.geometry.pose import Pose
7 | from dro_sfm.geometry.camera_utils import scale_intrinsics
8 | from dro_sfm.utils.image import image_grid
9 |
10 | ########################################################################################################################
11 |
12 | class Camera(nn.Module):
13 | """
14 | Differentiable camera class implementing reconstruction and projection
15 | functions for a pinhole model.
16 | """
17 | def __init__(self, K, Tcw=None):
18 | """
19 | Initializes the Camera class
20 |
21 | Parameters
22 | ----------
23 | K : torch.Tensor [3,3]
24 | Camera intrinsics
25 | Tcw : Pose
26 | Camera -> World pose transformation
27 | """
28 | super().__init__()
29 | self.K = K
30 | self.Tcw = Pose.identity(len(K)) if Tcw is None else Tcw
31 |
32 | def __len__(self):
33 | """Batch size of the camera intrinsics"""
34 | return len(self.K)
35 |
36 | def to(self, *args, **kwargs):
37 | """Moves object to a specific device"""
38 | self.K = self.K.to(*args, **kwargs)
39 | self.Tcw = self.Tcw.to(*args, **kwargs)
40 | return self
41 |
42 | ########################################################################################################################
43 |
44 | @property
45 | def fx(self):
46 | """Focal length in x"""
47 | return self.K[:, 0, 0]
48 |
49 | @property
50 | def fy(self):
51 | """Focal length in y"""
52 | return self.K[:, 1, 1]
53 |
54 | @property
55 | def cx(self):
56 | """Principal point in x"""
57 | return self.K[:, 0, 2]
58 |
59 | @property
60 | def cy(self):
61 | """Principal point in y"""
62 | return self.K[:, 1, 2]
63 |
64 | @property
65 | @lru_cache()
66 | def Twc(self):
67 | """World -> Camera pose transformation (inverse of Tcw)"""
68 | return self.Tcw.inverse()
69 |
70 | @property
71 | @lru_cache()
72 | def Kinv(self):
73 | """Inverse intrinsics (for lifting)"""
74 | Kinv = self.K.clone()
75 | Kinv[:, 0, 0] = 1. / self.fx
76 | Kinv[:, 1, 1] = 1. / self.fy
77 | Kinv[:, 0, 2] = -1. * self.cx / self.fx
78 | Kinv[:, 1, 2] = -1. * self.cy / self.fy
79 | return Kinv
80 |
81 | ########################################################################################################################
82 |
83 | def scaled(self, x_scale, y_scale=None):
84 | """
85 | Returns a scaled version of the camera (changing intrinsics)
86 |
87 | Parameters
88 | ----------
89 | x_scale : float
90 | Resize scale in x
91 | y_scale : float
92 | Resize scale in y. If None, use the same as x_scale
93 |
94 | Returns
95 | -------
96 | camera : Camera
97 | Scaled version of the current cmaera
98 | """
99 | # If single value is provided, use for both dimensions
100 | if y_scale is None:
101 | y_scale = x_scale
102 | # If no scaling is necessary, return same camera
103 | if x_scale == 1. and y_scale == 1.:
104 | return self
105 | # Scale intrinsics and return new camera with same Pose
106 | K = scale_intrinsics(self.K.clone(), x_scale, y_scale)
107 | return Camera(K, Tcw=self.Tcw)
108 |
109 | ########################################################################################################################
110 |
111 | def reconstruct(self, depth, frame='w'):
112 | """
113 | Reconstructs pixel-wise 3D points from a depth map.
114 |
115 | Parameters
116 | ----------
117 | depth : torch.Tensor [B,1,H,W]
118 | Depth map for the camera
119 | frame : 'w'
120 | Reference frame: 'c' for camera and 'w' for world
121 |
122 | Returns
123 | -------
124 | points : torch.tensor [B,3,H,W]
125 | Pixel-wise 3D points
126 | """
127 | B, C, H, W = depth.shape
128 | assert C == 1
129 |
130 | # Create flat index grid
131 | grid = image_grid(B, H, W, depth.dtype, depth.device, normalized=False) # [B,3,H,W]
132 | flat_grid = grid.view(B, 3, -1) # [B,3,HW]
133 |
134 | # Estimate the outward rays in the camera frame
135 | xnorm = (self.Kinv.bmm(flat_grid)).view(B, 3, H, W)
136 | # Scale rays to metric depth
137 | Xc = xnorm * depth
138 |
139 | # If in camera frame of reference
140 | if frame == 'c':
141 | return Xc
142 | # If in world frame of reference
143 | elif frame == 'w':
144 | return self.Twc @ Xc
145 | # If none of the above
146 | else:
147 | raise ValueError('Unknown reference frame {}'.format(frame))
148 |
149 | def project(self, X, frame='w', normalize=True):
150 | """
151 | Projects 3D points onto the image plane
152 |
153 | Parameters
154 | ----------
155 | X : torch.Tensor [B,3,H,W]
156 | 3D points to be projected
157 | frame : 'w'
158 | Reference frame: 'c' for camera and 'w' for world
159 |
160 | Returns
161 | -------
162 | points : torch.Tensor [B,H,W,2]
163 | 2D projected points that are within the image boundaries
164 | """
165 | B, C, H, W = X.shape
166 | assert C == 3
167 |
168 | # Project 3D points onto the camera image plane
169 | if frame == 'c':
170 | Xc = self.K.bmm(X.view(B, 3, -1))
171 | elif frame == 'w':
172 | Xc = self.K.bmm((self.Tcw @ X).view(B, 3, -1))
173 | else:
174 | raise ValueError('Unknown reference frame {}'.format(frame))
175 |
176 | # Normalize points
177 | X = Xc[:, 0]
178 | Y = Xc[:, 1]
179 | Z = Xc[:, 2].clamp(min=1e-5)
180 | if normalize:
181 | Xnorm = 2 * (X / Z) / (W - 1) - 1. #(-1, 1)
182 | Ynorm = 2 * (Y / Z) / (H - 1) - 1.
183 | else:
184 | Xnorm = X / Z
185 | Ynorm = Y / Z
186 |
187 | # Clamp out-of-bounds pixels
188 | # Xmask = ((Xnorm > 1) + (Xnorm < -1)).detach()
189 | # Xnorm[Xmask] = 2.
190 | # Ymask = ((Ynorm > 1) + (Ynorm < -1)).detach()
191 | # Ynorm[Ymask] = 2.
192 |
193 | # Return pixel coordinates
194 | return torch.stack([Xnorm, Ynorm], dim=-1).view(B, H, W, 2)
195 |
--------------------------------------------------------------------------------
/dro_sfm/geometry/camera_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as funct
4 |
5 | ########################################################################################################################
6 |
7 | def construct_K(fx, fy, cx, cy, dtype=torch.float, device=None):
8 | """Construct a [3,3] camera intrinsics from pinhole parameters"""
9 | return torch.tensor([[fx, 0, cx],
10 | [ 0, fy, cy],
11 | [ 0, 0, 1]], dtype=dtype, device=device)
12 |
13 | def scale_intrinsics(K, x_scale, y_scale):
14 | """Scale intrinsics given x_scale and y_scale factors"""
15 | K[..., 0, 0] *= x_scale
16 | K[..., 1, 1] *= y_scale
17 | K[..., 0, 2] = (K[..., 0, 2] + 0.5) * x_scale - 0.5
18 | K[..., 1, 2] = (K[..., 1, 2] + 0.5) * y_scale - 0.5
19 | return K
20 |
21 | ########################################################################################################################
22 |
23 | def view_synthesis(ref_image, depth, ref_cam, cam,
24 | mode='bilinear', padding_mode='zeros'):
25 | """
26 | Synthesize an image from another plus a depth map.
27 |
28 | Parameters
29 | ----------
30 | ref_image : torch.Tensor [B,3,H,W]
31 | Reference image to be warped
32 | depth : torch.Tensor [B,1,H,W]
33 | Depth map from the original image
34 | ref_cam : Camera
35 | Camera class for the reference image
36 | cam : Camera
37 | Camera class for the original image
38 | mode : str
39 | Interpolation mode
40 | padding_mode : str
41 | Padding mode for interpolation
42 |
43 | Returns
44 | -------
45 | ref_warped : torch.Tensor [B,3,H,W]
46 | Warped reference image in the original frame of reference
47 | """
48 | assert depth.size(1) == 1
49 | # Reconstruct world points from target_camera
50 | world_points = cam.reconstruct(depth, frame='w')
51 | # Project world points onto reference camera
52 | ref_coords = ref_cam.project(world_points, frame='w')
53 |
54 | # View-synthesis given the projected reference points
55 | return funct.grid_sample(ref_image, ref_coords, mode=mode,
56 | padding_mode=padding_mode, align_corners=True)
57 |
58 | ########################################################################################################################
59 |
60 |
--------------------------------------------------------------------------------
/dro_sfm/geometry/pose.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from dro_sfm.geometry.pose_utils import invert_pose, pose_vec2mat
4 |
5 | ########################################################################################################################
6 |
7 | class Pose:
8 | """
9 | Pose class, that encapsulates a [4,4] transformation matrix
10 | for a specific reference frame
11 | """
12 | def __init__(self, mat):
13 | """
14 | Initializes a Pose object.
15 |
16 | Parameters
17 | ----------
18 | mat : torch.Tensor [B,4,4]
19 | Transformation matrix
20 | """
21 | assert tuple(mat.shape[-2:]) == (4, 4)
22 | if mat.dim() == 2:
23 | mat = mat.unsqueeze(0)
24 | assert mat.dim() == 3
25 | self.mat = mat
26 |
27 | def __len__(self):
28 | """Batch size of the transformation matrix"""
29 | return len(self.mat)
30 |
31 | ########################################################################################################################
32 |
33 | @classmethod
34 | def identity(cls, N=1, device=None, dtype=torch.float):
35 | """Initializes as a [4,4] identity matrix"""
36 | return cls(torch.eye(4, device=device, dtype=dtype).repeat([N,1,1]))
37 |
38 | @classmethod
39 | def from_vec(cls, vec, mode):
40 | """Initializes from a [B,6] batch vector"""
41 | mat = pose_vec2mat(vec, mode) # [B,3,4]
42 | pose = torch.eye(4, device=vec.device, dtype=vec.dtype).repeat([len(vec), 1, 1])
43 | pose[:, :3, :3] = mat[:, :3, :3]
44 | pose[:, :3, -1] = mat[:, :3, -1]
45 | return cls(pose)
46 |
47 | ########################################################################################################################
48 |
49 | @property
50 | def shape(self):
51 | """Returns the transformation matrix shape"""
52 | return self.mat.shape
53 |
54 | def item(self):
55 | """Returns the transformation matrix"""
56 | return self.mat
57 |
58 | def repeat(self, *args, **kwargs):
59 | """Repeats the transformation matrix multiple times"""
60 | self.mat = self.mat.repeat(*args, **kwargs)
61 | return self
62 |
63 | def inverse(self):
64 | """Returns a new Pose that is the inverse of this one"""
65 | return Pose(invert_pose(self.mat))
66 |
67 | def to(self, *args, **kwargs):
68 | """Moves object to a specific device"""
69 | self.mat = self.mat.to(*args, **kwargs)
70 | return self
71 |
72 | ########################################################################################################################
73 |
74 | def transform_pose(self, pose):
75 | """Creates a new pose object that compounds this and another one (self * pose)"""
76 | assert tuple(pose.shape[-2:]) == (4, 4)
77 | return Pose(self.mat.bmm(pose.item()))
78 |
79 | def transform_points(self, points):
80 | """Transforms 3D points using this object"""
81 | assert points.shape[1] == 3
82 | B, _, H, W = points.shape
83 | out = self.mat[:,:3,:3].bmm(points.view(B, 3, -1)) + \
84 | self.mat[:,:3,-1].unsqueeze(-1)
85 | return out.view(B, 3, H, W)
86 |
87 | def __matmul__(self, other):
88 | """Transforms the input (Pose or 3D points) using this object"""
89 | if isinstance(other, Pose):
90 | return self.transform_pose(other)
91 | elif isinstance(other, torch.Tensor):
92 | if other.shape[1] == 3 and other.dim() > 2:
93 | assert other.dim() == 3 or other.dim() == 4
94 | return self.transform_points(other)
95 | else:
96 | raise ValueError('Unknown tensor dimensions {}'.format(other.shape))
97 | else:
98 | raise NotImplementedError()
99 |
100 | ########################################################################################################################
101 |
--------------------------------------------------------------------------------
/dro_sfm/geometry/pose_utils.py:
--------------------------------------------------------------------------------
1 |
2 | from dro_sfm.geometry.pose_trans import axis_angle_to_matrix
3 | import torch
4 | import numpy as np
5 | from torch._C import dtype
6 |
7 | def mat2euler(mat):
8 | euler = torch.ones(mat.shape[0], 3, dtype=mat.dtype, device=mat.device)
9 | cy_thresh = 1e-6
10 | # try:
11 | # cy_thresh = np.finfo(mat.dtype).eps * 4
12 | # except ValueError:
13 | # cy_thresh = np.finfo(np.float).eps * 4.0
14 | # print("cy_thresh", cy_thresh)
15 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = mat[:, 0, 0], mat[:, 0, 1], mat[:, 0, 2], \
16 | mat[:, 1, 0], mat[:, 1, 1], mat[:, 1, 2], \
17 | mat[:, 2, 0], mat[:, 2, 1], mat[:, 2, 2]
18 | # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2)
19 | cy = torch.sqrt(r33 * r33 + r23 * r23)
20 |
21 | mask = cy > cy_thresh
22 |
23 | if torch.sum(mask) > 1:
24 | euler[mask, 0] = torch.atan2(-r23, r33)[mask]
25 | euler[mask, 1] = torch.atan2(r13, cy)[mask]
26 | euler[mask, 2] = torch.atan2(-r12, r11)[mask]
27 |
28 | mask = cy <= cy_thresh
29 | if torch.sum(mask) > 1:
30 | print("mat2euler!!!!!!")
31 | euler[mask, 0] = 0.0
32 | euler[mask, 1] = torch.atan2(r13, cy) # atan2(sin(y), cy)
33 | euler[mask, 2] = torch.atan2(r21, r22)
34 |
35 | return euler
36 |
37 |
38 | ########################################################################################################################
39 |
40 | def euler2mat(angle):
41 | """Convert euler angles to rotation matrix"""
42 | B = angle.size(0)
43 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
44 |
45 | cosz = torch.cos(z)
46 | sinz = torch.sin(z)
47 |
48 | zeros = z.detach() * 0
49 | ones = zeros.detach() + 1
50 | zmat = torch.stack([cosz, -sinz, zeros,
51 | sinz, cosz, zeros,
52 | zeros, zeros, ones], dim=1).view(B, 3, 3)
53 |
54 | cosy = torch.cos(y)
55 | siny = torch.sin(y)
56 |
57 | ymat = torch.stack([cosy, zeros, siny,
58 | zeros, ones, zeros,
59 | -siny, zeros, cosy], dim=1).view(B, 3, 3)
60 |
61 | cosx = torch.cos(x)
62 | sinx = torch.sin(x)
63 |
64 | xmat = torch.stack([ones, zeros, zeros,
65 | zeros, cosx, -sinx,
66 | zeros, sinx, cosx], dim=1).view(B, 3, 3)
67 |
68 | rot_mat = xmat.bmm(ymat).bmm(zmat)
69 | return rot_mat
70 |
71 | ########################################################################################################################
72 |
73 | def pose_vec2mat(vec, mode='euler'):
74 | """Convert Euler parameters to transformation matrix."""
75 | if mode is None:
76 | return vec
77 | trans, rot = vec[:, :3].unsqueeze(-1), vec[:, 3:]
78 | if mode == 'euler':
79 | rot_mat = euler2mat(rot)
80 | elif mode == 'axis_angle':
81 | rot_mat = axis_angle_to_matrix(rot)
82 | else:
83 | raise ValueError('Rotation mode not supported {}'.format(mode))
84 | mat = torch.cat([rot_mat, trans], dim=2) # [B,3,4]
85 | return mat
86 |
87 | ########################################################################################################################
88 |
89 | def invert_pose(T):
90 | """Inverts a [B,4,4] torch.tensor pose"""
91 | Tinv = torch.eye(4, device=T.device, dtype=T.dtype).repeat([len(T), 1, 1])
92 | Tinv[:, :3, :3] = torch.transpose(T[:, :3, :3], -2, -1)
93 | Tinv[:, :3, -1] = torch.bmm(-1. * Tinv[:, :3, :3], T[:, :3, -1].unsqueeze(-1)).squeeze(-1)
94 | return Tinv
95 |
96 | ########################################################################################################################
97 |
98 | def invert_pose_numpy(T):
99 | """Inverts a [4,4] np.array pose"""
100 | Tinv = np.copy(T)
101 | R, t = Tinv[:3, :3], Tinv[:3, 3]
102 | Tinv[:3, :3], Tinv[:3, 3] = R.T, - np.matmul(R.T, t)
103 | return Tinv
104 |
105 | ########################################################################################################################
106 |
--------------------------------------------------------------------------------
/dro_sfm/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | from dro_sfm.loggers.wandb_logger import WandbLogger
2 |
3 | __all__ = ["WandbLogger"]
--------------------------------------------------------------------------------
/dro_sfm/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/losses/__init__.py
--------------------------------------------------------------------------------
/dro_sfm/losses/loss_base.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch.nn as nn
4 | from dro_sfm.utils.types import is_list
5 |
6 | ########################################################################################################################
7 |
8 | class ProgressiveScaling:
9 | """
10 | Helper class to manage progressive scaling.
11 | After a certain training progress percentage, decrease the number of scales by 1.
12 |
13 | Parameters
14 | ----------
15 | progressive_scaling : float
16 | Training progress percentage where the number of scales is decreased
17 | num_scales : int
18 | Initial number of scales
19 | """
20 | def __init__(self, progressive_scaling, num_scales=4):
21 | self.num_scales = num_scales
22 | # Use it only if bigger than zero (make a list)
23 | if progressive_scaling > 0.0:
24 | self.progressive_scaling = np.float32(
25 | [progressive_scaling * (i + 1) for i in range(num_scales - 1)] + [1.0])
26 | # Otherwise, disable it
27 | else:
28 | self.progressive_scaling = progressive_scaling
29 | def __call__(self, progress):
30 | """
31 | Call for an update in the number of scales
32 |
33 | Parameters
34 | ----------
35 | progress : float
36 | Training progress percentage
37 |
38 | Returns
39 | -------
40 | num_scales : int
41 | New number of scales
42 | """
43 | if is_list(self.progressive_scaling):
44 | return int(self.num_scales -
45 | np.searchsorted(self.progressive_scaling, progress))
46 | else:
47 | return self.num_scales
48 |
49 | ########################################################################################################################
50 |
51 | class LossBase(nn.Module):
52 | """Base class for losses."""
53 | def __init__(self):
54 | """Initializes logs and metrics dictionaries"""
55 | super().__init__()
56 | self._logs = {}
57 | self._metrics = {}
58 |
59 | ########################################################################################################################
60 |
61 | @property
62 | def logs(self):
63 | """Return logs."""
64 | return self._logs
65 |
66 | @property
67 | def metrics(self):
68 | """Return metrics."""
69 | return self._metrics
70 |
71 | def add_metric(self, key, val):
72 | """Add a new metric to the dictionary and detach it."""
73 | self._metrics[key] = val.detach()
74 |
75 | ########################################################################################################################
76 |
--------------------------------------------------------------------------------
/dro_sfm/models/SelfSupModelMF.py:
--------------------------------------------------------------------------------
1 | from dro_sfm.models.SfmModelMF import SfmModelMF
2 | from dro_sfm.losses.multiview_photometric_loss_mf import MultiViewPhotometricDecayLoss
3 | from dro_sfm.models.model_utils import merge_outputs
4 |
5 |
6 | class SelfSupModelMF(SfmModelMF):
7 | """
8 | Model that inherits a depth and pose network from SfmModel and
9 | includes the photometric loss for self-supervised training.
10 |
11 | Parameters
12 | ----------
13 | kwargs : dict
14 | Extra parameters
15 | """
16 | def __init__(self, **kwargs):
17 | # Initializes SfmModel
18 | super().__init__(**kwargs)
19 | # Initializes the photometric loss
20 | self._photometric_loss = MultiViewPhotometricDecayLoss(**kwargs)
21 |
22 | @property
23 | def logs(self):
24 | """Return logs."""
25 | return {
26 | **super().logs,
27 | **self._photometric_loss.logs
28 | }
29 |
30 | def self_supervised_loss(self, image, ref_images, inv_depths, poses,
31 | intrinsics, return_logs=False, progress=0.0):
32 | """
33 | Calculates the self-supervised photometric loss.
34 |
35 | Parameters
36 | ----------
37 | image : torch.Tensor [B,3,H,W]
38 | Original image
39 | ref_images : list of torch.Tensor [B,3,H,W]
40 | Reference images from context
41 | inv_depths : torch.Tensor [B,1,H,W]
42 | Predicted inverse depth maps from the original image
43 | poses : list of Pose
44 | List containing predicted poses between original and context images
45 | intrinsics : torch.Tensor [B,3,3]
46 | Camera intrinsics
47 | return_logs : bool
48 | True if logs are stored
49 | progress :
50 | Training progress percentage
51 |
52 | Returns
53 | -------
54 | output : dict
55 | Dictionary containing a "loss" scalar a "metrics" dictionary
56 | """
57 | return self._photometric_loss(
58 | image, ref_images, inv_depths, intrinsics, intrinsics, poses,
59 | return_logs=return_logs, progress=progress)
60 |
61 | def forward(self, batch, return_logs=False, progress=0.0):
62 | """
63 | Processes a batch.
64 |
65 | Parameters
66 | ----------
67 | batch : dict
68 | Input batch
69 | return_logs : bool
70 | True if logs are stored
71 | progress :
72 | Training progress percentage
73 |
74 | Returns
75 | -------
76 | output : dict
77 | Dictionary containing a "loss" scalar and different metrics and predictions
78 | for logging and downstream usage.
79 | """
80 | # Calculate predicted depth and pose output
81 | output = super().forward(batch, return_logs=return_logs)
82 | if not self.training:
83 | # If not training, no need for self-supervised loss
84 | return output
85 | else:
86 | if output["poses"] is None:
87 | return None
88 | # Otherwise, calculate self-supervised loss
89 | self_sup_output = self.self_supervised_loss(
90 | batch['rgb_original'], batch['rgb_context_original'],
91 | output['inv_depths'], output['poses'], batch['intrinsics'],
92 | return_logs=return_logs, progress=progress)
93 | # Return loss and metrics
94 | return {
95 | 'loss': self_sup_output['loss'],
96 | **merge_outputs(output, self_sup_output),
97 | }
98 |
--------------------------------------------------------------------------------
/dro_sfm/models/SemiSupModelMF.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dro_sfm.models.SelfSupModelMF import SelfSupModelMF, SfmModelMF
3 | from dro_sfm.losses.supervised_loss import SupervisedDepthPoseLoss as SupervisedLoss
4 | from dro_sfm.models.model_utils import merge_outputs
5 | from dro_sfm.utils.depth import depth2inv
6 |
7 |
8 | class SemiSupModelMFPose(SelfSupModelMF):
9 | """
10 | Model that inherits a depth and pose networks, plus the self-supervised loss from
11 | SelfSupModel and includes a supervised loss for semi-supervision.
12 |
13 | Parameters
14 | ----------
15 | supervised_loss_weight : float
16 | Weight for the supervised loss
17 | kwargs : dict
18 | Extra parameters
19 | """
20 | def __init__(self, supervised_loss_weight=0.9, **kwargs):
21 | # Initializes SelfSupModel
22 | super().__init__(**kwargs)
23 | # If supervision weight is 0.0, use SelfSupModel directly
24 | assert 0. < supervised_loss_weight <= 1., "Model requires (0, 1] supervision"
25 | # Store weight and initializes supervised loss
26 | self.supervised_loss_weight = supervised_loss_weight
27 | self._supervised_loss = SupervisedLoss(**kwargs)
28 |
29 | print(f"=================supervised_loss_weight:{supervised_loss_weight}====================")
30 | # # Pose network is only required if there is self-supervision
31 | # self._network_requirements['pose_net'] = self.supervised_loss_weight < 1
32 | # # GT depth is only required if there is supervision
33 | self._train_requirements['gt_depth'] = self.supervised_loss_weight > 0
34 | self._train_requirements['gt_pose'] = True
35 |
36 | @property
37 | def logs(self):
38 | """Return logs."""
39 | return {
40 | **super().logs,
41 | **self._supervised_loss.logs
42 | }
43 |
44 | def supervised_loss(self, image, ref_images, inv_depths, gt_depth, gt_poses, poses,
45 | intrinsics, return_logs=False, progress=0.0):
46 | """
47 | Calculates the self-supervised photometric loss.
48 |
49 | Parameters
50 | ----------
51 | image : torch.Tensor [B,3,H,W]
52 | Original image
53 | ref_images : list of torch.Tensor [B,3,H,W]
54 | Reference images from context
55 | inv_depths : torch.Tensor [B,1,H,W]
56 | Predicted inverse depth maps from the original image
57 | poses : list of Pose
58 | List containing predicted poses between original and context images
59 | intrinsics : torch.Tensor [B,3,3]
60 | Camera intrinsics
61 | return_logs : bool
62 | True if logs are stored
63 | progress :
64 | Training progress percentage
65 |
66 | Returns
67 | -------
68 | output : dict
69 | Dictionary containing a "loss" scalar a "metrics" dictionary
70 | """
71 | return self._supervised_loss(
72 | image, ref_images, inv_depths, depth2inv(gt_depth), gt_poses, intrinsics, intrinsics, poses,
73 | return_logs=return_logs, progress=progress)
74 |
75 | def forward(self, batch, return_logs=False, progress=0.0):
76 | """
77 | Processes a batch.
78 |
79 | Parameters
80 | ----------
81 | batch : dict
82 | Input batch
83 | return_logs : bool
84 | True if logs are stored
85 | progress :
86 | Training progress percentage
87 |
88 | Returns
89 | -------
90 | output : dict
91 | Dictionary containing a "loss" scalar and different metrics and predictions
92 | for logging and downstream usage.
93 | """
94 | if not self.training:
95 | # If not training, no need for self-supervised loss
96 | return SfmModelMF.forward(self, batch)
97 | else:
98 | if self.supervised_loss_weight == 1.:
99 | # If no self-supervision, no need to calculate loss
100 | self_sup_output = SfmModelMF.forward(self, batch)
101 | loss = torch.tensor([0.]).type_as(batch['rgb'])
102 | else:
103 | # Otherwise, calculate and weight self-supervised loss
104 | self_sup_output = SelfSupModelMF.forward(self, batch)
105 | loss = (1.0 - self.supervised_loss_weight) * self_sup_output['loss']
106 | # Calculate and weight supervised loss
107 | sup_output = self.supervised_loss(
108 | batch['rgb_original'], batch['rgb_context_original'],
109 | self_sup_output['inv_depths'], batch['depth'], batch['pose_context'], self_sup_output['poses'], batch['intrinsics'],
110 | return_logs=return_logs, progress=progress)
111 | loss += self.supervised_loss_weight * sup_output['loss']
112 | # Merge and return outputs
113 | return {
114 | 'loss': loss,
115 | **merge_outputs(self_sup_output, sup_output),
116 | }
--------------------------------------------------------------------------------
/dro_sfm/models/SfmModel.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 | from kornia.geometry.epipolar.projection import depth
4 | import torch.nn as nn
5 | from dro_sfm.utils.image import flip_model, interpolate_scales
6 | from dro_sfm.geometry.pose import Pose
7 | from dro_sfm.utils.misc import make_list
8 | from dro_sfm.utils.depth import inv2depth
9 |
10 |
11 | class SfmModel(nn.Module):
12 | """
13 | Model class encapsulating a pose and depth networks.
14 |
15 | Parameters
16 | ----------
17 | depth_net : nn.Module
18 | Depth network to be used
19 | pose_net : nn.Module
20 | Pose network to be used
21 | rotation_mode : str
22 | Rotation mode for the pose network
23 | flip_lr_prob : float
24 | Probability of flipping when using the depth network
25 | upsample_depth_maps : bool
26 | True if depth map scales are upsampled to highest resolution
27 | kwargs : dict
28 | Extra parameters
29 | """
30 | def __init__(self, depth_net=None, pose_net=None,
31 | rotation_mode='euler', flip_lr_prob=0.0,
32 | upsample_depth_maps=False, **kwargs):
33 | super().__init__()
34 | self.depth_net = depth_net
35 | self.pose_net = pose_net
36 | self.rotation_mode = rotation_mode
37 | self.flip_lr_prob = flip_lr_prob
38 | self.upsample_depth_maps = upsample_depth_maps
39 | self._logs = {}
40 | self._losses = {}
41 |
42 | self._network_requirements = {
43 | 'depth_net': True, # Depth network required
44 | 'pose_net': True, # Pose network required
45 | 'percep_net': False, # Pose network required
46 | }
47 | self._train_requirements = {
48 | 'gt_depth': False, # No ground-truth depth required
49 | 'gt_pose': False, # No ground-truth pose required
50 | }
51 |
52 | @property
53 | def logs(self):
54 | """Return logs."""
55 | return self._logs
56 |
57 | @property
58 | def losses(self):
59 | """Return metrics."""
60 | return self._losses
61 |
62 | def add_loss(self, key, val):
63 | """Add a new loss to the dictionary and detaches it."""
64 | self._losses[key] = val.detach()
65 |
66 | @property
67 | def network_requirements(self):
68 | """
69 | Networks required to run the model
70 |
71 | Returns
72 | -------
73 | requirements : dict
74 | depth_net : bool
75 | Whether a depth network is required by the model
76 | pose_net : bool
77 | Whether a depth network is required by the model
78 | """
79 | return self._network_requirements
80 |
81 | @property
82 | def train_requirements(self):
83 | """
84 | Information required by the model at training stage
85 |
86 | Returns
87 | -------
88 | requirements : dict
89 | gt_depth : bool
90 | Whether ground truth depth is required by the model at training time
91 | gt_pose : bool
92 | Whether ground truth pose is required by the model at training time
93 | """
94 | return self._train_requirements
95 |
96 | def add_depth_net(self, depth_net):
97 | """Add a depth network to the model"""
98 | self.depth_net = depth_net
99 |
100 | def add_pose_net(self, pose_net):
101 | """Add a pose network to the model"""
102 | self.pose_net = pose_net
103 |
104 | def compute_inv_depths(self, image):
105 | """Computes inverse depth maps from single images"""
106 | # Randomly flip and estimate inverse depth maps
107 | flip_lr = random.random() < self.flip_lr_prob if self.training else False
108 | inv_depths = make_list(flip_model(self.depth_net, image, flip_lr))
109 | # If upsampling depth maps
110 | if self.upsample_depth_maps:
111 | inv_depths = interpolate_scales(
112 | inv_depths, mode='nearest', align_corners=None)
113 | # Return inverse depth maps
114 | return inv_depths
115 |
116 | def compute_poses(self, image, contexts, intrinsics, depth):
117 | """Compute poses from image and a sequence of context images"""
118 | pose_vec = self.pose_net(image, contexts, intrinsics, depth)
119 | if pose_vec is None:
120 | return None
121 | if pose_vec.shape[2] == 6:
122 | return [Pose.from_vec(pose_vec[:, i], self.rotation_mode)
123 | for i in range(pose_vec.shape[1])]
124 | else:
125 | return [Pose(pose_vec[:, i]) for i in range(pose_vec.shape[1])]
126 |
127 | def forward(self, batch, return_logs=False):
128 | """
129 | Processes a batch.
130 |
131 | Parameters
132 | ----------
133 | batch : dict
134 | Input batch
135 | return_logs : bool
136 | True if logs are stored
137 |
138 | Returns
139 | -------
140 | output : dict
141 | Dictionary containing predicted inverse depth maps and poses
142 | """
143 | # Generate inverse depth predictions
144 | inv_depths = self.compute_inv_depths(batch['rgb'])
145 | # Generate pose predictions if available
146 | pose = None
147 | if 'rgb_context' in batch and self.pose_net is not None:
148 | pose = self.compute_poses(batch['rgb'],
149 | batch['rgb_context'], batch["intrinsics"], inv2depth(inv_depths[0]))
150 | # Return output dictionary
151 | return {
152 | 'inv_depths': inv_depths,
153 | 'poses': pose,
154 | }
155 |
--------------------------------------------------------------------------------
/dro_sfm/models/SfmModelMF.py:
--------------------------------------------------------------------------------
1 |
2 | import random
3 | import torch.nn as nn
4 | from dro_sfm.utils.image import flip_lr_intr, flip_mf_model, interpolate_scales
5 | from dro_sfm.utils.image import flip_lr as flip_lr_img
6 | from dro_sfm.geometry.pose import Pose
7 | from dro_sfm.utils.misc import make_list
8 |
9 |
10 | class SfmModelMF(nn.Module):
11 | """
12 | Model class encapsulating a pose and depth networks.
13 |
14 | Parameters
15 | ----------
16 | depth_net : nn.Module
17 | Depth network to be used
18 | pose_net : nn.Module
19 | Pose network to be used
20 | rotation_mode : str
21 | Rotation mode for the pose network
22 | flip_lr_prob : float
23 | Probability of flipping when using the depth network
24 | upsample_depth_maps : bool
25 | True if depth map scales are upsampled to highest resolution
26 | kwargs : dict
27 | Extra parameters
28 | """
29 | def __init__(self, depth_net=None, pose_net=None,
30 | rotation_mode='euler', flip_lr_prob=0.0,
31 | upsample_depth_maps=False, min_depth=0.1, max_depth=100, **kwargs):
32 | super().__init__()
33 | self.depth_net = depth_net
34 | self.pose_net = pose_net
35 | self.rotation_mode = rotation_mode
36 | self.flip_lr_prob = flip_lr_prob
37 | self.upsample_depth_maps = upsample_depth_maps
38 | self.min_depth = min_depth
39 | self.max_depth = max_depth
40 | self._logs = {}
41 | self._losses = {}
42 | self._network_requirements = {
43 | 'depth_net': True, # Depth network required
44 | 'pose_net': False, # Pose network required
45 | 'percep_net': False, # Pose network required
46 | }
47 | self._train_requirements = {
48 | 'gt_depth': False, # No ground-truth depth required
49 | 'gt_pose': False, # No ground-truth pose required
50 | }
51 |
52 | @property
53 | def logs(self):
54 | """Return logs."""
55 | return self._logs
56 |
57 | @property
58 | def losses(self):
59 | """Return metrics."""
60 | return self._losses
61 |
62 | def add_loss(self, key, val):
63 | """Add a new loss to the dictionary and detaches it."""
64 | self._losses[key] = val.detach()
65 |
66 | @property
67 | def network_requirements(self):
68 | """
69 | Networks required to run the model
70 |
71 | Returns
72 | -------
73 | requirements : dict
74 | depth_net : bool
75 | Whether a depth network is required by the model
76 | pose_net : bool
77 | Whether a depth network is required by the model
78 | """
79 | return self._network_requirements
80 |
81 | @property
82 | def train_requirements(self):
83 | """
84 | Information required by the model at training stage
85 |
86 | Returns
87 | -------
88 | requirements : dict
89 | gt_depth : bool
90 | Whether ground truth depth is required by the model at training time
91 | gt_pose : bool
92 | Whether ground truth pose is required by the model at training time
93 | """
94 | return self._train_requirements
95 |
96 | def add_depth_net(self, depth_net):
97 | """Add a depth network to the model"""
98 | self.depth_net = depth_net
99 |
100 | def add_pose_net(self, pose_net):
101 | """Add a pose network to the model"""
102 | self.pose_net = pose_net
103 |
104 | def compute_inv_depths(self, image, ref_imgs, intrinsics):
105 | """Computes inverse depth maps from single images"""
106 | # Randomly flip and estimate inverse depth maps
107 | flip_lr = random.random() < self.flip_lr_prob if self.training else False
108 | if flip_lr:
109 | intrinsics = flip_lr_intr(intrinsics, width=image.shape[3])
110 | inv_depths_with_poses = flip_mf_model(self.depth_net, image, ref_imgs, intrinsics, flip_lr)
111 | inv_depths, poses = inv_depths_with_poses
112 | inv_depths = make_list(inv_depths)
113 | if flip_lr:
114 | inv_depths = [flip_lr_img(inv_d) for inv_d in inv_depths]
115 | # If upsampling depth maps
116 | if self.upsample_depth_maps:
117 | inv_depths = interpolate_scales(
118 | inv_depths, mode='nearest', align_corners=None)
119 | # Return inverse depth maps
120 | return inv_depths, poses
121 |
122 | def compute_poses(self, image, contexts, intrinsics, depth):
123 | """Compute poses from image and a sequence of context images"""
124 | pose_vec = self.pose_net(image, contexts, intrinsics, depth)
125 | if pose_vec is None:
126 | return None
127 | if pose_vec.shape[2] == 6:
128 | return [Pose.from_vec(pose_vec[:, i], self.rotation_mode)
129 | for i in range(pose_vec.shape[1])]
130 | else:
131 | return [Pose(pose_vec[:, i]) for i in range(pose_vec.shape[1])]
132 |
133 | def forward(self, batch, return_logs=False):
134 | """
135 | Processes a batch.
136 |
137 | Parameters
138 | ----------
139 | batch : dict
140 | Input batch
141 | return_logs : bool
142 | True if logs are stored
143 |
144 | Returns
145 | -------
146 | output : dict
147 | Dictionary containing predicted inverse depth maps and poses
148 | """
149 | # Generate inverse depth predictions
150 | inv_depths, pose_vec = self.compute_inv_depths(batch['rgb'], batch['rgb_context'], batch["intrinsics"])
151 | # # Generate pose predictions if available
152 | # pose = None
153 | # if 'rgb_context' in batch and self.pose_net is not None:
154 | # pose = self.compute_poses(batch['rgb'],
155 | # batch['rgb_context'], batch["intrinsics"], inv2depth(inv_depths[0]))
156 | # Return output dictionary
157 | if pose_vec.shape[2] == 6:
158 | poses = [Pose.from_vec(pose_vec[:, i], self.rotation_mode)
159 | for i in range(pose_vec.shape[1])]
160 | elif (pose_vec.shape[2]) == 4 and (pose_vec.shape[3] == 4):
161 | poses = [Pose(pose_vec[:, i]) for i in range(pose_vec.shape[1])]
162 | else:
163 | #pose_vec shape: (b, n_view, n_iter, 6)
164 | poses = []
165 | for i in range(pose_vec.shape[1]):
166 | poses_view = []
167 | for j in range(pose_vec.shape[2]):
168 | poses_view.append(Pose.from_vec(pose_vec[:, i, j], self.rotation_mode))
169 | poses.append(poses_view) #([pose_view1, pose_view2, ....]) each view has n_iter pose
170 |
171 | # print(poses[0][-1].shape, len(poses), len(poses[0]), len(inv_depths), inv_depths[0].shape)
172 | # print(poses[0][-1].mat[0], inv2depth(inv_depths)[-1][0, 0, 12, 40])
173 | # print("gt", batch["pose_context"][0][0])
174 | return {
175 | 'inv_depths': inv_depths,
176 | 'poses': poses,
177 | }
--------------------------------------------------------------------------------
/dro_sfm/models/SupModelMF.py:
--------------------------------------------------------------------------------
1 |
2 | from dro_sfm.utils.depth import depth2inv
3 | from dro_sfm.losses.supervised_loss import SupervisedDepthPoseLoss
4 | from dro_sfm.models.SfmModelMF import SfmModelMF
5 | from dro_sfm.models.model_utils import merge_outputs
6 |
7 |
8 | class SupModelMF(SfmModelMF):
9 | """
10 | Model that inherits a depth and pose network from SfmModel and
11 | includes the photometric loss for self-supervised training.
12 |
13 | Parameters
14 | ----------
15 | kwargs : dict
16 | Extra parameters
17 | """
18 | def __init__(self, **kwargs):
19 | # Initializes SfmModel
20 | super().__init__(**kwargs)
21 | # Initializes the photometric loss
22 |
23 | self._network_requirements = {
24 | 'depth_net': True, # Depth network required
25 | 'pose_net': False, # Pose network required
26 | 'percep_net': False, # Pose network required
27 | }
28 |
29 | self._train_requirements = {
30 | 'gt_depth': True, # No ground-truth depth required
31 | 'gt_pose': True, # No ground-truth pose required
32 | }
33 |
34 | # self._photometric_loss = MultiViewPhotometricLoss(**kwargs)
35 | self._loss = SupervisedDepthPoseLoss(**kwargs)
36 |
37 | @property
38 | def logs(self):
39 | """Return logs."""
40 | return {
41 | **super().logs,
42 | **self._photometric_loss.logs
43 | }
44 |
45 | def supervised_loss(self, image, ref_images, inv_depths, gt_depth, gt_poses, poses,
46 | intrinsics, return_logs=False, progress=0.0):
47 | """
48 | Calculates the self-supervised photometric loss.
49 |
50 | Parameters
51 | ----------
52 | image : torch.Tensor [B,3,H,W]
53 | Original image
54 | ref_images : list of torch.Tensor [B,3,H,W]
55 | Reference images from context
56 | inv_depths : torch.Tensor [B,1,H,W]
57 | Predicted inverse depth maps from the original image
58 | poses : list of Pose
59 | List containing predicted poses between original and context images
60 | intrinsics : torch.Tensor [B,3,3]
61 | Camera intrinsics
62 | return_logs : bool
63 | True if logs are stored
64 | progress :
65 | Training progress percentage
66 |
67 | Returns
68 | -------
69 | output : dict
70 | Dictionary containing a "loss" scalar a "metrics" dictionary
71 | """
72 | return self._loss(
73 | image, ref_images, inv_depths, depth2inv(gt_depth), gt_poses, intrinsics, intrinsics, poses,
74 | return_logs=return_logs, progress=progress)
75 |
76 | def forward(self, batch, return_logs=False, progress=0.0):
77 | """
78 | Processes a batch.
79 |
80 | Parameters
81 | ----------
82 | batch : dict
83 | Input batch
84 | return_logs : bool
85 | True if logs are stored
86 | progress :
87 | Training progress percentage
88 |
89 | Returns
90 | -------
91 | output : dict
92 | Dictionary containing a "loss" scalar and different metrics and predictions
93 | for logging and downstream usage.
94 | """
95 | # Calculate predicted depth and pose output
96 | output = super().forward(batch, return_logs=return_logs)
97 | if not self.training:
98 | # If not training, no need for self-supervised loss
99 | return output
100 | else:
101 | if output["poses"] is None:
102 | return None
103 | # Otherwise, calculate self-supervised loss
104 | self_sup_output = self.supervised_loss(
105 | batch['rgb_original'], batch['rgb_context_original'],
106 | output['inv_depths'], batch['depth'], batch['pose_context'], output['poses'], batch['intrinsics'],
107 | return_logs=return_logs, progress=progress)
108 | # Return loss and metrics
109 | return {
110 | 'loss': self_sup_output['loss'],
111 | **merge_outputs(output, self_sup_output),
112 | }
--------------------------------------------------------------------------------
/dro_sfm/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Structure-from-Motion (SfM) Models and wrappers
3 | ===============================================
4 |
5 | - SfmModel is a torch.nn.Module wrapping both a Depth and a Pose network to enable training in a Structure-from-Motion setup (i.e. from videos)
6 | - SelfSupModel is an SfmModel specialized for self-supervised learning (using videos only)
7 | - SemiSupModel is an SfmModel specialized for semi-supervised learning (using videos and depth supervision)
8 | - ModelWrapper is a torch.nn.Module that wraps an SfmModel to enable easy training and eval with a trainer
9 | - ModelCheckpoint enables saving/restoring state of torch.nn.Module objects
10 |
11 | """
12 |
--------------------------------------------------------------------------------
/dro_sfm/models/model_checkpoint.py:
--------------------------------------------------------------------------------
1 |
2 | # Adapted from Pytorch-Lightning
3 | # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/model_checkpoint.py
4 |
5 | import os, re
6 | import numpy as np
7 | import torch
8 | from dro_sfm.utils.logging import pcolor
9 |
10 |
11 | def sync_s3_data(local, model):
12 | """Sync saved models with the s3 bucket"""
13 | remote = os.path.join(model.config.checkpoint.s3_path, model.config.name)
14 | command = 'aws s3 sync {} {} --acl bucket-owner-full-control --quiet --delete'.format(local, remote)
15 | os.system(command)
16 |
17 |
18 | def save_code(filepath):
19 | """Save code in the models folder"""
20 | os.system('tar cfz {}/code.tar.gz *'.format(filepath))
21 |
22 |
23 | class ModelCheckpoint:
24 | def __init__(self, filepath=None, monitor='val_loss',
25 | save_top_k=1, mode='auto', period=1,
26 | s3_path='', s3_frequency=5):
27 | super().__init__()
28 | # If save_top_k is zero, save all models
29 | if save_top_k == 0:
30 | save_top_k = 1e6
31 | # Create checkpoint folder
32 | self.dirpath, self.filename = os.path.split(filepath)
33 | print(self.dirpath, self.filename, filepath)
34 | os.makedirs(self.dirpath, exist_ok=True)
35 | # Store arguments
36 | self.monitor = monitor
37 | self.save_top_k = save_top_k
38 | self.period = period
39 | self.epoch_last_check = None
40 | self.best_k_models = {}
41 | self.kth_best_model = ''
42 | self.best = 0
43 | # Monitoring modes
44 | torch_inf = torch.tensor(np.Inf)
45 | mode_dict = {
46 | 'min': (torch_inf, 'min'),
47 | 'max': (-torch_inf, 'max'),
48 | 'auto': (-torch_inf, 'max') if \
49 | 'acc' in self.monitor or \
50 | 'a1' in self.monitor or \
51 | self.monitor.startswith('fmeasure')
52 | else (torch_inf, 'min'),
53 | }
54 | self.kth_value, self.mode = mode_dict[mode]
55 |
56 | self.s3_path = s3_path
57 | self.s3_frequency = s3_frequency
58 | self.s3_enabled = s3_path is not '' and s3_frequency > 0
59 | self.save_code = True
60 |
61 | @staticmethod
62 | def _del_model(filepath):
63 | if os.path.isfile(filepath):
64 | os.remove(filepath)
65 |
66 | def _save_model(self, filepath, model):
67 | # Create folder, save model and sync to s3
68 | os.makedirs(os.path.dirname(filepath), exist_ok=True)
69 | torch.save({
70 | 'config': model.config,
71 | 'epoch': model.current_epoch,
72 | 'state_dict': model.state_dict(),
73 | 'optimizer': model.optimizer.state_dict(),
74 | 'scheduler': model.scheduler.state_dict(),
75 | }, filepath)
76 | self._sync_s3(filepath, model)
77 |
78 | def _sync_s3(self, filepath, model):
79 | # If it's not time to sync, do nothing
80 | if self.s3_enabled and (model.current_epoch + 1) % self.s3_frequency == 0:
81 | filepath = os.path.dirname(filepath)
82 | # Print message and links
83 | print(pcolor('###### Syncing: {} -> {}'.format(filepath,
84 | model.config.checkpoint.s3_path), 'red', attrs=['bold']))
85 | print(pcolor('###### URL: {}'.format(
86 | model.config.checkpoint.s3_url), 'red', attrs=['bold']))
87 | # If it's time to save code
88 | if self.save_code:
89 | self.save_code = False
90 | save_code(filepath)
91 | # Sync model to s3
92 | sync_s3_data(filepath, model)
93 |
94 | def check_monitor_top_k(self, current):
95 | # If we don't have enough models
96 | if len(self.best_k_models) < self.save_top_k:
97 | return True
98 | # Convert to torch if necessary
99 | if not isinstance(current, torch.Tensor):
100 | current = torch.tensor(current)
101 | # Get monitoring operation
102 | monitor_op = {
103 | "min": torch.lt,
104 | "max": torch.gt,
105 | }[self.mode]
106 | # Compare and return
107 | return monitor_op(current, self.best_k_models[self.kth_best_model])
108 |
109 | def format_checkpoint_name(self, epoch, metrics):
110 | metrics['epoch'] = epoch
111 | filename = self.filename
112 | for tmp in re.findall(r'(\{.*?)[:\}]', self.filename):
113 | name = tmp[1:]
114 | filename = filename.replace(tmp, name + '={' + name)
115 | if name not in metrics:
116 | metrics[name] = 0
117 | filename = filename.format(**metrics)
118 | return os.path.join(self.dirpath, '{}.ckpt'.format(filename))
119 |
120 | def check_and_save(self, model, metrics):
121 | # Check saving interval
122 | epoch = model.current_epoch
123 | if self.epoch_last_check is not None and \
124 | (epoch - self.epoch_last_check) < self.period:
125 | return
126 | self.epoch_last_check = epoch
127 | # Prepare filepath
128 | filepath = self.format_checkpoint_name(epoch, metrics)
129 | while os.path.isfile(filepath):
130 | filepath = self.format_checkpoint_name(epoch, metrics)
131 | # Check if saving or not
132 | if self.save_top_k != -1:
133 | current = metrics.get(self.monitor)
134 | assert current, 'Checkpoint metric is not available'
135 | if self.check_monitor_top_k(current):
136 | self._do_check_save(filepath, model, current)
137 | else:
138 | self._save_model(filepath, model)
139 |
140 | def _do_check_save(self, filepath, model, current):
141 | # List of models to delete
142 | del_list = []
143 | if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
144 | delpath = self.kth_best_model
145 | self.best_k_models.pop(self.kth_best_model)
146 | del_list.append(delpath)
147 | # Monitor current models
148 | self.best_k_models[filepath] = current
149 | if len(self.best_k_models) == self.save_top_k:
150 | # Monitor dict has reached k elements
151 | _op = max if self.mode == 'min' else min
152 | self.kth_best_model = _op(self.best_k_models,
153 | key=self.best_k_models.get)
154 | self.kth_value = self.best_k_models[self.kth_best_model]
155 | # Determine best model
156 | _op = min if self.mode == 'min' else max
157 | self.best = _op(self.best_k_models.values())
158 | # Delete old models
159 | for cur_path in del_list:
160 | if cur_path != filepath:
161 | self._del_model(cur_path)
162 | # Save model
163 | self._save_model(filepath, model)
164 |
--------------------------------------------------------------------------------
/dro_sfm/models/model_utils.py:
--------------------------------------------------------------------------------
1 |
2 | from dro_sfm.utils.types import is_tensor, is_list, is_numpy
3 |
4 | def merge_outputs(*outputs):
5 | """
6 | Merges model outputs for logging
7 |
8 | Parameters
9 | ----------
10 | outputs : tuple of dict
11 | Outputs to be merged
12 |
13 | Returns
14 | -------
15 | output : dict
16 | Dictionary with a "metrics" key containing a dictionary with various metrics and
17 | all other keys that are not "loss" (it is handled differently).
18 | """
19 | ignore = ['loss'] # Keys to ignore
20 | combine = ['metrics'] # Keys to combine
21 | merge = {key: {} for key in combine}
22 | for output in outputs:
23 | # Iterate over all keys
24 | for key, val in output.items():
25 | # Combine these keys
26 | if key in combine:
27 | for sub_key, sub_val in output[key].items():
28 | assert sub_key not in merge[key].keys(), \
29 | 'Combining duplicated key {} to {}'.format(sub_key, key)
30 | merge[key][sub_key] = sub_val
31 | # Ignore these keys
32 | elif key not in ignore:
33 | assert key not in merge.keys(), \
34 | 'Adding duplicated key {}'.format(key)
35 | merge[key] = val
36 | return merge
37 |
38 |
39 | def stack_batch(batch):
40 | """
41 | Stack multi-camera batches (B,N,C,H,W becomes BN,C,H,W)
42 |
43 | Parameters
44 | ----------
45 | batch : dict
46 | Batch
47 |
48 | Returns
49 | -------
50 | batch : dict
51 | Stacked batch
52 | """
53 | # If there is multi-camera information
54 | if len(batch['rgb'].shape) == 5:
55 | assert batch['rgb'].shape[0] == 1, 'Only batch size 1 is supported for multi-cameras'
56 | # Loop over all keys
57 | for key in batch.keys():
58 | # If list, stack every item
59 | if is_list(batch[key]):
60 | if is_tensor(batch[key][0]) or is_numpy(batch[key][0]):
61 | batch[key] = [sample[0] for sample in batch[key]]
62 | # Else, stack single item
63 | else:
64 | batch[key] = batch[key][0]
65 | return batch
66 |
--------------------------------------------------------------------------------
/dro_sfm/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/networks/__init__.py
--------------------------------------------------------------------------------
/dro_sfm/networks/layers/PercepNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | class PercepNet(torch.nn.Module):
7 | def __init__(self, requires_grad=False, resize=True):
8 | super(PercepNet, self).__init__()
9 | mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406])
10 | std_rgb = torch.FloatTensor([0.229, 0.224, 0.225])
11 | self.register_buffer('mean_rgb', mean_rgb)
12 | self.register_buffer('std_rgb', std_rgb)
13 | self.resize = resize
14 |
15 | vgg_pretrained_features = torchvision.models.vgg16(pretrained=True).features
16 | self.slice1 = nn.Sequential()
17 | self.slice2 = nn.Sequential()
18 | self.slice3 = nn.Sequential()
19 | self.slice4 = nn.Sequential()
20 | for x in range(4):
21 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
22 | for x in range(4, 9):
23 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
24 | for x in range(9, 16):
25 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
26 | for x in range(16, 23):
27 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
28 | if not requires_grad:
29 | for param in self.parameters():
30 | param.requires_grad = False
31 |
32 | def normalize(self, x):
33 | out = (x - self.mean_rgb.view(1,3,1,1)) / self.std_rgb.view(1,3,1,1)
34 | if self.resize:
35 | out = nn.functional.interpolate(out, mode='bilinear',size=(224, 224), align_corners=False)
36 | return out
37 |
38 | def forward(self, im1, im2):
39 | im = torch.cat([im1,im2], 0)
40 | im = self.normalize(im) # normalize input
41 |
42 | ## compute features
43 | feats = []
44 | f = self.slice1(im)
45 | h, w = f.shape[-2:]
46 | feats += [torch.chunk(f, 2, dim=0)]
47 | f = self.slice2(f)
48 | feats += [torch.chunk(f, 2, dim=0)]
49 | f = self.slice3(f)
50 | feats += [torch.chunk(f, 2, dim=0)]
51 | f = self.slice4(f)
52 | feats += [torch.chunk(f, 2, dim=0)]
53 |
54 | losses = []
55 | #weights = [0.3, 0.3, 0.4]
56 | #weights = [0.15, 0.25, 0.25, 0.25]
57 | weights = [0.15, 0.25, 0.6]
58 | #weights = [1.0, 1.0, 1.0]
59 | for i, (f1, f2) in enumerate(feats[0:3]):
60 | loss = weights[i] * torch.abs(f1-f2).mean(1, True) #(B, 1, H, W)
61 | loss = nn.functional.interpolate(loss, mode='bilinear',size=(h, w), align_corners=False)
62 | losses += [loss]
63 |
64 | return sum(losses)
65 |
66 |
--------------------------------------------------------------------------------
/dro_sfm/networks/layers/resnet/layers.py:
--------------------------------------------------------------------------------
1 |
2 | # Adapted from monodepth2
3 | # https://github.com/nianticlabs/monodepth2/blob/master/layers.py
4 |
5 | from __future__ import absolute_import, division, print_function
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def disp_to_depth(disp, min_depth, max_depth):
12 | """Convert network's sigmoid output into depth prediction
13 | The formula for this conversion is given in the 'additional considerations'
14 | section of the paper.
15 | """
16 | min_disp = 1 / max_depth
17 | max_disp = 1 / min_depth
18 | scaled_disp = min_disp + (max_disp - min_disp) * disp
19 | depth = 1 / scaled_disp
20 | return scaled_disp, depth
21 |
22 |
23 | class ConvBlock(nn.Module):
24 | """Layer to perform a convolution followed by ELU
25 | """
26 | def __init__(self, in_channels, out_channels):
27 | super(ConvBlock, self).__init__()
28 |
29 | self.conv = Conv3x3(in_channels, out_channels)
30 | self.nonlin = nn.ELU(inplace=True)
31 |
32 | def forward(self, x):
33 | out = self.conv(x)
34 | out = self.nonlin(out)
35 | return out
36 |
37 |
38 | class Conv3x3(nn.Module):
39 | """Layer to pad and convolve input
40 | """
41 | def __init__(self, in_channels, out_channels, use_refl=True):
42 | super(Conv3x3, self).__init__()
43 |
44 | if use_refl:
45 | self.pad = nn.ReflectionPad2d(1)
46 | else:
47 | self.pad = nn.ZeroPad2d(1)
48 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
49 |
50 | def forward(self, x):
51 | out = self.pad(x)
52 | out = self.conv(out)
53 | return out
54 |
55 |
56 | def upsample(x):
57 | """Upsample input tensor by a factor of 2
58 | """
59 | return F.interpolate(x, scale_factor=2, mode="nearest")
60 |
61 |
62 |
--------------------------------------------------------------------------------
/dro_sfm/networks/layers/resnet/pose_decoder.py:
--------------------------------------------------------------------------------
1 |
2 | # Adapted from monodepth2
3 | # https://github.com/nianticlabs/monodepth2/blob/master/networks/pose_decoder.py
4 |
5 | from __future__ import absolute_import, division, print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 | from collections import OrderedDict
10 |
11 |
12 | class PoseDecoder(nn.Module):
13 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
14 | super(PoseDecoder, self).__init__()
15 |
16 | self.num_ch_enc = num_ch_enc
17 | self.num_input_features = num_input_features
18 |
19 | if num_frames_to_predict_for is None:
20 | num_frames_to_predict_for = num_input_features - 1
21 | self.num_frames_to_predict_for = num_frames_to_predict_for
22 |
23 | self.convs = OrderedDict()
24 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
25 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
26 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
27 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
28 |
29 | self.relu = nn.ReLU()
30 |
31 | self.net = nn.ModuleList(list(self.convs.values()))
32 |
33 | def forward(self, input_features):
34 | last_features = [f[-1] for f in input_features]
35 |
36 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
37 | cat_features = torch.cat(cat_features, 1)
38 |
39 | out = cat_features
40 | for i in range(3):
41 | out = self.convs[("pose", i)](out)
42 | if i != 2:
43 | out = self.relu(out)
44 | out = out.mean(3).mean(2)
45 |
46 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6)
47 |
48 | axisangle = out[..., :3]
49 | translation = out[..., 3:]
50 |
51 | return axisangle, translation
52 |
--------------------------------------------------------------------------------
/dro_sfm/networks/layers/resnet/pose_res_decoder.py:
--------------------------------------------------------------------------------
1 |
2 | # Adapted from monodepth2
3 | # https://github.com/nianticlabs/monodepth2/blob/master/networks/pose_decoder.py
4 |
5 | from __future__ import absolute_import, division, print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 | from collections import OrderedDict
10 |
11 |
12 | class PoseResDecoder(nn.Module):
13 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
14 | super(PoseResDecoder, self).__init__()
15 |
16 | self.num_ch_enc = num_ch_enc
17 | self.num_input_features = num_input_features
18 |
19 | if num_frames_to_predict_for is None:
20 | num_frames_to_predict_for = num_input_features - 1
21 | self.num_frames_to_predict_for = num_frames_to_predict_for
22 |
23 | self.convs = OrderedDict()
24 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
25 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
26 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
27 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1)
28 |
29 | self.relu = nn.ReLU()
30 |
31 | self.net = nn.ModuleList(list(self.convs.values()))
32 |
33 | self.found_mat_net = nn.Sequential(nn.Linear(9, 18), nn.ReLU(inplace=True),
34 | nn.Linear(18, 18), nn.ReLU(inplace=True),
35 | nn.Linear(18, 6))
36 |
37 | self.fusion_net = nn.Sequential(nn.Linear(12, 24), nn.ReLU(inplace=True),
38 | nn.Linear(24, 24), nn.ReLU(inplace=True),
39 | nn.Linear(24, 6))
40 |
41 | def forward(self, input_features, foud_mat):
42 | last_features = [f[-1] for f in input_features]
43 |
44 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
45 | cat_features = torch.cat(cat_features, 1)
46 |
47 | out = cat_features
48 | for i in range(3):
49 | out = self.convs[("pose", i)](out)
50 | if i != 2:
51 | out = self.relu(out)
52 | out = out.mean(3).mean(2)
53 |
54 | # out = 0.01 * (out.view(-1, self.num_frames_to_predict_for, 1, 6) + self.found_mat_net(foud_mat.view(-1, 9)).view(-1, 1, 1, 6))
55 |
56 | fund_mat_proj = self.found_mat_net(foud_mat.view(-1, 9))
57 | out = 0.01 * (self.fusion_net(torch.cat([out.view(-1, 6), fund_mat_proj], dim=1)) + fund_mat_proj).view(-1, 1, 1, 6)
58 |
59 | print("out", out.view(-1, 6))
60 | print("fund", fund_mat_proj.view(-1, 6))
61 |
62 | axisangle = out[..., :3]
63 | translation = out[..., 3:]
64 | return axisangle, translation
65 |
66 |
67 |
68 | class PoseResAngleDecoder(nn.Module):
69 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1):
70 | super(PoseResAngleDecoder, self).__init__()
71 |
72 | self.num_ch_enc = num_ch_enc
73 | self.num_input_features = num_input_features
74 |
75 | if num_frames_to_predict_for is None:
76 | num_frames_to_predict_for = num_input_features - 1
77 | self.num_frames_to_predict_for = num_frames_to_predict_for
78 |
79 | self.convs = OrderedDict()
80 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1)
81 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1)
82 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1)
83 | self.convs[("pose", 2)] = nn.Conv2d(256, 7 * num_frames_to_predict_for, 1)
84 |
85 | self.relu = nn.ReLU()
86 |
87 | self.net = nn.ModuleList(list(self.convs.values()))
88 |
89 | self.found_mat_net = nn.Sequential(nn.Linear(6, 18), nn.ReLU(inplace=True),
90 | nn.Linear(18, 18), nn.ReLU(inplace=True),
91 | nn.Linear(18, 6))
92 |
93 | self.fusion_net = nn.Sequential(nn.Linear(12, 128), nn.ReLU(inplace=True),
94 | nn.Linear(128, 128), nn.ReLU(inplace=True),
95 | nn.Linear(128, 6))
96 |
97 | def forward(self, input_features, pose_geo):
98 | last_features = [f[-1] for f in input_features]
99 |
100 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features]
101 | cat_features = torch.cat(cat_features, 1)
102 |
103 | out = cat_features
104 | for i in range(3):
105 | out = self.convs[("pose", i)](out)
106 | if i != 2:
107 | out = self.relu(out)
108 | out = out.mean(3).mean(2).view(-1, 7)
109 |
110 | # out = 0.01 * (out.view(-1, self.num_frames_to_predict_for, 1, 6) + self.found_mat_net(foud_mat.view(-1, 9)).view(-1, 1, 1, 6))
111 |
112 | #trans_scale = 0.01 * pose_geo[:, 3:] * out[:, -1].unsqueeze(1)
113 | trans_scale = pose_geo[:, 3:] * out[:, -1].unsqueeze(1)
114 |
115 | pose_geo_new = torch.cat([pose_geo[:, :3], trans_scale], dim=1)
116 |
117 | #out = 0.01 * (self.fusion_net(torch.cat([out[:, :6], self.found_mat_net(pose_geo_new)], dim=1)))
118 | # out = 0.01 * (self.fusion_net(torch.cat([0.01 * out[:, :6], pose_geo_new], dim=1)))
119 | out = 0.01 * (self.fusion_net(torch.cat([out[:, :6], pose_geo_new], dim=1)))
120 |
121 | # out = 0.01 * out[:, :6] + pose_geo_new
122 |
123 | out = out.view(-1, 1, 1, 6)
124 |
125 | axisangle = out[..., :3]
126 | translation = out[..., 3:]
127 | return axisangle, translation
128 |
129 |
--------------------------------------------------------------------------------
/dro_sfm/networks/layers/resnet/resnet_encoder.py:
--------------------------------------------------------------------------------
1 |
2 | # Adapted from monodepth2
3 | # https://github.com/nianticlabs/monodepth2/blob/master/networks/resnet_encoder.py
4 |
5 | from __future__ import absolute_import, division, print_function
6 |
7 | import numpy as np
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torchvision.models as models
12 | import torch.utils.model_zoo as model_zoo
13 |
14 |
15 | class ResNetMultiImageInput(models.ResNet):
16 | """Constructs a resnet model with varying number of input images.
17 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
18 | """
19 | def __init__(self, block, layers, num_classes=1000, num_input_images=1):
20 | super(ResNetMultiImageInput, self).__init__(block, layers)
21 | self.inplanes = 64
22 | self.conv1 = nn.Conv2d(
23 | num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
24 | self.bn1 = nn.BatchNorm2d(64)
25 | self.relu = nn.ReLU(inplace=True)
26 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
27 | self.layer1 = self._make_layer(block, 64, layers[0])
28 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
29 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
30 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
31 |
32 | for m in self.modules():
33 | if isinstance(m, nn.Conv2d):
34 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
35 | elif isinstance(m, nn.BatchNorm2d):
36 | nn.init.constant_(m.weight, 1)
37 | nn.init.constant_(m.bias, 0)
38 |
39 |
40 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
41 | """Constructs a ResNet model.
42 | Args:
43 | num_layers (int): Number of resnet layers. Must be 18 or 50
44 | pretrained (bool): If True, returns a model pre-trained on ImageNet
45 | num_input_images (int): Number of frames stacked as input
46 | """
47 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
48 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
49 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
50 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
51 |
52 | if pretrained:
53 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)])
54 | loaded['conv1.weight'] = torch.cat(
55 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images
56 | model.load_state_dict(loaded)
57 | return model
58 |
59 |
60 | class ResnetEncoder(nn.Module):
61 | """Pytorch module for a resnet encoder
62 | """
63 | def __init__(self, num_layers, pretrained, num_input_images=1):
64 | super(ResnetEncoder, self).__init__()
65 | self.num_ch_enc = np.array([64, 64, 128, 256, 512])
66 |
67 | resnets = {18: models.resnet18,
68 | 34: models.resnet34,
69 | 50: models.resnet50,
70 | 101: models.resnet101,
71 | 152: models.resnet152}
72 |
73 | if num_layers not in resnets:
74 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
75 |
76 | if num_input_images > 1:
77 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
78 | else:
79 | self.encoder = resnets[num_layers](pretrained)
80 |
81 | if num_layers > 34:
82 | self.num_ch_enc[1:] *= 4
83 |
84 | def forward(self, input_image):
85 | self.features = []
86 | x = (input_image - 0.45) / 0.225
87 | x = self.encoder.conv1(x)
88 | x = self.encoder.bn1(x)
89 | self.features.append(self.encoder.relu(x))
90 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
91 | self.features.append(self.encoder.layer2(self.features[-1]))
92 | self.features.append(self.encoder.layer3(self.features[-1]))
93 | self.features.append(self.encoder.layer4(self.features[-1]))
94 |
95 | return self.features
96 |
--------------------------------------------------------------------------------
/dro_sfm/networks/optim/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/dro_sfm/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Trainers
3 | ========
4 |
5 | Trainer classes providing an easy way to train and evaluate SfM models
6 | when wrapped in a ModelWrapper.
7 |
8 | Inspired by pytorch-lightning.
9 |
10 | """
11 |
12 | from dro_sfm.trainers.horovod_trainer import HorovodTrainer
13 |
14 | __all__ = ["HorovodTrainer"]
--------------------------------------------------------------------------------
/dro_sfm/trainers/base_trainer.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from tqdm import tqdm
4 | from dro_sfm.utils.logging import prepare_dataset_prefix
5 |
6 |
7 | def sample_to_cuda(data, dtype=None):
8 | if isinstance(data, str):
9 | return data
10 | elif isinstance(data, dict):
11 | return {key: sample_to_cuda(data[key], dtype) for key in data.keys()}
12 | elif isinstance(data, list):
13 | return [sample_to_cuda(val, dtype) for val in data]
14 | else:
15 | # only convert floats (e.g., to half), otherwise preserve (e.g, ints)
16 | dtype = dtype if torch.is_floating_point(data) else None
17 | return data.to('cuda', dtype=dtype)
18 |
19 |
20 | class BaseTrainer:
21 | def __init__(self, min_epochs=0, max_epochs=50,
22 | checkpoint=None, **kwargs):
23 |
24 | self.min_epochs = min_epochs
25 | self.max_epochs = max_epochs
26 |
27 | self.checkpoint = checkpoint
28 | self.module = None
29 |
30 | @property
31 | def proc_rank(self):
32 | raise NotImplementedError('Not implemented for BaseTrainer')
33 |
34 | @property
35 | def world_size(self):
36 | raise NotImplementedError('Not implemented for BaseTrainer')
37 |
38 | @property
39 | def is_rank_0(self):
40 | return self.proc_rank == 0
41 |
42 | def check_and_save(self, module, output):
43 | if self.checkpoint:
44 | self.checkpoint.check_and_save(module, output)
45 |
46 | def train_progress_bar(self, dataloader, config, ncols=120):
47 | return tqdm(enumerate(dataloader, 0),
48 | unit=' images', unit_scale=self.world_size * config.batch_size,
49 | total=len(dataloader), smoothing=0,
50 | disable=not self.is_rank_0, ncols=ncols,
51 | )
52 |
53 | def val_progress_bar(self, dataloader, config, n=0, ncols=120):
54 | return tqdm(enumerate(dataloader, 0),
55 | unit=' images', unit_scale=self.world_size * config.batch_size,
56 | total=len(dataloader), smoothing=0,
57 | disable=not self.is_rank_0, ncols=ncols,
58 | desc=prepare_dataset_prefix(config, n)
59 | )
60 |
61 | def test_progress_bar(self, dataloader, config, n=0, ncols=120):
62 | return tqdm(enumerate(dataloader, 0),
63 | unit=' images', unit_scale=self.world_size * config.batch_size,
64 | total=len(dataloader), smoothing=0,
65 | disable=not self.is_rank_0, ncols=ncols,
66 | desc=prepare_dataset_prefix(config, n)
67 | )
68 |
--------------------------------------------------------------------------------
/dro_sfm/trainers/horovod_trainer.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import horovod.torch as hvd
5 | from dro_sfm.trainers.base_trainer import BaseTrainer, sample_to_cuda
6 | from dro_sfm.utils.config import prep_logger_and_checkpoint
7 | from dro_sfm.utils.logging import print_config
8 | from dro_sfm.utils.logging import AvgMeter
9 |
10 |
11 | class HorovodTrainer(BaseTrainer):
12 | def __init__(self, **kwargs):
13 | super().__init__(**kwargs)
14 |
15 | hvd.init()
16 | torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", 1)))
17 | torch.cuda.set_device(hvd.local_rank())
18 | torch.backends.cudnn.benchmark = True
19 |
20 | self.avg_loss = AvgMeter(50)
21 | self.avg_loss2 = AvgMeter(50)
22 | self.dtype = kwargs.get("dtype", None) # just for test for now
23 |
24 | @property
25 | def proc_rank(self):
26 | return hvd.rank()
27 |
28 | @property
29 | def world_size(self):
30 | return hvd.size()
31 |
32 | def fit(self, module):
33 |
34 | # Prepare module for training
35 | module.trainer = self
36 | # Update and print module configuration
37 | prep_logger_and_checkpoint(module)
38 | print_config(module.config)
39 |
40 | # Send module to GPU
41 | module = module.to('cuda')
42 | # Configure optimizer and scheduler
43 | module.configure_optimizers()
44 |
45 | # Create distributed optimizer
46 | compression = hvd.Compression.none
47 | optimizer = hvd.DistributedOptimizer(module.optimizer,
48 | named_parameters=module.named_parameters(), compression=compression)
49 | scheduler = module.scheduler
50 |
51 | # Get train and val dataloaders
52 | train_dataloader = module.train_dataloader()
53 | val_dataloaders = module.val_dataloader()
54 |
55 | # Epoch loop
56 | for epoch in range(module.current_epoch, self.max_epochs):
57 | # Train
58 | self.train(train_dataloader, module, optimizer)
59 | # Validation
60 | validation_output = self.validate(val_dataloaders, module)
61 | # Check and save model
62 | self.check_and_save(module, validation_output)
63 | # Update current epoch
64 | module.current_epoch += 1
65 | # Take a scheduler step
66 | scheduler.step()
67 |
68 | def train(self, dataloader, module, optimizer):
69 | # Set module to train
70 | module.train()
71 | # Shuffle dataloader sampler
72 | if hasattr(dataloader.sampler, "set_epoch"):
73 | dataloader.sampler.set_epoch(module.current_epoch)
74 | # Prepare progress bar
75 | progress_bar = self.train_progress_bar(
76 | dataloader, module.config.datasets.train)
77 | # Start training loop
78 | outputs = []
79 | # For all batches
80 | for i, batch in progress_bar:
81 | # Reset optimizer
82 | optimizer.zero_grad()
83 | # Send samples to GPU and take a training step
84 | batch = sample_to_cuda(batch)
85 | output = module.training_step(batch, i)
86 | if output is None:
87 | print("skip this training step.....")
88 | continue
89 | # Backprop through loss and take an optimizer step
90 | output['loss'].backward()
91 | optimizer.step()
92 | # Append output to list of outputs
93 | output['loss'] = output['loss'].detach()
94 | outputs.append(output)
95 | # Update progress bar if in rank 0
96 | if self.is_rank_0:
97 | progress_bar.set_description(
98 | 'Epoch {} | Avg.Loss {:.4f} Loss2 {:.4f}'.format(
99 | module.current_epoch, self.avg_loss(output['loss'].item()),
100 | self.avg_loss2(output['metrics']['pose_loss'].item() if 'pose_loss' in output['metrics'] else 0.0)))
101 | # Return outputs for epoch end
102 | # return module.training_epoch_end(outputs)
103 |
104 | def validate(self, dataloaders, module):
105 | # Set module to eval
106 | module.eval()
107 | # Start validation loop
108 | all_outputs = []
109 | # For all validation datasets
110 | for n, dataloader in enumerate(dataloaders):
111 | # Prepare progress bar for that dataset
112 | progress_bar = self.val_progress_bar(
113 | dataloader, module.config.datasets.validation, n)
114 | outputs = []
115 | # For all batches
116 | for i, batch in progress_bar:
117 | # Send batch to GPU and take a validation step
118 | batch = sample_to_cuda(batch)
119 | output = module.validation_step(batch, i, n)
120 | # Append output to list of outputs
121 | outputs.append(output)
122 | # Append dataset outputs to list of all outputs
123 | all_outputs.append(outputs)
124 | # Return all outputs for epoch end
125 | return module.validation_epoch_end(all_outputs)
126 |
127 | def test(self, module):
128 | # Send module to GPU
129 | module = module.to('cuda', dtype=self.dtype)
130 | # Get test dataloaders
131 | test_dataloaders = module.test_dataloader()
132 | # Run evaluation
133 | self.evaluate(test_dataloaders, module)
134 |
135 | @torch.no_grad()
136 | def evaluate(self, dataloaders, module):
137 | # Set module to eval
138 | module.eval()
139 | # Start evaluation loop
140 | all_outputs = []
141 | # For all test datasets
142 | for n, dataloader in enumerate(dataloaders):
143 | # Prepare progress bar for that dataset
144 | progress_bar = self.val_progress_bar(
145 | dataloader, module.config.datasets.test, n)
146 | outputs = []
147 | # For all batches
148 | for i, batch in progress_bar:
149 | # Send batch to GPU and take a test step
150 | batch = sample_to_cuda(batch, self.dtype)
151 | output = module.test_step(batch, i, n)
152 | # Append output to list of outputs
153 | outputs.append(output)
154 | # Append dataset outputs to list of all outputs
155 | all_outputs.append(outputs)
156 | # Return all outputs for epoch end
157 | return module.test_epoch_end(all_outputs)
158 |
--------------------------------------------------------------------------------
/dro_sfm/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/dro_sfm/utils/__init__.py
--------------------------------------------------------------------------------
/dro_sfm/utils/horovod.py:
--------------------------------------------------------------------------------
1 |
2 | try:
3 | import horovod.torch as hvd
4 | HAS_HOROVOD = True
5 | except ImportError:
6 | HAS_HOROVOD = False
7 |
8 | def hvd_disable():
9 | global HAS_HOROVOD
10 | HAS_HOROVOD=False
11 |
12 | def hvd_init():
13 | if HAS_HOROVOD:
14 | hvd.init()
15 | return HAS_HOROVOD
16 |
17 | def on_rank_0(func):
18 | def wrapper(*args, **kwargs):
19 | if rank() == 0:
20 | func(*args, **kwargs)
21 | return wrapper
22 |
23 | def rank():
24 | return hvd.rank() if HAS_HOROVOD else 0
25 |
26 | def world_size():
27 | return hvd.size() if HAS_HOROVOD else 1
28 |
29 | @on_rank_0
30 | def print0(string='\n'):
31 | print(string)
32 |
33 | def reduce_value(value, average, name):
34 | """
35 | Reduce the mean value of a tensor from all GPUs
36 |
37 | Parameters
38 | ----------
39 | value : torch.Tensor
40 | Value to be reduced
41 | average : bool
42 | Whether values will be averaged or not
43 | name : str
44 | Value name
45 |
46 | Returns
47 | -------
48 | value : torch.Tensor
49 | reduced value
50 | """
51 | return hvd.allreduce(value, average=average, name=name)
52 |
--------------------------------------------------------------------------------
/dro_sfm/utils/load.py:
--------------------------------------------------------------------------------
1 |
2 | import importlib
3 | import logging
4 | import os
5 | import warnings
6 | import torch
7 |
8 | from inspect import signature
9 | from collections import OrderedDict
10 |
11 | from dro_sfm.utils.misc import make_list, same_shape
12 | from dro_sfm.utils.logging import pcolor
13 | from dro_sfm.utils.horovod import print0
14 | from dro_sfm.utils.types import is_str
15 |
16 |
17 | def set_debug(debug):
18 | """
19 | Enable or disable debug terminal logging
20 |
21 | Parameters
22 | ----------
23 | debug : bool
24 | Debugging flag (True to enable)
25 | """
26 | # Disable logging if requested
27 | if not debug:
28 | os.environ['NCCL_DEBUG'] = ''
29 | os.environ['WANDB_SILENT'] = 'false'
30 | warnings.filterwarnings("ignore")
31 | logging.disable(logging.CRITICAL)
32 |
33 |
34 | def filter_args(func, keys):
35 | """
36 | Filters a dictionary so it only contains keys that are arguments of a function
37 |
38 | Parameters
39 | ----------
40 | func : Function
41 | Function for which we are filtering the dictionary
42 | keys : dict
43 | Dictionary with keys we are filtering
44 |
45 | Returns
46 | -------
47 | filtered : dict
48 | Dictionary containing only keys that are arguments of func
49 | """
50 | filtered = {}
51 | sign = list(signature(func).parameters.keys())
52 | for k, v in {**keys}.items():
53 | if k in sign:
54 | filtered[k] = v
55 | return filtered
56 |
57 |
58 | def filter_args_create(func, keys):
59 | """
60 | Filters a dictionary so it only contains keys that are arguments of a function
61 | and creates a function with those arguments
62 |
63 | Parameters
64 | ----------
65 | func : Function
66 | Function for which we are filtering the dictionary
67 | keys : dict
68 | Dictionary with keys we are filtering
69 |
70 | Returns
71 | -------
72 | func : Function
73 | Function with filtered keys as arguments
74 | """
75 | return func(**filter_args(func, keys))
76 |
77 |
78 | def load_class(filename, paths, concat=True):
79 | """
80 | Look for a file in different locations and return its method with the same name
81 | Optionally, you can use concat to search in path.filename instead
82 |
83 | Parameters
84 | ----------
85 | filename : str
86 | Name of the file we are searching for
87 | paths : str or list of str
88 | Folders in which the file will be searched
89 | concat : bool
90 | Flag to concatenate filename to each path during the search
91 |
92 | Returns
93 | -------
94 | method : Function
95 | Loaded method
96 | """
97 | # for each path in paths
98 | for path in make_list(paths):
99 | # Create full path
100 | full_path = '{}.{}'.format(path, filename) if concat else path
101 | if importlib.util.find_spec(full_path):
102 | # Return method with same name as the file
103 | return getattr(importlib.import_module(full_path), filename)
104 | raise ValueError('Unknown class {}'.format(filename))
105 |
106 |
107 | def load_class_args_create(filename, paths, args={}, concat=True):
108 | """Loads a class (filename) and returns an instance with filtered arguments (args)"""
109 | class_type = load_class(filename, paths, concat)
110 | return filter_args_create(class_type, args)
111 |
112 |
113 | def load_network(network, path, prefixes=''):
114 | """
115 | Loads a pretrained network
116 |
117 | Parameters
118 | ----------
119 | network : nn.Module
120 | Network that will receive the pretrained weights
121 | path : str
122 | File containing a 'state_dict' key with pretrained network weights
123 | prefixes : str or list of str
124 | Layer name prefixes to consider when loading the network
125 |
126 | Returns
127 | -------
128 | network : nn.Module
129 | Updated network with pretrained weights
130 | """
131 | prefixes = make_list(prefixes)
132 | # If path is a string
133 | if is_str(path):
134 | saved_state_dict = torch.load(path, map_location='cpu')['state_dict']
135 | if path.endswith('.pth.tar'):
136 | saved_state_dict = backwards_state_dict(saved_state_dict)
137 | # If state dict is already provided
138 | else:
139 | saved_state_dict = path
140 | # Get network state dict
141 | network_state_dict = network.state_dict()
142 |
143 | updated_state_dict = OrderedDict()
144 | n, n_total = 0, len(network_state_dict.keys())
145 | for key, val in saved_state_dict.items():
146 | for prefix in prefixes:
147 | prefix = prefix + '.'
148 | if prefix in key:
149 | idx = key.find(prefix) + len(prefix)
150 | key = key[idx:]
151 | if key in network_state_dict.keys() and \
152 | same_shape(val.shape, network_state_dict[key].shape):
153 | updated_state_dict[key] = val
154 | n += 1
155 | try:
156 | network.load_state_dict(updated_state_dict, strict=True)
157 | except Exception as e:
158 | print(e)
159 | network.load_state_dict(updated_state_dict, strict=False)
160 | base_color, attrs = 'cyan', ['bold', 'dark']
161 | color = 'green' if n == n_total else 'yellow' if n > 0 else 'red'
162 | print0(pcolor('=====###### Pretrained {} loaded:'.format(prefixes[0]), base_color, attrs=attrs) +
163 | pcolor(' {}/{} '.format(n, n_total), color, attrs=attrs) +
164 | pcolor('tensors', base_color, attrs=attrs))
165 | return network
166 |
167 |
168 | def backwards_state_dict(state_dict):
169 | """
170 | Modify the state dict of older models for backwards compatibility
171 |
172 | Parameters
173 | ----------
174 | state_dict : dict
175 | Model state dict with pretrained weights
176 |
177 | Returns
178 | -------
179 | state_dict : dict
180 | Updated model state dict with modified layer names
181 | """
182 | # List of layer names to change
183 | changes = (('model.model', 'model'),
184 | ('pose_network', 'pose_net'),
185 | ('disp_network', 'depth_net'))
186 | # Iterate over all keys and values
187 | updated_state_dict = OrderedDict()
188 | for key, val in state_dict.items():
189 | # Ad hoc changes due to version changes
190 | key = '{}.{}'.format('model', key)
191 | if 'disp_network' in key:
192 | key = key.replace('conv3.0.weight', 'conv3.weight')
193 | key = key.replace('conv3.0.bias', 'conv3.bias')
194 | # Change layer names
195 | for change in changes:
196 | key = key.replace('{}.'.format(change[0]),
197 | '{}.'.format(change[1]))
198 | updated_state_dict[key] = val
199 | # Return updated state dict
200 | return updated_state_dict
201 |
--------------------------------------------------------------------------------
/dro_sfm/utils/logging.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from termcolor import colored
4 | from functools import partial
5 |
6 | from dro_sfm.utils.horovod import on_rank_0
7 |
8 |
9 | def pcolor(string, color, on_color=None, attrs=None):
10 | """
11 | Produces a colored string for printing
12 |
13 | Parameters
14 | ----------
15 | string : str
16 | String that will be colored
17 | color : str
18 | Color to use
19 | on_color : str
20 | Background color to use
21 | attrs : list of str
22 | Different attributes for the string
23 |
24 | Returns
25 | -------
26 | string: str
27 | Colored string
28 | """
29 | return colored(string, color, on_color, attrs)
30 |
31 |
32 | def prepare_dataset_prefix(config, dataset_idx):
33 | """
34 | Concatenates dataset path and split for metrics logging
35 |
36 | Parameters
37 | ----------
38 | config : CfgNode
39 | Dataset configuration
40 | dataset_idx : int
41 | Dataset index for multiple datasets
42 |
43 | Returns
44 | -------
45 | prefix : str
46 | Dataset prefix for metrics logging
47 | """
48 | # Path is always available
49 | prefix = '{}'.format(os.path.splitext(config.path[dataset_idx].split('/')[-1])[0])
50 | # If split is available and does not contain { character
51 | if config.split[dataset_idx] != '' and '{' not in config.split[dataset_idx]:
52 | prefix += '-{}'.format(os.path.splitext(os.path.basename(config.split[dataset_idx]))[0])
53 | # If depth type is available
54 | if config.depth_type[dataset_idx] != '':
55 | prefix += '-{}'.format(config.depth_type[dataset_idx])
56 | # If we are using specific cameras
57 | if len(config.cameras[dataset_idx]) == 1: # only allows single cameras
58 | prefix += '-{}'.format(config.cameras[dataset_idx][0])
59 | # Return full prefix
60 | return prefix
61 |
62 |
63 | def s3_url(config):
64 | """
65 | Generate the s3 url where the models will be saved
66 |
67 | Parameters
68 | ----------
69 | config : CfgNode
70 | Model configuration
71 |
72 | Returns
73 | -------
74 | url : str
75 | String containing the URL pointing to the s3 bucket
76 | """
77 | return 'https://s3.console.aws.amazon.com/s3/buckets/{}/{}'.format(
78 | config.checkpoint.s3_path[5:], config.name)
79 |
80 |
81 | @on_rank_0
82 | def print_config(config, color=('blue', 'red', 'cyan'), attrs=('bold', 'dark')):
83 | """
84 | Prints header for model configuration
85 |
86 | Parameters
87 | ----------
88 | config : CfgNode
89 | Model configuration
90 | color : list of str
91 | Color pallete for the header
92 | attrs :
93 | Colored string attributes
94 | """
95 | # Recursive print function
96 | def print_recursive(rec_args, n=2, l=0):
97 | if l == 0:
98 | print(pcolor('config:', color[1], attrs=attrs))
99 | for key, val in rec_args.items():
100 | if isinstance(val, dict):
101 | print(pcolor('{} {}:'.format('-' * n, key), color[1], attrs=attrs))
102 | print_recursive(val, n + 2, l + 1)
103 | else:
104 | print('{}: {}'.format(pcolor('{} {}'.format('-' * n, key), color[2]), val))
105 |
106 | # Color partial functions
107 | pcolor1 = partial(pcolor, color='blue', attrs=['bold', 'dark'])
108 | pcolor2 = partial(pcolor, color='blue', attrs=['bold'])
109 | # Config and name
110 | line = pcolor1('#' * 120)
111 | path = pcolor1('### Config: ') + \
112 | pcolor2('{}'.format(config.default.replace('/', '.'))) + \
113 | pcolor1(' -> ') + \
114 | pcolor2('{}'.format(config.config.replace('/', '.')))
115 | name = pcolor1('### Name: ') + \
116 | pcolor2('{}'.format(config.name))
117 | # Add wandb link if available
118 | if not config.wandb.dry_run:
119 | name += pcolor1(' -> ') + \
120 | pcolor2('{}'.format(config.wandb.url))
121 | # Add s3 link if available
122 | if config.checkpoint.s3_path is not '':
123 | name += pcolor1('\n### s3:') + \
124 | pcolor2(' {}'.format(config.checkpoint.s3_url))
125 | # Create header string
126 | header = '%s\n%s\n%s\n%s' % (line, path, name, line)
127 |
128 | # Print header, config and header again
129 | print()
130 | print(header)
131 | print_recursive(config)
132 | print(header)
133 | print()
134 |
135 |
136 | class AvgMeter:
137 | """Average meter for logging"""
138 | def __init__(self, n_max=100):
139 | """
140 | Initializes a AvgMeter object.
141 |
142 | Parameters
143 | ----------
144 | n_max : int
145 | Number of steps to average over
146 | """
147 | self.n_max = n_max
148 | self.values = []
149 |
150 | def __call__(self, value):
151 | """Appends new value and returns average"""
152 | self.values.append(value)
153 | if len(self.values) > self.n_max:
154 | self.values.pop(0)
155 | return self.get()
156 |
157 | def get(self):
158 | """Get current average"""
159 | return sum(self.values) / len(self.values)
160 |
161 | def reset(self):
162 | """Reset meter"""
163 | self.values.clear()
164 |
165 | def get_and_reset(self):
166 | """Gets current average and resets"""
167 | average = self.get()
168 | self.reset()
169 | return average
170 |
--------------------------------------------------------------------------------
/dro_sfm/utils/misc.py:
--------------------------------------------------------------------------------
1 |
2 | from dro_sfm.utils.types import is_list
3 |
4 | ########################################################################################################################
5 |
6 | def filter_dict(dictionary, keywords):
7 | """
8 | Returns only the keywords that are part of a dictionary
9 |
10 | Parameters
11 | ----------
12 | dictionary : dict
13 | Dictionary for filtering
14 | keywords : list of str
15 | Keywords that will be filtered
16 |
17 | Returns
18 | -------
19 | keywords : list of str
20 | List containing the keywords that are keys in dictionary
21 | """
22 | return [key for key in keywords if key in dictionary]
23 |
24 | ########################################################################################################################
25 |
26 | def make_list(var, n=None):
27 | """
28 | Wraps the input into a list, and optionally repeats it to be size n
29 |
30 | Parameters
31 | ----------
32 | var : Any
33 | Variable to be wrapped in a list
34 | n : int
35 | How much the wrapped variable will be repeated
36 |
37 | Returns
38 | -------
39 | var_list : list
40 | List generated from var
41 | """
42 | var = var if is_list(var) else [var]
43 | if n is None:
44 | return var
45 | else:
46 | assert len(var) == 1 or len(var) == n, 'Wrong list length for make_list'
47 | return var * n if len(var) == 1 else var
48 |
49 | ########################################################################################################################
50 |
51 | def same_shape(shape1, shape2):
52 | """
53 | Checks if two shapes are the same
54 |
55 | Parameters
56 | ----------
57 | shape1 : tuple
58 | First shape
59 | shape2 : tuple
60 | Second shape
61 |
62 | Returns
63 | -------
64 | flag : bool
65 | True if both shapes are the same (same length and dimensions)
66 | """
67 | if len(shape1) != len(shape2):
68 | return False
69 | for i in range(len(shape1)):
70 | if shape1[i] != shape2[i]:
71 | return False
72 | return True
73 |
74 | ########################################################################################################################
75 |
--------------------------------------------------------------------------------
/dro_sfm/utils/reduce.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import numpy as np
4 | from collections import OrderedDict
5 | from dro_sfm.utils.horovod import reduce_value
6 | from dro_sfm.utils.logging import prepare_dataset_prefix
7 |
8 |
9 | def reduce_dict(data, to_item=False):
10 | """
11 | Reduce the mean values of a dictionary from all GPUs
12 |
13 | Parameters
14 | ----------
15 | data : dict
16 | Dictionary to be reduced
17 | to_item : bool
18 | True if the reduced values will be return as .item()
19 |
20 | Returns
21 | -------
22 | dict : dict
23 | Reduced dictionary
24 | """
25 | for key, val in data.items():
26 | data[key] = reduce_value(data[key], average=True, name=key)
27 | if to_item:
28 | data[key] = data[key].item()
29 | return data
30 |
31 | def all_reduce_metrics(output_data_batch, datasets, name='depth'):
32 | """
33 | Reduce metrics for all batches and all datasets using Horovod
34 |
35 | Parameters
36 | ----------
37 | output_data_batch : list
38 | List of outputs for each batch
39 | datasets : list
40 | List of all considered datasets
41 | name : str
42 | Name of the task for the metric
43 |
44 | Returns
45 | -------
46 | all_metrics_dict : list
47 | List of reduced metrics
48 | """
49 | # If there is only one dataset, wrap in a list
50 | if isinstance(output_data_batch[0], dict):
51 | output_data_batch = [output_data_batch]
52 | # Get metrics keys and dimensions
53 | names = [key for key in list(output_data_batch[0][0].keys()) if key.startswith(name)]
54 | dims = [output_data_batch[0][0][name].shape[0] for name in names]
55 | # List storing metrics for all datasets
56 | all_metrics_dict = []
57 | # Loop over all datasets and all batches
58 | for output_batch, dataset in zip(output_data_batch, datasets):
59 | metrics_dict = OrderedDict()
60 | length = len(dataset)
61 | # Count how many times each sample was seen
62 | seen = torch.zeros(length)
63 | for output in output_batch:
64 | for i, idx in enumerate(output['idx']):
65 | seen[idx] += 1
66 | seen = reduce_value(seen, average=False, name='idx')
67 | assert not np.any(seen.numpy() == 0), \
68 | 'Not all samples were seen during evaluation'
69 | # Reduce all relevant metrics
70 | for name, dim in zip(names, dims):
71 | metrics = torch.zeros(length, dim)
72 | for output in output_batch:
73 | for i, idx in enumerate(output['idx']):
74 | metrics[idx] = output[name]
75 | metrics = reduce_value(metrics, average=False, name=name)
76 | metrics_dict[name] = (metrics / seen.view(-1, 1)).mean(0)
77 | # Append metrics dictionary to the list
78 | all_metrics_dict.append(metrics_dict)
79 | # Return list of metrics dictionary
80 | return all_metrics_dict
81 |
82 | ########################################################################################################################
83 |
84 | def collate_metrics(output_data_batch, name='depth'):
85 | """
86 | Collate epoch output to produce average metrics
87 |
88 | Parameters
89 | ----------
90 | output_data_batch : list
91 | List of outputs for each batch
92 | name : str
93 | Name of the task for the metric
94 |
95 | Returns
96 | -------
97 | metrics_data : list
98 | List of collated metrics
99 | """
100 | # If there is only one dataset, wrap in a list
101 | if isinstance(output_data_batch[0], dict):
102 | output_data_batch = [output_data_batch]
103 | # Calculate the mean of all metrics
104 | metrics_data = []
105 | # For all datasets
106 | for i, output_batch in enumerate(output_data_batch):
107 | metrics = OrderedDict()
108 | # For all keys (assume they are the same for all batches)
109 | for key, val in output_batch[0].items():
110 | if key.startswith(name):
111 | metrics[key] = torch.stack([output[key] for output in output_batch], 0)
112 | metrics[key] = torch.mean(metrics[key], 0)
113 | metrics_data.append(metrics)
114 | # Return metrics data
115 | return metrics_data
116 |
117 | def create_dict(metrics_data, metrics_keys, metrics_modes,
118 | dataset, name='depth'):
119 | """
120 | Creates a dictionary from collated metrics
121 |
122 | Parameters
123 | ----------
124 | metrics_data : list
125 | List containing collated metrics
126 | metrics_keys : list
127 | List of keys for the metrics
128 | metrics_modes
129 | List of modes for the metrics
130 | dataset : CfgNode
131 | Dataset configuration file
132 | name : str
133 | Name of the task for the metric
134 |
135 | Returns
136 | -------
137 | metrics_dict : dict
138 | Metrics dictionary
139 | """
140 | # Create metrics dictionary
141 | metrics_dict = {}
142 | # For all datasets
143 | for n, metrics in enumerate(metrics_data):
144 | if metrics: # If there are calculated metrics
145 | prefix = prepare_dataset_prefix(dataset, n)
146 | # For all keys
147 | for i, key in enumerate(metrics_keys):
148 | for mode in metrics_modes:
149 | metrics_dict['{}-{}{}'.format(prefix, key, mode)] =\
150 | metrics['{}{}'.format(name, mode)][i].item()
151 | # Return metrics dictionary
152 | return metrics_dict
153 |
154 | ########################################################################################################################
155 |
156 | def average_key(batch_list, key):
157 | """
158 | Average key in a list of batches
159 |
160 | Parameters
161 | ----------
162 | batch_list : list of dict
163 | List containing dictionaries with the same keys
164 | key : str
165 | Key to be averaged
166 |
167 | Returns
168 | -------
169 | average : float
170 | Average of the value contained in key for all batches
171 | """
172 | values = [batch[key] for batch in batch_list]
173 | return sum(values) / len(values)
174 |
175 | def average_sub_key(batch_list, key, sub_key):
176 | """
177 | Average subkey in a dictionary in a list of batches
178 |
179 | Parameters
180 | ----------
181 | batch_list : list of dict
182 | List containing dictionaries with the same keys
183 | key : str
184 | Key to be averaged
185 | sub_key :
186 | Sub key to be averaged (belonging to key)
187 |
188 | Returns
189 | -------
190 | average : float
191 | Average of the value contained in the sub_key of key for all batches
192 | """
193 | values = [batch[key][sub_key] for batch in batch_list]
194 | return sum(values) / len(values)
195 |
196 | def average_loss_and_metrics(batch_list, prefix):
197 | """
198 | Average loss and metrics values in a list of batches
199 |
200 | Parameters
201 | ----------
202 | batch_list : list of dict
203 | List containing dictionaries with the same keys
204 | prefix : str
205 | Prefix string for metrics logging
206 |
207 | Returns
208 | -------
209 | values : dict
210 | Dictionary containing a 'loss' float entry and a 'metrics' dict entry
211 | """
212 | values = OrderedDict()
213 | key = 'loss'
214 | values['{}-{}'.format(prefix, key)] = \
215 | average_key(batch_list, key)
216 | key = 'metrics'
217 | for sub_key in batch_list[0][key].keys():
218 | values['{}-{}'.format(prefix, sub_key)] = \
219 | average_sub_key(batch_list, key, sub_key)
220 | return values
221 |
222 | ########################################################################################################################
223 |
--------------------------------------------------------------------------------
/dro_sfm/utils/save.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import os
4 |
5 | from dro_sfm.utils.image import write_image
6 | from dro_sfm.utils.depth import write_depth, inv2depth, viz_inv_depth
7 | from dro_sfm.utils.logging import prepare_dataset_prefix
8 |
9 |
10 | def save_depth(batch, output, args, dataset, save):
11 | """
12 | Save depth predictions in various ways
13 |
14 | Parameters
15 | ----------
16 | batch : dict
17 | Batch from dataloader
18 | output : dict
19 | Output from model
20 | args : tuple
21 | Step arguments
22 | dataset : CfgNode
23 | Dataset configuration
24 | save : CfgNode
25 | Save configuration
26 | """
27 | # If there is no save folder, don't save
28 | if save.folder is '':
29 | return
30 |
31 | # If we want to save
32 | if save.depth.rgb or save.depth.viz or save.depth.npz or save.depth.png:
33 | # Retrieve useful tensors
34 | rgb = batch['rgb']
35 | pred_inv_depth = output['inv_depth']
36 |
37 | # Prepare path strings
38 | filename = batch['filename']
39 | dataset_idx = 0 if len(args) == 1 else args[1]
40 | save_path = os.path.join(save.folder, 'depth',
41 | prepare_dataset_prefix(dataset, dataset_idx),
42 | os.path.basename(save.pretrained).split('.')[0])
43 | # Create folder
44 | os.makedirs(save_path, exist_ok=True)
45 |
46 | # For each image in the batch
47 | length = rgb.shape[0]
48 | for i in range(length):
49 | # Save numpy depth maps
50 | if save.depth.npz:
51 | write_depth('{}/{}_depth.npz'.format(save_path, filename[i]),
52 | depth=inv2depth(pred_inv_depth[i]),
53 | intrinsics=batch['intrinsics'][i] if 'intrinsics' in batch else None)
54 | # Save png depth maps
55 | if save.depth.png:
56 | write_depth('{}/{}_depth.png'.format(save_path, filename[i]),
57 | depth=inv2depth(pred_inv_depth[i]))
58 | # Save rgb images
59 | if save.depth.rgb:
60 | rgb_i = rgb[i].permute(1, 2, 0).detach().cpu().numpy() * 255
61 | write_image('{}/{}_rgb.png'.format(save_path, filename[i]), rgb_i)
62 | # Save inverse depth visualizations
63 | if save.depth.viz:
64 | viz_i = viz_inv_depth(pred_inv_depth[i]) * 255
65 | write_image('{}/{}_viz.png'.format(save_path, filename[i]), viz_i)
66 |
--------------------------------------------------------------------------------
/dro_sfm/utils/types.py:
--------------------------------------------------------------------------------
1 |
2 | import yacs
3 | import numpy as np
4 | import torch
5 |
6 | ########################################################################################################################
7 |
8 | def is_numpy(data):
9 | """Checks if data is a numpy array."""
10 | return isinstance(data, np.ndarray)
11 |
12 | def is_tensor(data):
13 | """Checks if data is a torch tensor."""
14 | return type(data) == torch.Tensor
15 |
16 | def is_tuple(data):
17 | """Checks if data is a tuple."""
18 | return isinstance(data, tuple)
19 |
20 | def is_list(data):
21 | """Checks if data is a list."""
22 | return isinstance(data, list)
23 |
24 | def is_dict(data):
25 | """Checks if data is a dictionary."""
26 | return isinstance(data, dict)
27 |
28 | def is_str(data):
29 | """Checks if data is a string."""
30 | return isinstance(data, str)
31 |
32 | def is_int(data):
33 | """Checks if data is an integer."""
34 | return isinstance(data, int)
35 |
36 | def is_seq(data):
37 | """Checks if data is a list or tuple."""
38 | return is_tuple(data) or is_list(data)
39 |
40 | def is_cfg(data):
41 | """Checks if data is a configuration node"""
42 | return type(data) == yacs.config.CfgNode
43 |
44 | ########################################################################################################################
--------------------------------------------------------------------------------
/media/figs/demo_kitti.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/media/figs/demo_kitti.gif
--------------------------------------------------------------------------------
/media/figs/demo_scannet.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aliyun/dro-sfm/8707e2e0ef799d7d47418a018060f503ef449fe3/media/figs/demo_scannet.gif
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | mkdir -p $(dirname $2)
2 |
3 | NGPUS=$(nvidia-smi -L | wc -l)
4 | mpirun -allow-run-as-root -np ${NGPUS} -H localhost:${NGPUS} -x MASTER_ADDR=127.0.0.1 -x MASTER_PORT=23457 -x HOROVOD_TIMELINE -x OMP_NUM_THREADS=1 -x KMP_AFFINITY='granularity=fine,compact,1,0' -bind-to none -map-by slot -x NCCL_DEBUG=INFO -x NCCL_MIN_NRINGS=4 --report-bindings $1 2>&1 | tee -a $2
5 |
--------------------------------------------------------------------------------
/scripts/eval.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4 | sys.path.append(lib_dir)
5 | import argparse
6 | import torch
7 |
8 | from dro_sfm.models.model_wrapper import ModelWrapper
9 | from dro_sfm.trainers.horovod_trainer import HorovodTrainer
10 | from dro_sfm.utils.config import parse_test_file
11 | from dro_sfm.utils.load import set_debug
12 | from dro_sfm.utils.horovod import hvd_init
13 |
14 |
15 | def parse_args():
16 | """Parse arguments for training script"""
17 | parser = argparse.ArgumentParser(description='dro-sfm evaluation script')
18 | parser.add_argument('--checkpoint', type=str, help='Checkpoint (.ckpt)')
19 | parser.add_argument('--config', type=str, default=None, help='Configuration (.yaml)')
20 | parser.add_argument('--half', action="store_true", help='Use half precision (fp16)')
21 | args = parser.parse_args()
22 | assert args.checkpoint.endswith('.ckpt'), \
23 | 'You need to provide a .ckpt file as checkpoint'
24 | assert args.config is None or args.config.endswith('.yaml'), \
25 | 'You need to provide a .yaml file as configuration'
26 | return args
27 |
28 |
29 | def test(ckpt_file, cfg_file, half):
30 | """
31 | Monocular depth estimation test script.
32 |
33 | Parameters
34 | ----------
35 | ckpt_file : str
36 | Checkpoint path for a pretrained model
37 | cfg_file : str
38 | Configuration file
39 | half: bool
40 | use half precision (fp16)
41 | """
42 | # Initialize horovod
43 | hvd_init()
44 |
45 | # Parse arguments
46 | config, state_dict = parse_test_file(ckpt_file, cfg_file)
47 |
48 | # Set debug if requested
49 | set_debug(config.debug)
50 |
51 | # Initialize monodepth model from checkpoint arguments
52 | model_wrapper = ModelWrapper(config)
53 | # Restore model state
54 | model_wrapper.load_state_dict(state_dict)
55 |
56 | # change to half precision for evaluation if requested
57 | config.arch["dtype"] = torch.float16 if half else None
58 |
59 | # Create trainer with args.arch parameters
60 | trainer = HorovodTrainer(**config.arch)
61 |
62 | # Test model
63 | trainer.test(model_wrapper)
64 |
65 |
66 | if __name__ == '__main__':
67 | args = parse_args()
68 | test(args.checkpoint, args.config, args.half)
69 |
--------------------------------------------------------------------------------
/scripts/evaluate_depth_maps.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4 | sys.path.append(lib_dir)
5 | import argparse
6 | import numpy as np
7 | import torch
8 |
9 | from glob import glob
10 | from argparse import Namespace
11 | from dro_sfm.utils.depth import load_depth
12 | from tqdm import tqdm
13 |
14 | from dro_sfm.utils.depth import load_depth, compute_depth_metrics
15 |
16 |
17 | def parse_args():
18 | """Parse arguments for benchmark script"""
19 | parser = argparse.ArgumentParser(description='dro-sfm benchmark script')
20 | parser.add_argument('--pred_folder', type=str,
21 | help='Folder containing predicted depth maps (.npz with key "depth")')
22 | parser.add_argument('--gt_folder', type=str,
23 | help='Folder containing ground-truth depth maps (.npz with key "depth")')
24 | parser.add_argument('--use_gt_scale', action='store_true',
25 | help='Use ground-truth median scaling on predicted depth maps')
26 | parser.add_argument('--min_depth', type=float, default=0.,
27 | help='Minimum distance to consider during evaluation')
28 | parser.add_argument('--max_depth', type=float, default=80.,
29 | help='Maximum distance to consider during evaluation')
30 | parser.add_argument('--crop', type=str, default='', choices=['', 'garg'],
31 | help='Which crop to use during evaluation')
32 | args = parser.parse_args()
33 | return args
34 |
35 |
36 | def main(args):
37 | # Get and sort ground-truth and predicted files
38 | exts = ('npz', 'png')
39 | gt_files, pred_files = [], []
40 | for ext in exts:
41 | gt_files.extend(glob(os.path.join(args.gt_folder, '*.{}'.format(ext))))
42 | pred_files.extend(glob(os.path.join(args.pred_folder, '*.{}'.format(ext))))
43 | # Sort ground-truth and prediction
44 | gt_files.sort()
45 | pred_files.sort()
46 | # Loop over all files
47 | metrics = []
48 | progress_bar = tqdm(zip(gt_files, pred_files), total=len(gt_files))
49 | for gt, pred in progress_bar:
50 | # Get and prepare ground-truth and predictions
51 | gt = torch.tensor(load_depth(gt)).unsqueeze(0).unsqueeze(0)
52 | pred = torch.tensor(load_depth(pred)).unsqueeze(0).unsqueeze(0)
53 | # Calculate metrics
54 | metrics.append(compute_depth_metrics(
55 | args, gt, pred, use_gt_scale=args.use_gt_scale))
56 | # Get and print average value
57 | metrics = (sum(metrics) / len(metrics)).detach().cpu().numpy()
58 | names = ['abs_rel', 'sqr_rel', 'rmse', 'rmse_log', 'a1', 'a2', 'a3']
59 | for name, metric in zip(names, metrics):
60 | print('{} = {}'.format(name, metric))
61 |
62 |
63 | if __name__ == '__main__':
64 | args = parse_args()
65 | main(args)
66 |
--------------------------------------------------------------------------------
/scripts/infer_pose.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
4 | sys.path.append(lib_dir)
5 | import argparse
6 | import numpy as np
7 | import torch
8 | from pytorch3d import transforms
9 | import json
10 |
11 | from glob import glob
12 | from cv2 import imwrite
13 |
14 | from dro_sfm.models.model_wrapper import ModelWrapper
15 | from dro_sfm.datasets.augmentations import resize_image, to_tensor
16 | from dro_sfm.utils.horovod import hvd_init, rank, world_size, print0
17 | from dro_sfm.utils.image import load_image
18 | from dro_sfm.utils.config import parse_test_file
19 | from dro_sfm.utils.load import set_debug
20 | from dro_sfm.utils.depth import write_depth, inv2depth, viz_inv_depth
21 | from dro_sfm.utils.logging import pcolor
22 |
23 | poses = dict()
24 |
25 | def is_image(file, ext=('.png', '.jpg',)):
26 | """Check if a file is an image with certain extensions"""
27 | return file.endswith(ext)
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description='dro-sfm inference of depth maps from images')
31 | parser.add_argument('--checkpoint', type=str, help='Checkpoint (.ckpt)')
32 | parser.add_argument('--input', type=str, help='Input file or folder')
33 | parser.add_argument('--output', type=str, help='Output file or folder')
34 | parser.add_argument('--image_shape', type=int, nargs='+', default=None,
35 | help='Input and output image shape '
36 | '(default: checkpoint\'s config.datasets.augmentation.image_shape)')
37 | parser.add_argument('--half', action="store_true", help='Use half precision (fp16)')
38 | parser.add_argument('--save', type=str, choices=['npz', 'png'], default=None,
39 | help='Save format (npz or png). Default is None (no depth map is saved).')
40 | parser.add_argument('--limit', type=int, default=None,
41 | help='Limit the amount of files to process in folder')
42 | parser.add_argument('--offset', type=int, default=None,
43 | help='Start at offset for files to process in folder')
44 | args = parser.parse_args()
45 | assert args.checkpoint.endswith('.ckpt'), \
46 | 'You need to provide a .ckpt file as checkpoint'
47 | assert args.image_shape is None or len(args.image_shape) == 2, \
48 | 'You need to provide a 2-dimensional tuple as shape (H,W)'
49 | assert (is_image(args.input) and is_image(args.output)) or \
50 | (not is_image(args.input) and not is_image(args.input)), \
51 | 'Input and output must both be images or folders'
52 | return args
53 |
54 |
55 | @torch.no_grad()
56 | def infer_and_save_pose(input_file_refs, input_file, model_wrapper, image_shape, half, save):
57 | """
58 | Process a single input file to produce and save visualization
59 |
60 | Parameters
61 | ----------
62 | input_file_refs : list(str)
63 | Reference image file paths
64 | input_file : str
65 | Image file for pose estimation
66 | model_wrapper : nn.Module
67 | Model wrapper used for inference
68 | image_shape : Image shape
69 | Input image shape
70 | half: bool
71 | use half precision (fp16)
72 | save: str
73 | Save format (npz or png)
74 | """
75 | base_name = os.path.basename(input_file)
76 |
77 | # change to half precision for evaluation if requested
78 | dtype = torch.float16 if half else None
79 |
80 | # Load image
81 | def process_image(filename):
82 | image = load_image(filename)
83 | # Resize and to tensor
84 | image = resize_image(image, image_shape)
85 | image = to_tensor(image).unsqueeze(0)
86 |
87 | # Send image to GPU if available
88 | if torch.cuda.is_available():
89 | image = image.to('cuda:{}'.format(rank()), dtype=dtype)
90 | return image
91 | image_ref = [process_image(input_file_ref) for input_file_ref in input_file_refs]
92 | image = process_image(input_file)
93 |
94 | # Depth inference (returns predicted inverse depth)
95 | pose_tensor = model_wrapper.pose(image, image_ref)[0][0] # take the pose from 1st to 2nd image
96 |
97 | rot_matrix = transforms.euler_angles_to_matrix(pose_tensor[3:], convention="ZYX")
98 | translation = pose_tensor[:3]
99 |
100 | poses[base_name] = (rot_matrix, translation)
101 |
102 | def main(args):
103 |
104 | # Initialize horovod
105 | hvd_init()
106 |
107 | # Parse arguments
108 | config, state_dict = parse_test_file(args.checkpoint)
109 |
110 | # If no image shape is provided, use the checkpoint one
111 | image_shape = args.image_shape
112 | if image_shape is None:
113 | image_shape = config.datasets.augmentation.image_shape
114 |
115 | print(image_shape)
116 | # Set debug if requested
117 | set_debug(config.debug)
118 |
119 | # Initialize model wrapper from checkpoint arguments
120 | model_wrapper = ModelWrapper(config, load_datasets=False)
121 | # Restore monodepth_model state
122 | model_wrapper.load_state_dict(state_dict)
123 |
124 | # change to half precision for evaluation if requested
125 | dtype = torch.float16 if args.half else None
126 |
127 | # Send model to GPU if available
128 | if torch.cuda.is_available():
129 | model_wrapper = model_wrapper.to('cuda:{}'.format(rank()), dtype=dtype)
130 |
131 | # Set to eval mode
132 | model_wrapper.eval()
133 |
134 | if os.path.isdir(args.input):
135 | # If input file is a folder, search for image files
136 | files = []
137 | for ext in ['png', 'jpg']:
138 | files.extend(glob((os.path.join(args.input, '*.{}'.format(ext)))))
139 | files.sort()
140 | print0('Found {} files'.format(len(files)))
141 | else:
142 | raise RuntimeError("Input needs directory, not file")
143 |
144 | if not os.path.isdir(args.output):
145 | root, file_name = os.path.split(args.output)
146 | os.makedirs(root, exist_ok=True)
147 | else:
148 | raise RuntimeError("Output needs to be a file")
149 |
150 |
151 | # Process each file
152 | list_of_files = list(zip(files[rank() :-2:world_size()],
153 | files[rank()+1:-1:world_size()],
154 | files[rank()+2: :world_size()]))
155 | if args.offset:
156 | list_of_files = list_of_files[args.offset:]
157 | if args.limit:
158 | list_of_files = list_of_files[:args.limit]
159 | for fn1, fn2, fn3 in list_of_files:
160 | infer_and_save_pose([fn1, fn3], fn2, model_wrapper, image_shape, args.half, args.save)
161 |
162 | position = np.zeros(3)
163 | orientation = np.eye(3)
164 | for key in sorted(poses.keys()):
165 | rot_matrix, translation = poses[key]
166 | orientation = orientation.dot(rot_matrix.tolist())
167 | position += orientation.dot(translation.tolist())
168 | poses[key] = {"rot": rot_matrix.tolist(),
169 | "trans": translation.tolist(),
170 | "pose": [*orientation[0], position[0],
171 | *orientation[1], position[1],
172 | *orientation[2], position[2],
173 | 0, 0, 0, 1]}
174 |
175 | json.dump(poses, open(args.output, "w"), sort_keys=True)
176 | print(f"Written pose of {len(list_of_files)} images to {args.output}")
177 |
178 |
179 |
180 |
181 | if __name__ == '__main__':
182 | args = parse_args()
183 | main(args)
--------------------------------------------------------------------------------
/scripts/infer_pose.sh:
--------------------------------------------------------------------------------
1 | python scripts/infer_pose.py --checkpoint ckpt/PackNet01_MR_selfsup_D.ckpt --input /data0/datasets/kitti/KITTI_tiny/2011_09_26/2011_09_26_drive_0023_sync/image_02/data --output results/eval/pose --save png
2 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | import os
4 | lib_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5 | sys.path.append(lib_dir)
6 | import argparse
7 | from dro_sfm.models.model_wrapper import ModelWrapper
8 | from dro_sfm.models.model_checkpoint import ModelCheckpoint
9 | from dro_sfm.trainers.horovod_trainer import HorovodTrainer
10 | from dro_sfm.utils.config import parse_train_file
11 | from dro_sfm.utils.load import set_debug, filter_args_create
12 | from dro_sfm.utils.horovod import hvd_init, rank
13 | from dro_sfm.loggers import WandbLogger
14 |
15 |
16 | def parse_args():
17 | """Parse arguments for training script"""
18 | parser = argparse.ArgumentParser(description='dro-sfm training script')
19 | parser.add_argument('file', type=str, help='Input file (.ckpt or .yaml)')
20 | parser.add_argument('--config', type=str, default=None, help='Input file (yaml)')
21 | args = parser.parse_args()
22 | assert args.file.endswith(('.ckpt', '.yaml')), \
23 | 'You need to provide a .ckpt of .yaml file'
24 | return args
25 |
26 |
27 | def train(file, config):
28 | """
29 | Monocular depth estimation training script.
30 |
31 | Parameters
32 | ----------
33 | file : str
34 | Filepath, can be either a
35 | **.yaml** for a yacs configuration file or a
36 | **.ckpt** for a pre-trained checkpoint file.
37 | """
38 | # Initialize horovod
39 | hvd_init()
40 |
41 | # Produce configuration and checkpoint from filename
42 | config, ckpt = parse_train_file(file, config)
43 |
44 | # Set debug if requested
45 | set_debug(config.debug)
46 |
47 | # Wandb Logger
48 | logger = None if config.wandb.dry_run or rank() > 0 \
49 | else filter_args_create(WandbLogger, config.wandb)
50 |
51 | # model checkpoint
52 | checkpoint = None if config.checkpoint.filepath is '' or rank() > 0 else \
53 | filter_args_create(ModelCheckpoint, config.checkpoint)
54 |
55 | # Initialize model wrapper
56 | model_wrapper = ModelWrapper(config, resume=ckpt, logger=logger)
57 |
58 | # Create trainer with args.arch parameters
59 | trainer = HorovodTrainer(**config.arch, checkpoint=checkpoint)
60 |
61 | # Train model
62 | trainer.fit(model_wrapper)
63 |
64 |
65 | if __name__ == '__main__':
66 | args = parse_args()
67 | train(args.file, args.config)
68 |
--------------------------------------------------------------------------------