├── common ├── loss │ ├── __init__.py │ ├── loss_functions.py │ └── inverse_warp.py ├── utils │ ├── __init__.py │ ├── logger.py │ ├── utils.py │ └── custom_transforms.py ├── data_prepare │ ├── __init__.py │ ├── well_lit_from_varying.txt │ ├── prepare_train_data_VIVID.py │ └── VIVID_raw_loader.py └── models │ ├── __init__.py │ ├── ResDispPoseNet.py │ ├── PoseResNet.py │ ├── DispResNet.py │ └── resnet_encoder.py ├── eval_vivid ├── __init__.py ├── pose_evaluation_utils.py └── eval_depth.py ├── media ├── overview_a.png ├── overview_b.png ├── qualitative_1.png ├── qualitative_2.png └── qualitative_3.png ├── scripts ├── prepare_vivid_data.sh ├── train_vivid_resnet18_indoor.sh ├── train_vivid_resnet18_outdoor.sh ├── run_vivid_inference.sh ├── display_result.sh ├── test_vivid_outdoor.sh └── test_vivid_indoor.sh ├── .gitignore ├── dataloader ├── VIVID_sequence_folders.py └── VIVID_validation_folders.py ├── environment.yml ├── test_disp.py ├── run_inference.py ├── test_pose.py ├── README.md └── train.py /common/loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /eval_vivid/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /common/data_prepare/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /media/overview_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UkcheolShin/ThermalMonoDepth/HEAD/media/overview_a.png -------------------------------------------------------------------------------- /media/overview_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UkcheolShin/ThermalMonoDepth/HEAD/media/overview_b.png -------------------------------------------------------------------------------- /media/qualitative_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UkcheolShin/ThermalMonoDepth/HEAD/media/qualitative_1.png -------------------------------------------------------------------------------- /media/qualitative_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UkcheolShin/ThermalMonoDepth/HEAD/media/qualitative_2.png -------------------------------------------------------------------------------- /media/qualitative_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UkcheolShin/ThermalMonoDepth/HEAD/media/qualitative_3.png -------------------------------------------------------------------------------- /common/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .DispResNet import DispResNet 2 | from .PoseResNet import PoseResNet 3 | from .ResDispPoseNet import DispPoseResNet 4 | -------------------------------------------------------------------------------- /scripts/prepare_vivid_data.sh: -------------------------------------------------------------------------------- 1 | # for VIVID rgb-thermal dataset 2 | DATASET=/HDD/Dataset/KAIST_VIVID 3 | TRAIN_SET=/HDD/Dataset_processed/VIVID_256/ 4 | mkdir -p $TRAIN_SET 5 | python common/data_prepare/prepare_train_data_VIVID.py $DATASET --dump-root $TRAIN_SET --width 320 --height 256 --num-threads 16 --with-depth --with-pose 6 | -------------------------------------------------------------------------------- /scripts/train_vivid_resnet18_indoor.sh: -------------------------------------------------------------------------------- 1 | DATA_ROOT=/HDD/Dataset_processed 2 | TRAIN_SET=$DATA_ROOT/VIVID_256 3 | GPU_ID=1 4 | 5 | CUDA_VISIBLE_DEVICES=${GPU_ID} \ 6 | python train.py $TRAIN_SET \ 7 | --resnet-layers 18 \ 8 | --num-scales 1 \ 9 | --scene_type indoor \ 10 | -b 4 \ 11 | --sequence-length 3 \ 12 | --with-ssim 1 \ 13 | --with-auto-mask 1 \ 14 | --with-pretrain 1 \ 15 | --rearrange-bin 30 \ 16 | --clahe-clip 3.0 \ 17 | --log-output \ 18 | --name T_vivid_resnet18_indoor 19 | -------------------------------------------------------------------------------- /scripts/train_vivid_resnet18_outdoor.sh: -------------------------------------------------------------------------------- 1 | DATA_ROOT=/HDD/Dataset_processed 2 | TRAIN_SET=$DATA_ROOT/VIVID_256 3 | GPU_ID=2 4 | 5 | CUDA_VISIBLE_DEVICES=${GPU_ID} \ 6 | python train.py $TRAIN_SET \ 7 | --resnet-layers 18 \ 8 | --num-scales 1 \ 9 | --scene_type outdoor \ 10 | -b 4 \ 11 | --sequence-length 3 \ 12 | --with-ssim 1 \ 13 | --with-auto-mask 1 \ 14 | --with-pretrain 1 \ 15 | --rearrange-bin 30 \ 16 | --clahe-clip 2.0 \ 17 | --log-output \ 18 | --name T_vivid_resnet18_outdoor 19 | -------------------------------------------------------------------------------- /scripts/run_vivid_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run script : bash run_vivid_inference.sh 3 | 4 | NAMES=("T_vivid_resnet18_indoor") 5 | 6 | DATA_ROOT=/HDD/Dataset_processed/VIVID_256/ 7 | RESNET=18 8 | IMG_H=256 9 | IMG_W=320 10 | GPU_ID=0 11 | 12 | for NAME in ${NAMES[@]}; do 13 | echo "Name : ${NAME}" 14 | RESULTS_DIR=results/${NAME} 15 | POSE_NET=checkpoints/${NAME}/exp_pose_disp_model_best.pth.tar 16 | DISP_NET=checkpoints/${NAME}/dispnet_disp_model_best.pth.tar 17 | 18 | SEQS=("indoor_robust_dark" "indoor_robust_varying" "indoor_aggresive_dark" 19 | "indoor_aggresive_local" "indoor_unstable_dark" "indoor_robust_varying_well_lit") 20 | 21 | for SEQ in ${SEQS[@]}; do 22 | echo "Seq_name : ${SEQ}" 23 | 24 | CUDA_VISIBLE_DEVICES=${GPU_ID} python run_inference.py ${DATA_ROOT} --resnet-layers $RESNET \ 25 | --output-dir $RESULTS_DIR --sequence ${SEQ} --scene_type indoor --pretrained-model $DISP_NET 26 | done 27 | 28 | SEQS=("outdoor_robust_night1" "outdoor_robust_night2") 29 | 30 | for SEQ in ${SEQS[@]}; do 31 | echo "Seq_name : ${SEQ}" 32 | 33 | CUDA_VISIBLE_DEVICES=${GPU_ID} python run_inference.py ${DATA_ROOT} --resnet-layers $RESNET \ 34 | --output-dir $RESULTS_DIR --sequence ${SEQ} --scene_type outdoor --pretrained-model $DISP_NET 35 | done 36 | done 37 | -------------------------------------------------------------------------------- /scripts/display_result.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run script : bash 3 | 4 | NAMES=("T_vivid_resnet18_indoor") 5 | RESULTS_DIR=results 6 | 7 | for NAME in ${NAMES[@]}; do 8 | echo "NAME : ${NAME}" 9 | 10 | # depth 11 | echo "Indoor Depth : well-lit results" 12 | SEQS=("indoor_aggresive_local" "indoor_robust_varying_well_lit") 13 | for SEQ in ${SEQS[@]}; do 14 | echo "Seq_name : ${SEQ}" 15 | cat ${RESULTS_DIR}/${NAME}/Depth/${SEQ}/eval_depth.txt 16 | done 17 | 18 | echo "Indoor Depth : low-light results" 19 | SEQS=("indoor_robust_dark" "indoor_unstable_dark" "indoor_aggresive_dark" ) 20 | for SEQ in ${SEQS[@]}; do 21 | echo "Seq_name : ${SEQ}" 22 | cat ${RESULTS_DIR}/${NAME}/Depth/${SEQ}/eval_depth.txt 23 | done 24 | 25 | # pose 26 | SEQS=("indoor_aggresive_local" "indoor_robust_varying" "indoor_robust_dark" "indoor_unstable_dark" "indoor_aggresive_dark" ) 27 | echo "Indoor Pose : All results" 28 | for SEQ in ${SEQS[@]}; do 29 | echo "Seq_name : ${SEQ}" 30 | cat ${RESULTS_DIR}/${NAME}/POSE/${SEQ}/eval_pose.txt 31 | done 32 | done 33 | 34 | NAMES=("T_vivid_resnet18_outdoor") 35 | for NAME in ${NAMES[@]}; do 36 | echo "NAME : ${NAME}" 37 | 38 | # depth 39 | echo "Outdoor Depth : night-time results" 40 | SEQS=("outdoor_robust_night1" "outdoor_robust_night2" ) 41 | for SEQ in ${SEQS[@]}; do 42 | echo "Seq_name : ${SEQ}" 43 | cat ${RESULTS_DIR}/${NAME}/Depth/${SEQ}/eval_depth.txt 44 | done 45 | 46 | # pose 47 | SEQS=("outdoor_robust_night1" "outdoor_robust_night2" ) 48 | echo "Outdoor Pose : night-time results" 49 | for SEQ in ${SEQS[@]}; do 50 | echo "Seq_name : ${SEQ}" 51 | cat ${RESULTS_DIR}/${NAME}/POSE/${SEQ}/eval_pose.txt 52 | done 53 | done 54 | 55 | -------------------------------------------------------------------------------- /scripts/test_vivid_outdoor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run script : bash test_vivid_outdoor.sh 3 | 4 | GPU_ID=0 5 | DATA_ROOT=/HDD/Dataset_processed/VIVID_256/ 6 | RESULTS_DIR=results/ 7 | 8 | RESNET=18 9 | IMG_H=256 10 | IMG_W=320 11 | DATASET=VIVID 12 | INPUT_TYPE=T 13 | DEPTH_GT_DIR=Depth_T 14 | POSE_GT=poses_T.txt 15 | 16 | # outdoor testset 17 | NAMES=("T_vivid_resnet18_outdoor") 18 | 19 | for NAME in ${NAMES[@]}; do 20 | SEQS=("outdoor_robust_night1" "outdoor_robust_night2" ) 21 | POSE_NET=checkpoints/${NAME}/exp_pose_pose_model_best.pth.tar 22 | DISP_NET=checkpoints/${NAME}/dispnet_disp_model_best.pth.tar 23 | echo "${NAME}" 24 | 25 | for SEQ in ${SEQS[@]}; do 26 | echo "Seq_name : ${SEQ}" 27 | SCENE=outdoor 28 | 29 | #mkdir -p ${RESULTS_DIR} 30 | DATA_DIR=${DATA_ROOT}/${SEQ}/ 31 | OUTPUT_DEPTH_DIR=${RESULTS_DIR}/${NAME}/Depth/${SEQ}/ 32 | OUTPUT_POSE_DIR=${RESULTS_DIR}/${NAME}/POSE/${SEQ}/ 33 | mkdir -p ${OUTPUT_DEPTH_DIR} 34 | mkdir -p ${OUTPUT_POSE_DIR} 35 | 36 | # Detph Evaulation 37 | CUDA_VISIBLE_DEVICES=${GPU_ID} python test_disp.py \ 38 | --resnet-layers $RESNET --pretrained-dispnet $DISP_NET \ 39 | --img-height $IMG_H --img-width $IMG_W \ 40 | --dataset-dir ${DATA_DIR} --output-dir $OUTPUT_DEPTH_DIR >> ${OUTPUT_DEPTH_DIR}/disp.txt 41 | 42 | CUDA_VISIBLE_DEVICES=${GPU_ID} python eval_vivid/eval_depth.py \ 43 | --dataset $DATASET --pred_depth ${OUTPUT_DEPTH_DIR}/predictions.npy \ 44 | --gt_depth ${DATA_DIR}/${DEPTH_GT_DIR} --scene outdoor >> ${OUTPUT_DEPTH_DIR}/eval_depth.txt 45 | 46 | rm ${OUTPUT_DEPTH_DIR}/predictions.npy 47 | 48 | 49 | # Pose Evaulation 50 | CUDA_VISIBLE_DEVICES=${GPU_ID} python test_pose.py \ 51 | --resnet-layers $RESNET --pretrained-posenet $POSE_NET \ 52 | --img-height $IMG_H --img-width $IMG_W \ 53 | --dataset-dir ${DATA_ROOT} --output-dir ${OUTPUT_POSE_DIR} \ 54 | --sequences ${SEQ} >> ${OUTPUT_POSE_DIR}/eval_pose.txt 55 | done 56 | done 57 | 58 | -------------------------------------------------------------------------------- /common/models/ResDispPoseNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .resnet_encoder import * 9 | from .PoseResNet import PoseDecoder 10 | from .DispResNet import DepthDecoder 11 | 12 | class DispPoseResNet(nn.Module): 13 | def __init__(self, num_layers = 18, pretrained = True, num_channel=1): 14 | super(DispPoseResNet, self).__init__() 15 | self.encoder = ResnetEncoder(num_layers = num_layers, pretrained = pretrained, num_input_images=1, num_channel=num_channel) 16 | self.depth_decoder = DepthDecoder(self.encoder.num_ch_enc) 17 | self.pose_decoder = PoseDecoder(self.encoder.num_ch_enc*2) 18 | 19 | self.DispResNet = DispResNet(self.encoder, self.depth_decoder) 20 | self.PoseResNet = PoseResNet(self.encoder, self.pose_decoder) 21 | 22 | def init_weights(self): 23 | pass 24 | 25 | class DispResNet(nn.Module): 26 | 27 | def __init__(self, encoder, decoder): 28 | super(DispResNet, self).__init__() 29 | self.encoder = encoder 30 | self.decoder = decoder 31 | 32 | def init_weights(self): 33 | pass 34 | 35 | def forward(self, x): 36 | enc_features = self.encoder(x) 37 | outputs = self.decoder(enc_features) 38 | 39 | if self.training: 40 | return outputs 41 | else: 42 | return outputs[0] 43 | 44 | class PoseResNet(nn.Module): 45 | 46 | def __init__(self, encoder, decoder): 47 | super(PoseResNet, self).__init__() 48 | self.encoder = encoder 49 | self.decoder = decoder 50 | 51 | def init_weights(self): 52 | pass 53 | 54 | def forward(self, img1, img2): 55 | features1 = self.encoder(img1) 56 | features2 = self.encoder(img2) 57 | 58 | features = [] 59 | for k in range(0, len(features1)) : 60 | features.append(torch.cat([features1[k],features2[k]],dim=1)) 61 | pose = self.decoder([features]) 62 | return pose 63 | -------------------------------------------------------------------------------- /scripts/test_vivid_indoor.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # run script : bash test_vivid_indoor.sh 3 | GPU_ID=0 4 | DATA_ROOT=/HDD/Dataset_processed/VIVID_256/ 5 | RESULTS_DIR=results/ 6 | 7 | RESNET=18 8 | IMG_H=256 9 | IMG_W=320 10 | DATASET=VIVID 11 | DEPTH_GT_DIR=Depth_T 12 | POSE_GT=poses_T.txt 13 | 14 | NAMES=("T_vivid_resnet18_indoor") 15 | 16 | for NAME in ${NAMES[@]}; do 17 | # indoor testset 18 | SEQS=("indoor_robust_dark" "indoor_aggresive_dark" "indoor_aggresive_local" "indoor_unstable_dark" "indoor_robust_varying_well_lit") 19 | POSE_NET=checkpoints/${NAME}/exp_pose_pose_model_best.pth.tar 20 | DISP_NET=checkpoints/${NAME}/dispnet_disp_model_best.pth.tar 21 | echo "${NAME}" 22 | 23 | # depth 24 | for SEQ in ${SEQS[@]}; do 25 | echo "Seq_name : ${SEQ}" 26 | 27 | #mkdir -p ${RESULTS_DIR} 28 | DATA_DIR=${DATA_ROOT}/${SEQ}/ 29 | OUTPUT_DEPTH_DIR=${RESULTS_DIR}/${NAME}/Depth/${SEQ}/ 30 | mkdir -p ${OUTPUT_DEPTH_DIR} 31 | 32 | # Detph Evaulation 33 | CUDA_VISIBLE_DEVICES=${GPU_ID} python test_disp.py \ 34 | --resnet-layers $RESNET --pretrained-dispnet $DISP_NET \ 35 | --img-height $IMG_H --img-width $IMG_W \ 36 | --dataset-dir ${DATA_DIR} --output-dir $OUTPUT_DEPTH_DIR >> ${OUTPUT_DEPTH_DIR}/disp.txt 37 | 38 | CUDA_VISIBLE_DEVICES=${GPU_ID} python eval_vivid/eval_depth.py \ 39 | --dataset $DATASET --pred_depth ${OUTPUT_DEPTH_DIR}/predictions.npy \ 40 | --gt_depth ${DATA_DIR}/${DEPTH_GT_DIR} --scene indoor >> ${OUTPUT_DEPTH_DIR}/eval_depth.txt 41 | 42 | rm ${OUTPUT_DEPTH_DIR}/predictions.npy 43 | done 44 | 45 | # pose 46 | SEQS=("indoor_robust_dark" "indoor_robust_varying" "indoor_aggresive_dark" "indoor_aggresive_local" "indoor_unstable_dark") 47 | for SEQ in ${SEQS[@]}; do 48 | echo "Seq_name : ${SEQ}" 49 | 50 | #mkdir -p ${RESULTS_DIR} 51 | DATA_DIR=${DATA_ROOT}/${SEQ}/ 52 | OUTPUT_POSE_DIR=${RESULTS_DIR}/${NAME}/POSE/${SEQ}/ 53 | mkdir -p ${OUTPUT_POSE_DIR} 54 | 55 | # Pose Evaulation 56 | CUDA_VISIBLE_DEVICES=${GPU_ID} python test_pose.py \ 57 | --resnet-layers $RESNET --pretrained-posenet $POSE_NET \ 58 | --img-height $IMG_H --img-width $IMG_W \ 59 | --dataset-dir ${DATA_ROOT} --output-dir ${OUTPUT_POSE_DIR} \ 60 | --sequences ${SEQ} >> ${OUTPUT_POSE_DIR}/eval_pose.txt 61 | done 62 | done 63 | -------------------------------------------------------------------------------- /eval_vivid/pose_evaluation_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from path import Path 3 | from imageio import imread 4 | from tqdm import tqdm 5 | 6 | 7 | class test_framework_VIVID(object): 8 | def __init__(self, root, sequence_set, seq_length=3, step=1): 9 | self.root = root 10 | self.img_files, self.poses, self.sample_indices = read_scene_data(self.root, sequence_set, seq_length, step) 11 | 12 | def generator(self): 13 | for img_list, pose_list, sample_list in zip(self.img_files, self.poses, self.sample_indices): 14 | for snippet_indices in sample_list: 15 | imgs = [np.expand_dims(imread(img_list[i]).astype(np.float32), axis=2) for i in snippet_indices] 16 | 17 | poses = np.stack([pose_list[i] for i in snippet_indices]) 18 | first_pose = poses[0] 19 | poses[:,:,-1] -= first_pose[:,-1] 20 | compensated_poses = np.linalg.inv(first_pose[:,:3]) @ poses 21 | 22 | yield {'imgs': imgs, 23 | 'path': img_list[0], 24 | 'poses': compensated_poses 25 | } 26 | 27 | def __iter__(self): 28 | return self.generator() 29 | 30 | def __len__(self): 31 | return sum(len(imgs) for imgs in self.img_files) 32 | 33 | 34 | def read_scene_data(data_root, sequence_set, seq_length=3, step=1): 35 | data_root = Path(data_root) 36 | im_sequences = [] 37 | poses_sequences = [] 38 | indices_sequences = [] 39 | demi_length = (seq_length - 1) // 2 40 | shift_range = np.array([step*i for i in range(-demi_length, demi_length + 1)]).reshape(1, -1) 41 | 42 | sequences = [] 43 | for seq in sequence_set: 44 | sequences.append((data_root/seq)) 45 | 46 | print('getting test metadata for theses sequences : {}'.format(sequences)) 47 | for sequence in tqdm(sequences): 48 | imgs = sorted((sequence/'Thermal').files('*.png')) 49 | poses = np.genfromtxt(sequence/'poses_T.txt').astype(np.float64).reshape(-1, 3, 4) 50 | 51 | # construct 5-snippet sequences 52 | tgt_indices = np.arange(demi_length, len(imgs) - demi_length).reshape(-1, 1) 53 | snippet_indices = shift_range + tgt_indices 54 | im_sequences.append(imgs) 55 | poses_sequences.append(poses) 56 | indices_sequences.append(snippet_indices) 57 | return im_sequences, poses_sequences, indices_sequences 58 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # User define 7 | checkpoints 8 | results 9 | 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /common/models/PoseResNet.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | from collections import OrderedDict 12 | from .resnet_encoder import * 13 | 14 | class PoseDecoder(nn.Module): 15 | def __init__(self, num_ch_enc, num_input_features=1, num_frames_to_predict_for=1, stride=1): 16 | super(PoseDecoder, self).__init__() 17 | 18 | self.num_ch_enc = num_ch_enc 19 | self.num_input_features = num_input_features 20 | 21 | if num_frames_to_predict_for is None: 22 | num_frames_to_predict_for = num_input_features - 1 23 | self.num_frames_to_predict_for = num_frames_to_predict_for 24 | 25 | self.convs = OrderedDict() 26 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 27 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 28 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 29 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) 30 | 31 | self.relu = nn.ReLU() 32 | 33 | self.net = nn.ModuleList(list(self.convs.values())) 34 | 35 | def forward(self, input_features): 36 | last_features = [f[-1] for f in input_features] 37 | 38 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 39 | cat_features = torch.cat(cat_features, 1) 40 | 41 | out = cat_features 42 | for i in range(3): 43 | out = self.convs[("pose", i)](out) 44 | if i != 2: 45 | out = self.relu(out) 46 | 47 | out = out.mean(3).mean(2) 48 | 49 | pose = 0.01 * out.view(-1, 6) 50 | 51 | return pose 52 | 53 | 54 | class PoseResNet(nn.Module): 55 | 56 | def __init__(self, num_layers = 18, pretrained = True, num_channel=3): 57 | super(PoseResNet, self).__init__() 58 | self.encoder = ResnetEncoder(num_layers = num_layers, pretrained = pretrained, num_input_images=2, num_channel=num_channel) 59 | self.decoder = PoseDecoder(self.encoder.num_ch_enc) 60 | 61 | def init_weights(self): 62 | pass 63 | 64 | def forward(self, img1, img2): 65 | x = torch.cat([img1,img2],1) 66 | features = self.encoder(x) 67 | pose = self.decoder([features]) 68 | return pose 69 | -------------------------------------------------------------------------------- /common/utils/logger.py: -------------------------------------------------------------------------------- 1 | from blessings import Terminal 2 | import progressbar 3 | import sys 4 | 5 | 6 | class TermLogger(object): 7 | def __init__(self, n_epochs, train_size, valid_size): 8 | self.n_epochs = n_epochs 9 | self.train_size = train_size 10 | self.valid_size = valid_size 11 | self.t = Terminal() 12 | s = 10 13 | e = 1 # epoch bar position 14 | tr = 3 # train bar position 15 | ts = 6 # valid bar position 16 | value = self.t.height 17 | h = int(0 if value is None else value) 18 | 19 | for i in range(10): 20 | print('') 21 | self.epoch_bar = progressbar.ProgressBar( 22 | max_value=n_epochs, fd=Writer(self.t, (0, h-s+e))) 23 | 24 | self.train_writer = Writer(self.t, (0, h-s+tr)) 25 | self.train_bar_writer = Writer(self.t, (0, h-s+tr+1)) 26 | 27 | self.valid_writer = Writer(self.t, (0, h-s+ts-1)) 28 | self.valid_writer2 = Writer(self.t, (0, h-s+ts)) 29 | self.valid_bar_writer = Writer(self.t, (0, h-s+ts+1)) 30 | 31 | self.reset_train_bar() 32 | self.reset_valid_bar() 33 | 34 | def reset_train_bar(self): 35 | self.train_bar = progressbar.ProgressBar( 36 | max_value=self.train_size, fd=self.train_bar_writer) 37 | 38 | def reset_valid_bar(self): 39 | self.valid_bar = progressbar.ProgressBar( 40 | max_value=self.valid_size, fd=self.valid_bar_writer) 41 | 42 | 43 | class Writer(object): 44 | """Create an object with a write method that writes to a 45 | specific place on the screen, defined at instantiation. 46 | 47 | This is the glue between blessings and progressbar. 48 | """ 49 | 50 | def __init__(self, t, location): 51 | """ 52 | Input: location - tuple of ints (x, y), the position 53 | of the bar in the terminal 54 | """ 55 | self.location = location 56 | self.t = t 57 | 58 | def write(self, string): 59 | with self.t.location(*self.location): 60 | sys.stdout.write("\033[K") 61 | print(string) 62 | 63 | def flush(self): 64 | return 65 | 66 | 67 | class AverageMeter(object): 68 | """Computes and stores the average and current value""" 69 | 70 | def __init__(self, i=1, precision=3): 71 | self.meters = i 72 | self.precision = precision 73 | self.reset(self.meters) 74 | 75 | def reset(self, i): 76 | self.val = [0]*i 77 | self.avg = [0]*i 78 | self.sum = [0]*i 79 | self.count = 0 80 | 81 | def update(self, val, n=1): 82 | if not isinstance(val, list): 83 | val = [val] 84 | assert(len(val) == self.meters) 85 | self.count += n 86 | for i, v in enumerate(val): 87 | self.val[i] = v 88 | self.sum[i] += v * n 89 | self.avg[i] = self.sum[i] / self.count 90 | 91 | def __repr__(self): 92 | val = ' '.join(['{:.{}f}'.format(v, self.precision) for v in self.val]) 93 | avg = ' '.join(['{:.{}f}'.format(a, self.precision) for a in self.avg]) 94 | return '{} ({})'.format(val, avg) 95 | -------------------------------------------------------------------------------- /common/data_prepare/well_lit_from_varying.txt: -------------------------------------------------------------------------------- 1 | 000020 2 | 000116 3 | 000168 4 | 000216 5 | 000302 6 | 000325 7 | 000376 8 | 000420 9 | 000443 10 | 000492 11 | 000538 12 | 000590 13 | 000023 14 | 000118 15 | 000169 16 | 000217 17 | 000303 18 | 000326 19 | 000377 20 | 000421 21 | 000444 22 | 000493 23 | 000539 24 | 000591 25 | 000042 26 | 000120 27 | 000170 28 | 000218 29 | 000304 30 | 000327 31 | 000378 32 | 000422 33 | 000445 34 | 000494 35 | 000540 36 | 000592 37 | 000044 38 | 000121 39 | 000171 40 | 000219 41 | 000305 42 | 000328 43 | 000379 44 | 000423 45 | 000446 46 | 000495 47 | 000541 48 | 000593 49 | 000045 50 | 000122 51 | 000172 52 | 000220 53 | 000306 54 | 000329 55 | 000380 56 | 000424 57 | 000447 58 | 000496 59 | 000542 60 | 000594 61 | 000046 62 | 000123 63 | 000173 64 | 000247 65 | 000307 66 | 000330 67 | 000381 68 | 000425 69 | 000448 70 | 000497 71 | 000543 72 | 000595 73 | 000047 74 | 000124 75 | 000174 76 | 000248 77 | 000308 78 | 000331 79 | 000382 80 | 000426 81 | 000449 82 | 000498 83 | 000544 84 | 000596 85 | 000065 86 | 000125 87 | 000175 88 | 000249 89 | 000309 90 | 000332 91 | 000383 92 | 000427 93 | 000450 94 | 000499 95 | 000545 96 | 000597 97 | 000067 98 | 000126 99 | 000176 100 | 000250 101 | 000310 102 | 000333 103 | 000385 104 | 000428 105 | 000477 106 | 000500 107 | 000546 108 | 000598 109 | 000069 110 | 000127 111 | 000177 112 | 000251 113 | 000311 114 | 000334 115 | 000387 116 | 000429 117 | 000478 118 | 000501 119 | 000547 120 | 000599 121 | 000071 122 | 000128 123 | 000178 124 | 000252 125 | 000312 126 | 000335 127 | 000389 128 | 000430 129 | 000479 130 | 000502 131 | 000548 132 | 000600 133 | 000073 134 | 000129 135 | 000179 136 | 000253 137 | 000313 138 | 000336 139 | 000392 140 | 000431 141 | 000480 142 | 000503 143 | 000549 144 | 000601 145 | 000076 146 | 000130 147 | 000180 148 | 000254 149 | 000314 150 | 000337 151 | 000394 152 | 000432 153 | 000481 154 | 000504 155 | 000550 156 | 000602 157 | 000078 158 | 000131 159 | 000181 160 | 000255 161 | 000315 162 | 000338 163 | 000396 164 | 000433 165 | 000482 166 | 000505 167 | 000551 168 | 000603 169 | 000079 170 | 000159 171 | 000182 172 | 000257 173 | 000316 174 | 000339 175 | 000397 176 | 000434 177 | 000483 178 | 000506 179 | 000552 180 | 000604 181 | 000080 182 | 000160 183 | 000183 184 | 000259 185 | 000317 186 | 000340 187 | 000398 188 | 000435 189 | 000484 190 | 000507 191 | 000553 192 | 000081 193 | 000161 194 | 000184 195 | 000261 196 | 000318 197 | 000369 198 | 000399 199 | 000436 200 | 000485 201 | 000508 202 | 000554 203 | 000082 204 | 000162 205 | 000185 206 | 000263 207 | 000319 208 | 000370 209 | 000400 210 | 000437 211 | 000486 212 | 000509 213 | 000581 214 | 000083 215 | 000163 216 | 000186 217 | 000266 218 | 000320 219 | 000371 220 | 000401 221 | 000438 222 | 000487 223 | 000510 224 | 000583 225 | 000084 226 | 000164 227 | 000187 228 | 000269 229 | 000321 230 | 000372 231 | 000402 232 | 000439 233 | 000488 234 | 000534 235 | 000585 236 | 000111 237 | 000165 238 | 000213 239 | 000274 240 | 000322 241 | 000373 242 | 000403 243 | 000440 244 | 000489 245 | 000535 246 | 000587 247 | 000112 248 | 000166 249 | 000214 250 | 000280 251 | 000323 252 | 000374 253 | 000404 254 | 000441 255 | 000490 256 | 000536 257 | 000588 258 | 000114 259 | 000167 260 | 000215 261 | 000282 262 | 000324 263 | 000375 264 | 000419 265 | 000442 266 | 000491 267 | 000537 268 | 000589 -------------------------------------------------------------------------------- /dataloader/VIVID_sequence_folders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | 5 | import math 6 | import random 7 | from imageio import imread 8 | from path import Path 9 | 10 | def load_as_float(path): 11 | return np.expand_dims(imread(path).astype(np.float32), axis=2) # HW -> HW1 12 | 13 | class SequenceFolder(data.Dataset): 14 | """A sequence data loader where the files are arranged in this way: 15 | root/scene_1/Thermal/0000000.png 16 | root/scene_1/Thermal/0000001.png 17 | .. 18 | root/scene_1/cam.txt 19 | root/scene_2/Thermal/0000000.png 20 | . 21 | transform functions must take in a list a images and a numpy array (usually intrinsics matrix) 22 | """ 23 | def __init__(self, root, seed=None, train=True, sequence_length=3,\ 24 | tf_share=None, tf_input=None, tf_loss=None, \ 25 | scene_type='indoor', interval=1): 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | 29 | self.root = Path(root) 30 | if scene_type == 'indoor': 31 | folder_list_path = self.root/'train_indoor.txt' if train else self.root/'val_indoor.txt' 32 | elif scene_type == 'outdoor': 33 | folder_list_path = self.root/'train_outdoor.txt' if train else self.root/'val_outdoor.txt' 34 | 35 | self.folders = [self.root/folder[:-1] for folder in open(folder_list_path)] 36 | self.tf_share = tf_share 37 | self.tf_input = tf_input 38 | self.tf_loss = tf_loss 39 | self.crawl_folders(sequence_length, interval) 40 | 41 | def crawl_folders(self, sequence_length, interval): 42 | sequence_set = [] 43 | demi_length = (sequence_length-1)//2 + interval - 1 44 | shifts = list(range(-demi_length, demi_length + 1)) 45 | for i in range(1, 2*demi_length): 46 | shifts.pop(1) 47 | 48 | for folder in self.folders: 49 | imgs = sorted((folder/"Thermal").files('*.png')) 50 | intrinsics = np.genfromtxt(folder/'cam_T.txt').astype(np.float32).reshape((3, 3)) 51 | 52 | for i in range(demi_length, len(imgs)-demi_length): 53 | sample = {'intrinsics': intrinsics, 'tgt': imgs[i], 'ref_imgs': []} 54 | for j in shifts: 55 | sample['ref_imgs'].append(imgs[i+j]) 56 | sequence_set.append(sample) 57 | 58 | random.shuffle(sequence_set) 59 | self.samples = sequence_set 60 | 61 | def __getitem__(self, index): 62 | sample = self.samples[index] 63 | 64 | # Read thermal images & GT depths 65 | tgt_img = load_as_float(sample['tgt']) 66 | ref_imgs = [load_as_float(ref_img) for ref_img in sample['ref_imgs']] 67 | 68 | # Pre-process thermal images for network input & loss calculation 69 | imgs, intrinsics = self.tf_share([tgt_img] + ref_imgs, np.expand_dims(np.copy(sample['intrinsics']),axis=0)) 70 | imgs_input, _ = self.tf_input(imgs, None) 71 | imgs_loss, _ = self.tf_loss(imgs, None) 72 | 73 | tgt_img_input = imgs_input[0] 74 | ref_imgs_input = imgs_input[1:] 75 | tgt_img_loss = imgs_loss[0] 76 | ref_imgs_loss = imgs_loss[1:] 77 | 78 | return tgt_img_input, ref_imgs_input, tgt_img_loss, ref_imgs_loss, intrinsics.squeeze() 79 | 80 | def __len__(self): 81 | return len(self.samples) 82 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ThermalDepth 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.6.20=hecda079_0 10 | - certifi=2020.6.20=py37hc8dfbb8_0 11 | - cudatoolkit=10.2.89=hfd86e86_1 12 | - freetype=2.10.2=h5ab3b9f_0 13 | - intel-openmp=2020.1=217 14 | - jpeg=9b=h024ee3a_2 15 | - ld_impl_linux-64=2.33.1=h53a641e_7 16 | - libedit=3.1.20191231=h7b6447c_0 17 | - libffi=3.2.1=hd88cf55_4 18 | - libgcc-ng=9.1.0=hdf63c60_0 19 | - libgfortran-ng=7.3.0=hdf63c60_0 20 | - libpng=1.6.37=hbc83047_0 21 | - libprotobuf=3.12.3=h8b12597_1 22 | - libstdcxx-ng=9.1.0=hdf63c60_0 23 | - libtiff=4.1.0=h2733197_1 24 | - lz4-c=1.9.2=he6710b0_0 25 | - mkl=2020.1=217 26 | - mkl-service=2.3.0=py37he904b0f_0 27 | - mkl_fft=1.1.0=py37h23d657b_0 28 | - mkl_random=1.1.1=py37h0573a6f_0 29 | - ncurses=6.2=he6710b0_1 30 | - ninja=1.9.0=py37hfd86e86_0 31 | - numpy=1.18.5=py37ha1c710e_0 32 | - numpy-base=1.18.5=py37hde5b4d6_0 33 | - olefile=0.46=py37_0 34 | - openssl=1.1.1g=h516909a_0 35 | - pillow=7.1.2=py37hb39fc2d_0 36 | - pip=20.1.1=py37_1 37 | - python=3.7.6=h0371630_2 38 | - python_abi=3.7=1_cp37m 39 | - pytorch=1.5.1=py3.7_cuda10.2.89_cudnn7.6.5_0 40 | - pyyaml=5.3.1=py37hb5d75c8_1 41 | - readline=7.0=h7b6447c_5 42 | - setuptools=47.3.1=py37_0 43 | - six=1.15.0=py_0 44 | - sqlite=3.32.3=h62c20be_0 45 | - tensorboardx=2.1=py_0 46 | - termcolor=1.1.0=py_2 47 | - tk=8.6.10=hbc83047_0 48 | - torchvision=0.6.1=py37_cu102 49 | - wheel=0.34.2=py37_0 50 | - xz=5.2.5=h7b6447c_0 51 | - yacs=0.1.6=py_0 52 | - yaml=0.2.5=h516909a_0 53 | - zlib=1.2.11=h7b6447c_3 54 | - zstd=1.4.4=h0b5b093_3 55 | - pip: 56 | - absl-py==0.9.0 57 | - argparse==1.4.0 58 | - astunparse==1.6.3 59 | - blessings==1.7 60 | - cachetools==4.1.1 61 | - catkin-pkg==0.4.22 62 | - cffi==1.14.4 63 | - chardet==3.0.4 64 | - click==7.1.2 65 | - cloudpickle==1.6.0 66 | - configparser==5.0.1 67 | - cycler==0.10.0 68 | - decorator==4.4.2 69 | - distro==1.5.0 70 | - docker-pycreds==0.4.0 71 | - docutils==0.16 72 | - future==0.18.2 73 | - gast==0.3.3 74 | - gitdb==4.0.5 75 | - gitpython==3.1.11 76 | - google-auth==1.18.0 77 | - google-auth-oauthlib==0.4.1 78 | - google-pasta==0.2.0 79 | - grpcio==1.30.0 80 | - h5py==2.10.0 81 | - horovod==0.21.0 82 | - idna==2.10 83 | - imageio==2.9.0 84 | - importlib-metadata==1.7.0 85 | - keras-preprocessing==1.1.2 86 | - kiwisolver==1.2.0 87 | - markdown==3.2.2 88 | - matplotlib==3.2.2 89 | - networkx==2.4 90 | - oauthlib==3.1.0 91 | - opencv-python==4.3.0.36 92 | - opt-einsum==3.2.1 93 | - path==14.0.1 94 | - pebble==4.5.3 95 | - progressbar2==3.51.4 96 | - promise==2.3 97 | - protobuf==3.12.2 98 | - psutil==5.8.0 99 | - pyasn1==0.4.8 100 | - pyasn1-modules==0.2.8 101 | - pycparser==2.20 102 | - pyparsing==2.4.7 103 | - python-dateutil==2.8.1 104 | - python-utils==2.4.0 105 | - pywavelets==1.1.1 106 | - requests==2.24.0 107 | - requests-oauthlib==1.3.0 108 | - rospkg==1.2.8 109 | - rsa==4.6 110 | - scikit-image==0.17.2 111 | - scipy==1.1.0 112 | - sentry-sdk==0.19.5 113 | - shortuuid==1.0.1 114 | - smmap==3.0.4 115 | - subprocess32==3.5.4 116 | - tensorboard==2.2.2 117 | - tensorboard-plugin-wit==1.7.0 118 | - tensorflow==2.2.0 119 | - tensorflow-estimator==2.2.0 120 | - tifffile==2020.7.4 121 | - torchgeometry==0.1.2 122 | - tqdm==4.47.0 123 | - urllib3==1.25.9 124 | - wandb==0.10.12 125 | - watchdog==1.0.2 126 | - werkzeug==1.0.1 127 | - wrapt==1.12.1 128 | - zipp==3.1.0 129 | prefix: /home/user/anaconda3/envs/ThermalDepth 130 | 131 | -------------------------------------------------------------------------------- /common/models/DispResNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .resnet_encoder import * 8 | 9 | 10 | import numpy as np 11 | from collections import OrderedDict 12 | 13 | class ConvBlock(nn.Module): 14 | """Layer to perform a convolution followed by ELU 15 | """ 16 | def __init__(self, in_channels, out_channels): 17 | super(ConvBlock, self).__init__() 18 | 19 | self.conv = Conv3x3(in_channels, out_channels) 20 | self.nonlin = nn.ELU(inplace=True) 21 | 22 | def forward(self, x): 23 | out = self.conv(x) 24 | out = self.nonlin(out) 25 | return out 26 | 27 | class Conv3x3(nn.Module): 28 | """Layer to pad and convolve input 29 | """ 30 | def __init__(self, in_channels, out_channels, use_refl=True): 31 | super(Conv3x3, self).__init__() 32 | 33 | if use_refl: 34 | self.pad = nn.ReflectionPad2d(1) 35 | else: 36 | self.pad = nn.ZeroPad2d(1) 37 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 38 | 39 | def forward(self, x): 40 | out = self.pad(x) 41 | out = self.conv(out) 42 | return out 43 | 44 | def upsample(x): 45 | """Upsample input tensor by a factor of 2 46 | """ 47 | return F.interpolate(x, scale_factor=2, mode="nearest") 48 | 49 | class DepthDecoder(nn.Module): 50 | def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True): 51 | super(DepthDecoder, self).__init__() 52 | 53 | self.alpha = 10 54 | self.beta = 0.01 55 | 56 | self.num_output_channels = num_output_channels 57 | self.use_skips = use_skips 58 | self.upsample_mode = 'nearest' 59 | self.scales = scales 60 | 61 | self.num_ch_enc = num_ch_enc 62 | self.num_ch_dec = np.array([16, 32, 64, 128, 256]) 63 | 64 | # decoder 65 | self.convs = OrderedDict() 66 | for i in range(4, -1, -1): 67 | # upconv_0 68 | num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] 69 | num_ch_out = self.num_ch_dec[i] 70 | self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) 71 | 72 | # upconv_1 73 | num_ch_in = self.num_ch_dec[i] 74 | if self.use_skips and i > 0: 75 | num_ch_in += self.num_ch_enc[i - 1] 76 | num_ch_out = self.num_ch_dec[i] 77 | self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) 78 | 79 | for s in self.scales: 80 | self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels) 81 | 82 | self.decoder = nn.ModuleList(list(self.convs.values())) 83 | self.sigmoid = nn.Sigmoid() 84 | 85 | def forward(self, input_features): 86 | self.outputs = [] 87 | 88 | # decoder 89 | x = input_features[-1] 90 | for i in range(4, -1, -1): 91 | x = self.convs[("upconv", i, 0)](x) 92 | x = [upsample(x)] 93 | if self.use_skips and i > 0: 94 | x += [input_features[i - 1]] 95 | x = torch.cat(x, 1) 96 | x = self.convs[("upconv", i, 1)](x) 97 | if i in self.scales: 98 | self.outputs.append(self.alpha * self.sigmoid(self.convs[("dispconv", i)](x)) + self.beta) 99 | 100 | self.outputs = self.outputs[::-1] 101 | 102 | return self.outputs 103 | 104 | class DispResNet(nn.Module): 105 | 106 | def __init__(self, num_layers = 18, pretrained = True, num_channel=3): 107 | super(DispResNet, self).__init__() 108 | self.encoder = ResnetEncoder(num_layers = num_layers, pretrained = pretrained, num_input_images=1, num_channel=num_channel) 109 | self.decoder = DepthDecoder(self.encoder.num_ch_enc) 110 | def init_weights(self): 111 | pass 112 | 113 | def forward(self, x): 114 | features = self.encoder(x) 115 | outputs = self.decoder(features) 116 | 117 | if self.training: 118 | return outputs 119 | else: 120 | return outputs[0] 121 | -------------------------------------------------------------------------------- /test_disp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from skimage.transform import resize as imresize 3 | from imageio import imread 4 | import numpy as np 5 | from path import Path 6 | import argparse 7 | from tqdm import tqdm 8 | import time 9 | 10 | import sys 11 | sys.path.append('./common/') 12 | import models 13 | 14 | parser = argparse.ArgumentParser(description='Script for DispNet testing with corresponding groundTruth', 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument("--pretrained-dispnet", required=True, type=str, help="pretrained DispNet path") 17 | parser.add_argument("--img-height", default=256, type=int, help="Image height") # 256 (kitti) 18 | parser.add_argument("--img-width", default=320, type=int, help="Image width") # 832 (kitti) 19 | parser.add_argument("--min-depth", default=1e-3) 20 | parser.add_argument("--max-depth", default=80) 21 | parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory") 22 | parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file") 23 | parser.add_argument("--output-dir", default=None, required=True, type=str, help="Output directory for saving predictions in a big 3D numpy file") 24 | parser.add_argument('--resnet-layers', required=True, type=int, default=18, choices=[18, 50], help='depth network architecture.') 25 | 26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | 28 | def load_tensor_image(filename, args): 29 | img = np.expand_dims(imread(filename).astype(np.float32), axis=2) 30 | h,w,_ = img.shape 31 | if (h != args.img_height or w != args.img_width): 32 | img = imresize(img, (args.img_height, args.img_width)).astype(np.float32) 33 | img = np.transpose(img, (2, 0, 1)) 34 | img = (torch.from_numpy(img).float() / 2**14) 35 | tensor_img = ((img.unsqueeze(0)-0.45)/0.225).to(device) 36 | return tensor_img 37 | 38 | @torch.no_grad() 39 | def main(): 40 | args = parser.parse_args() 41 | 42 | # load models 43 | disp_pose_net = models.DispPoseResNet(args.resnet_layers, False, num_channel=1).to(device) 44 | disp_net = disp_pose_net.DispResNet 45 | 46 | weights = torch.load(args.pretrained_dispnet) 47 | disp_net.load_state_dict(weights['state_dict']) 48 | disp_net.eval() 49 | 50 | dataset_dir = Path(args.dataset_dir) 51 | 52 | # read file list 53 | if args.dataset_list is not None: 54 | with open(args.dataset_list, 'r') as f: 55 | test_files = list(f.read().splitlines()) 56 | else: 57 | test_files=sorted((dataset_dir+'Thermal').files('*.png')) 58 | 59 | print('{} files to test'.format(len(test_files))) 60 | 61 | output_dir = Path(args.output_dir) 62 | output_dir.makedirs_p() 63 | 64 | test_disp_avg = 0 65 | test_disp_std = 0 66 | test_depth_avg = 0 67 | test_depth_std = 0 68 | 69 | avg_time = 0 70 | for j in tqdm(range(len(test_files))): 71 | tgt_img = load_tensor_image(test_files[j], args) 72 | 73 | # compute speed 74 | torch.cuda.synchronize() 75 | t_start = time.time() 76 | 77 | output = disp_net(tgt_img) 78 | 79 | torch.cuda.synchronize() 80 | elapsed_time = time.time() - t_start 81 | 82 | avg_time += elapsed_time 83 | 84 | pred_disp = output.squeeze().cpu().numpy() 85 | 86 | if j == 0: 87 | predictions = np.zeros((len(test_files), *pred_disp.shape)) 88 | predictions[j] = 1/pred_disp 89 | 90 | test_disp_avg += pred_disp.mean() 91 | test_disp_std += pred_disp.std() 92 | test_depth_avg += predictions.mean() 93 | test_depth_std += predictions.std() 94 | 95 | np.save(output_dir/'predictions.npy', predictions) 96 | 97 | avg_time /= len(test_files) 98 | print('Avg Time: ', avg_time, ' seconds.') 99 | print('Avg Speed: ', 1.0 / avg_time, ' fps') 100 | 101 | print('Avg disp : {0:0.3f}, std disp : {1:0.5f}'.format(test_disp_avg/len(test_files), test_disp_std/len(test_files))) 102 | print('Avg depth: {0:0.3f}, std depth: {1:0.5f}'.format(test_depth_avg/len(test_files), test_depth_std/len(test_files))) 103 | 104 | 105 | if __name__ == '__main__': 106 | main() 107 | -------------------------------------------------------------------------------- /common/utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import shutil 3 | import numpy as np 4 | import torch 5 | from path import Path 6 | import datetime 7 | from collections import OrderedDict 8 | from matplotlib import cm 9 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 10 | import matplotlib.pyplot as plt 11 | 12 | def high_res_colormap(low_res_cmap, resolution=1000, max_value=1): 13 | # Construct the list colormap, with interpolated values for higer resolution 14 | # For a linear segmented colormap, you can just specify the number of point in 15 | # cm.get_cmap(name, lutsize) with the parameter lutsize 16 | x = np.linspace(0, 1, low_res_cmap.N) 17 | low_res = low_res_cmap(x) 18 | new_x = np.linspace(0, max_value, resolution) 19 | high_res = np.stack([np.interp(new_x, x, low_res[:, i]) 20 | for i in range(low_res.shape[1])], axis=1) 21 | return ListedColormap(high_res) 22 | 23 | def opencv_rainbow(resolution=1000): 24 | # Construct the opencv equivalent of Rainbow 25 | opencv_rainbow_data = ( 26 | (0.000, (1.00, 0.00, 0.00)), 27 | (0.400, (1.00, 1.00, 0.00)), 28 | (0.600, (0.00, 1.00, 0.00)), 29 | (0.800, (0.00, 0.00, 1.00)), 30 | (1.000, (0.60, 0.00, 1.00)) 31 | ) 32 | 33 | return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution) 34 | 35 | def opencv_rainbow_inv(resolution=1000): 36 | # Construct the opencv equivalent of Rainbow 37 | opencv_rainbow_data = ( 38 | (0.000, (0.60, 0.00, 1.00)), 39 | (0.400, (0.00, 0.00, 0.10)), 40 | (0.600, (0.00, 1.00, 0.00)), 41 | (0.800, (1.00, 1.00, 0.00)), 42 | (1.000, (1.00, 0.00, 0.00)) 43 | ) 44 | 45 | return LinearSegmentedColormap.from_list('opencv_rainbow', opencv_rainbow_data, resolution) 46 | 47 | COLORMAPS = {'rainbow': opencv_rainbow(), 48 | 'rainbow_inv' : plt.get_cmap('jet'), #opencv_rainbow_inv(), 49 | 'magma': high_res_colormap(cm.get_cmap('magma')), 50 | 'magma_inv': high_res_colormap(cm.get_cmap('magma').reversed()), 51 | 'bone': cm.get_cmap('bone', 10000), 52 | 'bone_inv': cm.get_cmap('bone', 10000).reversed()} 53 | 54 | def ind2rgb(im): 55 | cmap = plt.get_cmap('jet') 56 | im = im.cpu().squeeze().numpy() 57 | im = cmap(im) 58 | im = im[:,:,0:3] 59 | # put it from HWC to CHW format 60 | im = np.transpose(im, (2, 0, 1)) 61 | return torch.from_numpy(im).float() 62 | 63 | def tensor2array(tensor, max_value=None, colormap='rainbow'): 64 | tensor = tensor.detach().cpu() 65 | if max_value is None: 66 | max_value = tensor[~np.isinf(tensor).type(torch.bool)].max() 67 | tensor[np.isinf(tensor).type(torch.bool)] = max_value 68 | max_value = tensor.max().item() 69 | if tensor.ndimension() == 2 or tensor.size(0) == 1: 70 | norm_array = tensor.squeeze().numpy()/max_value 71 | array = COLORMAPS[colormap](norm_array).astype(np.float32) 72 | array = array.transpose(2, 0, 1) 73 | 74 | elif tensor.ndimension() == 3: 75 | if( tensor.size(0) == 3) : 76 | array = 0.45 + tensor.numpy()*0.225 77 | elif (tensor.size(0) == 2): 78 | array = tensor.numpy() 79 | 80 | return array 81 | 82 | def tensor2array_thermal(tensor): 83 | tensor = tensor.detach().cpu() 84 | array = np.expand_dims(0.45 + tensor.detach().cpu().squeeze().numpy()*0.225, 0) 85 | return array 86 | 87 | def save_checkpoint(save_path, dispnet_state, exp_pose_state, is_depth_best, is_pose_best, filename='checkpoint.pth.tar'): 88 | file_prefixes = ['dispnet', 'exp_pose'] 89 | states = [dispnet_state, exp_pose_state] 90 | for (prefix, state) in zip(file_prefixes, states): 91 | torch.save(state, save_path/'{}_{}'.format(prefix, filename)) 92 | 93 | if (is_depth_best&is_pose_best): 94 | for prefix in file_prefixes: 95 | shutil.copyfile(save_path/'{}_{}'.format(prefix, filename), 96 | save_path/'{}_both_model_best.pth.tar'.format(prefix)) 97 | elif is_depth_best : 98 | for prefix in file_prefixes: 99 | shutil.copyfile(save_path/'{}_{}'.format(prefix, filename), 100 | save_path/'{}_disp_model_best.pth.tar'.format(prefix)) 101 | elif is_pose_best : 102 | for prefix in file_prefixes: 103 | shutil.copyfile(save_path/'{}_{}'.format(prefix, filename), 104 | save_path/'{}_pose_model_best.pth.tar'.format(prefix)) 105 | -------------------------------------------------------------------------------- /common/models/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.models as models 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | class ResNetMultiImageInput(models.ResNet): 17 | """Constructs a resnet model with varying number of input images. 18 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 19 | """ 20 | def __init__(self, block, layers, num_classes=1000, num_input_images=1, num_channel=3): 21 | super(ResNetMultiImageInput, self).__init__(block, layers) 22 | self.inplanes = 64 23 | self.conv1 = nn.Conv2d(num_input_images * num_channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 24 | self.bn1 = nn.BatchNorm2d(64) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 27 | self.layer1 = self._make_layer(block, 64, layers[0]) 28 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 29 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 30 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 31 | 32 | for m in self.modules(): 33 | if isinstance(m, nn.Conv2d): 34 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 35 | elif isinstance(m, nn.BatchNorm2d): 36 | nn.init.constant_(m.weight, 1) 37 | nn.init.constant_(m.bias, 0) 38 | 39 | 40 | 41 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1, num_channel=3): 42 | """Constructs a ResNet model. 43 | Args: 44 | num_layers (int): Number of resnet layers. Must be 18 or 50 45 | pretrained (bool): If True, returns a model pre-trained on ImageNet 46 | num_input_images (int): Number of frames stacked as input 47 | """ 48 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 49 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 50 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 51 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images, num_channel=num_channel) 52 | 53 | if pretrained: 54 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) 55 | if num_channel == 3: 56 | loaded['conv1.weight'] = torch.cat([loaded['conv1.weight']] * num_input_images, 1) / num_input_images 57 | else : 58 | loaded['conv1.weight'] = torch.cat([loaded['conv1.weight'].mean(1).unsqueeze(1)] * num_input_images, 1) / num_input_images 59 | model.load_state_dict(loaded) 60 | return model 61 | 62 | 63 | class ResnetEncoder(nn.Module): 64 | """Pytorch module for a resnet encoder 65 | """ 66 | def __init__(self, num_layers, pretrained, num_input_images=1, num_channel=3): 67 | super(ResnetEncoder, self).__init__() 68 | 69 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 70 | 71 | resnets = {18: models.resnet18, 72 | 34: models.resnet34, 73 | 50: models.resnet50, 74 | 101: models.resnet101, 75 | 152: models.resnet152} 76 | 77 | if num_layers not in resnets: 78 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 79 | 80 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images, num_channel) 81 | """ 82 | if num_input_images > 1: 83 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 84 | else: 85 | self.encoder = resnets[num_layers](pretrained) 86 | """ 87 | if num_layers > 34: 88 | self.num_ch_enc[1:] *= 4 89 | 90 | def forward(self, input_image): 91 | self.features = [] 92 | x = input_image 93 | x = self.encoder.conv1(x) 94 | x = self.encoder.bn1(x) 95 | self.features.append(self.encoder.relu(x)) 96 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 97 | self.features.append(self.encoder.layer2(self.features[-1])) 98 | self.features.append(self.encoder.layer3(self.features[-1])) 99 | self.features.append(self.encoder.layer4(self.features[-1])) 100 | 101 | return self.features -------------------------------------------------------------------------------- /eval_vivid/eval_depth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import matplotlib as mpl 4 | import matplotlib.cm as cm 5 | import numpy as np 6 | import os 7 | from tqdm import tqdm 8 | from path import Path 9 | 10 | parser = argparse.ArgumentParser(description="Depth options") 11 | parser.add_argument("--dataset", required=True, help="VIVID", choices=['VIVID'], type=str) 12 | parser.add_argument("--scene", required=True, help="scene type for VIVID", choices=['indoor', 'outdoor'], type=str) 13 | parser.add_argument("--pred_depth", required=True, help="depth predictions npy", type=str) 14 | parser.add_argument("--gt_depth", required=True, help="gt depth folder for VIVID", type=str) 15 | parser.add_argument("--ratio_name", help="names for saving ratios", type=str) 16 | 17 | 18 | args = parser.parse_args() 19 | 20 | def compute_depth_errors(gt, pred): 21 | """Computation of error metrics between predicted and ground truth depths 22 | Args: 23 | gt (N): ground truth depth 24 | pred (N): predicted depth 25 | """ 26 | thresh = np.maximum((gt / pred), (pred / gt)) 27 | a1 = (thresh < 1.25).mean() 28 | a2 = (thresh < 1.25 ** 2).mean() 29 | a3 = (thresh < 1.25 ** 3).mean() 30 | 31 | rmse = (gt - pred) ** 2 32 | rmse = np.sqrt(rmse.mean()) 33 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 34 | rmse_log = np.sqrt(rmse_log.mean()) 35 | 36 | log10 = np.mean(np.abs((np.log10(gt) - np.log10(pred)))) 37 | 38 | abs_rel = np.mean(np.abs(gt - pred) / gt) 39 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 40 | 41 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 42 | 43 | class DepthEval(): 44 | def __init__(self): 45 | 46 | self.min_depth = 1e-3 47 | 48 | if args.dataset == 'VIVID': 49 | if args.scene == 'indoor' : 50 | self.max_depth = 10. 51 | elif args.scene == 'outdoor' : 52 | self.max_depth = 80. 53 | 54 | def main(self): 55 | pred_depths = [] 56 | 57 | """ Get result """ 58 | # Read precomputed result 59 | pred_depths = np.load(os.path.join(args.pred_depth)) 60 | 61 | """ Evaluation """ 62 | if args.dataset == 'VIVID': 63 | gt_depths = [] 64 | for gt_f in sorted(Path(args.gt_depth).files("*.npy")): 65 | gt_depths.append(np.load(gt_f)) 66 | 67 | pred_depths = self.evaluate_depth(gt_depths, pred_depths, eval_mono=True) 68 | 69 | 70 | def evaluate_depth(self, gt_depths, pred_depths, eval_mono=True): 71 | """evaluate depth result 72 | Args: 73 | gt_depths (NxHxW): gt depths 74 | pred_depths (NxHxW): predicted depths 75 | split (str): data split for evaluation 76 | - depth_eigen 77 | eval_mono (bool): use median scaling if True 78 | """ 79 | errors = [] 80 | ratios = [] 81 | resized_pred_depths = [] 82 | 83 | print("==> Evaluating depth result...") 84 | for i in tqdm(range(pred_depths.shape[0])): 85 | if pred_depths[i].mean() != -1: 86 | gt_depth = gt_depths[i] 87 | gt_height, gt_width = gt_depth.shape[:2] 88 | 89 | # resizing prediction (based on inverse depth) 90 | pred_inv_depth = 1 / (pred_depths[i] + 1e-6) 91 | pred_inv_depth = cv2.resize(pred_inv_depth, (gt_width, gt_height)) 92 | pred_depth = 1 / (pred_inv_depth + 1e-6) 93 | 94 | mask = np.logical_and(gt_depth > self.min_depth, gt_depth < self.max_depth) 95 | val_pred_depth = pred_depth[mask] 96 | val_gt_depth = gt_depth[mask] 97 | 98 | # median scaling is used for monocular evaluation 99 | ratio = 1 100 | if eval_mono: 101 | ratio = np.median(val_gt_depth) / np.median(val_pred_depth) 102 | ratios.append(ratio) 103 | val_pred_depth *= ratio 104 | 105 | resized_pred_depths.append(pred_depth * ratio) 106 | 107 | val_pred_depth[val_pred_depth < self.min_depth] = self.min_depth 108 | val_pred_depth[val_pred_depth > self.max_depth] = self.max_depth 109 | 110 | errors.append(compute_depth_errors(val_gt_depth, val_pred_depth)) 111 | 112 | if eval_mono: 113 | ratios = np.array(ratios) 114 | med = np.median(ratios) 115 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med))) 116 | print(" Scaling ratios | mean: {:0.3f} +- std: {:0.3f}".format(np.mean(ratios), np.std(ratios))) 117 | if args.ratio_name: 118 | np.savetxt(args.ratio_name, ratios, fmt='%.4f') 119 | 120 | mean_errors = np.array(errors).mean(0) 121 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 122 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 123 | 124 | return resized_pred_depths 125 | 126 | 127 | eval = DepthEval() 128 | eval.main() 129 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.backends.cudnn as cudnn 3 | import torch.utils.data 4 | 5 | from imageio import imsave 6 | import numpy as np 7 | from path import Path 8 | import argparse 9 | 10 | import sys 11 | sys.path.append('./common/') 12 | 13 | import models 14 | import utils.custom_transforms as custom_transforms 15 | from utils.utils import tensor2array 16 | 17 | import matplotlib as mpl 18 | import matplotlib.cm as cm 19 | 20 | parser = argparse.ArgumentParser(description='Inference script for DispNet learned with \ 21 | Structure from Motion Learner inference on KITTI Dataset', 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | 24 | parser.add_argument('data', metavar='DIR', help='path to dataset') 25 | parser.add_argument("--sequence", default='.', type=str, help="Dataset directory") 26 | parser.add_argument("--output-dir", default='output', type=str, help="Output directory") 27 | parser.add_argument("--img-exts", default='jpg', choices=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") 28 | 29 | parser.add_argument('--resnet-layers', required=True, type=int, default=18, choices=[18, 50], 30 | help='depth network architecture.') 31 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers') 32 | parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size') 33 | 34 | parser.add_argument('--interval', type=int, help='Interval of sequence', metavar='N', default=1) 35 | parser.add_argument('--sequence_length', type=int, help='Length of sequence', metavar='N', default=3) 36 | parser.add_argument('--pretrained-model', dest='pretrained_model', default=None, metavar='PATH', help='path to pre-trained model') 37 | parser.add_argument('--scene_type', type=str, choices=['indoor', 'outdoor'], default='indoor', required=True) 38 | 39 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 40 | 41 | def depth_visualizer(inv_depth): 42 | """ 43 | Args: 44 | data (HxW): depth data 45 | Returns: 46 | vis_data (HxWx3): depth visualization (RGB) 47 | """ 48 | inv_depth = inv_depth.squeeze().detach().cpu() 49 | vmax = np.percentile(inv_depth, 98) 50 | normalizer = mpl.colors.Normalize(vmin=inv_depth.min(), vmax=vmax) 51 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 52 | vis_data = (mapper.to_rgba(inv_depth)[:, :, :3] * 255).astype(np.uint8) 53 | return vis_data 54 | 55 | @torch.no_grad() 56 | def main(): 57 | args = parser.parse_args() 58 | 59 | ArrToTen = custom_transforms.ArrayToTensor(max_value=2**14) 60 | ArrToTen4Loss = custom_transforms.ArrayToTensorWithLocalWindow() 61 | TenThrRearran = custom_transforms.TensorThermalRearrange(bin_num=30, CLAHE_clip=2) 62 | normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) 63 | 64 | transform4input = custom_transforms.Compose([ArrToTen, normalize]) 65 | transform4loss = custom_transforms.Compose([ArrToTen4Loss, TenThrRearran, normalize]) 66 | 67 | from dataloader.VIVID_validation_folders import ValidationSet 68 | val_set = ValidationSet( 69 | args.data, 70 | tf_input = transform4input, 71 | tf_loss = transform4loss, 72 | sequence_length = args.sequence_length, 73 | interval = args.interval, 74 | scene_type = args.scene_type, 75 | inference_folder=args.sequence, 76 | ) 77 | 78 | val_loader = torch.utils.data.DataLoader( 79 | val_set, batch_size=args.batch_size, shuffle=False, 80 | num_workers=args.workers, pin_memory=True) 81 | 82 | # 1. Load models 83 | # create model 84 | print("=> creating model") 85 | disp_pose_net = models.DispPoseResNet(args.resnet_layers, False, num_channel=1).to(device) 86 | disp_net = disp_pose_net.DispResNet 87 | 88 | # load parameters 89 | print("=> using pre-trained weights for DispResNet") 90 | weights = torch.load(args.pretrained_model) 91 | disp_net.load_state_dict(weights['state_dict']) 92 | disp_net.eval() 93 | 94 | # 2. Load dataset 95 | output_dir = Path(args.output_dir+'/'+args.sequence) 96 | output_dir.makedirs_p() 97 | (output_dir/'thr_img').makedirs_p() 98 | (output_dir/'thr_invdepth').makedirs_p() 99 | 100 | disp_net.eval() 101 | 102 | for idx, (tgt_img, tgt_img_vis, depth_gt) in enumerate(val_loader): 103 | 104 | # original validate_with_gt param 105 | tgt_img = tgt_img.to(device) 106 | depth_gt = depth_gt.to(device) 107 | 108 | # compute output 109 | output_disp = disp_net(tgt_img) 110 | 111 | tgt_img_vis = (255*(0.45 + tgt_img_vis.squeeze().detach().cpu().numpy()*0.225)).astype(np.uint8) 112 | tgt_disp = depth_visualizer(output_disp) 113 | 114 | # Save images 115 | file_name = '{:06d}'.format(idx) 116 | imsave(output_dir/'thr_img'/'{}.{}'.format(file_name, args.img_exts), tgt_img_vis) 117 | imsave(output_dir/'thr_invdepth'/'{}.{}'.format(file_name, args.img_exts), tgt_disp) 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /dataloader/VIVID_validation_folders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | 5 | import math 6 | import random 7 | from imageio import imread 8 | from path import Path 9 | 10 | def load_as_float(path): 11 | return np.expand_dims(imread(path).astype(np.float32), axis=2) # HW -> HW1 12 | 13 | class ValidationSet(data.Dataset): 14 | """A sequence data loader where the files are arranged in this way: 15 | root/scene_1/Thermal/0000000.jpg 16 | root/scene_1/Depth/0000000.npy 17 | root/scene_1/Thermal/0000001.jpg 18 | root/scene_1/Depth/0000001.npy 19 | .. 20 | root/scene_2/0000000.jpg 21 | root/scene_2/0000000.npy 22 | . 23 | 24 | transform functions must take in a list a images and a numpy array which can be None 25 | """ 26 | 27 | def __init__(self, root, tf_input=None, tf_loss=None, inference_folder = '', sequence_length=3, interval=1, scene_type='indoor'): 28 | self.root = Path(root) 29 | 30 | if inference_folder == '' : 31 | if scene_type == 'indoor': 32 | folder_list_path = self.root/'val_indoor.txt' 33 | elif scene_type == 'outdoor': 34 | folder_list_path = self.root/'val_outdoor.txt' 35 | self.folders = [self.root/folder[:-1] for folder in open(folder_list_path)] 36 | else: 37 | self.folders = [self.root/inference_folder] 38 | 39 | self.tf_input = tf_input 40 | self.tf_loss = tf_loss 41 | 42 | self.crawl_folders(sequence_length, interval) 43 | 44 | def crawl_folders(self, sequence_length=3, interval=1): 45 | sequence_set = [] 46 | demi_length = (sequence_length-1)//2 + interval - 1 47 | shifts = list(range(-demi_length, demi_length + 1)) 48 | for i in range(1, 2*demi_length): 49 | shifts.pop(1) 50 | 51 | for folder in self.folders: 52 | imgs = sorted((folder/"Thermal").files('*.png')) 53 | for i in range(demi_length, len(imgs)-demi_length): 54 | depth = folder/"Depth_T"/(imgs[i].name[:-4] + '.npy') 55 | sample = {'tgt_img': imgs[i], 'tgt_depth': depth } 56 | sequence_set.append(sample) 57 | 58 | self.samples = sequence_set 59 | 60 | def __getitem__(self, index): 61 | sample = self.samples[index] 62 | tgt_img = load_as_float(sample['tgt_img']) 63 | depth = np.load(sample['tgt_depth']).astype(np.float32) 64 | 65 | img_input, _ = self.tf_input([tgt_img], None) 66 | img_loss, _ = self.tf_loss([tgt_img], None) # used for visualization only 67 | 68 | tgt_img_input = img_input[0] 69 | tgt_img_loss = img_loss[0] 70 | 71 | return tgt_img_input, tgt_img_loss, depth 72 | 73 | def __len__(self): 74 | return len(self.samples) 75 | 76 | 77 | class ValidationSetPose(data.Dataset): 78 | """A sequence data loader where the files are arranged in this way: 79 | root/scene_1/Thermal/0000000.jpg 80 | root/scene_1/Depth/0000000.npy 81 | root/scene_1/Thermal/0000001.jpg 82 | root/scene_1/Depth/0000001.npy 83 | .. 84 | root/scene_2/0000000.jpg 85 | root/scene_2/0000000.npy 86 | . 87 | 88 | transform functions must take in a list a images and a numpy array which can be None 89 | """ 90 | def __init__(self, root, tf_input=None, sequence_length=3, interval=1, scene_type='indoor'): 91 | self.root = Path(root) 92 | if scene_type == 'indoor': 93 | scene_list_path = self.root/'val_indoor.txt' 94 | elif scene_type == 'outdoor': 95 | scene_list_path = self.root/'val_outdoor.txt' 96 | 97 | self.folders = [self.root/folder[:-1] for folder in open(scene_list_path)] 98 | self.tf_input = tf_input 99 | self.crawl_folders(sequence_length, step=1) 100 | 101 | def crawl_folders(self, sequence_length=3, step=1): 102 | sequence_set = [] 103 | demi_length = (sequence_length - 1) // 2 104 | shift_range = np.array([step*i for i in range(-demi_length, demi_length + 1)]).reshape(1, -1) 105 | 106 | for folder in self.folders: 107 | imgs = sorted((folder/"Thermal").files('*.png')) 108 | poses = np.genfromtxt(folder/'poses_T.txt').astype(np.float64).reshape(-1, 3, 4) 109 | 110 | # construct 5-snippet sequences 111 | tgt_indices = np.arange(demi_length, len(imgs) - demi_length).reshape(-1, 1) 112 | snippet_indices = shift_range + tgt_indices 113 | 114 | for indices in snippet_indices : 115 | sample = {'imgs' : [], 'poses' : []} 116 | for i in indices : 117 | sample['imgs'].append(imgs[i]) 118 | sample['poses'].append(poses[i]) 119 | sequence_set.append(sample) 120 | 121 | self.samples = sequence_set 122 | 123 | def __getitem__(self, index): 124 | sample = self.samples[index] 125 | imgs = [load_as_float(img) for img in sample['imgs']] 126 | imgs, _ = self.tf_input(imgs, None) 127 | 128 | poses = np.stack([pose for pose in sample['poses']]) 129 | first_pose = poses[0] 130 | poses[:,:,-1] -= first_pose[:,-1] 131 | compensated_poses = np.linalg.inv(first_pose[:,:3]) @ poses 132 | 133 | return imgs, compensated_poses 134 | 135 | def __len__(self): 136 | return len(self.samples) 137 | -------------------------------------------------------------------------------- /test_pose.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from skimage.transform import resize as imresize 3 | import numpy as np 4 | from path import Path 5 | import argparse 6 | from tqdm import tqdm 7 | # from imageio import imread 8 | 9 | import sys 10 | sys.path.append('./common/') 11 | import models 12 | from loss.inverse_warp import pose_vec2mat 13 | 14 | parser = argparse.ArgumentParser(description='Script for PoseNet testing with corresponding groundTruth from KITTI Odometry', 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument("--pretrained-posenet", required=True, type=str, help="pretrained PoseNet path") 17 | parser.add_argument("--img-height", default=256, type=int, help="Image height") 18 | parser.add_argument("--img-width", default=320, type=int, help="Image width") 19 | parser.add_argument("--no-resize", action='store_true', help="no resizing is done") 20 | 21 | parser.add_argument("--dataset-dir", type=str, help="Dataset directory") 22 | parser.add_argument('--sequence-length', type=int, metavar='N', help='sequence length for testing', default=5) 23 | parser.add_argument("--sequences", default=['indoor_aggresive_dark'], type=str, nargs='*', help="sequences to test") 24 | parser.add_argument("--output-dir", default=None, type=str, help="Output directory for saving predictions in a big 3D numpy file") 25 | parser.add_argument('--resnet-layers', required=True, type=int, default=18, choices=[18, 50], help='depth network architecture.') 26 | 27 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 28 | 29 | def load_tensor_image(img, args): 30 | h,w,_ = img.shape 31 | if (h != args.img_height or w != args.img_width): 32 | img = imresize(img, (args.img_height, args.img_width)).astype(np.float32) 33 | img = np.transpose(img, (2, 0, 1)) 34 | img = (torch.from_numpy(img).float() / 2**14) 35 | tensor_img = ((img.unsqueeze(0)-0.45)/0.225).to(device) 36 | return tensor_img 37 | 38 | @torch.no_grad() 39 | def main(): 40 | args = parser.parse_args() 41 | 42 | # load models 43 | disp_pose_net = models.DispPoseResNet(args.resnet_layers, False, num_channel=1).to(device) 44 | pose_net = disp_pose_net.PoseResNet 45 | 46 | weights = torch.load(args.pretrained_posenet) 47 | pose_net.load_state_dict(weights['state_dict'], strict=False) 48 | pose_net.eval() 49 | 50 | seq_length = 5 51 | 52 | # load data loader 53 | from eval_vivid.pose_evaluation_utils import test_framework_VIVID as test_framework 54 | dataset_dir = Path(args.dataset_dir) 55 | framework = test_framework(dataset_dir, args.sequences, seq_length=seq_length, step=1) 56 | 57 | print('{} snippets to test'.format(len(framework))) 58 | errors = np.zeros((len(framework), 2), np.float32) 59 | if args.output_dir is not None: 60 | output_dir = Path(args.output_dir) 61 | output_dir.makedirs_p() 62 | predictions_array = np.zeros((len(framework), seq_length, 3, 4)) 63 | 64 | for j, sample in enumerate(tqdm(framework)): 65 | imgs = sample['imgs'] 66 | squence_imgs = [] 67 | for i, img in enumerate(imgs): 68 | img = load_tensor_image(img, args) 69 | squence_imgs.append(img) 70 | 71 | global_pose = np.eye(4) 72 | poses = [] 73 | poses.append(global_pose[0:3, :]) 74 | 75 | for iter in range(seq_length - 1): 76 | pose = pose_net(squence_imgs[iter], squence_imgs[iter + 1]) 77 | pose_mat = pose_vec2mat(pose).squeeze(0).cpu().numpy() 78 | 79 | pose_mat = np.vstack([pose_mat, np.array([0, 0, 0, 1])]) 80 | global_pose = global_pose @ np.linalg.inv(pose_mat) 81 | poses.append(global_pose[0:3, :]) 82 | 83 | final_poses = np.stack(poses, axis=0) 84 | 85 | if args.output_dir is not None: 86 | predictions_array[j] = final_poses 87 | 88 | ATE, RE = compute_pose_error(sample['poses'], final_poses) 89 | errors[j] = ATE, RE 90 | 91 | mean_errors = errors.mean(0) 92 | std_errors = errors.std(0) 93 | error_names = ['ATE', 'RE'] 94 | print('') 95 | print("Results") 96 | print("\t {:>10}, {:>10}".format(*error_names)) 97 | print("mean \t {:10.4f}, {:10.4f}".format(*mean_errors)) 98 | print("std \t {:10.4f}, {:10.4f}".format(*std_errors)) 99 | 100 | if args.output_dir is not None: 101 | np.save(output_dir/'predictions.npy', predictions_array) 102 | 103 | def compute_pose_error(gt, pred): 104 | RE = 0 105 | snippet_length = gt.shape[0] 106 | scale_factor = np.sum(gt[:, :, -1] * pred[:, :, -1])/np.sum(pred[:, :, -1] ** 2) 107 | ATE = np.linalg.norm((gt[:, :, -1] - scale_factor * pred[:, :, -1]).reshape(-1)) 108 | for gt_pose, pred_pose in zip(gt, pred): 109 | # Residual matrix to which we compute angle's sin and cos 110 | R = gt_pose[:, :3] @ np.linalg.inv(pred_pose[:, :3]) 111 | s = np.linalg.norm([R[0, 1]-R[1, 0], 112 | R[1, 2]-R[2, 1], 113 | R[0, 2]-R[2, 0]]) 114 | c = np.trace(R) - 1 115 | # Note: we actually compute double of cos and sin, but arctan2 is invariant to scale 116 | RE += np.arctan2(s, c) 117 | 118 | return ATE/snippet_length, RE/snippet_length 119 | 120 | 121 | def compute_pose(pose_net, tgt_img, ref_imgs): 122 | poses = [] 123 | for ref_img in ref_imgs: 124 | pose = pose_net(tgt_img, ref_img).unsqueeze(1) 125 | poses.append(pose) 126 | poses = torch.cat(poses, 1) 127 | return poses 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /common/utils/custom_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import random 4 | import numpy as np 5 | import cv2 6 | 7 | '''Set of tranform random routines that takes list of inputs as arguments, 8 | in order to have random but coherent transformations.''' 9 | 10 | class Compose(object): 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, images, intrinsics): 15 | for t in self.transforms: 16 | images, intrinsics = t(images, intrinsics) 17 | return images, intrinsics 18 | 19 | class Normalize(object): 20 | def __init__(self, mean, std): 21 | self.mean = mean 22 | self.std = std 23 | 24 | def __call__(self, images, intrinsics): 25 | for tensor in images: 26 | for t, m, s in zip(tensor, self.mean, self.std): 27 | t.sub_(m).div_(s) 28 | return images, intrinsics 29 | 30 | class ArrayToTensor(object): 31 | """Converts a list of numpy.ndarray (H x W x C) along with a intrinsics matrix to a list of torch.FloatTensor of shape (C x H x W) with a intrinsics tensor.""" 32 | def __init__(self, max_value=255): 33 | self.max_value = max_value 34 | 35 | def __call__(self, images, intrinsics): 36 | tensors = [] 37 | for im in images: 38 | # put it from HWC to CHW format 39 | im = np.transpose(im, (2, 0, 1)) 40 | # handle numpy array 41 | tensors.append(torch.from_numpy(im).float()/self.max_value) 42 | return tensors, intrinsics 43 | 44 | 45 | class ArrayToTensorWithLocalWindow(object): 46 | """Converts a list of numpy.ndarray (H x W x C) along with a intrinsics matrix to a list of torch.FloatTensor of shape (C x H x W) with a intrinsics tensor.""" 47 | def __init__(self, trust_ratio=1.00): 48 | self.max_ratio = trust_ratio 49 | self.min_ratio = 1.0-trust_ratio 50 | 51 | def __call__(self, images, intrinsics): 52 | tensors = [] 53 | # Decide min-max values of local image window 54 | tmin = 0 55 | tmax = 0 56 | for im in images : 57 | im = im.squeeze() #HW 58 | im_srt = np.sort(im.reshape(-1)) 59 | tmax += im_srt[round(len(im_srt)*self.max_ratio)-1] 60 | tmin += im_srt[round(len(im_srt)*self.min_ratio)] 61 | tmax /= len(images) 62 | tmin /= len(images) 63 | 64 | for im in images: 65 | # put it from HWC to CHW format 66 | im = np.transpose(im, (2, 0, 1)) 67 | im = torch.clamp(torch.from_numpy(im).float(), tmin, tmax) 68 | tensors.append((im - tmin)/(tmax - tmin)) #CHW 69 | return tensors, intrinsics 70 | 71 | class TensorThermalRearrange(object): 72 | def __init__(self, bin_num = 30, CLAHE_clip = 2, CLAHE_tilesize = 8): 73 | self.bins = bin_num 74 | self.CLAHE = cv2.createCLAHE(clipLimit=CLAHE_clip, tileGridSize=(CLAHE_tilesize,CLAHE_tilesize)) 75 | def __call__(self, images, intrinsics): 76 | imgs = [] 77 | 78 | tmp_img = torch.cat(images, axis=0) 79 | hist = torch.histc(tmp_img, bins=self.bins) 80 | imgs_max = tmp_img.max() 81 | imgs_min = tmp_img.min() 82 | itv = (imgs_max - imgs_min)/self.bins 83 | total_num = hist.sum() 84 | 85 | for im in images : #CHW 86 | _,H,W = im.shape 87 | mul_mask_ = torch.zeros((self.bins,H,W)) 88 | sub_mask_ = torch.zeros((self.bins,H,W)) 89 | subhist_new_min = imgs_min.clone() 90 | 91 | for x in range(0,self.bins) : 92 | subhist = (im > imgs_min+itv*x) & (im <= imgs_min+itv*(x+1)) 93 | if (subhist.sum() == 0): 94 | continue 95 | subhist_new_itv = hist[x]/total_num 96 | mul_mask_[x,...] = subhist * (subhist_new_itv / itv) 97 | sub_mask_[x,...] = subhist * (subhist_new_itv / itv * -(imgs_min+itv*x) + subhist_new_min) 98 | subhist_new_min += subhist_new_itv 99 | 100 | mul_mask = mul_mask_.sum(axis=0, keepdim=True).detach() 101 | sub_mask = sub_mask_.sum(axis=0, keepdim=True).detach() 102 | im_ = mul_mask*im + sub_mask 103 | 104 | im_ = self.CLAHE.apply((im_.squeeze()*255).numpy().astype(np.uint8)).astype(np.float32) 105 | im_ = np.expand_dims(im_, axis=2) 106 | img_out = torch.from_numpy(np.transpose(im_/255., (2, 0, 1))) 107 | imgs.append(img_out) #CHW 108 | return imgs, intrinsics 109 | 110 | class RandomHorizontalFlip(object): 111 | """Randomly horizontally flips the given numpy array with a probability of 0.5""" 112 | 113 | def __call__(self, images, intrinsics): 114 | assert intrinsics is not None 115 | if random.random() < 0.5: 116 | output_intrinsics = np.copy(intrinsics) 117 | output_images = [np.copy(np.fliplr(im)) for im in images] 118 | w = output_images[0].shape[1] 119 | output_intrinsics[:, 0, 2] = w - output_intrinsics[:, 0, 2] 120 | else: 121 | output_images = images 122 | output_intrinsics = intrinsics 123 | return output_images, output_intrinsics 124 | 125 | class RandomScaleCrop(object): 126 | """Randomly zooms images up to 15% and crop them to keep same size as before.""" 127 | 128 | def __call__(self, images, intrinsics): 129 | assert intrinsics is not None 130 | output_intrinsics = np.copy(intrinsics) 131 | 132 | in_h, in_w, ch = images[0].shape 133 | x_scaling, y_scaling = np.random.uniform(1, 1.15, 2) 134 | scaled_h, scaled_w = int(in_h * y_scaling), int(in_w * x_scaling) 135 | 136 | output_intrinsics[:,0] *= x_scaling 137 | output_intrinsics[:,1] *= y_scaling 138 | 139 | if ch == 1: 140 | scaled_images = [np.expand_dims(cv2.resize(im, (scaled_w, scaled_h)), axis=2) for im in images] 141 | else : 142 | scaled_images = [cv2.resize(im, (scaled_w, scaled_h)) for im in images] 143 | 144 | offset_y = np.random.randint(scaled_h - in_h + 1) 145 | offset_x = np.random.randint(scaled_w - in_w + 1) 146 | cropped_images = [im[offset_y:offset_y + in_h, offset_x:offset_x + in_w] for im in scaled_images] 147 | 148 | output_intrinsics[:, 0, 2] -= offset_x 149 | output_intrinsics[:, 1, 2] -= offset_y 150 | 151 | return cropped_images, output_intrinsics -------------------------------------------------------------------------------- /common/data_prepare/prepare_train_data_VIVID.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy.misc 3 | import numpy as np 4 | from pebble import ProcessPool 5 | import sys 6 | from tqdm import tqdm 7 | from path import Path 8 | import cv2 9 | import shutil 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("dataset_dir", metavar='DIR', 13 | help='path to original dataset') 14 | parser.add_argument("--with-depth", action='store_true', 15 | help="If available, will store depth ground truth along with images, for validation") 16 | parser.add_argument("--with-pose", action='store_true', 17 | help="If available, will store pose ground truth along with images, for validation") 18 | parser.add_argument("--no-train-gt", action='store_true', 19 | help="If selected, will delete ground truth depth to save space") 20 | parser.add_argument("--dump-root", type=str, default='dump', help="Where to dump the data") 21 | parser.add_argument("--height", type=int, default=256, help="image height") 22 | parser.add_argument("--width", type=int, default=320, help="image width") 23 | parser.add_argument("--num-threads", type=int, default=4, help="number of threads to use") 24 | 25 | args = parser.parse_args() 26 | 27 | def dump_example(args, scene): 28 | scene_name = scene.split('/')[-1] 29 | scene_data = data_loader.collect_scenes(scene,) 30 | dump_dir = args.dump_root/scene_data['rel_path'] 31 | dump_dir.makedirs_p() 32 | dump_dir_ther = dump_dir/"Thermal" 33 | dump_dir_rgb = dump_dir/"RGB" 34 | dump_dir_depth_T = dump_dir/"Depth_T" 35 | dump_dir_depth_RGB = dump_dir/"Depth_RGB" 36 | dump_dir_ther.makedirs_p() 37 | dump_dir_rgb.makedirs_p() 38 | dump_dir_depth_T.makedirs_p() 39 | dump_dir_depth_RGB.makedirs_p() 40 | 41 | # save intrinsic param 42 | intrinsics_T = scene_data['intrinsics_T'] 43 | dump_cam_file = dump_dir/'cam_T.txt' 44 | np.savetxt(dump_cam_file, intrinsics_T) 45 | 46 | intrinsics_RGB = scene_data['intrinsics_RGB'] 47 | dump_cam_file = dump_dir/'cam_RGB.txt' 48 | np.savetxt(dump_cam_file, intrinsics_RGB) 49 | 50 | extrinsics_T2RGB = scene_data['Tr_T2RGB'] 51 | dump_cam_file = dump_dir/'Tr_T2RGB.txt' 52 | np.savetxt(dump_cam_file, extrinsics_T2RGB) 53 | 54 | poses_T_file = dump_dir/'poses_T.txt' 55 | poses_RGB_file = dump_dir/'poses_RGB.txt' 56 | 57 | poses_T = [] 58 | poses_RGB = [] 59 | 60 | # save each files + pose + depth 61 | for sample in data_loader.get_scene_imgs(scene_data): 62 | frame_nb = sample["id"] 63 | cv2.imwrite(dump_dir_rgb/'{}.png'.format(frame_nb), sample['Img_RGB']) 64 | cv2.imwrite(dump_dir_ther/'{}.png'.format(frame_nb), sample['Img_Ther']) 65 | 66 | if "pose_T" in sample.keys(): 67 | poses_T.append(sample["pose_T"].tolist()) 68 | poses_RGB.append(sample["pose_RGB"].tolist()) 69 | if "depth_T" in sample.keys(): 70 | dump_depth_T_file = dump_dir_depth_T/'{}.npy'.format(frame_nb) 71 | np.save(dump_depth_T_file, sample["depth_T"]) 72 | dump_depth_RGB_file = dump_dir_depth_RGB/'{}.npy'.format(frame_nb) 73 | np.save(dump_depth_RGB_file, sample["depth_RGB"]) 74 | 75 | if len(poses_T) != 0: 76 | np.savetxt(poses_T_file, np.array(poses_T).reshape(-1, 12), fmt='%.6e') 77 | np.savetxt(poses_RGB_file, np.array(poses_RGB).reshape(-1, 12), fmt='%.6e') 78 | 79 | if len(dump_dir_rgb.files('*.png')) < 3: 80 | dump_dir.rmtree() 81 | 82 | def extract_well_lit_images(args): 83 | tgt_dir = args.dump_root/'indoor_robust_varying' 84 | dump_dir = args.dump_root/'indoor_robust_varying_well_lit' 85 | dump_dir.makedirs_p() 86 | dump_dir_ther = dump_dir/"Thermal" 87 | dump_dir_rgb = dump_dir/"RGB" 88 | dump_dir_depth_T = dump_dir/"Depth_T" 89 | dump_dir_depth_RGB = dump_dir/"Depth_RGB" 90 | dump_dir_ther.makedirs_p() 91 | dump_dir_rgb.makedirs_p() 92 | dump_dir_depth_T.makedirs_p() 93 | dump_dir_depth_RGB.makedirs_p() 94 | 95 | # read well-lit image list 96 | img_list = np.genfromtxt('./common/data_prepare/well_lit_from_varying.txt').astype(int) # 97 | 98 | for frame_nb in img_list : 99 | dump_img_T_file = dump_dir_ther/'{:06d}.png'.format(frame_nb) 100 | dump_img_RGB_file = dump_dir_rgb/'{:06d}.png'.format(frame_nb) 101 | dump_depth_T_file = dump_dir_depth_T/'{:06d}.npy'.format(frame_nb) 102 | dump_depth_RGB_file = dump_dir_depth_RGB/'{:06d}.npy'.format(frame_nb) 103 | 104 | shutil.copy(tgt_dir/"Thermal"/'{:06d}.png'.format(frame_nb), dump_img_T_file) 105 | shutil.copy(tgt_dir/"RGB"/'{:06d}.png'.format(frame_nb), dump_img_RGB_file) 106 | shutil.copy(tgt_dir/"Depth_T"/'{:06d}.npy'.format(frame_nb), dump_depth_T_file) 107 | shutil.copy(tgt_dir/"Depth_RGB"/'{:06d}.npy'.format(frame_nb), dump_depth_RGB_file) 108 | 109 | def main(): 110 | args.dump_root = Path(args.dump_root) 111 | args.dump_root.mkdir_p() 112 | 113 | global data_loader 114 | 115 | from VIVID_raw_loader import VIVIDRawLoader 116 | data_loader = VIVIDRawLoader(args.dataset_dir, 117 | img_height=args.height, 118 | img_width=args.width, 119 | get_depth=args.with_depth, 120 | get_pose=args.with_pose) 121 | 122 | n_scenes = len(data_loader.scenes) 123 | print('Found {} potential scenes'.format(n_scenes)) 124 | print('Retrieving frames') 125 | if args.num_threads == 1: 126 | for scene in tqdm(data_loader.scenes): 127 | dump_example(args, scene) 128 | else: 129 | with ProcessPool(max_workers=args.num_threads) as pool: 130 | tasks = pool.map(dump_example, [args]*n_scenes, data_loader.scenes) 131 | try: 132 | for _ in tqdm(tasks.result(), total=n_scenes): 133 | pass 134 | except KeyboardInterrupt as e: 135 | tasks.cancel() 136 | raise e 137 | 138 | print('Extracting well-lit image from varying illumination set') 139 | extract_well_lit_images(args) 140 | 141 | print('Generating train val lists') 142 | with open(args.dump_root / 'train_indoor.txt', 'w') as tf: 143 | for seq in data_loader.indoor_train_list : 144 | tf.write('{}\n'.format(seq)) 145 | with open(args.dump_root / 'val_indoor.txt', 'w') as tf: 146 | for seq in data_loader.indoor_val_list : 147 | tf.write('{}\n'.format(seq)) 148 | with open(args.dump_root / 'test_indoor.txt', 'w') as tf: 149 | for seq in data_loader.indoor_test_list : 150 | tf.write('{}\n'.format(seq)) 151 | with open(args.dump_root / 'train_outdoor.txt', 'w') as tf: 152 | for seq in data_loader.outdoor_train_list : 153 | tf.write('{}\n'.format(seq)) 154 | with open(args.dump_root / 'val_outdoor.txt', 'w') as tf: 155 | for seq in data_loader.outdoor_val_list : 156 | tf.write('{}\n'.format(seq)) 157 | with open(args.dump_root / 'test_outdoor.txt', 'w') as tf: 158 | for seq in data_loader.outdoor_test_list : 159 | tf.write('{}\n'.format(seq)) 160 | 161 | print('Done!') 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Maximizing Self-supervision from Thermal Image for Effective Self-supervised Learning of Depth and Ego-motion 2 | 3 | This github is a official implementation of the paper: 4 | 5 | >Maximizing Self-supervision from Thermal Image for Effective Self-supervised Learning of Depth and Ego-motion 6 | > 7 | >[Ukcheol Shin](https://ukcheolshin.github.io/), [Kyunghyun Lee](https://scholar.google.co.kr/citations?user=WOBfTQoAAAAJ&hl=ko&oi=ao), [Byeong-Uk Lee](https://sites.google.com/view/bulee), In So Kweon 8 | > 9 | >**Robotics and Automation Letter 2022 & IROS 2022** 10 | > 11 | >[[PDF](https://arxiv.org/abs/2201.04387)] [[Project webpage](https://sites.google.com/view/thermal-monodepth-project-page)] [[Full paper](https://arxiv.org/abs/2201.04387)] [[Youtube](https://youtu.be/qIBcOuLYr70)] 12 | 13 | ## Introduction 14 | Recently, self-supervised learning of depth and ego-motion from thermal images shows strong robustness and reliability under challenging lighting and weather conditions. 15 | However, the inherent thermal image properties such as weak contrast, blurry edges, and noise hinder to generate effective self-supervision from thermal images. 16 | Therefore, most previous researches just rely on additional self-supervisory sources such as RGB video, generative models, and Lidar information. 17 | In this paper, we conduct an in-depth analysis of thermal image characteristics that degenerates self-supervision from thermal images. 18 | Based on the analysis, we propose an effective thermal image mapping method that significantly increases image information, such as overall structure, contrast, and details, while preserving temporal consistency. 19 | By resolving the fundamental problem of the thermal image, our depth and pose network trained only with thermal images achieves state-of-the-art results without utilizing any extra self-supervisory source. 20 | As our best knowledge, this work is the first self-supervised learning approach to train monocular depth and relative pose networks with only thermal images. 21 | 22 |
23 |
24 |
25 |
28 |
29 |
31 |
32 |