├── LICENSE ├── README.md ├── assets ├── megadepth_test_1500_scene_info │ ├── 0015_0.1_0.3.npz │ ├── 0015_0.3_0.5.npz │ ├── 0022_0.1_0.3.npz │ ├── 0022_0.3_0.5.npz │ ├── 0022_0.5_0.7.npz │ ├── megadepth_test_1500.txt │ └── megadepth_test_pairs_with_gt.txt ├── phototourism_sample_images │ ├── london_bridge_19481797_2295892421.jpg │ ├── london_bridge_49190386_5209386933.jpg │ ├── london_bridge_78916675_4568141288.jpg │ ├── london_bridge_94185272_3874562886.jpg │ ├── piazza_san_marco_06795901_3725050516.jpg │ ├── piazza_san_marco_15148634_5228701572.jpg │ ├── piazza_san_marco_18627786_5929294590.jpg │ ├── piazza_san_marco_43351518_2659980686.jpg │ ├── piazza_san_marco_58751010_4849458397.jpg │ ├── st_pauls_cathedral_30776973_2635313996.jpg │ ├── st_pauls_cathedral_37347628_10902811376.jpg │ ├── united_states_capitol_26757027_6717084061.jpg │ └── united_states_capitol_98169888_3347710852.jpg ├── scannet_sample_images │ ├── scene0711_00_frame-001680.jpg │ ├── scene0711_00_frame-001995.jpg │ ├── scene0713_00_frame-001320.jpg │ ├── scene0713_00_frame-002025.jpg │ ├── scene0721_00_frame-000375.jpg │ ├── scene0721_00_frame-002745.jpg │ ├── scene0722_00_frame-000045.jpg │ ├── scene0722_00_frame-000735.jpg │ ├── scene0726_00_frame-000135.jpg │ ├── scene0726_00_frame-000210.jpg │ ├── scene0737_00_frame-000930.jpg │ ├── scene0737_00_frame-001095.jpg │ ├── scene0738_00_frame-000885.jpg │ ├── scene0738_00_frame-001065.jpg │ ├── scene0743_00_frame-000000.jpg │ ├── scene0743_00_frame-001275.jpg │ ├── scene0744_00_frame-000585.jpg │ ├── scene0744_00_frame-002310.jpg │ ├── scene0747_00_frame-000000.jpg │ ├── scene0747_00_frame-001530.jpg │ ├── scene0752_00_frame-000075.jpg │ ├── scene0752_00_frame-001440.jpg │ ├── scene0755_00_frame-000120.jpg │ ├── scene0755_00_frame-002055.jpg │ ├── scene0758_00_frame-000165.jpg │ ├── scene0758_00_frame-000510.jpg │ ├── scene0768_00_frame-001095.jpg │ ├── scene0768_00_frame-003435.jpg │ ├── scene0806_00_frame-000225.jpg │ └── scene0806_00_frame-001095.jpg ├── scannet_test_1500 │ ├── intrinsics.npz │ ├── scannet_test.txt │ ├── statistics.json │ └── test.npz └── yfcc_test_4000 │ ├── 0000.npz │ └── yfcc_test_4000.txt ├── configs ├── data │ ├── __init__.py │ ├── base.py │ ├── debug │ │ └── .gitignore │ ├── megadepth_test_1500.py │ ├── megadepth_trainval_640.py │ ├── megadepth_trainval_832.py │ ├── scannet_test_1500.py │ ├── scannet_trainval.py │ └── yfcc_test_4000.py └── loftr │ ├── indoor │ ├── buggy_pos_enc │ │ ├── loftr_ds.py │ │ ├── loftr_ds_dense.py │ │ ├── loftr_ot.py │ │ └── loftr_ot_dense.py │ ├── debug │ │ └── .gitignore │ ├── loftr_ds.py │ ├── loftr_ds_dense.py │ ├── loftr_ds_quadtree.py │ ├── loftr_ot.py │ ├── loftr_ot_dense.py │ └── scannet │ │ ├── loftr_ds_eval.py │ │ ├── loftr_ds_eval_new.py │ │ └── loftr_ds_quadtree_eval.py │ └── outdoor │ ├── __pycache__ │ └── loftr_ds_quadtree.cpython-38.pyc │ ├── buggy_pos_enc │ ├── loftr_ds.py │ ├── loftr_ds_dense.py │ ├── loftr_ot.py │ └── loftr_ot_dense.py │ ├── debug │ └── .gitignore │ ├── loftr_ds.py │ ├── loftr_ds_dense.py │ ├── loftr_ds_quadtree.py │ ├── loftr_ot.py │ └── loftr_ot_dense.py ├── environment.yaml ├── images ├── .DS_Store └── main_figure_curvature.svg ├── requirements.txt ├── scripts └── reproduce_test │ ├── indoor_ds_quadtree_with_cse.sh │ ├── outdoor_ds_quadtree_with_cse_MEGA.sh │ └── outdoor_ds_quadtree_with_cse_YFCC.sh ├── src ├── DPT │ ├── curvature_extraction.py │ ├── depth_estimation.py │ ├── midas │ │ ├── __pycache__ │ │ │ ├── base_model.cpython-38.pyc │ │ │ ├── blocks.cpython-38.pyc │ │ │ ├── dpt_depth.cpython-38.pyc │ │ │ ├── midas_loss.cpython-38.pyc │ │ │ ├── transforms.cpython-38.pyc │ │ │ └── vit.cpython-38.pyc │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_loss.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── transforms.py │ │ └── vit.py │ └── utils.py ├── __init__.py ├── config │ └── default.py ├── datasets │ ├── megadepth.py │ ├── sampler.py │ ├── scannet.py │ └── yfcc.py ├── lightning │ ├── data.py │ └── lightning_loftr.py ├── loftr │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── resnet_fpn.cpython-38.pyc │ │ └── resnet_fpn.py │ ├── loftr.py │ ├── loftr_module │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── fine_preprocess.cpython-38.pyc │ │ │ ├── linear_attention.cpython-38.pyc │ │ │ ├── quadtree_attention.cpython-38.pyc │ │ │ └── transformer.cpython-38.pyc │ │ ├── fine_preprocess.py │ │ ├── linear_attention.py │ │ ├── quadtree_attention.py │ │ └── transformer.py │ └── utils │ │ ├── __pycache__ │ │ ├── coarse_matching.cpython-38.pyc │ │ ├── cvpr_ds_config.cpython-38.pyc │ │ ├── fine_matching.cpython-38.pyc │ │ ├── geometry.cpython-38.pyc │ │ ├── position_encoding.cpython-38.pyc │ │ └── supervision.cpython-38.pyc │ │ ├── coarse_matching.py │ │ ├── cvpr_ds_config.py │ │ ├── fine_matching.py │ │ ├── geometry.py │ │ ├── position_encoding.py │ │ └── supervision.py ├── losses │ └── loftr_loss.py ├── optimizers │ └── __init__.py └── utils │ ├── augment.py │ ├── comm.py │ ├── dataloader.py │ ├── dataset.py │ ├── metrics.py │ ├── misc.py │ ├── plotting.py │ └── profiler.py └── test.py /README.md: -------------------------------------------------------------------------------- 1 | # [ICCV 2023] [Guiding Local Feature Matching with Surface Curvature](https://openaccess.thecvf.com/content/ICCV2023/papers/Wang_Guiding_Local_Feature_Matching_with_Surface_Curvature_ICCV_2023_paper.pdf) 2 | [Shuzhe Wang](https://scholar.google.com/citations?user=Kzq9fl4AAAAJ&hl=en&oi=ao), [Juho Kannala](https://scholar.google.com/citations?user=c4mWQPQAAAAJ&hl=en), [Marc Pollefeys](https://scholar.google.com/citations?user=YYH0BjEAAAAJ&hl=en), [Daniel Barath](https://scholar.google.com/citations?user=U9-D8DYAAAAJ&hl=en) 3 | 4 | We propose a new method, called curvature similarity extractor (CSE), for improving local feature matching across images. CSE calculates the curvature of the local 3D surface patch for each detected feature point in a viewpoint-invariant manner via fitting quadrics to predicted monocular depth maps. This curvature is then leveraged as an additional signal in feature matching with off-the-shelf matchers like SuperGlue and LoFTR. Additionally, CSE enables end-to-end joint training by connecting the matcher and depth predictor networks. Our experiments demonstrate on large-scale real-world datasets that CSE consistently improves the accuracy of state-of-the-art methods. Fine-tuning the depth prediction network further enhances the accuracy. The proposed approach achieves state-of-the-art results on the ScanNet dataset, showcasing the effectiveness of incorporating 3D geometric information into feature matching. 5 | 6 | ![](images/main_figure_curvature.svg) 7 | 8 | ## Environment 9 | 10 | Our curvature similarity extractor is an add-on component for advanced matchers. Here we consider the [QuadTree](https://github.com/Tangshitao/QuadTreeAttention) as the matcher and [DPT](https://github.com/isl-org/DPT) for the depth estimation. Please consider setup the [QuadTree](https://github.com/Tangshitao/QuadTreeAttention) environment with the following commands: 11 | 12 | ``` 13 | git clone git@github.com:Tangshitao/QuadTreeAttention.git 14 | cd QuadTreeAttention&&python setup.py install 15 | ``` 16 | 17 | Download our CSE module and setup the environment with the following commands. 18 | 19 | ``` 20 | cd .. 21 | git clone git@github.com:AaltoVision/surface-curvature-estimator.git 22 | cd surface-curvature-estimator 23 | conda env create -f environment.yaml 24 | conda activate 25 | ``` 26 | 27 | ## Dataset 28 | 29 | For [Megadepth](https://www.cs.cornell.edu/projects/megadepth/) and [ScanNet](https://github.com/ScanNet/ScanNet#scannet-data) datasets, please refer to the [LoFTR](https://github.com/zju3dv/LoFTR) for dataset setup. For [YFCC100M](), you can use the OANet to download it. 30 | 31 | ``` 32 | cd .. 33 | git clone https://github.com/zjhthu/OANet 34 | cd OANet 35 | bash download_data.sh raw_data raw_data_yfcc.tar.gz 0 8 36 | tar -xvf raw_data_yfcc.tar.gz 37 | # YFCC100M 38 | ln -s raw_data/yfcc100m/* /path_to/data/yfcc/test 39 | ``` 40 | 41 | ## Evaluation 42 | 43 | We provide the evaluation without any model fine-tuning. The batch size is set to 1 to allow single gpu evaluation. Please consider downloading the models [here](https://drive.google.com/drive/folders/1lUpLCfkZJePPNEgSwS3OKkhIXvx_L5OR?usp=drive_link) and following the commands below for the evaluation. The weights can also be downloaded from the original [QuadTree](https://github.com/Tangshitao/QuadTreeAttention) and [DPT](https://github.com/isl-org/DPT) repos. Note we run the evaluation with and without CSE for a more fair comparison. 44 | 45 | ``` 46 | # For ScanNet 47 | sh scripts/reproduce_test/indoor_ds_quadtree_with_cse.sh 48 | # For MegaDepth 49 | sh scripts/reproduce_test/outdoor_ds_quadtree_with_cse_MEGA.sh 50 | # For YFCC 51 | scripts/reproduce_test/outdoor_ds_quadtree_with_cse_YFCC.sh 52 | ``` 53 | 54 | ## Acknowledgements 55 | 56 | We appreciate the previous open-source repositories [QuadTree](https://github.com/Tangshitao/QuadTreeAttention) , [LoFTR](https://github.com/zju3dv/LoFTR) , and [DPT](https://github.com/isl-org/DPT) 57 | 58 | ## Citation 59 | 60 | Please consider citing our papers if you find this code useful for your research: 61 | 62 | ``` 63 | @InProceedings{Wang_2023_ICCV, 64 | author = {Wang, Shuzhe and Kannala, Juho and Pollefeys, Marc and Barath, Daniel}, 65 | title = {Guiding Local Feature Matching with Surface Curvature}, 66 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 67 | month = {October}, 68 | year = {2023}, 69 | pages = {17981-17991} 70 | } 71 | ``` -------------------------------------------------------------------------------- /assets/megadepth_test_1500_scene_info/0015_0.1_0.3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/megadepth_test_1500_scene_info/0015_0.1_0.3.npz -------------------------------------------------------------------------------- /assets/megadepth_test_1500_scene_info/0015_0.3_0.5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/megadepth_test_1500_scene_info/0015_0.3_0.5.npz -------------------------------------------------------------------------------- /assets/megadepth_test_1500_scene_info/0022_0.1_0.3.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/megadepth_test_1500_scene_info/0022_0.1_0.3.npz -------------------------------------------------------------------------------- /assets/megadepth_test_1500_scene_info/0022_0.3_0.5.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/megadepth_test_1500_scene_info/0022_0.3_0.5.npz -------------------------------------------------------------------------------- /assets/megadepth_test_1500_scene_info/0022_0.5_0.7.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/megadepth_test_1500_scene_info/0022_0.5_0.7.npz -------------------------------------------------------------------------------- /assets/megadepth_test_1500_scene_info/megadepth_test_1500.txt: -------------------------------------------------------------------------------- 1 | 0015_0.1_0.3 2 | 0015_0.3_0.5 3 | 0022_0.1_0.3 4 | 0022_0.3_0.5 5 | 0022_0.5_0.7 -------------------------------------------------------------------------------- /assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/london_bridge_94185272_3874562886.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/london_bridge_94185272_3874562886.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/piazza_san_marco_06795901_3725050516.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/piazza_san_marco_15148634_5228701572.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/piazza_san_marco_18627786_5929294590.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/piazza_san_marco_18627786_5929294590.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/piazza_san_marco_43351518_2659980686.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/piazza_san_marco_43351518_2659980686.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/piazza_san_marco_58751010_4849458397.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/piazza_san_marco_58751010_4849458397.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/st_pauls_cathedral_30776973_2635313996.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/st_pauls_cathedral_30776973_2635313996.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/st_pauls_cathedral_37347628_10902811376.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/st_pauls_cathedral_37347628_10902811376.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg -------------------------------------------------------------------------------- /assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0711_00_frame-001680.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0711_00_frame-001680.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0711_00_frame-001995.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0711_00_frame-001995.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0713_00_frame-001320.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0713_00_frame-001320.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0713_00_frame-002025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0713_00_frame-002025.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0721_00_frame-000375.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0721_00_frame-000375.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0721_00_frame-002745.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0721_00_frame-002745.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0722_00_frame-000045.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0722_00_frame-000045.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0722_00_frame-000735.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0722_00_frame-000735.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0726_00_frame-000135.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0726_00_frame-000135.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0726_00_frame-000210.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0726_00_frame-000210.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0737_00_frame-000930.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0737_00_frame-000930.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0737_00_frame-001095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0737_00_frame-001095.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0738_00_frame-000885.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0738_00_frame-000885.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0738_00_frame-001065.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0738_00_frame-001065.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0743_00_frame-000000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0743_00_frame-000000.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0743_00_frame-001275.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0743_00_frame-001275.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0744_00_frame-000585.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0744_00_frame-000585.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0744_00_frame-002310.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0744_00_frame-002310.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0747_00_frame-000000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0747_00_frame-000000.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0747_00_frame-001530.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0747_00_frame-001530.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0752_00_frame-000075.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0752_00_frame-000075.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0752_00_frame-001440.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0752_00_frame-001440.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0755_00_frame-000120.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0755_00_frame-000120.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0755_00_frame-002055.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0755_00_frame-002055.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0758_00_frame-000165.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0758_00_frame-000165.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0758_00_frame-000510.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0758_00_frame-000510.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0768_00_frame-001095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0768_00_frame-001095.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0768_00_frame-003435.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0768_00_frame-003435.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0806_00_frame-000225.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0806_00_frame-000225.jpg -------------------------------------------------------------------------------- /assets/scannet_sample_images/scene0806_00_frame-001095.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_sample_images/scene0806_00_frame-001095.jpg -------------------------------------------------------------------------------- /assets/scannet_test_1500/intrinsics.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_test_1500/intrinsics.npz -------------------------------------------------------------------------------- /assets/scannet_test_1500/scannet_test.txt: -------------------------------------------------------------------------------- 1 | test.npz -------------------------------------------------------------------------------- /assets/scannet_test_1500/statistics.json: -------------------------------------------------------------------------------- 1 | { 2 | "scene0707_00": 15, 3 | "scene0708_00": 15, 4 | "scene0709_00": 15, 5 | "scene0710_00": 15, 6 | "scene0711_00": 15, 7 | "scene0712_00": 15, 8 | "scene0713_00": 15, 9 | "scene0714_00": 15, 10 | "scene0715_00": 15, 11 | "scene0716_00": 15, 12 | "scene0717_00": 15, 13 | "scene0718_00": 15, 14 | "scene0719_00": 15, 15 | "scene0720_00": 15, 16 | "scene0721_00": 15, 17 | "scene0722_00": 15, 18 | "scene0723_00": 15, 19 | "scene0724_00": 15, 20 | "scene0725_00": 15, 21 | "scene0726_00": 15, 22 | "scene0727_00": 15, 23 | "scene0728_00": 15, 24 | "scene0729_00": 15, 25 | "scene0730_00": 15, 26 | "scene0731_00": 15, 27 | "scene0732_00": 15, 28 | "scene0733_00": 15, 29 | "scene0734_00": 15, 30 | "scene0735_00": 15, 31 | "scene0736_00": 15, 32 | "scene0737_00": 15, 33 | "scene0738_00": 15, 34 | "scene0739_00": 15, 35 | "scene0740_00": 15, 36 | "scene0741_00": 15, 37 | "scene0742_00": 15, 38 | "scene0743_00": 15, 39 | "scene0744_00": 15, 40 | "scene0745_00": 15, 41 | "scene0746_00": 15, 42 | "scene0747_00": 15, 43 | "scene0748_00": 15, 44 | "scene0749_00": 15, 45 | "scene0750_00": 15, 46 | "scene0751_00": 15, 47 | "scene0752_00": 15, 48 | "scene0753_00": 15, 49 | "scene0754_00": 15, 50 | "scene0755_00": 15, 51 | "scene0756_00": 15, 52 | "scene0757_00": 15, 53 | "scene0758_00": 15, 54 | "scene0759_00": 15, 55 | "scene0760_00": 15, 56 | "scene0761_00": 15, 57 | "scene0762_00": 15, 58 | "scene0763_00": 15, 59 | "scene0764_00": 15, 60 | "scene0765_00": 15, 61 | "scene0766_00": 15, 62 | "scene0767_00": 15, 63 | "scene0768_00": 15, 64 | "scene0769_00": 15, 65 | "scene0770_00": 15, 66 | "scene0771_00": 15, 67 | "scene0772_00": 15, 68 | "scene0773_00": 15, 69 | "scene0774_00": 15, 70 | "scene0775_00": 15, 71 | "scene0776_00": 15, 72 | "scene0777_00": 15, 73 | "scene0778_00": 15, 74 | "scene0779_00": 15, 75 | "scene0780_00": 15, 76 | "scene0781_00": 15, 77 | "scene0782_00": 15, 78 | "scene0783_00": 15, 79 | "scene0784_00": 15, 80 | "scene0785_00": 15, 81 | "scene0786_00": 15, 82 | "scene0787_00": 15, 83 | "scene0788_00": 15, 84 | "scene0789_00": 15, 85 | "scene0790_00": 15, 86 | "scene0791_00": 15, 87 | "scene0792_00": 15, 88 | "scene0793_00": 15, 89 | "scene0794_00": 15, 90 | "scene0795_00": 15, 91 | "scene0796_00": 15, 92 | "scene0797_00": 15, 93 | "scene0798_00": 15, 94 | "scene0799_00": 15, 95 | "scene0800_00": 15, 96 | "scene0801_00": 15, 97 | "scene0802_00": 15, 98 | "scene0803_00": 15, 99 | "scene0804_00": 15, 100 | "scene0805_00": 15, 101 | "scene0806_00": 15 102 | } -------------------------------------------------------------------------------- /assets/scannet_test_1500/test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/scannet_test_1500/test.npz -------------------------------------------------------------------------------- /assets/yfcc_test_4000/0000.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/assets/yfcc_test_4000/0000.npz -------------------------------------------------------------------------------- /assets/yfcc_test_4000/yfcc_test_4000.txt: -------------------------------------------------------------------------------- 1 | 0000.npz 2 | -------------------------------------------------------------------------------- /configs/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/configs/data/__init__.py -------------------------------------------------------------------------------- /configs/data/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | The data config will be the last one merged into the main config. 3 | Setups in data configs will override all existed setups! 4 | """ 5 | 6 | from yacs.config import CfgNode as CN 7 | _CN = CN() 8 | _CN.DATASET = CN() 9 | _CN.TRAINER = CN() 10 | 11 | # training data config 12 | _CN.DATASET.TRAIN_DATA_ROOT = None 13 | _CN.DATASET.TRAIN_POSE_ROOT = None 14 | _CN.DATASET.TRAIN_NPZ_ROOT = None 15 | _CN.DATASET.TRAIN_LIST_PATH = None 16 | _CN.DATASET.TRAIN_INTRINSIC_PATH = None 17 | # validation set config 18 | _CN.DATASET.VAL_DATA_ROOT = None 19 | _CN.DATASET.VAL_POSE_ROOT = None 20 | _CN.DATASET.VAL_NPZ_ROOT = None 21 | _CN.DATASET.VAL_LIST_PATH = None 22 | _CN.DATASET.VAL_INTRINSIC_PATH = None 23 | 24 | # testing data config 25 | _CN.DATASET.TEST_DATA_ROOT = None 26 | _CN.DATASET.TEST_POSE_ROOT = None 27 | _CN.DATASET.TEST_NPZ_ROOT = None 28 | _CN.DATASET.TEST_LIST_PATH = None 29 | _CN.DATASET.TEST_INTRINSIC_PATH = None 30 | 31 | # dataset config 32 | _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 33 | _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val 34 | 35 | cfg = _CN 36 | -------------------------------------------------------------------------------- /configs/data/debug/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /configs/data/megadepth_test_1500.py: -------------------------------------------------------------------------------- 1 | from configs.data.base import cfg 2 | 3 | TEST_BASE_PATH = "assets/megadepth_test_1500_scene_info" 4 | 5 | cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" 6 | cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" 7 | cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" 8 | cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/megadepth_test_1500.txt" 9 | 10 | cfg.DATASET.MGDPT_IMG_RESIZE = 832 11 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 12 | -------------------------------------------------------------------------------- /configs/data/megadepth_trainval_640.py: -------------------------------------------------------------------------------- 1 | from configs.data.base import cfg 2 | 3 | 4 | TRAIN_BASE_PATH = "data/megadepth/index" 5 | cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth" 6 | cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train" 7 | cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" 8 | cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt" 9 | cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 10 | 11 | TEST_BASE_PATH = "data/megadepth/index" 12 | cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" 13 | cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" 14 | cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" 15 | cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" 16 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val 17 | 18 | # 368 scenes in total for MegaDepth 19 | # (with difficulty balanced (further split each scene to 3 sub-scenes)) 20 | cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100 21 | 22 | cfg.DATASET.MGDPT_IMG_RESIZE = 640 # for training on 11GB mem GPUs 23 | -------------------------------------------------------------------------------- /configs/data/megadepth_trainval_832.py: -------------------------------------------------------------------------------- 1 | from configs.data.base import cfg 2 | 3 | 4 | TRAIN_BASE_PATH = "data/megadepth/index" 5 | cfg.DATASET.TRAINVAL_DATA_SOURCE = "MegaDepth" 6 | cfg.DATASET.TRAIN_DATA_ROOT = "data/megadepth/train" 7 | cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" 8 | cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/trainvaltest_list/train_list.txt" 9 | cfg.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.0 10 | 11 | TEST_BASE_PATH = "data/megadepth/index" 12 | cfg.DATASET.TEST_DATA_SOURCE = "MegaDepth" 13 | cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/megadepth/test" 14 | cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_info_val_1500" 15 | cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/trainvaltest_list/val_list.txt" 16 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val 17 | 18 | # 368 scenes in total for MegaDepth 19 | # (with difficulty balanced (further split each scene to 3 sub-scenes)) 20 | cfg.TRAINER.N_SAMPLES_PER_SUBSET = 100 21 | 22 | cfg.DATASET.MGDPT_IMG_RESIZE = 832 # for training on 32GB meme GPUs 23 | -------------------------------------------------------------------------------- /configs/data/scannet_test_1500.py: -------------------------------------------------------------------------------- 1 | from configs.data.base import cfg 2 | 3 | TEST_BASE_PATH = "assets/scannet_test_1500" 4 | 5 | cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" 6 | cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" 7 | cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" 8 | cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" 9 | cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" 10 | 11 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 12 | 13 | 14 | # TEST_BASE_PATH = "data/scannet/index" 15 | 16 | # cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" 17 | # cfg.DATASET.TEST_DATA_ROOT = "data/scannet/train" 18 | # cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}/scene_data/train" 19 | # cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scene_data/train_list/scannet_all.txt" 20 | # cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" 21 | 22 | # cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 23 | 24 | 25 | -------------------------------------------------------------------------------- /configs/data/scannet_trainval.py: -------------------------------------------------------------------------------- 1 | from configs.data.base import cfg 2 | 3 | 4 | TRAIN_BASE_PATH = "data/scannet/index" 5 | cfg.DATASET.TRAINVAL_DATA_SOURCE = "ScanNet" 6 | cfg.DATASET.TRAIN_DATA_ROOT = "data/scannet/train" 7 | cfg.DATASET.TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_data/train" 8 | cfg.DATASET.TRAIN_LIST_PATH = f"{TRAIN_BASE_PATH}/scene_data/train_list/scannet_debug.txt" 9 | cfg.DATASET.TRAIN_INTRINSIC_PATH = f"{TRAIN_BASE_PATH}/intrinsics.npz" 10 | 11 | TEST_BASE_PATH = "assets/scannet_test_1500" 12 | cfg.DATASET.TEST_DATA_SOURCE = "ScanNet" 13 | cfg.DATASET.VAL_DATA_ROOT = cfg.DATASET.TEST_DATA_ROOT = "data/scannet/test" 14 | cfg.DATASET.VAL_NPZ_ROOT = cfg.DATASET.TEST_NPZ_ROOT = TEST_BASE_PATH 15 | cfg.DATASET.VAL_LIST_PATH = cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/scannet_test.txt" 16 | cfg.DATASET.VAL_INTRINSIC_PATH = cfg.DATASET.TEST_INTRINSIC_PATH = f"{TEST_BASE_PATH}/intrinsics.npz" 17 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 # for both test and val 18 | -------------------------------------------------------------------------------- /configs/data/yfcc_test_4000.py: -------------------------------------------------------------------------------- 1 | from configs.data.base import cfg 2 | 3 | TEST_BASE_PATH = "assets/yfcc_test_4000" 4 | 5 | cfg.DATASET.TEST_DATA_SOURCE = "YFCC" 6 | cfg.DATASET.TEST_DATA_ROOT = "data/yfcc/test" 7 | cfg.DATASET.TEST_NPZ_ROOT = f"{TEST_BASE_PATH}" 8 | cfg.DATASET.TEST_LIST_PATH = f"{TEST_BASE_PATH}/yfcc_test_4000.txt" 9 | 10 | cfg.DATASET.MGDPT_IMG_RESIZE = 832 11 | cfg.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 12 | -------------------------------------------------------------------------------- /configs/loftr/indoor/buggy_pos_enc/loftr_ds.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 5 | 6 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 7 | -------------------------------------------------------------------------------- /configs/loftr/indoor/buggy_pos_enc/loftr_ds_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 5 | 6 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 7 | 8 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 9 | -------------------------------------------------------------------------------- /configs/loftr/indoor/buggy_pos_enc/loftr_ot.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 5 | 6 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 7 | -------------------------------------------------------------------------------- /configs/loftr/indoor/buggy_pos_enc/loftr_ot_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 5 | 6 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 7 | 8 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 9 | -------------------------------------------------------------------------------- /configs/loftr/indoor/debug/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /configs/loftr/indoor/loftr_ds.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 4 | 5 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 6 | -------------------------------------------------------------------------------- /configs/loftr/indoor/loftr_ds_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 4 | 5 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 6 | 7 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 8 | -------------------------------------------------------------------------------- /configs/loftr/indoor/loftr_ds_quadtree.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | #cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 5 | 6 | 7 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 8 | cfg.LOFTR.RESNETFPN.INITIAL_DIM = 128 9 | cfg.LOFTR.RESNETFPN.BLOCK_DIMS=[128, 196, 256] 10 | cfg.LOFTR.COARSE.D_MODEL = 256 11 | cfg.LOFTR.COARSE.BLOCK_TYPE = 'quadtree' 12 | cfg.LOFTR.COARSE.ATTN_TYPE = 'B' 13 | cfg.LOFTR.COARSE.TOPKS=[32, 16, 16] 14 | cfg.LOFTR.FINE.D_MODEL = 128 15 | 16 | 17 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] -------------------------------------------------------------------------------- /configs/loftr/indoor/loftr_ot.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 4 | 5 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 6 | -------------------------------------------------------------------------------- /configs/loftr/indoor/loftr_ot_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 4 | 5 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 6 | 7 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] 8 | -------------------------------------------------------------------------------- /configs/loftr/indoor/scannet/loftr_ds_eval.py: -------------------------------------------------------------------------------- 1 | """ A config only for reproducing the ScanNet evaluation results. 2 | 3 | We remove border matches by default, but the originally implemented 4 | `remove_border()` has a bug, leading to only two sides of 5 | all borders are actually removed. However, the [bug fix](https://github.com/zju3dv/LoFTR/commit/e9146c8144dea5f3cbdd98b225f3e147a171c216) 6 | makes the scannet evaluation results worse (auc@10=40.8 => 39.5), which should be 7 | caused by tiny result fluctuation of few image pairs. This config set `BORDER_RM` to 0 8 | to be consistent with the results in our paper. 9 | """ 10 | 11 | from src.config.default import _CN as cfg 12 | 13 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 14 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 15 | 16 | cfg.LOFTR.MATCH_COARSE.BORDER_RM = 0 17 | -------------------------------------------------------------------------------- /configs/loftr/indoor/scannet/loftr_ds_eval_new.py: -------------------------------------------------------------------------------- 1 | """ A config only for reproducing the ScanNet evaluation results. 2 | 3 | We remove border matches by default, but the originally implemented 4 | `remove_border()` has a bug, leading to only two sides of 5 | all borders are actually removed. However, the [bug fix](https://github.com/zju3dv/LoFTR/commit/e9146c8144dea5f3cbdd98b225f3e147a171c216) 6 | makes the scannet evaluation results worse (auc@10=40.8 => 39.5), which should be 7 | caused by tiny result fluctuation of few image pairs. This config set `BORDER_RM` to 0 8 | to be consistent with the results in our paper. 9 | 10 | Update: This config is for testing the re-trained model with the pos-enc bug fixed. 11 | """ 12 | 13 | from src.config.default import _CN as cfg 14 | 15 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = True 16 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 17 | 18 | cfg.LOFTR.MATCH_COARSE.BORDER_RM = 0 19 | -------------------------------------------------------------------------------- /configs/loftr/indoor/scannet/loftr_ds_quadtree_eval.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | #cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 5 | 6 | 7 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 8 | cfg.LOFTR.RESNETFPN.INITIAL_DIM = 128 9 | cfg.LOFTR.RESNETFPN.BLOCK_DIMS=[128, 196, 256] 10 | cfg.LOFTR.COARSE.D_MODEL = 256 11 | cfg.LOFTR.COARSE.BLOCK_TYPE = 'quadtree' 12 | cfg.LOFTR.COARSE.ATTN_TYPE = 'B' 13 | cfg.LOFTR.COARSE.TOPKS=[32, 16, 16] 14 | cfg.LOFTR.FINE.D_MODEL = 128 15 | 16 | 17 | cfg.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12, 17, 20, 23, 26, 29] -------------------------------------------------------------------------------- /configs/loftr/outdoor/__pycache__/loftr_ds_quadtree.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/configs/loftr/outdoor/__pycache__/loftr_ds_quadtree.cpython-38.pyc -------------------------------------------------------------------------------- /configs/loftr/outdoor/buggy_pos_enc/loftr_ds.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 5 | 6 | cfg.TRAINER.CANONICAL_LR = 8e-3 7 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 8 | cfg.TRAINER.WARMUP_RATIO = 0.1 9 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 10 | 11 | # pose estimation 12 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 13 | 14 | cfg.TRAINER.OPTIMIZER = "adamw" 15 | cfg.TRAINER.ADAMW_DECAY = 0.1 16 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 17 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/buggy_pos_enc/loftr_ds_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 5 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 6 | 7 | cfg.TRAINER.CANONICAL_LR = 8e-3 8 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 9 | cfg.TRAINER.WARMUP_RATIO = 0.1 10 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 11 | 12 | # pose estimation 13 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 14 | 15 | cfg.TRAINER.OPTIMIZER = "adamw" 16 | cfg.TRAINER.ADAMW_DECAY = 0.1 17 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 18 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/buggy_pos_enc/loftr_ot.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 5 | 6 | cfg.TRAINER.CANONICAL_LR = 8e-3 7 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 8 | cfg.TRAINER.WARMUP_RATIO = 0.1 9 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 10 | 11 | # pose estimation 12 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 13 | 14 | cfg.TRAINER.OPTIMIZER = "adamw" 15 | cfg.TRAINER.ADAMW_DECAY = 0.1 16 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 17 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/buggy_pos_enc/loftr_ot_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.COARSE.TEMP_BUG_FIX = False 4 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 5 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 6 | 7 | cfg.TRAINER.CANONICAL_LR = 8e-3 8 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 9 | cfg.TRAINER.WARMUP_RATIO = 0.1 10 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 11 | 12 | # pose estimation 13 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 14 | 15 | cfg.TRAINER.OPTIMIZER = "adamw" 16 | cfg.TRAINER.ADAMW_DECAY = 0.1 17 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 18 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/debug/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/loftr_ds.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 4 | 5 | cfg.TRAINER.CANONICAL_LR = 8e-3 6 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 7 | cfg.TRAINER.WARMUP_RATIO = 0.1 8 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 9 | 10 | # pose estimation 11 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 12 | 13 | cfg.TRAINER.OPTIMIZER = "adamw" 14 | cfg.TRAINER.ADAMW_DECAY = 0.1 15 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 16 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/loftr_ds_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' 4 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 5 | 6 | cfg.TRAINER.CANONICAL_LR = 8e-3 7 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 8 | cfg.TRAINER.WARMUP_RATIO = 0.1 9 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 10 | 11 | # pose estimation 12 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 13 | 14 | cfg.TRAINER.OPTIMIZER = "adamw" 15 | cfg.TRAINER.ADAMW_DECAY = 0.1 16 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 17 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/loftr_ds_quadtree.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = "dual_softmax" 4 | 5 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 6 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 7 | 8 | cfg.LOFTR.RESNETFPN.INITIAL_DIM = 128 9 | cfg.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] 10 | cfg.LOFTR.COARSE.D_MODEL = 256 11 | cfg.LOFTR.COARSE.BLOCK_TYPE = "quadtree" 12 | cfg.LOFTR.COARSE.ATTN_TYPE = "B" 13 | cfg.LOFTR.COARSE.TOPKS=[16, 8, 8] 14 | 15 | cfg.LOFTR.FINE.D_MODEL = 128 16 | 17 | cfg.TRAINER.CANONICAL_LR = 8e-3 18 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 19 | cfg.TRAINER.WARMUP_RATIO = 0.1 20 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 21 | 22 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 23 | 24 | cfg.TRAINER.OPTIMIZER = "adamw" 25 | cfg.TRAINER.ADAMW_DECAY = 0.1 -------------------------------------------------------------------------------- /configs/loftr/outdoor/loftr_ot.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 4 | 5 | cfg.TRAINER.CANONICAL_LR = 8e-3 6 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 7 | cfg.TRAINER.WARMUP_RATIO = 0.1 8 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 9 | 10 | # pose estimation 11 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 12 | 13 | cfg.TRAINER.OPTIMIZER = "adamw" 14 | cfg.TRAINER.ADAMW_DECAY = 0.1 15 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 16 | -------------------------------------------------------------------------------- /configs/loftr/outdoor/loftr_ot_dense.py: -------------------------------------------------------------------------------- 1 | from src.config.default import _CN as cfg 2 | 3 | cfg.LOFTR.MATCH_COARSE.MATCH_TYPE = 'sinkhorn' 4 | cfg.LOFTR.MATCH_COARSE.SPARSE_SPVS = False 5 | 6 | cfg.TRAINER.CANONICAL_LR = 8e-3 7 | cfg.TRAINER.WARMUP_STEP = 1875 # 3 epochs 8 | cfg.TRAINER.WARMUP_RATIO = 0.1 9 | cfg.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] 10 | 11 | # pose estimation 12 | cfg.TRAINER.RANSAC_PIXEL_THR = 0.5 13 | 14 | cfg.TRAINER.OPTIMIZER = "adamw" 15 | cfg.TRAINER.ADAMW_DECAY = 0.1 16 | cfg.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 17 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: CSE 2 | channels: 3 | # - https://dx-mirrors.sensetime.com/anaconda/cloud/pytorch 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.8 9 | - cudatoolkit=10.2 10 | - pytorch=1.9.0 11 | - pip 12 | - pip: 13 | - -r requirements.txt 14 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/images/.DS_Store -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv_python==4.4.0.46 2 | albumentations==0.5.1 --no-binary=imgaug,albumentations 3 | ray>=1.0.1 4 | einops==0.3.0 5 | kornia==0.4.1 6 | loguru==0.5.3 7 | yacs>=0.1.8 8 | tqdm 9 | autopep8 10 | pylint 11 | ipython 12 | jupyterlab 13 | matplotlib 14 | timm==0.9.5 15 | h5py==3.1.0 16 | pytorch-lightning==1.3.5 17 | joblib>=1.0.1 18 | yacs==0.1.8 19 | 20 | 21 | -------------------------------------------------------------------------------- /scripts/reproduce_test/indoor_ds_quadtree_with_cse.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | SCRIPTPATH=$(dirname $(readlink -f "$0")) 4 | PROJECT_DIR="${SCRIPTPATH}/../../" 5 | 6 | # conda activate loftr 7 | export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH 8 | cd $PROJECT_DIR 9 | 10 | data_cfg_path="configs/data/scannet_test_1500.py" 11 | main_cfg_path="configs/loftr/indoor/scannet/loftr_ds_quadtree_eval.py" 12 | ckpt_path='weights/indoor.ckpt' 13 | # ckpt_path='weights/indoor_finetuning.ckpt' 14 | dump_dir="dump/loftr_ds_indoor" 15 | profiler_name="inference" 16 | n_nodes=1 # mannually keep this the same with --nodes 17 | n_gpus_per_node=1 18 | torch_num_workers=4 19 | batch_size=1 # per gpu 20 | 21 | # Runing evaluation without cse 22 | python -u ./test.py \ 23 | ${data_cfg_path} \ 24 | ${main_cfg_path} \ 25 | --ckpt_path=${ckpt_path} \ 26 | --dump_dir=${dump_dir} \ 27 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ 28 | --batch_size=${batch_size} --num_workers=${torch_num_workers}\ 29 | --profiler_name=${profiler_name} \ 30 | --benchmark \ 31 | 32 | # Runing evaluation with cse 33 | python -u ./test.py \ 34 | ${data_cfg_path} \ 35 | ${main_cfg_path} \ 36 | --ckpt_path=${ckpt_path} \ 37 | --dump_dir=${dump_dir} \ 38 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ 39 | --batch_size=${batch_size} --num_workers=${torch_num_workers}\ 40 | --profiler_name=${profiler_name} \ 41 | --benchmark \ 42 | --add_curvature 43 | 44 | 45 | -------------------------------------------------------------------------------- /scripts/reproduce_test/outdoor_ds_quadtree_with_cse_MEGA.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | SCRIPTPATH=$(dirname $(readlink -f "$0")) 4 | PROJECT_DIR="${SCRIPTPATH}/../../" 5 | 6 | # conda activate loftr 7 | export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH 8 | cd $PROJECT_DIR 9 | 10 | data_cfg_path="configs/data/megadepth_test_1500.py" 11 | main_cfg_path="configs/loftr/outdoor/loftr_ds_quadtree.py" 12 | ckpt_path='weights/outdoor.ckpt' 13 | # ckpt_path='weights/outdoor_finetuning.ckpt' 14 | dump_dir="dump/loftr_ds_outdoor" 15 | profiler_name="inference" 16 | n_nodes=1 # mannually keep this the same with --nodes 17 | n_gpus_per_node=1 18 | torch_num_workers=4 19 | batch_size=1 # per gpu 20 | 21 | 22 | # Runing evaluation without cse 23 | python -u ./test.py \ 24 | ${data_cfg_path} \ 25 | ${main_cfg_path} \ 26 | --ckpt_path=${ckpt_path} \ 27 | --dump_dir=${dump_dir} \ 28 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ 29 | --batch_size=${batch_size} --num_workers=${torch_num_workers}\ 30 | --profiler_name=${profiler_name} \ 31 | --benchmark \ 32 | 33 | # Runing evaluation with cse 34 | python -u ./test.py \ 35 | ${data_cfg_path} \ 36 | ${main_cfg_path} \ 37 | --ckpt_path=${ckpt_path} \ 38 | --dump_dir=${dump_dir} \ 39 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ 40 | --batch_size=${batch_size} --num_workers=${torch_num_workers}\ 41 | --profiler_name=${profiler_name} \ 42 | --benchmark \ 43 | --add_curvature 44 | -------------------------------------------------------------------------------- /scripts/reproduce_test/outdoor_ds_quadtree_with_cse_YFCC.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -l 2 | 3 | SCRIPTPATH=$(dirname $(readlink -f "$0")) 4 | PROJECT_DIR="${SCRIPTPATH}/../../" 5 | 6 | # conda activate loftr 7 | export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH 8 | cd $PROJECT_DIR 9 | 10 | data_cfg_path="configs/data/yfcc_test_4000.py" 11 | main_cfg_path="configs/loftr/outdoor/loftr_ds_quadtree.py" 12 | ckpt_path='weights/outdoor.ckpt' 13 | # ckpt_path='weights/outdoor_finetuning.ckpt' 14 | dump_dir="dump/loftr_ds_outdoor_yfcc" 15 | profiler_name="inference" 16 | n_nodes=1 # mannually keep this the same with --nodes 17 | n_gpus_per_node=1 18 | torch_num_workers=4 19 | batch_size=1 # per gpu 20 | 21 | python -u ./test.py \ 22 | ${data_cfg_path} \ 23 | ${main_cfg_path} \ 24 | --ckpt_path=${ckpt_path} \ 25 | --dump_dir=${dump_dir} \ 26 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ 27 | --batch_size=${batch_size} --num_workers=${torch_num_workers}\ 28 | --profiler_name=${profiler_name} \ 29 | --benchmark 30 | 31 | # Runing evaluation with cse 32 | python -u ./test.py \ 33 | ${data_cfg_path} \ 34 | ${main_cfg_path} \ 35 | --ckpt_path=${ckpt_path} \ 36 | --dump_dir=${dump_dir} \ 37 | --gpus=${n_gpus_per_node} --num_nodes=${n_nodes} --accelerator="ddp" \ 38 | --batch_size=${batch_size} --num_workers=${torch_num_workers}\ 39 | --profiler_name=${profiler_name} \ 40 | --benchmark \ 41 | --add_curvature 42 | -------------------------------------------------------------------------------- /src/DPT/curvature_extraction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from src.DPT.utils import ellipsoid_plot 7 | from matplotlib import pyplot as plt 8 | 9 | import numpy as np 10 | 11 | def extract_curv_tensor(z_unfold, _device): 12 | bz, kpts, groups = z_unfold.shape[0], z_unfold.shape[1], z_unfold.shape[2] 13 | # set matrix C 14 | k = 4 15 | C = torch.zeros((bz, kpts, 6, 6), device = _device) 16 | D = torch.zeros((bz, kpts, 10, groups),device = _device) 17 | Coff = torch.zeros((bz, kpts, 4, 4),device = _device) 18 | T = torch.zeros((bz, kpts, 4, 4),device = _device) 19 | 20 | C[:,:,0,0] = -1 21 | C[:,:,1,1] = -1 22 | C[:,:,2,2] = -1 23 | C[:,:,3,3] = -k 24 | C[:,:,4,4] = -k 25 | C[:,:,5,5] = -k 26 | C[:,:,0, 1] = k/2 - 1 27 | C[:,:,0, 2] = k/2 - 1 28 | C[:,:,1, 0] = k/2 - 1 29 | C[:,:,1, 2] = k/2 - 1 30 | C[:,:,2, 0] = k/2 - 1 31 | C[:,:,2, 1] = k/2 - 1 32 | 33 | inv_C = torch.linalg.pinv(C) 34 | 35 | p1 = z_unfold[:,:,:,0] 36 | p2 = z_unfold[:,:,:,1] 37 | p3 = z_unfold[:,:,:,2] 38 | 39 | D[:,:,0,:] = p1 * p1 40 | D[:,:,1,:] = p2 * p2 41 | D[:,:,2,:] = p3 * p3 42 | D[:,:,3,:] = 2 * p1 * p2 43 | D[:,:,4,:] = 2 * p1 * p3 44 | D[:,:,5,:] = 2 * p2 * p3 45 | D[:,:,6,:] = 2 * p1 46 | D[:,:,7,:] = 2 * p2 47 | D[:,:,8,:] = 2 * p3 48 | 49 | D_T = D.transpose(3,2) 50 | S = torch.einsum('bkmn,bknp->bkmp', D, D_T) 51 | 52 | S11 = S[:,:, 0:6, 0:6] 53 | S12 = S[:,:, 0:6, 6:10] 54 | S22 = S[:,:, 6:10, 6:10] 55 | A = S11 - torch.einsum('bpij,bpjk,bpkm->bpim', S12, torch.linalg.inv(S22 + 1e-5 * torch.rand((bz, kpts, 4, 4), device = _device).float()), S12.transpose(3,2)) 56 | 57 | M = torch.einsum('bpij,bpjk->bpik',inv_C, A) + 1e-4 * torch.rand((bz, kpts, 6, 6), device = _device).float() 58 | eigvals, eigvecs = torch.linalg.eig(M) 59 | eigvals = eigvals.real 60 | eigvecs = eigvecs.real 61 | 62 | idx = torch.argmax(eigvals,dim=2) 63 | vecs1 = torch.zeros_like(eigvals) 64 | for i in torch.arange(bz): 65 | vecs1[i,:,:] = eigvecs[i, torch.arange(kpts), :, idx[i,:]] 66 | 67 | vecs2 = - torch.einsum('bpij, bpjk,bpk->bpi',torch.linalg.inv(S22+1e-5 * torch.rand((bz, kpts, 4, 4), device = _device).float()), S12.transpose(3,2), vecs1) 68 | 69 | vecs = torch.cat((vecs1, vecs2),2) 70 | Coff[:,:,0,0] = vecs[:,:,0] 71 | Coff[:,:,1,1] = vecs[:,:,1] 72 | Coff[:,:,2,2] = vecs[:,:,2] 73 | Coff[:,:,3,3] = vecs[:,:,9] 74 | 75 | Coff[:,:,0,1] = Coff[:,:,1,0] = vecs[:,:,3] 76 | Coff[:,:,0,2] = Coff[:,:,2,0] = vecs[:,:,4] 77 | Coff[:,:,0,3] = Coff[:,:,3,0] = vecs[:,:,6] 78 | Coff[:,:,1,2] = Coff[:,:,2,1] = vecs[:,:,5] 79 | Coff[:,:,1,3] = Coff[:,:,3,1] = vecs[:,:,7] 80 | Coff[:,:,2,3] = Coff[:,:,3,2] = vecs[:,:,8] 81 | 82 | Coff = Coff + 1e-7 * torch.rand(bz, kpts,4,4, device = _device) 83 | centre = torch.linalg.solve(-Coff[:,:,0:3, 0:3], vecs[:,:,6:9]) 84 | 85 | 86 | T[:,:,0,0] = T[:,:,1,1]=T[:,:,2,2] = T[:,:,3,3] = 1 87 | T[:,:,3, 0:3] = centre 88 | R = torch.einsum('bpij,bpjk,bpkm->bpim', T, Coff, T.transpose(3,2)) 89 | evals, evecs = torch.linalg.eig(R[:,:,0:3, 0:3] / -R[:,:,3, 3][:,:,None,None]) 90 | evecs = evecs.transpose(3,2) 91 | radii = torch.sqrt(1 / torch.clamp(torch.abs(evals), min=1e-9)) 92 | return centre, evecs, radii, vecs 93 | 94 | class GetCurvature(nn.Module): 95 | def __init__(self, img_w, img_h): 96 | super().__init__() 97 | x = torch.linspace(0,img_w-1, img_w) 98 | y = torch.linspace(0,img_h-1, img_h) 99 | [yy, xx]=torch.meshgrid(y,x) 100 | xx_unfold1 = F.unfold(xx[None, None, :,:], kernel_size=(8, 8), stride=8, padding=1)[:,None,:,:] 101 | yy_unfold1 = F.unfold(yy[None, None, :,:], kernel_size=(8, 8), stride=8, padding=1)[:,None,:,:] 102 | self.register_buffer('xx_unfold1', xx_unfold1, persistent=False) 103 | self.register_buffer('yy_unfold1', yy_unfold1, persistent=False) 104 | 105 | def get_curv_descs(self, depth): 106 | _device = depth.device 107 | # depth = depth + 1e-5 * torch.rand(depth.shape, device = _device).float() 108 | d_unfold1 = F.unfold(depth[None, None, :,:], kernel_size=(8, 8), stride=8, padding=1)[:,None,:,:] 109 | 110 | z_unfold1 = torch.cat((self.xx_unfold1, self.yy_unfold1, d_unfold1), axis=1).permute(0,3,2,1) 111 | center_points = z_unfold1[:,:, 36, :][:,:,None,:].repeat(1,1,64,1) 112 | norm_points = z_unfold1 - center_points 113 | num_kpts = norm_points.shape[1] 114 | descs_curv = torch.zeros((1, num_kpts , 1), device = _device) 115 | z_unfolds = [norm_points] 116 | for i, z_unfold in enumerate(z_unfolds): 117 | center, evecs, radii, vecs = extract_curv_tensor(z_unfold, _device) 118 | # # viz 119 | # ctr_points = norm_points[:,:1,:,:].reshape(-1,3).cpu().numpy() 120 | # fig = plt.figure() 121 | # ax = fig.add_subplot(111, projection='3d') 122 | # ax.scatter(ctr_points[:,0], ctr_points[:,1], ctr_points[:,2], marker='o', color='g') 123 | # ax.scatter(0,0,0, marker='o', color='r') 124 | # ellipsoid_plot(center[:,:1,:].reshape(-1,3).cpu().numpy().squeeze(), radii[:,:1,:].reshape(-1,3).cpu().numpy().squeeze(), evecs[:,:1,:,:].reshape(3,3).cpu().numpy(), ax=ax, plot_axes=True, cage_color='g') 125 | # plt.show() 126 | # # end viz 127 | simi_curv = torch.div(torch.min(radii, 2).values, torch.max(radii,2).values) 128 | descs_curv[:,:, i] = simi_curv 129 | return descs_curv 130 | 131 | def get_curv_descs_cos_similarity(self, depth): 132 | _device = depth.device 133 | # depth = depth + 1e-5 * torch.rand(depth.shape, device = _device).float() 134 | d_unfold1 = F.unfold(depth[None, None, :,:], kernel_size=(8, 8), stride=8, padding=1)[:,None,:,:] 135 | 136 | z_unfold1 = torch.cat((self.xx_unfold1, self.yy_unfold1, d_unfold1), axis=1).permute(0,3,2,1) 137 | center_points = z_unfold1[:,:, 36, :][:,:,None,:].repeat(1,1,64,1) 138 | norm_points = z_unfold1 - center_points 139 | num_kpts = norm_points.shape[1] 140 | descs_curv = torch.zeros((1, num_kpts , 1), device = _device) 141 | z_unfolds = [norm_points] 142 | for i, z_unfold in enumerate(z_unfolds): 143 | center, evecs, radii, vecs = extract_curv_tensor(z_unfold, _device) 144 | # # viz 145 | # ctr_points = norm_points[:,:1,:,:].reshape(-1,3).cpu().numpy() 146 | # fig = plt.figure() 147 | # ax = fig.add_subplot(111, projection='3d') 148 | # ax.scatter(ctr_points[:,0], ctr_points[:,1], ctr_points[:,2], marker='o', color='g') 149 | # ax.scatter(0,0,0, marker='o', color='r') 150 | # ellipsoid_plot(center[:,:1,:].reshape(-1,3).cpu().numpy().squeeze(), radii[:,:1,:].reshape(-1,3).cpu().numpy().squeeze(), evecs[:,:1,:,:].reshape(3,3).cpu().numpy(), ax=ax, plot_axes=True, cage_color='g') 151 | # plt.show() 152 | # # end viz 153 | modulus_radii = torch.linalg.vector_norm(radii, dim=2, keepdim=True) 154 | descs_curv = radii / modulus_radii.repeat(1,1,3) 155 | return descs_curv 156 | 157 | from pdb import set_trace as bb 158 | 159 | @torch.no_grad() 160 | def compute_patch_convexity(depth_map, patch_size=8): 161 | # Extract patches using unfold 162 | patches = depth_map.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size) 163 | patches = patches.contiguous().view(-1, patch_size, patch_size) 164 | 165 | # Compute the first derivatives 166 | dx_filter = torch.Tensor([[-1, 0, 1],[-2, 0, 2],[-1, 0, 1]]).view(1, 1, 3, 3).to(depth_map.device) 167 | dy_filter = torch.Tensor([[-1, -2, -1],[0, 0, 0],[1, 2, 1]]).view(1, 1, 3, 3).to(depth_map.device) 168 | 169 | dx = F.conv2d(patches.unsqueeze(1), dx_filter, padding=1) 170 | dy = F.conv2d(patches.unsqueeze(1), dy_filter, padding=1) 171 | 172 | # Compute the second derivatives 173 | dxx = F.conv2d(dx, dx_filter, padding=1) 174 | dyy = F.conv2d(dy, dy_filter, padding=1) 175 | dxy = F.conv2d(dx, dy_filter, padding=1) 176 | 177 | # Compute Gaussian curvature at the centers 178 | center_idx = patch_size // 2 179 | K = (dxx[:, 0, center_idx, center_idx] * dyy[:, 0, center_idx, center_idx] - dxy[:, 0, center_idx, center_idx]**2) / (1 + dx[:, 0, center_idx, center_idx]**2 + dy[:, 0, center_idx, center_idx]**2)**2 180 | K_sign = torch.sign(K) 181 | return K_sign 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /src/DPT/depth_estimation.py: -------------------------------------------------------------------------------- 1 | from src.DPT.midas.dpt_depth import DPTDepthModel 2 | # from src.DPT.midas.midas_net import MidasNet 3 | from src.DPT.curvature_extraction import GetCurvature, compute_patch_convexity 4 | import sys 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | from src.DPT.utils import write_depth_color, write_curv_color 11 | import os 12 | from pdb import set_trace as bb 13 | 14 | 15 | 16 | class MiDasDepth(nn.Module): 17 | def __init__(self, config): 18 | super().__init__() 19 | if config['trainer']['testing']: 20 | model_path = None 21 | else: 22 | model_path = 'weights/dpt_large-midas-2f21e586.pt' 23 | print('load_depth_model_from', model_path) 24 | model = DPTDepthModel( 25 | path= model_path, # load the pretrained model for fine-tune 26 | backbone="vitl16_384", 27 | non_negative=True, 28 | ) 29 | 30 | model.to(memory_format=torch.channels_last) 31 | self.model = model.train() 32 | if config['dataset']['test_data_source'] == 'MegaDepth' or config['dataset']['trainval_data_source'] == 'MegaDepth' or config['dataset']['test_data_source'] == 'YFCC': 33 | img_resize = config['dataset']['mgdpt_img_resize'] 34 | self.getcurvature = GetCurvature(img_resize, img_resize) 35 | elif config['dataset']['test_data_source'] == 'ScanNet' or config['dataset']['trainval_data_source'] == 'ScanNet': 36 | self.getcurvature = GetCurvature(640, 480) 37 | else: 38 | raise NotImplementedError() 39 | 40 | self.get_curv_map = None # or L2 41 | self.get_curv_desc = None # or 'cosine_similarity' 42 | self.temperature = 1 43 | self.viz = False 44 | self.scale_factor = 100 45 | self.sign = config['loftr']['match_coarse']['curv_sign'] # consider if the surface is convex or concave 46 | 47 | def forward(self, data): 48 | dataset_name = data['dataset_name'][0] 49 | img0 = data['img2depth0'] 50 | img1 = data['img2depth1'] 51 | sample0 = img0.to(memory_format=torch.channels_last) 52 | sample1 = img1.to(memory_format=torch.channels_last) 53 | sample = torch.cat((sample0, sample1), 0) 54 | prediction = self.model.forward(sample) 55 | depth0 = prediction[0,:,:] 56 | depth1 = prediction[1,:,:] 57 | if dataset_name == 'ScanNet': 58 | depth0 = ( 59 | torch.nn.functional.interpolate( 60 | depth0[None,None,:,:], 61 | size=(480,640), 62 | mode="bicubic", 63 | align_corners=False, 64 | ) 65 | .squeeze(0) 66 | ) 67 | 68 | depth1 = ( 69 | torch.nn.functional.interpolate( 70 | depth1[None,None,:,:], 71 | size=(480,640), 72 | mode="bicubic", 73 | align_corners=False, 74 | ) 75 | .squeeze(0) 76 | ) 77 | 78 | else: 79 | img0_size = (data['image0'].shape[2], data['image0'].shape[3]) 80 | img1_size = (data['image1'].shape[2], data['image1'].shape[3]) 81 | depth0 = ( 82 | torch.nn.functional.interpolate( 83 | depth0[None,None,:,:], 84 | size=img0_size, 85 | mode="bicubic", 86 | align_corners=False, 87 | ) 88 | .squeeze(0) 89 | ) 90 | 91 | depth1 = ( 92 | torch.nn.functional.interpolate( 93 | depth1[None,None,:,:], 94 | size=img1_size, 95 | mode="bicubic", 96 | align_corners=False, 97 | ) 98 | .squeeze(0) 99 | ) 100 | 101 | # print(depth0.min(), depth0.median(),depth0.max()) 102 | # print(data['pair_names']) 103 | if self.get_curv_desc == 'cosine_similarity': 104 | descs_curv0 = self.getcurvature.get_curv_descs_cos_similarity(depth0.squeeze() * self.scale_factor) 105 | descs_curv1 = self.getcurvature.get_curv_descs_cos_similarity(depth1.squeeze() * self.scale_factor) 106 | if self.get_curv_map == 'L2': 107 | curv_map = torch.cdist(descs_curv0, descs_curv1, p=2) 108 | norm_curv_map = 1 - (curv_map - curv_map.min())/(curv_map.max() - curv_map.min()) 109 | else: 110 | norm_curv_map = torch.einsum("nlc,nsc->nls", descs_curv0, descs_curv1) / self.temperature 111 | 112 | else: 113 | descs_curv0 = self.getcurvature.get_curv_descs(depth0.squeeze() * self.scale_factor) 114 | descs_curv1 = self.getcurvature.get_curv_descs(depth1.squeeze() * self.scale_factor) 115 | if self.sign: 116 | with torch.no_grad(): 117 | sign0 = compute_patch_convexity(depth0.squeeze()) 118 | sign1 = compute_patch_convexity(depth1.squeeze()) 119 | curv_map = torch.cdist(descs_curv0*sign0 , descs_curv1*sign1, p=2) 120 | else: 121 | curv_map = torch.cdist(descs_curv0 , descs_curv1, p=2) 122 | norm_curv_map = 1 - (curv_map - curv_map.min())/(curv_map.max() - curv_map.min()) 123 | 124 | data.update({'curv_map': norm_curv_map}) 125 | 126 | # if self.viz: 127 | # depth0 = depth0.squeeze().cpu().numpy() 128 | # depth1 = depth1.squeeze().cpu().numpy() 129 | # if dataset_name == 'ScanNet': 130 | # save_path0 = os.path.join('data/scannet/test', data['pair_names'][0][0]) 131 | # save_path1 = os.path.join('data/scannet/test', data['pair_names'][1][0]) 132 | 133 | # elif dataset_name == 'MegaDepth': 134 | # save_path0 = os.path.join('data/megadepth/test', data['pair_names'][0][0]) 135 | # save_path1 = os.path.join('data/megadepth/test', data['pair_names'][1][0]) 136 | # else: 137 | # raise NotImplementedError() 138 | # # write_depth_color(save_path0, depth0, bits=2) 139 | # # write_depth_color(save_path1, depth1, bits=2) 140 | # # add curvature representation 141 | # if dataset_name == 'ScanNet': 142 | # dsize = (640, 480) 143 | # descs_curv0 = descs_curv0.squeeze().cpu().numpy().reshape(60, 80) 144 | # descs_curv1 = descs_curv1.squeeze().cpu().numpy().reshape(60, 80) 145 | # elif dataset_name == 'MegaDepth': 146 | # dsize = (data['image0'].shape[2], data['image0'].shape[3]) 147 | # descs_curv0 = descs_curv0.squeeze().cpu().numpy().reshape(104, 104) 148 | # descs_curv1 = descs_curv1.squeeze().cpu().numpy().reshape(104, 104) 149 | # write_curv_color(save_path0, descs_curv0 , dsize, bits=2) 150 | # write_curv_color(save_path1, descs_curv1 , dsize, bits=2) 151 | -------------------------------------------------------------------------------- /src/DPT/midas/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/DPT/midas/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /src/DPT/midas/__pycache__/blocks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/DPT/midas/__pycache__/blocks.cpython-38.pyc -------------------------------------------------------------------------------- /src/DPT/midas/__pycache__/dpt_depth.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/DPT/midas/__pycache__/dpt_depth.cpython-38.pyc -------------------------------------------------------------------------------- /src/DPT/midas/__pycache__/midas_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/DPT/midas/__pycache__/midas_loss.cpython-38.pyc -------------------------------------------------------------------------------- /src/DPT/midas/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/DPT/midas/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /src/DPT/midas/__pycache__/vit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/DPT/midas/__pycache__/vit.cpython-38.pyc -------------------------------------------------------------------------------- /src/DPT/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /src/DPT/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | # FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | for p in self.pretrained.parameters(): 59 | p.requires_grad = False 60 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 63 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 64 | 65 | self.scratch.output_conv = head 66 | 67 | 68 | def forward(self, x): 69 | if self.channels_last == True: 70 | x.contiguous(memory_format=torch.channels_last) 71 | 72 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 73 | 74 | layer_1_rn = self.scratch.layer1_rn(layer_1) 75 | layer_2_rn = self.scratch.layer2_rn(layer_2) 76 | layer_3_rn = self.scratch.layer3_rn(layer_3) 77 | layer_4_rn = self.scratch.layer4_rn(layer_4) 78 | 79 | path_4 = self.scratch.refinenet4(layer_4_rn) 80 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 81 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 82 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 83 | 84 | out = self.scratch.output_conv(path_1) 85 | 86 | return out 87 | 88 | 89 | class DPTDepthModel(DPT): 90 | def __init__(self, path=None, non_negative=True, **kwargs): 91 | features = kwargs["features"] if "features" in kwargs else 256 92 | 93 | head = nn.Sequential( 94 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 95 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 96 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 97 | nn.ReLU(True), 98 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 99 | nn.ReLU(True) if non_negative else nn.Identity(), 100 | nn.Identity(), 101 | ) 102 | 103 | super().__init__(head, **kwargs) 104 | 105 | if path is not None: 106 | self.load(path) 107 | 108 | def forward(self, x): 109 | return super().forward(x).squeeze(dim=1) 110 | 111 | -------------------------------------------------------------------------------- /src/DPT/midas/midas_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def masked_l1_loss(preds, target, mask_valid): 6 | element_wise_loss = abs(preds - target) 7 | element_wise_loss[~mask_valid] = 0 8 | return element_wise_loss.sum() / mask_valid.sum() 9 | 10 | 11 | def compute_scale_and_shift(prediction, target, mask): 12 | # system matrix: A = [[a_00, a_01], [a_10, a_11]] 13 | a_00 = torch.sum(mask * prediction * prediction, (1, 2)) 14 | a_01 = torch.sum(mask * prediction, (1, 2)) 15 | a_11 = torch.sum(mask, (1, 2)) 16 | 17 | # right hand side: b = [b_0, b_1] 18 | b_0 = torch.sum(mask * prediction * target, (1, 2)) 19 | b_1 = torch.sum(mask * target, (1, 2)) 20 | 21 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b 22 | x_0 = torch.zeros_like(b_0) 23 | x_1 = torch.zeros_like(b_1) 24 | 25 | det = a_00 * a_11 - a_01 * a_01 26 | valid = det.nonzero() 27 | 28 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / (det[valid] + 1e-6) 29 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / (det[valid] + 1e-6) 30 | 31 | return x_0, x_1 32 | 33 | 34 | def masked_shift_and_scale(depth_preds, depth_gt, mask_valid): 35 | depth_preds_nan = depth_preds.clone() 36 | depth_gt_nan = depth_gt.clone() 37 | import pdb 38 | pdb.set_trace() 39 | depth_preds_nan[~mask_valid] = np.nan 40 | depth_gt_nan[~mask_valid] = np.nan 41 | 42 | mask_diff = mask_valid.view(mask_valid.size()[:2] + (-1,)).sum(-1, keepdims=True) + 1 43 | t_gt = depth_gt_nan.view(depth_gt_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1) 44 | t_gt[torch.isnan(t_gt)] = 0 45 | diff_gt = torch.abs(depth_gt - t_gt) 46 | diff_gt[~mask_valid] = 0 47 | s_gt = (diff_gt.view(diff_gt.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1) 48 | depth_gt_aligned = (depth_gt - t_gt) / (s_gt + 1e-6) 49 | 50 | 51 | t_pred = depth_preds_nan.view(depth_preds_nan.size()[:2] + (-1,)).nanmedian(-1, keepdims=True)[0].unsqueeze(-1) 52 | t_pred[torch.isnan(t_pred)] = 0 53 | diff_pred = torch.abs(depth_preds - t_pred) 54 | diff_pred[~mask_valid] = 0 55 | s_pred = (diff_pred.view(diff_pred.size()[:2] + (-1,)).sum(-1, keepdims=True) / mask_diff).unsqueeze(-1) 56 | depth_pred_aligned = (depth_preds - t_pred) / (s_pred + 1e-6) 57 | 58 | return depth_pred_aligned, depth_gt_aligned 59 | 60 | 61 | def reduction_batch_based(image_loss, M): 62 | # average of all valid pixels of the batch 63 | 64 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) 65 | divisor = torch.sum(M) 66 | 67 | if divisor == 0: 68 | return 0 69 | else: 70 | return torch.sum(image_loss) / divisor 71 | 72 | 73 | def reduction_image_based(image_loss, M): 74 | # mean of average of valid pixels of an image 75 | 76 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) 77 | valid = M.nonzero() 78 | 79 | image_loss[valid] = image_loss[valid] / M[valid] 80 | 81 | return torch.mean(image_loss) 82 | 83 | 84 | 85 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): 86 | 87 | M = torch.sum(mask, (1, 2)) 88 | 89 | diff = prediction - target 90 | diff = torch.mul(mask, diff) 91 | 92 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 93 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) 94 | grad_x = torch.mul(mask_x, grad_x) 95 | 96 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 97 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) 98 | grad_y = torch.mul(mask_y, grad_y) 99 | 100 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 101 | 102 | return reduction(image_loss, M) 103 | 104 | 105 | 106 | class SSIMAE(nn.Module): 107 | def __init__(self): 108 | super().__init__() 109 | 110 | def forward(self, depth_preds, depth_gt, mask_valid): 111 | depth_pred_aligned, depth_gt_aligned = masked_shift_and_scale(depth_preds, depth_gt, mask_valid) 112 | ssi_mae_loss = masked_l1_loss(depth_pred_aligned, depth_gt_aligned, mask_valid) 113 | return ssi_mae_loss 114 | 115 | 116 | class GradientMatchingTerm(nn.Module): 117 | def __init__(self, scales=4, reduction='batch-based'): 118 | super().__init__() 119 | 120 | if reduction == 'batch-based': 121 | self.__reduction = reduction_batch_based 122 | else: 123 | self.__reduction = reduction_image_based 124 | 125 | self.__scales = scales 126 | 127 | def forward(self, prediction, target, mask): 128 | total = 0 129 | 130 | for scale in range(self.__scales): 131 | step = pow(2, scale) 132 | 133 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], 134 | mask[:, ::step, ::step], reduction=self.__reduction) 135 | 136 | return total 137 | 138 | 139 | class MidasLoss(nn.Module): 140 | def __init__(self, alpha=0.1, scales=4, reduction='image-based'): 141 | super().__init__() 142 | 143 | self.__ssi_mae_loss = SSIMAE() 144 | self.__gradient_matching_term = GradientMatchingTerm(scales=scales, reduction=reduction) 145 | self.__alpha = alpha 146 | self.__prediction_ssi = None 147 | 148 | def forward(self, prediction, target, mask): 149 | prediction_inverse = 1 / (prediction.squeeze(1)+1e-6) 150 | target_inverse = 1 / (target.squeeze(1)+1e-6) 151 | ssi_loss = self.__ssi_mae_loss(prediction, target, mask) 152 | 153 | scale, shift = compute_scale_and_shift(prediction_inverse, target_inverse, mask.squeeze(1)) 154 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction_inverse + shift.view(-1, 1, 1) 155 | reg_loss = self.__gradient_matching_term(self.__prediction_ssi, target_inverse, mask.squeeze(1)) 156 | if self.__alpha > 0: 157 | total = ssi_loss + self.__alpha * reg_loss 158 | 159 | return total, ssi_loss, reg_loss 160 | 161 | 162 | # import torch 163 | 164 | # a = torch.rand((1,1,1,480,640)) 165 | # b = torch.rand((1,1,1,480,640)) 166 | # c = torch.ones((1,1,1,480,640)).long() 167 | # depth_pred_aligned, depth_gt_aligned = masked_shift_and_scale(a,b,c) 168 | 169 | -------------------------------------------------------------------------------- /src/DPT/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /src/DPT/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /src/DPT/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /src/DPT/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth. 2 | """ 3 | import sys 4 | import re 5 | import numpy as np 6 | import cv2 7 | import torch 8 | import h5py 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | 13 | def write_depth(path, depth, bits=1): 14 | """Write depth map to pfm and png file. 15 | Args: 16 | path (str): filepath without extension 17 | depth (array): depth 18 | """ 19 | #write_pfm(path + ".pfm", depth.astype(np.float32)) 20 | # hf = h5py.File(path + ".h5", 'w') 21 | # hf.create_dataset('depth', data=depth.astype(np.float16)) 22 | # hf.close() 23 | 24 | 25 | depth_min = depth.min() 26 | depth_max = depth.max() 27 | # import pdb 28 | # pdb.set_trace() 29 | 30 | max_val = (2**(8*bits))-1 31 | 32 | if depth_max - depth_min > np.finfo("float").eps: 33 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 34 | else: 35 | out = np.zeros(depth.shape, dtype=depth.type) 36 | save_path_rgb = path.replace('jpg', 'depth_epoch0_v51.png') 37 | 38 | if bits == 1: 39 | cv2.imwrite(save_path_rgb , out.astype("uint8")) 40 | elif bits == 2: 41 | cv2.imwrite(save_path_rgb , out.astype("uint16")) 42 | 43 | 44 | def write_curv_rgb(path, descs_curv, dsize, bits=2): 45 | curv_val = (2**(8*bits))-1 46 | descs_curv = np.abs(descs_curv) 47 | curv_min = descs_curv.min() 48 | curv_max = descs_curv.max() 49 | out = curv_val * (descs_curv - curv_min) / (curv_max - curv_min) 50 | out = (cv2.resize(out, dsize=dsize, interpolation=cv2.INTER_CUBIC)).astype("uint16") 51 | save_path_rgb = path.replace('jpg', 'curv_epoch0_v51.png') 52 | cv2.imwrite(save_path_rgb , out) 53 | 54 | 55 | def write_depth_color(path, depth, bits=2): 56 | """Write depth map to pfm and png file. 57 | Args: 58 | path (str): filepath without extension 59 | depth (array): depth 60 | """ 61 | depth_min = depth.min() 62 | depth_max = depth.max() 63 | 64 | # max_val = (2**(8*bits))-1 65 | max_val = 1 66 | 67 | if depth_max - depth_min > np.finfo("float").eps: 68 | out = 1 - max_val * (depth - depth_min) / (depth_max - depth_min) 69 | else: 70 | out = 1 - np.zeros(depth.shape, dtype=depth.type) 71 | save_path_rgb = path.replace('jpg', 'depth_new_llllllllast.png') 72 | 73 | plt.imsave(save_path_rgb, out, cmap='gray') 74 | 75 | def write_curv_color(path, descs_curv, dsize, bits=2): 76 | curv_val = (2**(8*bits))-1 77 | descs_curv = np.abs(descs_curv) 78 | curv_min = descs_curv.min() 79 | curv_max = descs_curv.max() 80 | out = curv_val * (descs_curv - curv_min) / (curv_max - curv_min) 81 | out = (cv2.resize(out, dsize=dsize, interpolation=cv2.INTER_CUBIC)).astype("uint16") 82 | save_path_rgb = path.replace('jpg', 'curv_no_train_sign.png') 83 | 84 | out_new = (out - np.min(out)) / (np.max(out) - np.min(out)) 85 | # plt.imsave(save_path_rgb, out_new, cmap='viridis') 86 | plt.imsave(save_path_rgb, out_new, cmap='viridis') 87 | 88 | def ellipsoid_plot(center, radii, rotation, ax, plot_axes=False, cage_color='b', cage_alpha=0.2): 89 | """Plot an ellipsoid""" 90 | 91 | u = np.linspace(0.0, 2.0 * np.pi, 100) 92 | v = np.linspace(0.0, np.pi, 100) 93 | 94 | # cartesian coordinates that correspond to the spherical angles: 95 | x = radii[0] * np.outer(np.cos(u), np.sin(v)) 96 | y = radii[1] * np.outer(np.sin(u), np.sin(v)) 97 | z = radii[2] * np.outer(np.ones_like(u), np.cos(v)) 98 | # rotate accordingly 99 | for i in range(len(x)): 100 | for j in range(len(x)): 101 | [x[i, j], y[i, j], z[i, j]] = np.dot([x[i, j], y[i, j], z[i, j]], rotation) + center 102 | 103 | if plot_axes: 104 | # make some purdy axes 105 | axes = np.array([[radii[0],0.0,0.0], 106 | [0.0,radii[1],0.0], 107 | [0.0,0.0,radii[2]]]) 108 | # rotate accordingly 109 | for i in range(len(axes)): 110 | axes[i] = np.dot(axes[i], rotation) 111 | 112 | # plot axes 113 | for p in axes: 114 | X3 = np.linspace(-p[0], p[0], 100) + center[0] 115 | Y3 = np.linspace(-p[1], p[1], 100) + center[1] 116 | Z3 = np.linspace(-p[2], p[2], 100) + center[2] 117 | ax.plot(X3, Y3, Z3, color=cage_color) 118 | 119 | # plot ellipsoid 120 | ax.plot_wireframe(x, y, z, rstride=4, cstride=4, color=cage_color, alpha=cage_alpha) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/__init__.py -------------------------------------------------------------------------------- /src/config/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | _CN = CN() 3 | 4 | ############## ↓ LoFTR Pipeline ↓ ############## 5 | _CN.LOFTR = CN() 6 | _CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN' 7 | _CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] 8 | _CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd 9 | _CN.LOFTR.FINE_CONCAT_COARSE_FEAT = True 10 | 11 | # 1. LoFTR-backbone (local feature CNN) config 12 | _CN.LOFTR.RESNETFPN = CN() 13 | _CN.LOFTR.RESNETFPN.INITIAL_DIM = 128 14 | _CN.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 15 | 16 | # 2. LoFTR-coarse module config 17 | _CN.LOFTR.COARSE = CN() 18 | _CN.LOFTR.COARSE.D_MODEL = 256 19 | _CN.LOFTR.COARSE.D_FFN = 256 20 | _CN.LOFTR.COARSE.NHEAD = 8 21 | _CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 22 | _CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] 23 | _CN.LOFTR.COARSE.TEMP_BUG_FIX = True 24 | _CN.LOFTR.COARSE.BLOCK_TYPE = 'loftr' 25 | _CN.LOFTR.COARSE.ATTN_TYPE = 'B' 26 | _CN.LOFTR.COARSE.TOPKS = [16, 8, 8] 27 | 28 | # 3. Coarse-Matching config 29 | _CN.LOFTR.MATCH_COARSE = CN() 30 | _CN.LOFTR.MATCH_COARSE.THR = 0.2 31 | _CN.LOFTR.MATCH_COARSE.BORDER_RM = 2 32 | _CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] 33 | _CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 34 | _CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3 35 | _CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 36 | _CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False 37 | _CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.2 # training tricks: save GPU memory 38 | _CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock 39 | _CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = True 40 | _CN.LOFTR.MATCH_COARSE.ADD_CURV = False # if the curvature features are added 41 | _CN.LOFTR.MATCH_COARSE.CURV_SIGN = True 42 | 43 | # 4. LoFTR-fine module config 44 | _CN.LOFTR.FINE = CN() 45 | _CN.LOFTR.FINE.D_MODEL = 128 46 | _CN.LOFTR.FINE.D_FFN = 128 47 | _CN.LOFTR.FINE.NHEAD = 8 48 | _CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1 49 | _CN.LOFTR.FINE.ATTENTION = 'linear' 50 | _CN.LOFTR.FINE.BLOCK_TYPE = 'loftr' 51 | 52 | # 5. LoFTR Losses 53 | # -- # coarse-level 54 | _CN.LOFTR.LOSS = CN() 55 | _CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy'] 56 | _CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0 57 | # _CN.LOFTR.LOSS.SPARSE_SPVS = False 58 | # -- - -- # focal loss (coarse) 59 | _CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25 60 | _CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0 61 | _CN.LOFTR.LOSS.POS_WEIGHT = 1.0 62 | _CN.LOFTR.LOSS.NEG_WEIGHT = 1.0 63 | # _CN.LOFTR.LOSS.DUAL_SOFTMAX = False # whether coarse-level use dual-softmax or not. 64 | # use `_CN.LOFTR.MATCH_COARSE.MATCH_TYPE` 65 | 66 | # -- # fine-level 67 | _CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2'] 68 | _CN.LOFTR.LOSS.FINE_WEIGHT = 1.0 69 | _CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0 # for filtering valid fine-level gts (some gt matches might fall out of the fine-level window) 70 | 71 | 72 | ############## Dataset ############## 73 | _CN.DATASET = CN() 74 | # 1. data config 75 | # training and validating 76 | _CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth'] 77 | _CN.DATASET.TRAIN_DATA_ROOT = None 78 | _CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses) 79 | _CN.DATASET.TRAIN_NPZ_ROOT = None 80 | _CN.DATASET.TRAIN_LIST_PATH = None 81 | _CN.DATASET.TRAIN_INTRINSIC_PATH = None 82 | _CN.DATASET.VAL_DATA_ROOT = None 83 | _CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses) 84 | _CN.DATASET.VAL_NPZ_ROOT = None 85 | _CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file 86 | _CN.DATASET.VAL_INTRINSIC_PATH = None 87 | # testing 88 | _CN.DATASET.TEST_DATA_SOURCE = None 89 | _CN.DATASET.TEST_DATA_ROOT = None 90 | _CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses) 91 | _CN.DATASET.TEST_NPZ_ROOT = None 92 | _CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file 93 | _CN.DATASET.TEST_INTRINSIC_PATH = None 94 | 95 | # 2. dataset config 96 | # general options 97 | _CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score 98 | _CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0 99 | _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile'] 100 | 101 | # MegaDepth options 102 | _CN.DATASET.MGDPT_IMG_RESIZE = 840 # resize the longer side, zero-pad bottom-right to square. 103 | _CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE 104 | _CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000 105 | _CN.DATASET.MGDPT_DF = 8 106 | 107 | ############## Trainer ############## 108 | _CN.TRAINER = CN() 109 | _CN.TRAINER.WORLD_SIZE = 1 110 | _CN.TRAINER.CANONICAL_BS = 64 111 | _CN.TRAINER.CANONICAL_LR = 6e-3 112 | _CN.TRAINER.SCALING = None # this will be calculated automatically 113 | _CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning 114 | 115 | # optimizer 116 | _CN.TRAINER.OPTIMIZER = "adamw" # [adam, adamw] 117 | _CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime 118 | _CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam 119 | _CN.TRAINER.ADAMW_DECAY = 0.1 120 | 121 | # step-based warm-up 122 | _CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant] 123 | _CN.TRAINER.WARMUP_RATIO = 0. 124 | _CN.TRAINER.WARMUP_STEP = 100 125 | 126 | # learning rate scheduler 127 | _CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR] 128 | _CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step] 129 | _CN.TRAINER.MSLR_MILESTONES = [3, 6, 9, 12] # MSLR: MultiStepLR 130 | _CN.TRAINER.MSLR_GAMMA = 0.5 131 | _CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing 132 | _CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval 133 | 134 | # plotting related 135 | _CN.TRAINER.ENABLE_PLOTTING = True 136 | _CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting 137 | _CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence'] 138 | _CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic' 139 | 140 | # geometric metrics and pose solver 141 | _CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue) 142 | _CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H'] 143 | _CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC] 144 | _CN.TRAINER.RANSAC_PIXEL_THR = 0.5 145 | _CN.TRAINER.RANSAC_CONF = 0.99999 146 | _CN.TRAINER.RANSAC_MAX_ITERS = 10000 147 | _CN.TRAINER.USE_MAGSACPP = False 148 | 149 | # data sampler for train_dataloader 150 | _CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal'] 151 | # 'scene_balance' config 152 | _CN.TRAINER.N_SAMPLES_PER_SUBSET = 20 153 | _CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not 154 | _CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not 155 | _CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data 156 | # 'random' config 157 | _CN.TRAINER.RDM_REPLACEMENT = True 158 | _CN.TRAINER.RDM_NUM_SAMPLES = None 159 | 160 | # gradient clipping 161 | _CN.TRAINER.GRADIENT_CLIPPING = 0.5 162 | 163 | # reproducibility 164 | # This seed affects the data sampling. With the same seed, the data sampling is promised 165 | # to be the same. When resume training from a checkpoint, it's better to use a different 166 | # seed, otherwise the sampled data will be exactly the same as before resuming, which will 167 | # cause less unique data items sampled during the entire training. 168 | # Use of different seed values might affect the final training result, since not all data items 169 | # are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.) 170 | _CN.TRAINER.SEED = 66 171 | _CN.TRAINER.TESTING = False 172 | 173 | def get_cfg_defaults(): 174 | """Get a yacs CfgNode object with default values for my_project.""" 175 | # Return a clone so that the defaults will not be altered 176 | # This is for the "local variable" use pattern 177 | return _CN.clone() 178 | -------------------------------------------------------------------------------- /src/datasets/megadepth.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | from loguru import logger 7 | import deepdish as dd 8 | 9 | from src.utils.dataset import read_megadepth_gray, read_megadepth_depth, read_image_for_depth_megadepth 10 | 11 | 12 | class MegaDepthDataset(Dataset): 13 | def __init__(self, 14 | root_dir, 15 | npz_path, 16 | mode='train', 17 | min_overlap_score=0.4, 18 | img_resize=None, 19 | df=None, 20 | img_padding=False, 21 | depth_padding=False, 22 | augment_fn=None, 23 | **kwargs): 24 | """ 25 | Manage one scene(npz_path) of MegaDepth dataset. 26 | 27 | Args: 28 | root_dir (str): megadepth root directory that has `phoenix`. 29 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 30 | mode (str): options are ['train', 'val', 'test'] 31 | min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. 32 | img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. 33 | This is useful during training with batches and testing with memory intensive algorithms. 34 | df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. 35 | img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. 36 | depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. 37 | augment_fn (callable, optional): augments images with pre-defined visual effects. 38 | """ 39 | super().__init__() 40 | self.root_dir = root_dir 41 | self.mode = mode 42 | self.scene_id = npz_path.split('.')[0] 43 | 44 | # prepare scene_info and pair_info 45 | if mode == 'test' and min_overlap_score != 0: 46 | logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") 47 | min_overlap_score = 0 48 | self.scene_info = np.load(npz_path, allow_pickle=True) 49 | self.pair_infos = self.scene_info['pair_infos'].copy() 50 | del self.scene_info['pair_infos'] 51 | self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] 52 | 53 | # parameters for image resizing, padding and depthmap padding 54 | if mode == 'train': 55 | assert img_resize is not None and img_padding and depth_padding 56 | self.img_resize = img_resize 57 | self.df = df 58 | self.img_padding = img_padding 59 | self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. 60 | 61 | # for training LoFTR 62 | self.augment_fn = augment_fn if mode == 'train' else None 63 | self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) 64 | 65 | def __len__(self): 66 | return len(self.pair_infos) 67 | 68 | def __getitem__(self, idx): 69 | (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx] 70 | 71 | # read grayscale image and mask. (1, h, w) and (h, w) 72 | img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0]) 73 | img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1]) 74 | 75 | # TODO: Support augmentation & handle seeds for each worker correctly. 76 | image0, mask0, scale0 = read_megadepth_gray( 77 | img_name0, self.img_resize, self.df, self.img_padding, None) 78 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 79 | image1, mask1, scale1 = read_megadepth_gray( 80 | img_name1, self.img_resize, self.df, self.img_padding, None) 81 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 82 | 83 | img2depth0 = read_image_for_depth_megadepth(img_name0, self.img_resize, self.df, self.img_padding, None) 84 | img2depth1 = read_image_for_depth_megadepth(img_name1, self.img_resize, self.df, self.img_padding, None) 85 | 86 | # read depth. shape: (h, w) 87 | if self.mode in ['train', 'val']: 88 | depth0 = read_megadepth_depth( 89 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) 90 | depth1 = read_megadepth_depth( 91 | osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) 92 | else: 93 | depth0 = depth1 = torch.tensor([]) 94 | # curv0_path = osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]) 95 | # desc_curv0 = dd.io.load(curv0_path.replace('.h5', '.curv.h5')) 96 | # curv1_path = osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]) 97 | # desc_curv1 = dd.io.load(curv1_path.replace('.h5', '.curv.h5')) 98 | 99 | 100 | # read intrinsics of original size 101 | K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3) 102 | K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3) 103 | 104 | # read and compute relative poses 105 | T0 = self.scene_info['poses'][idx0] 106 | T1 = self.scene_info['poses'][idx1] 107 | T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4) 108 | T_1to0 = T_0to1.inverse() 109 | 110 | data = { 111 | 'image0': image0, # (1, h, w) 112 | 'depth0': depth0, # (h, w) 113 | 'img2depth0' : img2depth0, 114 | 'image1': image1, 115 | 'depth1': depth1, 116 | 'img2depth1' : img2depth1, 117 | 'T_0to1': T_0to1, # (4, 4) 118 | 'T_1to0': T_1to0, 119 | 'K0': K_0, # (3, 3) 120 | 'K1': K_1, 121 | 'scale0': scale0, # [scale_w, scale_h] 122 | 'scale1': scale1, 123 | 'dataset_name': 'MegaDepth', 124 | 'scene_id': self.scene_id, 125 | 'pair_id': idx, 126 | 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]), 127 | } 128 | 129 | # for LoFTR training 130 | if mask0 is not None: # img_padding is True 131 | if self.coarse_scale: 132 | [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), 133 | scale_factor=self.coarse_scale, 134 | mode='nearest', 135 | recompute_scale_factor=False)[0].bool() 136 | data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) 137 | return data 138 | -------------------------------------------------------------------------------- /src/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Sampler, ConcatDataset 3 | 4 | 5 | class RandomConcatSampler(Sampler): 6 | """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset 7 | in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. 8 | However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. 9 | 10 | For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. 11 | Args: 12 | shuffle (bool): shuffle the random sampled indices across all sub-datsets. 13 | repeat (int): repeatedly use the sampled indices multiple times for training. 14 | [arXiv:1902.05509, arXiv:1901.09335] 15 | NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) 16 | NOTE: This sampler behaves differently with DistributedSampler. 17 | It assume the dataset is splitted across ranks instead of replicated. 18 | TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. 19 | ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 20 | """ 21 | def __init__(self, 22 | data_source: ConcatDataset, 23 | n_samples_per_subset: int, 24 | subset_replacement: bool=True, 25 | shuffle: bool=True, 26 | repeat: int=1, 27 | seed: int=None): 28 | if not isinstance(data_source, ConcatDataset): 29 | raise TypeError("data_source should be torch.utils.data.ConcatDataset") 30 | 31 | self.data_source = data_source 32 | self.n_subset = len(self.data_source.datasets) 33 | self.n_samples_per_subset = n_samples_per_subset 34 | self.n_samples = self.n_subset * self.n_samples_per_subset * repeat 35 | self.subset_replacement = subset_replacement 36 | self.repeat = repeat 37 | self.shuffle = shuffle 38 | self.generator = torch.manual_seed(seed) 39 | assert self.repeat >= 1 40 | 41 | def __len__(self): 42 | return self.n_samples 43 | 44 | def __iter__(self): 45 | indices = [] 46 | # sample from each sub-dataset 47 | for d_idx in range(self.n_subset): 48 | low = 0 if d_idx==0 else self.data_source.cumulative_sizes[d_idx-1] 49 | high = self.data_source.cumulative_sizes[d_idx] 50 | if self.subset_replacement: 51 | rand_tensor = torch.randint(low, high, (self.n_samples_per_subset, ), 52 | generator=self.generator, dtype=torch.int64) 53 | else: # sample without replacement 54 | len_subset = len(self.data_source.datasets[d_idx]) 55 | rand_tensor = torch.randperm(len_subset, generator=self.generator) + low 56 | if len_subset >= self.n_samples_per_subset: 57 | rand_tensor = rand_tensor[:self.n_samples_per_subset] 58 | else: # padding with replacement 59 | rand_tensor_replacement = torch.randint(low, high, (self.n_samples_per_subset - len_subset, ), 60 | generator=self.generator, dtype=torch.int64) 61 | rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) 62 | indices.append(rand_tensor) 63 | indices = torch.cat(indices) 64 | if self.shuffle: # shuffle the sampled dataset (from multiple subsets) 65 | rand_tensor = torch.randperm(len(indices), generator=self.generator) 66 | indices = indices[rand_tensor] 67 | 68 | # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) 69 | if self.repeat > 1: 70 | repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] 71 | if self.shuffle: 72 | _choice = lambda x: x[torch.randperm(len(x), generator=self.generator)] 73 | repeat_indices = map(_choice, repeat_indices) 74 | indices = torch.cat([indices, *repeat_indices], 0) 75 | 76 | assert indices.shape[0] == self.n_samples 77 | return iter(indices.tolist()) 78 | -------------------------------------------------------------------------------- /src/datasets/scannet.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from readline import insert_text 3 | from typing import Dict 4 | from unicodedata import name 5 | import deepdish as dd 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils as utils 10 | from numpy.linalg import inv 11 | from src.utils.dataset import ( 12 | read_scannet_gray, 13 | read_scannet_depth, 14 | read_scannet_pose, 15 | read_scannet_intrinsic, 16 | read_image_for_depth 17 | ) 18 | 19 | 20 | class ScanNetDataset(utils.data.Dataset): 21 | def __init__(self, 22 | root_dir, 23 | npz_path, 24 | intrinsic_path, 25 | mode='train', 26 | min_overlap_score=0.4, 27 | augment_fn=None, 28 | pose_dir=None, 29 | **kwargs): 30 | """Manage one scene of ScanNet Dataset. 31 | Args: 32 | root_dir (str): ScanNet root directory that contains scene folders. 33 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 34 | intrinsic_path (str): path to depth-camera intrinsic file. 35 | mode (str): options are ['train', 'val', 'test']. 36 | augment_fn (callable, optional): augments images with pre-defined visual effects. 37 | pose_dir (str): ScanNet root directory that contains all poses. 38 | (we use a separate (optional) pose_dir since we store images and poses separately.) 39 | """ 40 | super().__init__() 41 | self.root_dir = root_dir 42 | self.pose_dir = pose_dir if pose_dir is not None else root_dir 43 | self.mode = mode 44 | 45 | # prepare data_names, intrinsics and extrinsics(T) 46 | with np.load(npz_path) as data: 47 | self.data_names = data['name'] 48 | if 'score' in data.keys() and mode not in ['val' or 'test']: 49 | kept_mask = data['score'] > min_overlap_score 50 | self.data_names = self.data_names[kept_mask] 51 | self.intrinsics = dict(np.load(intrinsic_path)) 52 | 53 | # for training LoFTR 54 | self.augment_fn = augment_fn if mode == 'train' else None 55 | 56 | 57 | def __len__(self): 58 | return len(self.data_names) 59 | 60 | def _read_abs_pose(self, scene_name, name): 61 | pth = osp.join(self.pose_dir, 62 | scene_name, 63 | 'pose', f'{name}.txt') 64 | return read_scannet_pose(pth) 65 | 66 | def _compute_rel_pose(self, scene_name, name0, name1): 67 | pose0 = self._read_abs_pose(scene_name, name0) 68 | pose1 = self._read_abs_pose(scene_name, name1) 69 | 70 | return np.matmul(pose1, inv(pose0)) # (4, 4) 71 | 72 | def __getitem__(self, idx): 73 | data_name = self.data_names[idx] 74 | scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name 75 | scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' 76 | 77 | # read the grayscale image which will be resized to (1, 480, 640) 78 | img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') 79 | img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') 80 | 81 | # TODO: Support augmentation & handle seeds for each worker correctly. 82 | image0 = read_scannet_gray(img_name0, resize=(640, 480), augment_fn=None) 83 | # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 84 | image1 = read_scannet_gray(img_name1, resize=(640, 480), augment_fn=None) 85 | # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 86 | # read image to predict depth 87 | img2depth0 = read_image_for_depth(img_name0, resize=(640, 480)) 88 | img2depth1 = read_image_for_depth(img_name1, resize=(640, 480)) 89 | # read the depthmap which is stored as (480, 640) 90 | if self.mode in ['train', 'val']: 91 | depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) 92 | depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) 93 | else: 94 | depth0 = depth1 = torch.tensor([]) 95 | 96 | # curv_path0 = osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.curv_est_depth.h5') 97 | # curv_path1 = osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.curv_est_depth.h5') 98 | # desc_curv0 = dd.io.load(curv_path0) 99 | # desc_curv1 = dd.io.load(curv_path1) 100 | 101 | 102 | # read the intrinsic of depthmap 103 | K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) 104 | 105 | # read and compute relative poses 106 | T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), 107 | dtype=torch.float32) 108 | T_1to0 = T_0to1.inverse() 109 | 110 | data = { 111 | 'image0': image0, # (1, h, w) 112 | 'depth0': depth0, # (h, w) 113 | 'img2depth0' : img2depth0, 114 | 'image1': image1, 115 | 'depth1': depth1, 116 | 'img2depth1' : img2depth1, 117 | 'T_0to1': T_0to1, # (4, 4) 118 | 'T_1to0': T_1to0, 119 | 'K0': K_0, # (3, 3) 120 | 'K1': K_1, 121 | 'dataset_name': 'ScanNet', 122 | 'scene_id': scene_name, 123 | 'pair_id': idx, 124 | 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), 125 | osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) 126 | } 127 | 128 | 129 | return data 130 | -------------------------------------------------------------------------------- /src/datasets/yfcc.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import Dataset 6 | from loguru import logger 7 | import deepdish as dd 8 | 9 | from src.utils.dataset import read_megadepth_gray, read_megadepth_depth, read_image_for_depth_megadepth 10 | 11 | # change the format from megadepth to yfcc 12 | class YFCCDataset(Dataset): 13 | def __init__(self, 14 | root_dir, 15 | npz_path, 16 | mode='train', 17 | min_overlap_score=0.4, 18 | img_resize=None, 19 | df=None, 20 | img_padding=False, 21 | depth_padding=False, 22 | augment_fn=None, 23 | **kwargs): 24 | """ 25 | Manage one scene(npz_path) of MegaDepth dataset. 26 | 27 | Args: 28 | root_dir (str): megadepth root directory that has `phoenix`. 29 | npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. 30 | mode (str): options are ['train', 'val', 'test'] 31 | min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing. 32 | img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended. 33 | This is useful during training with batches and testing with memory intensive algorithms. 34 | df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize. 35 | img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training. 36 | depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training. 37 | augment_fn (callable, optional): augments images with pre-defined visual effects. 38 | """ 39 | super().__init__() 40 | self.root_dir = root_dir 41 | self.mode = mode 42 | # self.scene_id = npz_path.split('.')[0] 43 | 44 | # prepare scene_info and pair_info 45 | if mode == 'test' and min_overlap_score != 0: 46 | logger.warning("You are using `min_overlap_score`!=0 in test mode. Set to 0.") 47 | min_overlap_score = 0 48 | 49 | # with open(test_path, 'r') as f: 50 | # self.pair_infos = [l.split() for l in f.readlines()] 51 | 52 | self.scene_info = np.load(npz_path) 53 | self.pair_infos = self.scene_info['pair_infos'].tolist().copy() 54 | # del self.scene_info['pair_infos'] 55 | # self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score] 56 | 57 | # parameters for image resizing, padding and depthmap padding 58 | if mode == 'train': 59 | assert img_resize is not None and img_padding and depth_padding 60 | self.img_resize = img_resize 61 | self.df = df 62 | self.img_padding = img_padding 63 | self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth. 64 | 65 | # for training LoFTR 66 | self.augment_fn = augment_fn if mode == 'train' else None 67 | self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125) 68 | 69 | 70 | 71 | def __len__(self): 72 | return len(self.pair_infos) 73 | 74 | def __getitem__(self, idx): 75 | pairs = self.pair_infos[idx] 76 | name0, name1 = pairs[0], pairs[1] 77 | # read grayscale image and mask. (1, h, w) and (h, w) 78 | img_name0 = osp.join(self.root_dir, name0) 79 | img_name1 = osp.join(self.root_dir, name1) 80 | 81 | # reuse the megadepth code for image reading and resize YFCC images to 1600 82 | image0, mask0, scale0 = read_megadepth_gray( 83 | img_name0, self.img_resize, self.df, self.img_padding, None) 84 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 85 | image1, mask1, scale1 = read_megadepth_gray( 86 | img_name1, self.img_resize, self.df, self.img_padding, None) 87 | # np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) 88 | # recuse the megadepth code for YFCC image 89 | img2depth0 = read_image_for_depth_megadepth(img_name0, self.img_resize, self.df, self.img_padding, None) 90 | img2depth1 = read_image_for_depth_megadepth(img_name1, self.img_resize, self.df, self.img_padding, None) 91 | 92 | # # read depth. shape: (h, w) 93 | # if self.mode in ['train', 'val']: 94 | # depth0 = read_megadepth_depth( 95 | # osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size) 96 | # depth1 = read_megadepth_depth( 97 | # osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size) 98 | # else: 99 | # no gt depth for test 100 | depth0 = depth1 = torch.tensor([]) 101 | # curv0_path = osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]) 102 | # desc_curv0 = dd.io.load(curv0_path.replace('.h5', '.curv.h5')) 103 | # curv1_path = osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]) 104 | # desc_curv1 = dd.io.load(curv1_path.replace('.h5', '.curv.h5')) 105 | 106 | 107 | # read intrinsics of original size 108 | K_0 = torch.from_numpy(np.array(pairs[4:13]).astype(float).reshape(3, 3)) 109 | K_1 = torch.from_numpy(np.array(pairs[13:22]).astype(float).reshape(3, 3)) 110 | T_0to1 = torch.from_numpy(np.array(pairs[22:]).astype(float).reshape(4, 4)) 111 | 112 | # read and compute relative poses 113 | T_1to0 = T_0to1.inverse() 114 | 115 | data = { 116 | 'image0': image0, # (1, h, w) 117 | 'depth0': depth0, # (h, w) 118 | 'img2depth0' : img2depth0, 119 | 'image1': image1, 120 | 'depth1': depth1, 121 | 'img2depth1' : img2depth1, 122 | 'T_0to1': T_0to1, # (4, 4) 123 | 'T_1to0': T_1to0, 124 | 'K0': K_0, # (3, 3) 125 | 'K1': K_1, 126 | 'scale0': scale0, # [scale_w, scale_h] 127 | 'scale1': scale1, 128 | 'dataset_name': 'YFCC', 129 | 'scene_id': idx, 130 | 'pair_id': idx, 131 | 'pair_names': (name0, name1), 132 | } 133 | 134 | # for LoFTR training 135 | if mask0 is not None: # img_padding is True 136 | if self.coarse_scale: 137 | [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), 138 | scale_factor=self.coarse_scale, 139 | mode='nearest', 140 | recompute_scale_factor=False)[0].bool() 141 | data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1}) 142 | return data 143 | -------------------------------------------------------------------------------- /src/loftr/__init__.py: -------------------------------------------------------------------------------- 1 | from .loftr import LoFTR 2 | from .utils.cvpr_ds_config import default_cfg 3 | -------------------------------------------------------------------------------- /src/loftr/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 2 | 3 | 4 | def build_backbone(config): 5 | if config['backbone_type'] == 'ResNetFPN': 6 | if config['resolution'] == (8, 2): 7 | return ResNetFPN_8_2(config['resnetfpn']) 8 | elif config['resolution'] == (16, 4): 9 | return ResNetFPN_16_4(config['resnetfpn']) 10 | else: 11 | raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") 12 | -------------------------------------------------------------------------------- /src/loftr/backbone/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/backbone/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/backbone/__pycache__/resnet_fpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/backbone/__pycache__/resnet_fpn.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/backbone/resnet_fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv1x1(in_planes, out_planes, stride=1): 6 | """1x1 convolution without padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | def __init__(self, in_planes, planes, stride=1): 17 | super().__init__() 18 | self.conv1 = conv3x3(in_planes, planes, stride) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | if stride == 1: 25 | self.downsample = None 26 | else: 27 | self.downsample = nn.Sequential( 28 | conv1x1(in_planes, planes, stride=stride), 29 | nn.BatchNorm2d(planes) 30 | ) 31 | 32 | def forward(self, x): 33 | y = x 34 | y = self.relu(self.bn1(self.conv1(y))) 35 | y = self.bn2(self.conv2(y)) 36 | 37 | if self.downsample is not None: 38 | x = self.downsample(x) 39 | 40 | return self.relu(x+y) 41 | 42 | 43 | class ResNetFPN_8_2(nn.Module): 44 | """ 45 | ResNet+FPN, output resolution are 1/8 and 1/2. 46 | Each block has 2 layers. 47 | """ 48 | 49 | def __init__(self, config): 50 | super().__init__() 51 | # Config 52 | block = BasicBlock 53 | initial_dim = config['initial_dim'] 54 | block_dims = config['block_dims'] 55 | 56 | # Class Variable 57 | self.in_planes = initial_dim 58 | 59 | # Networks 60 | self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) 61 | self.bn1 = nn.BatchNorm2d(initial_dim) 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 65 | self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 66 | self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 67 | 68 | # 3. FPN upsample 69 | self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) 70 | self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) 71 | self.layer2_outconv2 = nn.Sequential( 72 | conv3x3(block_dims[2], block_dims[2]), 73 | nn.BatchNorm2d(block_dims[2]), 74 | nn.LeakyReLU(), 75 | conv3x3(block_dims[2], block_dims[1]), 76 | ) 77 | self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) 78 | self.layer1_outconv2 = nn.Sequential( 79 | conv3x3(block_dims[1], block_dims[1]), 80 | nn.BatchNorm2d(block_dims[1]), 81 | nn.LeakyReLU(), 82 | conv3x3(block_dims[1], block_dims[0]), 83 | ) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 88 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 89 | nn.init.constant_(m.weight, 1) 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, block, dim, stride=1): 93 | layer1 = block(self.in_planes, dim, stride=stride) 94 | layer2 = block(dim, dim, stride=1) 95 | layers = (layer1, layer2) 96 | 97 | self.in_planes = dim 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | # ResNet Backbone 102 | x0 = self.relu(self.bn1(self.conv1(x))) 103 | x1 = self.layer1(x0) # 1/2 104 | x2 = self.layer2(x1) # 1/4 105 | x3 = self.layer3(x2) # 1/8 106 | 107 | # FPN 108 | x3_out = self.layer3_outconv(x3) 109 | 110 | x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) 111 | x2_out = self.layer2_outconv(x2) 112 | x2_out = self.layer2_outconv2(x2_out+x3_out_2x) 113 | 114 | x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) 115 | x1_out = self.layer1_outconv(x1) 116 | x1_out = self.layer1_outconv2(x1_out+x2_out_2x) 117 | 118 | return [x3_out, x1_out] 119 | 120 | 121 | class ResNetFPN_16_4(nn.Module): 122 | """ 123 | ResNet+FPN, output resolution are 1/16 and 1/4. 124 | Each block has 2 layers. 125 | """ 126 | 127 | def __init__(self, config): 128 | super().__init__() 129 | # Config 130 | block = BasicBlock 131 | initial_dim = config['initial_dim'] 132 | block_dims = config['block_dims'] 133 | 134 | # Class Variable 135 | self.in_planes = initial_dim 136 | 137 | # Networks 138 | self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) 139 | self.bn1 = nn.BatchNorm2d(initial_dim) 140 | self.relu = nn.ReLU(inplace=True) 141 | 142 | self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 143 | self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 144 | self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 145 | self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 146 | 147 | # 3. FPN upsample 148 | self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) 149 | self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) 150 | self.layer3_outconv2 = nn.Sequential( 151 | conv3x3(block_dims[3], block_dims[3]), 152 | nn.BatchNorm2d(block_dims[3]), 153 | nn.LeakyReLU(), 154 | conv3x3(block_dims[3], block_dims[2]), 155 | ) 156 | 157 | self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) 158 | self.layer2_outconv2 = nn.Sequential( 159 | conv3x3(block_dims[2], block_dims[2]), 160 | nn.BatchNorm2d(block_dims[2]), 161 | nn.LeakyReLU(), 162 | conv3x3(block_dims[2], block_dims[1]), 163 | ) 164 | 165 | for m in self.modules(): 166 | if isinstance(m, nn.Conv2d): 167 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 168 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 169 | nn.init.constant_(m.weight, 1) 170 | nn.init.constant_(m.bias, 0) 171 | 172 | def _make_layer(self, block, dim, stride=1): 173 | layer1 = block(self.in_planes, dim, stride=stride) 174 | layer2 = block(dim, dim, stride=1) 175 | layers = (layer1, layer2) 176 | 177 | self.in_planes = dim 178 | return nn.Sequential(*layers) 179 | 180 | def forward(self, x): 181 | # ResNet Backbone 182 | x0 = self.relu(self.bn1(self.conv1(x))) 183 | x1 = self.layer1(x0) # 1/2 184 | x2 = self.layer2(x1) # 1/4 185 | x3 = self.layer3(x2) # 1/8 186 | x4 = self.layer4(x3) # 1/16 187 | 188 | # FPN 189 | x4_out = self.layer4_outconv(x4) 190 | 191 | x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) 192 | x3_out = self.layer3_outconv(x3) 193 | x3_out = self.layer3_outconv2(x3_out+x4_out_2x) 194 | 195 | x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) 196 | x2_out = self.layer2_outconv(x2) 197 | x2_out = self.layer2_outconv2(x2_out+x3_out_2x) 198 | 199 | return [x4_out, x2_out] 200 | -------------------------------------------------------------------------------- /src/loftr/loftr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.einops import rearrange 4 | 5 | from .backbone import build_backbone 6 | from .utils.position_encoding import PositionEncodingSine 7 | from .loftr_module import LocalFeatureTransformer, FinePreprocess 8 | from .utils.coarse_matching import CoarseMatching 9 | from .utils.fine_matching import FineMatching 10 | from pdb import set_trace as bb 11 | 12 | 13 | class LoFTR(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | # Misc 17 | self.config = config 18 | 19 | # Modules 20 | self.backbone = build_backbone(config) 21 | self.pos_encoding = PositionEncodingSine( 22 | config['coarse']['d_model'], 23 | temp_bug_fix=config['coarse']['temp_bug_fix']) 24 | self.loftr_coarse = LocalFeatureTransformer(config['coarse']) 25 | self.coarse_matching = CoarseMatching(config['match_coarse'], ) 26 | self.fine_preprocess = FinePreprocess(config) 27 | self.loftr_fine = LocalFeatureTransformer(config["fine"]) 28 | self.fine_matching = FineMatching() 29 | for p in self.parameters(): 30 | p.requires_grad = False 31 | 32 | def forward(self, data): 33 | """ 34 | Update: 35 | data (dict): { 36 | 'image0': (torch.Tensor): (N, 1, H, W) 37 | 'image1': (torch.Tensor): (N, 1, H, W) 38 | 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position 39 | 'mask1'(optional) : (torch.Tensor): (N, H, W) 40 | } 41 | """ 42 | # 1. Local Feature CNN 43 | data.update({ 44 | 'bs': data['image0'].size(0), 45 | 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] 46 | }) 47 | 48 | if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence 49 | feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) 50 | (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) 51 | else: # handle different input shapes 52 | (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) 53 | 54 | data.update({ 55 | 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], 56 | 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] 57 | }) 58 | 59 | # 2. coarse-level loftr module 60 | # add featmap with positional encoding, then flatten it to sequence [N, HW, C] 61 | feat_c0 = self.pos_encoding(feat_c0) 62 | feat_c1 = self.pos_encoding(feat_c1) 63 | 64 | mask_c0 = mask_c1 = None # mask is useful in training 65 | if 'mask0' in data: 66 | mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) 67 | feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) 68 | 69 | # 3. match coarse-level 70 | self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) 71 | 72 | # 4. fine-level refinement 73 | feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) 74 | if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted 75 | feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) 76 | 77 | # 5. match fine-level 78 | self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) 79 | 80 | def load_state_dict(self, state_dict, *args, **kwargs): 81 | for k in list(state_dict.keys()): 82 | if k.startswith('matcher.'): 83 | state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) 84 | return super().load_state_dict(state_dict, *args, **kwargs) 85 | -------------------------------------------------------------------------------- /src/loftr/loftr_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import LocalFeatureTransformer 2 | from .fine_preprocess import FinePreprocess 3 | -------------------------------------------------------------------------------- /src/loftr/loftr_module/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/loftr_module/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/loftr_module/__pycache__/fine_preprocess.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/loftr_module/__pycache__/fine_preprocess.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/loftr_module/__pycache__/linear_attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/loftr_module/__pycache__/linear_attention.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/loftr_module/__pycache__/quadtree_attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/loftr_module/__pycache__/quadtree_attention.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/loftr_module/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/loftr_module/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/loftr_module/fine_preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange, repeat 5 | 6 | 7 | class FinePreprocess(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | 11 | self.config = config 12 | self.cat_c_feat = config['fine_concat_coarse_feat'] 13 | self.W = self.config['fine_window_size'] 14 | 15 | d_model_c = self.config['coarse']['d_model'] 16 | d_model_f = self.config['fine']['d_model'] 17 | self.d_model_f = d_model_f 18 | if self.cat_c_feat: 19 | self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) 20 | self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) 21 | 22 | self._reset_parameters() 23 | 24 | def _reset_parameters(self): 25 | for p in self.parameters(): 26 | if p.dim() > 1: 27 | nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") 28 | 29 | def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): 30 | W = self.W 31 | stride = data['hw0_f'][0] // data['hw0_c'][0] 32 | 33 | data.update({'W': W}) 34 | if data['b_ids'].shape[0] == 0: 35 | feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 36 | feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 37 | return feat0, feat1 38 | 39 | # 1. unfold(crop) all local windows 40 | feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) 41 | feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 42 | feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) 43 | feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 44 | 45 | # 2. select only the predicted matches 46 | feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] 47 | feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] 48 | 49 | # option: use coarse-level loftr feature as context: concat and linear 50 | if self.cat_c_feat: 51 | feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], 52 | feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] 53 | feat_cf_win = self.merge_feat(torch.cat([ 54 | torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] 55 | repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] 56 | ], -1)) 57 | feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) 58 | 59 | 60 | return feat_f0_unfold, feat_f1_unfold 61 | -------------------------------------------------------------------------------- /src/loftr/loftr_module/linear_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" 3 | Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py 4 | """ 5 | 6 | import torch 7 | from torch.nn import Module, Dropout 8 | 9 | 10 | def elu_feature_map(x): 11 | return torch.nn.functional.elu(x) + 1 12 | 13 | 14 | class LinearAttention(Module): 15 | def __init__(self, eps=1e-6): 16 | super().__init__() 17 | self.feature_map = elu_feature_map 18 | self.eps = eps 19 | 20 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 21 | """ Multi-Head linear attention proposed in "Transformers are RNNs" 22 | Args: 23 | queries: [N, L, H, D] 24 | keys: [N, S, H, D] 25 | values: [N, S, H, D] 26 | q_mask: [N, L] 27 | kv_mask: [N, S] 28 | Returns: 29 | queried_values: (N, L, H, D) 30 | """ 31 | Q = self.feature_map(queries) 32 | K = self.feature_map(keys) 33 | 34 | # set padded position to zero 35 | if q_mask is not None: 36 | Q = Q * q_mask[:, :, None, None] 37 | if kv_mask is not None: 38 | K = K * kv_mask[:, :, None, None] 39 | values = values * kv_mask[:, :, None, None] 40 | 41 | v_length = values.size(1) 42 | values = values / v_length # prevent fp16 overflow 43 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V 44 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 45 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 46 | 47 | return queried_values.contiguous() 48 | 49 | 50 | class FullAttention(Module): 51 | def __init__(self, use_dropout=False, attention_dropout=0.1): 52 | super().__init__() 53 | self.use_dropout = use_dropout 54 | self.dropout = Dropout(attention_dropout) 55 | 56 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 57 | """ Multi-head scaled dot-product attention, a.k.a full attention. 58 | Args: 59 | queries: [N, L, H, D] 60 | keys: [N, S, H, D] 61 | values: [N, S, H, D] 62 | q_mask: [N, L] 63 | kv_mask: [N, S] 64 | Returns: 65 | queried_values: (N, L, H, D) 66 | """ 67 | 68 | # Compute the unnormalized attention and apply the masks 69 | QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) 70 | if kv_mask is not None: 71 | QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) 72 | 73 | # Compute the attention and the weighted average 74 | softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) 75 | A = torch.softmax(softmax_temp * QK, dim=2) 76 | if self.use_dropout: 77 | A = self.dropout(A) 78 | 79 | queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) 80 | 81 | return queried_values.contiguous() 82 | -------------------------------------------------------------------------------- /src/loftr/loftr_module/quadtree_attention.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 6 | 7 | sys.path.append("../QuadTreeAttention") 8 | from QuadtreeAttention.modules.quadtree_attention import QTAttA, QTAttB 9 | 10 | 11 | class QuadtreeAttention(nn.Module): 12 | def __init__( 13 | self, 14 | dim, 15 | num_heads, 16 | topks, 17 | value_branch=False, 18 | act=nn.GELU(), 19 | qkv_bias=False, 20 | qk_scale=None, 21 | attn_drop=0.0, 22 | proj_drop=0.0, 23 | scale=1, 24 | attn_type="B", 25 | ): 26 | super().__init__() 27 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 28 | 29 | self.dim = dim 30 | self.num_heads = num_heads 31 | head_dim = dim // num_heads 32 | self.scale = qk_scale or head_dim ** -0.5 33 | 34 | self.q_proj = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=qkv_bias) 35 | self.k_proj = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=qkv_bias) 36 | self.v_proj = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=qkv_bias) 37 | if attn_type == "A": 38 | self.py_att = QTAttA(num_heads, dim // num_heads, scale=scale, topks=topks) 39 | else: 40 | self.py_att = QTAttB(num_heads, dim // num_heads, scale=scale, topks=topks) 41 | 42 | self.attn_drop = nn.Dropout(attn_drop) 43 | self.proj = nn.Linear(dim, dim) 44 | self.proj_drop = nn.Dropout(proj_drop) 45 | 46 | self.scale = scale 47 | 48 | self.apply(self._init_weights) 49 | 50 | def _init_weights(self, m): 51 | if isinstance(m, nn.Linear): 52 | trunc_normal_(m.weight, std=0.02) 53 | if isinstance(m, nn.Linear) and m.bias is not None: 54 | nn.init.constant_(m.bias, 0) 55 | elif isinstance(m, nn.LayerNorm): 56 | nn.init.constant_(m.bias, 0) 57 | nn.init.constant_(m.weight, 1.0) 58 | elif isinstance(m, nn.Conv2d): 59 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | fan_out //= m.groups 61 | # m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 62 | trunc_normal_(m.weight, std=0.02) 63 | m.init = True 64 | if m.bias is not None: 65 | m.bias.data.zero_() 66 | 67 | def forward(self, x, target, H, W, msg=None): 68 | 69 | B, N, C = x.shape 70 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 71 | target = target.permute(0, 2, 1).reshape(B, C, H, W) 72 | keys = [] 73 | values = [] 74 | queries = [] 75 | 76 | q = self.q_proj(x) 77 | k = self.k_proj(target) 78 | v = self.v_proj(target) 79 | for i in range(self.scale): 80 | keys.append(k) 81 | values.append(v) 82 | queries.append(q) 83 | 84 | if i != self.scale - 1: 85 | k = F.avg_pool2d(k, kernel_size=2, stride=2) 86 | q = F.avg_pool2d(q, kernel_size=2, stride=2) 87 | v = F.avg_pool2d(v, kernel_size=2, stride=2) 88 | 89 | msg = self.py_att(queries, keys, values).view(B, -1, C) 90 | 91 | x = self.proj(msg) 92 | x = self.proj_drop(x) 93 | 94 | return x -------------------------------------------------------------------------------- /src/loftr/utils/__pycache__/coarse_matching.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/utils/__pycache__/coarse_matching.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/utils/__pycache__/cvpr_ds_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/utils/__pycache__/cvpr_ds_config.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/utils/__pycache__/fine_matching.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/utils/__pycache__/fine_matching.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/utils/__pycache__/geometry.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/utils/__pycache__/geometry.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/utils/__pycache__/position_encoding.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/utils/__pycache__/position_encoding.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/utils/__pycache__/supervision.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/surface-curvature-estimator/72504f5dd54b3cadf1b775644c1c94b290011fb0/src/loftr/utils/__pycache__/supervision.cpython-38.pyc -------------------------------------------------------------------------------- /src/loftr/utils/cvpr_ds_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | def lower_config(yacs_cfg): 5 | if not isinstance(yacs_cfg, CN): 6 | return yacs_cfg 7 | return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} 8 | 9 | 10 | _CN = CN() 11 | _CN.BACKBONE_TYPE = 'ResNetFPN' 12 | _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] 13 | _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd 14 | _CN.FINE_CONCAT_COARSE_FEAT = True 15 | 16 | 17 | # 1. LoFTR-backbone (local feature CNN) config 18 | _CN.RESNETFPN = CN() 19 | _CN.RESNETFPN.INITIAL_DIM = 128 20 | _CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 21 | 22 | # 2. LoFTR-coarse module config 23 | _CN.COARSE = CN() 24 | _CN.COARSE.D_MODEL = 256 25 | _CN.COARSE.D_FFN = 256 26 | _CN.COARSE.NHEAD = 8 27 | _CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 28 | _CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] 29 | _CN.COARSE.TEMP_BUG_FIX = False 30 | _CN.COARSE.BLOCK_TYPE = 'quadtree' 31 | _CN.COARSE.ATTN_TYPE = 'B' 32 | _CN.COARSE.TOPKS=[32, 16, 16] 33 | 34 | # 3. Coarse-Matching config 35 | _CN.MATCH_COARSE = CN() 36 | _CN.MATCH_COARSE.THR = 0.2 37 | _CN.MATCH_COARSE.BORDER_RM = 2 38 | _CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] 39 | _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 40 | _CN.MATCH_COARSE.SKH_ITERS = 3 41 | _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 42 | _CN.MATCH_COARSE.SKH_PREFILTER = True 43 | _CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory 44 | _CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock 45 | 46 | # 4. LoFTR-fine module config 47 | _CN.FINE = CN() 48 | _CN.FINE.D_MODEL = 128 49 | _CN.FINE.D_FFN = 128 50 | _CN.FINE.NHEAD = 8 51 | _CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 52 | _CN.FINE.ATTENTION = 'linear' 53 | _CN.FINE.BLOCK_TYPE = 'loftr' 54 | 55 | default_cfg = lower_config(_CN) 56 | -------------------------------------------------------------------------------- /src/loftr/utils/fine_matching.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from kornia.geometry.subpix import dsnt 6 | from kornia.utils.grid import create_meshgrid 7 | 8 | 9 | class FineMatching(nn.Module): 10 | """FineMatching with s2d paradigm""" 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, feat_f0, feat_f1, data): 16 | """ 17 | Args: 18 | feat0 (torch.Tensor): [M, WW, C] 19 | feat1 (torch.Tensor): [M, WW, C] 20 | data (dict) 21 | Update: 22 | data (dict):{ 23 | 'expec_f' (torch.Tensor): [M, 3], 24 | 'mkpts0_f' (torch.Tensor): [M, 2], 25 | 'mkpts1_f' (torch.Tensor): [M, 2]} 26 | """ 27 | M, WW, C = feat_f0.shape 28 | W = int(math.sqrt(WW)) 29 | scale = data['hw0_i'][0] / data['hw0_f'][0] 30 | self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale 31 | 32 | # corner case: if no coarse matches found 33 | if M == 0: 34 | assert self.training == False, "M is always >0, when training, see coarse_matching.py" 35 | # logger.warning('No matches found in coarse-level.') 36 | data.update({ 37 | 'expec_f': torch.empty(0, 3, device=feat_f0.device), 38 | 'mkpts0_f': data['mkpts0_c'], 39 | 'mkpts1_f': data['mkpts1_c'], 40 | }) 41 | return 42 | 43 | feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] 44 | sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) 45 | softmax_temp = 1. / C**.5 46 | heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) 47 | 48 | # compute coordinates from heatmap 49 | coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] 50 | grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] 51 | 52 | # compute std over 53 | var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] 54 | std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability 55 | 56 | # for fine-level supervision 57 | data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) 58 | 59 | # compute absolute kpt coords 60 | self.get_fine_match(coords_normalized, data) 61 | 62 | @torch.no_grad() 63 | def get_fine_match(self, coords_normed, data): 64 | W, WW, C, scale = self.W, self.WW, self.C, self.scale 65 | 66 | # mkpts0_f and mkpts1_f 67 | mkpts0_f = data['mkpts0_c'] 68 | scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale 69 | mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] 70 | 71 | data.update({ 72 | "mkpts0_f": mkpts0_f, 73 | "mkpts1_f": mkpts1_f 74 | }) 75 | -------------------------------------------------------------------------------- /src/loftr/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.no_grad() 5 | def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): 6 | """ Warp kpts0 from I0 to I1 with depth, K and Rt 7 | Also check covisibility and depth consistency. 8 | Depth is consistent if relative error < 0.2 (hard-coded). 9 | 10 | Args: 11 | kpts0 (torch.Tensor): [N, L, 2] - , 12 | depth0 (torch.Tensor): [N, H, W], 13 | depth1 (torch.Tensor): [N, H, W], 14 | T_0to1 (torch.Tensor): [N, 3, 4], 15 | K0 (torch.Tensor): [N, 3, 3], 16 | K1 (torch.Tensor): [N, 3, 3], 17 | Returns: 18 | calculable_mask (torch.Tensor): [N, L] 19 | warped_keypoints0 (torch.Tensor): [N, L, 2] 20 | """ 21 | kpts0_long = kpts0.round().long() 22 | 23 | # Sample depth, get calculable_mask on depth != 0 24 | kpts0_depth = torch.stack( 25 | [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 26 | ) # (N, L) 27 | nonzero_mask = kpts0_depth != 0 28 | 29 | # Unproject 30 | kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) 31 | kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) 32 | 33 | # Rigid Transform 34 | w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) 35 | w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] 36 | 37 | # Project 38 | w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) 39 | w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth 40 | 41 | # Covisible Check 42 | h, w = depth1.shape[1:3] 43 | covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ 44 | (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) 45 | w_kpts0_long = w_kpts0.long() 46 | w_kpts0_long[~covisible_mask, :] = 0 47 | 48 | w_kpts0_depth = torch.stack( 49 | [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 50 | ) # (N, L) 51 | consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 52 | valid_mask = nonzero_mask * covisible_mask * consistent_mask 53 | 54 | return valid_mask, w_kpts0 55 | -------------------------------------------------------------------------------- /src/loftr/utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class PositionEncodingSine(nn.Module): 7 | """ 8 | This is a sinusoidal position encoding that generalized to 2-dimensional images 9 | """ 10 | 11 | def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): 12 | """ 13 | Args: 14 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 15 | temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), 16 | the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact 17 | on the final performance. For now, we keep both impls for backward compatability. 18 | We will remove the buggy impl after re-training all variants of our released models. 19 | """ 20 | super().__init__() 21 | 22 | pe = torch.zeros((d_model, *max_shape)) 23 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 24 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 25 | if temp_bug_fix: 26 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) 27 | else: # a buggy implementation (for backward compatability only) 28 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) 29 | div_term = div_term[:, None, None] # [C//4, 1, 1] 30 | pe[0::4, :, :] = torch.sin(x_position * div_term) 31 | pe[1::4, :, :] = torch.cos(x_position * div_term) 32 | pe[2::4, :, :] = torch.sin(y_position * div_term) 33 | pe[3::4, :, :] = torch.cos(y_position * div_term) 34 | 35 | self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] 36 | 37 | def forward(self, x): 38 | """ 39 | Args: 40 | x: [N, C, H, W] 41 | """ 42 | return x + self.pe[:, :, :x.size(2), :x.size(3)] 43 | -------------------------------------------------------------------------------- /src/loftr/utils/supervision.py: -------------------------------------------------------------------------------- 1 | from math import log 2 | from loguru import logger 3 | 4 | import torch 5 | from einops import repeat 6 | from kornia.utils import create_meshgrid 7 | 8 | from .geometry import warp_kpts 9 | 10 | ############## ↓ Coarse-Level supervision ↓ ############## 11 | 12 | 13 | @torch.no_grad() 14 | def mask_pts_at_padded_regions(grid_pt, mask): 15 | """For megadepth dataset, zero-padding exists in images""" 16 | mask = repeat(mask, 'n h w -> n (h w) c', c=2) 17 | grid_pt[~mask.bool()] = 0 18 | return grid_pt 19 | 20 | 21 | @torch.no_grad() 22 | def spvs_coarse(data, config): 23 | """ 24 | Update: 25 | data (dict): { 26 | "conf_matrix_gt": [N, hw0, hw1], 27 | 'spv_b_ids': [M] 28 | 'spv_i_ids': [M] 29 | 'spv_j_ids': [M] 30 | 'spv_w_pt0_i': [N, hw0, 2], in original image resolution 31 | 'spv_pt1_i': [N, hw1, 2], in original image resolution 32 | } 33 | 34 | NOTE: 35 | - for scannet dataset, there're 3 kinds of resolution {i, c, f} 36 | - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} 37 | """ 38 | # 1. misc 39 | device = data['image0'].device 40 | N, _, H0, W0 = data['image0'].shape 41 | _, _, H1, W1 = data['image1'].shape 42 | scale = config['LOFTR']['RESOLUTION'][0] 43 | scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale 44 | scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale 45 | h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) 46 | 47 | # 2. warp grids 48 | # create kpts in meshgrid and resize them to image resolution 49 | grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] 50 | grid_pt0_i = scale0 * grid_pt0_c 51 | grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) 52 | grid_pt1_i = scale1 * grid_pt1_c 53 | 54 | # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt 55 | if 'mask0' in data: 56 | grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) 57 | grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) 58 | 59 | # warp kpts bi-directionally and resize them to coarse-level resolution 60 | # (no depth consistency check, since it leads to worse results experimentally) 61 | # (unhandled edge case: points with 0-depth will be warped to the left-up corner) 62 | _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) 63 | _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) 64 | w_pt0_c = w_pt0_i / scale1 65 | w_pt1_c = w_pt1_i / scale0 66 | 67 | # 3. check if mutual nearest neighbor 68 | w_pt0_c_round = w_pt0_c[:, :, :].round().long() 69 | nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 70 | w_pt1_c_round = w_pt1_c[:, :, :].round().long() 71 | nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 72 | 73 | # corner case: out of boundary 74 | def out_bound_mask(pt, w, h): 75 | return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) 76 | nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 77 | nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 78 | 79 | loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) 80 | correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) 81 | correct_0to1[:, 0] = False # ignore the top-left corner 82 | 83 | # 4. construct a gt conf_matrix 84 | conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) 85 | b_ids, i_ids = torch.where(correct_0to1 != 0) 86 | j_ids = nearest_index1[b_ids, i_ids] 87 | 88 | conf_matrix_gt[b_ids, i_ids, j_ids] = 1 89 | data.update({'conf_matrix_gt': conf_matrix_gt}) 90 | 91 | # 5. save coarse matches(gt) for training fine level 92 | if len(b_ids) == 0: 93 | logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") 94 | # this won't affect fine-level loss calculation 95 | b_ids = torch.tensor([0], device=device) 96 | i_ids = torch.tensor([0], device=device) 97 | j_ids = torch.tensor([0], device=device) 98 | 99 | data.update({ 100 | 'spv_b_ids': b_ids, 101 | 'spv_i_ids': i_ids, 102 | 'spv_j_ids': j_ids 103 | }) 104 | 105 | # 6. save intermediate results (for fast fine-level computation) 106 | data.update({ 107 | 'spv_w_pt0_i': w_pt0_i, 108 | 'spv_pt1_i': grid_pt1_i 109 | }) 110 | 111 | 112 | def compute_supervision_coarse(data, config): 113 | assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" 114 | data_source = data['dataset_name'][0] 115 | if data_source.lower() in ['scannet', 'megadepth']: 116 | spvs_coarse(data, config) 117 | else: 118 | raise ValueError(f'Unknown data source: {data_source}') 119 | 120 | 121 | ############## ↓ Fine-Level supervision ↓ ############## 122 | 123 | @torch.no_grad() 124 | def spvs_fine(data, config): 125 | """ 126 | Update: 127 | data (dict):{ 128 | "expec_f_gt": [M, 2]} 129 | """ 130 | # 1. misc 131 | # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') 132 | w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] 133 | scale = config['LOFTR']['RESOLUTION'][1] 134 | radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2 135 | 136 | # 2. get coarse prediction 137 | b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] 138 | 139 | # 3. compute gt 140 | scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale 141 | # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later 142 | expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] 143 | data.update({"expec_f_gt": expec_f_gt}) 144 | 145 | 146 | def compute_supervision_fine(data, config): 147 | data_source = data['dataset_name'][0] 148 | if data_source.lower() in ['scannet', 'megadepth']: 149 | spvs_fine(data, config) 150 | else: 151 | raise NotImplementedError 152 | -------------------------------------------------------------------------------- /src/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR 3 | 4 | 5 | def build_optimizer(model, config): 6 | name = config.TRAINER.OPTIMIZER 7 | lr = 1e-5 8 | 9 | if name == "adam": 10 | return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) 11 | elif name == "adamw": 12 | return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) 13 | else: 14 | raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") 15 | 16 | 17 | def build_scheduler(config, optimizer): 18 | """ 19 | Returns: 20 | scheduler (dict):{ 21 | 'scheduler': lr_scheduler, 22 | 'interval': 'step', # or 'epoch' 23 | 'monitor': 'val_f1', (optional) 24 | 'frequency': x, (optional) 25 | } 26 | """ 27 | scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} 28 | name = config.TRAINER.SCHEDULER 29 | 30 | if name == 'MultiStepLR': 31 | scheduler.update( 32 | {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) 33 | elif name == 'CosineAnnealing': 34 | scheduler.update( 35 | {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) 36 | elif name == 'ExponentialLR': 37 | scheduler.update( 38 | {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) 39 | else: 40 | raise NotImplementedError() 41 | 42 | return scheduler 43 | -------------------------------------------------------------------------------- /src/utils/augment.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | 3 | 4 | class DarkAug(object): 5 | """ 6 | Extreme dark augmentation aiming at Aachen Day-Night 7 | """ 8 | 9 | def __init__(self) -> None: 10 | self.augmentor = A.Compose([ 11 | A.RandomBrightnessContrast(p=0.75, brightness_limit=(-0.6, 0.0), contrast_limit=(-0.5, 0.3)), 12 | A.Blur(p=0.1, blur_limit=(3, 9)), 13 | A.MotionBlur(p=0.2, blur_limit=(3, 25)), 14 | A.RandomGamma(p=0.1, gamma_limit=(15, 65)), 15 | A.HueSaturationValue(p=0.1, val_shift_limit=(-100, -40)) 16 | ], p=0.75) 17 | 18 | def __call__(self, x): 19 | return self.augmentor(image=x)['image'] 20 | 21 | 22 | class MobileAug(object): 23 | """ 24 | Random augmentations aiming at images of mobile/handhold devices. 25 | """ 26 | 27 | def __init__(self): 28 | self.augmentor = A.Compose([ 29 | A.MotionBlur(p=0.25), 30 | A.ColorJitter(p=0.5), 31 | A.RandomRain(p=0.1), # random occlusion 32 | A.RandomSunFlare(p=0.1), 33 | A.JpegCompression(p=0.25), 34 | A.ISONoise(p=0.25) 35 | ], p=1.0) 36 | 37 | def __call__(self, x): 38 | return self.augmentor(image=x)['image'] 39 | 40 | 41 | def build_augmentor(method=None, **kwargs): 42 | if method is not None: 43 | raise NotImplementedError('Using of augmentation functions are not supported yet!') 44 | if method == 'dark': 45 | return DarkAug() 46 | elif method == 'mobile': 47 | return MobileAug() 48 | elif method is None: 49 | return None 50 | else: 51 | raise ValueError(f'Invalid augmentation method: {method}') 52 | 53 | 54 | if __name__ == '__main__': 55 | augmentor = build_augmentor('FDA') 56 | -------------------------------------------------------------------------------- /src/utils/comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | [Copied from detectron2] 4 | This file contains primitives for multi-gpu communication. 5 | This is useful when doing distributed training. 6 | """ 7 | 8 | import functools 9 | import logging 10 | import numpy as np 11 | import pickle 12 | import torch 13 | import torch.distributed as dist 14 | 15 | _LOCAL_PROCESS_GROUP = None 16 | """ 17 | A torch process group which only includes processes that on the same machine as the current process. 18 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 19 | """ 20 | 21 | 22 | def get_world_size() -> int: 23 | if not dist.is_available(): 24 | return 1 25 | if not dist.is_initialized(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def get_rank() -> int: 31 | if not dist.is_available(): 32 | return 0 33 | if not dist.is_initialized(): 34 | return 0 35 | return dist.get_rank() 36 | 37 | 38 | def get_local_rank() -> int: 39 | """ 40 | Returns: 41 | The rank of the current process within the local (per-machine) process group. 42 | """ 43 | if not dist.is_available(): 44 | return 0 45 | if not dist.is_initialized(): 46 | return 0 47 | assert _LOCAL_PROCESS_GROUP is not None 48 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 49 | 50 | 51 | def get_local_size() -> int: 52 | """ 53 | Returns: 54 | The size of the per-machine process group, 55 | i.e. the number of processes per machine. 56 | """ 57 | if not dist.is_available(): 58 | return 1 59 | if not dist.is_initialized(): 60 | return 1 61 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 62 | 63 | 64 | def is_main_process() -> bool: 65 | return get_rank() == 0 66 | 67 | 68 | def synchronize(): 69 | """ 70 | Helper function to synchronize (barrier) among all processes when 71 | using distributed training 72 | """ 73 | if not dist.is_available(): 74 | return 75 | if not dist.is_initialized(): 76 | return 77 | world_size = dist.get_world_size() 78 | if world_size == 1: 79 | return 80 | dist.barrier() 81 | 82 | 83 | @functools.lru_cache() 84 | def _get_global_gloo_group(): 85 | """ 86 | Return a process group based on gloo backend, containing all the ranks 87 | The result is cached. 88 | """ 89 | if dist.get_backend() == "nccl": 90 | return dist.new_group(backend="gloo") 91 | else: 92 | return dist.group.WORLD 93 | 94 | 95 | def _serialize_to_tensor(data, group): 96 | backend = dist.get_backend(group) 97 | assert backend in ["gloo", "nccl"] 98 | device = torch.device("cpu" if backend == "gloo" else "cuda") 99 | 100 | buffer = pickle.dumps(data) 101 | if len(buffer) > 1024 ** 3: 102 | logger = logging.getLogger(__name__) 103 | logger.warning( 104 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 105 | get_rank(), len(buffer) / (1024 ** 3), device 106 | ) 107 | ) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to(device=device) 110 | return tensor 111 | 112 | 113 | def _pad_to_largest_tensor(tensor, group): 114 | """ 115 | Returns: 116 | list[int]: size of the tensor, on each rank 117 | Tensor: padded tensor that has the max size 118 | """ 119 | world_size = dist.get_world_size(group=group) 120 | assert ( 121 | world_size >= 1 122 | ), "comm.gather/all_gather must be called from ranks within the given group!" 123 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 124 | size_list = [ 125 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 126 | ] 127 | dist.all_gather(size_list, local_size, group=group) 128 | 129 | size_list = [int(size.item()) for size in size_list] 130 | 131 | max_size = max(size_list) 132 | 133 | # we pad the tensor because torch all_gather does not support 134 | # gathering tensors of different shapes 135 | if local_size != max_size: 136 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 137 | tensor = torch.cat((tensor, padding), dim=0) 138 | return size_list, tensor 139 | 140 | 141 | def all_gather(data, group=None): 142 | """ 143 | Run all_gather on arbitrary picklable data (not necessarily tensors). 144 | 145 | Args: 146 | data: any picklable object 147 | group: a torch process group. By default, will use a group which 148 | contains all ranks on gloo backend. 149 | 150 | Returns: 151 | list[data]: list of data gathered from each rank 152 | """ 153 | if get_world_size() == 1: 154 | return [data] 155 | if group is None: 156 | group = _get_global_gloo_group() 157 | if dist.get_world_size(group) == 1: 158 | return [data] 159 | 160 | tensor = _serialize_to_tensor(data, group) 161 | 162 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 163 | max_size = max(size_list) 164 | 165 | # receiving Tensor from all ranks 166 | tensor_list = [ 167 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 168 | ] 169 | dist.all_gather(tensor_list, tensor, group=group) 170 | 171 | data_list = [] 172 | for size, tensor in zip(size_list, tensor_list): 173 | buffer = tensor.cpu().numpy().tobytes()[:size] 174 | data_list.append(pickle.loads(buffer)) 175 | 176 | return data_list 177 | 178 | 179 | def gather(data, dst=0, group=None): 180 | """ 181 | Run gather on arbitrary picklable data (not necessarily tensors). 182 | 183 | Args: 184 | data: any picklable object 185 | dst (int): destination rank 186 | group: a torch process group. By default, will use a group which 187 | contains all ranks on gloo backend. 188 | 189 | Returns: 190 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 191 | an empty list. 192 | """ 193 | if get_world_size() == 1: 194 | return [data] 195 | if group is None: 196 | group = _get_global_gloo_group() 197 | if dist.get_world_size(group=group) == 1: 198 | return [data] 199 | rank = dist.get_rank(group=group) 200 | 201 | tensor = _serialize_to_tensor(data, group) 202 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 203 | 204 | # receiving Tensor from all ranks 205 | if rank == dst: 206 | max_size = max(size_list) 207 | tensor_list = [ 208 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 209 | ] 210 | dist.gather(tensor, tensor_list, dst=dst, group=group) 211 | 212 | data_list = [] 213 | for size, tensor in zip(size_list, tensor_list): 214 | buffer = tensor.cpu().numpy().tobytes()[:size] 215 | data_list.append(pickle.loads(buffer)) 216 | return data_list 217 | else: 218 | dist.gather(tensor, [], dst=dst, group=group) 219 | return [] 220 | 221 | 222 | def shared_random_seed(): 223 | """ 224 | Returns: 225 | int: a random number that is the same across all workers. 226 | If workers need a shared RNG, they can use this shared seed to 227 | create one. 228 | 229 | All workers must call this function, otherwise it will deadlock. 230 | """ 231 | ints = np.random.randint(2 ** 31) 232 | all_ints = all_gather(ints) 233 | return all_ints[0] 234 | 235 | 236 | def reduce_dict(input_dict, average=True): 237 | """ 238 | Reduce the values in the dictionary from all processes so that process with rank 239 | 0 has the reduced results. 240 | 241 | Args: 242 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 243 | average (bool): whether to do average or sum 244 | 245 | Returns: 246 | a dict with the same keys as input_dict, after reduction. 247 | """ 248 | world_size = get_world_size() 249 | if world_size < 2: 250 | return input_dict 251 | with torch.no_grad(): 252 | names = [] 253 | values = [] 254 | # sort the keys so that they are consistent across processes 255 | for k in sorted(input_dict.keys()): 256 | names.append(k) 257 | values.append(input_dict[k]) 258 | values = torch.stack(values, dim=0) 259 | dist.reduce(values, dst=0) 260 | if dist.get_rank() == 0 and average: 261 | # only main process gets accumulated, so only divide by 262 | # world_size in this case 263 | values /= world_size 264 | reduced_dict = {k: v for k, v in zip(names, values)} 265 | return reduced_dict 266 | -------------------------------------------------------------------------------- /src/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # --- PL-DATAMODULE --- 5 | 6 | def get_local_split(items: list, world_size: int, rank: int, seed: int): 7 | """ The local rank only loads a split of the dataset. """ 8 | n_items = len(items) 9 | items_permute = np.random.RandomState(seed).permutation(items) 10 | if n_items % world_size == 0: 11 | padded_items = items_permute 12 | else: 13 | padding = np.random.RandomState(seed).choice( 14 | items, 15 | world_size - (n_items % world_size), 16 | replace=True) 17 | padded_items = np.concatenate([items_permute, padding]) 18 | assert len(padded_items) % world_size == 0, \ 19 | f'len(padded_items): {len(padded_items)}; world_size: {world_size}; len(padding): {len(padding)}' 20 | n_per_rank = len(padded_items) // world_size 21 | local_items = padded_items[n_per_rank * rank: n_per_rank * (rank+1)] 22 | 23 | return local_items 24 | -------------------------------------------------------------------------------- /src/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | from loguru import logger 3 | 4 | import cv2 5 | import numpy as np 6 | import h5py 7 | import torch 8 | from numpy.linalg import inv 9 | from torchvision.transforms import Compose 10 | from src.DPT.midas.transforms import Resize, NormalizeImage, PrepareForNet 11 | 12 | 13 | try: 14 | # for internel use only 15 | from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT 16 | except Exception: 17 | MEGADEPTH_CLIENT = SCANNET_CLIENT = None 18 | 19 | # --- DATA IO --- 20 | 21 | def load_array_from_s3( 22 | path, client, cv_type, 23 | use_h5py=False, 24 | ): 25 | byte_str = client.Get(path) 26 | try: 27 | if not use_h5py: 28 | raw_array = np.fromstring(byte_str, np.uint8) 29 | data = cv2.imdecode(raw_array, cv_type) 30 | else: 31 | f = io.BytesIO(byte_str) 32 | data = np.array(h5py.File(f, 'r')['/depth']) 33 | except Exception as ex: 34 | print(f"==> Data loading failure: {path}") 35 | raise ex 36 | 37 | assert data is not None 38 | return data 39 | 40 | 41 | def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): 42 | cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \ 43 | else cv2.IMREAD_COLOR 44 | if str(path).startswith('s3://'): 45 | image = load_array_from_s3(str(path), client, cv_type) 46 | else: 47 | image = cv2.imread(str(path), cv_type) 48 | 49 | if augment_fn is not None: 50 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 51 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 52 | image = augment_fn(image) 53 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 54 | return image # (h, w) 55 | 56 | 57 | def get_resized_wh(w, h, resize=None): 58 | if resize is not None: # resize the longer edge 59 | scale = resize / max(h, w) 60 | w_new, h_new = int(round(w*scale)), int(round(h*scale)) 61 | else: 62 | w_new, h_new = w, h 63 | return w_new, h_new 64 | 65 | 66 | def get_divisible_wh(w, h, df=None): 67 | if df is not None: 68 | w_new, h_new = map(lambda x: int(x // df * df), [w, h]) 69 | else: 70 | w_new, h_new = w, h 71 | return w_new, h_new 72 | 73 | 74 | def pad_bottom_right(inp, pad_size, ret_mask=False): 75 | assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" 76 | mask = None 77 | if inp.ndim == 2: 78 | padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) 79 | padded[:inp.shape[0], :inp.shape[1]] = inp 80 | if ret_mask: 81 | mask = np.zeros((pad_size, pad_size), dtype=bool) 82 | mask[:inp.shape[0], :inp.shape[1]] = True 83 | elif inp.ndim == 3: 84 | padded = np.zeros((pad_size, pad_size, inp.shape[2]), dtype=inp.dtype) 85 | padded[:inp.shape[0], :inp.shape[1], :] = inp 86 | if ret_mask: 87 | mask = np.zeros((pad_size, pad_size, inp.shape[2] ), dtype=bool) 88 | mask[:inp.shape[0], :inp.shape[1], :] = True 89 | else: 90 | raise NotImplementedError() 91 | return padded, mask 92 | 93 | 94 | # --- MEGADEPTH --- 95 | 96 | def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): 97 | """ 98 | Args: 99 | resize (int, optional): the longer edge of resized images. None for no resize. 100 | padding (bool): If set to 'True', zero-pad resized images to squared size. 101 | augment_fn (callable, optional): augments images with pre-defined visual effects 102 | Returns: 103 | image (torch.tensor): (1, h, w) 104 | mask (torch.tensor): (h, w) 105 | scale (torch.tensor): [w/w_new, h/h_new] 106 | """ 107 | # read image 108 | image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) 109 | 110 | # resize image 111 | w, h = image.shape[1], image.shape[0] 112 | w_new, h_new = get_resized_wh(w, h, resize) 113 | w_new, h_new = get_divisible_wh(w_new, h_new, df) 114 | 115 | image = cv2.resize(image, (w_new, h_new)) 116 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) 117 | 118 | if padding: # padding 119 | pad_to = max(h_new, w_new) 120 | image, mask = pad_bottom_right(image, pad_to, ret_mask=True) 121 | else: 122 | mask = None 123 | 124 | image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized 125 | mask = torch.from_numpy(mask) 126 | 127 | return image, mask, scale 128 | 129 | 130 | def read_megadepth_depth(path, pad_to=None): 131 | if str(path).startswith('s3://'): 132 | depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) 133 | else: 134 | depth = np.array(h5py.File(path, 'r')['depth']) 135 | if pad_to is not None: 136 | depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) 137 | depth = torch.from_numpy(depth).float() # (h, w) 138 | return depth 139 | 140 | 141 | # --- ScanNet --- 142 | 143 | def read_scannet_gray(path, resize=(640, 480), augment_fn=None): 144 | """ 145 | Args: 146 | resize (tuple): align image to depthmap, in (w, h). 147 | augment_fn (callable, optional): augments images with pre-defined visual effects 148 | Returns: 149 | image (torch.tensor): (1, h, w) 150 | mask (torch.tensor): (h, w) 151 | scale (torch.tensor): [w/w_new, h/h_new] 152 | """ 153 | # read and resize image 154 | image = imread_gray(path, augment_fn) 155 | image = cv2.resize(image, resize) 156 | 157 | # (h, w) -> (1, h, w) and normalized 158 | image = torch.from_numpy(image).float()[None] / 255 159 | return image 160 | 161 | 162 | def read_scannet_depth(path): 163 | if str(path).startswith('s3://'): 164 | depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) 165 | else: 166 | depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) 167 | depth = depth / 1000 168 | depth = torch.from_numpy(depth).float() # (h, w) 169 | return depth 170 | 171 | 172 | def read_scannet_pose(path): 173 | """ Read ScanNet's Camera2World pose and transform it to World2Camera. 174 | 175 | Returns: 176 | pose_w2c (np.ndarray): (4, 4) 177 | """ 178 | cam2world = np.loadtxt(path, delimiter=' ') 179 | world2cam = inv(cam2world) 180 | return world2cam 181 | 182 | 183 | def read_scannet_intrinsic(path): 184 | """ Read ScanNet's intrinsic matrix and return the 3x3 matrix. 185 | """ 186 | intrinsic = np.loadtxt(path, delimiter=' ') 187 | return intrinsic[:-1, :-1] 188 | 189 | 190 | def resize_image_for_depth(img): 191 | net_w, net_h = 384, 384 192 | resize_mode="minimal" 193 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 194 | # normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 195 | transform = Compose( 196 | [ 197 | Resize( 198 | net_w, 199 | net_h, 200 | resize_target=None, 201 | keep_aspect_ratio=True, 202 | ensure_multiple_of=32, 203 | resize_method=resize_mode, 204 | image_interpolation_method=cv2.INTER_CUBIC, 205 | ), 206 | normalization, 207 | PrepareForNet(), 208 | ] 209 | ) 210 | img_input = transform({"image": img})["image"] 211 | return img_input 212 | 213 | 214 | def read_image_for_depth(path, resize = ((640, 480))): 215 | """Read image and output RGB image (0-1). 216 | 217 | Args: 218 | path (str): path to file 219 | 220 | Returns: 221 | array: RGB image (0-1) 222 | """ 223 | img = cv2.imread(path) 224 | img = cv2.resize(img, resize) 225 | 226 | if img.ndim == 2: 227 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 228 | 229 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 230 | 231 | img_input = resize_image_for_depth(img) 232 | 233 | return img_input 234 | 235 | 236 | # --- Megadepth --- 237 | 238 | 239 | def read_image_for_depth_megadepth(path, resize=None, df=None, padding=False, augment_fn=None): 240 | """Read image and output RGB image (0-1). 241 | 242 | Args: 243 | path (str): path to file 244 | 245 | Returns: 246 | array: RGB image (0-1) 247 | """ 248 | img = cv2.imread(path) 249 | w, h = img.shape[1], img.shape[0] 250 | w_new, h_new = get_resized_wh(w, h, resize) 251 | w_new, h_new = get_divisible_wh(w_new, h_new, df) 252 | image = cv2.resize(img, (w_new, h_new)) 253 | scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float) 254 | if padding: # padding 255 | pad_to = max(h_new, w_new) 256 | image_after, mask = pad_bottom_right(image, pad_to, ret_mask=True) 257 | 258 | if img.ndim == 2: 259 | image_after = cv2.cvtColor(image_after, cv2.COLOR_GRAY2BGR) 260 | 261 | image_after = cv2.cvtColor(image_after, cv2.COLOR_BGR2RGB) / 255.0 262 | 263 | img_input = resize_image_for_depth(image_after) 264 | return img_input 265 | 266 | 267 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from collections import OrderedDict 5 | from loguru import logger 6 | from kornia.geometry.epipolar import numeric 7 | from kornia.geometry.conversions import convert_points_to_homogeneous 8 | 9 | 10 | # --- METRICS --- 11 | 12 | def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0): 13 | # angle error between 2 vectors 14 | t_gt = T_0to1[:3, 3] 15 | n = np.linalg.norm(t) * np.linalg.norm(t_gt) 16 | t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0))) 17 | t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity 18 | if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging 19 | t_err = 0 20 | 21 | # angle error between 2 rotation matrices 22 | R_gt = T_0to1[:3, :3] 23 | cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2 24 | cos = np.clip(cos, -1., 1.) # handle numercial errors 25 | R_err = np.rad2deg(np.abs(np.arccos(cos))) 26 | 27 | return t_err, R_err 28 | 29 | 30 | def symmetric_epipolar_distance(pts0, pts1, E, K0, K1): 31 | """Squared symmetric epipolar distance. 32 | This can be seen as a biased estimation of the reprojection error. 33 | Args: 34 | pts0 (torch.Tensor): [N, 2] 35 | E (torch.Tensor): [3, 3] 36 | """ 37 | pts0 = (pts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 38 | pts1 = (pts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 39 | pts0 = convert_points_to_homogeneous(pts0) 40 | pts1 = convert_points_to_homogeneous(pts1) 41 | 42 | Ep0 = pts0 @ E.T # [N, 3] 43 | p1Ep0 = torch.sum(pts1 * Ep0, -1) # [N,] 44 | Etp1 = pts1 @ E # [N, 3] 45 | 46 | d = p1Ep0**2 * (1.0 / (Ep0[:, 0]**2 + Ep0[:, 1]**2) + 1.0 / (Etp1[:, 0]**2 + Etp1[:, 1]**2)) # N 47 | return d 48 | 49 | 50 | def compute_symmetrical_epipolar_errors(data): 51 | """ 52 | Update: 53 | data (dict):{"epi_errs": [M]} 54 | """ 55 | Tx = numeric.cross_product_matrix(data['T_0to1'][:, :3, 3]) 56 | E_mat = Tx @ data['T_0to1'][:, :3, :3] 57 | 58 | m_bids = data['m_bids'] 59 | pts0 = data['mkpts0_f'] 60 | pts1 = data['mkpts1_f'] 61 | 62 | epi_errs = [] 63 | for bs in range(Tx.size(0)): 64 | mask = m_bids == bs 65 | epi_errs.append( 66 | symmetric_epipolar_distance(pts0[mask], pts1[mask], E_mat[bs], data['K0'][bs], data['K1'][bs])) 67 | epi_errs = torch.cat(epi_errs, dim=0) 68 | 69 | data.update({'epi_errs': epi_errs}) 70 | 71 | 72 | def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999): 73 | if len(kpts0) < 5: 74 | return None 75 | # normalize keypoints 76 | kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 77 | kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 78 | 79 | # normalize ransac threshold 80 | ransac_thr = thresh / np.mean([K0[0, 0], K1[1, 1], K0[0, 0], K1[1, 1]]) 81 | 82 | # compute pose with cv2 83 | E, mask = cv2.findEssentialMat( 84 | kpts0, kpts1, np.eye(3), threshold=ransac_thr, prob=conf, method=cv2.RANSAC) 85 | if E is None: 86 | print("\nE is None while trying to recover pose.\n") 87 | return None 88 | 89 | # recover pose from E 90 | best_num_inliers = 0 91 | ret = None 92 | for _E in np.split(E, len(E) / 3): 93 | n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask) 94 | if n > best_num_inliers: 95 | ret = (R, t[:, 0], mask.ravel() > 0) 96 | best_num_inliers = n 97 | 98 | return ret 99 | 100 | 101 | def compute_pose_errors(data, config): 102 | """ 103 | Update: 104 | data (dict):{ 105 | "R_errs" List[float]: [N] 106 | "t_errs" List[float]: [N] 107 | "inliers" List[np.ndarray]: [N] 108 | } 109 | """ 110 | pixel_thr = config.TRAINER.RANSAC_PIXEL_THR # 0.5 111 | conf = config.TRAINER.RANSAC_CONF # 0.99999 112 | data.update({'R_errs': [], 't_errs': [], 'inliers': []}) 113 | 114 | m_bids = data['m_bids'].cpu().numpy() 115 | pts0 = data['mkpts0_f'].cpu().numpy() 116 | pts1 = data['mkpts1_f'].cpu().numpy() 117 | K0 = data['K0'].cpu().numpy() 118 | K1 = data['K1'].cpu().numpy() 119 | T_0to1 = data['T_0to1'].cpu().numpy() 120 | 121 | for bs in range(K0.shape[0]): 122 | mask = m_bids == bs 123 | ret = estimate_pose(pts0[mask], pts1[mask], K0[bs], K1[bs], pixel_thr, conf=conf) 124 | 125 | if ret is None: 126 | data['R_errs'].append(np.inf) 127 | data['t_errs'].append(np.inf) 128 | data['inliers'].append(np.array([]).astype(np.bool)) 129 | else: 130 | R, t, inliers = ret 131 | t_err, R_err = relative_pose_error(T_0to1[bs], R, t, ignore_gt_t_thr=0.0) 132 | data['R_errs'].append(R_err) 133 | data['t_errs'].append(t_err) 134 | data['inliers'].append(inliers) 135 | 136 | 137 | # --- METRIC AGGREGATION --- 138 | 139 | def error_auc(errors, thresholds): 140 | """ 141 | Args: 142 | errors (list): [N,] 143 | thresholds (list) 144 | """ 145 | errors = [0] + sorted(list(errors)) 146 | recall = list(np.linspace(0, 1, len(errors))) 147 | 148 | aucs = [] 149 | thresholds = [5, 10, 20] 150 | for thr in thresholds: 151 | last_index = np.searchsorted(errors, thr) 152 | y = recall[:last_index] + [recall[last_index-1]] 153 | x = errors[:last_index] + [thr] 154 | aucs.append(np.trapz(y, x) / thr) 155 | 156 | return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} 157 | 158 | 159 | def epidist_prec(errors, thresholds, ret_dict=False): 160 | precs = [] 161 | for thr in thresholds: 162 | prec_ = [] 163 | for errs in errors: 164 | correct_mask = errs < thr 165 | prec_.append(np.mean(correct_mask) if len(correct_mask) > 0 else 0) 166 | precs.append(np.mean(prec_) if len(prec_) > 0 else 0) 167 | if ret_dict: 168 | return {f'prec@{t:.0e}': prec for t, prec in zip(thresholds, precs)} 169 | else: 170 | return precs 171 | 172 | 173 | def aggregate_metrics(metrics, epi_err_thr=5e-4): 174 | """ Aggregate metrics for the whole dataset: 175 | (This method should be called once per dataset) 176 | 1. AUC of the pose error (angular) at the threshold [5, 10, 20] 177 | 2. Mean matching precision at the threshold 5e-4(ScanNet), 1e-4(MegaDepth) 178 | """ 179 | # filter duplicates 180 | unq_ids = OrderedDict((iden, id) for id, iden in enumerate(metrics['identifiers'])) 181 | unq_ids = list(unq_ids.values()) 182 | logger.info(f'Aggregating metrics over {len(unq_ids)} unique items...') 183 | 184 | # pose auc 185 | angular_thresholds = [5, 10, 20] 186 | pose_errors = np.max(np.stack([metrics['R_errs'], metrics['t_errs']]), axis=0)[unq_ids] 187 | aucs = error_auc(pose_errors, angular_thresholds) # (auc@5, auc@10, auc@20) 188 | 189 | # matching precision 190 | dist_thresholds = [epi_err_thr] 191 | precs = epidist_prec(np.array(metrics['epi_errs'], dtype=object)[unq_ids], dist_thresholds, True) # (prec@err_thr) 192 | 193 | return {**aucs, **precs} 194 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import contextlib 3 | import joblib 4 | from typing import Union 5 | from loguru import _Logger, logger 6 | from itertools import chain 7 | 8 | import torch 9 | from yacs.config import CfgNode as CN 10 | from pytorch_lightning.utilities import rank_zero_only 11 | 12 | 13 | def lower_config(yacs_cfg): 14 | if not isinstance(yacs_cfg, CN): 15 | return yacs_cfg 16 | return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} 17 | 18 | 19 | def upper_config(dict_cfg): 20 | if not isinstance(dict_cfg, dict): 21 | return dict_cfg 22 | return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} 23 | 24 | 25 | def log_on(condition, message, level): 26 | if condition: 27 | assert level in ['INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'] 28 | logger.log(level, message) 29 | 30 | 31 | def get_rank_zero_only_logger(logger: _Logger): 32 | if rank_zero_only.rank == 0: 33 | return logger 34 | else: 35 | for _level in logger._core.levels.keys(): 36 | level = _level.lower() 37 | setattr(logger, level, 38 | lambda x: None) 39 | logger._log = lambda x: None 40 | return logger 41 | 42 | 43 | def setup_gpus(gpus: Union[str, int]) -> int: 44 | """ A temporary fix for pytorch-lighting 1.3.x """ 45 | gpus = str(gpus) 46 | gpu_ids = [] 47 | 48 | if ',' not in gpus: 49 | n_gpus = int(gpus) 50 | return n_gpus if n_gpus != -1 else torch.cuda.device_count() 51 | else: 52 | gpu_ids = [i.strip() for i in gpus.split(',') if i != ''] 53 | 54 | # setup environment variables 55 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') 56 | if visible_devices is None: 57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 58 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in gpu_ids) 59 | visible_devices = os.getenv('CUDA_VISIBLE_DEVICES') 60 | logger.warning(f'[Temporary Fix] manually set CUDA_VISIBLE_DEVICES when specifying gpus to use: {visible_devices}') 61 | else: 62 | logger.warning('[Temporary Fix] CUDA_VISIBLE_DEVICES already set by user or the main process.') 63 | return len(gpu_ids) 64 | 65 | 66 | def flattenList(x): 67 | return list(chain(*x)) 68 | 69 | 70 | @contextlib.contextmanager 71 | def tqdm_joblib(tqdm_object): 72 | """Context manager to patch joblib to report into tqdm progress bar given as argument 73 | 74 | Usage: 75 | with tqdm_joblib(tqdm(desc="My calculation", total=10)) as progress_bar: 76 | Parallel(n_jobs=16)(delayed(sqrt)(i**2) for i in range(10)) 77 | 78 | When iterating over a generator, directly use of tqdm is also a solutin (but monitor the task queuing, instead of finishing) 79 | ret_vals = Parallel(n_jobs=args.world_size)( 80 | delayed(lambda x: _compute_cov_score(pid, *x))(param) 81 | for param in tqdm(combinations(image_ids, 2), 82 | desc=f'Computing cov_score of [{pid}]', 83 | total=len(image_ids)*(len(image_ids)-1)/2)) 84 | Src: https://stackoverflow.com/a/58936697 85 | """ 86 | class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): 87 | def __init__(self, *args, **kwargs): 88 | super().__init__(*args, **kwargs) 89 | 90 | def __call__(self, *args, **kwargs): 91 | tqdm_object.update(n=self.batch_size) 92 | return super().__call__(*args, **kwargs) 93 | 94 | old_batch_callback = joblib.parallel.BatchCompletionCallBack 95 | joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback 96 | try: 97 | yield tqdm_object 98 | finally: 99 | joblib.parallel.BatchCompletionCallBack = old_batch_callback 100 | tqdm_object.close() 101 | 102 | -------------------------------------------------------------------------------- /src/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib 5 | 6 | 7 | def _compute_conf_thresh(data): 8 | dataset_name = data['dataset_name'][0].lower() 9 | if dataset_name == 'scannet': 10 | thr = 5e-4 11 | elif dataset_name == 'megadepth': 12 | thr = 1e-4 13 | else: 14 | raise ValueError(f'Unknown dataset: {dataset_name}') 15 | return thr 16 | 17 | 18 | # --- VISUALIZATION --- # 19 | 20 | def make_matching_figure( 21 | img0, img1, mkpts0_raw, mkpts1_raw, color_raw, 22 | kpts0=None, kpts1=None, text=[], dpi=75, path=None, idxs = None): 23 | # draw image pair 24 | 25 | if idxs is not None: 26 | 27 | mkpts0 = mkpts0_raw[idxs] 28 | mkpts1 = mkpts1_raw[idxs] 29 | color = color_raw[idxs] 30 | else: 31 | mkpts0 = mkpts0_raw 32 | mkpts1 = mkpts1_raw 33 | color = color_raw 34 | 35 | 36 | assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' 37 | fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) 38 | axes[0].imshow(img0, cmap='gray') 39 | axes[1].imshow(img1, cmap='gray') 40 | for i in range(2): # clear all frames 41 | axes[i].get_yaxis().set_ticks([]) 42 | axes[i].get_xaxis().set_ticks([]) 43 | for spine in axes[i].spines.values(): 44 | spine.set_visible(False) 45 | plt.tight_layout(pad=1) 46 | 47 | if kpts0 is not None: 48 | assert kpts1 is not None 49 | axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) 50 | axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) 51 | 52 | # draw matches 53 | if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: 54 | fig.canvas.draw() 55 | transFigure = fig.transFigure.inverted() 56 | fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) 57 | fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) 58 | fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), 59 | (fkpts0[i, 1], fkpts1[i, 1]), 60 | transform=fig.transFigure, c=color[i], linewidth=5) 61 | for i in range(len(mkpts0))] 62 | 63 | axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=15) 64 | axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=15) 65 | 66 | # put txts 67 | txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' 68 | fig.text( 69 | 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, 70 | fontsize=22, va='top', ha='left', color='k', weight='bold') 71 | 72 | # save or return figure 73 | if path: 74 | plt.savefig(str(path), bbox_inches='tight', pad_inches=0) 75 | plt.close() 76 | else: 77 | return fig 78 | 79 | 80 | def _make_evaluation_figure(data, b_id, alpha='dynamic'): 81 | b_mask = data['m_bids'] == b_id 82 | conf_thr = _compute_conf_thresh(data) 83 | 84 | img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) 85 | img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) 86 | kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() 87 | kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() 88 | 89 | # for megadepth, we visualize matches on the resized image 90 | if 'scale0' in data: 91 | kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] 92 | kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] 93 | 94 | epi_errs = data['epi_errs'][b_mask].cpu().numpy() 95 | correct_mask = epi_errs < conf_thr 96 | precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 97 | n_correct = np.sum(correct_mask) 98 | n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) 99 | recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) 100 | # recall might be larger than 1, since the calculation of conf_matrix_gt 101 | # uses groundtruth depths and camera poses, but epipolar distance is used here. 102 | 103 | # matching info 104 | if alpha == 'dynamic': 105 | alpha = dynamic_alpha(len(correct_mask)) 106 | color = error_colormap(epi_errs, conf_thr, alpha=alpha) 107 | 108 | text = [ 109 | f'#Matches {len(kpts0)}', 110 | f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', 111 | f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' 112 | ] 113 | 114 | # make the figure 115 | figure = make_matching_figure(img0, img1, kpts0, kpts1, 116 | color, text=text) 117 | return figure 118 | 119 | def _make_confidence_figure(data, b_id): 120 | # TODO: Implement confidence figure 121 | raise NotImplementedError() 122 | 123 | 124 | def make_matching_figures(data, config, mode='evaluation'): 125 | """ Make matching figures for a batch. 126 | 127 | Args: 128 | data (Dict): a batch updated by PL_LoFTR. 129 | config (Dict): matcher config 130 | Returns: 131 | figures (Dict[str, List[plt.figure]] 132 | """ 133 | assert mode in ['evaluation', 'confidence'] # 'confidence' 134 | figures = {mode: []} 135 | for b_id in range(data['image0'].size(0)): 136 | if mode == 'evaluation': 137 | fig = _make_evaluation_figure( 138 | data, b_id, 139 | alpha=config.TRAINER.PLOT_MATCHES_ALPHA) 140 | elif mode == 'confidence': 141 | fig = _make_confidence_figure(data, b_id) 142 | else: 143 | raise ValueError(f'Unknown plot mode: {mode}') 144 | figures[mode].append(fig) 145 | return figures 146 | 147 | 148 | def dynamic_alpha(n_matches, 149 | milestones=[0, 300, 1000, 2000], 150 | alphas=[1.0, 0.8, 0.4, 0.2]): 151 | if n_matches == 0: 152 | return 1.0 153 | ranges = list(zip(alphas, alphas[1:] + [None])) 154 | loc = bisect.bisect_right(milestones, n_matches) - 1 155 | _range = ranges[loc] 156 | if _range[1] is None: 157 | return _range[0] 158 | return _range[1] + (milestones[loc + 1] - n_matches) / ( 159 | milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) 160 | 161 | 162 | def error_colormap(err, thr, alpha=1.0): 163 | assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" 164 | x = 1 - np.clip(err / (thr * 2), 0, 1) 165 | return np.clip( 166 | np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) 167 | -------------------------------------------------------------------------------- /src/utils/profiler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler 3 | from contextlib import contextmanager 4 | from pytorch_lightning.utilities import rank_zero_only 5 | 6 | 7 | class InferenceProfiler(SimpleProfiler): 8 | """ 9 | This profiler records duration of actions with cuda.synchronize() 10 | Use this in test time. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.start = rank_zero_only(self.start) 16 | self.stop = rank_zero_only(self.stop) 17 | self.summary = rank_zero_only(self.summary) 18 | 19 | @contextmanager 20 | def profile(self, action_name: str) -> None: 21 | try: 22 | torch.cuda.synchronize() 23 | self.start(action_name) 24 | yield action_name 25 | finally: 26 | torch.cuda.synchronize() 27 | self.stop(action_name) 28 | 29 | 30 | def build_profiler(name): 31 | if name == 'inference': 32 | return InferenceProfiler() 33 | elif name == 'pytorch': 34 | from pytorch_lightning.profiler import PyTorchProfiler 35 | return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) 36 | elif name is None: 37 | return PassThroughProfiler() 38 | else: 39 | raise ValueError(f'Invalid profiler: {name}') 40 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import argparse 3 | import pprint 4 | from loguru import logger as loguru_logger 5 | 6 | from src.config.default import get_cfg_defaults 7 | from src.utils.profiler import build_profiler 8 | 9 | from src.lightning.data import MultiSceneDataModule 10 | from src.lightning.lightning_loftr import PL_LoFTR 11 | 12 | 13 | def parse_args(): 14 | # init a costum parser which will be added into pl.Trainer parser 15 | # check documentation: https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-flags 16 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument( 18 | 'data_cfg_path', type=str, help='data config path') 19 | parser.add_argument( 20 | 'main_cfg_path', type=str, help='main config path') 21 | parser.add_argument( 22 | '--ckpt_path', type=str, default="weights/indoor_ds.ckpt", help='path to the checkpoint') 23 | parser.add_argument( 24 | '--dump_dir', type=str, default=None, help="if set, the matching results will be dump to dump_dir") 25 | parser.add_argument( 26 | '--profiler_name', type=str, default=None, help='options: [inference, pytorch], or leave it unset') 27 | parser.add_argument( 28 | '--batch_size', type=int, default=1, help='batch_size per gpu') 29 | parser.add_argument( 30 | '--num_workers', type=int, default=2) 31 | parser.add_argument( 32 | '--thr', type=float, default=None, help='modify the coarse-level matching threshold.') 33 | parser.add_argument( 34 | '--add_curvature', action="store_true", help='if the curvature features are added') 35 | 36 | parser = pl.Trainer.add_argparse_args(parser) 37 | return parser.parse_args() 38 | 39 | 40 | if __name__ == '__main__': 41 | # parse arguments 42 | args = parse_args() 43 | pprint.pprint(vars(args)) 44 | 45 | # init default-cfg and merge it with the main- and data-cfg 46 | config = get_cfg_defaults() 47 | config.merge_from_file(args.main_cfg_path) 48 | config.merge_from_file(args.data_cfg_path) 49 | pl.seed_everything(config.TRAINER.SEED) # reproducibility 50 | 51 | # tune when testing 52 | if args.thr is not None: 53 | config.LOFTR.MATCH_COARSE.THR = args.thr 54 | if args.add_curvature: 55 | config.LOFTR.MATCH_COARSE.ADD_CURV = True 56 | 57 | 58 | loguru_logger.info(f"Args and config initialized!") 59 | 60 | # lightning module 61 | profiler = build_profiler(args.profiler_name) 62 | model = PL_LoFTR(config, pretrained_ckpt=args.ckpt_path, profiler=profiler, dump_dir=args.dump_dir) 63 | loguru_logger.info(f"LoFTR-lightning initialized!") 64 | 65 | # lightning data 66 | data_module = MultiSceneDataModule(args, config) 67 | loguru_logger.info(f"DataModule initialized!") 68 | 69 | # lightning trainer 70 | trainer = pl.Trainer.from_argparse_args(args, replace_sampler_ddp=False, logger=False) 71 | if args.add_curvature: 72 | loguru_logger.info(f"Start testing with cse!") 73 | else: 74 | loguru_logger.info(f"Start testing !") 75 | trainer.test(model, datamodule=data_module, verbose=False) 76 | 77 | --------------------------------------------------------------------------------