├── architecture ├── __init__.py ├── modeling │ ├── backbone │ │ ├── utils │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── backbone.py │ │ └── builder.py │ ├── aggregation │ │ ├── TemporalStereo │ │ │ ├── __init__.py │ │ │ ├── precise.py │ │ │ ├── coarse.py │ │ │ ├── fine.py │ │ │ └── TemporalStereo.py │ │ ├── __init__.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── SPP3D.py │ │ │ ├── cat_fms.py │ │ │ ├── correlation.py │ │ │ ├── dif_fms.py │ │ │ ├── block_cost.py │ │ │ └── raft_corr.py │ │ └── builder.py │ ├── losses │ │ ├── __init__.py │ │ ├── smooth_l1_loss.py │ │ └── warsserstein_distance_loss.py │ ├── prediction │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── argmin.py │ │ └── soft_argmin.py │ ├── __init__.py │ └── layers │ │ ├── __init__.py │ │ ├── conv_gru.py │ │ └── inverse_warp_3d.py ├── data │ ├── datasets │ │ ├── vkitti │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── vkitti_2.py │ │ ├── tartanair │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── tartanair.py │ │ ├── scene_flow │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ └── scene_flow.py │ │ ├── kitti │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── kitti2015.py │ │ │ └── kittiraw.py │ │ ├── __init__.py │ │ └── builder.py │ ├── utils │ │ ├── calibration │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ └── kitti_calib.py │ │ ├── __init__.py │ │ ├── load_eth3d.py │ │ ├── load_scene_flow.py │ │ ├── load_tartanair.py │ │ ├── load_disparity.py │ │ ├── load_drivingstereo.py │ │ ├── load_vkitti.py │ │ ├── load_kitti.py │ │ └── load_flow.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── flow_eval.py │ │ ├── pixel_error.py │ │ ├── flow_pixel_error.py │ │ └── eval.py │ └── __init__.py └── utils │ ├── visualization │ ├── __init__.py │ ├── colormap.py │ ├── flow_colormap.py │ └── disparity_colormap.py │ ├── __init__.py │ ├── config.py │ └── time_test_template.py ├── media └── architecture.png ├── projects └── TemporalStereo │ ├── video.sh │ ├── submit.sh │ ├── demo.sh │ ├── configs │ ├── tartanair.yaml │ ├── sceneflow.yaml │ ├── tartanair_full.yaml │ ├── kittiraw-multi.yaml │ ├── kitti2015-multi.yaml │ └── kitti2015.yaml │ ├── logger.py │ ├── dist_train.py │ └── config.py ├── requirements.txt ├── .gitignore └── README.md /architecture/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /architecture/modeling/backbone/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /architecture/modeling/aggregation/TemporalStereo/__init__.py: -------------------------------------------------------------------------------- 1 | from .TemporalStereo import TEMPORALSTEREO -------------------------------------------------------------------------------- /media/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/TemporalStereo/HEAD/media/architecture.png -------------------------------------------------------------------------------- /architecture/data/datasets/vkitti/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import VKITTIStereoDatasetBase 2 | from .vkitti_2 import VKITTI2StereoDataset -------------------------------------------------------------------------------- /projects/TemporalStereo/video.sh: -------------------------------------------------------------------------------- 1 | echo Starting inference on a stereo video... 2 | 3 | python video_inference.py 4 | 5 | echo Done! -------------------------------------------------------------------------------- /architecture/data/datasets/tartanair/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import TARTANAIRStereoDatasetBase 2 | from .tartanair import TARTANAIRStereoDataset -------------------------------------------------------------------------------- /architecture/data/datasets/scene_flow/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import SceneFlowStereoDatasetBase 2 | from .scene_flow import SceneFlowStereoDataset -------------------------------------------------------------------------------- /architecture/modeling/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .smooth_l1_loss import DispSmoothL1Loss 2 | from .warsserstein_distance_loss import WarssersteinDistanceLoss -------------------------------------------------------------------------------- /architecture/modeling/prediction/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import PREDICTION_REGISTRY, build_prediction 2 | from .soft_argmin import SOFTARGMIN 3 | from .argmin import ARGMIN -------------------------------------------------------------------------------- /architecture/data/datasets/kitti/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import KITTIStereoDatasetBase 2 | from .kitti2015 import KITTI2015StereoDataset 3 | from .kittiraw import KITTIRAWStereoDataset -------------------------------------------------------------------------------- /architecture/data/utils/calibration/__init__.py: -------------------------------------------------------------------------------- 1 | from .projection import Projection 2 | from .kitti_calib import load_calib 3 | from .utils import trans2homo4x4, cart_to_homo, load_velodyne -------------------------------------------------------------------------------- /architecture/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import build_backbone 2 | from .aggregation import build_aggregation 3 | from .prediction import build_prediction 4 | from .losses import DispSmoothL1Loss, WarssersteinDistanceLoss -------------------------------------------------------------------------------- /architecture/data/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixel_error import calc_error 2 | from .eval import do_evaluation, do_occlusion_evaluation 3 | from .flow_pixel_error import flow_calc_error 4 | from .flow_eval import do_flow_evaluation -------------------------------------------------------------------------------- /architecture/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import Backbone 2 | from .builder import build_backbone, BACKBONE_REGISTRY 3 | from .TemporalStereo import TEMPORALSTEREO 4 | 5 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /architecture/utils/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_colormap import flow_err_to_color, flow_to_color, flow_max_rad 2 | from .disparity_colormap import disp_map, disp_to_color, disp_err_to_color, disp_err_to_colorbar 3 | from .colormap import colormap -------------------------------------------------------------------------------- /architecture/modeling/aggregation/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_aggregation, AGGREGATION_REGISTRY 2 | from .TemporalStereo import TEMPORALSTEREO 3 | 4 | from .utils import cat_fms, dif_fms, SPP3D, block_cost, CorrBlock, FlowCorrBlock, correlation, correlation1d 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | omegaconf 3 | opencv-python 4 | opt-einsum 5 | pafy 6 | pypng 7 | PyQt5 8 | pytorch-lightning==1.5.2 9 | PyYAML 10 | seaborn 11 | scikit-image 12 | scikit-learn 13 | tensorboard 14 | thop 15 | tqdm 16 | typing-extensions 17 | yacs 18 | -------------------------------------------------------------------------------- /architecture/modeling/aggregation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .cat_fms import cat_fms 2 | from .dif_fms import dif_fms 3 | from .SPP3D import SPP3D 4 | 5 | from .block_cost import block_cost 6 | from .raft_corr import CorrBlock, FlowCorrBlock 7 | 8 | from .correlation import correlation, correlation1d 9 | -------------------------------------------------------------------------------- /architecture/modeling/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .inverse_warp import inverse_warp, project_to_3d, mesh_grid 2 | from .inverse_warp_3d import inverse_warp_3d 3 | from .basic_layers import Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, get_norm, get_activation 4 | from .conv_gru import ConvGRU 5 | from .softsplat import ModuleSoftsplat, FunctionSoftsplat -------------------------------------------------------------------------------- /architecture/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import (read_vkitti_extrinsic, read_vkitti_intrinsic, read_vkitti_png_flow, read_vkitti_png_depth, 2 | read_eth3d_intrinsic, read_eth3d_pfm_disparity, 3 | read_sceneflow_extrinsic, read_sceneflow_pfm_flow, read_sceneflow_pfm_disparity, 4 | read_tartantic_intrinsic, read_tartanair_depth, read_tartanair_extrinsic, read_tartanair_flow) 5 | -------------------------------------------------------------------------------- /architecture/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .vkitti import VKITTI2StereoDataset, VKITTIStereoDatasetBase 2 | from .scene_flow import SceneFlowStereoDataset, SceneFlowStereoDatasetBase 3 | from .tartanair import TARTANAIRStereoDataset, TARTANAIRStereoDatasetBase 4 | from .kitti import KITTI2015StereoDataset, KITTIStereoDatasetBase, KITTIRAWStereoDataset 5 | from .base import StereoDatasetBase 6 | from .builder import build_stereo_dataset -------------------------------------------------------------------------------- /projects/TemporalStereo/submit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CONFIG=./configs/kitti2015-multi.yaml 4 | EXP_ROOT=./exps/TemporalStereo/kitti2015/multi 5 | DATA_ROOT=./datasets/KITTI-Multiview/KITTI-2015/ 6 | CKPT=$EXP_ROOT/ckpt_best.ckpt 7 | ANN=./splits/view_11_train_all.yaml 8 | 9 | python kitti_submission.py --config-file $CONFIG \ 10 | --checkpoint-path $CKPT \ 11 | --data-root $DATA_ROOT \ 12 | --annfile $ANN \ 13 | --resize-to-shape 384 1284 \ 14 | --log_dir $EXP_ROOT/output 15 | -------------------------------------------------------------------------------- /architecture/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import CfgNode 2 | from .time_test_template import timeTestTemplate 3 | from .visualization import (disp_to_color, disp_err_to_color, disp_err_to_colorbar, disp_map, 4 | flow_to_color, flow_err_to_color, flow_max_rad, 5 | colormap) 6 | 7 | __all__ = [ 8 | "CfgNode", 9 | "timeTestTemplate", 10 | "disp_map", "disp_err_to_colorbar", "disp_err_to_color", "disp_to_color", 11 | "flow_err_to_color", "flow_to_color", 'flow_max_rad', 12 | "colormap", 13 | ] -------------------------------------------------------------------------------- /architecture/modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import torch.nn as nn 3 | 4 | __all__ = ["Backbone"] 5 | 6 | class Backbone(nn.Module, metaclass=ABCMeta): 7 | """ 8 | Abstract base class for network backbones. 9 | """ 10 | 11 | def __init__(self): 12 | """ 13 | The `__init__` method of any subclass can specify its own set of arguments. 14 | """ 15 | super().__init__() 16 | 17 | @abstractmethod 18 | def forward(self, *inputs): 19 | """ 20 | Subclasses must override this method, but adhere to the same return type. 21 | """ 22 | pass -------------------------------------------------------------------------------- /architecture/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_vkitti import read_vkitti_extrinsic, read_vkitti_intrinsic, read_vkitti_png_flow, read_vkitti_png_depth 2 | from .load_flow import load_flying_things_flow, load_flying_chairs_flow, load_kitti_flow, write_flo, write_flying_chairs_flow 3 | from .load_disparity import load_scene_flow_disp, load_eth3d_disp, load_middlebury_disp 4 | from .load_scene_flow import read_sceneflow_pfm_disparity, read_sceneflow_pfm_flow, read_sceneflow_extrinsic 5 | from .load_tartanair import read_tartanair_depth, read_tartanair_extrinsic, read_tartantic_intrinsic, read_tartanair_flow 6 | from .load_kitti import read_kitti_extrinsic, read_kitti_intrinsic, read_kitti_png_disparity 7 | from .load_eth3d import read_eth3d_intrinsic, read_eth3d_pfm_disparity -------------------------------------------------------------------------------- /projects/TemporalStereo/demo.sh: -------------------------------------------------------------------------------- 1 | CONFIG=./configs/sceneflow.yaml 2 | CKPT=./exps/TemporalStereo/sceneflow/epoch_best.ckpt 3 | LOGDIR=./exps/TemporalStereo/sceneflow/output/ 4 | DATA_ROOT=./datasets/SceneFlow/Flyingthings3D 5 | DATA_TYPE=SceneFlow 6 | ANNFILE=./splits/flyingthings3d/test.json 7 | H=544 8 | W=960 9 | DEVICE=cuda:0 10 | 11 | echo Starting running demo... 12 | 13 | python demo.py --config-file $CONFIG \ 14 | --checkpoint-path $CKPT \ 15 | --resize-to-shape $H $W \ 16 | --data-type $DATA_TYPE \ 17 | --data-root $DATA_ROOT\ 18 | --annfile $ANNFILE \ 19 | --device $DEVICE \ 20 | --log-dir $LOGDIR 21 | 22 | echo Results are saved to $LOGDIR. 23 | echo done! -------------------------------------------------------------------------------- /architecture/modeling/prediction/builder.py: -------------------------------------------------------------------------------- 1 | from detectron2.utils.registry import Registry 2 | 3 | 4 | PREDICTION_REGISTRY = Registry("PREDICTION") 5 | PREDICTION_REGISTRY.__doc__ = """ 6 | Registry for preditions, which predict disparity maps from images 7 | The registered object must be a callable that accepts two arguments: 8 | 1. A :class:`detectron2.config.CfgNode` 9 | Registered object must return instance of :class:`nn.Module`. 10 | """ 11 | 12 | 13 | def build_prediction(cfg): 14 | """ 15 | Build a prediction from `cfg.MODEL.PREDICTION.NAME`. 16 | Returns: 17 | an instance of :class:`nn.Module` 18 | """ 19 | 20 | prediction_name = cfg.MODEL.PREDICTION.NAME 21 | prediction = PREDICTION_REGISTRY.get(prediction_name)(cfg) 22 | return prediction 23 | 24 | 25 | -------------------------------------------------------------------------------- /architecture/modeling/aggregation/builder.py: -------------------------------------------------------------------------------- 1 | from detectron2.utils.registry import Registry 2 | 3 | AGGREGATION_REGISTRY = Registry("AGGREGATION") 4 | AGGREGATION_REGISTRY.__doc__ = """ 5 | Registry for cost aggregation, which estimate aggregated cost volume from images 6 | The registered object must be a callable that accepts one arguments: 7 | 1. A :class:`detectron2.config.CfgNode` 8 | Registered object must return instance of :class:`nn.Module`. 9 | """ 10 | 11 | 12 | def build_aggregation(cfg): 13 | """ 14 | Build a cost aggregation predictor from `cfg.MODEL.AGGREGATION.NAME`. 15 | Returns: 16 | an instance of :class:`nn.Module` 17 | """ 18 | 19 | aggregation_name = cfg.MODEL.AGGREGATION.NAME 20 | aggregation_predictor = AGGREGATION_REGISTRY.get(aggregation_name)(cfg) 21 | return aggregation_predictor -------------------------------------------------------------------------------- /architecture/modeling/backbone/builder.py: -------------------------------------------------------------------------------- 1 | from detectron2.utils.registry import Registry 2 | 3 | from .backbone import Backbone 4 | 5 | BACKBONE_REGISTRY = Registry("BACKBONE") 6 | BACKBONE_REGISTRY.__doc__ = """ 7 | Registry for backbones, which extract feature maps from images 8 | The registered object must be a callable that accepts two arguments: 9 | 1. A :class:`detectron2.config.CfgNode` 10 | 2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. 11 | Registered object must return instance of :class:`Backbone`. 12 | """ 13 | 14 | 15 | def build_backbone(cfg): 16 | """ 17 | Build a backbone from `cfg.MODEL.BACKBONE.NAME`. 18 | Returns: 19 | an instance of :class:`Backbone` 20 | """ 21 | 22 | backbone_name = cfg.MODEL.BACKBONE.NAME 23 | backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg) 24 | assert isinstance(backbone, Backbone) 25 | return backbone 26 | 27 | 28 | -------------------------------------------------------------------------------- /architecture/modeling/layers/conv_gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvGRU(nn.Module): 5 | def __init__(self, 6 | in_planes: int, 7 | hidden_planes: int): 8 | super(ConvGRU, self).__init__() 9 | 10 | self.in_planes = in_planes 11 | self.hidden_planes = hidden_planes 12 | 13 | fuse_planes = in_planes + hidden_planes 14 | 15 | self.convz = nn.Conv2d(fuse_planes, hidden_planes, 3, padding=1) 16 | self.convr = nn.Conv2d(fuse_planes, hidden_planes, 3, padding=1) 17 | self.convq = nn.Conv2d(fuse_planes, hidden_planes, 3, padding=1) 18 | 19 | def forward(self, last_hidden, x): 20 | hx = torch.cat((last_hidden, x), dim=1) 21 | 22 | update_gate = torch.sigmoid(self.convz(hx)) 23 | reset_gate = torch.sigmoid(self.convr(hx)) 24 | cur_hidden = torch.tanh(self.convq(torch.cat((reset_gate * last_hidden, x), dim=1))) 25 | 26 | hidden = (1 - update_gate) * last_hidden + update_gate * cur_hidden 27 | 28 | return hidden 29 | -------------------------------------------------------------------------------- /architecture/data/datasets/vkitti/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from architecture.data.datasets.base import StereoDatasetBase 4 | 5 | class VKITTIStereoDatasetBase(StereoDatasetBase): 6 | def __init__(self, annFile, root, height, width, frame_idxs, 7 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 8 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 9 | 10 | super(VKITTIStereoDatasetBase, self).__init__(annFile, root, height, width, frame_idxs, 11 | is_train, use_common_intrinsics, do_same_lr_transform, 12 | mean, std) 13 | 14 | self.K = np.array([[725.0087/1242, 0, 620.5/1242, 0], 15 | [0, 725.0087/375, 187/375, 0], 16 | [0, 0, 1, 0], 17 | [0, 0, 0, 1]]) 18 | # (h, w) 19 | self.full_resolution = (375, 1242) 20 | 21 | self.baseline = 0.532725 22 | 23 | self.with_depth_gt = True 24 | self.with_disp_gt = False 25 | self.with_flow_gt = False 26 | self.with_pose_gt = True 27 | -------------------------------------------------------------------------------- /architecture/data/datasets/tartanair/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from architecture.data.datasets.base import StereoDatasetBase 4 | 5 | class TARTANAIRStereoDatasetBase(StereoDatasetBase): 6 | def __init__(self, annFile, root, height, width, frame_idxs, 7 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 8 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 9 | 10 | super(TARTANAIRStereoDatasetBase, self).__init__(annFile, root, height, width, frame_idxs, 11 | is_train, use_common_intrinsics, do_same_lr_transform, 12 | mean, std) 13 | 14 | self.K = np.array([[320.0/640.0, 0, 320.0/640.0, 0], 15 | [0, 320.0/480.0, 240.0/480.0, 0], 16 | [0, 0, 1, 0], 17 | [0, 0, 0, 1]]) 18 | # (h, w) 19 | self.full_resolution = (480, 640) 20 | 21 | self.baseline = 0.25 22 | 23 | self.with_depth_gt = True 24 | self.with_disp_gt = False 25 | self.with_flow_gt = False 26 | self.with_pose_gt = True 27 | -------------------------------------------------------------------------------- /architecture/data/datasets/kitti/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from architecture.data.datasets.base import StereoDatasetBase 4 | 5 | class KITTIStereoDatasetBase(StereoDatasetBase): 6 | def __init__(self, annFile, root, height, width, frame_idxs, 7 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 8 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 9 | 10 | super(KITTIStereoDatasetBase, self).__init__(annFile, root, height, width, frame_idxs, 11 | is_train, use_common_intrinsics, do_same_lr_transform, 12 | mean, std) 13 | 14 | self.K=np.array([[721.5377/1242, 0, 609.5593/1242, 0], 15 | [0, 721.5377/375, 172.854/375, 0], 16 | [0, 0, 1, 0], 17 | [0, 0, 0, 1]]) 18 | # (h, w) 19 | self.full_resolution = (375, 1242) 20 | 21 | self.baseline = 0.54 22 | 23 | self.with_depth_gt = False 24 | self.with_disp_gt = True 25 | self.with_flow_gt = False 26 | self.with_pose_gt = True 27 | -------------------------------------------------------------------------------- /architecture/data/evaluation/flow_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | 4 | from .flow_pixel_error import flow_calc_error 5 | 6 | def do_flow_evaluation(est_flow, gt_flow, lb=0.0, ub=400, sparse=False): 7 | """ 8 | Do pixel error evaluation. (See KITTI evaluation protocols for details.) 9 | Args: 10 | est_flow: (Tensor), estimated flow map 11 | [..., 2, Height, Width] layout 12 | gt_flow: (Tensor), ground truth flow map 13 | [..., 2, Height, Width] layout 14 | lb: (scalar), the lower bound of disparity you want to mask out 15 | ub: (scalar), the upper bound of disparity you want to mask out 16 | sparse: (bool), whether the given flow is sparse, default False 17 | Returns: 18 | error_dict (dict): the error of 1px, 2px, 3px, 5px, in percent, 19 | range [0,100] and average error epe 20 | """ 21 | error_dict = {} 22 | if est_flow is None: 23 | warnings.warn('Estimated flow map is None') 24 | return error_dict 25 | if gt_flow is None: 26 | warnings.warn('Reference ground truth flow map is None') 27 | return error_dict 28 | 29 | if torch.is_tensor(est_flow): 30 | est_flow = est_flow.clone().cpu() 31 | 32 | if torch.is_tensor(gt_flow): 33 | gt_flow = gt_flow.clone().cpu() 34 | 35 | error_dict = flow_calc_error(est_flow, gt_flow, sparse=sparse) 36 | 37 | return error_dict 38 | -------------------------------------------------------------------------------- /architecture/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Magic Leap, Inc. 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Originating Author: Zak Murez (zak.murez.com) 16 | 17 | import argparse 18 | from fvcore.common.config import CfgNode as _CfgNode 19 | 20 | 21 | def convert_to_dict(cfg_node, key_list=[]): 22 | """ Convert a config node to dictionary """ 23 | _VALID_TYPES = {tuple, list, str, int, float, bool} 24 | if not isinstance(cfg_node, _CfgNode): 25 | if type(cfg_node) not in _VALID_TYPES: 26 | print("Key {} with value {} is not a valid type; valid types: {}".format( 27 | ".".join(key_list), type(cfg_node), _VALID_TYPES), ) 28 | return cfg_node 29 | else: 30 | cfg_dict = dict(cfg_node) 31 | for k, v in cfg_dict.items(): 32 | cfg_dict[k] = convert_to_dict(v, key_list + [k]) 33 | return cfg_dict 34 | 35 | class CfgNode(_CfgNode): 36 | """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged""" 37 | def convert_to_dict(self): 38 | return convert_to_dict(self) 39 | -------------------------------------------------------------------------------- /architecture/data/utils/load_eth3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from .load_disparity import load_eth3d_disp 4 | 5 | 6 | def read_eth3d_intrinsic(intrinsic_fn): 7 | data = {} 8 | with open(intrinsic_fn, 'r') as fp: 9 | # 0 PINHOLE 941 490 542.019 542.019 541.922 255.202 10 | lines = fp.readlines() 11 | values = lines[-1][10:].rstrip().split(' ') 12 | h, w = values[1], values[0] 13 | resolution = (h, w) 14 | camera = '02' 15 | key = 'K_cam{}'.format(camera) 16 | inv_key = 'inv_K_cam{}'.format(int(camera)) 17 | matrix = np.array([float(values[i]) for i in range(len(values))]) 18 | K = np.eye(4) 19 | K[0, 0] = matrix[2] 20 | K[1, 1] = matrix[3] 21 | K[0, 2] = matrix[4] 22 | K[1, 2] = matrix[5] 23 | item = { 24 | key: K, 25 | inv_key: np.linalg.pinv(K) 26 | } 27 | data['{}'.format(camera)] = item 28 | 29 | return data, resolution 30 | 31 | 32 | def read_eth3d_pfm_disparity(disp_fn, K=np.array([[541.764, 0, 553.869, 0], 33 | [0, 541.764, 232.396, 0], 34 | [0, 0, 1, 0], 35 | [0, 0, 0, 1]])): 36 | 37 | disp = load_eth3d_disp(disp_fn) 38 | # uint16 39 | valid_mask = disp > 0 40 | 41 | f = K[0, 0] 42 | b = 0.595499 # meter 43 | 44 | depth = b * f / (disp + 1e-12) 45 | depth = depth * valid_mask 46 | 47 | return depth, disp 48 | -------------------------------------------------------------------------------- /architecture/utils/time_test_template.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def timeTestTemplate(module, *args, **kwargs): 5 | """ 6 | Module time test, inputs can be in tuple/list or dict 7 | Args: 8 | module: (nn.Module, callable Function): the module we want to time test 9 | *args: (tuple, list): the inputs of module 10 | **kwargs: (dict): the inputs of module 11 | 12 | Returns: 13 | avg_time: (double, float): the average time of inference for one round 14 | """ 15 | device = kwargs.pop('device', None) 16 | assert device is not None, "param: device must be given, e.g., device=torch.device('cuda:0')" 17 | iters = kwargs.pop('iters', 1000) 18 | with torch.cuda.device(device): 19 | torch.cuda.empty_cache() 20 | if isinstance(module, nn.Module): 21 | module.eval().to(device) 22 | 23 | avg_time = 0.0 24 | count = 0 25 | 26 | with torch.no_grad(): 27 | for i in range(iters): 28 | start_time = torch.cuda.Event(enable_timing=True) 29 | end_time = torch.cuda.Event(enable_timing=True) 30 | start_time.record() 31 | if len(args) > 0: 32 | module(*args) 33 | if len(kwargs) > 0: 34 | module(**kwargs) 35 | end_time.record() 36 | torch.cuda.synchronize(device) 37 | if i >=100 and i < 900: 38 | avg_time += start_time.elapsed_time(end_time) 39 | count += 1 40 | avg_time = avg_time / count # in milliseconds 41 | avg_time = avg_time / 1000 # in second 42 | 43 | return avg_time 44 | -------------------------------------------------------------------------------- /architecture/data/datasets/scene_flow/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from architecture.data.datasets.base import StereoDatasetBase 4 | 5 | class SceneFlowStereoDatasetBase(StereoDatasetBase): 6 | def __init__(self, annFile, root, height, width, frame_idxs, 7 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 8 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 9 | super(SceneFlowStereoDatasetBase, self).__init__(annFile, root, height, width, frame_idxs, 10 | is_train, use_common_intrinsics, do_same_lr_transform, 11 | mean, std) 12 | 13 | # https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html#information 14 | # Most scenes use a virtual focal length of 35.0mm. For those scenes, the virtual camera intrinsics matrix is given by 15 | self.K = np.array([[1050.0/960, 0, 497.5/960, 0], 16 | [0, 1050.0/540, 269.5/540, 0], 17 | [0, 0, 1, 0], 18 | [0, 0, 0, 1]]) 19 | 20 | # Some scenes in the Driving subset use a virtual focal length of 15.0mm 21 | self.K15 = np.array([[450.0/960, 0, 497.5/960, 0], 22 | [0, 450.0/540, 269.5/540, 0], 23 | [0, 0, 1, 0], 24 | [0, 0, 0, 1]]) 25 | 26 | # (h, w) 27 | self.full_resolution = (540, 960) 28 | 29 | self.with_depth_gt = False 30 | self.with_disp_gt = True 31 | self.with_flow_gt = False 32 | self.with_pose_gt = True 33 | -------------------------------------------------------------------------------- /architecture/data/utils/calibration/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # Reference include: 5 | # https://github.com/charlesq34/frustum-pointnets/blob/master/kitti/kitti_util.py 6 | 7 | def cart_to_homo(pts_3d): 8 | """ 9 | Convert Cartesian to Homogeneous 10 | Inputs: 11 | pts_3d, (numpy.ndarray): 3xn points in Cartesian 12 | Outputs: 13 | pts_4d, (numpy.ndarray): 4xn points in Homogeneous by pending 1 14 | """ 15 | c,n = pts_3d.shape 16 | assert c == 3 17 | pts_4d = np.vstack((pts_3d, np.ones(n))) 18 | assert pts_4d.shape == (4, n) 19 | return pts_4d 20 | 21 | def trans2homo4x4(T): 22 | """ 23 | Transform into homogeneous matrix 24 | Inputs: 25 | T, (numpy.ndarray): 3x3 or 3x4 matrix 26 | Outputs: 27 | homoT, (numpy.ndarray): 4x4 matrix 28 | """ 29 | h, w = T.shape 30 | assert h<=4 and w <=4 31 | homoT = np.eye(4) 32 | homoT[:h, :w] = T 33 | return homoT 34 | 35 | def load_velodyne(filepath, 36 | no_reflect=False, 37 | load_bin_without_reflect=False, 38 | dtype=np.float32): 39 | """ 40 | velodyne point cloud contains 4 values, where the first 3 values corresponds to x, y, z, 41 | and the last value is the reflectance information, often it's not used 42 | Args: 43 | filepath, (str): the velodyne file path 44 | no_reflect, (bool): weather return with the 4th value, i.e., reflectance information 45 | load_bin_without_reflect, (bool): weather load the 4th value, i.e., reflectance information 46 | dtype, (str or dtype): typecode or data-type to which the array is cast. 47 | 48 | Outputs: 49 | velo, (dtype): the loaded velodyne point cloud, in nx3 or nx4 layout 50 | 51 | """ 52 | channels = 3 if load_bin_without_reflect else 4 53 | velo = np.fromfile(filepath, dtype=dtype).reshape(-1, channels) 54 | if no_reflect: 55 | return velo[:, :3] 56 | else: 57 | if channels == 3: 58 | velo = cart_to_homo(velo.transpose()).transpose() # fake reflectance 59 | return velo -------------------------------------------------------------------------------- /architecture/modeling/aggregation/utils/SPP3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Any, Dict, List, Optional, Tuple, Union 5 | 6 | from architecture.modeling.layers import Conv3d 7 | 8 | class SPP3D(nn.Module): 9 | """ 10 | 3D SPP 11 | Args: 12 | in_planes: (int), the channels of feature map 13 | norm: (str), the type of normalization layer 14 | activation: (str, list, tuple), the type of activation layer and its coefficient is needed 15 | """ 16 | def __init__(self, 17 | in_planes: int = 64, 18 | strides: Union[List, Tuple] = (2, 4, 8, 16), 19 | norm: str = 'BN3d', 20 | activation: Union[str, List, Tuple] = 'ReLU'): 21 | super(SPP3D, self).__init__() 22 | 23 | self.in_planes = in_planes 24 | self.strides = strides 25 | self.norm = norm 26 | self.activation = activation 27 | 28 | self.pools = nn.ModuleList() 29 | for stride in self.strides: 30 | self.pools.append( 31 | Conv3d(in_planes, 16, 1, 1, 0, 1, bias=False, norm=(norm, 16), activation=activation) 32 | ) 33 | self.fuse = nn.Sequential( 34 | Conv3d(16*len(strides)+in_planes, in_planes, 3, 1, 1, 1, bias=False, norm=(norm, in_planes), activation=activation), 35 | nn.Conv3d(in_planes, in_planes, 1, 1, 0, 1, bias=False), 36 | ) 37 | 38 | def forward(self, x): 39 | features = [x] 40 | B, C, D, H, W = x.shape 41 | for stride, pool in zip(self.strides, self.pools): 42 | stride = (min(D, stride), min(H, stride), min(W, stride)) 43 | out = F.avg_pool3d(x, kernel_size=stride, stride=stride) 44 | out = pool(out) 45 | out = F.interpolate(out, size=(D, H, W), mode='trilinear', align_corners=True) 46 | features.append(out) 47 | features = torch.cat(features, dim=1) 48 | out = self.fuse(features) 49 | return out 50 | 51 | -------------------------------------------------------------------------------- /architecture/modeling/prediction/argmin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from detectron2.config import configurable 5 | 6 | from .builder import PREDICTION_REGISTRY 7 | 8 | @PREDICTION_REGISTRY.register() 9 | class ARGMIN(nn.Module): 10 | """ 11 | A faster implementation of argmin. 12 | Args: 13 | dim, (int): perform argmin at dimension $dim 14 | 15 | Inputs: 16 | cost_volume, (Tensor): the matching cost after regularization, 17 | [BatchSize, disp_sample_number, Height, Width] layout 18 | disp_sample, (Tensor): the estimated disparity samples, 19 | [BatchSize, disp_sample_number, Height, Width] layout. 20 | Returns: 21 | disp_map, (Tensor): a disparity map regressed from cost volume, 22 | [BatchSize, 1, Height, Width] layout 23 | """ 24 | @configurable 25 | def __init__(self, dim:int = 1): 26 | super(ARGMIN, self).__init__() 27 | self.dim = dim 28 | 29 | @classmethod 30 | def from_config(cls, cfg): 31 | return { 32 | "dim": cfg.MODEL.PREDICTION.get("DIM", 1), 33 | } 34 | 35 | def forward(self, cost_volume, disp_sample): 36 | 37 | # note, cost volume direct represent similarity 38 | # 'c' or '-c' do not affect the performance because feature-based cost volume provided flexibility. 39 | 40 | assert cost_volume.shape == disp_sample.shape, "{}, {}".format(cost_volume.shape, disp_sample.shape) 41 | 42 | _, indices = torch.max(cost_volume, dim=self.dim, keepdim=True) 43 | # compute disparity: (BatchSize, 1, Height, Width) 44 | disp_map = torch.gather(disp_sample, dim=self.dim, index=indices) 45 | 46 | return disp_map 47 | 48 | def __repr__(self): 49 | repr_str = '{}\n'.format(self.__class__.__name__) 50 | repr_str += ' ' * 4 + 'Dim: {}\n'.format(self.dim) 51 | 52 | return repr_str 53 | 54 | @property 55 | def name(self): 56 | return 'Argmin' -------------------------------------------------------------------------------- /architecture/data/utils/load_scene_flow.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from .load_flow import load_flying_things_flow 4 | from .load_disparity import load_scene_flow_disp 5 | 6 | 7 | def read_sceneflow_extrinsic(extrinsic_fn): 8 | data = {} 9 | with open(extrinsic_fn, 'r') as fp: 10 | lines = fp.readlines() 11 | item_num = len(lines) // 4 12 | for i in range(item_num): 13 | frame_info = lines[i*4+0] 14 | frame = int(frame_info.rstrip().split(' ')[-1]) 15 | 16 | # read left camera extrinsic 17 | left_extrinisc = lines[i*4+1] 18 | left_values = left_extrinisc.rstrip().split(' ') 19 | camera = 0 20 | key = 'T_cam{}'.format(int(camera)) 21 | inv_key = 'inv_T_cam{}'.format(int(camera)) 22 | matrix = np.array([float(left_values[i]) for i in range(1, len(left_values))]) 23 | matrix = matrix.reshape(4, 4) 24 | item = { 25 | key: matrix, 26 | inv_key: np.linalg.pinv(matrix), 27 | } 28 | data['Frame{}:{}'.format(frame, camera)] = item 29 | 30 | # read right camera extrinsic 31 | right_extrinisc = lines[i*4+2] 32 | right_values = right_extrinisc.rstrip().split(' ') 33 | camera = 1 34 | key = 'T_cam{}'.format(int(camera)) 35 | inv_key = 'inv_T_cam{}'.format(int(camera)) 36 | matrix = np.array([float(right_values[i]) for i in range(1, len(right_values))]) 37 | matrix = matrix.reshape(4, 4) 38 | item = { 39 | key: matrix, 40 | inv_key: np.linalg.pinv(matrix), 41 | } 42 | data['Frame{}:{}'.format(frame, camera)] = item 43 | 44 | return data 45 | 46 | 47 | 48 | def read_sceneflow_pfm_disparity(disp_fn, K): 49 | disp = load_scene_flow_disp(disp_fn) 50 | h, w = disp.shape 51 | disp = np.nan_to_num(disp, nan=0.0) 52 | disp[disp > w] = 0 53 | disp[disp < 0] = 0 54 | 55 | f = K[0, 0] 56 | # https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html#information, No.6 57 | b = 1.0 # meter 58 | eps = 1e-8 59 | 60 | depth = b * f / (disp + eps) 61 | 62 | return depth, disp 63 | 64 | def read_sceneflow_pfm_flow(flow_fn): 65 | """Convert from .png to (h, w, 2) (flow_x, flow_y) float32 array""" 66 | # read png to bgr in 16 bit unsigned short 67 | 68 | out_flow = load_flying_things_flow(flow_fn) 69 | 70 | return out_flow 71 | -------------------------------------------------------------------------------- /architecture/modeling/layers/inverse_warp_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def inverse_warp_3d(img, disp, padding_mode='zeros', disp_Y=None): 5 | """ 6 | Args: 7 | img: (Tensor), the source image (where to sample pixels) 8 | [B, C, H, W] or [B, C, D, H, W] 9 | disp: (Tensor), disparity map of the target image 10 | [B, D, H, W] 11 | padding_mode: (str), padding mode, default is zero padding 12 | disp_Y: (Tensor): disparity map of the target image along Y-axis, i.e., Height dimension 13 | [B, D, H, W] 14 | 15 | Returns: 16 | projected_img: (Tensor), source image warped to the target image 17 | [B, C, D, H, W] 18 | """ 19 | 20 | device = disp.device 21 | B, D, H, W = disp.shape 22 | 23 | if disp_Y is not None: 24 | assert disp.shape == disp_Y.shape, 'disparity map along x and y axis should have same shape!' 25 | if img.dim() == 4: 26 | _, C, iH, iW = img.shape 27 | img = img.unsqueeze(2).expand(B, C, D, iH, iW) 28 | elif img.dim() == 5: 29 | assert D == img.shape[2], 'The disparity number should be same between image and disparity map!' 30 | else: 31 | raise ValueError('image is only allowed with 4 or 5 dimensions, ' 32 | 'but got {} dimensions!'.format(img.dim())) 33 | 34 | # get mesh grid for each dimension 35 | grid_d = torch.linspace(0, D - 1, D).view(1, D, 1, 1).expand(B, D, H, W).to(device) 36 | grid_h = torch.linspace(0, H - 1, H).view(1, 1, H, 1).expand(B, D, H, W).to(device) 37 | grid_w = torch.linspace(0, W - 1, W).view(1, 1, 1, W).expand(B, D, H, W).to(device) 38 | 39 | # shift the index of W dimension with disparity 40 | grid_w = grid_w + disp 41 | if disp_Y is not None: 42 | grid_h = grid_h + disp_Y 43 | 44 | # normalize the grid value into [-1, 1]; (0, D-1), (0, H-1), (0, W-1) 45 | grid_d = (grid_d / (D - 1) * 2) - 1 46 | grid_h = (grid_h / (H - 1) * 2) - 1 47 | grid_w = (grid_w / (W - 1) * 2) - 1 48 | 49 | # concatenate the grid_* to [B, D, H, W, 3] 50 | grid_d = grid_d.unsqueeze(4) 51 | grid_h = grid_h.unsqueeze(4) 52 | grid_w = grid_w.unsqueeze(4) 53 | grid = torch.cat((grid_w, grid_h, grid_d), 4) 54 | 55 | # [B, C, D, H, W] 56 | projected_img = F.grid_sample(img, grid, padding_mode=padding_mode, align_corners=True) 57 | 58 | return projected_img -------------------------------------------------------------------------------- /projects/TemporalStereo/configs/tartanair.yaml: -------------------------------------------------------------------------------- 1 | LOG_DIR: "./exps/" 2 | FRAME_IDXS: [0, ] 3 | 4 | CHECKPOINT: 5 | EVERY_N_TRAIN_STEPS: 0 6 | EVERY_N_EPOCHS: 2 7 | 8 | TRAINER: 9 | NAME: 'TemporalStereo' 10 | NUM_GPUS: 1 11 | VERSION: "tartanair/baseline" 12 | MAX_EPOCHS: 40 13 | CHECK_VAL_EVERY_N_EPOCHS: 2 14 | 15 | 16 | SCHEDULER: 17 | TYPE: 'MultiStepLR' 18 | MULTI_STEP_LR: 19 | MILESTONES: [30, 40] 20 | GAMMA: 0.1 21 | 22 | OPTIMIZER: 23 | TYPE: 'RMSProp' 24 | RMSPROP: 25 | LR: 1e-3 26 | 27 | MODEL: 28 | WITH_PREVIOUS: False 29 | PREVIOUS_WITH_GRADIENT: False 30 | WITH_FLOW: False 31 | USE_LOCAL_MAP: False 32 | LOCAL_MAP_SIZE: 0 33 | VIS_FEATURE: False 34 | BACKBONE: 35 | NAME: "TEMPORALSTEREO" 36 | IN_PLANES: 3 37 | AGGREGATION: 38 | NAME: "TEMPORALSTEREO" 39 | COARSE: 40 | IN_PLANES: 256 41 | C: 32 42 | NUM_SAMPLE: 12 43 | DELTA: 1.0 44 | BLOCK_COST_SCALE: 3 45 | TOPK: 2 46 | SPATIAL_FUSION: True 47 | FINE: 48 | IN_PLANES: 128 49 | C: 16 50 | NUM_SAMPLE: 5 51 | DELTA: 1.0 52 | BLOCK_COST_SCALE: 3 53 | TOPK: 2 54 | SPATIAL_FUSION: True 55 | PRECISE: 56 | IN_PLANES: 64 57 | C: 8 58 | NUM_SAMPLE: 5 59 | DELTA: 1.0 60 | BLOCK_COST_SCALE: 3 61 | TOPK: 2 62 | LOSSES: 63 | SMOOTH_L1_LOSS: 64 | GLOBAL_WEIGHT: 1.0 65 | WEIGHTS: [2.0, 1.0, 0.7, 0.5] 66 | WARSSERSTEIN_DISTANCE_LOSS: 67 | GLOBAL_WEIGHT: 2.0 68 | WEIGHTS: [1.0, 0.7, 0.5] 69 | 70 | DATA: 71 | TRAIN: 72 | DATA_ROOT: "./datasets/TartanAir" 73 | ANNFILE: "./splits/tartanair/view_1_train.json" 74 | TYPE: 'TartanAir' 75 | HEIGHT: 480 76 | WIDTH: 640 77 | DO_SAME_LR_TRANSFORM: False 78 | BATCH_SIZE: 16 79 | NUM_WORKERS: 16 80 | FRAME_IDXS: [0, ] 81 | VAL: 82 | DATA_ROOT: "./datasets/TartanAir" 83 | ANNFILE: "./splits/tartanair/view_1_test.json" 84 | TYPE: 'TartanAir' 85 | HEIGHT: 480 86 | WIDTH: 640 87 | DO_SAME_LR_TRANSFORM: True 88 | BATCH_SIZE: 8 89 | NUM_WORKERS: 8 90 | FRAME_IDXS: [0, ] 91 | TEST: 92 | DATA_ROOT: "./datasets/TartanAir" 93 | ANNFILE: "./splits/tartanair/view_1_test.json" 94 | TYPE: 'TartanAir' 95 | HEIGHT: 480 96 | WIDTH: 640 97 | DO_SAME_LR_TRANSFORM: True 98 | BATCH_SIZE: 1 99 | NUM_WORKERS: 4 100 | FRAME_IDXS: [0, ] 101 | 102 | VAL: 103 | EVAL_DISPARITY_IDS: [0, 1, 2, 3, 4, 5] 104 | DO_OCCLUSION_EVALUATION: True 105 | 106 | 107 | -------------------------------------------------------------------------------- /projects/TemporalStereo/configs/sceneflow.yaml: -------------------------------------------------------------------------------- 1 | LOG_DIR: "./exps/" 2 | FRAME_IDXS: [0, ] 3 | 4 | CHECKPOINT: 5 | EVERY_N_TRAIN_STEPS: 0 6 | EVERY_N_EPOCHS: 2 7 | 8 | TRAINER: 9 | NAME: 'TemporalStereo' 10 | NUM_GPUS: 1 11 | VERSION: "sceneflow/first_try" 12 | MAX_EPOCHS: 40 13 | CHECK_VAL_EVERY_N_EPOCHS: 2 14 | 15 | SCHEDULER: 16 | TYPE: 'MultiStepLR' 17 | MULTI_STEP_LR: 18 | MILESTONES: [30, 40] 19 | GAMMA: 0.1 20 | 21 | OPTIMIZER: 22 | TYPE: 'RMSProp' 23 | RMSPROP: 24 | LR: 1e-3 25 | 26 | MODEL: 27 | WITH_PREVIOUS: False 28 | PREVIOUS_WITH_GRADIENT: False 29 | WITH_FLOW: False 30 | USE_LOCAL_MAP: False 31 | LOCAL_MAP_SIZE: 0 32 | VIS_FEATURE: False 33 | BACKBONE: 34 | NAME: "TEMPORALSTEREO" 35 | IN_PLANES: 3 36 | AGGREGATION: 37 | NAME: "TEMPORALSTEREO" 38 | COARSE: 39 | IN_PLANES: 256 40 | C: 32 41 | NUM_SAMPLE: 12 42 | DELTA: 1.0 43 | BLOCK_COST_SCALE: 3 44 | TOPK: 2 45 | SPATIAL_FUSION: True 46 | FINE: 47 | IN_PLANES: 128 48 | C: 16 49 | NUM_SAMPLE: 5 50 | DELTA: 1.0 51 | BLOCK_COST_SCALE: 3 52 | TOPK: 2 53 | SPATIAL_FUSION: True 54 | PRECISE: 55 | IN_PLANES: 64 56 | C: 8 57 | NUM_SAMPLE: 5 58 | DELTA: 1.0 59 | BLOCK_COST_SCALE: 3 60 | TOPK: 2 61 | LOSSES: 62 | SMOOTH_L1_LOSS: 63 | GLOBAL_WEIGHT: 1.0 64 | WEIGHTS: [2.0, 1.0, 0.7, 0.5] 65 | WARSSERSTEIN_DISTANCE_LOSS: 66 | GLOBAL_WEIGHT: 2.0 67 | WEIGHTS: [1.0, 0.7, 0.5] 68 | 69 | DATA: 70 | TRAIN: 71 | DATA_ROOT: "./datasets/SceneFlow/FlyingThings3D/" 72 | ANNFILE: "./splits/flyingthings3d/train.json" 73 | TYPE: 'SceneFlow' 74 | HEIGHT: 512 75 | WIDTH: 960 76 | DO_SAME_LR_TRANSFORM: False 77 | BATCH_SIZE: 4 78 | NUM_WORKERS: 4 79 | FRAME_IDXS: [0, ] 80 | VAL: 81 | DATA_ROOT: "./datasets/SceneFlow/FlyingThings3D/" 82 | ANNFILE: "./splits/flyingthings3d/test.json" 83 | TYPE: 'SceneFlow' 84 | HEIGHT: 544 85 | WIDTH: 960 86 | DO_SAME_LR_TRANSFORM: True 87 | BATCH_SIZE: 4 88 | NUM_WORKERS: 4 89 | FRAME_IDXS: [0, ] 90 | TEST: 91 | DATA_ROOT: "./datasets/SceneFlow/FlyingThings3D/" 92 | ANNFILE: "./splits/flyingthings3d/test.json" 93 | TYPE: 'SceneFlow' 94 | HEIGHT: 544 95 | WIDTH: 960 96 | DO_SAME_LR_TRANSFORM: True 97 | BATCH_SIZE: 1 98 | NUM_WORKERS: 4 99 | FRAME_IDXS: [0, ] 100 | 101 | VAL: 102 | EVAL_DISPARITY_IDS: [0, 1, 2, 3, 4, 5, 6] 103 | VIS_INTERVAL: 8 104 | 105 | -------------------------------------------------------------------------------- /architecture/data/evaluation/pixel_error.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | def calc_error(est_disp=None, gt_disp=None, lb=None, ub=None): 7 | """ 8 | Args: 9 | est_disp (Tensor): in [..., Height, Width] layout 10 | gt_disp (Tensor): in [..., Height, Width] layout 11 | lb (scalar): the lower bound of disparity you want to mask out 12 | ub (scalar): the upper bound of disparity you want to mask out 13 | Output: 14 | dict: the error of 1px, 2px, 3px, 5px, in percent, 15 | range [0,100] and average error epe 16 | """ 17 | error1 = torch.Tensor([0.]) 18 | error2 = torch.Tensor([0.]) 19 | error3 = torch.Tensor([0.]) 20 | error5 = torch.Tensor([0.]) 21 | epe = torch.Tensor([0.]) 22 | 23 | if (not torch.is_tensor(est_disp)) or (not torch.is_tensor(gt_disp)): 24 | return { 25 | '1px': error1 * 100, 26 | '2px': error2 * 100, 27 | '3px': error3 * 100, 28 | '5px': error5 * 100, 29 | 'epe': epe 30 | } 31 | 32 | assert torch.is_tensor(est_disp) and torch.is_tensor(gt_disp) 33 | assert est_disp.shape == gt_disp.shape 34 | 35 | est_disp = est_disp.clone().cpu() 36 | gt_disp = gt_disp.clone().cpu() 37 | 38 | mask = torch.ones(gt_disp.shape, dtype=torch.bool, device=gt_disp.device) 39 | if lb is not None: 40 | mask = mask & (gt_disp > lb) 41 | if ub is not None: 42 | mask = mask & (gt_disp < ub) 43 | mask.detach_() 44 | if abs(mask.float().sum()) < 1.0: 45 | return { 46 | '1px': error1 * 100, 47 | '2px': error2 * 100, 48 | '3px': error3 * 100, 49 | '5px': error5 * 100, 50 | 'epe': epe 51 | } 52 | 53 | gt_disp = gt_disp[mask] 54 | est_disp = est_disp[mask] 55 | 56 | abs_error = torch.abs(gt_disp - est_disp) 57 | total_num = mask.float().sum() 58 | 59 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num 60 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num 61 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num 62 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num 63 | epe = abs_error.float().mean() 64 | 65 | # .mean() will get a tensor with size: torch.Size([]), after decorate with torch.Tensor, the size will be: torch.Size([1]) 66 | return { 67 | '1px': torch.Tensor([error1 * 100]), 68 | '2px': torch.Tensor([error2 * 100]), 69 | '3px': torch.Tensor([error3 * 100]), 70 | '5px': torch.Tensor([error5 * 100]), 71 | 'epe': torch.Tensor([epe]), 72 | } -------------------------------------------------------------------------------- /projects/TemporalStereo/configs/tartanair_full.yaml: -------------------------------------------------------------------------------- 1 | LOG_DIR: "./exps/" 2 | FRAME_IDXS: [-3, -2, -1, 0, ] 3 | 4 | CHECKPOINT: 5 | EVERY_N_TRAIN_STEPS: 0 6 | EVERY_N_EPOCHS: 2 7 | 8 | TRAINER: 9 | NAME: 'TemporalStereo' 10 | NUM_GPUS: 1 11 | VERSION: "tartanair/full" 12 | MAX_EPOCHS: 20 13 | CHECK_VAL_EVERY_N_EPOCHS: 2 14 | LOAD_FROM_CHECKPOINT: "./checkpoints/tartanair.ckpt" 15 | 16 | SCHEDULER: 17 | TYPE: 'MultiStepLR' 18 | MULTI_STEP_LR: 19 | MILESTONES: [10, 20] 20 | GAMMA: 0.1 21 | 22 | OPTIMIZER: 23 | TYPE: 'RMSProp' 24 | RMSPROP: 25 | LR: 1e-3 26 | 27 | MODEL: 28 | WITH_PREVIOUS: True 29 | PREVIOUS_WITH_GRADIENT: False 30 | USE_PAST_COST: True 31 | LOCAL_MAP_SIZE: 3 32 | BACKBONE: 33 | NAME: "TEMPORALSTEREO" 34 | IN_PLANES: 3 35 | MEMORY_PERCENT: 0.5 36 | AGGREGATION: 37 | NAME: "TEMPORALSTEREO" 38 | COARSE: 39 | IN_PLANES: 256 40 | C: 32 41 | NUM_SAMPLE: 12 42 | DELTA: 1.0 43 | BLOCK_COST_SCALE: 3 44 | TOPK: 2 45 | SPATIAL_FUSION: True 46 | FINE: 47 | IN_PLANES: 128 48 | C: 16 49 | NUM_SAMPLE: 5 50 | DELTA: 1.0 51 | BLOCK_COST_SCALE: 3 52 | TOPK: 2 53 | SPATIAL_FUSION: True 54 | PRECISE: 55 | IN_PLANES: 64 56 | C: 8 57 | NUM_SAMPLE: 5 58 | DELTA: 1.0 59 | BLOCK_COST_SCALE: 3 60 | TOPK: 2 61 | LOSSES: 62 | SMOOTH_L1_LOSS: 63 | GLOBAL_WEIGHT: 1.0 64 | WEIGHTS: [2.0, 1.0, 0.7, 0.5] 65 | WARSSERSTEIN_DISTANCE_LOSS: 66 | GLOBAL_WEIGHT: 2.0 67 | WEIGHTS: [1.0, 0.7, 0.5] 68 | 69 | DATA: 70 | TRAIN: 71 | DATA_ROOT: "./datasets/TartanAir" 72 | ANNFILE: "./splits/tartanair/view_4_train.json" 73 | TYPE: 'TartanAir' 74 | HEIGHT: 480 75 | WIDTH: 640 76 | DO_SAME_LR_TRANSFORM: False 77 | BATCH_SIZE: 4 78 | NUM_WORKERS: 16 79 | FRAME_IDXS: [-3, -2, -1, 0, ] 80 | VAL: 81 | DATA_ROOT: "./datasets/TartanAir" 82 | ANNFILE: "./splits/tartanair/view_4_test.json" 83 | TYPE: 'TartanAir' 84 | HEIGHT: 480 85 | WIDTH: 640 86 | DO_SAME_LR_TRANSFORM: True 87 | BATCH_SIZE: 4 88 | NUM_WORKERS: 16 89 | FRAME_IDXS: [-3, -2, -1, 0, ] 90 | TEST: 91 | DATA_ROOT: "./datasets/TartanAir" 92 | ANNFILE: "./splits/tartanair/view_4_test.json" 93 | TYPE: 'TartanAir' 94 | HEIGHT: 480 95 | WIDTH: 640 96 | DO_SAME_LR_TRANSFORM: True 97 | BATCH_SIZE: 1 98 | NUM_WORKERS: 4 99 | FRAME_IDXS: [-3, -2, -1, 0, ] 100 | 101 | VAL: 102 | EVAL_DISPARITY_IDS: [0, 1, 2, 3, 4, 5] 103 | DO_OCCLUSION_EVALUATION: True 104 | -------------------------------------------------------------------------------- /architecture/data/utils/load_tartanair.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from scipy.spatial.transform import Rotation 4 | 5 | 6 | def read_tartantic_intrinsic(): 7 | K = np.eye(4) 8 | K[0, 0] = 320.0 9 | K[1, 1] = 320.0 10 | K[0, 2] = 320.0 11 | K[1, 2] = 240.0 12 | 13 | return K 14 | 15 | 16 | def read_tartanair_extrinsic(extrinsic_fn, side='left'): 17 | data = {} 18 | camera_id = {'left': 0, 'right': 1} 19 | with open(extrinsic_fn, 'r') as fp: 20 | lines = fp.readlines() 21 | # poses = np.loadtxt(extrinsic_fn) 22 | # for lineid, pose in enumerate(poses): 23 | for lineid, line in enumerate(lines): 24 | frame = int(lineid) 25 | camera = int(camera_id[side]) 26 | key = 'T_cam{}'.format(int(camera)) 27 | inv_key = 'inv_T_cam{}'.format(int(camera)) 28 | values = line.rstrip().split(' ') 29 | assert len(values) == 7, 'Pose must be quaterion format -- 7 params, but {} got'.format(len(values)) 30 | pose = np.array([float(values[i]) for i in range(len(values))]) 31 | tx, ty, tz, qx, qy, qz, qw = pose 32 | R = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() 33 | t = np.array([tx, ty, tz]) 34 | matrix = np.eye(4) 35 | matrix[:3, :3] = R.transpose() 36 | matrix[:3, 3] = -R.transpose().dot(t) 37 | # ned(z-axis down) to z-axis forward 38 | m_correct = np.zeros_like(matrix) 39 | m_correct[0, 1] = 1 40 | m_correct[1, 2] = 1 41 | m_correct[2, 0] = 1 42 | m_correct[3, 3] = 1 43 | matrix = np.matmul(m_correct, matrix) 44 | 45 | item = { 46 | key: matrix, 47 | inv_key: np.linalg.pinv(matrix), 48 | } 49 | data['Frame{}:{}'.format(frame, camera)] = item 50 | lineid += 1 51 | 52 | return data 53 | 54 | 55 | def read_tartanair_depth(depth_fn, K=np.array([[320.0, 0, 320.0, 0], 56 | [0, 320, 240.0, 0], 57 | [0, 0, 1, 0], 58 | [0, 0, 0, 1]])): 59 | 60 | if '.npy' in depth_fn: 61 | depth = np.load(depth_fn) 62 | elif '.png' in depth_fn: 63 | depth = cv2.imread(depth_fn, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) 64 | # [0, 655.35] meter 65 | depth = depth / 100.0 66 | else: 67 | raise TypeError('only support png and npy format, invalid type found: {}'.format(depth_fn)) 68 | 69 | f = K[0, 0] 70 | b = 0.25 # meter 71 | 72 | disp = b * f / (depth + 1e-5) 73 | 74 | return depth, disp 75 | 76 | def read_tartanair_flow(flow_fn): 77 | """Convert to (h, w, 2) (flow_x, flow_y) float32 array""" 78 | 79 | out_flow = np.load(flow_fn) 80 | 81 | return out_flow 82 | -------------------------------------------------------------------------------- /architecture/data/utils/load_disparity.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | 5 | def load_pfm(file_path): 6 | """ 7 | load image in PFM type. 8 | Args: 9 | file_path string: file path(absolute) 10 | Returns: 11 | data (numpy.array): data of image in (Height, Width[, 3]) layout 12 | scale (float): scale of image 13 | """ 14 | with open(file_path, encoding="ISO-8859-1") as fp: 15 | color = None 16 | width = None 17 | height = None 18 | scale = None 19 | endian = None 20 | 21 | # load file header and grab channels, if is 'PF' 3 channels else 1 channel(gray scale) 22 | header = fp.readline().rstrip() 23 | if header == 'PF': 24 | color = True 25 | elif header == 'Pf': 26 | color = False 27 | else: 28 | raise Exception('Not a PFM file.') 29 | 30 | # grab image dimensions 31 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', fp.readline()) 32 | if dim_match: 33 | width, height = map(int, dim_match.groups()) 34 | else: 35 | raise Exception('Malformed PFM header.') 36 | 37 | # grab image scale 38 | scale = float(fp.readline().rstrip()) 39 | if scale < 0: # little-endian 40 | endian = '<' 41 | scale = -scale 42 | else: 43 | endian = '>' # big-endian 44 | 45 | # grab image data 46 | data = np.fromfile(fp, endian + 'f') 47 | shape = (height, width, 3) if color else (height, width) 48 | 49 | # reshape data to [Height, Width, Channels] 50 | data = np.reshape(data, shape) 51 | data = np.flipud(data) 52 | 53 | return data, scale 54 | 55 | 56 | # load utils 57 | def load_scene_flow_disp(img_path): 58 | """load scene flow disparity image 59 | Args: 60 | img_path: 61 | Returns: 62 | """ 63 | assert img_path.endswith('.pfm'), "scene flow disparity image must end with .pfm" \ 64 | "but got {}".format(img_path) 65 | 66 | disp_img, __ = load_pfm(img_path) 67 | 68 | return disp_img 69 | 70 | 71 | def load_eth3d_disp(img_path): 72 | """load scene flow disparity image 73 | Args: 74 | img_path: 75 | Returns: 76 | """ 77 | assert img_path.endswith('.pfm'), "eth3d disparity image must end with .pfm" \ 78 | "but got {}".format(img_path) 79 | 80 | disp_img, __ = load_pfm(img_path) 81 | 82 | return disp_img 83 | 84 | 85 | def load_middlebury_disp(img_path): 86 | """load scene flow disparity image 87 | Args: 88 | img_path: 89 | Returns: 90 | """ 91 | assert img_path.endswith('.pfm'), "middlebury disparity image must end with .pfm" \ 92 | "but got {}".format(img_path) 93 | 94 | disp_img, __ = load_pfm(img_path) 95 | 96 | return disp_img 97 | -------------------------------------------------------------------------------- /architecture/data/datasets/kitti/kitti2015.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch 5 | 6 | from architecture.data.utils import read_kitti_intrinsic, read_kitti_extrinsic, read_kitti_png_disparity 7 | 8 | from .base import KITTIStereoDatasetBase 9 | 10 | class KITTI2015StereoDataset(KITTIStereoDatasetBase): 11 | def __init__(self, annFile, root, height, width, frame_idxs, 12 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 13 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 14 | super(KITTI2015StereoDataset, self).__init__(annFile, root, height, width, frame_idxs, 15 | is_train, use_common_intrinsics, do_same_lr_transform, 16 | mean, std) 17 | 18 | def Loader(self, image_path): 19 | image_path = os.path.join(self.root, image_path) 20 | img = self.img_loader(image_path) 21 | return img 22 | 23 | def dispLoader(self, disp_path, K=None): 24 | disp_path = os.path.join(self.root, disp_path) 25 | depth, disp = read_kitti_png_disparity(disp_path, K) 26 | 27 | depth = torch.from_numpy(depth.astype(np.float32)).unsqueeze(0) 28 | disp = torch.from_numpy(disp.astype(np.float32)).unsqueeze(0) 29 | 30 | return depth, disp 31 | 32 | def extrinsicLoader(self, extrinsic_path): 33 | # the transformation matrix is from world original point to current frame 34 | extrinsic_path = os.path.join(self.root, extrinsic_path) 35 | extrinsics = read_kitti_extrinsic(extrinsic_path) 36 | 37 | return extrinsics 38 | 39 | def getExtrinsic(self, extrinsics, image_path): 40 | # 'testing/sequences/000000/image_2/000000_10.png' 41 | img_id = int(image_path.split('/')[-1].split('.')[0].split('_')[-1]) 42 | # the transformation matrix is from world original point to current frame 43 | # 4x4 44 | left_T = torch.from_numpy(extrinsics['Frame{:02d}:02'.format(img_id)]['T_cam02']) 45 | left_inv_T = torch.from_numpy(extrinsics['Frame{:02d}:02'.format(img_id)]['inv_T_cam02']) 46 | pose = extrinsics.get('Frame{:02d}:03', None) 47 | if pose is not None: 48 | right_T = pose['T_cam03'] 49 | right_inv_T = pose['inv_T_cam03'] 50 | else: 51 | right_T = torch.eye(4) 52 | right_inv_T = torch.eye(4) 53 | 54 | return left_T, left_inv_T, right_T, right_inv_T 55 | 56 | def intrinsicLoader(self, intrinsic_path): 57 | intrinsic_path = os.path.join(self.root, intrinsic_path) 58 | K, resolution = read_kitti_intrinsic(intrinsic_path) 59 | K = K['02']['K_cam02'] 60 | full_K = K.copy() 61 | h, w = resolution 62 | norm_K = np.eye(4) 63 | norm_K[0, :] = K[0, :] / w 64 | norm_K[1, :] = K[1, :] / h 65 | return norm_K.copy(), full_K, resolution 66 | 67 | 68 | -------------------------------------------------------------------------------- /architecture/utils/visualization/colormap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from typing import Union 4 | 5 | def colormap(_cmap, *args, normalize:bool=True, format:str='HWC', **kwargs): 6 | """ 7 | color a given array 8 | Args: 9 | _cmap: (str, callable function), the colormap or direct give function handle 10 | normalize: (bool), whether perform max-min-normalization 11 | format: (str), the format of output colormap, 12 | if 'CHW' -> [Channel, Height, Width], elif 'HWC' -> [Height, Width, Channel] 13 | args: (Tensor, numpy.array), the inputs is required in cmap function 14 | [..., H, W] 15 | kwargs: (dict), args required in cmap function 16 | 17 | Outputs: 18 | 19 | """ 20 | inputs = [] 21 | for input in args: 22 | if isinstance(input, torch.Tensor): 23 | input = input.detach().cpu().numpy() 24 | 25 | if normalize: 26 | ma = float(input.max()) 27 | mi = float(input.min()) 28 | d = ma - mi if ma != mi else 1e5 29 | input = (input - mi) / d 30 | 31 | if input.ndim == 4: 32 | if input.shape[1] == 1: 33 | # one channel map 34 | input = input[0, 0] 35 | elif input.shape[1] == 2: 36 | # here, we fix it as flow 37 | input = input[0] 38 | input = input.transpose(1, 2, 0) 39 | else: 40 | input = input[0] 41 | 42 | elif input.ndim == 3: 43 | if input.shape[0] == 1: 44 | input = input[0] 45 | elif input.shape[0] == 2: 46 | input = input.transpose(1, 2, 0) 47 | else: 48 | pass 49 | 50 | elif input.ndim == 2: 51 | pass 52 | else: 53 | assert input.ndim > 1 and input.ndim < 5, input.ndim 54 | 55 | inputs.append(input) 56 | 57 | if isinstance(_cmap, str): 58 | if _cmap == 'plasma': 59 | _COLORMAP = plt.get_cmap('plasma', 256) # for plotting 60 | elif _cmap == 'gray': 61 | _COLORMAP = plt.get_cmap('gray') 62 | elif _cmap == 'jet': 63 | _COLORMAP = plt.get_cmap('jet') 64 | else: 65 | raise ValueError("invalid type: {} received, only support ['jet', 'plasma', 'gray']".format(cmap)) 66 | 67 | map = _COLORMAP(*inputs, **kwargs) 68 | 69 | elif callable(_cmap): 70 | map = _cmap(*inputs, **kwargs) 71 | 72 | else: 73 | raise ValueError(_cmap) 74 | 75 | # [H, W, C], value range [0, 1] 76 | map = map[..., :3] 77 | 78 | if format == 'HWC': 79 | pass 80 | elif format == 'CHW': 81 | map = map.transpose(2, 0, 1) 82 | else: 83 | raise ValueError(format) 84 | 85 | return map -------------------------------------------------------------------------------- /projects/TemporalStereo/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pytorch_lightning.loggers import TensorBoardLogger 3 | from pytorch_lightning.utilities import rank_zero_only 4 | import time 5 | 6 | from torch.utils.collect_env import get_pretty_env_info 7 | import PIL 8 | 9 | def get_pil_version(): 10 | return "\n Pillow ({})".format(PIL.__version__) 11 | 12 | def collect_env_info(): 13 | env_str = get_pretty_env_info() 14 | env_str += get_pil_version() 15 | return env_str 16 | 17 | def sec_to_hm(t): 18 | """Convert time in seconds to time in hours, minutes and seconds 19 | e.g. 10239 -> (2, 50, 39) 20 | """ 21 | t = int(t) 22 | s = t % 60 23 | t //= 60 24 | m = t % 60 25 | t //= 60 26 | return t, m, s 27 | 28 | 29 | def sec_to_hm_str(t): 30 | """Convert time in seconds to a nice string 31 | e.g. 10239 -> '02h50m39s' 32 | """ 33 | h, m, s = sec_to_hm(t) 34 | return "{:02d}h{:02d}m{:02d}s".format(h, m, s) 35 | 36 | class FileWriter: 37 | def __init__(self, save_path): 38 | self._save_path = save_path 39 | self.start_time = time.time() 40 | self.num_total_steps = 0 41 | 42 | self.set_log_file() 43 | 44 | self.set_start_time(time.time()) 45 | 46 | def set_log_file(self, filename="log.txt"): 47 | os.makedirs(self._save_path, mode=0o777, exist_ok=True) 48 | self.log_file_path = os.path.join(self._save_path, filename) 49 | with open(self.log_file_path, "w") as fp: 50 | fp.write("Start Recording!") 51 | self.stdout(collect_env_info()) 52 | 53 | @rank_zero_only 54 | def stdout(self, outstr): 55 | with open(self.log_file_path, "a") as fp: 56 | fp.write(outstr+"\n") 57 | print(outstr) 58 | 59 | @rank_zero_only 60 | def set_num_total_steps(self, steps): 61 | self.num_total_steps = steps 62 | 63 | @rank_zero_only 64 | def set_start_time(self, tm): 65 | self.start_time = tm 66 | 67 | @rank_zero_only 68 | def log_time(self, current_step, current_epoch, batch_idx, batch_size, duration, loss): 69 | """Print a logging statement to the terminal 70 | """ 71 | samples_per_sec = batch_size / duration 72 | time_sofar = time.time() - self.start_time 73 | training_time_left = (self.num_total_steps / current_step - 1.0) * time_sofar if current_step > 0 else 0 74 | print_string = "\nEpoch {:>3} | batch {:>6} | examples/s: {:5.1f}" + \ 75 | " | loss: {:.5f} | time elapsed: {} | time left: {}" 76 | self.stdout(print_string.format(current_epoch, batch_idx, samples_per_sec, loss, 77 | sec_to_hm_str(time_sofar), sec_to_hm_str(training_time_left))) 78 | 79 | class Logger(TensorBoardLogger): 80 | def __init__(self, *args, **kwargs): 81 | super(Logger, self).__init__(*args, **kwargs) 82 | 83 | self._filewriter = FileWriter(self.log_dir) 84 | 85 | @property 86 | def filewriter(self): 87 | return self._filewriter 88 | -------------------------------------------------------------------------------- /projects/TemporalStereo/configs/kittiraw-multi.yaml: -------------------------------------------------------------------------------- 1 | LOG_DIR: "./exps/" 2 | FRAME_IDXS: [-7, -6, -5, -4, -3, -2, -1, 0, ] 3 | 4 | CHECKPOINT: 5 | EVERY_N_TRAIN_STEPS: 0 6 | EVERY_N_EPOCHS: 1 7 | 8 | TRAINER: 9 | NAME: "TemporalStereo" 10 | NUM_GPUS: 1 11 | GRADIENT_CLIP_VAL: 0.1 12 | VERSION: "kittiraw/multi" 13 | MAX_EPOCHS: 10 14 | CHECK_VAL_EVERY_N_EPOCHS: 1 15 | LOAD_FROM_CHECKPOINT: "./checkpoints/tartanair_multi.ckpt" 16 | 17 | SCHEDULER: 18 | TYPE: 'MultiStepLR' 19 | MULTI_STEP_LR: 20 | MILESTONES: [7,] 21 | GAMMA: 0.1 22 | 23 | OPTIMIZER: 24 | TYPE: 'RMSProp' 25 | RMSPROP: 26 | LR: 1e-3 27 | 28 | MODEL: 29 | WITH_PREVIOUS: True 30 | PREVIOUS_WITH_GRADIENT: False 31 | USE_PAST_COST: True 32 | LOCAL_MAP_SIZE: 3 33 | BACKBONE: 34 | NAME: "TEMPORALSTEREO" 35 | IN_PLANES: 3 36 | MEMORY_PERCENT: 0.5 37 | AGGREGATION: 38 | NAME: "TEMPORALSTEREO" 39 | COARSE: 40 | IN_PLANES: 256 41 | C: 32 42 | NUM_SAMPLE: 12 43 | DELTA: 1.0 44 | BLOCK_COST_SCALE: 3 45 | TOPK: 2 46 | SPATIAL_FUSION: True 47 | FINE: 48 | IN_PLANES: 128 49 | C: 16 50 | NUM_SAMPLE: 5 51 | DELTA: 1.0 52 | BLOCK_COST_SCALE: 3 53 | TOPK: 2 54 | SPATIAL_FUSION: True 55 | PRECISE: 56 | IN_PLANES: 64 57 | C: 8 58 | NUM_SAMPLE: 5 59 | DELTA: 1.0 60 | BLOCK_COST_SCALE: 3 61 | TOPK: 2 62 | LOSSES: 63 | SMOOTH_L1_LOSS: 64 | GLOBAL_WEIGHT: 1.0 65 | WEIGHTS: [2.0, 1.0, 0.7, 0.5] 66 | SPARSE: True 67 | WARSSERSTEIN_DISTANCE_LOSS: 68 | GLOBAL_WEIGHT: 2.0 69 | WEIGHTS: [1.0, 0.7, 0.5] 70 | SPARSE: True 71 | 72 | DATA: 73 | TRAIN: 74 | TYPE: "KITTIRAW" 75 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTIRAW/" 76 | ANNFILE: "./splits/kittiraw/view_8_train_all.json" 77 | BATCH_SIZE: 4 78 | NUM_WORKERS: 16 79 | USE_COMMON_INTRINSICS: False 80 | DO_SAME_LR_TRANSFORM: False 81 | HEIGHT: 320 82 | WIDTH: 1184 83 | FRAME_IDXS: [-7, -6, -5, -4, -3, -2, -1, 0, ] 84 | VAL: 85 | TYPE: "KITTI2015" 86 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 87 | ANNFILE: "./splits/kitti2015/view_11_train_all.json" 88 | BATCH_SIZE: 1 89 | NUM_WORKERS: 8 90 | USE_COMMON_INTRINSICS: False 91 | DO_SAME_LR_TRANSFORM: True 92 | HEIGHT: 384 93 | WIDTH: 1248 94 | FRAME_IDXS: [-7, -6, -5, -4, -3, -2, -1, 0, ] 95 | TEST: 96 | TYPE: "KITTI2015" 97 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 98 | ANNFILE: "./splits/kitti2015/view_11_train_all.json" 99 | BATCH_SIZE: 1 100 | NUM_WORKERS: 2 101 | USE_COMMON_INTRINSICS: False 102 | DO_SAME_LR_TRANSFORM: True 103 | HEIGHT: 384 104 | WIDTH: 1248 105 | FRAME_IDXS: [-7, -6, -5, -4, -3, -2, -1, 0, ] 106 | 107 | VAL: 108 | EVAL_DISPARITY_IDS: [0, 1, 2, 3, 4, 5, 6, 7, 8] 109 | DO_OCCLUSION_EVALUATION: False 110 | VIS_INTERVAL: 3 111 | 112 | -------------------------------------------------------------------------------- /projects/TemporalStereo/configs/kitti2015-multi.yaml: -------------------------------------------------------------------------------- 1 | LOG_DIR: "./exps/" 2 | FRAME_IDXS: [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, ] 3 | 4 | CHECKPOINT: 5 | EVERY_N_TRAIN_STEPS: 0 6 | EVERY_N_EPOCHS: 4 7 | 8 | TRAINER: 9 | NAME: "TemporalStereo" 10 | NUM_GPUS: 1 11 | GRADIENT_CLIP_VAL: 0.1 12 | VERSION: "kitti2015/multi" 13 | MAX_EPOCHS: 16 14 | CHECK_VAL_EVERY_N_EPOCHS: 4 15 | LOAD_FROM_CHECKPOINT: "./checkpoints/kittiraw.ckpt" 16 | 17 | SCHEDULER: 18 | TYPE: 'MultiStepLR' 19 | MULTI_STEP_LR: 20 | MILESTONES: [12,] 21 | GAMMA: 0.1 22 | 23 | OPTIMIZER: 24 | TYPE: 'RMSProp' 25 | RMSPROP: 26 | LR: 1e-4 27 | 28 | MODEL: 29 | WITH_PREVIOUS: True 30 | PREVIOUS_WITH_GRADIENT: False 31 | USE_PAST_COST: True 32 | LOCAL_MAP_SIZE: 3 33 | BACKBONE: 34 | NAME: "TEMPORALSTEREO" 35 | IN_PLANES: 3 36 | MEMORY_PERCENT: 0.5 37 | AGGREGATION: 38 | NAME: "TEMPORALSTEREO" 39 | COARSE: 40 | IN_PLANES: 256 41 | C: 32 42 | NUM_SAMPLE: 12 43 | DELTA: 1.0 44 | BLOCK_COST_SCALE: 3 45 | TOPK: 2 46 | SPATIAL_FUSION: True 47 | FINE: 48 | IN_PLANES: 128 49 | C: 16 50 | NUM_SAMPLE: 5 51 | DELTA: 1.0 52 | BLOCK_COST_SCALE: 3 53 | TOPK: 2 54 | SPATIAL_FUSION: True 55 | PRECISE: 56 | IN_PLANES: 64 57 | C: 8 58 | NUM_SAMPLE: 5 59 | DELTA: 1.0 60 | BLOCK_COST_SCALE: 3 61 | TOPK: 2 62 | LOSSES: 63 | SMOOTH_L1_LOSS: 64 | GLOBAL_WEIGHT: 1.0 65 | WEIGHTS: [2.0, 1.0, 0.7, 0.5] 66 | SPARSE: True 67 | WARSSERSTEIN_DISTANCE_LOSS: 68 | GLOBAL_WEIGHT: 2.0 69 | WEIGHTS: [1.0, 0.7, 0.5] 70 | SPARSE: True 71 | 72 | DATA: 73 | TRAIN: 74 | TYPE: "KITTI2015" 75 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 76 | ANNFILE: "./splits/kitti2015/view_11_train_all.json" 77 | BATCH_SIZE: 4 78 | NUM_WORKERS: 8 79 | USE_COMMON_INTRINSICS: False 80 | DO_SAME_LR_TRANSFORM: False 81 | HEIGHT: 320 82 | WIDTH: 1184 83 | FRAME_IDXS: [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, ] 84 | VAL: 85 | TYPE: "KITTI2015" 86 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 87 | ANNFILE: "./splits/kitti2015/view_11_val.json" 88 | BATCH_SIZE: 1 89 | NUM_WORKERS: 2 90 | USE_COMMON_INTRINSICS: False 91 | DO_SAME_LR_TRANSFORM: True 92 | HEIGHT: 384 93 | WIDTH: 1248 94 | FRAME_IDXS: [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, ] 95 | TEST: 96 | TYPE: "KITTI2015" 97 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 98 | ANNFILE: "./splits/kitti2015/view_11_train_all.json" 99 | BATCH_SIZE: 1 100 | NUM_WORKERS: 2 101 | USE_COMMON_INTRINSICS: False 102 | DO_SAME_LR_TRANSFORM: True 103 | HEIGHT: 384 104 | WIDTH: 1248 105 | FRAME_IDXS: [-10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, ] 106 | 107 | VAL: 108 | EVAL_DISPARITY_IDS: [0, 1, 2, 3, 4, 5, 6, 7, 8] 109 | DO_OCCLUSION_EVALUATION: False 110 | VIS_INTERVAL: 3 111 | 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | outputs 4 | *.bin 5 | *.pyc 6 | _ext 7 | __pycache__ 8 | checkpoints 9 | results/ 10 | *.json 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | cover/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | .pybuilder/ 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | # For a library or package, you might want to ignore these files since the code is 98 | # intended to run in multiple environments; otherwise, check them in: 99 | # .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # MacOS 135 | .DS_Store 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ 147 | 148 | # pytype static type analyzer 149 | .pytype/ 150 | 151 | # Cython debug symbols 152 | cython_debug/ 153 | 154 | *.pth 155 | *.pkl 156 | *.npy 157 | *.ipynb 158 | **/.ipynb_checkpoints/ 159 | *.swn 160 | *.swo 161 | *.swp 162 | *~ 163 | 164 | exps 165 | -------------------------------------------------------------------------------- /architecture/modeling/prediction/soft_argmin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from detectron2.config import configurable 5 | 6 | from .builder import PREDICTION_REGISTRY 7 | 8 | @PREDICTION_REGISTRY.register() 9 | class SOFTARGMIN(nn.Module): 10 | """ 11 | A faster implementation of soft argmin. 12 | Args: 13 | temperature, (float): a temperature will times with cost_volume, i.e., the temperature coefficient 14 | details can refer to: https://bouthilx.wordpress.com/2013/04/21/a-soft-argmax/ 15 | normalize, (bool): whether apply softmax on cost_volume, default True 16 | Inputs: 17 | cost_volume, (Tensor): the matching cost after regularization, 18 | [BatchSize, disp_sample_number, Height, Width] layout 19 | disp_sample, (Tensor): the estimated disparity samples, 20 | [BatchSize, disp_sample_number, Height, Width] layout. 21 | Returns: 22 | disp_map, (Tensor): a disparity map regressed from cost volume, 23 | [BatchSize, 1, Height, Width] layout 24 | """ 25 | @configurable 26 | def __init__(self, temperature:float=1.0, normalize:bool=True): 27 | super(SOFTARGMIN, self).__init__() 28 | self.temperature = temperature 29 | self.normalize = normalize 30 | 31 | @classmethod 32 | def from_config(cls, cfg): 33 | return { 34 | "temperature": cfg.MODEL.PREDICTION.get("TEMPERATURE", 1.0), 35 | "normalize": cfg.MODEL.PREDICTION.get("NORMALIZE", True), 36 | } 37 | 38 | def forward(self, cost_volume, disp_sample): 39 | 40 | # note, cost volume direct represent similarity 41 | # 'c' or '-c' do not affect the performance because feature-based cost volume provided flexibility. 42 | 43 | if cost_volume.dim() != 4: 44 | raise ValueError('expected 4D input (got {}D input)' 45 | .format(cost_volume.dim())) 46 | 47 | # scale cost volume with temperature 48 | cost_volume = cost_volume * self.temperature 49 | 50 | if self.normalize: 51 | prob_volume = F.softmax(cost_volume, dim=1) 52 | else: 53 | prob_volume = cost_volume 54 | 55 | assert prob_volume.shape == disp_sample.shape, 'The shape of disparity samples and cost volume should be' \ 56 | ' consistent!' 57 | 58 | # compute disparity: (BatchSize, 1, Height, Width) 59 | disp_map = torch.sum(prob_volume * disp_sample, dim=1, keepdim=True) 60 | 61 | return disp_map 62 | 63 | def __repr__(self): 64 | repr_str = '{}\n'.format(self.__class__.__name__) 65 | repr_str += ' ' * 4 + 'Temperature: {}\n'.format(self.temperature) 66 | repr_str += ' ' * 4 + 'Normalize: {}\n'.format(self.normalize) 67 | 68 | return repr_str 69 | 70 | @property 71 | def name(self): 72 | return 'SoftArgmin' -------------------------------------------------------------------------------- /architecture/data/utils/load_drivingstereo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def read_drivingstereo_intrinsic(intrinsic_fn): 6 | lineid = 0 7 | data = {} 8 | with open(intrinsic_fn, 'r') as fp: 9 | for line in fp.readlines(): 10 | values = line.rstrip().split(' ') 11 | if lineid == 0: 12 | lineid += 1 13 | continue 14 | frame = int(values[0]) 15 | camera = int(values[1]) 16 | key = 'K_cam{}'.format(camera) 17 | inv_key = 'inv_K_cam{}'.format(int(camera)) 18 | matrix = np.array([float(values[i]) for i in range(2, len(values))]) 19 | K = np.eye(4) 20 | K[0, 0] = matrix[0] 21 | K[1, 1] = matrix[1] 22 | K[0, 2] = matrix[2] 23 | K[1, 2] = matrix[3] 24 | item = { 25 | key: K, 26 | inv_key: np.linalg.pinv(K) 27 | } 28 | data['Frame{}:{}'.format(frame, camera)] = item 29 | lineid += 1 30 | 31 | return data 32 | 33 | 34 | def read_drivingstereo_extrinsic(extrinsic_fn): 35 | lineid = 0 36 | data = {} 37 | with open(extrinsic_fn, 'r') as fp: 38 | for line in fp.readlines(): 39 | values = line.rstrip().split(' ') 40 | if lineid == 0: 41 | lineid += 1 42 | continue 43 | frame = int(values[0]) 44 | camera = int(values[1]) 45 | key = 'T_cam{}'.format(int(camera)) 46 | inv_key = 'inv_T_cam{}'.format(int(camera)) 47 | matrix = np.array([float(values[i]) for i in range(2, len(values))]) 48 | matrix = matrix.reshape(4, 4) 49 | item = { 50 | key: matrix, 51 | inv_key: np.linalg.pinv(matrix), 52 | } 53 | data['Frame{}:{}'.format(frame, camera)] = item 54 | lineid += 1 55 | 56 | return data 57 | 58 | 59 | def read_drivingstereo_png_depth(depth_fn, K=np.array([[725.0087, 0, 620.5, 0], 60 | [0, 725.0087, 187, 0], 61 | [0, 0, 1, 0], 62 | [0, 0, 0, 1]]), div_factor=256.0): 63 | 64 | depth = cv2.imread(depth_fn, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_ANYCOLOR) 65 | # The disparity value and depth value for each pixel can be computed by 66 | # converting the uint16 value to float and dividing it by 256. 67 | # The zero values indicate the invalid pixels. 68 | # Different from half-resulution disparity maps, 69 | # the disparity value for each pixel in the full-resolution map 70 | # is computed by converting the uint16 value to float and dividing it by 128. 71 | # uint16, full div 128, hafl div 256 72 | valid = depth > 0 73 | depth = (depth * valid) / div_factor 74 | 75 | f = K[0, 0] 76 | b = 0.5443450 # meter 77 | 78 | disp = b * f / depth 79 | 80 | return depth, disp 81 | 82 | def read_drivingstereo_png_flow(flow_fn): 83 | raise NotImplementedError("DrivingStereo has no ground truth for optical flow!!!") 84 | 85 | -------------------------------------------------------------------------------- /projects/TemporalStereo/configs/kitti2015.yaml: -------------------------------------------------------------------------------- 1 | LOG_DIR: "./exps/" 2 | FRAME_IDXS: [0, ] 3 | 4 | CHECKPOINT: 5 | EVERY_N_TRAIN_STEPS: 0 6 | EVERY_N_EPOCHS: 25 7 | 8 | TRAINER: 9 | NAME: "TemporalStereo" 10 | NUM_GPUS: 1 11 | GRADIENT_CLIP_VAL: 0.1 12 | VERSION: "kitti2015/single" 13 | MAX_EPOCHS: 120 14 | CHECK_VAL_EVERY_N_EPOCHS: 5 15 | LOAD_FROM_CHECKPOINT: "/path/to/kitti_raw.ckpt" 16 | 17 | SCHEDULER: 18 | TYPE: 'MultiStepLR' 19 | MULTI_STEP_LR: 20 | MILESTONES: [60, 100,] 21 | GAMMA: 0.1 22 | 23 | OPTIMIZER: 24 | TYPE: 'RMSProp' 25 | RMSPROP: 26 | LR: 1e-4 27 | 28 | MODEL: 29 | WITH_PREVIOUS: False 30 | PREVIOUS_WITH_GRADIENT: False 31 | WITH_FLOW: False 32 | USE_LOCAL_MAP: False 33 | LOCAL_MAP_SIZE: 0 34 | VIS_FEATURE: False 35 | BACKBONE: 36 | NAME: "TEMPORALSTEREO" 37 | IN_PLANES: 3 38 | AGGREGATION: 39 | NAME: "TEMPORALSTEREO" 40 | COARSE: 41 | IN_PLANES: 256 42 | C: 32 43 | NUM_SAMPLE: 12 44 | DELTA: 1.0 45 | BLOCK_COST_SCALE: 3 46 | TOPK: 2 47 | SPATIAL_FUSION: True 48 | FINE: 49 | IN_PLANES: 128 50 | C: 16 51 | NUM_SAMPLE: 5 52 | DELTA: 1.0 53 | BLOCK_COST_SCALE: 3 54 | TOPK: 2 55 | SPATIAL_FUSION: True 56 | PRECISE: 57 | IN_PLANES: 64 58 | C: 8 59 | NUM_SAMPLE: 5 60 | DELTA: 1.0 61 | BLOCK_COST_SCALE: 3 62 | TOPK: 2 63 | LOSSES: 64 | SMOOTH_L1_LOSS: 65 | GLOBAL_WEIGHT: 1.0 66 | WEIGHTS: [2.0, 1.5, 1.0, 0.7, 0.5] 67 | SPARSE: True 68 | WARSSERSTEIN_DISTANCE_LOSS: 69 | GLOBAL_WEIGHT: 2.0 70 | WEIGHTS: [1.0, 0.7, 0.5] 71 | SPARSE: True 72 | 73 | DATA: 74 | TRAIN: 75 | TYPE: "KITTI2015" 76 | # DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2012/" 77 | # ANNFILE: "./splits/kitti2012/view_11_train_all.json" 78 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 79 | ANNFILE: "./splits/kitti2015/view_11_train_all.json" 80 | BATCH_SIZE: 2 81 | NUM_WORKERS: 4 82 | USE_COMMON_INTRINSICS: False 83 | DO_SAME_LR_TRANSFORM: False 84 | HEIGHT: 320 85 | WIDTH: 1184 86 | FRAME_IDXS: [0, ] 87 | VAL: 88 | TYPE: "KITTI2015" 89 | # DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2012/" 90 | # ANNFILE: "./splits/kitti2012/view_11_train_all.json" 91 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 92 | ANNFILE: "./splits/kitti2015/view_11_train_all.json" 93 | BATCH_SIZE: 1 94 | NUM_WORKERS: 2 95 | USE_COMMON_INTRINSICS: False 96 | DO_SAME_LR_TRANSFORM: True 97 | HEIGHT: 384 98 | WIDTH: 1248 99 | FRAME_IDXS: [0, ] 100 | TEST: 101 | TYPE: "KITTI2015" 102 | # DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2012/" 103 | # ANNFILE: "./splits/kitti2012/view_11_train_all.json" 104 | DATA_ROOT: "./datasets/KITTI-Multiview/KITTI-2015/" 105 | ANNFILE: "./splits/kitti2015/view_11_train_all.json" 106 | BATCH_SIZE: 1 107 | NUM_WORKERS: 2 108 | USE_COMMON_INTRINSICS: False 109 | DO_SAME_LR_TRANSFORM: True 110 | HEIGHT: 384 111 | WIDTH: 1248 112 | FRAME_IDXS: [0, ] 113 | 114 | VAL: 115 | EVAL_DISPARITY_IDS: [0, 1, 2, 3, 4, 5, 6, 7, 8] 116 | DO_OCCLUSION_EVALUATION: False 117 | VIS_INTERVAL: 3 118 | 119 | -------------------------------------------------------------------------------- /architecture/data/utils/load_vkitti.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def read_vkitti_intrinsic(intrinsic_fn): 6 | lineid = 0 7 | data = {} 8 | with open(intrinsic_fn, 'r') as fp: 9 | for line in fp.readlines(): 10 | values = line.rstrip().split(' ') 11 | if lineid == 0: 12 | lineid += 1 13 | continue 14 | frame = int(values[0]) 15 | camera = int(values[1]) 16 | key = 'K_cam{}'.format(camera) 17 | inv_key = 'inv_K_cam{}'.format(int(camera)) 18 | matrix = np.array([float(values[i]) for i in range(2, len(values))]) 19 | K = np.eye(4) 20 | K[0, 0] = matrix[0] 21 | K[1, 1] = matrix[1] 22 | K[0, 2] = matrix[2] 23 | K[1, 2] = matrix[3] 24 | item = { 25 | key: K, 26 | inv_key: np.linalg.pinv(K) 27 | } 28 | data['Frame{}:{}'.format(frame, camera)] = item 29 | lineid += 1 30 | 31 | return data 32 | 33 | 34 | def read_vkitti_extrinsic(extrinsic_fn): 35 | lineid = 0 36 | data = {} 37 | with open(extrinsic_fn, 'r') as fp: 38 | for line in fp.readlines(): 39 | values = line.rstrip().split(' ') 40 | if lineid == 0: 41 | lineid += 1 42 | continue 43 | frame = int(values[0]) 44 | camera = int(values[1]) 45 | key = 'T_cam{}'.format(int(camera)) 46 | inv_key = 'inv_T_cam{}'.format(int(camera)) 47 | matrix = np.array([float(values[i]) for i in range(2, len(values))]) 48 | matrix = matrix.reshape(4, 4) 49 | item = { 50 | key: matrix, 51 | inv_key: np.linalg.pinv(matrix), 52 | } 53 | data['Frame{}:{}'.format(frame, camera)] = item 54 | lineid += 1 55 | 56 | return data 57 | 58 | 59 | def read_vkitti_png_depth(depth_fn, K=np.array([[725.0087, 0, 620.5, 0], 60 | [0, 725.0087, 187, 0], 61 | [0, 0, 1, 0], 62 | [0, 0, 0, 1]])): 63 | 64 | depth = cv2.imread(depth_fn, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_ANYCOLOR) 65 | # [0, 655.35] meter 66 | depth = depth / 100.0 67 | 68 | f = K[0, 0] 69 | b = 0.532725 # meter 70 | 71 | disp = b * f / depth 72 | 73 | return depth, disp 74 | 75 | def read_vkitti_png_flow(flow_fn): 76 | """Convert from .png to (h, w, 2) (flow_x, flow_y) float32 array""" 77 | # read png to bgr in 16 bit unsigned short 78 | 79 | bgr = cv2.imread(flow_fn, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) 80 | h, w, _c = bgr.shape 81 | assert bgr.dtype == np.uint16 and _c == 3 82 | # b == invalid flow flag == 0 for sky or other invalid flow 83 | invalid = bgr[..., 0] == 0 84 | # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1] 85 | out_flow = 2.0 / (2**16 - 1.0) * bgr[..., 2:0:-1].astype('f4') - 1 86 | out_flow[..., 0] *= w - 1 87 | out_flow[..., 1] *= h - 1 88 | out_flow[invalid] = 0 # or another value (e.g., np.nan) 89 | 90 | return out_flow 91 | 92 | -------------------------------------------------------------------------------- /architecture/data/datasets/tartanair/tartanair.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch 5 | 6 | from architecture.data.utils import read_tartanair_flow, read_tartanair_extrinsic, read_tartanair_depth 7 | 8 | from .base import TARTANAIRStereoDatasetBase 9 | 10 | class TARTANAIRStereoDataset(TARTANAIRStereoDatasetBase): 11 | def __init__(self, annFile, root, height, width, frame_idxs, 12 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 13 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 14 | super(TARTANAIRStereoDataset, self).__init__(annFile, root, height, width, frame_idxs, 15 | is_train, use_common_intrinsics, do_same_lr_transform, 16 | mean, std) 17 | 18 | def Loader(self, image_path): 19 | image_path = os.path.join(self.root, image_path) 20 | img = self.img_loader(image_path) 21 | return img 22 | 23 | def depthLoader(self, depth_path, K=None): 24 | depth_path = os.path.join(self.root, depth_path) 25 | depth, disp = read_tartanair_depth(depth_path, K) 26 | 27 | depth = torch.from_numpy(depth.astype(np.float32)).unsqueeze(0) 28 | disp = torch.from_numpy(disp.astype(np.float32)).unsqueeze(0) 29 | 30 | return depth, disp 31 | 32 | def flowLoader(self, flow_path): 33 | flow_path = os.path.join(self.root, flow_path) 34 | flow = read_tartanair_flow(flow_path) 35 | 36 | flow = torch.from_numpy(flow.astype(np.float32)).permute(2, 0, 1).contiguous() 37 | 38 | return flow 39 | 40 | def extrinsicLoader(self, extrinsic_path): 41 | # the transformation matrix is from world original point to current frame 42 | left_extrinsic_path = os.path.join(self.root, extrinsic_path, 'pose_left.txt') 43 | left_extrinsics = read_tartanair_extrinsic(left_extrinsic_path, 'left') 44 | right_extrinsic_path = os.path.join(self.root, extrinsic_path, 'pose_right.txt') 45 | right_extrinsics = read_tartanair_extrinsic(right_extrinsic_path, 'right') 46 | extrinsics = {} 47 | extrinsics.update(left_extrinsics) 48 | extrinsics.update(right_extrinsics) 49 | 50 | return extrinsics 51 | 52 | def getExtrinsic(self, extrinsics, image_path): 53 | # 'hospital/Easy/P000/image_left/000021_left.png' 54 | img_id = int(image_path.split('/')[-1].split('.')[0].split('_')[0]) 55 | # the transformation matrix is from world original point to current frame 56 | # 4x4 57 | left_T = torch.from_numpy(extrinsics['Frame{}:0'.format(img_id)]['T_cam0']) 58 | left_inv_T = torch.from_numpy(extrinsics['Frame{}:0'.format(img_id)]['inv_T_cam0']) 59 | right_T = torch.from_numpy(extrinsics['Frame{}:1'.format(img_id)]['T_cam1']) 60 | right_inv_T = torch.from_numpy(extrinsics['Frame{}:1'.format(img_id)]['inv_T_cam1']) 61 | 62 | return left_T, left_inv_T, right_T, right_inv_T 63 | 64 | def intrinsicLoader(self, intrinsic_path): 65 | norm_K = self.K.copy() 66 | full_K = np.eye(4) 67 | resolution = self.full_resolution 68 | h, w = resolution 69 | full_K[0, :] = norm_K[0, :] * w 70 | full_K[1, :] = norm_K[1, :] * h 71 | 72 | return norm_K, full_K, resolution 73 | 74 | 75 | -------------------------------------------------------------------------------- /architecture/data/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.insert(0, '/home/yzhang/projects/mono/StereoBenchmark/') 5 | from architecture.data.datasets import VKITTI2StereoDataset 6 | from architecture.data.datasets import SceneFlowStereoDataset 7 | from architecture.data.datasets import TARTANAIRStereoDataset 8 | from architecture.data.datasets import KITTI2015StereoDataset 9 | from architecture.data.datasets import KITTIRAWStereoDataset 10 | 11 | def build_stereo_dataset(cfg, phase): 12 | 13 | data_root = cfg.DATA_ROOT 14 | data_type = cfg.TYPE 15 | annFile = cfg.ANNFILE 16 | height = cfg.HEIGHT 17 | width = cfg.WIDTH 18 | frame_idxs = cfg.FRAME_IDXS 19 | use_common_intrinsics = cfg.get('USE_COMMON_INTRINSICS', True) 20 | do_same_lr_transform = cfg.get('DO_SAME_LR_TRANSFORM', True) 21 | mean = cfg.get('MEAN', (0.485, 0.456, 0.406)) 22 | std = cfg.get('STD', (0.229, 0.224, 0.225)) 23 | 24 | 25 | is_train = True if phase == 'train' else False 26 | 27 | if 'VKITTI2' in data_type: 28 | dataset = VKITTI2StereoDataset(annFile, data_root, height, width, frame_idxs, is_train, use_common_intrinsics, 29 | do_same_lr_transform, mean, std) 30 | 31 | elif 'SceneFlow' in data_type: 32 | dataset = SceneFlowStereoDataset(annFile, data_root, height, width, frame_idxs, is_train, use_common_intrinsics, 33 | do_same_lr_transform, mean, std) 34 | 35 | elif 'TartanAir' in data_type: 36 | dataset = TARTANAIRStereoDataset(annFile, data_root, height, width, frame_idxs, is_train, use_common_intrinsics, 37 | do_same_lr_transform, mean, std) 38 | 39 | elif 'KITTI2015' in data_type: 40 | dataset = KITTI2015StereoDataset(annFile, data_root, height, width, frame_idxs, is_train, use_common_intrinsics, 41 | do_same_lr_transform, mean, std) 42 | 43 | elif 'KITTIRAW' in data_type: 44 | dataset = KITTIRAWStereoDataset(annFile, data_root, height, width, frame_idxs, is_train, use_common_intrinsics, 45 | do_same_lr_transform, mean, std) 46 | 47 | else: 48 | raise ValueError("invalid data type: {}".format(data_type)) 49 | 50 | return dataset 51 | 52 | 53 | if __name__ == '__main__': 54 | """ 55 | Test the Stereo Dataset 56 | """ 57 | import sys 58 | sys.path.insert(0, '/home/yzhang/projects/stereo/TemporalStereo/') 59 | import matplotlib.pyplot as plt 60 | 61 | from projects.TemporalStereo.config import get_cfg, get_parser 62 | args = get_parser().parse_args() 63 | args.config_file = '/home/yzhang/projects/stereo/TemporalStereo/projects/TemporalStereo/configs/kitti2015.yaml' 64 | cfg = get_cfg(args) 65 | 66 | dataset = build_stereo_dataset(cfg.DATA.TRAIN, 'train') 67 | 68 | print(dataset) 69 | 70 | print("Dataset contains {} items".format(len(dataset))) 71 | 72 | idxs = [0, ] # 100, 1000] 73 | for i in idxs: 74 | sample = dataset[i] 75 | for key in list(sample): 76 | _include_keys = ['color_aug', 'color', 'disp_gt', 'depth_gt'] 77 | if key[0] in _include_keys: 78 | print('Key {} with shape: {}'.format(key, sample[key].shape)) 79 | img = sample[key].permute(1, 2, 0).cpu().numpy() 80 | plt.imshow(img) 81 | plt.show() 82 | 83 | print('Done!') 84 | 85 | 86 | -------------------------------------------------------------------------------- /architecture/data/evaluation/flow_pixel_error.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | def zero_mask(input, eps=1e-12): 6 | mask = abs(input) < eps 7 | return mask 8 | 9 | def flow_calc_error(est_flow=None, gt_flow=None, lb=0.0, ub=400, sparse=False): 10 | """ 11 | Args: 12 | est_flow: (Tensor), estimated flow map 13 | [..., 2, Height, Width] layout 14 | gt_flow: (Tensor), ground truth flow map 15 | [..., 2, Height, Width] layout 16 | lb: (scalar), the lower bound of disparity you want to mask out 17 | ub: (scalar), the upper bound of disparity you want to mask out 18 | sparse: (bool), whether the given flow is sparse, default False 19 | Output: 20 | dict: the error of 1px, 2px, 3px, 5px, in percent, 21 | range [0,100] and average error epe 22 | """ 23 | error1 = torch.Tensor([0.]) 24 | error2 = torch.Tensor([0.]) 25 | error3 = torch.Tensor([0.]) 26 | error5 = torch.Tensor([0.]) 27 | epe = torch.Tensor([0.]) 28 | 29 | if (not torch.is_tensor(est_flow)) or (not torch.is_tensor(gt_flow)): 30 | return { 31 | '1px': error1 * 100, 32 | '2px': error2 * 100, 33 | '3px': error3 * 100, 34 | '5px': error5 * 100, 35 | 'epe': epe 36 | } 37 | 38 | assert torch.is_tensor(est_flow) and torch.is_tensor(gt_flow) 39 | assert est_flow.shape == gt_flow.shape 40 | 41 | est_flow = est_flow.clone().cpu() 42 | gt_flow = gt_flow.clone().cpu() 43 | if len(gt_flow.shape) == 3: 44 | gt_flow = gt_flow.unsqueeze(0) 45 | est_flow = est_flow.unsqueeze(0) 46 | 47 | assert gt_flow.shape[1] == 2, "flow should have horizontal and vertical dimension, " \ 48 | "but got {}".format(gt_flow.shape[1]) 49 | 50 | # [B, 1, H, W] 51 | gt_u, gt_v = gt_flow[:, 0:1, :, :], gt_flow[:, 1:2, :, :] 52 | est_u, est_v = est_flow[:, 0:1, :, :], est_flow[:, 1:2, :, :] 53 | 54 | # get valid mask 55 | # [B, 1, H, W] 56 | mask = torch.ones(gt_u.shape, dtype=torch.bool, device=gt_u.device) 57 | if sparse: 58 | mask = mask & (~(zero_mask(gt_u) & zero_mask(gt_v))) 59 | mask = mask & (~(torch.isnan(gt_u) | torch.isnan(gt_v))) 60 | 61 | rad = torch.sqrt(gt_u**2 + gt_v**2) 62 | mask = mask & (rad > lb) & (rad < ub) 63 | 64 | mask.detach_() 65 | if abs(mask.float().sum()) < 1.0: 66 | return { 67 | '1px': error1 * 100, 68 | '2px': error2 * 100, 69 | '3px': error3 * 100, 70 | '5px': error5 * 100, 71 | 'epe': epe 72 | } 73 | 74 | 75 | 76 | gt_u = gt_u[mask] 77 | gt_v = gt_v[mask] 78 | est_u = est_u[mask] 79 | est_v = est_v[mask] 80 | 81 | abs_error = torch.sqrt((gt_u - est_u)**2 + (gt_v - est_v)**2) 82 | total_num = mask.float().sum() 83 | 84 | error1 = torch.sum(torch.gt(abs_error, 1).float()) / total_num 85 | error2 = torch.sum(torch.gt(abs_error, 2).float()) / total_num 86 | error3 = torch.sum(torch.gt(abs_error, 3).float()) / total_num 87 | error5 = torch.sum(torch.gt(abs_error, 5).float()) / total_num 88 | epe = abs_error.float().mean() 89 | 90 | return { 91 | '1px': error1 * 100, 92 | '2px': error2 * 100, 93 | '3px': error3 * 100, 94 | '5px': error5 * 100, 95 | 'epe': epe 96 | } -------------------------------------------------------------------------------- /architecture/modeling/aggregation/utils/cat_fms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from architecture.modeling.layers import inverse_warp_3d 4 | 5 | def cat_fms(reference_fm, target_fm, disp_sample): 6 | """ 7 | perform concatenation between left and rith image feature to construct 4D cost volume 8 | 9 | Args: 10 | reference_fm: (Tensor), the feature map of reference image, often left image 11 | [BatchSize, C, H, W] 12 | target_fm: (Tensor), the feature map of target image, often right image 13 | [BatchSize, C, H, W] 14 | disp_sample: (Tensor), the disparity samples/candidates for feature concatenation or matching 15 | [BatchSize, NumSamples, H, W] 16 | 17 | Returns: 18 | concat_fm: (Tensor), the concatenated feature map 19 | [BatchSize, 2C, NumSamples, H, W] 20 | """ 21 | B, C, H, W = reference_fm.shape 22 | 23 | # the number of disparity samples 24 | D = disp_sample.shape[1] 25 | 26 | # expand D dimension 27 | concat_reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W) 28 | concat_target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W) 29 | 30 | # shift target feature according to disparity samples 31 | concat_target_fm = inverse_warp_3d(concat_target_fm, -disp_sample, padding_mode='zeros') 32 | 33 | # [B, 2C, D, H, W) 34 | concat_fm = torch.cat((concat_reference_fm, concat_target_fm), dim=1) 35 | 36 | return concat_fm 37 | 38 | 39 | if __name__ == '__main__': 40 | """ 41 | GPU: GTX3090, CUDA:11.0, Torch:1.7.1 42 | CAT_FMS reference forward once takes 5.3421ms, i.e. 187.19fps 43 | """ 44 | print("Feature Concatenation Test...") 45 | from architecture.utils import timeTestTemplate 46 | 47 | # -------------------------------------- Value Test-------------------------------------- # 48 | H, W = 3, 4 49 | device = torch.device('cuda:0') 50 | left = torch.linspace(1, H * W, H * W).reshape(1, 1, H, W).to(device) 51 | right = torch.linspace(H * W + 1, H * W * 2, H * W).reshape(1, 1, H, W).to(device) 52 | print('left: \n ', left) 53 | print('right: \n ', right) 54 | 55 | disp_samples = torch.linspace(-2, 2, 5).repeat(1, H, W, 1). \ 56 | permute(0, 3, 1, 2).contiguous().to(device) 57 | 58 | print("Disparity Samples/Candidates: \n", disp_samples) 59 | 60 | 61 | cost = cat_fms(left, right, disp_samples) 62 | print('Cost in shape: ', cost.shape) 63 | idx = 0 64 | for i in range(-2, 3, 1): 65 | print('Disparity {}:\n {}'.format(i, cost[:, :, idx, ])) 66 | idx += 1 67 | 68 | for i in range(cost.shape[1]): 69 | print('Channel {}:\n {}'.format(i, cost[:, i, ])) 70 | 71 | # -------------------------------------- Time Test-------------------------------------- # 72 | 73 | C, H, W = 32, 384//4, 1248//4 # size in KITTI 74 | device = torch.device('cuda:0') 75 | left = torch.rand(1, C, H, W).to(device) 76 | right = torch.rand(1, C, H, W).to(device) 77 | 78 | max_disp = 192//4 79 | disp_samples = torch.linspace(0, max_disp-1, max_disp).repeat(1, H, W, 1). \ 80 | permute(0, 3, 1, 2).contiguous().to(device) 81 | 82 | avg_time = timeTestTemplate(cat_fms, left, right, disp_samples, iters=50, device=torch.device('cuda:0')) 83 | 84 | print('{} reference forward once takes {:.4f}ms, i.e. {:.2f}fps'.format('CAT_FMS', avg_time * 1000, (1 / avg_time))) 85 | 86 | 87 | -------------------------------------------------------------------------------- /architecture/data/datasets/vkitti/vkitti_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch 5 | 6 | from architecture.data.utils import read_vkitti_png_depth, read_vkitti_png_flow, read_vkitti_extrinsic, read_vkitti_intrinsic 7 | 8 | from .base import VKITTIStereoDatasetBase 9 | 10 | class VKITTI2StereoDataset(VKITTIStereoDatasetBase): 11 | def __init__(self, annFile, root, height, width, frame_idxs, 12 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 13 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 14 | super(VKITTI2StereoDataset, self).__init__(annFile, root, height, width, frame_idxs, 15 | is_train, use_common_intrinsics, do_same_lr_transform, 16 | mean, std) 17 | 18 | def Loader(self, image_path): 19 | image_path = os.path.join(self.root, image_path) 20 | img = self.img_loader(image_path) 21 | return img 22 | 23 | def depthLoader(self, depth_path, K=None): 24 | if self.use_common_intrinsics: 25 | K = np.eye(4) 26 | K[0, :] = self.K[0, :] * self.full_resolution[1] 27 | K[1, :] = self.K[1, :] * self.full_resolution[0] 28 | else: 29 | # 'Scene01/15-deg-left/frames/depth/Camera_0/depth_00001.png' 30 | scene, variation, _, _, camera, depth_name = depth_path.split('/') 31 | intrinsic_path = os.path.join(self.root, scene, variation, 'intrinsic.txt') 32 | intrinsics = read_vkitti_intrinsic(intrinsic_path) 33 | frame_idx = int(depth_name.split('.')[0].split('_')[-1]) 34 | camera_idx = int(camera.split('_')[-1]) 35 | K = intrinsics['Frame{}:{}'.format(frame_idx, camera_idx)]['K_cam{}'.format(camera_idx)] 36 | 37 | depth_path = os.path.join(self.root, depth_path) 38 | depth, disp = read_vkitti_png_depth(depth_path, K) 39 | 40 | depth = torch.from_numpy(depth.astype(np.float32)).unsqueeze(0) 41 | disp = torch.from_numpy(disp.astype(np.float32)).unsqueeze(0) 42 | 43 | return depth, disp 44 | 45 | def flowLoader(self, flow_path): 46 | flow_path = os.path.join(self.root, flow_path) 47 | flow = read_vkitti_png_flow(flow_path) 48 | 49 | flow = torch.from_numpy(flow.astype(np.float32)).permute(2, 0, 1).contiguous() 50 | 51 | return flow 52 | 53 | def extrinsicLoader(self, extrinsic_path): 54 | # the transformation matrix is from world original point to current frame 55 | extrinsic_path = os.path.join(self.root, extrinsic_path) 56 | extrinsics = read_vkitti_extrinsic(extrinsic_path) 57 | 58 | return extrinsics 59 | 60 | def getExtrinsic(self, extrinsics, image_path): 61 | # 'Scene01/15-deg-left/frames/rgb/Camera_0/rgb_00009.jpg' 62 | img_id = int(image_path.split('/')[-1].split('.')[0].split('_')[-1]) 63 | # the transformation matrix is from world original point to current frame 64 | # 4x4 65 | left_T = torch.from_numpy(extrinsics['Frame{}:0'.format(img_id)]['T_cam0']) 66 | left_inv_T = torch.from_numpy(extrinsics['Frame{}:0'.format(img_id)]['inv_T_cam0']) 67 | right_T = torch.from_numpy(extrinsics['Frame{}:1'.format(img_id)]['T_cam1']) 68 | right_inv_T = torch.from_numpy(extrinsics['Frame{}:1'.format(img_id)]['inv_T_cam1']) 69 | 70 | return left_T, left_inv_T, right_T, right_inv_T 71 | 72 | # todo: vkiit2 intrinsic loader 73 | def intrinsicLoader(self, intrinsic_path): 74 | return self.K.copy(), self.full_resolution 75 | 76 | 77 | -------------------------------------------------------------------------------- /architecture/modeling/aggregation/utils/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | try: 5 | from spatial_correlation_sampler import SpatialCorrelationSampler 6 | except: 7 | pass 8 | 9 | 10 | def correlation(reference_fm, target_fm, patch_size=1, 11 | kernel_size=1, stride=1, padding=0, dilation=1, dilation_patch=1): 12 | # for a pixel of left image at (x, y), it will calculates correlation cost volume 13 | # with pixel of right image at (xr, y), where xr in [x-(patch_size-1)//2, x+(patch_size-1)//2] 14 | correlation_sampler = SpatialCorrelationSampler(kernel_size= kernel_size, 15 | patch_size=patch_size, 16 | stride=stride, 17 | padding=padding, 18 | dilation=dilation, 19 | dilation_patch=dilation_patch) 20 | # [B, pH, pW, H, W] 21 | out = correlation_sampler(reference_fm, target_fm) 22 | 23 | B, pH, pW, H, W = out.shape 24 | 25 | out = out.reshape(B, pH * pW, H, W) 26 | 27 | cost = F.leaky_relu(out, negative_slope=0.1, inplace=True) 28 | 29 | return cost 30 | 31 | 32 | def correlation1d(reference_fm, target_fm, max_disp=1, 33 | kernel_size=1, stride=1, padding=0, dilation=1, dilation_patch=1): 34 | 35 | patch_size = (1, 2 * max_disp - 1) 36 | # for a pixel of left image at (x, y), it will calculates correlation cost volume 37 | # with pixel of right image at (xr, y), where xr in [x-max_disp, x+max_disp] 38 | # but we only need the left half part, i.e., [x-max_disp, 0] 39 | correlation_sampler = SpatialCorrelationSampler(kernel_size= kernel_size, 40 | patch_size=patch_size, 41 | stride=stride, 42 | padding=padding, 43 | dilation=dilation, 44 | dilation_patch=dilation_patch) 45 | # [B, 1, 2*max_disp-1, H, W] 46 | out = correlation_sampler(reference_fm, target_fm) 47 | 48 | B, pH, pW, H, W = out.shape 49 | 50 | out = out.reshape(B, pH * pW, H, W) 51 | 52 | out = out[:, :max_disp, :, :] 53 | 54 | # [B, max_disp, H, W] 55 | cost = F.leaky_relu(out, negative_slope=0.1, inplace=True) 56 | 57 | return cost 58 | 59 | 60 | if __name__ == '__main__': 61 | print("Test Correlation...") 62 | import time 63 | 64 | iters = 50 65 | scale = 8 66 | B, C, H, W = 1, 40, 384, 1248 67 | device = 'cuda:0' 68 | 69 | prev = torch.randn(B, C, H//scale, W//scale, device=device) 70 | curr = torch.randn(B, C, H//scale, W//scale, device=device) 71 | 72 | cost = correlation(prev, curr, patch_size=9, kernel_size=1) 73 | print('cost with shape: ', cost.shape) 74 | 75 | start_time = time.time() 76 | 77 | for i in range(iters): 78 | with torch.no_grad(): 79 | correlation(prev, curr, patch_size=21, kernel_size=1) 80 | 81 | torch.cuda.synchronize(device) 82 | end_time = time.time() 83 | avg_time = (end_time - start_time) / iters 84 | 85 | 86 | print('{} reference forward once takes {:.4f}ms, i.e. {:.2f}fps'.format('Correlation', avg_time * 1000, (1 / avg_time))) 87 | 88 | print("Done!") 89 | 90 | """ 91 | Correlation2d: at scale=16, patch_size=21, reference forward once takes 0.6607ms, i.e. 1513.52fps 92 | """ 93 | -------------------------------------------------------------------------------- /architecture/data/datasets/scene_flow/scene_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch 5 | 6 | from architecture.data.utils import read_sceneflow_pfm_disparity, read_sceneflow_pfm_flow, read_sceneflow_extrinsic 7 | 8 | from .base import SceneFlowStereoDatasetBase 9 | 10 | class SceneFlowStereoDataset(SceneFlowStereoDatasetBase): 11 | def __init__(self, annFile, root, height, width, frame_idxs, 12 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 13 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 14 | super(SceneFlowStereoDataset, self).__init__(annFile, root, height, width, frame_idxs, 15 | is_train, use_common_intrinsics, do_same_lr_transform, 16 | mean, std) 17 | 18 | def Loader(self, image_path): 19 | image_path = os.path.join(self.root, image_path) 20 | try: 21 | img = self.img_loader(image_path) 22 | return img 23 | except: 24 | print("Image Reading Error: {}".format(image_path)) 25 | exit(-1) 26 | 27 | def dispLoader(self, disp_path, K=None): 28 | disp_path = os.path.join(self.root, disp_path) 29 | depth, disp = read_sceneflow_pfm_disparity(disp_path, K) 30 | 31 | depth = torch.from_numpy(depth.copy().astype(np.float32)).unsqueeze(0) 32 | disp = torch.from_numpy(disp.copy().astype(np.float32)).unsqueeze(0) 33 | 34 | return depth, disp 35 | 36 | def flowLoader(self, flow_path): 37 | flow_path = os.path.join(self.root, flow_path) 38 | flow = read_sceneflow_pfm_flow(flow_path) 39 | 40 | flow = torch.from_numpy(flow.copy().astype(np.float32)).unsqueeze(0) 41 | 42 | return flow 43 | 44 | def intrinsicLoader(self, intrinsic_path): 45 | if '15mm' in intrinsic_path: 46 | norm_K = self.K15.copy() 47 | else: 48 | norm_K = self.K.copy() 49 | 50 | full_K = np.eye(4) 51 | resolution = self.full_resolution 52 | h, w = resolution 53 | full_K[0, :] = norm_K[0, :] * w 54 | full_K[1, :] = norm_K[1, :] * h 55 | 56 | return norm_K, full_K, resolution 57 | 58 | def extrinsicLoader(self, extrinsic_path): 59 | # the transformation matrix is from world original point to current frame 60 | extrinsic_path = os.path.join(self.root, extrinsic_path) 61 | extrinsics = read_sceneflow_extrinsic(extrinsic_path) 62 | return extrinsics 63 | 64 | def getExtrinsic(self, extrinsics, image_path): 65 | # 'SceneFlow/FlyingThings3D/frames_cleanpass/TRAIN/A/0000/left/0006.png' 66 | img_id = int(image_path.split('/')[-1].split('.')[0]) 67 | # the transformation matrix is from world original point to current frame 68 | # 4x4 69 | try: 70 | left_T = torch.from_numpy(extrinsics['Frame{}:0'.format(img_id)]['T_cam0']) 71 | left_inv_T = torch.from_numpy(extrinsics['Frame{}:0'.format(img_id)]['inv_T_cam0']) 72 | right_T = torch.from_numpy(extrinsics['Frame{}:1'.format(img_id)]['T_cam1']) 73 | right_inv_T = torch.from_numpy(extrinsics['Frame{}:1'.format(img_id)]['inv_T_cam1']) 74 | except: 75 | # print("There is no extrinsic parameter for image: {}, set to identical matrix!".format(image_path)) 76 | left_T = torch.eye(4, 4) 77 | left_inv_T = torch.eye(4, 4) 78 | right_T = torch.eye(4, 4) 79 | right_inv_T = torch.eye(4, 4) 80 | 81 | return left_T, left_inv_T, right_T, right_inv_T 82 | -------------------------------------------------------------------------------- /architecture/modeling/aggregation/utils/dif_fms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from architecture.modeling.layers import inverse_warp_3d 4 | 5 | def dif_fms(reference_fm, target_fm, disp_sample): 6 | """ 7 | perform substraction(i.e., compare difference) between left and rith image feature to construct 4D cost volume 8 | 9 | Args: 10 | reference_fm: (Tensor), the feature map of reference image, often left image 11 | [BatchSize, C, H, W] 12 | target_fm: (Tensor), the feature map of target image, often right image 13 | [BatchSize, C, H, W] 14 | disp_sample: (Tensor), the disparity samples/candidates for feature concatenation or matching 15 | [BatchSize, NumSamples, H, W] 16 | 17 | Returns: 18 | dif_fm: (Tensor), the substracted feature map 19 | [BatchSize, C, NumSamples, H, W] 20 | """ 21 | B, C, H, W = reference_fm.shape 22 | 23 | # the number of disparity samples 24 | D = disp_sample.shape[1] 25 | 26 | # expand D dimension 27 | dif_reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W) 28 | dif_target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W) 29 | 30 | # shift target feature according to disparity samples 31 | dif_target_fm = inverse_warp_3d(dif_target_fm, -disp_sample, padding_mode='zeros') 32 | 33 | # [B, C, D, H, W) 34 | dif_fm = torch.abs(dif_reference_fm - dif_target_fm) 35 | 36 | # fill the outliers with max cost 37 | max_dif = dif_fm.max() 38 | ones = torch.ones_like(dif_fm) 39 | no_empty_mask = (dif_target_fm > 0).float() 40 | # [B, C, D, H, W) 41 | dif_fm = (dif_fm * no_empty_mask) + (1 - no_empty_mask) * ones * max_dif 42 | 43 | 44 | return dif_fm 45 | 46 | 47 | if __name__ == '__main__': 48 | """ 49 | GPU: GTX3090, CUDA:11.0, Torch:1.7.1 50 | DIF_FMS reference forward once takes 8.3691ms, i.e. 119.49fps 51 | """ 52 | print("Feature Substraction Test...") 53 | from architecture.utils import timeTestTemplate 54 | 55 | # -------------------------------------- Value Test-------------------------------------- # 56 | H, W = 3, 4 57 | device = torch.device('cuda:0') 58 | left = torch.linspace(1, H * W, H * W).reshape(1, 1, H, W).to(device) 59 | right = torch.linspace(H * W + 1, H * W * 2, H * W).reshape(1, 1, H, W).to(device) 60 | print('left: \n ', left) 61 | print('right: \n ', right) 62 | 63 | disp_samples = torch.linspace(-2, 2, 5).repeat(1, H, W, 1). \ 64 | permute(0, 3, 1, 2).contiguous().to(device) 65 | 66 | print("Disparity Samples/Candidates: \n", disp_samples) 67 | 68 | cost = dif_fms(left, right, disp_samples) 69 | print('Cost in shape: ', cost.shape) 70 | idx = 0 71 | for i in range(-2, 3, 1): 72 | print('Disparity {}:\n {}'.format(i, cost[:, :, idx, ])) 73 | idx += 1 74 | 75 | for i in range(cost.shape[1]): 76 | print('Channel {}:\n {}'.format(i, cost[:, i, ])) 77 | 78 | # -------------------------------------- Time Test-------------------------------------- # 79 | 80 | C, H, W = 32, 384 // 4, 1248 // 4 # size in KITTI 81 | device = torch.device('cuda:0') 82 | left = torch.rand(1, C, H, W).to(device) 83 | right = torch.rand(1, C, H, W).to(device) 84 | 85 | max_disp = 192 // 4 86 | disp_samples = torch.linspace(0, max_disp - 1, max_disp).repeat(1, H, W, 1). \ 87 | permute(0, 3, 1, 2).contiguous().to(device) 88 | 89 | avg_time = timeTestTemplate(dif_fms, left, right, disp_samples) 90 | 91 | print('{} reference forward once takes {:.4f}ms, i.e. {:.2f}fps'.format('DIF_FMS', avg_time * 1000, (1 / avg_time))) 92 | 93 | 94 | -------------------------------------------------------------------------------- /architecture/data/datasets/kitti/kittiraw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import torch 5 | 6 | from architecture.data.utils.calibration.kitti_calib import load_calib, read_calib_file 7 | from architecture.data.utils import read_kitti_intrinsic, read_kitti_extrinsic, read_kitti_png_disparity 8 | 9 | from .base import KITTIStereoDatasetBase 10 | 11 | class KITTIRAWStereoDataset(KITTIStereoDatasetBase): 12 | def __init__(self, annFile, root, height, width, frame_idxs, 13 | is_train=False, use_common_intrinsics=False, do_same_lr_transform=True, 14 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): 15 | super(KITTIRAWStereoDataset, self).__init__(annFile, root, height, width, frame_idxs, 16 | is_train, use_common_intrinsics, do_same_lr_transform, 17 | mean, std) 18 | 19 | def Loader(self, image_path): 20 | image_path = os.path.join(self.root, image_path) 21 | img = self.img_loader(image_path) 22 | return img 23 | 24 | def dispLoader(self, disp_path, K=None): 25 | disp_path = os.path.join(self.root, disp_path) 26 | depth, disp = read_kitti_png_disparity(disp_path, K) 27 | 28 | depth = torch.from_numpy(depth.astype(np.float32)).unsqueeze(0) 29 | disp = torch.from_numpy(disp.astype(np.float32)).unsqueeze(0) 30 | 31 | return depth, disp 32 | 33 | def extrinsicLoader(self, extrinsic_path): 34 | lineid = 0 35 | data = {} 36 | extrinsic_path = os.path.join(self.root, extrinsic_path) 37 | with open(extrinsic_path, 'r') as fp: 38 | for line in fp.readlines(): 39 | values = line.rstrip().split(' ') 40 | frame = '{:04d}'.format(lineid) 41 | camera = '02' 42 | key = 'T_cam{}'.format(camera) 43 | inv_key = 'inv_T_cam{}'.format(camera) 44 | matrix = np.array([float(values[i]) for i in range(len(values))]) 45 | matrix = matrix.reshape(3, 4) 46 | matrix = np.concatenate((matrix, np.array([[0, 0, 0, 1]])), axis=0) 47 | item = { 48 | # inverse pose 49 | key: np.linalg.pinv(matrix), 50 | inv_key: matrix, 51 | } 52 | data['Frame{}:{}'.format(frame, camera)] = item 53 | lineid += 1 54 | 55 | return data 56 | 57 | def getExtrinsic(self, extrinsics, image_path): 58 | # rawdata/2011_09_26/2011_09_26_drive_0095_sync/image_03/data/0000000001.png 59 | img_id = int(image_path.split('/')[-1].split('.')[0]) 60 | # the transformation matrix is from world original point to current frame 61 | # 4x4 62 | left_T = torch.from_numpy(extrinsics['Frame{:04d}:02'.format(img_id)]['T_cam02']) 63 | left_inv_T = torch.from_numpy(extrinsics['Frame{:04d}:02'.format(img_id)]['inv_T_cam02']) 64 | pose = extrinsics.get('Frame{:04d}:03', None) 65 | if pose is not None: 66 | right_T = pose['T_cam03'] 67 | right_inv_T = pose['inv_T_cam03'] 68 | else: 69 | right_T = torch.eye(4) 70 | right_inv_T = torch.eye(4) 71 | 72 | return left_T, left_inv_T, right_T, right_inv_T 73 | 74 | def intrinsicLoader(self, intrinsic_path): 75 | intrinsic_path = os.path.join(self.root, intrinsic_path) 76 | data = read_calib_file(intrinsic_path) 77 | K = data['P_rect_02'].reshape(3, 4)[:3, :3] 78 | resolution = data['S_rect_02'][[1,0]] 79 | full_K = K.copy() 80 | h, w = resolution 81 | norm_K = np.eye(4) 82 | norm_K[0, :3] = K[0, :] / w 83 | norm_K[1, :3] = K[1, :] / h 84 | return norm_K.copy(), full_K, resolution 85 | 86 | 87 | -------------------------------------------------------------------------------- /architecture/data/utils/load_kitti.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | 5 | def read_kitti_intrinsic(intrinsic_fn): 6 | data = {} 7 | resolution = None 8 | with open(intrinsic_fn, 'r') as fp: 9 | for line in fp.readlines(): 10 | if line.find('P_rect_02') > -1: 11 | line = line[11:] 12 | values = line.rstrip().split(' ') 13 | camera = '02' 14 | key = 'K_cam{}'.format(camera) 15 | inv_key = 'inv_K_cam{}'.format(int(camera)) 16 | matrix = np.array([float(values[i]) for i in range(len(values))]) 17 | K = np.eye(4) 18 | K[0, 0] = matrix[0] 19 | K[0, 2] = matrix[2] 20 | K[1, 1] = matrix[5] 21 | K[1, 2] = matrix[6] 22 | item = { 23 | key: K, 24 | inv_key: np.linalg.pinv(K) 25 | } 26 | data['{}'.format(camera)] = item 27 | # S_rect_02: 1.242000e+03 3.750000e+02 28 | if line.find('S_rect_02') > -1: 29 | line = line[11:] 30 | values = line.rstrip().split(' ') 31 | w, h = float(values[0]), float(values[1]) 32 | resolution = (h, w) 33 | 34 | return data, resolution 35 | 36 | 37 | def read_kitti_extrinsic(extrinsic_fn): 38 | """ 39 | We assume the extrinsic is obtained by ORBSLAM3, 40 | The pose of it is from camera to world, but we need world to camera, so we have to inverse it 41 | """ 42 | lineid = 0 43 | data = {} 44 | with open(extrinsic_fn, 'r') as fp: 45 | for line in fp.readlines(): 46 | values = line.rstrip().split(' ') 47 | frame = '{:02d}'.format(lineid) 48 | camera = '02' 49 | key = 'T_cam{}'.format(camera) 50 | inv_key = 'inv_T_cam{}'.format(camera) 51 | matrix = np.array([float(values[i]) for i in range(len(values))]) 52 | matrix = matrix.reshape(3, 4) 53 | matrix = np.concatenate((matrix, np.array([[0, 0, 0, 1]])), axis=0) 54 | item = { 55 | # inverse pose 56 | key: np.linalg.pinv(matrix), 57 | inv_key: matrix, 58 | } 59 | data['Frame{}:{}'.format(frame, camera)] = item 60 | lineid += 1 61 | 62 | return data 63 | 64 | 65 | def read_kitti_png_disparity(disp_fn, K=np.array([[721.5377, 0, 609.5593, 0], 66 | [0, 721.5377, 172.854, 0], 67 | [0, 0, 1, 0], 68 | [0, 0, 0, 1]])): 69 | 70 | disp = cv2.imread(disp_fn, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_ANYCOLOR) 71 | # uint16 72 | valid_mask = disp > 0 73 | disp = disp / 256.0 74 | 75 | f = K[0, 0] 76 | b = 0.54 # meter 77 | 78 | depth = b * f / (disp + 1e-12) 79 | depth = depth * valid_mask 80 | 81 | return depth, disp 82 | 83 | def read_kitti_png_depth(depth_fn, K=np.array([[721.5377, 0, 609.5593, 0], 84 | [0, 721.5377, 172.854, 0], 85 | [0, 0, 1, 0], 86 | [0, 0, 0, 1]])): 87 | 88 | depth = cv2.imread(depth_fn, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_ANYCOLOR) 89 | # uint16 90 | valid_mask = depth > 0 91 | depth = depth / 256.0 92 | 93 | f = K[0, 0] 94 | b = 0.54 # meter 95 | 96 | disp = b * f / (depth + 1e-12) 97 | disp = disp * valid_mask 98 | 99 | return depth, disp 100 | 101 | def read_kitti_png_flow(flow_fn): 102 | raise NotImplementedError 103 | -------------------------------------------------------------------------------- /architecture/data/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | 5 | from architecture.modeling.layers import inverse_warp 6 | from .pixel_error import calc_error 7 | 8 | 9 | def do_evaluation(est_disp, gt_disp, lb, ub): 10 | """ 11 | Do pixel error evaluation. (See KITTI evaluation protocols for details.) 12 | Args: 13 | est_disp: (Tensor), estimated disparity map, 14 | [..., Height, Width] layout 15 | gt_disp: (Tensor), ground truth disparity map 16 | [..., Height, Width] layout 17 | lb: (scalar), the lower bound of disparity you want to mask out 18 | ub: (scalar), the upper bound of disparity you want to mask out 19 | Returns: 20 | error_dict: (dict), the error of 1px, 2px, 3px, 5px, in percent, 21 | range [0,100] and average error epe 22 | """ 23 | error_dict = {} 24 | if est_disp is None: 25 | warnings.warn('Estimated disparity map is None') 26 | return error_dict 27 | if gt_disp is None: 28 | warnings.warn('Reference ground truth disparity map is None') 29 | return error_dict 30 | 31 | if torch.is_tensor(est_disp): 32 | est_disp = est_disp.clone().cpu() 33 | 34 | if torch.is_tensor(gt_disp): 35 | gt_disp = gt_disp.clone().cpu() 36 | 37 | assert est_disp.shape == gt_disp.shape, "Estimated Disparity map with shape: {}, but GroundTruth Disparity map" \ 38 | " with shape: {}".format(est_disp.shape, gt_disp.shape) 39 | 40 | error_dict = calc_error(est_disp, gt_disp, lb=lb, ub=ub) 41 | 42 | return error_dict 43 | 44 | 45 | def do_occlusion_evaluation(est_disp, ref_gt_disp, target_gt_disp, lb, ub): 46 | """ 47 | Do occlusoin evaluation. 48 | Args: 49 | est_disp: (Tensor), estimated disparity map 50 | [BatchSize, 1, Height, Width] layout 51 | ref_gt_disp: (Tensor), reference(left) ground truth disparity map 52 | [BatchSize, 1, Height, Width] layout 53 | target_gt_disp: (Tensor), target(right) ground truth disparity map, 54 | [BatchSize, 1, Height, Width] layout 55 | lb: (scalar): the lower bound of disparity you want to mask out 56 | ub: (scalar): the upper bound of disparity you want to mask out 57 | Returns: 58 | """ 59 | error_dict = {} 60 | if est_disp is None: 61 | warnings.warn('Estimated disparity map is None, expected given') 62 | return error_dict 63 | if ref_gt_disp is None: 64 | warnings.warn('Reference ground truth disparity map is None, expected given') 65 | return error_dict 66 | if target_gt_disp is None: 67 | warnings.warn('Target ground truth disparity map is None, expected given') 68 | return error_dict 69 | 70 | if torch.is_tensor(est_disp): 71 | est_disp = est_disp.clone().cpu() 72 | if torch.is_tensor(ref_gt_disp): 73 | ref_gt_disp = ref_gt_disp.clone().cpu() 74 | if torch.is_tensor(target_gt_disp): 75 | target_gt_disp = target_gt_disp.clone().cpu() 76 | 77 | assert est_disp.shape == ref_gt_disp.shape and target_gt_disp.shape == ref_gt_disp.shape, "{}, {}, {}".format( 78 | est_disp.shape, ref_gt_disp.shape, target_gt_disp.shape) 79 | 80 | warp_ref_gt_disp = inverse_warp(target_gt_disp.clone(), -ref_gt_disp.clone(), mode='disparity') 81 | theta = 1.0 82 | eps = 1e-6 83 | occlusion = ( 84 | (torch.abs(warp_ref_gt_disp.clone() - ref_gt_disp.clone()) > theta) | 85 | (torch.abs(warp_ref_gt_disp.clone()) < eps) 86 | ).prod(dim=1, keepdim=True).type_as(ref_gt_disp) 87 | occlusion = occlusion.clamp(0, 1) 88 | 89 | occlusion_error_dict = calc_error( 90 | est_disp.clone() * occlusion.clone(), 91 | ref_gt_disp.clone() * occlusion.clone(), 92 | lb=lb, ub=ub 93 | ) 94 | for key in occlusion_error_dict.keys(): 95 | error_dict['occ_' + key] = occlusion_error_dict[key] 96 | 97 | not_occlusion = 1.0 - occlusion 98 | not_occlusion_error_dict = calc_error( 99 | est_disp.clone() * not_occlusion.clone(), 100 | ref_gt_disp.clone() * not_occlusion.clone(), 101 | lb=lb, ub=ub 102 | ) 103 | for key in not_occlusion_error_dict.keys(): 104 | error_dict['noc_' + key] = not_occlusion_error_dict[key] 105 | 106 | return error_dict -------------------------------------------------------------------------------- /architecture/modeling/losses/smooth_l1_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from detectron2.config import configurable 5 | from typing import Optional, Dict, Tuple, Union, List 6 | import warnings 7 | 8 | 9 | class DispSmoothL1Loss(object): 10 | """ 11 | Args: 12 | max_disp: (int), the max of Disparity. default is 192 13 | start_disp: (int), the start searching disparity index, usually be 0 14 | weights: (list, tuple, float, optional): weight for each scale of est disparity map. 15 | sparse: (bool), whether the ground-truth disparity is sparse, 16 | for example, KITTI is sparse, but SceneFlow is not, default is False. 17 | Inputs: 18 | estDisp: (Tensor or List[Tensor]): the estimated disparity maps, 19 | [BatchSize, 1, Height, Width] layout. 20 | gtDisp: (Tensor), the ground truth disparity map, 21 | [BatchSize, 1, Height, Width] layout. 22 | Outputs: 23 | loss: (dict), the loss of each level 24 | """ 25 | @configurable 26 | def __init__(self, max_disp:int, start_disp:int=0, global_weight:float=1.0, 27 | weights:Union[Tuple, List, float, None]=None, sparse:bool=False): 28 | self.max_disp = max_disp 29 | self.start_disp = start_disp 30 | self.global_weight = global_weight 31 | self.weights = weights 32 | self.sparse = sparse 33 | if sparse: 34 | # sparse disparity ==> max_pooling 35 | self.scale_func = F.adaptive_max_pool2d 36 | else: 37 | # dense disparity ==> avg_pooling 38 | self.scale_func = F.adaptive_avg_pool2d 39 | 40 | @classmethod 41 | def from_config(cls, cfg): 42 | return { 43 | "max_disp": cfg.get("MAX_DISP", 192), 44 | "start_disp": cfg.get("START_DISP", 0), 45 | "weights": cfg.get("WEIGHTS", None), 46 | "sparse": cfg.get("SPARSE", False), 47 | } 48 | 49 | def loss_per_level(self, estDisp, gtDisp): 50 | N, C, H, W = estDisp.shape 51 | scaled_gtDisp = gtDisp 52 | scale = 1.0 53 | if gtDisp.shape[-2] != H or gtDisp.shape[-1] != W: 54 | # compute scale per level and scale gtDisp 55 | scale = gtDisp.shape[-1] / (W * 1.0) 56 | scaled_gtDisp = gtDisp / scale 57 | scaled_gtDisp = self.scale_func(scaled_gtDisp, (H, W)) 58 | 59 | # mask for valid disparity 60 | # (start disparity, max disparity / scale) 61 | # Attention: the invalid disparity of KITTI is set as 0, be sure to mask it out 62 | mask = (scaled_gtDisp > self.start_disp) & (scaled_gtDisp < (self.max_disp / scale)) 63 | if mask.sum() < 1.0: 64 | warnings.warn('SmoothL1 loss: there is no point\'s disparity is in ({},{})!'.format(self.start_disp, 65 | self.max_disp / scale)) 66 | loss = (torch.abs(estDisp - scaled_gtDisp) * mask.float()).mean() 67 | return loss 68 | 69 | # smooth l1 loss 70 | loss = F.smooth_l1_loss(estDisp[mask], scaled_gtDisp[mask], reduction='mean') 71 | 72 | return loss 73 | 74 | def __call__(self, estDisp, gtDisp): 75 | if not isinstance(estDisp, (list, tuple)): 76 | estDisp = [estDisp] 77 | 78 | if self.weights is None: 79 | self.weights = [1.0] * len(estDisp) 80 | 81 | # compute loss for per level 82 | loss_all_level = [] 83 | for est_disp_per_lvl in estDisp: 84 | loss_all_level.append( 85 | self.loss_per_level(est_disp_per_lvl, gtDisp) 86 | ) 87 | 88 | # re-weight loss per level 89 | weighted_loss_all_level = dict() 90 | for i, loss_per_level in enumerate(loss_all_level): 91 | name = "l1_loss_lvl{}".format(i) 92 | weighted_loss_all_level[name] = self.weights[i] * loss_per_level * self.global_weight 93 | 94 | return weighted_loss_all_level 95 | 96 | def __repr__(self): 97 | repr_str = '{}\n'.format(self.__class__.__name__) 98 | repr_str += ' ' * 4 + 'Max Disparity: {}\n'.format(self.max_disp) 99 | repr_str += ' ' * 4 + 'Start disparity: {}\n'.format(self.start_disp) 100 | repr_str += ' ' * 4 + 'Global Loss weight: {}\n'.format(self.global_weight) 101 | repr_str += ' ' * 4 + 'Loss weights: {}\n'.format(self.weights) 102 | repr_str += ' ' * 4 + 'GT Disparity is sparse: {}\n'.format(self.sparse) 103 | 104 | return repr_str 105 | 106 | @property 107 | def name(self): 108 | return 'SmoothL1Loss' -------------------------------------------------------------------------------- /architecture/modeling/aggregation/utils/block_cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from architecture.modeling.layers import inverse_warp_3d 5 | 6 | def groupwise_correlation(fea1, fea2): 7 | B, C, D, H, W = fea1.shape 8 | channels_per_group = 8 9 | assert C % channels_per_group == 0 10 | num_groups = C // channels_per_group 11 | cost = -torch.pow((fea1 - fea2), 2.0).view([B, num_groups, channels_per_group, D, H, W]).sum(dim=2) 12 | assert cost.shape == (B, num_groups, D, H, W) 13 | return cost 14 | 15 | 16 | def block_cost(reference_fm, target_fm, disp_sample, block_cost_scale=3): 17 | """ 18 | perform concatenation and groupwise correlation between left and rith image feature to construct 4D cost volume 19 | 20 | Args: 21 | reference_fm: (Tensor), the feature map of reference image, often left image 22 | [BatchSize, C, H, W] 23 | target_fm: (Tensor), the feature map of target image, often right image 24 | [BatchSize, C, H, W] 25 | disp_sample: (Tensor), the disparity samples/candidates for feature concatenation or matching 26 | [BatchSize, NumSamples, H, W] 27 | 28 | Returns: 29 | concat_fm: (Tensor), the concatenated feature map 30 | [BatchSize, 2C, NumSamples, H, W] 31 | """ 32 | B, C, H, W = reference_fm.shape 33 | 34 | if isinstance(disp_sample, int): 35 | max_disp = disp_sample 36 | # [b, c, h, max_disp-1+w] 37 | padded_target_fm = F.pad(target_fm, pad=(max_disp-1, 0, 0, 0), mode='constant', value=0.0) 38 | unfolded_target_fm = F.unfold(padded_target_fm, kernel_size=(1, max_disp), dilation=(1, 1), padding=(0, 0), stride=(1, 1)) 39 | unfolded_target_fm = unfolded_target_fm.reshape(B, C, max_disp, H, W) 40 | # [max_disp-1, ..., 2, 1, 0] -> [0, 1, 2, ..., max_disp-1] 41 | target_fm = torch.flip(unfolded_target_fm, dims=[2, ]) 42 | 43 | reference_fm = reference_fm.reshape(B, C, 1, H, W).repeat(1, 1, max_disp, 1, 1) 44 | 45 | cost = -(reference_fm - target_fm) ** 2 46 | 47 | else: 48 | # the number of disparity samples 49 | D = disp_sample.shape[1] 50 | 51 | # expand D dimension 52 | reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W) 53 | target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W) 54 | # 55 | # # shift target feature according to disparity samples 56 | target_fm = inverse_warp_3d(target_fm, -disp_sample, padding_mode='zeros') 57 | 58 | cost = torch.cat([reference_fm, target_fm], dim=1) 59 | 60 | 61 | # [B, C, D, H, W) 62 | B, C, D, H, W = reference_fm.shape 63 | 64 | costs = [cost, ] 65 | block_cost_scale = int(block_cost_scale) 66 | for s in range(block_cost_scale): 67 | sD, sH, sW = 1, min(2**s, H), min(2**s, W) 68 | local_reference_fm = F.avg_pool3d(reference_fm, kernel_size=(sD, sH, sW), stride=(sD, sH, sW)) 69 | local_target_fm = F.avg_pool3d(target_fm, kernel_size=(sD, sH, sW), stride=(sD, sH, sW)) 70 | 71 | cost = groupwise_correlation(local_reference_fm, local_target_fm) 72 | 73 | # [B, C//8, D, H, W] 74 | cost = F.interpolate(cost, size=(D, H, W), mode='trilinear', align_corners=True) 75 | 76 | cost = cost.reshape(B, C//8, D, H, W).contiguous() 77 | 78 | costs.append(cost) 79 | 80 | # [B, 2C + C//8*local_scale, D, H, W] 81 | cost = torch.cat(costs, dim=1) 82 | 83 | return cost 84 | 85 | 86 | if __name__ == '__main__': 87 | """ 88 | GPU: GTX3090, CUDA:11.0, Torch:1.7.1 89 | SPARSE_CAT_FMS reference forward once takes 3.1339ms, i.e. 319.09fps at 1/4 resolution, 2.4887ms, i.e. 401.82fps with 2 scale 90 | SPARSE_CAT_FMS reference forward once takes 1.0396ms, i.e. 961.90fps at 1/8 resolution 91 | 92 | BLOCK_COST reference forward once takes 1.7147ms, i.e. 583.18fps at 1/4 resolution, C=48, disp_samples=4 93 | """ 94 | print("Feature Concatenation Test...") 95 | from architecture.utils import timeTestTemplate 96 | 97 | # -------------------------------------- Time Test-------------------------------------- # 98 | 99 | scale = 16 100 | C, H, W = 192, 384//scale, 1248//scale # size in KITTI 101 | device = torch.device('cuda:0') 102 | left = torch.rand(1, C, H, W).to(device) 103 | right = torch.rand(1, C, H, W).to(device) 104 | 105 | disp_samples = (torch.randn(12) * W).repeat(1, H, W, 1). \ 106 | permute(0, 3, 1, 2).contiguous().to(device) 107 | # disp_samples = 12 108 | 109 | avg_time = timeTestTemplate(block_cost, left, right, disp_samples, iters=1000, device=torch.device('cuda:0')) 110 | 111 | print('{} reference forward once takes {:.4f}ms, i.e. {:.2f}fps'.format('BLOCK_COST', avg_time * 1000, (1 / avg_time))) 112 | 113 | 114 | -------------------------------------------------------------------------------- /architecture/modeling/aggregation/TemporalStereo/precise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | from architecture.modeling.aggregation.utils import block_cost 8 | 9 | from .module import PredictionHeads, ResidualBlock3D, DepthwiseConv3D, UNet 10 | 11 | class PreciseAggregation(nn.Module): 12 | def __init__(self, 13 | in_planes: int, 14 | C: int, 15 | num_sample: int, 16 | delta: float = 1, 17 | block_cost_scale: int = 3, 18 | topk: int = 2, 19 | norm: str = 'BN3d', 20 | activation: Union[str, Tuple, List] = 'SiLU'): 21 | super(PreciseAggregation, self).__init__() 22 | self.in_planes = in_planes 23 | self.C = C 24 | self.num_sample = num_sample 25 | self.delta = delta 26 | self.block_cost_scale = block_cost_scale 27 | self.topk = topk 28 | self.norm = norm 29 | self.activation = activation 30 | 31 | cost_planes = 4 * in_planes + block_cost_scale * 2 * in_planes // 8 32 | self.init3d = nn.Sequential( 33 | DepthwiseConv3D(cost_planes, C, 3, 1, 1, bias=True, norm=norm, activation=activation), 34 | ResidualBlock3D(in_planes=C, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), 35 | DepthwiseConv3D(C, C, 3, 1, padding=2, dilation=2, bias=False, norm=norm, activation=activation), 36 | ) 37 | 38 | self.pred_heads = PredictionHeads(in_planes=C, delta=delta, norm=norm, activation=activation) 39 | 40 | self.refinement = UNet(in_planes=3, out_planes=in_planes) 41 | 42 | self.weight_init() 43 | 44 | def weight_init(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2. / n)) 49 | elif isinstance(m, nn.Conv3d): 50 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | elif isinstance(m, nn.BatchNorm3d): 56 | m.weight.data.fill_(1) 57 | m.bias.data.zero_() 58 | elif isinstance(m, nn.Linear): 59 | m.bias.data.zero_() 60 | 61 | def predict_disp(self, cost, disp_sample, off, k=2): 62 | topk_cost, indices = torch.topk(cost, k=k, dim=1) 63 | prob = torch.softmax(topk_cost, dim=1) 64 | topk_disp = torch.gather(disp_sample+off, dim=1, index=indices) 65 | disp_map = torch.sum(prob*topk_disp, dim=1, keepdim=True) 66 | 67 | return disp_map, topk_disp, topk_cost 68 | 69 | def generate_disparity_sample(self, low_disparity, high_disparity, num_sample): 70 | batch_size, _, height, width = low_disparity.shape 71 | device = low_disparity.device 72 | # track with constant speed motion 73 | disp_sample = torch.Tensor([0, 3, 4, 5, 8]) 74 | num_sample = len(disp_sample) 75 | disp_sample = (disp_sample / disp_sample.max()).view(1, num_sample, 1, 1) 76 | disp_sample = disp_sample.expand(batch_size, num_sample, height, width).to(device) 77 | disp_sample = torch.abs(high_disparity - low_disparity) * disp_sample + torch.min(low_disparity, high_disparity) 78 | 79 | return disp_sample 80 | 81 | def forward(self, left, right, low_disparity, high_disparity, left_image, right_image, prev_info:dict): 82 | B, _, H, W = left.shape 83 | spx_left_feats, spx_right_feats = self.refinement.encoder(left_image, right_image) 84 | spx2l, spx4l = spx_left_feats 85 | spx2r, spx4r = spx_right_feats 86 | left, right = torch.cat([left, spx4l], dim=1), torch.cat([right, spx4r], dim=1) 87 | 88 | disp_sample = self.generate_disparity_sample(low_disparity, high_disparity, self.num_sample) 89 | raw_cost = block_cost(left, right, disp_sample, block_cost_scale=self.block_cost_scale) 90 | 91 | init_cost = self.init3d(raw_cost) 92 | final_cost, off = self.pred_heads(init_cost) 93 | 94 | # learn disparity 95 | disp, memory_sample, memory_volume = self.predict_disp(final_cost, disp_sample, off, k=self.topk) 96 | full_disp = self.refinement.decoder(disp, left, spx2l) 97 | 98 | prev_info['prev_disp'] = full_disp.detach() 99 | # save memory 100 | prev_info['cost_memory'] = { 101 | 'disp_sample': F.interpolate(memory_sample/2, scale_factor=1/2, mode='bilinear', align_corners=True), 102 | 'cost_volume': F.interpolate(memory_volume, scale_factor=1/2, mode='bilinear', align_corners=True), 103 | } 104 | 105 | return full_disp, disp, final_cost, off, disp_sample, prev_info 106 | -------------------------------------------------------------------------------- /projects/TemporalStereo/dist_train.py: -------------------------------------------------------------------------------- 1 | # Copyright CVLAB of University of Bologna 2021. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the StereoBenchmark licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | import os 8 | 9 | import pytorch_lightning as pl 10 | from pytorch_lightning import seed_everything 11 | from pytorch_lightning.callbacks import LearningRateMonitor 12 | from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging 13 | from pytorch_lightning.utilities import rank_zero_only 14 | 15 | seed_everything(43, workers=True) 16 | 17 | import torch 18 | torch.autograd.set_detect_anomaly(True) 19 | from torch.utils.data import DataLoader 20 | 21 | import sys 22 | import os.path as osp 23 | sys.path.insert(0, osp.join(osp.dirname(osp.abspath(__file__)), '../../')) 24 | 25 | from config import get_cfg, get_parser 26 | from TemporalStereo import TemporalStereo 27 | from architecture.data.datasets import build_stereo_dataset 28 | from logger import Logger 29 | 30 | import shutil 31 | 32 | @rank_zero_only 33 | def backup_code(save_dir): 34 | savedir = '{}/code/'.format(save_dir) 35 | datadirs = ['architecture', 36 | 'projects', 37 | ] 38 | if os.path.exists(savedir): 39 | shutil.rmtree(savedir) 40 | os.makedirs(savedir) 41 | root = osp.join(osp.dirname(osp.abspath(__file__)), '../../') 42 | for datadir in datadirs: 43 | shutil.copytree('{}{}'.format(root, datadir), 44 | '{}{}'.format(savedir, datadir), 45 | ignore=shutil.ignore_patterns('*.pyc', '*.npy', '*.pdf', '*.json', '*.bin', '.idea', 46 | '*.egg', '*.egg-info', 'build', 'dist', '*.so', 47 | '*.pth', '*.pkl', '*.ckpt', 48 | '__pycache__', '.DS_Store', )) 49 | 50 | if __name__ == "__main__": 51 | 52 | args = get_parser().parse_args() 53 | 54 | cfg = get_cfg(args) 55 | model = TemporalStereo(cfg.convert_to_dict()) 56 | 57 | save_path = os.path.join(cfg.LOG_DIR, cfg.TRAINER.NAME, cfg.TRAINER.VERSION) 58 | logger = Logger(cfg.LOG_DIR, cfg.TRAINER.NAME, cfg.TRAINER.VERSION) 59 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 60 | filename='{epoch:03d}' if cfg.CHECKPOINT.EVERY_N_EPOCHS > 0 else '{epoch}-{step}', 61 | dirpath=save_path, 62 | monitor=None, 63 | save_last=True, 64 | save_top_k=-1, 65 | every_n_train_steps=cfg.CHECKPOINT.EVERY_N_TRAIN_STEPS, 66 | every_n_epochs=cfg.CHECKPOINT.EVERY_N_EPOCHS,) 67 | lr_monitor = LearningRateMonitor(logging_interval='step') 68 | swa = StochasticWeightAveraging(swa_epoch_start=0.8) 69 | 70 | if os.path.isfile(cfg.TRAINER.LOAD_FROM_CHECKPOINT): 71 | checkpoint = torch.load(cfg.TRAINER.LOAD_FROM_CHECKPOINT) 72 | model.load_state_dict(checkpoint['state_dict'], strict=False) 73 | logger.filewriter.stdout("Load checkpoint from: {}!".format(cfg.TRAINER.LOAD_FROM_CHECKPOINT)) 74 | elif len(cfg.TRAINER.LOAD_FROM_CHECKPOINT) > 0: 75 | logger.filewriter.stdout("Warning: Checkpoint at {} doesn't exist!".format(cfg.TRAINER.LOAD_FROM_CHECKPOINT)) 76 | 77 | 78 | # backup code 79 | backup_code(save_dir=save_path) 80 | 81 | trainer = pl.Trainer( 82 | logger=logger, 83 | strategy='ddp', 84 | benchmark=True, # speed up training, about 2x 85 | enable_checkpointing=True, 86 | callbacks=[checkpoint_callback, lr_monitor, swa], 87 | check_val_every_n_epoch=cfg.TRAINER.CHECK_VAL_EVERY_N_EPOCHS, 88 | resume_from_checkpoint=cfg.TRAINER.RESUME_FROM_CHECKPOINT if os.path.isfile(cfg.TRAINER.RESUME_FROM_CHECKPOINT) else None, 89 | gpus=cfg.TRAINER.NUM_GPUS, 90 | num_nodes=cfg.TRAINER.NUM_NODES, 91 | max_epochs=cfg.TRAINER.MAX_EPOCHS, 92 | precision=cfg.TRAINER.PRECISION, 93 | amp_backend='native', 94 | sync_batchnorm=cfg.TRAINER.SYNC_BATCHNORM, 95 | detect_anomaly=True, 96 | gradient_clip_val= cfg.TRAINER.GRADIENT_CLIP_VAL, 97 | accumulate_grad_batches=1, 98 | fast_dev_run=False, 99 | # limit_train_batches=0.002, limit_val_batches=0.01, limit_test_batches=0.005, 100 | ) 101 | 102 | # ----------------------------------------- Train ----------------------------------------- # 103 | 104 | trainer.fit(model) 105 | 106 | # ----------------------------------------- Test ----------------------------------------- # 107 | dataset = build_stereo_dataset(cfg.DATA.TEST, 'test') 108 | dataloader = DataLoader( 109 | dataset, cfg.DATA.TEST.BATCH_SIZE, shuffle=False, 110 | num_workers=cfg.DATA.TEST.NUM_WORKERS, pin_memory=True, drop_last=False) 111 | 112 | trainer.test(model, 113 | test_dataloaders=dataloader, 114 | ckpt_path=os.path.join(save_path, 'epoch={:03d}.ckpt'.format(cfg.TRAINER.MAX_EPOCHS-1))) 115 | 116 | print("Done!") -------------------------------------------------------------------------------- /architecture/modeling/aggregation/TemporalStereo/coarse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | from architecture.modeling.layers import Conv2d, Conv3d 8 | from architecture.modeling.aggregation.utils import block_cost 9 | 10 | from .module import (ResidualBlock3D, DepthwiseConv3D, 11 | ConvexUpsample, PyramidFusion, PredictionHeads) 12 | 13 | class CoarseAggregation(nn.Module): 14 | def __init__(self, 15 | in_planes: int, 16 | C: int, 17 | num_sample: int, 18 | delta: float = 1, 19 | block_cost_scale: int = 3, 20 | topk: int = 2, 21 | spatial_fusion: bool = True, 22 | norm: str = 'BN3d', 23 | activation: Union[str, Tuple, List] = 'SiLU'): 24 | super(CoarseAggregation, self).__init__() 25 | self.in_planes = in_planes 26 | self.C = C 27 | self.num_sample = num_sample 28 | self.delta = delta 29 | self.block_cost_scale = block_cost_scale 30 | self.norm = norm 31 | self.activation = activation 32 | self.spatial_fusion = spatial_fusion 33 | self.topk = topk 34 | 35 | cost_planes = 1 * in_planes + block_cost_scale * in_planes // 8 36 | self.init3d = nn.Sequential( 37 | DepthwiseConv3D(cost_planes, C, 3, 1, 1, bias=True, norm=norm, activation=activation), 38 | ResidualBlock3D(in_planes=C, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), 39 | DepthwiseConv3D(C, C, 3, 1, padding=2, dilation=2, bias=False, norm=norm, activation=activation), 40 | ) 41 | 42 | self.past_conv = Conv3d(1, C, 1, 1, 0, bias=False, norm=(norm, C), activation=activation) 43 | if spatial_fusion: 44 | self.fuse = PyramidFusion(in_planes=C, norm=norm, activation=activation) 45 | 46 | self.pred_heads = PredictionHeads(in_planes=C, delta=delta, norm=norm, activation=activation) 47 | 48 | self.convex_upsample = ConvexUpsample(in_planes=in_planes, upscale_factor=2, window_size=3) 49 | 50 | self.weight_init() 51 | 52 | def weight_init(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d): 55 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 56 | m.weight.data.normal_(0, math.sqrt(2. / n)) 57 | elif isinstance(m, nn.Conv3d): 58 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 59 | m.weight.data.normal_(0, math.sqrt(2. / n)) 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.BatchNorm3d): 64 | m.weight.data.fill_(1) 65 | m.bias.data.zero_() 66 | elif isinstance(m, nn.Linear): 67 | m.bias.data.zero_() 68 | 69 | def predict_disp(self, cost, disp_sample, off, k=2): 70 | topk_cost, indices = torch.topk(cost, k=k, dim=1) 71 | prob = torch.softmax(topk_cost, dim=1) 72 | topk_disp = torch.gather(disp_sample+off, dim=1, index=indices) 73 | disp_map = torch.sum(prob*topk_disp, dim=1, keepdim=True) 74 | 75 | return disp_map, topk_disp, topk_cost 76 | 77 | def forward(self, left, right, prev_info:dict): 78 | B, _, H, W = left.shape 79 | raw_cost = block_cost(left, right, self.num_sample, block_cost_scale=self.block_cost_scale) 80 | disp_sample = torch.linspace(0, self.num_sample-1, self.num_sample).view(1, self.num_sample, 1, 1) 81 | disp_sample = disp_sample.expand(B, self.num_sample, H, W).to(left.device) 82 | init_cost = self.init3d(raw_cost) 83 | 84 | memory = prev_info.get('cost_memory', None) 85 | use_past_cost = prev_info.get('use_past_cost', False) 86 | if memory is None or not use_past_cost: 87 | memory_sample = torch.zeros_like(disp_sample[:, :self.topk]) 88 | memory_volume = torch.zeros_like(memory_sample).unsqueeze(dim=1) 89 | 90 | else: 91 | memory_sample = memory['disp_sample'] 92 | mh, mw = memory_sample.shape[-2:] 93 | memory_sample = F.interpolate(memory_sample*W/mw, size=(H, W), mode='bilinear', align_corners=True) 94 | memory_volume = memory['cost_volume'] 95 | memory_volume = F.interpolate(memory_volume, size=(H, W), mode='bilinear', align_corners=True) 96 | memory_volume = memory_volume.unsqueeze(dim=1) 97 | 98 | memory_volume = self.past_conv(memory_volume) 99 | # [B, D, H, W] 100 | disp_sample = torch.cat([disp_sample, memory_sample], dim=1) 101 | # [B, C, 2*D, H, W] 102 | init_cost = torch.cat([init_cost, memory_volume], dim=2) 103 | disp_sample, indices = torch.sort(disp_sample, dim=1) 104 | init_cost = torch.gather(init_cost, dim=2, index=indices.unsqueeze(dim=1).repeat(1, self.C, 1, 1, 1)) 105 | init_cost = init_cost.contiguous() 106 | 107 | if self.spatial_fusion: 108 | init_cost = self.fuse(init_cost) 109 | 110 | final_cost, off = self.pred_heads(init_cost) 111 | 112 | # learn disparity 113 | disp, memory_sample, memory_volume = self.predict_disp(final_cost, disp_sample, off, k=self.topk) 114 | disp = self.convex_upsample(left, disp) 115 | 116 | return disp, final_cost, off, disp_sample, prev_info 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

4 | 5 |

TemporalStereo: Efficient Spatial-Temporal Stereo Matching Network

6 |

7 | Youmin Zhang 8 | · 9 | Matteo Poggi 10 | · 11 | Stefano Mattoccia 12 |

13 | 14 |

Arxiv | Project Page | Youtube Video

15 |
16 |

17 |

18 | 19 | Logo 20 | 21 |

22 |

23 | TemporalStereo Architecture, the first supervised stereo network based on video. 24 |

25 | 26 | ## Codebase is almost given 27 | 28 | Currently, our codebase supports the training on flyingthings3d, while other parts will be ready soon... 29 | 30 | Besides, pretrained checkpoints on various datasets are already given, please refer to the following section. 31 | 32 | ## ⚙️ Setup 33 | 34 | Assuming a fresh [Anaconda](https://www.anaconda.com/download/) distribution, you can install the dependencies with: 35 | ```shell 36 | conda create -n temporalstereo python=3.8 37 | conda activate temporalstereo 38 | ``` 39 | We ran our experiments with PyTorch 1.10.1+, CUDA 11.3, Python 3.8 and Ubuntu 20.04. 40 | 41 | 42 | 43 | #### NVIDIA Apex 44 | 45 | We used NVIDIA Apex (commit @ 4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a) for multi-GPU training. 46 | 47 | Apex can be installed as follows: 48 | 49 | ```bash 50 | $ cd PATH_TO_INSTALL 51 | $ git clone https://github.com/NVIDIA/apex 52 | $ cd apex 53 | $ git reset --hard 4ef930c1c884fdca5f472ab2ce7cb9b505d26c1a 54 | $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 55 | ``` 56 | 57 | #### Detectron2 58 | 59 | ```bash 60 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' 61 | ``` 62 | 63 | 64 | #### Cupy 65 | 66 | ```bash 67 | # for cuda 11.3, refer to https://docs.cupy.dev/en/stable/install.html 68 | pip install cupy-cuda113 69 | ``` 70 | 71 | #### Finally 72 | 73 | ```bash 74 | pip install -r requirements.txt 75 | ``` 76 | 77 | ## 💾 Datasets 78 | We used three datasets for training and evaluation. 79 | 80 | All Annfile files (*.json) are available [here](https://drive.google.com/drive/folders/1PNVF8_0lVteOJG8M2MvDq2fzcAqJ7_J2?usp=sharing). 81 | 82 | Besides, we had also put the generation script there for you to get annfile by yourself. 83 | 84 | #### Flyingthings3D 85 | 86 | The [Flyingthings3D/SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) can be downloaded here. 87 | 88 | After that, you will get a data structure as follows: 89 | 90 | ``` 91 | FlyingThings3D 92 | ├── disparity 93 | │ ├── TEST 94 | │ ├── TRAIN 95 | └── frames_finalpass 96 | │ ├── TEST 97 | │ ├── TRAIN 98 | ``` 99 | 100 | #### KITTI 2012/2015 101 | 102 | Processed KITTI 2012/2015 dataset and KITTI Raw Sequences can be downloaded from [Baidu Wangpan](https://pan.baidu.com/s/1epoRBXRy1c4TELMEa-aovg?pwd=iros), with password: iros, or [DropBox](https://www.dropbox.com/sh/gt0pju2tb6ncqxq/AADZfasHy0gSnizQKpTeS-Oia?dl=0). 103 | 104 | Besides, the above link we only upload the pseudo labels of the KITTI Raw Sequences, for raw image downloading, you can refer to [this](https://github.com/youmi-zym/CompletionFormer#kitti-depth-completion-kitti-dc) for help. 105 | 106 | For KITTI 2012 & 2015, we provide stereo image sequences, estimated poses by ORBSLAM3, and calibration files. 107 | 108 | #### TartanAir 109 | 110 | The processed TartanAir dataset can be downloaded [here](https://www.dropbox.com/sh/gt0pju2tb6ncqxq/AADZfasHy0gSnizQKpTeS-Oia?dl=0). 111 | 112 | 113 | ## ⏳ Training 114 | 115 | Note: batch size is set for each GPU 116 | 117 | ```bash 118 | $ cd THIS_PROJECT_ROOT/projects/TemporalStereo 119 | 120 | # sceneflow 121 | python dist_train.py --config-file ./configs/sceneflow.yaml 122 | ``` 123 | 124 | During the training, tensorboard logs are saved under the experiments directory. To run the tensorboard: 125 | 126 | ```bash 127 | $ cd THIS_PROJECT_ROOT/ 128 | $ tensorboard --logdir=. --bind_all 129 | ``` 130 | 131 | Then you can access the tensorboard via http://YOUR_SERVER_IP:6006 132 | 133 | ### Checkpoints 134 | 135 | Pretrained checkpoints on various datasets are all available [here](https://drive.google.com/drive/folders/1dbOfdx6BQ6cRX_m-G4kvvji30uKJOqMH?usp=sharing). 136 | 137 | ## 📊 Testing 138 | 139 | ```bash 140 | $ cd THIS_PROJECT_ROOT/projects/TemporalStereo 141 | # please remember to modify the parameters according to your case 142 | 143 | # run a demo 144 | ./demo.sh 145 | 146 | # submit to kitti 147 | ./submit.sh 148 | 149 | # inference on a video 150 | ./video.sh 151 | 152 | ``` 153 | 154 | ## 👩‍⚖️ Acknowledgement 155 | Thanks the authors for their works: 156 | 157 | [AcfNet](https://github.com/DeepMotionAIResearch/DenseMatchingBenchmark), [CoEx](https://github.com/antabangun/coex), [Detectron2](https://github.com/facebookresearch/detectron2.git) 158 | 159 | 160 | ## Citation 161 | 162 | If you find our work useful in your research please consider citing our paper: 163 | 164 | ``` 165 | @inproceedings{Zhang2023TemporalStereo, 166 | title = {TemporalStereo: Efficient Spatial-Temporal Stereo Matching Network}, 167 | author = {Zhang, Youmin and Poggi, Matteo and Mattoccia, Stefano}, 168 | booktitle = {IROS}, 169 | year = {2023} 170 | } 171 | ``` 172 | -------------------------------------------------------------------------------- /architecture/data/utils/load_flow.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import png 4 | 5 | 6 | def load_pfm(file_path): 7 | """ 8 | load image in PFM type. 9 | Args: 10 | file_path string: file path(absolute) 11 | Returns: 12 | data (numpy.array): data of image in (Height, Width[, 3]) layout 13 | scale (float): scale of image 14 | """ 15 | with open(file_path, encoding="ISO-8859-1") as fp: 16 | 17 | color = None 18 | width = None 19 | height = None 20 | scale = None 21 | endian = None 22 | 23 | # load file header and grab channels, if is 'PF' 3 channels else 1 channel(gray scale) 24 | header = fp.readline().rstrip() # .decode("ascii") # .decode('utf-8') 25 | if header == 'PF': 26 | color = True 27 | elif header == 'Pf': 28 | color = False 29 | else: 30 | raise Exception('Not a PFM file.') 31 | 32 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', fp.readline()) # .decode('ascii')) 33 | if dim_match: 34 | width, height = map(int, dim_match.groups()) 35 | else: 36 | raise Exception('Malformed PFM header.') 37 | 38 | scale = float(fp.readline().rstrip()) 39 | if scale < 0: # little-endian 40 | endian = '<' 41 | scale = -scale 42 | else: 43 | endian = '>' # big-endian 44 | 45 | data = np.fromfile(fp, endian + 'f') 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | 51 | return data, scale 52 | 53 | 54 | def load_png(file_path): 55 | """ 56 | Read from KITTI .png file 57 | Args: 58 | file_path string: file path(absolute) 59 | Returns: 60 | data (numpy.array): data of image in (Height, Width, 3) layout 61 | """ 62 | flow_object = png.Reader(filename=file_path) 63 | flow_direct = flow_object.asDirect() 64 | flow_data = list(flow_direct[2]) 65 | (w, h) = flow_direct[3]['size'] 66 | 67 | flow = np.zeros((h, w, 3), dtype=np.float64) 68 | for i in range(len(flow_data)): 69 | flow[i, :, 0] = flow_data[i][0::3] 70 | flow[i, :, 1] = flow_data[i][1::3] 71 | flow[i, :, 2] = flow_data[i][2::3] 72 | 73 | invalid_idx = (flow[:, :, 2] == 0) 74 | flow[:, :, 0:2] = (flow[:, :, 0:2] - 2 ** 15) / 64.0 75 | flow[invalid_idx, 0] = 0 76 | flow[invalid_idx, 1] = 0 77 | 78 | return flow.astype(np.float32) 79 | 80 | 81 | def load_flo(file_path): 82 | """ 83 | Read .flo file in MiddleBury format 84 | Code adapted from: 85 | http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 86 | WARNING: this will work on little-endian architectures (eg Intel x86) only! 87 | Args: 88 | file_path string: file path(absolute) 89 | Returns: 90 | flow (numpy.array): data of image in (Height, Width, 2) layout 91 | """ 92 | 93 | with open(file_path, 'rb') as f: 94 | magic = np.fromfile(f, np.float32, count=1) 95 | assert(magic == 202021.25) 96 | w = int(np.fromfile(f, np.int32, count=1)) 97 | h = int(np.fromfile(f, np.int32, count=1)) 98 | # print('Reading %d x %d flo file\n' % (w, h)) 99 | flow = np.fromfile(f, np.float32, count=2 * w * h) 100 | # Reshape data into 3D array (columns, rows, bands) 101 | # The reshape here is for visualization, the original code is (w,h,2) 102 | flow = np.resize(flow, (h, w, 2)) 103 | 104 | return flow 105 | 106 | 107 | def write_flo(file_path, uv, v=None): 108 | """ Write optical flow to file. 109 | If v is None, uv is assumed to contain both u and v channels, 110 | stacked in depth. 111 | Original code by Deqing Sun, adapted from Daniel Scharstein. 112 | """ 113 | nBands = 2 114 | 115 | if v is None: 116 | assert (uv.ndim == 3) 117 | assert (uv.shape[2] == 2) 118 | u = uv[:, :, 0] 119 | v = uv[:, :, 1] 120 | else: 121 | u = uv 122 | 123 | assert (u.shape == v.shape) 124 | height, width = u.shape 125 | f = open(file_path, 'wb') 126 | # write the header 127 | np.array([202021.25]).astype(np.float32).tofile(f) 128 | np.array(width).astype(np.int32).tofile(f) 129 | np.array(height).astype(np.int32).tofile(f) 130 | # arrange into matrix form 131 | tmp = np.zeros((height, width * nBands)) 132 | tmp[:, np.arange(width) * 2] = u 133 | tmp[:, np.arange(width) * 2 + 1] = v 134 | tmp.astype(np.float32).tofile(f) 135 | 136 | f.close() 137 | 138 | 139 | # load utils 140 | def load_flying_chairs_flow(img_path): 141 | """load flying chairs flow image 142 | Args: 143 | img_path: 144 | Returns: 145 | """ 146 | assert img_path.endswith('.flo'), "flying chairs flow image must end with .flo " \ 147 | "but got {}".format(img_path) 148 | 149 | flow_img = load_flo(img_path) 150 | 151 | return flow_img 152 | 153 | 154 | # load utils 155 | def write_flying_chairs_flow(img_path, uv, v=None): 156 | """write flying chairs flow image 157 | Args: 158 | img_path: 159 | Returns: 160 | """ 161 | assert img_path.endswith('.flo'), "flying chairs flow image must end with .flo " \ 162 | "but got {}".format(img_path) 163 | 164 | write_flo(img_path, uv, v) 165 | 166 | 167 | # load utils 168 | def load_flying_things_flow(img_path): 169 | """load flying things flow image 170 | Args: 171 | img_path: 172 | Returns: 173 | """ 174 | assert img_path.endswith('.pfm'), "flying things flow image must end with .pfm " \ 175 | "but got {}".format(img_path) 176 | 177 | flow_img, __ = load_pfm(img_path) 178 | if flow_img.shape[-1] > 2: 179 | flow_img = flow_img[:, :, :2] 180 | 181 | return flow_img 182 | 183 | 184 | # load utils 185 | def load_kitti_flow(img_path): 186 | """load KITTI 2012/2015 flow image 187 | Args: 188 | img_path: 189 | Returns: 190 | """ 191 | assert img_path.endswith('.png'), "KITTI 2012/2015 flow image must end with .png " \ 192 | "but got {}".format(img_path) 193 | 194 | flow_img = load_png(img_path) 195 | 196 | return flow_img -------------------------------------------------------------------------------- /architecture/modeling/aggregation/TemporalStereo/fine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | from architecture.modeling.layers import Conv2d, Conv3d 8 | from architecture.modeling.aggregation.utils import block_cost 9 | 10 | from .module import ConvexUpsample, PredictionHeads, PyramidFusion, ResidualBlock3D, DepthwiseConv3D 11 | 12 | class FineAggregation(nn.Module): 13 | def __init__(self, 14 | in_planes: int, 15 | C: int, 16 | num_sample: int, 17 | delta: float = 1, 18 | block_cost_scale: int = 3, 19 | topk: int = 2, 20 | spatial_fusion: bool = True, 21 | norm: str = 'BN3d', 22 | activation: Union[str, Tuple, List] = 'SiLU'): 23 | super(FineAggregation, self).__init__() 24 | self.in_planes = in_planes 25 | self.C = C 26 | self.num_sample = num_sample 27 | self.delta = delta 28 | self.block_cost_scale = block_cost_scale 29 | self.topk = topk 30 | self.spatial_fusion = spatial_fusion 31 | self.norm = norm 32 | self.activation = activation 33 | self.phi = nn.Parameter(torch.Tensor([0.0, ]), requires_grad=True) 34 | 35 | cost_planes = 2 * in_planes + block_cost_scale * in_planes // 8 36 | self.init3d = nn.Sequential( 37 | DepthwiseConv3D(cost_planes, C, 3, 1, 1, bias=True, norm=norm, activation=activation), 38 | ResidualBlock3D(in_planes=C, kernel_size=3, stride=2, padding=1, norm=norm, activation=activation), 39 | DepthwiseConv3D(C, C, 3, 1, padding=2, dilation=2, bias=False, norm=norm, activation=activation), 40 | ) 41 | 42 | self.past_conv = Conv3d(1, C, 1, 1, 0, bias=False, norm=(norm, C), activation=activation) 43 | 44 | if self.spatial_fusion: 45 | self.fuse = PyramidFusion(in_planes=C, norm=norm, activation=activation) 46 | 47 | self.pred_heads = PredictionHeads(in_planes=C, delta=delta, norm=norm, activation=activation) 48 | 49 | self.convex_upsample = ConvexUpsample(in_planes=in_planes, upscale_factor=2, window_size=3) 50 | 51 | self.weight_init() 52 | 53 | def weight_init(self): 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d): 56 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 57 | m.weight.data.normal_(0, math.sqrt(2. / n)) 58 | elif isinstance(m, nn.Conv3d): 59 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 60 | m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.BatchNorm3d): 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | elif isinstance(m, nn.Linear): 68 | m.bias.data.zero_() 69 | 70 | def predict_disp(self, cost, disp_sample, off, k=2): 71 | topk_cost, indices = torch.topk(cost, k=k, dim=1) 72 | prob = torch.softmax(topk_cost, dim=1) 73 | topk_disp = torch.gather(disp_sample+off, dim=1, index=indices) 74 | disp_map = torch.sum(prob*topk_disp, dim=1, keepdim=True) 75 | 76 | return disp_map, topk_disp, topk_cost 77 | 78 | def generate_disparity_sample(self, low_disparity, high_disparity, num_sample, prev_info): 79 | batch_size, _, height, width = low_disparity.shape 80 | device = low_disparity.device 81 | # track with constant speed motion 82 | disp_sample = torch.Tensor([0, 3, 4, 5, 8]) 83 | num_sample = len(disp_sample) 84 | disp_sample = (disp_sample / disp_sample.max()).view(1, num_sample, 1, 1) 85 | disp_sample = disp_sample.expand(batch_size, num_sample, height, width).to(device) 86 | disp_sample = torch.abs(high_disparity - low_disparity) * disp_sample + torch.min(low_disparity, high_disparity) 87 | 88 | # track in local map 89 | local_map = prev_info.get('local_map', None) 90 | local_map_size = prev_info.get('local_map_size', 0) 91 | if local_map is not None and local_map_size > 0: 92 | local_map = F.interpolate(local_map*width/local_map.shape[-1], size=(height, width), mode='bilinear', align_corners=True) 93 | disp_sample = torch.cat([local_map, disp_sample], dim=1) 94 | 95 | return disp_sample 96 | 97 | def forward(self, left, right, low_disparity, high_disparity, prev_info:dict): 98 | B, _, H, W = left.shape 99 | disp_sample = self.generate_disparity_sample(low_disparity, high_disparity, self.num_sample, prev_info) 100 | raw_cost = block_cost(left, right, disp_sample, block_cost_scale=self.block_cost_scale) 101 | 102 | init_cost = self.init3d(raw_cost) 103 | 104 | # fuse temporal info 105 | memory = prev_info.get('cost_memory', None) 106 | use_past_cost = prev_info.get('use_past_cost', False) 107 | if memory is None or not use_past_cost: 108 | memory_sample = torch.zeros_like(disp_sample[:, :self.topk]) 109 | memory_volume = torch.zeros_like(memory_sample).unsqueeze(dim=1) 110 | else: 111 | memory_sample = memory['disp_sample'] 112 | memory_volume = memory['cost_volume'].unsqueeze(dim=1) 113 | 114 | memory_volume = self.past_conv(memory_volume) 115 | # [B, D, H, W] 116 | disp_sample = torch.cat([disp_sample, memory_sample], dim=1) 117 | # [B, C, 2*D, H, W] 118 | init_cost = torch.cat([init_cost, memory_volume], dim=2) 119 | # [B, D, H, W] 120 | disp_sample, indices = torch.sort(disp_sample, dim=1) 121 | init_cost = torch.gather(init_cost, dim=2, index=indices.unsqueeze(dim=1).repeat(1, self.C, 1, 1, 1)) 122 | init_cost = init_cost.contiguous() 123 | if self.spatial_fusion: 124 | init_cost = self.fuse(init_cost) 125 | 126 | final_cost, off = self.pred_heads(init_cost) 127 | 128 | # learn disparity 129 | disp, memory_sample, memory_volume = self.predict_disp(final_cost, disp_sample, off, k=self.topk) 130 | disp = self.convex_upsample(left, disp) 131 | 132 | return disp, final_cost, off, disp_sample, prev_info 133 | -------------------------------------------------------------------------------- /architecture/modeling/losses/warsserstein_distance_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from detectron2.config import configurable 5 | from typing import Optional, Dict, Tuple, Union, List 6 | import warnings 7 | 8 | 9 | class WarssersteinDistanceLoss(object): 10 | """ 11 | Args: 12 | max_disp: (int), the max of Disparity. default is 192 13 | start_disp: (int), the start searching disparity index, usually be 0 14 | weights: (list, tuple, float, optional): weight for each scale of est disparity map. 15 | sparse: (bool), whether the ground-truth disparity is sparse, 16 | for example, KITTI is sparse, but SceneFlow is not, default is False. 17 | Inputs: 18 | estCosts: (Tensor or List[Tensor]): the estimated cost volumes, 19 | [BatchSize, NumSamples, Height, Width] layout. 20 | estOffsets: (Tensor or List[Tensor]): the estimated disparity offsets for each cost volume, 21 | [BatchSize, NumSamples, Height, Width] layout. 22 | gtDisp: (Tensor), the ground truth disparity map, 23 | [BatchSize, 1, Height, Width] layout. 24 | Outputs: 25 | loss: (dict), the loss of each level 26 | """ 27 | @configurable 28 | def __init__(self, max_disp:int, start_disp:int=0, global_weight:float=1.0, 29 | weights:Union[Tuple, List, float, None]=None, sparse:bool=False): 30 | self.max_disp = max_disp 31 | self.start_disp = start_disp 32 | self.global_weight = global_weight 33 | self.weights = weights 34 | self.sparse = sparse 35 | if sparse: 36 | # sparse disparity ==> max_pooling 37 | self.scale_func = F.adaptive_max_pool2d 38 | else: 39 | # dense disparity ==> avg_pooling 40 | self.scale_func = F.adaptive_avg_pool2d 41 | 42 | @classmethod 43 | def from_config(cls, cfg): 44 | return { 45 | "max_disp": cfg.get("MAX_DISP", 192), 46 | "start_disp": cfg.get("START_DISP", 0), 47 | "global_weight": cfg.get("GLOBAL_WEIGHT", 1.0), 48 | "weights": cfg.get("WEIGHTS", None), 49 | "sparse": cfg.get("SPARSE", False), 50 | } 51 | 52 | def loss_per_level(self, estCost, estOffset, dispSample, gtDisp): 53 | N, D, H, W = estCost.shape 54 | estProb = torch.softmax(estCost, dim=1) 55 | 56 | scaled_gtDisp = gtDisp 57 | scale = 1.0 58 | if gtDisp.shape[-2] != H or gtDisp.shape[-1] != W: 59 | # compute scale per level and scale gtDisp 60 | scale = gtDisp.shape[-1] / (W * 1.0) 61 | scaled_gtDisp = gtDisp / scale 62 | scaled_gtDisp = self.scale_func(scaled_gtDisp, (H, W)) 63 | 64 | # mask for valid disparity 65 | # (start disparity, max disparity / scale) 66 | # Attention: the invalid disparity of KITTI is set as 0, be sure to mask it out 67 | mask = (scaled_gtDisp > self.start_disp) & (scaled_gtDisp < (self.max_disp / scale)) 68 | if mask.sum() < 1.0: 69 | warnings.warn('Warsserstein distance loss: there is no point\'s disparity is in ({},{})!'.format(self.start_disp, 70 | self.max_disp / scale)) 71 | loss = (estProb * torch.abs(estOffset + dispSample - scaled_gtDisp) * mask.float()).sum(dim=1).mean() 72 | return loss 73 | 74 | # warsserstein distance loss 75 | # sum{ (0.2 + 0.8*P(d)) * | d* - (d + delta)| } 76 | war_loss = ((estProb*1.0 + 0.25) * torch.abs(estOffset + dispSample - scaled_gtDisp) * mask.float()).sum(dim=1).mean() 77 | 78 | return war_loss 79 | 80 | def __call__(self, estCosts, estOffsets, dispSamples, gtDisp): 81 | if not isinstance(estCosts, (list, tuple)): 82 | estCosts = [estCosts, ] 83 | 84 | if not isinstance(estOffsets, (list, tuple)): 85 | estOffsets = [estOffsets, ] 86 | 87 | if not isinstance(dispSamples, (list, tuple)): 88 | dispSamples = [dispSamples, ] * len(estCosts) 89 | 90 | assert len(estCosts) == len(estOffsets), "{}, {}".format(len(estCosts), len(estOffsets)) 91 | 92 | if self.weights is None: 93 | self.weights = [1.0] * len(estCosts) 94 | 95 | # compute loss for per level 96 | loss_all_level = [] 97 | for est_cost_per_lvl, est_off_per_lvl, est_sample_per_lvl in zip(estCosts, estOffsets, dispSamples): 98 | assert est_sample_per_lvl.shape == est_cost_per_lvl.shape, "sample shape: {}, cost shape: {}".format(est_sample_per_lvl.shape, 99 | est_cost_per_lvl.shape) 100 | assert est_off_per_lvl.shape == est_cost_per_lvl.shape, "sample shape: {}, cost shape: {}".format(est_off_per_lvl.shape, 101 | est_cost_per_lvl.shape) 102 | 103 | loss_all_level.append( 104 | self.loss_per_level(est_cost_per_lvl, est_off_per_lvl, est_sample_per_lvl, gtDisp) 105 | ) 106 | 107 | # re-weight loss per level 108 | weighted_loss_all_level = dict() 109 | for i, loss_per_level in enumerate(loss_all_level): 110 | name = "wars_loss_lvl{}".format(i) 111 | weighted_loss_all_level[name] = self.weights[i] * loss_per_level * self.global_weight 112 | 113 | return weighted_loss_all_level 114 | 115 | def __repr__(self): 116 | repr_str = '{}\n'.format(self.__class__.__name__) 117 | repr_str += ' ' * 4 + 'Max Disparity: {}\n'.format(self.max_disp) 118 | repr_str += ' ' * 4 + 'Start disparity: {}\n'.format(self.start_disp) 119 | repr_str += ' ' * 4 + 'Global Loss weight: {}\n'.format(self.global_weight) 120 | repr_str += ' ' * 4 + 'Loss weights: {}\n'.format(self.weights) 121 | repr_str += ' ' * 4 + 'GT Disparity is sparse: {}\n'.format(self.sparse) 122 | 123 | return repr_str 124 | 125 | @property 126 | def name(self): 127 | return 'WarssersteinDistanceLoss' -------------------------------------------------------------------------------- /architecture/modeling/aggregation/TemporalStereo/TemporalStereo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | from detectron2.config import configurable 7 | 8 | from ..builder import AGGREGATION_REGISTRY 9 | from .coarse import CoarseAggregation 10 | from .fine import FineAggregation 11 | from .precise import PreciseAggregation 12 | 13 | 14 | @AGGREGATION_REGISTRY.register() 15 | class TEMPORALSTEREO(nn.Module): 16 | """ 17 | Cost Aggregation method proposed by TemporalStereo 18 | Args: 19 | max_disp: (int), the max disparity value 20 | norm: (str), the type of normalization layer 21 | activation: (str, list, tuple), the type of activation layer and its coefficient is needed 22 | """ 23 | @configurable 24 | def __init__(self, 25 | coarse: nn.Module, 26 | fine: nn.Module, 27 | precise: nn.Module, 28 | norm: str = 'BN', 29 | activation: Union[str, List, Tuple] = 'SiLU'): 30 | super(TEMPORALSTEREO, self).__init__() 31 | self.norm = norm 32 | self.activation = activation 33 | 34 | self.coarse = coarse 35 | self.fine = fine 36 | self.precise = precise 37 | 38 | @classmethod 39 | def from_config(cls, cfg): 40 | coarse = CoarseAggregation( 41 | in_planes= cfg.MODEL.AGGREGATION.COARSE.get('IN_PLANES', 192), 42 | C= cfg.MODEL.AGGREGATION.COARSE.get('C', 32), 43 | num_sample= cfg.MODEL.AGGREGATION.COARSE.get('NUM_SAMPLE', 12), 44 | delta= cfg.MODEL.AGGREGATION.COARSE.get('DELTA', 1), 45 | block_cost_scale= cfg.MODEL.AGGREGATION.COARSE.get('BLOCK_COST_SCALE', 3), 46 | topk= cfg.MODEL.AGGREGATION.COARSE.get('TOPK', 2), 47 | spatial_fusion= cfg.MODEL.AGGREGATION.COARSE.get('SPATIAL_FUSION', True), 48 | norm= cfg.MODEL.AGGREGATION.COARSE.get('NORM', 'BN3d'), 49 | activation= cfg.MODEL.AGGREGATION.COARSE.get('ACTIVATION', 'SiLU'), 50 | ) 51 | fine = FineAggregation( 52 | in_planes= cfg.MODEL.AGGREGATION.FINE.get('IN_PLANES', 64), 53 | C= cfg.MODEL.AGGREGATION.FINE.get('C', 16), 54 | num_sample= cfg.MODEL.AGGREGATION.FINE.get('NUM_SAMPLE', 5), 55 | delta= cfg.MODEL.AGGREGATION.FINE.get('DELTA', 1), 56 | block_cost_scale= cfg.MODEL.AGGREGATION.FINE.get('BLOCK_COST_SCALE', 3), 57 | topk= cfg.MODEL.AGGREGATION.FINE.get('TOPK', 2), 58 | spatial_fusion= cfg.MODEL.AGGREGATION.FINE.get('SPATIAL_FUSION', True), 59 | norm= cfg.MODEL.AGGREGATION.FINE.get('NORM', 'BN3d'), 60 | activation= cfg.MODEL.AGGREGATION.FINE.get('ACTIVATION', 'SiLU'), 61 | ) 62 | precise = PreciseAggregation( 63 | in_planes= cfg.MODEL.AGGREGATION.PRECISE.get('IN_PLANES', 48), 64 | C= cfg.MODEL.AGGREGATION.PRECISE.get('C', 8), 65 | num_sample= cfg.MODEL.AGGREGATION.PRECISE.get('NUM_SAMPLE', 5), 66 | delta= cfg.MODEL.AGGREGATION.PRECISE.get('DELTA', 1), 67 | block_cost_scale= cfg.MODEL.AGGREGATION.PRECISE.get('BLOCK_COST_SCALE', 3), 68 | topk= cfg.MODEL.AGGREGATION.PRECISE.get('TOPK', 2), 69 | norm= cfg.MODEL.AGGREGATION.PRECISE.get('NORM', 'BN3d'), 70 | activation= cfg.MODEL.AGGREGATION.PRECISE.get('ACTIVATION', 'SiLU'), 71 | ) 72 | return { 73 | 'coarse': coarse, 74 | 'fine': fine, 75 | 'precise': precise, 76 | "norm": cfg.MODEL.AGGREGATION.get('NORM', 'BN'), 77 | "activation": cfg.MODEL.AGGREGATION.get('ACTIVATION', 'SiLU'), 78 | } 79 | 80 | def weight_init(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | elif isinstance(m, nn.Conv3d): 86 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 87 | m.weight.data.normal_(0, math.sqrt(2. / n)) 88 | elif isinstance(m, nn.BatchNorm2d): 89 | m.weight.data.fill_(1) 90 | m.bias.data.zero_() 91 | elif isinstance(m, nn.BatchNorm3d): 92 | m.weight.data.fill_(1) 93 | m.bias.data.zero_() 94 | elif isinstance(m, nn.Linear): 95 | m.bias.data.zero_() 96 | 97 | def forward(self, left_feats, right_feats, left_image, right_image, prev_info: dict): 98 | disps = [] 99 | costs = [] 100 | offs = [] 101 | disp_samples = [] 102 | search_ranges = [] 103 | disp_range = 4 104 | 105 | left_feat, left_feat_8, left_feat_16 = left_feats 106 | right_feat, right_feat_8, right_feat_16 = right_feats 107 | 108 | # coarse prediction 109 | disp, cost, off, disp_sample, prev_info = self.coarse(left_feat_16, right_feat_16, prev_info) 110 | low, high = disp - disp_range, disp + disp_range 111 | disps.append(disp) 112 | disp_samples.append(disp_sample) 113 | search_ranges.append({'low': low, 'high': high}) 114 | costs.append(cost) 115 | offs.append(off) 116 | 117 | # fine prediction 118 | disp, cost, off, disp_sample, prev_info = self.fine(left_feat_8, right_feat_8, low, high, prev_info) 119 | low, high = disp - disp_range, disp + disp_range 120 | disps.append(disp) 121 | disp_samples.append(disp_sample) 122 | search_ranges.append({'low': low, 'high': high}) 123 | costs.append(cost) 124 | offs.append(off) 125 | 126 | # precise 127 | full_disp, disp, cost, off, disp_sample, prev_info = self.precise(left_feat, right_feat, low, high, 128 | left_image, right_image, prev_info) 129 | disps.append(disp) 130 | disps.append(full_disp) 131 | disp_samples.append(disp_sample) 132 | costs.append(cost) 133 | offs.append(off) 134 | 135 | return disps[::-1], costs[::-1], disp_samples[::-1], offs[::-1], search_ranges[::-1], prev_info 136 | 137 | -------------------------------------------------------------------------------- /architecture/utils/visualization/flow_colormap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | 5 | UNKNOWN_FLOW_THRESH = 1e7 6 | 7 | 8 | def make_color_wheel(): 9 | """ 10 | Generate color wheel according MiddleBury color code 11 | Outputs: 12 | color_wheel, (numpy.ndarray): color wheel, in [nCols, 3] layout 13 | """ 14 | RY = 15 15 | YG = 6 16 | GC = 4 17 | CB = 11 18 | BM = 13 19 | MR = 6 20 | 21 | nCols = RY + YG + GC + CB + BM + MR 22 | 23 | color_wheel = np.zeros([nCols, 3]) 24 | 25 | col = 0 26 | 27 | # RY 28 | color_wheel[0:RY, 0] = 255 29 | color_wheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 30 | col += RY 31 | 32 | # YG 33 | color_wheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 34 | color_wheel[col:col+YG, 1] = 255 35 | col += YG 36 | 37 | # GC 38 | color_wheel[col:col+GC, 1] = 255 39 | color_wheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 40 | col += GC 41 | 42 | # CB 43 | color_wheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 44 | color_wheel[col:col+CB, 2] = 255 45 | col += CB 46 | 47 | # BM 48 | color_wheel[col:col+BM, 2] = 255 49 | color_wheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 50 | col += + BM 51 | 52 | # MR 53 | color_wheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 54 | color_wheel[col:col+MR, 0] = 255 55 | 56 | return color_wheel 57 | 58 | 59 | def flow_max_rad(flow): 60 | """ 61 | Maximum sqrt(f_x^2 + f_y^2) 62 | Inputs: 63 | flow, (numpy.ndarray): flow map, in [Height, Width, 2] layout 64 | Outputs: 65 | max_rad, (float): the max flow in Euclidean Distance 66 | """ 67 | u = flow[:, :, 0] 68 | v = flow[:, :, 1] 69 | 70 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) | np.isnan(u) | np.isnan(v) 71 | u[idxUnknow] = 0 72 | v[idxUnknow] = 0 73 | 74 | rad = np.sqrt(u ** 2 + v ** 2) 75 | max_rad = max(-1, np.max(rad)) 76 | 77 | return max_rad 78 | 79 | 80 | def flow_color(flow, max_rad=None): 81 | """ 82 | compute optical flow color map 83 | Inputs: 84 | flow, (numpy.ndarray): optical flow map, in [Height, Width, 2] layout 85 | max_rad, (float): the max flow in Euclidean Distance 86 | Outputs: 87 | img, (numpy.ndarray): optical flow in color code, in [Height, Width, 3] layout, value range [0, 1] 88 | """ 89 | # [H, W] 90 | u = flow[:, :, 0] 91 | v = flow[:, :, 1] 92 | 93 | h, w = u.shape 94 | img = np.zeros([h, w, 3], dtype=np.float32) 95 | 96 | # [H, W] 97 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) | np.isnan(u) | np.isnan(v) 98 | u[idxUnknow] = 0 99 | v[idxUnknow] = 0 100 | 101 | # [H, W] 102 | if max_rad is None: 103 | rad = np.sqrt(u ** 2 + v ** 2) 104 | max_rad = max(-1, np.max(rad)) 105 | 106 | # [H, W] 107 | u = u / (max_rad + np.finfo(float).eps) 108 | v = v / (max_rad + np.finfo(float).eps) 109 | 110 | color_wheel = make_color_wheel() 111 | nCols = np.size(color_wheel, 0) 112 | 113 | # [H, W] 114 | rad = np.sqrt(u**2+v**2) 115 | 116 | # angle, [H, W] 117 | a = np.arctan2(-v, -u) / np.pi 118 | 119 | fk = (a+1) / 2 * (nCols - 1) + 1 120 | 121 | k0 = np.floor(fk).astype(int) 122 | 123 | k1 = k0 + 1 124 | k1[k1 == nCols+1] = 1 125 | f = fk - k0 126 | 127 | for i in range(0, np.size(color_wheel, 1)): 128 | tmp = color_wheel[:, i] 129 | col0 = tmp[k0-1] / 255 130 | col1 = tmp[k1-1] / 255 131 | col = (1-f) * col0 + f * col1 132 | 133 | idx = (rad <= 1) 134 | col[idx] = 1-rad[idx]*(1-col[idx]) 135 | notIdx = np.logical_not(idx) 136 | 137 | col[notIdx] *= 0.75 138 | img[:, :, i] = col*(1-idxUnknow) 139 | 140 | return img 141 | 142 | 143 | def flow_to_color(flow, max_rad=None): 144 | """ 145 | Convert flow into MiddleBury color code image 146 | Inputs: 147 | flow, (numpy.ndarray): optical flow map, in [Height, Width, 2] layout 148 | max_rad, (float): the max flow in Euclidean Distance 149 | Outputs: 150 | img, (numpy.ndarray): optical flow image in MiddleBury color, in [Height, Width, 3] layout, value range [0,1] 151 | """ 152 | # [H, W] 153 | u = flow[:, :, 0] 154 | v = flow[:, :, 1] 155 | 156 | # [H, W] 157 | idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) | np.isnan(u) | np.isnan(v) 158 | 159 | # [H, W, 3], [0, 1] 160 | img = flow_color(flow, max_rad=max_rad) 161 | 162 | # [H, W, 3] 163 | idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) 164 | # [H, W, 3] 165 | img[idx] = 0 166 | 167 | return img 168 | 169 | 170 | def flow_err_to_color(F_est, F_gt, F_gt_val=None): 171 | """ 172 | Calculate the error map between optical flow estimation and optical flow ground-truth 173 | hot color -> big error, cold color -> small error 174 | Inputs: 175 | F_est, (numpy.ndarray): optical flow estimation map in (Height, Width, 2) layout 176 | F_gt, (numpy.ndarray): optical flow ground-truth map in (Height, Width, 2) layout 177 | Outputs: 178 | F_err, (numpy.ndarray): optical flow error map in (Height, Width, 3) layout, range [0,1] 179 | """ 180 | 181 | F_shape = F_gt.shape[:2] 182 | 183 | # error color map with interval (0, 0.1875, 0.375, 0.75, 1.5, 3, 6, 12, 24, 48, inf)/3.0 184 | # different interval corresponds to different 3-channel projection 185 | cols = np.array([ 186 | [0.0, 0.1875, 49, 54, 149], 187 | [0.1875, 0.375, 69, 117, 180], 188 | [0.375, 0.75, 116, 173, 209], 189 | [0.75, 1.5, 171, 217, 233], 190 | [1.5, 3.0, 224, 243, 248], 191 | [3.0, 6.0, 254, 224, 144], 192 | [6.0, 12.0, 253, 174, 97], 193 | [12.0, 24.0, 244, 109, 67], 194 | [24.0, 48.0, 215, 48, 39], 195 | [48.0, float('inf'), 165, 0, 38] 196 | ]) 197 | 198 | E_duv = F_gt - F_est 199 | E = np.square(E_duv) 200 | E = np.sqrt(E[:, :, 0] + E[:, :, 1]) 201 | 202 | if F_gt_val is not None: 203 | E = E * F_gt_val 204 | 205 | if F_gt_val is None: 206 | F_val = np.ones(F_shape, dtype=np.bool_) 207 | else: 208 | F_val = (F_gt_val != 0.0) 209 | 210 | F_err = np.zeros((F_gt.shape[0], F_gt.shape[1], 3), dtype=np.float32) 211 | for i in range(cols.shape[0]): 212 | F_find = F_val & (E >= cols[i, 0]) & (E <= cols[i, 1]) 213 | F_find = np.where(F_find) 214 | F_err[:, :, 0][F_find] = float(cols[i, 2]) 215 | F_err[:, :, 1][F_find] = float(cols[i, 3]) 216 | F_err[:, :, 2][F_find] = float(cols[i, 4]) 217 | 218 | F_err = F_err / 255.0 219 | 220 | return F_err 221 | 222 | -------------------------------------------------------------------------------- /architecture/data/utils/calibration/kitti_calib.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | # Extract calibration information from KITTI 5 | # We use https://github.com/utiasSTARS/pykitti/blob/master/pykitti/odometry.py as reference 6 | # for the calibration and poses loading 7 | 8 | def read_calib_file(filepath): 9 | """Read a calibration file and parse into a dictionary""" 10 | data = {} 11 | 12 | with open(filepath, 'r') as f: 13 | for line in f.readlines(): 14 | try: 15 | key, value = line.split(':', 1) 16 | # The only non-float values in these files are dates, which 17 | # we don't care about anyway 18 | data[key] = np.array([float(x) for x in value.split()]) 19 | except ValueError: 20 | pass 21 | return data 22 | 23 | def read_calib_from_video(cam2cam_filepath, velo2cam_filepath): 24 | """Read the cam_to_cam calibration and velo_to_cam calibration and parse into a dictionary""" 25 | data = {} 26 | cam2cam = read_calib_file(cam2cam_filepath) 27 | velo2cam = read_calib_file(velo2cam_filepath) 28 | 29 | # Transformation matrix from rotation matrix and translation vector, in 3x4 layout 30 | Tr_velo_to_cam = np.zeros((3, 4)) 31 | Tr_velo_to_cam[:3, :3] = np.reshape(velo2cam['R'], [3, 3]) 32 | Tr_velo_to_cam[:, 3] = velo2cam['T'] 33 | 34 | data['Tr_velo_to_cam'] = Tr_velo_to_cam 35 | data.update(cam2cam) 36 | return data 37 | 38 | def load_calib(calib_filepath, velo2cam_filepath=None): 39 | """ 40 | For KITTI visual odometry: 41 | calibration stored in one file: P0, P1, P2, P3, Tr 42 | P0/P1/P2/P3 are the 3x4 projection matrices after rectification 43 | Tr transforms a point from velodyne coordinates into the left rectified camera coordinate system 44 | For KITTI raw data: 45 | calibration stored in two files(calib_cam_to_cam.txt, calib_velo_to_cam.txt) 46 | S_0x, K_0x, D_0x, T_0x, S_rect_0x, R_rect_0x, P_rect_0x -> for x in [0, 1, 2, 3] 47 | For KITTI 2015: 48 | calibration stored in two files(calib_cam_to_cam/{:0>6d}.txt, calib_velo_to_cam/{:0>6d.txt}) 49 | contents are the same as KITTI raw data 50 | 51 | """ 52 | from .utils import cart_to_homo, trans2homo4x4 53 | 54 | data = {} 55 | assert os.path.isfile(calib_filepath), calib_filepath 56 | if velo2cam_filepath is not None: 57 | cam2cam_filepath = calib_filepath 58 | assert os.path.isfile(velo2cam_filepath), velo2cam_filepath 59 | fileData = read_calib_from_video(cam2cam_filepath, velo2cam_filepath) 60 | else: 61 | fileData = read_calib_file(calib_filepath) 62 | 63 | # Create 3x4 projection matrices 64 | if 'P0' in fileData.keys(): # KITTI visual odometry 65 | P0, P1, P2, P3 = ['P{}'.format(x) for x in range(4)] 66 | elif 'P_rect_00' in fileData.keys(): # KITTI raw data, KITTI 2015 67 | P0, P1, P2, P3 = ['P_rect_0{}'.format(x) for x in range(4)] 68 | else: 69 | raise ValueError 70 | 71 | P_rect_00 = data['P_rect_00'] = np.reshape(fileData[P0], (3, 4)) 72 | P_rect_10 = data['P_rect_10'] = np.reshape(fileData[P1], (3, 4)) 73 | P_rect_20 = data['P_rect_20'] = np.reshape(fileData[P2], (3, 4)) 74 | P_rect_30 = data['P_rect_30'] = np.reshape(fileData[P3], (3, 4)) 75 | 76 | # Compute the camera intrinsics, in 3x3 layout 77 | K0 = data['K_cam0'] = P_rect_00[:3, :3] 78 | K1 = data['K_cam1'] = P_rect_10[:3, :3] 79 | K2 = data['K_cam2'] = P_rect_20[:3, :3] 80 | K3 = data['K_cam3'] = P_rect_30[:3, :3] 81 | 82 | # Compute the rectified extrinsic from cam0 to camN, in 4x4 layout 83 | # P_X0 = KX @ T_X0, so T_X0 = inv(KX) @ P_X0 84 | T0 = np.linalg.inv(trans2homo4x4(K0)) @ trans2homo4x4(P_rect_00) 85 | T1 = np.linalg.inv(trans2homo4x4(K1)) @ trans2homo4x4(P_rect_10) 86 | T2 = np.linalg.inv(trans2homo4x4(K2)) @ trans2homo4x4(P_rect_20) 87 | T3 = np.linalg.inv(trans2homo4x4(K3)) @ trans2homo4x4(P_rect_30) 88 | 89 | # Compute the velodyne to rectified camera coordinate transforms 90 | if 'Tr' in fileData.keys(): # KITTI visual odometry 91 | data['T_cam0_velo'] = np.reshape(fileData['Tr'], (3, 4)) 92 | else: 93 | """ 94 | R0_rect (3x3): Rotation from non-rectified to rectified camera coordinate system 95 | Tr_velo_to_cam (3x4): Rigid transformation from Velodyne to (non-rectified) camera coordinates 96 | For not odometry dataset, such as raw/kitti 2015/object, etc 97 | """ 98 | if 'R0_rect' in fileData.keys(): 99 | R0 = 'R0_rect' 100 | elif 'R_rect_00' in fileData.keys(): # KITTI raw data 101 | R0 = 'R_rect_00' 102 | else: 103 | raise ValueError 104 | 105 | R_rect_00 = fileData[R0].reshape((3,3)) 106 | Tr_cam0_velo = np.reshape(fileData['Tr_velo_to_cam'], (3,4)) 107 | data['T_cam0_velo'] = R_rect_00 @ Tr_cam0_velo 108 | 109 | T_cam0_velo = np.vstack((data['T_cam0_velo'], [0, 0, 0, 1])) # 4x4 110 | 111 | # project velodyne to cam0 and then translate from cam0 to camN, in 4x4 layout 112 | data['T_cam0_velo'] = T0.dot(T_cam0_velo) 113 | data['T_cam1_velo'] = T1.dot(T_cam0_velo) 114 | data['T_cam2_velo'] = T2.dot(T_cam0_velo) 115 | data['T_cam3_velo'] = T3.dot(T_cam0_velo) 116 | 117 | 118 | """ 119 | Compute the stereo baselines in meters by projecting the origin of each camera frame 120 | into the velodyne frame and computing the distances between them 121 | """ 122 | # the origin point of each camera frame 123 | p_cam = np.array([0, 0, 0, 1]) 124 | # project each camera frame to the velodyne frame 125 | p_velo0 = np.linalg.inv(data['T_cam0_velo']).dot(p_cam) 126 | p_velo1 = np.linalg.inv(data['T_cam1_velo']).dot(p_cam) 127 | p_velo2 = np.linalg.inv(data['T_cam2_velo']).dot(p_cam) 128 | p_velo3 = np.linalg.inv(data['T_cam3_velo']).dot(p_cam) 129 | 130 | # get baseline of two gray cameras and two rgb cameras 131 | data['b_gray'] = np.linalg.norm(p_velo1 - p_velo0) # baseline of gray cameras 132 | data['b_rgb'] = np.linalg.norm(p_velo3 - p_velo2) # baseline of rgb cameras 133 | 134 | # get resolution of image, the resolution of all camera is the same 135 | if 'S_rect_02' in fileData.keys(): 136 | # in (Height, Width) layout 137 | data['resolution'] = tuple(fileData['S_rect_02'][::-1].astype(np.int32)) 138 | 139 | else: 140 | data['resolution'] = None 141 | 142 | """ 143 | Double check ! 144 | For transformation from a point Y of velodyne to a point Z of rectified camera frame X 145 | 1. translate to camera frame X and then rectify with intrinsic 146 | Z = [K_camX|0] @ T_camX_velo @ Y 147 | 2. translate to camera frame 0 and move to camera frame X, then rectify with intrinsic 148 | Z = P_rect_X0 @ T_cam0_velo @ Y 149 | 150 | i.e. P_rect_X0 = [K_camX|0] @ T_camX_cam0 151 | """ 152 | for x in range(4): 153 | T_camX_velo = 'T_cam{}_velo'.format(x) 154 | K_camX = 'K_cam{}'.format(x) 155 | P_rect_X0 = 'P_rect_{}0'.format(x) 156 | 157 | Z1 = trans2homo4x4(data[P_rect_X0]) @ data['T_cam0_velo'] 158 | Z2 = trans2homo4x4(data[K_camX]) @ data[T_camX_velo] 159 | 160 | assert np.allclose(Z1, Z2) # Z1, Z2 should be the same 161 | 162 | export_ks = [ 163 | 'P_rect_00', 'P_rect_10', 'P_rect_20', 'P_rect_30', 164 | 'T_cam0_velo', 'T_cam1_velo', 'T_cam2_velo', 'T_cam3_velo', 165 | 'K_cam0', 'K_cam1', 'K_cam2', 'K_cam3', 166 | 'b_gray', 'b_rgb', 167 | 'resolution', 168 | ] 169 | 170 | calibration = dict() 171 | for k in export_ks: 172 | calibration[k] = data[k] 173 | 174 | return calibration 175 | 176 | 177 | -------------------------------------------------------------------------------- /projects/TemporalStereo/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from architecture.utils.config import CfgNode 4 | 5 | # ********************************************* CONFIG BEGIN ********************************************** # 6 | 7 | CN = CfgNode 8 | _C = CN() 9 | _C.MAX_DISP = 192 10 | _C.FRAME_IDXS = [0, -1] 11 | _C.LOG_DIR = os.path.join("./exps/") 12 | 13 | 14 | # ************************************************ DATA ************************************************ # 15 | _C.DATA = CN() 16 | 17 | _C.DATA.TRAIN = CN() 18 | _C.DATA.TRAIN.DATA_ROOT = os.path.join("./datasets/SceneFlow/Flyingthings3D") 19 | _C.DATA.TRAIN.TYPE = "SceneFlow" 20 | _C.DATA.TRAIN.ANNFILE = './splits/flyingthings3d/train.json' 21 | _C.DATA.TRAIN.HEIGHT = 512 22 | _C.DATA.TRAIN.WIDTH = 960 23 | _C.DATA.TRAIN.USE_COMMON_INTRINSICS = True 24 | _C.DATA.TRAIN.DO_SAME_LR_TRANSFORM =False 25 | _C.DATA.TRAIN.MEAN = (0.485, 0.456, 0.406) 26 | _C.DATA.TRAIN.STD = (0.229, 0.224, 0.225) 27 | _C.DATA.TRAIN.FRAME_IDXS = [0, ] 28 | _C.DATA.TRAIN.BATCH_SIZE = 8 29 | _C.DATA.TRAIN.NUM_WORKERS = 8 30 | 31 | _C.DATA.VAL = CN() 32 | _C.DATA.VAL.DATA_ROOT = os.path.join("./datasets/SceneFlow/Flyingthings3D") 33 | _C.DATA.VAL.TYPE = "SceneFlow" 34 | _C.DATA.VAL.ANNFILE = './splits/flyingthings3d/test.json' 35 | _C.DATA.VAL.HEIGHT = 544 36 | _C.DATA.VAL.WIDTH = 960 37 | _C.DATA.VAL.USE_COMMON_INTRINSICS = True 38 | _C.DATA.VAL.DO_SAME_LR_TRANSFORM = True 39 | _C.DATA.VAL.MEAN = (0.485, 0.456, 0.406) 40 | _C.DATA.VAL.STD = (0.229, 0.224, 0.225) 41 | _C.DATA.VAL.FRAME_IDXS = [0, ] 42 | _C.DATA.VAL.BATCH_SIZE = 4 43 | _C.DATA.VAL.NUM_WORKERS = 4 44 | 45 | _C.DATA.TEST = CN() 46 | _C.DATA.TEST.DATA_ROOT = os.path.join("./datasets/SceneFlow/Flyingthings3D") 47 | _C.DATA.TEST.TYPE = "SceneFlow" 48 | _C.DATA.TEST.ANNFILE = './splits/flyingthings3d/test.json' 49 | _C.DATA.TEST.HEIGHT = 544 50 | _C.DATA.TEST.WIDTH = 960 51 | _C.DATA.TEST.USE_COMMON_INTRINSICS = True 52 | _C.DATA.TEST.DO_SAME_LR_TRANSFORM =True 53 | _C.DATA.TEST.MEAN = (0.485, 0.456, 0.406) 54 | _C.DATA.TEST.STD = (0.229, 0.224, 0.225) 55 | _C.DATA.TEST.FRAME_IDXS = [0, ] 56 | _C.DATA.TEST.BATCH_SIZE = 1 57 | _C.DATA.TEST.NUM_WORKERS = 2 58 | 59 | # ************************************************ TRAINER ************************************************ # 60 | _C.CHECKPOINT = CN() 61 | _C.CHECKPOINT.EVERY_N_TRAIN_STEPS = 0 62 | _C.CHECKPOINT.EVERY_N_EPOCHS = 1 63 | 64 | # ************************************************ TRAINER ************************************************ # 65 | _C.TRAINER = CN() 66 | _C.TRAINER.NAME = "TemporalStereo" 67 | _C.TRAINER.VERSION = "default" 68 | _C.TRAINER.NUM_GPUS = 1 69 | _C.TRAINER.NUM_NODES = 1 70 | _C.TRAINER.MAX_EPOCHS = 10 71 | _C.TRAINER.MIN_EPOCHS = 1 72 | # _C.TRAINER.MAX_STEPS = None 73 | _C.TRAINER.MIN_STEPS = 1000 74 | _C.TRAINER.PRECISION = 32 # 16bit or 32bit 75 | _C.TRAINER.AMP_LEVEL = '00' # 00, 01, 02, 03 76 | _C.TRAINER.SYNC_BATCHNORM = True 77 | _C.TRAINER.GRADIENT_CLIP_VAL = 0.1 78 | _C.TRAINER.LOG_EVERY_N_STEPS = 50 79 | _C.TRAINER.FLUSH_LOGS_EVERY_N_STEPS = 100 80 | _C.TRAINER.CHECK_VAL_EVERY_N_EPOCHS = 1 81 | _C.TRAINER.RESUME_FROM_CHECKPOINT = '' 82 | _C.TRAINER.LOAD_FROM_CHECKPOINT = '' 83 | _C.TRAINER.FAST_DEV_RUN = False 84 | 85 | 86 | # ************************************************ OPTIMIZER ************************************************ # 87 | _C.OPTIMIZER = CN() 88 | _C.OPTIMIZER.TYPE = "RMSProp" 89 | _C.OPTIMIZER.RMSPROP = CN() 90 | _C.OPTIMIZER.RMSPROP.LR = 1e-3 91 | _C.OPTIMIZER.ADAM = CN() 92 | _C.OPTIMIZER.ADAM.LR = 1e-3 93 | _C.OPTIMIZER.ADAM.BETAS = (0.9, 0.999) 94 | _C.OPTIMIZER.ADAMW = CN() 95 | _C.OPTIMIZER.ADAMW.LR = 1e-3 96 | _C.OPTIMIZER.ADAMW.BETAS = (0.9, 0.999) 97 | _C.OPTIMIZER.ADAMW.WEIGHT_DECAY = 1e-4 98 | _C.SCHEDULER = CN() 99 | _C.SCHEDULER.TYPE = "MultiStepLR" 100 | _C.SCHEDULER.MULTI_STEP_LR = CN() 101 | _C.SCHEDULER.MULTI_STEP_LR.MILESTONES = [10, 20] 102 | _C.SCHEDULER.MULTI_STEP_LR.GAMMA = 0.1 103 | _C.SCHEDULER.EXPONENTIAL_LR = CN() 104 | _C.SCHEDULER.EXPONENTIAL_LR.GAMMA = 0.9 105 | 106 | 107 | # ************************************************ MODEL ************************************************ # 108 | _C.MODEL = CN() 109 | _C.MODEL.WITH_PREVIOUS = False 110 | _C.MODEL.PREVIOUS_WITH_GRADIENT = False 111 | _C.MODEL.WITH_FLOW = False 112 | _C.MODEL.USE_LOCAL_MAP = False 113 | _C.MODEL.USE_PAST_COST = False 114 | _C.MODEL.LOCAL_MAP_SIZE = 0 115 | _C.MODEL.VIS_FEATURE = False 116 | 117 | 118 | # ----------------------------------------------- BACKBONE ---------------------------------------------- # 119 | _C.MODEL.BACKBONE = CN() 120 | _C.MODEL.BACKBONE.NAME = 'TEMPORALSTEREO' 121 | _C.MODEL.BACKBONE.IN_PLANES = 3 122 | _C.MODEL.BACKBONE.ALPHA = 1.0 123 | _C.MODEL.BACKBONE.USE_GRU = False 124 | _C.MODEL.BACKBONE.MEMORY_PERCENT = 1/8 125 | _C.MODEL.BACKBONE.NORM = 'BN' 126 | _C.MODEL.BACKBONE.ACTIVATION = 'SiLU' 127 | 128 | # ---------------------------------------------- AGGREGATION ---------------------------------------------- # 129 | _C.MODEL.AGGREGATION = CN() 130 | _C.MODEL.AGGREGATION.NAME = 'TEMPORALSTEREO' 131 | _C.MODEL.AGGREGATION.NORM = 'BN' 132 | _C.MODEL.AGGREGATION.ACTIVATION = 'SiLU' 133 | 134 | _C.MODEL.AGGREGATION.COARSE = CN() 135 | _C.MODEL.AGGREGATION.COARSE.IN_PLANES = 192 136 | _C.MODEL.AGGREGATION.COARSE.C = 16 137 | _C.MODEL.AGGREGATION.COARSE.NUM_SAMPLE = 12 138 | _C.MODEL.AGGREGATION.COARSE.DELTA = 1.0 139 | _C.MODEL.AGGREGATION.COARSE.BLOCK_COST_SCALE = 3 140 | _C.MODEL.AGGREGATION.COARSE.SPATIAL_FUSION = True 141 | _C.MODEL.AGGREGATION.COARSE.TOPK = 1 142 | _C.MODEL.AGGREGATION.COARSE.NORM = 'BN3d' 143 | _C.MODEL.AGGREGATION.COARSE.ACTIVATION = 'SiLU' 144 | 145 | _C.MODEL.AGGREGATION.FINE = CN() 146 | _C.MODEL.AGGREGATION.FINE.IN_PLANES = 64 147 | _C.MODEL.AGGREGATION.FINE.C = 8 148 | _C.MODEL.AGGREGATION.FINE.NUM_SAMPLE = 4 149 | _C.MODEL.AGGREGATION.FINE.DELTA = 1.0 150 | _C.MODEL.AGGREGATION.FINE.BLOCK_COST_SCALE = 3 151 | _C.MODEL.AGGREGATION.FINE.SPATIAL_FUSION = True 152 | _C.MODEL.AGGREGATION.FINE.TOPK = 1 153 | _C.MODEL.AGGREGATION.FINE.NORM = 'BN3d' 154 | _C.MODEL.AGGREGATION.FINE.ACTIVATION = 'SiLU' 155 | 156 | _C.MODEL.AGGREGATION.PRECISE = CN() 157 | _C.MODEL.AGGREGATION.PRECISE.IN_PLANES = 48 158 | _C.MODEL.AGGREGATION.PRECISE.C = 8 159 | _C.MODEL.AGGREGATION.PRECISE.NUM_SAMPLE = 4 160 | _C.MODEL.AGGREGATION.PRECISE.DELTA = 1.0 161 | _C.MODEL.AGGREGATION.PRECISE.BLOCK_COST_SCALE = 3 162 | _C.MODEL.AGGREGATION.PRECISE.TOPK = 1 163 | _C.MODEL.AGGREGATION.PRECISE.NORM = 'BN3d' 164 | _C.MODEL.AGGREGATION.PRECISE.ACTIVATION = 'SiLU' 165 | 166 | # ------------------------------------------------ LOSS ---------------------------------------------- # 167 | _C.MODEL.LOSSES = CN() 168 | _C.MODEL.LOSSES.WARSSERSTEIN_DISTANCE_LOSS = CN() 169 | _C.MODEL.LOSSES.WARSSERSTEIN_DISTANCE_LOSS.MAX_DISP = 192 170 | _C.MODEL.LOSSES.WARSSERSTEIN_DISTANCE_LOSS.START_DISP = 0 171 | _C.MODEL.LOSSES.WARSSERSTEIN_DISTANCE_LOSS.GLOBAL_WEIGHT = 1.0 172 | _C.MODEL.LOSSES.WARSSERSTEIN_DISTANCE_LOSS.WEIGHTS = [1.2, 0.3, 0.1] 173 | _C.MODEL.LOSSES.WARSSERSTEIN_DISTANCE_LOSS.SPARSE = False 174 | 175 | _C.MODEL.LOSSES.SMOOTH_L1_LOSS = CN() 176 | _C.MODEL.LOSSES.SMOOTH_L1_LOSS.MAX_DISP = 192 177 | _C.MODEL.LOSSES.SMOOTH_L1_LOSS.START_DISP = 0 178 | _C.MODEL.LOSSES.SMOOTH_L1_LOSS.GLOBAL_WEIGHT = 1.0 179 | _C.MODEL.LOSSES.SMOOTH_L1_LOSS.WEIGHTS = [1.0, 0.7, 0.5] 180 | _C.MODEL.LOSSES.SMOOTH_L1_LOSS.SPARSE = False 181 | 182 | # ************************************************ VAL ************************************************ # 183 | _C.VAL = CN() 184 | _C.VAL.VIS_INTERVAL = 8 185 | _C.VAL.VIS_BATCH_INDEX = 4 186 | _C.VAL.LOWERBOUND = 0 187 | _C.VAL.UPPERBOUND = 192 188 | _C.VAL.DO_OCCLUSION_EVALUATION = True 189 | _C.VAL.EVAL_DISPARITY_IDS = [0, 1, 2, 3] 190 | 191 | 192 | def get_parser(): 193 | parser = argparse.ArgumentParser(description="TemporalStereo Training") 194 | parser.add_argument("--config-file", default="", metavar="FILE", 195 | help="path to config file") 196 | parser.add_argument( 197 | "opts", 198 | help="Modify config options using the command-line", 199 | default=None, 200 | nargs=argparse.REMAINDER, 201 | ) 202 | 203 | return parser 204 | 205 | def get_cfg(args): 206 | cfg = _C.clone() 207 | if args.config_file: 208 | cfg.merge_from_file(args.config_file) 209 | cfg.merge_from_list(args.opts) 210 | cfg.freeze() 211 | return cfg 212 | 213 | -------------------------------------------------------------------------------- /architecture/modeling/aggregation/utils/raft_corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class CorrBlock: 5 | # for disparity, i.e. stereo matching 6 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 7 | self.num_levels = num_levels 8 | self.radius = radius 9 | self.corr_pyramid = [] 10 | 11 | # generate grid index 12 | b, _, h, w = fmap1.shape 13 | self.coord_w = torch.arange(w).view(w, 1, 1, 1).repeat(b*h, 1, radius*2+1, 1).to(fmap1.device).float() 14 | self.coord_h = torch.ones_like(self.coord_w) * (-1.0) 15 | # all pairs correlation 16 | # [B*H*W, 1, 1, W] 17 | corr = CorrBlock.corr(fmap1, fmap2) 18 | 19 | self.corr_pyramid.append(corr) 20 | for i in range(self.num_levels-1): 21 | corr = F.avg_pool2d(corr, (1, 2), stride=(1, 2)) 22 | self.corr_pyramid.append(corr) 23 | 24 | def __call__(self, disp): 25 | r = self.radius 26 | b, c, h, w = disp.shape 27 | assert c == 1, "{} got".format(c) 28 | 29 | # left image's disparity map 30 | coord_w = self.coord_w.float() - disp.permute(0, 2, 3, 1).contiguous().reshape(b*h*w, 1).view(b*h*w, 1, 1, 1) 31 | # [B*H*W, 1, 1, 2] 32 | coords = torch.cat((coord_w, self.coord_h), dim=-1) 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | # [B*H*W, 1, 1, W] 37 | corr = self.corr_pyramid[i] 38 | # [1, 1, 2*r+1, 1] 39 | delta_w = torch.linspace(-r, r, 2*r+1).to(disp.device).view(1, 1, 2*r+1, 1) 40 | 41 | coords_lvl = coords / 2**i 42 | # [B*H*W, 1, 2*r+1, 2] 43 | coords_lvl[:, :, :, 0:1] = 2 * (coords_lvl[:, :, :, 0:1] + delta_w) / (w - 1) - 1 44 | 45 | # [B*H*W, 1, 1, 2*r+1] 46 | corr = F.grid_sample(corr, coords_lvl, mode='bilinear', padding_mode='zeros') 47 | # [B, H, W, 2*r+1] 48 | corr = corr.reshape(b, h, w, -1) 49 | out_pyramid.append(corr) 50 | 51 | out = torch.cat(out_pyramid, dim=-1) 52 | # [B, 4*(2*r+1), H, W] 53 | return out.permute(0, 3, 1, 2).contiguous().float() 54 | 55 | @staticmethod 56 | def corr(fmap1, fmap2): 57 | batch, dim, ht, wd = fmap1.shape 58 | # [B, H, W, C] 59 | fmap1 = fmap1.permute(0, 2, 3, 1) 60 | # [B, H, C, W] 61 | fmap2 = fmap2.permute(0, 2, 1, 3) 62 | 63 | # [B, H, W, W] 64 | corr = torch.matmul(fmap1, fmap2) 65 | # [B*H*W, 1, 1, W] 66 | corr = corr.reshape(batch*ht*wd, wd).unsqueeze(1).unsqueeze(1) 67 | return corr / torch.sqrt(torch.tensor(dim).float()) 68 | 69 | 70 | 71 | class FlowCorrBlock: 72 | # for optical flow 73 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 74 | self.num_levels = num_levels 75 | self.radius = radius 76 | self.corr_pyramid = [] 77 | 78 | # all pairs correlation 79 | corr = FlowCorrBlock.corr(fmap1, fmap2) 80 | 81 | batch, h1, w1, dim, h2, w2 = corr.shape 82 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 83 | 84 | self.corr_pyramid.append(corr) 85 | for i in range(self.num_levels - 1): 86 | corr = F.avg_pool2d(corr, 2, stride=2) 87 | self.corr_pyramid.append(corr) 88 | 89 | def __call__(self, coords): 90 | r = self.radius 91 | coords = coords.permute(0, 2, 3, 1) 92 | batch, h1, w1, _ = coords.shape 93 | 94 | out_pyramid = [] 95 | for i in range(self.num_levels): 96 | corr = self.corr_pyramid[i] 97 | dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 98 | dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) 99 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1) 100 | 101 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 102 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 103 | coords_lvl = centroid_lvl + delta_lvl 104 | 105 | corr = bilinear_sampler(corr, coords_lvl) 106 | corr = corr.view(batch, h1, w1, -1) 107 | out_pyramid.append(corr) 108 | 109 | out = torch.cat(out_pyramid, dim=-1) 110 | return out.permute(0, 3, 1, 2).contiguous().float() 111 | 112 | @staticmethod 113 | def corr(fmap1, fmap2): 114 | batch, dim, ht, wd = fmap1.shape 115 | fmap1 = fmap1.view(batch, dim, ht * wd) 116 | fmap2 = fmap2.view(batch, dim, ht * wd) 117 | 118 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 119 | x2 = torch.matmul(fmap1.transpose(1, 2), fmap1) 120 | y2 = torch.matmul(fmap2.transpose(1, 2), fmap2) 121 | corr = (x2 - 2*corr + y2) 122 | corr = corr.view(batch, ht, wd, 1, ht, wd) 123 | return corr / torch.sqrt(torch.tensor(dim).float()) 124 | 125 | @staticmethod 126 | def init_flow(size, device, flow_init=None): 127 | assert len(size) == 4, "Excepted size with [B, C, H, W], but {} got".format(size) 128 | b, c, h, w = size 129 | 130 | 131 | """ construct pixel coordination in an image""" 132 | # [1, H, W] copy 0-width for h times : x coord 133 | x_range = torch.arange(0, w, device=device, dtype=torch.float).view(1, 1, 1, w).expand(b, 1, h, w) 134 | # [1, H, W] copy 0-height for w times : y coord 135 | y_range = torch.arange(0, h, device=device, dtype=torch.float).view(1, 1, h, 1).expand(b, 1, h, w) 136 | # [b, 2, h, w] 137 | pixel_coord = torch.cat((x_range, y_range), dim=1) 138 | 139 | ref_coord = pixel_coord.detach() 140 | tgt_coord = pixel_coord.detach() 141 | if flow_init is not None: 142 | tgt_coord = tgt_coord + flow_init 143 | 144 | return ref_coord, tgt_coord 145 | 146 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 147 | """ Wrapper for grid_sample, uses pixel coordinates """ 148 | H, W = img.shape[-2:] 149 | xgrid, ygrid = coords.split([1,1], dim=-1) 150 | xgrid = 2*xgrid/(W-1) - 1 151 | ygrid = 2*ygrid/(H-1) - 1 152 | 153 | grid = torch.cat([xgrid, ygrid], dim=-1) 154 | img = F.grid_sample(img, grid, mode=mode, align_corners=True) 155 | 156 | if mask: 157 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 158 | return img, mask.float() 159 | 160 | return img 161 | 162 | 163 | if __name__ == '__main__': 164 | print("Test CorrBlock...") 165 | import time 166 | 167 | mode = 'flow' 168 | # mode = 'stereo' 169 | 170 | if mode == 'flow': 171 | iters = 50 172 | scale = 16 173 | B, C, H, W = 1, 80, 384, 1248 174 | device = 'cuda:0' 175 | 176 | prev = torch.randn(B, C, H//scale, W//scale, device=device) 177 | curr = torch.randn(B, C, H//scale, W//scale, device=device) 178 | 179 | ref_coord, tgt_coord = FlowCorrBlock.init_flow(size=(B, C, H//scale, W//scale), device=device, flow_init=None) 180 | coords = tgt_coord 181 | 182 | start_time = time.time() 183 | 184 | for i in range(iters): 185 | with torch.no_grad(): 186 | corr_fn = FlowCorrBlock(prev, curr, num_levels=4, radius=4) 187 | cost = corr_fn(coords) 188 | 189 | torch.cuda.synchronize(device) 190 | end_time = time.time() 191 | avg_time = (end_time - start_time) / iters 192 | 193 | 194 | print('{} reference forward once takes {:.4f}ms, i.e. {:.2f}fps'.format('FlowCorrBlock', avg_time * 1000, (1 / avg_time))) 195 | 196 | elif mode=='stereo': 197 | iters = 50 198 | scale = 4 199 | B, C, H, W = 1, 32, 384, 1248 200 | device = 'cuda:0' 201 | 202 | left = torch.randn(B, C, H // scale, W // scale, device=device) 203 | right = torch.randn(B, C, H // scale, W // scale, device=device) 204 | 205 | disp = torch.randn(B, 1, H//scale, W//scale, device=device) * 192 206 | 207 | start_time = time.time() 208 | 209 | for i in range(iters): 210 | with torch.no_grad(): 211 | corr_fn = CorrBlock(left, right, num_levels=4, radius=4) 212 | cost = corr_fn(disp) 213 | 214 | torch.cuda.synchronize(device) 215 | end_time = time.time() 216 | avg_time = (end_time - start_time) / iters 217 | 218 | print('{} reference forward once takes {:.4f}ms, i.e. {:.2f}fps'.format('CorrBlock', avg_time * 1000, 219 | (1 / avg_time))) 220 | 221 | print("Done!") 222 | 223 | """ 224 | RAFT Flow: at scale=8, reference forward once takes 2.1072ms, i.e. 474.56fps 225 | RAFT Flow: at scale=8, reference forward once takes 1.1996ms, i.e. 833.65fps 226 | RAFT Stereo: at scale=4, reference forward once takes 1.7301ms, i.e. 578.01fps 227 | """ 228 | -------------------------------------------------------------------------------- /architecture/utils/visualization/disparity_colormap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def disp_map(disp): 6 | """ 7 | Based on color histogram, convert the gray disp into color disp map. 8 | The histogram consists of 7 bins, value of each is e.g. [114.0, 185.0, 114.0, 174.0, 114.0, 185.0, 114.0] 9 | Accumulate each bin, named cbins, and scale it to [0,1], e.g. [0.114, 0.299, 0.413, 0.587, 0.701, 0.886, 1.0] 10 | For each value in disp, we have to find which bin it belongs to 11 | Therefore, we have to compare it with every value in cbins 12 | Finally, we have to get the ratio of it accounts for the bin, and then we can interpolate it with the histogram map 13 | For example, 0.780 belongs to the 5th bin, the ratio is (0.780-0.701)/0.114, 14 | then we can interpolate it into 3 channel with the 5th [0, 1, 0] and 6th [0, 1, 1] channel-map 15 | Inputs: 16 | disp: numpy array, disparity gray map in (Height * Width, 1) layout, value range [0,1] 17 | Outputs: 18 | disp: numpy array, disparity color map in (Height * Width, 3) layout, value range [0,1] 19 | """ 20 | map = np.array([ 21 | [0, 0, 0, 114], 22 | [0, 0, 1, 185], 23 | [1, 0, 0, 114], 24 | [1, 0, 1, 174], 25 | [0, 1, 0, 114], 26 | [0, 1, 1, 185], 27 | [1, 1, 0, 114], 28 | [1, 1, 1, 0] 29 | ]) 30 | # grab the last element of each column and convert into float type, e.g. 114 -> 114.0 31 | # the final result: [114.0, 185.0, 114.0, 174.0, 114.0, 185.0, 114.0] 32 | bins = map[0:map.shape[0] - 1, map.shape[1] - 1].astype(float) 33 | 34 | # reshape the bins from [7] into [7,1] 35 | bins = bins.reshape((bins.shape[0], 1)) 36 | 37 | # accumulate element in bins, and get [114.0, 299.0, 413.0, 587.0, 701.0, 886.0, 1000.0] 38 | cbins = np.cumsum(bins) 39 | 40 | # divide the last element in cbins, e.g. 1000.0 41 | bins = bins / cbins[cbins.shape[0] - 1] 42 | 43 | # divide the last element of cbins, e.g. 1000.0, and reshape it, final shape [6,1] 44 | cbins = cbins[0:cbins.shape[0] - 1] / cbins[cbins.shape[0] - 1] 45 | cbins = cbins.reshape((cbins.shape[0], 1)) 46 | 47 | # transpose disp array, and repeat disp 6 times in axis-0, 1 times in axis-1, final shape=[6, Height*Width] 48 | ind = np.tile(disp.T, (6, 1)) 49 | tmp = np.tile(cbins, (1, disp.size)) 50 | 51 | # get the number of disp's elements bigger than each value in cbins, and sum up the 6 numbers 52 | b = (ind > tmp).astype(int) 53 | s = np.sum(b, axis=0) 54 | 55 | bins = 1 / bins 56 | 57 | # add an element 0 ahead of cbins, [0, cbins] 58 | t = cbins 59 | cbins = np.zeros((cbins.size + 1, 1)) 60 | cbins[1:] = t 61 | 62 | # get the ratio and interpolate it 63 | disp = (disp - cbins[s]) * bins[s] 64 | disp = map[s, 0:3] * np.tile(1 - disp, (1, 3)) + map[s + 1, 0:3] * np.tile(disp, (1, 3)) 65 | 66 | return disp 67 | 68 | 69 | def disp_to_color(disp, max_disp=None): 70 | """ 71 | Transfer disparity map to color map 72 | Args: 73 | disp (numpy.array): disparity map in (Height, Width) layout, value range [0, inf] 74 | max_disp (int): max disparity, optionally specifies the scaling factor 75 | Returns: 76 | disparity color map (numpy.array): disparity map in (Height, Width, 3) layout, 77 | range [0,1] 78 | """ 79 | # grab the disp shape(Height, Width) 80 | h, w = disp.shape 81 | 82 | # if max_disp not provided, set as the max value in disp 83 | if max_disp is None: 84 | max_disp = np.max(disp) 85 | 86 | # scale the disp to [0,1] by max_disp 87 | disp = disp.copy() / max_disp 88 | 89 | # reshape the disparity to [Height*Width, 1] 90 | disp = disp.reshape((h * w, 1)) 91 | 92 | # convert to color map, with shape [Height*Width, 3] 93 | disp = disp_map(disp) 94 | 95 | # convert to RGB-mode 96 | disp = disp.reshape((h, w, 3)) 97 | 98 | return disp 99 | 100 | 101 | 102 | def disp_err_to_color(disp_est, disp_gt): 103 | """ 104 | Calculate the error map between disparity estimation and disparity ground-truth 105 | hot color -> big error, cold color -> small error 106 | Args: 107 | disp_est (numpy.array): estimated disparity map 108 | in (Height, Width) layout, range [0,inf] 109 | disp_gt (numpy.array): ground truth disparity map 110 | in (Height, Width) layout, range [0,inf] 111 | Returns: 112 | disp_err (numpy.array): disparity error map 113 | in (Height, Width, 3) layout, range [0,1] 114 | """ 115 | """ matlab 116 | function D_err = disp_error_image (D_gt,D_est,tau,dilate_radius) 117 | if nargin==3 118 | dilate_radius = 1; 119 | end 120 | [E,D_val] = disp_error_map (D_gt,D_est); 121 | E = min(E/tau(1),(E./abs(D_gt))/tau(2)); 122 | cols = error_colormap(); 123 | D_err = zeros([size(D_gt) 3]); 124 | for i=1:size(cols,1) 125 | [v,u] = find(D_val > 0 & E >= cols(i,1) & E <= cols(i,2)); 126 | D_err(sub2ind(size(D_err),v,u,1*ones(length(v),1))) = cols(i,3); 127 | D_err(sub2ind(size(D_err),v,u,2*ones(length(v),1))) = cols(i,4); 128 | D_err(sub2ind(size(D_err),v,u,3*ones(length(v),1))) = cols(i,5); 129 | end 130 | D_err = imdilate(D_err,strel('disk',dilate_radius)); 131 | """ 132 | # error color map with interval (0, 0.1875, 0.375, 0.75, 1.5, 3, 6, 12, 24, 48, inf)/3.0 133 | # different interval corresponds to different 3-channel projection 134 | cols = np.array( 135 | [ 136 | [0 / 3.0, 0.1875 / 3.0, 49, 54, 149], 137 | [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], 138 | [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], 139 | [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], 140 | [1.5 / 3.0, 3 / 3.0, 224, 243, 248], 141 | [3 / 3.0, 6 / 3.0, 254, 224, 144], 142 | [6 / 3.0, 12 / 3.0, 253, 174, 97], 143 | [12 / 3.0, 24 / 3.0, 244, 109, 67], 144 | [24 / 3.0, 48 / 3.0, 215, 48, 39], 145 | [48 / 3.0, float("inf"), 165, 0, 38] 146 | ] 147 | ) 148 | 149 | # [0, 1] -> [0, 255.0] 150 | disp_est = disp_est.copy() * 255.0 151 | disp_gt = disp_gt.copy() * 255.0 152 | # get the error (<3px or <5%) map 153 | tau = [3.0, 0.05] 154 | E = np.abs(disp_est - disp_gt) 155 | 156 | not_empty = disp_gt > 0.0 157 | tmp = np.zeros_like(disp_gt) 158 | tmp[not_empty] = E[not_empty] / disp_gt[not_empty] / tau[1] 159 | E = np.minimum(E / tau[0], tmp) 160 | 161 | h, w = disp_gt.shape 162 | err_im = np.zeros(shape=(h, w, 3)).astype(np.uint8) 163 | for col in cols: 164 | y_x = not_empty & (E >= col[0]) & (E <= col[1]) 165 | err_im[y_x] = col[2:] 166 | 167 | # value range [0, 1], shape in [H, W 3] 168 | err_im = err_im.astype(np.float64) / 255.0 169 | 170 | return err_im 171 | 172 | def revalue(map, lower, upper, start, scale): 173 | mask = (map > lower) & (map <= upper) 174 | if np.sum(mask) >= 1.0: 175 | mn, mx = map[mask].min(), map[mask].max() 176 | map[mask] = ((map[mask] - mn) / (mx -mn + 1e-7)) * scale + start 177 | 178 | return map 179 | 180 | def disp_err_to_colorbar(est, gt, with_bar=False, cmap='jet'): 181 | error_bar_height = 50 182 | valid = gt > 0 183 | error_map = np.abs(est - gt) * valid 184 | h, w= error_map.shape 185 | 186 | maxvalue = error_map.max() 187 | # meanvalue = error_map.mean() 188 | # breakpoints = np.array([0, max(1, 1*meanvalue), max(4, 4*meanvalue), max(8, min(8*meanvalue, maxvalue/2)) , max(12, maxvalue)]) 189 | breakpoints = np.array([0, 1, 2, 4, 12, 16, max(192, maxvalue)]) 190 | points = np.array([0, 0.25, 0.38, 0.66, 0.83, 0.95, 1]) 191 | num_bins = np.array([0, w//8, w//8, w//4, w//4, w//8, w - (w//4 + w//4 + w//8 + w//8 + w//8)]) 192 | acc_num_bins = np.cumsum(num_bins) 193 | 194 | for i in range(1, len(breakpoints)): 195 | scale = points[i] - points[i-1] 196 | start = points[i-1] 197 | lower = breakpoints[i-1] 198 | upper = breakpoints[i] 199 | error_map = revalue(error_map, lower, upper, start, scale) 200 | 201 | # [0, 1], [H, W, 3] 202 | error_map = plt.cm.get_cmap(cmap)(error_map)[:, :, :3] 203 | 204 | if not with_bar: 205 | return error_map 206 | 207 | error_bar = np.array([]) 208 | for i in range(1, len(num_bins)): 209 | error_bar = np.concatenate((error_bar, np.linspace(points[i-1], points[i], num_bins[i]))) 210 | 211 | error_bar = np.repeat(error_bar, error_bar_height).reshape(w, error_bar_height).transpose(1, 0) # [error_bar_height, w] 212 | error_bar_map = plt.cm.get_cmap(cmap)(error_bar)[:, :, :3] 213 | plt.xticks(ticks=acc_num_bins, labels=breakpoints.astype(np.int32)) 214 | # plt.axis('off') 215 | 216 | # [0, 1], [H, W, 3] 217 | error_map = np.concatenate((error_map, error_bar_map), axis=0) 218 | 219 | return error_map 220 | --------------------------------------------------------------------------------