├── .gitignore ├── README.md ├── assets ├── ft_visualize.png ├── represent.png ├── res.png └── structure_overview.png ├── checkpoints └── ckpt_best.pt ├── config ├── config_kitti.yaml ├── config_pretrain.yaml └── config_things.yaml ├── dataset ├── __init__.py ├── augmentation.py ├── flyingthings_subset.py ├── kitti.py └── preprocess_data.py ├── env.yaml ├── models ├── __init__.py ├── base.py ├── camlipwc.py ├── camlipwc2d_core.py ├── camlipwc3d_core.py ├── camlipwc_core.py ├── csrc │ ├── __init__.py │ ├── correlation │ │ ├── correlation.cpp │ │ ├── correlation.h │ │ ├── correlation_backward_kernel.cu │ │ ├── correlation_forward_kernel.cu │ │ └── correlation_test.cpp │ ├── furthest_point_sampling │ │ ├── furthest_point_sampling.cpp │ │ ├── furthest_point_sampling.h │ │ ├── furthest_point_sampling_kernel.cu │ │ └── furthest_point_sampling_test.cpp │ ├── k_nearest_neighbor │ │ ├── k_nearest_neighbor.cpp │ │ ├── k_nearest_neighbor.h │ │ ├── k_nearest_neighbor_kernel.cu │ │ └── k_nearest_neighbor_test.cpp │ ├── setup.py │ └── wrapper.py ├── fusion_module.py ├── losses2d.py ├── losses3d.py ├── pointconv.py └── utils.py ├── ops_pytorch ├── __init__.py ├── fused_conv_select │ ├── fused_conv_g.cpp │ ├── fused_conv_go.cu │ ├── fused_conv_gpu.h │ ├── fused_conv_select_k.py │ └── setup.py └── gpu_threenn_sample │ ├── no_sort_knn.py │ ├── no_sort_knn_g.cpp │ ├── no_sort_knn_go.cu │ ├── no_sort_knn_gpu.h │ └── setup.py ├── train.py └── utils ├── __init__.py ├── average_meter.py ├── build_utils.py ├── evaluation_utils.py ├── geometry.py ├── log_utils.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.zip 2 | output/* 3 | .vscode 4 | # 3rdparty 5 | Datasets 6 | SUMMARY 7 | *.out 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | *.log 13 | # C extensions 14 | *.so 15 | *.ply 16 | *.npy 17 | *.log 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | exps*/ 22 | ckpts/* 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | output/ 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | *.ppt 142 | *.pptx 143 | *.log -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

DELFlow: Dense Efficient Learning of Scene Flow for Large-Scale Point Clouds

2 | 3 |
4 |

5 | representation comparison 6 |

7 |
8 | 9 | 10 | 11 | ## :bookmark_tabs: Table of Contents 12 | 13 | 1. [Introduction](#clapper-introduction) 14 | 2. [Installation](#memo-installation) 15 | 3. [Dataset](#file_folder-dataset) 16 | 4. [Usage](#computer-usage) 17 | 5. [Qualitative Results](#art-visualization) 18 | 19 | 20 | 21 | 22 | ## :clapper: Introduction 23 | 24 | Point clouds are naturally sparse, while image pixels are dense. The inconsistency limits feature fusion from both modalities for point-wise scene flow estimation. Previous methods rarely predict scene flow from the entire point clouds of the scene with one-time inference due to the memory inefficiency and heavy overhead from distance calculation and sorting involved in commonly used farthest point sampling, KNN, and ball query algorithms for local feature aggregation. 25 | 26 | To mitigate these issues in scene flow learning, we regularize raw points to a dense format by storing 3D coordinates in 2D grids. Unlike the sampling operation commonly used in existing works, the dense 2D representation 27 | 28 | - preserves most points in the given scene, 29 | - brings in a significant boost of efficiency 30 | - eliminates the density gap between points and pixels, allowing us to perform effective feature fusion. 31 | 32 | We also present a novel warping projection technique to alleviate the information loss problem resulting from the fact that multiple points could be mapped into one grid during projection when computing cost volume. Sufficient experiments demonstrate the efficiency and effectiveness of our method, outperforming the prior-arts on the FlyingThings3D and KITTI dataset. 33 | 34 |
35 |

36 | structure comparison 37 |

38 |
39 | 40 | For more details, please refer to our [paper](https://openaccess.thecvf.com/content/ICCV2023/html/Peng_DELFlow_Dense_Efficient_Learning_of_Scene_Flow_for_Large-Scale_Point_ICCV_2023_paper.html), [arxiv](https://arxiv.org/abs/2308.04383). 41 | 42 | 43 | 44 | ## :memo: Installation 45 | 46 | Create a PyTorch environment using `conda`: 47 | ```bash 48 | git clone https://github.com/IRMVLab/DELFlow.git 49 | cd DELFlow 50 | conda env create -f env.yaml 51 | conda activate delflow 52 | ``` 53 | 54 | Compile CUDA extensions 55 | ```bash 56 | cd models/csrc 57 | python setup.py build_ext --inplace 58 | 59 | cd ops_pytorch/fused_conv_select 60 | python setup.py install 61 | 62 | cd ops_pytorch/gpu_threenn_sample 63 | python setup.py install 64 | ``` 65 | 66 | 67 | 68 | ## :file_folder: Dataset 69 | 70 | ### FlyingThings3D 71 | 72 | First, download and preprocess the [FlyingThings3D subset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) dataset, then process them as follows (you may need to change the `input_dir` and `output_dir`): 73 | 74 | ```bash 75 | python dataset/preprocess_data.py 76 | ``` 77 | 78 | ### KITTI 79 | 80 | First, download the following parts: 81 | 82 | - Main data: [data_scene_flow.zip](https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow.zip) 83 | - Calibration files: [data_scene_flow_calib.zip](https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow_calib.zip) 84 | - Disparity estimation (from GA-Net): [disp_ganet.zip](https://drive.google.com/file/d/1ieFpOVzqCzT8TXNk1zm2d9RLkrcaI78o/view?usp=sharing) 85 | - Semantic segmentation (from DDR-Net): [semantic_ddr.zip](https://drive.google.com/file/d/1dVSJeE9BBmVv2rCe5TR0PVanEv2WzwIy/view?usp=sharing) 86 | 87 | 88 |
89 | [Unzip them and organize the directory as follows (click to expand)] 90 | 91 | 92 | ``` 93 | dataset/kitti_scene_flow 94 | ├── testing 95 | │ ├── calib_cam_to_cam 96 | │ ├── calib_imu_to_velo 97 | │ ├── calib_velo_to_cam 98 | │ ├── disp_ganet 99 | │ ├── flow_occ 100 | │ ├── image_2 101 | │ ├── image_3 102 | │ ├── semantic_ddr 103 | └── training 104 | ├── calib_cam_to_cam 105 | ├── calib_imu_to_velo 106 | ├── calib_velo_to_cam 107 | ├── disp_ganet 108 | ├── disp_occ_0 109 | ├── disp_occ_1 110 | ├── flow_occ 111 | ├── image_2 112 | ├── image_3 113 | ├── obj_map 114 | ├── semantic_ddr 115 | ``` 116 |
117 | 118 | 119 | ## :computer: Usage 120 | 121 | Run the code using the following command: 122 | 123 | ```python 124 | python train.py config/$config_.yaml$ 125 | ``` 126 | 127 |
128 |

129 | res comparison 130 |

131 |
132 | 133 | 134 | 135 | ## :book: Citation 136 | 137 | If you use this codebase or model in your research, please cite: 138 | 139 | ``` 140 | @inproceedings{peng2023delflow, 141 | title={Delflow: Dense efficient learning of scene flow for large-scale point clouds}, 142 | author={Peng, Chensheng and Wang, Guangming and Lo, Xian Wan and Wu, Xinrui and Xu, Chenfeng and Tomizuka, Masayoshi and Zhan, Wei and Wang, Hesheng}, 143 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 144 | pages={16901--16910}, 145 | year={2023} 146 | } 147 | ``` 148 | 149 | 150 | 151 | ## :art: Visualization 152 | 153 | In this section, we present illustrative examples that demonstrate the effectiveness of our proposal. 154 | 155 |
156 |

157 | qual comparison 158 |

159 |
160 | 161 | 162 | 163 | ## :pray: Acknowledgements 164 | 165 | This code benefits a lot from [CamLiFlow](https://github.com/MCG-NJU/CamLiFlow). Thanks for making codes public available. 166 | -------------------------------------------------------------------------------- /assets/ft_visualize.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/ft_visualize.png -------------------------------------------------------------------------------- /assets/represent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/represent.png -------------------------------------------------------------------------------- /assets/res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/res.png -------------------------------------------------------------------------------- /assets/structure_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/structure_overview.png -------------------------------------------------------------------------------- /checkpoints/ckpt_best.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/checkpoints/ckpt_best.pt -------------------------------------------------------------------------------- /config/config_kitti.yaml: -------------------------------------------------------------------------------- 1 | trainset: 2 | name: kitti 3 | root_dir: /data/KITTI_SCENE 4 | split: traing160 5 | n_workers: 8 6 | max_depth: 90 7 | drop_last: true 8 | 9 | valset: 10 | name: kitti 11 | root_dir: /data/KITTI_SCENE 12 | split: training40 13 | n_workers: 8 14 | max_depth: 90 15 | 16 | model: 17 | name: kitti 18 | height: 320 19 | width: 1280 20 | stride_h: [2, 2, 2, 2, 2, 2] 21 | stride_w: [2, 2, 2, 2, 2, 2] 22 | freeze_bn: false 23 | batch_size: 12 24 | 25 | pwc2d: 26 | norm: 27 | feature_pyramid: batch_norm 28 | flow_estimator: null 29 | context_network: null 30 | max_displacement: 4 31 | 32 | pwc3d: 33 | norm: 34 | feature_pyramid: batch_norm 35 | correlation: null 36 | flow_estimator: null 37 | k: 16 38 | kernel_size: [10, 30] 39 | 40 | loss2d: 41 | level_weights: [16, 8, 4, 2, 1] 42 | order: robust 43 | 44 | loss3d: 45 | level_weights: [8, 4, 2, 1, 0.5] 46 | order: robust 47 | 48 | 49 | training: 50 | opt: adamw 51 | momentum: 0.9 52 | weight_decay: 0.000001 53 | grad_max_norm: 25 54 | accum_iter: 1 55 | 56 | sched: cosine 57 | epochs: 1200 58 | # lr_2d: 0.0005 # LR for 2D branch 59 | # lr_3d: 0.001 # LR for 3D branch 60 | lr_2d: 0.0005 # LR for 2D branch 61 | lr_3d: 0.001 # LR for 3D branch 62 | min_lr: 0.00001 63 | warmup_lr: 0.00001 64 | warmup_epochs: 0 65 | cooldown_epochs: 0 66 | 67 | log: 68 | dir: ./experiment 69 | run_name: kitti 70 | save_ckpt: true 71 | save_scalar_summary: true 72 | save_ckpt_every_n_epochs: 1 73 | save_summary_every_n_steps: 100 74 | 75 | ckpt: 76 | path: null 77 | save_path: ./checkpoints 78 | resume: false 79 | 80 | val_interval: 4 81 | gpu: 1 82 | port: random # for multi-gpu training 83 | amp: false 84 | debug: false 85 | sync_bn: true 86 | -------------------------------------------------------------------------------- /config/config_pretrain.yaml: -------------------------------------------------------------------------------- 1 | trainset: 2 | name: flyingthings3d 3 | root_dir: /data/processed_flyingthings3d 4 | split: train 5 | n_workers: 4 6 | drop_last: true 7 | full: false 8 | 9 | augmentation: 10 | enabled: false 11 | color_jitter: 12 | enabled: false 13 | random_horizontal_flip: 14 | enabled: false 15 | random_vertical_flip: 16 | enabled: false 17 | random_crop: 18 | enabled: false 19 | random_scale: 20 | enabled: false 21 | 22 | valset: 23 | name: flyingthings3d 24 | root_dir: /data/processed_flyingthings3d 25 | split: val 26 | n_workers: 4 27 | full: false 28 | 29 | augmentation: 30 | enabled: false 31 | 32 | model: 33 | name: pretrain 34 | height: 256 35 | width: 448 36 | stride_h: [2, 2, 2, 2, 2, 2] 37 | stride_w: [2, 2, 2, 2, 2, 2] 38 | freeze_bn: false 39 | batch_size: 64 40 | 41 | pwc2d: 42 | norm: 43 | feature_pyramid: batch_norm 44 | flow_estimator: null 45 | context_network: null 46 | max_displacement: 4 47 | 48 | pwc3d: 49 | norm: 50 | feature_pyramid: batch_norm 51 | correlation: null 52 | flow_estimator: null 53 | k: 16 54 | kernel_size: [10, 20] 55 | 56 | loss2d: 57 | level_weights: [8, 4, 2, 1, 0.5] 58 | order: l2 59 | 60 | loss3d: 61 | level_weights: [8, 4, 2, 1, 0.5] 62 | order: l2 63 | 64 | 65 | training: 66 | opt: adamw 67 | momentum: 0.9 68 | weight_decay: 0.00001 69 | grad_max_norm: 20 70 | accum_iter: 1 71 | 72 | sched: cosine 73 | epochs: 800 74 | lr_2d: 0.0005 # LR for 2D branch 75 | lr_3d: 0.001 # LR for 3D branch 76 | min_lr: 0.00001 77 | warmup_lr: 0.00001 78 | warmup_epochs: 2 79 | cooldown_epochs: 0 80 | 81 | log: 82 | dir: ./experiment 83 | run_name: pretrain 84 | save_ckpt: true 85 | save_scalar_summary: true 86 | save_ckpt_every_n_epochs: 1 87 | save_summary_every_n_steps: 100 88 | 89 | ckpt: 90 | path: null 91 | save_path: ./checkpoints 92 | resume: false 93 | 94 | val_interval: 10 95 | gpu: 1 96 | port: random # for multi-gpu training 97 | amp: false 98 | debug: false 99 | sync_bn: true 100 | -------------------------------------------------------------------------------- /config/config_things.yaml: -------------------------------------------------------------------------------- 1 | trainset: 2 | name: flyingthings3d 3 | root_dir: /data/processed_flyingthings3d 4 | split: train 5 | n_workers: 4 6 | drop_last: true 7 | full: true 8 | 9 | augmentation: 10 | enabled: true 11 | color_jitter: 12 | enabled: true 13 | brightness: 0.3 14 | contrast: 0.3 15 | saturation: 0.3 16 | hue: 0.159 # 0.5/3.14 17 | random_horizontal_flip: 18 | enabled: false 19 | random_vertical_flip: 20 | enabled: false 21 | random_crop: 22 | enabled: false 23 | crop_size: [640, 384] # [640, 384] 24 | random_scale: 25 | enabled: false 26 | random_down: 27 | enabled: true 28 | type: train 29 | 30 | valset: 31 | name: flyingthings3d 32 | root_dir: /data/processed_flyingthings3d 33 | split: val 34 | n_workers: 4 35 | full: true 36 | 37 | augmentation: 38 | enabled: false 39 | color_jitter: 40 | enabled: false 41 | random_horizontal_flip: 42 | enabled: false 43 | random_vertical_flip: 44 | enabled: false 45 | random_crop: 46 | enabled: false 47 | crop_size: [640, 384] # [640, 384] 48 | random_scale: 49 | enabled: false 50 | random_down: 51 | enabled: true 52 | type: eval 53 | 54 | model: 55 | name: flyingthings 56 | height: 256 57 | width: 448 58 | stride_h: [2, 2, 2, 2, 2, 2] 59 | stride_w: [2, 2, 2, 2, 2, 2] 60 | freeze_bn: false 61 | batch_size: 64 62 | 63 | pwc2d: 64 | norm: 65 | feature_pyramid: batch_norm 66 | flow_estimator: null 67 | context_network: null 68 | max_displacement: 4 69 | 70 | pwc3d: 71 | norm: 72 | feature_pyramid: batch_norm 73 | correlation: null 74 | flow_estimator: null 75 | k: 16 76 | kernel_size: [10, 20] 77 | 78 | loss2d: 79 | level_weights: [8, 4, 2, 1, 0.5] 80 | order: robust 81 | 82 | loss3d: 83 | level_weights: [8, 4, 2, 1, 0.5] 84 | order: robust 85 | 86 | 87 | training: 88 | opt: adamw 89 | momentum: 0.9 90 | weight_decay: 0.000001 91 | grad_max_norm: 20 92 | accum_iter: 1 93 | 94 | sched: cosine 95 | epochs: 400 96 | lr_2d: 0.00005 # LR for 2D branch 97 | lr_3d: 0.0001 # LR for 3D branch 98 | min_lr: 0.00001 99 | warmup_lr: 0.00001 100 | warmup_epochs: 2 101 | cooldown_epochs: 0 102 | 103 | log: 104 | dir: ./experiment 105 | run_name: flyingthings 106 | save_ckpt: true 107 | save_scalar_summary: true 108 | save_ckpt_every_n_epochs: 1 109 | save_summary_every_n_steps: 100 110 | 111 | ckpt: 112 | path: ./checkpoints/best.pt 113 | save_path: ./checkpoints 114 | resume: false 115 | 116 | val_interval: 10 117 | gpu: 0 118 | port: random # for multi-gpu training 119 | amp: false 120 | debug: false 121 | sync_bn: true 122 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import torchvision 4 | import numpy as np 5 | 6 | 7 | def color_jitter(image1, image2, brightness, contrast, saturation, hue): 8 | assert image1.shape == image2.shape 9 | cj_module = torchvision.transforms.ColorJitter(brightness, contrast, saturation, hue) 10 | 11 | images = np.concatenate([image1, image2], axis=0) 12 | images_t = torch.from_numpy(images.transpose([2, 0, 1]).copy()) 13 | images_t = cj_module.forward(images_t / 255.0) * 255.0 14 | images = images_t.numpy().astype(np.uint8).transpose(1, 2, 0) 15 | image1, image2 = images[:image1.shape[0]], images[image1.shape[0]:] 16 | 17 | return image1, image2 18 | 19 | 20 | def flip_point_cloud(pc, image_h, image_w, f, cx, cy, flip_mode): 21 | assert flip_mode in ['lr', 'ud'] 22 | pc_x, pc_y, depth = pc[..., 0], pc[..., 1], pc[..., 2] 23 | 24 | image_x = cx + (f / depth) * pc_x 25 | image_y = cy + (f / depth) * pc_y 26 | 27 | if flip_mode == 'lr': 28 | image_x = image_w - 1 - image_x 29 | else: 30 | image_y = image_h - 1 - image_y 31 | 32 | pc_x = (image_x - cx) * depth / f 33 | pc_y = (image_y - cy) * depth / f 34 | pc = np.concatenate([pc_x[:, None], pc_y[:, None], depth[:, None]], axis=-1) 35 | 36 | return pc 37 | 38 | 39 | def flip_scene_flow(pc1, flow_3d, image_h, image_w, f, cx, cy, flip_mode): 40 | new_pc1 = flip_point_cloud(pc1, image_h, image_w, f, cx, cy, flip_mode) 41 | new_pc1_warp = flip_point_cloud(pc1 + flow_3d[:, :3], image_h, image_w, f, cx, cy, flip_mode) 42 | return np.concatenate([new_pc1_warp - new_pc1, flow_3d[:, 3:]], axis=-1) 43 | 44 | 45 | def flip_image(image, flip_mode): 46 | if flip_mode == 'lr': 47 | return np.fliplr(image).copy() 48 | else: 49 | return np.flipud(image).copy() 50 | 51 | 52 | def flip_optical_flow(flow, flip_mode): 53 | assert flip_mode in ['lr', 'ud'] 54 | if flip_mode == 'lr': 55 | flow = np.fliplr(flow).copy() 56 | flow[:, :, 0] *= -1 57 | else: 58 | flow = np.flipud(flow).copy() 59 | flow[:, :, 1] *= -1 60 | return flow 61 | 62 | 63 | def random_flip(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, flip_mode): 64 | assert flow_3d.shape[1] <= 4 65 | assert flip_mode in ['lr', 'ud'] 66 | image_h, image_w = image1.shape[:2] 67 | 68 | if np.random.rand() < 0.5: # do nothing 69 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy 70 | 71 | if flip_mode == 'lr': 72 | new_f = -1.0 * f 73 | new_cx = image_w - 1 - cx 74 | new_cy = cy 75 | else: 76 | new_f = -1.0 * f 77 | new_cy = image_h - 1 - cy 78 | new_cx = cx 79 | 80 | new_mask = flip_image(mask, flip_mode) 81 | # flip images 82 | new_image1 = flip_image(image1, flip_mode) 83 | new_image2 = flip_image(image2, flip_mode) 84 | 85 | # flip point clouds 86 | new_pc1 = flip_image(pc1, flip_mode) 87 | new_pc2 = flip_image(pc2, flip_mode) 88 | 89 | # flip optical flow and scene flow 90 | new_flow_2d = flip_optical_flow(flow_2d, flip_mode) 91 | new_flow_3d = flip_image(flow_3d, flip_mode) 92 | # new_flow_3d = flip_scene_flow(pc1, flow_3d, image_h, image_w, f, cx, cy, flip_mode) 93 | 94 | return new_image1, new_image2, new_pc1, new_pc2, new_flow_2d, flow_3d, new_mask, new_f, new_cx, new_cy 95 | 96 | 97 | def crop_image_with_pc(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, crop_window): 98 | assert len(crop_window) == 4 # [x1, y1, x2, y2] 99 | 100 | x1, y1, x2, y2 = crop_window 101 | # image_h, image_w = image1.shape[:2] 102 | 103 | # crop images 104 | image1 = image1[y1:y2, x1:x2].copy() 105 | image2 = image2[y1:y2, x1:x2].copy() 106 | flow_2d = flow_2d[y1:y2, x1:x2].copy() 107 | 108 | # crop pc hw3 109 | pc1 = pc1[y1:y2, x1:x2].copy() 110 | pc2 = pc2[y1:y2, x1:x2].copy() 111 | flow_3d = flow_3d[y1:y2, x1:x2].copy() 112 | mask = mask[y1:y2, x1:x2].copy() 113 | 114 | # adjust camera params 115 | cx = cx - x1 116 | cy = cy - y1 117 | 118 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy 119 | 120 | 121 | def random_crop(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, crop_size): 122 | assert flow_3d.shape[-1] <= 4 123 | assert len(crop_size) == 2 124 | crop_w, crop_h = crop_size 125 | 126 | image_h, image_w = image1.shape[:2] 127 | assert crop_w <= image_w and crop_h <= image_h 128 | 129 | # top left of the cropping window 130 | x1 = np.random.randint(low=0, high=image_w - crop_w + 1) 131 | y1 = np.random.randint(low=0, high=image_h - crop_h + 1) 132 | crop_window = [x1, y1, x1 + crop_w, y1 + crop_h] 133 | 134 | return crop_image_with_pc(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, crop_window) 135 | 136 | 137 | def resize_sparse_flow_map(flow, target_w, target_h): 138 | curr_h, curr_w = flow.shape[:2] 139 | 140 | coords = np.meshgrid(np.arange(curr_w), np.arange(curr_h)) 141 | coords = np.stack(coords, axis=-1).astype(np.float32) 142 | 143 | mask = flow[..., -1] > 0 144 | coords0, flow0 = coords[mask], flow[mask][:, :2] 145 | 146 | scale_ratio_w = (target_w - 1) / (curr_w - 1) 147 | scale_ratio_h = (target_h - 1) / (curr_h - 1) 148 | 149 | coords1 = coords0 * [scale_ratio_w, scale_ratio_h] 150 | flow1 = flow0 * [scale_ratio_w, scale_ratio_h] 151 | 152 | xx = np.round(coords1[:, 0]).astype(np.int32) 153 | yy = np.round(coords1[:, 1]).astype(np.int32) 154 | valid = (xx >= 0) & (xx < target_w) & (yy >= 0) & (yy < target_h) 155 | xx, yy, flow1 = xx[valid], yy[valid], flow1[valid] 156 | 157 | flow_resized = np.zeros([target_h, target_w, 3], dtype=np.float32) 158 | flow_resized[yy, xx, :2] = flow1 159 | flow_resized[yy, xx, 2:] = 1.0 160 | 161 | return flow_resized 162 | 163 | 164 | def random_scale(image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy, scale_range): 165 | assert len(scale_range) == 2 166 | assert 1 <= scale_range[0] < scale_range[1] 167 | 168 | if np.random.rand() < 0.5: 169 | return image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy 170 | 171 | scale_ratio = np.random.uniform(scale_range[0], scale_range[1]) 172 | image_h, image_w = image1.shape[:2] 173 | crop_h, crop_w = int(image_h / scale_ratio), int(image_w / scale_ratio) 174 | 175 | # top left of the cropping window 176 | x1 = np.random.randint(low=0, high=image_w - crop_w + 1) 177 | y1 = np.random.randint(low=0, high=image_h - crop_h + 1) 178 | crop_window = [x1, y1, x1 + crop_w, y1 + crop_h] 179 | 180 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy = crop_image_with_pc( 181 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy, crop_window 182 | ) 183 | 184 | # resize images and optical flow 185 | image1 = cv2.resize(image1, (image_w, image_h), interpolation=cv2.INTER_LINEAR) 186 | image2 = cv2.resize(image2, (image_w, image_h), interpolation=cv2.INTER_LINEAR) 187 | flow_2d = resize_sparse_flow_map(flow_2d, image_w, image_h) 188 | 189 | # resize points and scene flow 190 | scale_ratio_w = (image_w - 1) / (crop_w - 1) 191 | scale_ratio_h = (image_h - 1) / (crop_h - 1) 192 | pc1[:, 0] *= scale_ratio_w 193 | pc1[:, 1] *= scale_ratio_h 194 | pc2[:, 0] *= scale_ratio_w 195 | pc2[:, 1] *= scale_ratio_h 196 | flow_3d[:, 0] *= scale_ratio_w 197 | flow_3d[:, 1] *= scale_ratio_h 198 | 199 | # adjust camera params 200 | cx *= scale_ratio_w 201 | cy *= scale_ratio_h 202 | 203 | return image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy 204 | 205 | 206 | def random_downsample(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, tpye): 207 | 208 | if type == "eval": 209 | a = 0; b = 0 210 | else: 211 | a, b = np.random.randint(0, 2, 2) 212 | 213 | return image1[a::2, b::2, :], image2[a::2, b::2, :], pc1[a::2, b::2, :], pc2[a::2, b::2, :], \ 214 | flow_2d[a::2, b::2, :] * 0.5, flow_3d[a::2, b::2, :], mask[a::2, b::2], f / 2.0, (cx - b) / 2.0, (cy - a) / 2.0 215 | 216 | 217 | def scale_crop(image1, image2, pc1, pc2, flow_2d, flow_3d, mask): 218 | 219 | return image1[:256, :448, ...], image2[:256, :448, ...], pc1[:256, :448, ...], pc2[:256, :448, ...], flow_2d[:256, :448, ...], flow_3d[:256, :448, ...], mask[:256, :448, ...] 220 | 221 | 222 | 223 | def joint_augmentation(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, cfgs): 224 | if not cfgs.enabled: 225 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy 226 | 227 | if cfgs.color_jitter.enabled: 228 | image1, image2 = color_jitter( 229 | image1, image2, 230 | brightness=cfgs.color_jitter.brightness, 231 | contrast=cfgs.color_jitter.contrast, 232 | saturation=cfgs.color_jitter.saturation, 233 | hue=cfgs.color_jitter.hue, 234 | ) 235 | 236 | # u = fx * x / z + cx, v = fy * y / z + cy 237 | # w - u = - fx * x + w - cx , h - v = - fy * y / z + h - cy 238 | if cfgs.random_horizontal_flip.enabled: 239 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_flip( 240 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, flip_mode='lr' 241 | ) 242 | 243 | if cfgs.random_vertical_flip.enabled: 244 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_flip( 245 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, flip_mode='ud' 246 | ) 247 | 248 | if cfgs.random_crop.enabled: 249 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_crop( 250 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, 251 | crop_size=cfgs.random_crop.crop_size 252 | ) 253 | 254 | if cfgs.random_scale.enabled: 255 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy = random_scale( 256 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy, 257 | scale_range=cfgs.random_scale.scale_range 258 | ) 259 | 260 | if cfgs.random_down.enabled: 261 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_downsample( 262 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, cfgs.random_down.type) 263 | 264 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask = scale_crop( 265 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask) 266 | 267 | 268 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy 269 | -------------------------------------------------------------------------------- /dataset/flyingthings_subset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | import torch 5 | import cv2 6 | from .augmentation import joint_augmentation, random_downsample, scale_crop 7 | 8 | __all__ = ['FlyingThings'] 9 | 10 | def load_flow_png(filepath, scale=64.0): 11 | # for KITTI which uses 16bit PNG images 12 | # see 'https://github.com/ClementPinard/FlowNetPytorch/blob/master/datasets/KITTI.py' 13 | # The -1 is here to specify not to change the image depth (16bit), and is compatible 14 | # with both OpenCV2 and OpenCV3 15 | flow_img = cv2.imread(filepath, -1) 16 | flow = flow_img[:, :, 2:0:-1].astype(np.float32) 17 | mask = flow_img[:, :, 0] > 0 18 | flow = flow - 32768.0 19 | flow = flow / scale 20 | return flow, mask 21 | 22 | class FlyingThings3D(torch.utils.data.Dataset): 23 | def __init__(self, cfgs): 24 | assert os.path.isdir(cfgs.root_dir) 25 | 26 | self.root_dir = str(cfgs.root_dir) 27 | self.split = str(cfgs.split) 28 | self.split_dir = os.path.join(self.root_dir, self.split) 29 | self.cfgs = cfgs 30 | 31 | self.indices = [] 32 | for filename in os.listdir(os.path.join(self.root_dir, self.split, 'flow_2d')): 33 | self.indices.append(int(filename.split('.')[0])) 34 | 35 | if not cfgs.full: 36 | self.indices = self.indices[::4] 37 | 38 | 39 | def __len__(self): 40 | return len(self.indices) 41 | 42 | def __getitem__(self, i): 43 | if not self.cfgs.augmentation.enabled: 44 | np.random.seed(0) 45 | 46 | idx1 = self.indices[i] 47 | idx2 = idx1 + 1 48 | data_dict = {'index': idx1} 49 | 50 | # camera intrinsics 51 | f, cx, cy = 1050, 479.5, 269.5 52 | 53 | # load data 54 | pcs = np.load(os.path.join(self.split_dir, 'pc', '%07d.npz' % idx1)) 55 | pc1, pc2 = pcs['pc1'], pcs['pc2'] 56 | 57 | flow_2d, flow_mask_2d = load_flow_png(os.path.join(self.split_dir, 'flow_2d', '%07d.png' % idx1)) 58 | flow_3d = np.load(os.path.join(self.split_dir, 'flow_3d', '%07d.npy' % idx1)) 59 | 60 | occ_mask_3d = np.load(os.path.join(self.split_dir, 'occ_mask_3d', '%07d.npy' % idx1)) 61 | occ_mask_3d = np.unpackbits(occ_mask_3d, count=len(pc1)) 62 | 63 | image1 = cv2.imread(os.path.join(self.split_dir, 'image', '%07d.png' % idx1))[..., ::-1].copy().astype(np.float32) 64 | image2 = cv2.imread(os.path.join(self.split_dir, 'image', '%07d.png' % idx2))[..., ::-1].copy().astype(np.float32) 65 | 66 | H, W = image1.shape[:2] 67 | 68 | pc1_mask = np.load(os.path.join(self.split_dir, 'pc1_mask', '%07d.npy' % idx1)) 69 | pc1_mask = np.unpackbits(pc1_mask).reshape(H, W).astype(np.bool_) 70 | 71 | pc2_mask = np.load(os.path.join(self.split_dir, 'pc2_mask', '%07d.npy' % idx1)) 72 | pc2_mask = np.unpackbits(pc2_mask).reshape(H, W).astype(np.bool_) 73 | 74 | 75 | # ignore fast moving objects 76 | flow_mask_2d = np.logical_and(flow_mask_2d, np.linalg.norm(flow_2d, axis=-1) < 250.0) 77 | flow_2d = np.concatenate([flow_2d, flow_mask_2d[..., None].astype(np.float32)], axis=2) 78 | 79 | 80 | pc1_hw3 = np.zeros((H, W, 3), dtype = np.float32) 81 | pc2_hw3 = np.zeros((H, W, 3), dtype = np.float32) 82 | flow_hw3 = np.zeros((H, W, 3), dtype = np.float32) 83 | valid_mask = np.zeros((H, W), dtype = np.bool_) 84 | 85 | 86 | pc1_hw3[pc1_mask] = pc1 87 | pc2_hw3[pc2_mask] = pc2 88 | flow_hw3[pc1_mask] = flow_3d 89 | valid_mask[pc1_mask] = np.logical_not(occ_mask_3d) 90 | # valid_mask = pc1_mask 91 | 92 | image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask, f, cx, cy = joint_augmentation( 93 | image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask, f, cx, cy, self.cfgs.augmentation 94 | ) 95 | 96 | # image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask = scale_crop(image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask) 97 | 98 | intrinsics = np.float32([f, f, cx, cy]) 99 | 100 | image1 = np.ascontiguousarray(image1.transpose([2, 0, 1])) 101 | image2 = np.ascontiguousarray(image2.transpose([2, 0, 1])) 102 | pc1_3hw = np.ascontiguousarray(pc1_hw3.transpose([2, 0, 1])) 103 | pc2_3hw = np.ascontiguousarray(pc2_hw3.transpose([2, 0, 1])) 104 | flow_3hw = np.ascontiguousarray(flow_hw3.transpose([2, 0, 1])) 105 | valid_mask = np.ascontiguousarray(valid_mask) 106 | pc1_mask = np.ascontiguousarray(pc1_mask) 107 | data_dict['image1'] = image1 # 3 x H x W 108 | data_dict['image2'] = image2 # 3 x H x W 109 | data_dict['pc1'] = pc1_3hw # 3 x H x W 110 | data_dict['pc2'] = pc2_3hw # 3 x H x W 111 | data_dict['flow_3d'] = flow_3hw # 3 x H x W 112 | data_dict['flow_2d'] = flow_2d.transpose([2, 0, 1]) # 3 x H x W 113 | data_dict['mask'] = valid_mask # H x W 114 | 115 | non_zero_mask = np.linalg.norm(pc1_3hw, axis = 0) > 1e-8 116 | data_dict['nonzero_mask'] = non_zero_mask # H x W 117 | data_dict['intrinsics'] = intrinsics # 4 118 | 119 | return data_dict 120 | -------------------------------------------------------------------------------- /dataset/kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import torch.utils.data 5 | import cv2 6 | import numpy as np 7 | import torch.utils.data 8 | 9 | 10 | def load_flow_png(filepath, scale=64.0): 11 | # for KITTI which uses 16bit PNG images 12 | # see 'https://github.com/ClementPinard/FlowNetPytorch/blob/master/datasets/KITTI.py' 13 | # The -1 is here to specify not to change the image depth (16bit), and is compatible 14 | # with both OpenCV2 and OpenCV3 15 | flow_img = cv2.imread(filepath, -1) 16 | flow = flow_img[:, :, 2:0:-1].astype(np.float32) 17 | mask = flow_img[:, :, 0] > 0 18 | flow = flow - 32768.0 19 | flow = flow / scale 20 | return flow, mask 21 | 22 | def load_disp_png(filepath): 23 | array = cv2.imread(filepath, -1) 24 | valid_mask = array > 0 25 | disp = array.astype(np.float32) / 256.0 26 | disp[np.logical_not(valid_mask)] = -1.0 27 | return disp, valid_mask 28 | 29 | def load_calib(filepath): 30 | with open(filepath) as f: 31 | lines = f.readlines() 32 | for line in lines: 33 | if line.startswith('P_rect_02'): 34 | proj_mat = line.split()[1:] 35 | proj_mat = [float(param) for param in proj_mat] 36 | proj_mat = np.array(proj_mat, dtype=np.float32).reshape(3, 4) 37 | assert proj_mat[0, 1] == proj_mat[1, 0] == 0 38 | assert proj_mat[2, 0] == proj_mat[2, 1] == 0 39 | assert proj_mat[0, 0] == proj_mat[1, 1] 40 | assert proj_mat[2, 2] == 1 41 | 42 | return proj_mat 43 | 44 | 45 | def disp2pc(disp, baseline, f, cx, cy, flow=None): 46 | h, w = disp.shape 47 | depth = baseline * f / (disp + 1e-5) 48 | 49 | xx = np.tile(np.arange(w, dtype=np.float32)[None, :], (h, 1)) 50 | yy = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w)) 51 | 52 | if flow is None: 53 | x = (xx - cx) * depth / f 54 | y = (yy - cy) * depth / f 55 | else: 56 | x = (xx - cx + flow[..., 0]) * depth / f 57 | y = (yy - cy + flow[..., 1]) * depth / f 58 | 59 | pc = np.concatenate([ 60 | x[:, :, None], 61 | y[:, :, None], 62 | depth[:, :, None], 63 | ], axis=-1) 64 | 65 | return pc 66 | 67 | 68 | 69 | def zero_padding(inputs, pad_h, pad_w): 70 | input_dim = len(inputs.shape) 71 | assert input_dim in [2, 3] 72 | 73 | if input_dim == 2: 74 | inputs = inputs[..., None] 75 | 76 | h, w, c = inputs.shape 77 | assert h <= pad_h and w <= pad_w 78 | 79 | result = np.zeros([pad_h, pad_w, c], dtype=inputs.dtype) 80 | result[:h, :w, :c] = inputs 81 | 82 | if input_dim == 2: 83 | result = result[..., 0] 84 | 85 | return result 86 | 87 | 88 | class KITTI(torch.utils.data.Dataset): 89 | def __init__(self, cfgs): 90 | assert os.path.isdir(cfgs.root_dir) 91 | assert cfgs.split in ['training200', 'training160', 'training40'] 92 | 93 | self.root_dir = os.path.join(cfgs.root_dir, 'training') 94 | self.split = cfgs.split 95 | self.cfgs = cfgs 96 | self.crop = 80 97 | 98 | if self.split == 'training200': 99 | self.indices = np.arange(200) 100 | elif self.split == 'training160': 101 | self.indices = [i for i in range(200) if i % 5 != 0] 102 | elif self.split == 'training40': 103 | self.indices = [i for i in range(200) if i % 5 == 0] 104 | 105 | def __len__(self): 106 | return len(self.indices) 107 | 108 | def __getitem__(self, i): 109 | np.random.seed(23333) 110 | 111 | index = self.indices[i] 112 | data_dict = {'index': index} 113 | 114 | proj_mat = load_calib(os.path.join(self.root_dir, 'calib_cam_to_cam', '%06d.txt' % index)) 115 | f, cx, cy = proj_mat[0, 0], proj_mat[0, 2], proj_mat[1, 2] 116 | 117 | image1 = cv2.imread(os.path.join(self.root_dir, 'image_2', '%06d_10.png' % index))[..., ::-1].copy().astype(np.float32) 118 | image2 = cv2.imread(os.path.join(self.root_dir, 'image_2', '%06d_11.png' % index))[..., ::-1].copy().astype(np.float32) 119 | flow_2d, flow_2d_mask = load_flow_png(os.path.join(self.root_dir, 'flow_occ', '%06d_10.png' % index)) 120 | 121 | data_dict['input_h'] = image1.shape[0] 122 | data_dict['input_w'] = image1.shape[1] 123 | 124 | disp1, mask1 = load_disp_png(os.path.join(self.root_dir, 'disp_occ_0', '%06d_10.png' % index)) 125 | disp2, mask2 = load_disp_png(os.path.join(self.root_dir, 'disp_occ_1', '%06d_10.png' % index)) 126 | valid = np.logical_and(np.logical_and(mask1, mask2), flow_2d_mask) 127 | 128 | valid = np.logical_and(valid, disp2 > 0.0) 129 | 130 | disp1_dense, mask1_dense = load_disp_png(os.path.join(self.root_dir, 'disp_ganet_training', '%06d_10.png' % index)) 131 | disp2_dense, mask2_dense = load_disp_png(os.path.join(self.root_dir, 'disp_ganet_training', '%06d_11.png' % index)) 132 | 133 | 134 | flow_3d = disp2pc(disp2, baseline=0.54, f=f, cx=cx, cy=cy, flow=flow_2d) - disp2pc(disp1, baseline=0.54, f=f, cx=cx, cy=cy) 135 | 136 | pc1 = disp2pc(disp1_dense, 0.54, f=f, cx=cx, cy=cy) 137 | pc2 = disp2pc(disp2_dense, 0.54, f=f, cx=cx, cy=cy) 138 | 139 | 140 | pc1 = pc1[self.crop:] 141 | pc2 = pc2[self.crop:] 142 | 143 | 144 | # limit max depth 145 | pc1[pc1[..., -1] < self.cfgs.max_depth] = 0.0 146 | pc2[pc2[..., -1] < self.cfgs.max_depth] = 0.0 147 | 148 | image1 = image1[self.crop:] 149 | image2 = image2[self.crop:] 150 | intrinsics = np.array([f, f, cx, cy - self.crop]) 151 | flow_3d = flow_3d[self.crop:] 152 | flow_2d = flow_2d[self.crop:] 153 | valid = valid[self.crop:] 154 | 155 | 156 | padding_h, padding_w = 320, 1280 157 | image1 = zero_padding(image1, padding_h, padding_w) 158 | image2 = zero_padding(image2, padding_h, padding_w) 159 | pc1 = zero_padding(pc1, padding_h, padding_w) 160 | pc2 = zero_padding(pc2, padding_h, padding_w) 161 | valid = zero_padding(valid, padding_h, padding_w) 162 | flow_3d = zero_padding(flow_3d, padding_h, padding_w) 163 | flow_2d = zero_padding(flow_2d, padding_h, padding_w) 164 | 165 | data_dict['image1'] = np.ascontiguousarray(image1.transpose([2, 0, 1])) # 3 x H x W 166 | data_dict['image2'] = np.ascontiguousarray(image2.transpose([2, 0, 1])) # 3 x H x W 167 | data_dict['pc1'] = np.ascontiguousarray(pc1.transpose([2, 0, 1])) # 3 x H x W 168 | data_dict['pc2'] = np.ascontiguousarray(pc2.transpose([2, 0, 1])) # 3 x H x W 169 | non_zero_mask = np.linalg.norm(data_dict['pc1'], axis = 0) > 1e-8 170 | data_dict['nonzero_mask'] = non_zero_mask # H x W 171 | data_dict['intrinsics'] = intrinsics # 4 172 | flow_2d = np.concatenate([flow_2d, flow_2d_mask[..., None].astype(np.float32)], axis=-1) # H x W x 3 173 | data_dict['flow_3d'] = np.ascontiguousarray(flow_3d.transpose([2, 0, 1])) # 3 x H x W 174 | data_dict['flow_2d'] = np.ascontiguousarray(flow_2d.transpose([2, 0, 1])) # 3 x H x W 175 | data_dict['mask'] = valid & non_zero_mask # H x W 176 | 177 | 178 | return data_dict -------------------------------------------------------------------------------- /dataset/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import shutil 4 | import logging 5 | import argparse 6 | import torch.utils.data 7 | import numpy as np 8 | from tqdm import tqdm 9 | import re 10 | import sys 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--input_dir', default = '/data/FlyingThings3D_subset', help='Path to the FlyingThings3D subset') 15 | parser.add_argument('--output_dir', required=False, default='/data/processed_flyingthings3d') 16 | parser.add_argument('--max_depth', required=False, default=35.0) 17 | parser.add_argument('--remove_occluded_points', action='store_true') 18 | args = parser.parse_args() 19 | 20 | def init_logging(filename=None, debug=False): 21 | logging.root = logging.RootLogger('DEBUG' if debug else 'INFO') 22 | formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s') 23 | 24 | stream_handler = logging.StreamHandler(sys.stdout) 25 | stream_handler.setFormatter(formatter) 26 | logging.root.addHandler(stream_handler) 27 | 28 | if filename is not None: 29 | file_handler = logging.FileHandler(filename) 30 | file_handler.setFormatter(formatter) 31 | logging.root.addHandler(file_handler) 32 | 33 | def load_fpm(filename): 34 | with open(filename, 'rb') as f: 35 | header = f.readline().rstrip() 36 | if header.decode("ascii") == 'PF': 37 | color = True 38 | elif header.decode("ascii") == 'Pf': 39 | color = False 40 | else: 41 | raise Exception('Not a PFM file.') 42 | 43 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', f.readline().decode("ascii")) 44 | if dim_match: 45 | width, height = list(map(int, dim_match.groups())) 46 | else: 47 | raise Exception('Malformed PFM header.') 48 | 49 | scale = float(f.readline().decode("ascii").rstrip()) 50 | if scale < 0: # little-endian 51 | endian = '<' 52 | else: 53 | endian = '>' # big-endian 54 | 55 | data = np.fromfile(f, endian + 'f') 56 | shape = (height, width, 3) if color else (height, width) 57 | data = np.reshape(data, shape) 58 | data = np.flipud(data) 59 | 60 | return data 61 | 62 | def disp2pc(disp, baseline, f, cx, cy, flow=None): 63 | h, w = disp.shape 64 | depth = baseline * f / (disp + 1e-5) 65 | 66 | xx = np.tile(np.arange(w, dtype=np.float32)[None, :], (h, 1)) 67 | yy = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w)) 68 | 69 | if flow is None: 70 | x = (xx - cx) * depth / f 71 | y = (yy - cy) * depth / f 72 | else: 73 | x = (xx - cx + flow[..., 0]) * depth / f 74 | y = (yy - cy + flow[..., 1]) * depth / f 75 | 76 | pc = np.concatenate([ 77 | x[:, :, None], 78 | y[:, :, None], 79 | depth[:, :, None], 80 | ], axis=-1) 81 | 82 | return pc 83 | 84 | def load_flow(filepath): 85 | with open(filepath, 'rb') as f: 86 | magic = np.fromfile(f, np.float32, count=1) 87 | assert (202021.25 == magic), 'Invalid .flo file: incorrect magic number' 88 | w = np.fromfile(f, np.int32, count=1)[0] 89 | h = np.fromfile(f, np.int32, count=1)[0] 90 | flow = np.fromfile(f, np.float32, count=2 * w * h).reshape([h, w, 2]) 91 | 92 | return flow 93 | 94 | 95 | def save_flow_png(filepath, flow, mask=None, scale=64.0): 96 | assert flow.shape[2] == 2 97 | assert np.abs(flow).max() < 32767.0 / scale 98 | flow = flow * scale 99 | flow = flow + 32768.0 100 | 101 | if mask is None: 102 | mask = np.ones_like(flow)[..., 0] 103 | else: 104 | mask = np.float32(mask > 0) 105 | 106 | flow_img = np.concatenate([ 107 | mask[..., None], 108 | flow[..., 1:2], 109 | flow[..., 0:1] 110 | ], axis=-1).astype(np.uint16) 111 | 112 | cv2.imwrite(filepath, flow_img) 113 | 114 | 115 | 116 | 117 | class Preprocessor(torch.utils.data.Dataset): 118 | def __init__(self, input_dir, output_dir, split, max_depth, remove_occluded_points): 119 | super(Preprocessor, self).__init__() 120 | 121 | self.input_dir = input_dir 122 | self.output_dir = output_dir 123 | self.split = split 124 | self.max_depth = max_depth 125 | self.remove_occluded_points = remove_occluded_points 126 | 127 | self.indices = [] 128 | for filename in os.listdir(os.path.join(input_dir, split, 'flow', 'left', 'into_future')): 129 | index = int(filename.split('.')[0]) 130 | self.indices.append(index) 131 | 132 | def __len__(self): 133 | return len(self.indices) 134 | 135 | def __getitem__(self, i): 136 | np.random.seed(0) 137 | 138 | index1 = self.indices[i] 139 | index2 = index1 + 1 140 | 141 | # camera intrinsics 142 | baseline, f, cx, cy = 1.0, 1050.0, 479.5, 269.5 143 | 144 | # load data 145 | disp1 = -load_fpm(os.path.join( 146 | self.input_dir, self.split, 'disparity', 'left', '%07d.pfm' % index1 147 | )) 148 | disp2 = -load_fpm(os.path.join( 149 | self.input_dir, self.split, 'disparity', 'left', '%07d.pfm' % index2 150 | )) 151 | disp1_change = -load_fpm(os.path.join( 152 | self.input_dir, self.split, 'disparity_change', 'left', 'into_future', '%07d.pfm' % index1 153 | )) 154 | flow_2d = load_flow(os.path.join( 155 | self.input_dir, self.split, 'flow', 'left', 'into_future', '%07d.flo' % index1 156 | )) 157 | occ_mask_2d = cv2.imread(os.path.join( 158 | self.input_dir, self.split, 'flow_occlusions', 'left', 'into_future', '%07d.png' % index1 159 | )) 160 | occ_mask_2d = occ_mask_2d[..., 0] > 1 161 | 162 | if self.remove_occluded_points: 163 | pc1 = disp2pc(disp1, baseline, f, cx, cy) 164 | pc2 = disp2pc(disp1 + disp1_change, baseline, f, cx, cy, flow_2d) 165 | 166 | # apply non-occlusion mask 167 | noc_mask_2d = np.logical_not(occ_mask_2d) 168 | pc1, pc2 = pc1[noc_mask_2d], pc2[noc_mask_2d] 169 | 170 | # apply depth mask 171 | mask = np.logical_and(pc1[..., -1] < self.max_depth, pc2[..., -1] < self.max_depth) 172 | pc1, pc2 = pc1[mask], pc2[mask] 173 | 174 | # NaN check 175 | mask = np.logical_not(np.isnan(np.sum(pc1, axis=-1) + np.sum(pc2, axis=-1))) 176 | pc1, pc2 = pc1[mask], pc2[mask] 177 | 178 | # compute scene flow 179 | flow_3d = pc2 - pc1 180 | occ_mask_3d = np.zeros(len(pc1), dtype=np.bool) 181 | else: 182 | pc1 = disp2pc(disp1, baseline, f, cx, cy) 183 | pc2 = disp2pc(disp2, baseline, f, cx, cy) 184 | flow_3d = disp2pc(disp1 + disp1_change, baseline, f, cx, cy, flow_2d) - pc1 # h x w x 3 185 | 186 | 187 | x, y = np.meshgrid(np.arange(pc1.shape[1]), np.arange(pc1.shape[0])) 188 | coords = np.concatenate([y[:, :, None], x[:, :, None]], axis= -1) # h x w x 2 189 | 190 | # apply depth mask and NaN check 191 | mask1 = np.logical_and((pc1[..., -1] < self.max_depth), np.logical_not(np.isnan(np.sum(pc1, axis=-1) + np.sum(flow_3d, axis=-1)))) 192 | mask2 = np.logical_and((pc2[..., -1] < self.max_depth), np.logical_not(np.isnan(np.sum(pc2, axis=-1)))) 193 | 194 | 195 | 196 | pc1, pc2, flow_3d, occ_mask_3d = pc1[mask1], pc2[mask2], flow_3d[mask1], occ_mask_2d[mask1] 197 | 198 | 199 | # save point clouds and occ mask 200 | np.savez( 201 | os.path.join(self.output_dir, self.split, 'pc', '%07d.npz' % index1), 202 | pc1=pc1, pc2=pc2 203 | ) 204 | np.save( 205 | os.path.join(self.output_dir, self.split, 'occ_mask_3d', '%07d.npy' % index1), 206 | np.packbits(occ_mask_3d) 207 | ) 208 | 209 | np.save( 210 | os.path.join(self.output_dir, self.split, 'pc1_mask', '%07d.npy' % index1), 211 | np.packbits(mask1) 212 | ) 213 | 214 | np.save( 215 | os.path.join(self.output_dir, self.split, 'pc2_mask', '%07d.npy' % index1), 216 | np.packbits(mask2) 217 | ) 218 | 219 | # mask regions moving extremely fast 220 | flow_mask = np.logical_and(np.abs(flow_2d[..., 0]) < 500, np.abs(flow_2d[..., 1]) < 500) 221 | flow_2d[np.logical_not(flow_mask)] = 0.0 222 | 223 | # save ground-truth flow 224 | save_flow_png( 225 | os.path.join(self.output_dir, self.split, 'flow_2d', '%07d.png' % index1), 226 | flow_2d, flow_mask 227 | ) 228 | np.save( 229 | os.path.join(self.output_dir, self.split, 'flow_3d', '%07d.npy' % index1), 230 | flow_3d 231 | ) 232 | 233 | return 0 234 | 235 | 236 | def main(): 237 | for split_idx, split in enumerate(['train', 'val']): 238 | 239 | if not os.path.exists(os.path.join(args.input_dir, split)): 240 | print(os.path.join(args.input_dir, split)) 241 | continue 242 | 243 | logging.info('Processing "%s" split...' % split) 244 | 245 | os.makedirs(os.path.join(args.output_dir, split, 'pc'), exist_ok=True) 246 | os.makedirs(os.path.join(args.output_dir, split, 'flow_2d'), exist_ok=True) 247 | os.makedirs(os.path.join(args.output_dir, split, 'flow_3d'), exist_ok=True) 248 | os.makedirs(os.path.join(args.output_dir, split, 'pc1_mask'), exist_ok=True) 249 | os.makedirs(os.path.join(args.output_dir, split, 'pc2_mask'), exist_ok=True) 250 | 251 | if not os.path.exists(os.path.join(args.output_dir, split, 'image')): 252 | logging.info('Copying images...') 253 | shutil.copytree( 254 | src=os.path.join(args.input_dir, split, 'image_clean', 'left'), 255 | dst=os.path.join(args.output_dir, split, 'image') 256 | ) 257 | 258 | if not os.path.exists(os.path.join(args.output_dir, split, 'occ_mask_2d')): 259 | logging.info('Copying occ_mask_2d...') 260 | shutil.copytree( 261 | src=os.path.join(args.input_dir, split, 'flow_occlusions', 'left', 'into_future'), 262 | dst=os.path.join(args.output_dir, split, 'occ_mask_2d') 263 | ) 264 | 265 | logging.info('Generating point clouds...') 266 | preprocessor = Preprocessor( 267 | args.input_dir, 268 | args.output_dir, 269 | split, 270 | args.max_depth, 271 | args.remove_occluded_points, 272 | ) 273 | preprocessor = torch.utils.data.DataLoader(dataset=preprocessor, num_workers=4) 274 | 275 | for i in tqdm(preprocessor): 276 | # print(i) 277 | # if i > 5: 278 | # break 279 | pass 280 | 281 | 282 | if __name__ == '__main__': 283 | init_logging() 284 | main() 285 | logging.info('All done.') 286 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: delflow 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - archspec=0.2.3=pyhd3eb1b0_0 10 | - asttokens=2.0.5=pyhd3eb1b0_0 11 | - attrs=23.1.0=py310h06a4308_0 12 | - beautifulsoup4=4.12.2=py310h06a4308_0 13 | - blas=1.0=mkl 14 | - boltons=23.0.0=py310h06a4308_0 15 | - brotli-python=1.0.9=py310h6a678d5_7 16 | - bzip2=1.0.8=h7b6447c_0 17 | - c-ares=1.19.1=h5eee18b_0 18 | - ca-certificates=2024.3.11=h06a4308_0 19 | - certifi=2024.2.2=py310h06a4308_0 20 | - cffi=1.16.0=py310h5eee18b_0 21 | - chardet=4.0.0=py310h06a4308_1003 22 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 23 | - click=8.1.7=py310h06a4308_0 24 | - cmake=3.26.4=h96355d8_0 25 | - conda=23.9.0=py310h06a4308_0 26 | - conda-build=24.3.0=py310h06a4308_0 27 | - conda-content-trust=0.2.0=py310h06a4308_0 28 | - conda-index=0.4.0=pyhd3eb1b0_0 29 | - conda-libmamba-solver=23.7.0=py310h06a4308_0 30 | - conda-package-handling=2.2.0=py310h06a4308_0 31 | - conda-package-streaming=0.9.0=py310h06a4308_0 32 | - cryptography=42.0.5=py310hdda0065_0 33 | - cuda-cudart=12.1.105=0 34 | - cuda-cupti=12.1.105=0 35 | - cuda-libraries=12.1.0=0 36 | - cuda-nvrtc=12.1.105=0 37 | - cuda-nvtx=12.1.105=0 38 | - cuda-opencl=12.4.99=0 39 | - cuda-runtime=12.1.0=0 40 | - decorator=5.1.1=pyhd3eb1b0_0 41 | - distro=1.8.0=py310h06a4308_0 42 | - exceptiongroup=1.2.0=py310h06a4308_0 43 | - executing=0.8.3=pyhd3eb1b0_0 44 | - expat=2.5.0=h6a678d5_0 45 | - ffmpeg=4.3=hf484d3e_0 46 | - filelock=3.13.1=py310h06a4308_0 47 | - fmt=9.1.0=hdb19cb5_0 48 | - freetype=2.12.1=h4a9f257_0 49 | - gmp=6.2.1=h295c915_3 50 | - gmpy2=2.1.2=py310heeb90bb_0 51 | - gnutls=3.6.15=he1e5248_0 52 | - icu=73.1=h6a678d5_0 53 | - idna=3.4=py310h06a4308_0 54 | - intel-openmp=2023.1.0=hdb19cb5_46306 55 | - ipython=8.20.0=py310h06a4308_0 56 | - jedi=0.18.1=py310h06a4308_1 57 | - jinja2=3.1.3=py310h06a4308_0 58 | - jpeg=9e=h5eee18b_1 59 | - jsonpatch=1.32=pyhd3eb1b0_0 60 | - jsonpointer=2.1=pyhd3eb1b0_0 61 | - jsonschema=4.19.2=py310h06a4308_0 62 | - jsonschema-specifications=2023.7.1=py310h06a4308_0 63 | - krb5=1.20.1=h143b758_1 64 | - lame=3.100=h7b6447c_0 65 | - lcms2=2.12=h3be6417_0 66 | - ld_impl_linux-64=2.38=h1181459_1 67 | - lerc=3.0=h295c915_0 68 | - libarchive=3.6.2=h6ac8c49_2 69 | - libcublas=12.1.0.26=0 70 | - libcufft=11.0.2.4=0 71 | - libcufile=1.9.0.20=0 72 | - libcurand=10.3.5.119=0 73 | - libcurl=8.5.0=h251f7ec_0 74 | - libcusolver=11.4.4.55=0 75 | - libcusparse=12.0.2.55=0 76 | - libdeflate=1.17=h5eee18b_1 77 | - libedit=3.1.20230828=h5eee18b_0 78 | - libev=4.33=h7f8727e_1 79 | - libffi=3.4.4=h6a678d5_0 80 | - libgcc-ng=11.2.0=h1234567_1 81 | - libgomp=11.2.0=h1234567_1 82 | - libiconv=1.16=h7f8727e_2 83 | - libidn2=2.3.4=h5eee18b_0 84 | - libjpeg-turbo=2.0.0=h9bf148f_0 85 | - liblief=0.12.3=h6a678d5_0 86 | - libmamba=1.5.3=haf1ee3a_0 87 | - libmambapy=1.5.3=py310h2dafd23_0 88 | - libnghttp2=1.57.0=h2d74bed_0 89 | - libnpp=12.0.2.50=0 90 | - libnvjitlink=12.1.105=0 91 | - libnvjpeg=12.1.1.14=0 92 | - libpng=1.6.39=h5eee18b_0 93 | - libsolv=0.7.24=he621ea3_0 94 | - libssh2=1.10.0=hdbd6064_2 95 | - libstdcxx-ng=11.2.0=h1234567_1 96 | - libtasn1=4.19.0=h5eee18b_0 97 | - libtiff=4.5.1=h6a678d5_0 98 | - libunistring=0.9.10=h27cfd23_0 99 | - libuuid=1.41.5=h5eee18b_0 100 | - libuv=1.44.2=h5eee18b_0 101 | - libwebp-base=1.3.2=h5eee18b_0 102 | - libxml2=2.10.4=hf1b16e4_1 103 | - llvm-openmp=14.0.6=h9e868ea_0 104 | - lz4-c=1.9.4=h6a678d5_0 105 | - markupsafe=2.1.3=py310h5eee18b_0 106 | - matplotlib-inline=0.1.6=py310h06a4308_0 107 | - menuinst=2.0.2=py310h06a4308_0 108 | - mkl=2023.1.0=h213fc3f_46344 109 | - mkl-service=2.4.0=py310h5eee18b_1 110 | - mkl_fft=1.3.8=py310h5eee18b_0 111 | - mkl_random=1.2.4=py310hdb19cb5_0 112 | - more-itertools=10.1.0=py310h06a4308_0 113 | - mpc=1.1.0=h10f8cd9_1 114 | - mpfr=4.0.2=hb69a4c5_1 115 | - mpmath=1.3.0=py310h06a4308_0 116 | - ncurses=6.4=h6a678d5_0 117 | - nettle=3.7.3=hbbd107a_1 118 | - numpy=1.26.4=py310h5f9d8c6_0 119 | - numpy-base=1.26.4=py310hb5e798b_0 120 | - openh264=2.1.1=h4ff587b_0 121 | - openjpeg=2.4.0=h3ad879b_0 122 | - openssl=3.0.13=h7f8727e_0 123 | - packaging=23.2=py310h06a4308_0 124 | - parso=0.8.3=pyhd3eb1b0_0 125 | - patch=2.7.6=h7b6447c_1001 126 | - patchelf=0.17.2=h6a678d5_0 127 | - pcre2=10.42=hebb0a14_0 128 | - pexpect=4.8.0=pyhd3eb1b0_3 129 | - pillow=10.2.0=py310h5eee18b_0 130 | - pip=23.3.1=py310h06a4308_0 131 | - pkginfo=1.9.6=py310h06a4308_0 132 | - platformdirs=3.10.0=py310h06a4308_0 133 | - pluggy=1.0.0=py310h06a4308_1 134 | - prompt-toolkit=3.0.43=py310h06a4308_0 135 | - prompt_toolkit=3.0.43=hd3eb1b0_0 136 | - psutil=5.9.0=py310h5eee18b_0 137 | - ptyprocess=0.7.0=pyhd3eb1b0_2 138 | - pure_eval=0.2.2=pyhd3eb1b0_0 139 | - py-lief=0.12.3=py310h6a678d5_0 140 | - pybind11-abi=4=hd3eb1b0_1 141 | - pycosat=0.6.6=py310h5eee18b_0 142 | - pycparser=2.21=pyhd3eb1b0_0 143 | - pygments=2.15.1=py310h06a4308_1 144 | - pyopenssl=24.0.0=py310h06a4308_0 145 | - pysocks=1.7.1=py310h06a4308_0 146 | - python=3.10.14=h955ad1f_0 147 | - python-libarchive-c=2.9=pyhd3eb1b0_1 148 | - pytorch=2.2.2=py3.10_cuda12.1_cudnn8.9.2_0 149 | - pytorch-cuda=12.1=ha16c6d3_5 150 | - pytorch-mutex=1.0=cuda 151 | - pytz=2023.3.post1=py310h06a4308_0 152 | - pyyaml=6.0.1=py310h5eee18b_0 153 | - readline=8.2=h5eee18b_0 154 | - referencing=0.30.2=py310h06a4308_0 155 | - reproc=14.2.4=h295c915_1 156 | - reproc-cpp=14.2.4=h295c915_1 157 | - rhash=1.4.3=hdbd6064_0 158 | - rpds-py=0.10.6=py310hb02cf49_0 159 | - ruamel.yaml=0.17.21=py310h5eee18b_0 160 | - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 161 | - six=1.16.0=pyhd3eb1b0_1 162 | - soupsieve=2.5=py310h06a4308_0 163 | - sqlite=3.41.2=h5eee18b_0 164 | - stack_data=0.2.0=pyhd3eb1b0_0 165 | - sympy=1.12=py310h06a4308_0 166 | - tbb=2021.8.0=hdb19cb5_0 167 | - tk=8.6.12=h1ccaba5_0 168 | - tomli=2.0.1=py310h06a4308_0 169 | - toolz=0.12.0=py310h06a4308_0 170 | - torchaudio=2.2.2=py310_cu121 171 | - torchtriton=2.2.0=py310 172 | - torchvision=0.17.2=py310_cu121 173 | - tqdm=4.65.0=py310h2f386ee_0 174 | - traitlets=5.7.1=py310h06a4308_0 175 | - truststore=0.8.0=py310h06a4308_0 176 | - typing_extensions=4.9.0=py310h06a4308_1 177 | - wcwidth=0.2.5=pyhd3eb1b0_0 178 | - wheel=0.41.2=py310h06a4308_0 179 | - xz=5.4.6=h5eee18b_0 180 | - yaml=0.2.5=h7b6447c_0 181 | - yaml-cpp=0.8.0=h6a678d5_0 182 | - zlib=1.2.13=h5eee18b_0 183 | - zstandard=0.19.0=py310h5eee18b_0 184 | - zstd=1.5.5=hc292b87_0 185 | - pip: 186 | - absl-py==2.1.0 187 | - accelerate==0.29.1 188 | - addict==2.4.0 189 | - aliyun-python-sdk-core==2.15.0 190 | - aliyun-python-sdk-kms==2.16.2 191 | - antlr4-python3-runtime==4.8 192 | - astunparse==1.6.3 193 | - cachetools==5.3.3 194 | - colorama==0.4.6 195 | - contourpy==1.2.0 196 | - crcmod==1.7 197 | - cycler==0.12.1 198 | - diffusers==0.27.2 199 | - dnspython==2.6.1 200 | - expecttest==0.2.1 201 | - fonttools==4.50.0 202 | - fsspec==2024.3.1 203 | - google-auth==2.29.0 204 | - google-auth-oauthlib==0.4.6 205 | - grpcio==1.62.1 206 | - huggingface-hub==0.22.2 207 | - hypothesis==6.99.13 208 | - importlib-metadata==7.1.0 209 | - jmespath==0.10.0 210 | - kiwisolver==1.4.5 211 | - markdown==3.6 212 | - markdown-it-py==3.0.0 213 | - matplotlib==3.8.3 214 | - mdurl==0.1.2 215 | - mmcv==2.1.0 216 | - mmdet==3.3.0 217 | - mmengine==0.10.3 218 | - model-index==0.1.11 219 | - networkx==3.2.1 220 | - ninja==1.11.1.1 221 | - oauthlib==3.2.2 222 | - omegaconf==2.1.0 223 | - opencv-python==4.6.0.66 224 | - opendatalab==0.0.10 225 | - openmim==0.3.9 226 | - openxlab==0.0.37 227 | - optree==0.11.0 228 | - ordered-set==4.1.0 229 | - oss2==2.17.0 230 | - pandas==2.2.1 231 | - protobuf==3.20.3 232 | - pyasn1==0.6.0 233 | - pyasn1-modules==0.4.0 234 | - pycocotools==2.0.7 235 | - pycryptodome==3.20.0 236 | - pyparsing==3.1.2 237 | - python-dateutil==2.9.0.post0 238 | - python-etcd==0.4.5 239 | - regex==2023.12.25 240 | - requests==2.28.2 241 | - requests-oauthlib==2.0.0 242 | - rich==13.4.2 243 | - rsa==4.9 244 | - safetensors==0.4.2 245 | - scipy==1.12.0 246 | - setuptools==60.2.0 247 | - shapely==2.0.3 248 | - sortedcontainers==2.4.0 249 | - tabulate==0.9.0 250 | - tensorboard==2.11.2 251 | - tensorboard-data-server==0.6.1 252 | - tensorboard-plugin-wit==1.8.1 253 | - tensorboardx==2.6 254 | - termcolor==2.4.0 255 | - terminaltables==3.1.10 256 | - timm==0.6.13 257 | - torchelastic==0.2.2 258 | - types-dataclasses==0.6.6 259 | - typing-extensions==4.10.0 260 | - tzdata==2024.1 261 | - urllib3==1.26.18 262 | - werkzeug==3.0.1 263 | - yapf==0.40.2 264 | - zipp==3.18.1 265 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .camlipwc import CamLiPWC 2 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def dist_reduce_sum(value): 5 | if torch.distributed.is_initialized(): 6 | value_t = torch.Tensor([value]).cuda() 7 | torch.distributed.all_reduce(value_t) 8 | return value_t 9 | else: 10 | return value 11 | 12 | 13 | class BaseModel(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.loss = None 17 | self.metrics = {} 18 | 19 | def clear_metrics(self): 20 | self.metrics = {} 21 | 22 | @torch.no_grad() 23 | def update_metrics(self, name, var): 24 | if isinstance(var, torch.Tensor): 25 | var = var.reshape(-1) 26 | count = var.shape[0] 27 | var = var.float().sum().item() 28 | 29 | var = dist_reduce_sum(var) 30 | count = dist_reduce_sum(count) 31 | 32 | if count <= 0: 33 | return 34 | 35 | if name not in self.metrics.keys(): 36 | self.metrics[name] = [0, 0] # [var, count] 37 | 38 | self.metrics[name][0] += var 39 | self.metrics[name][1] += count 40 | 41 | def get_metrics(self): 42 | results = {} 43 | for name, (var, count) in self.metrics.items(): 44 | results[name] = var / count 45 | return results 46 | 47 | def get_loss(self): 48 | if self.loss is None: 49 | raise ValueError('Loss is empty.') 50 | return self.loss 51 | 52 | @staticmethod 53 | def is_better(curr_metrics, best_metrics): 54 | raise RuntimeError('Function `is_better` must be implemented.') 55 | 56 | 57 | class FlowModel(BaseModel): 58 | def __init__(self): 59 | super(FlowModel, self).__init__() 60 | 61 | @torch.no_grad() 62 | def update_2d_metrics(self, pred, target): 63 | if target.shape[1] == 3: # sparse evaluation 64 | mask = target[:, 2, :, :] > 0 65 | target = target[:, :2, :, :] 66 | else: # dense evaluation 67 | mask = torch.ones_like(target)[:, 0, :, :] > 0 # B x H x W 68 | 69 | # compute endpoint error 70 | diff = pred - target 71 | epe2d_map = torch.linalg.norm(diff, dim=1) # B x H x W 72 | self.update_metrics('epe2d', epe2d_map[mask]) 73 | 74 | # compute 1px accuracy 75 | acc2d_map = epe2d_map < 1.0 76 | self.update_metrics('acc2d_1px', acc2d_map[mask]) 77 | 78 | # compute flow outliers 79 | mag = torch.linalg.norm(target, dim=1) + 1e-5 80 | out2d_map = torch.logical_and(epe2d_map > 3.0, epe2d_map / mag > 0.05) 81 | self.update_metrics('outlier2d', out2d_map[mask]) 82 | 83 | @torch.no_grad() 84 | def update_3d_metrics(self, pred, target, mask, noc_mask = None): 85 | 86 | # compute endpoint error 87 | diff = pred - target # [B, 3, H, W] 88 | epe3d_map = torch.linalg.norm(diff, dim=1) # [B, H, W] 89 | acc5_3d_map = epe3d_map < 0.05 # compute 5cm accuracy 90 | acc10_3d_map = epe3d_map < 0.10 # compute 10cm accuracy 91 | 92 | if noc_mask is not None: 93 | self.update_metrics('epe3d(noc)', epe3d_map[noc_mask]) 94 | self.update_metrics('acc3d_5cm(noc)', acc5_3d_map[noc_mask]) 95 | self.update_metrics('acc3d_10cm(noc)', acc10_3d_map[noc_mask]) 96 | else: 97 | self.update_metrics('epe3d', epe3d_map[mask]) 98 | self.update_metrics('acc3d_5cm', acc5_3d_map[mask]) 99 | self.update_metrics('acc3d_10cm', acc10_3d_map[mask]) -------------------------------------------------------------------------------- /models/camlipwc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .camlipwc_core import CamLiPWC_Core 3 | from .base import FlowModel 4 | from .losses2d import calc_supervised_loss_2d 5 | from .losses3d import calc_supervised_loss_3d 6 | from .utils import stride_sample, stride_sample_pc, normalize_image, resize_to_64x, resize_flow2d, build_pc_pyramid, build_labels 7 | 8 | 9 | class CamLiPWC(FlowModel): 10 | def __init__(self, cfgs): 11 | super(CamLiPWC, self).__init__() 12 | self.cfgs = cfgs 13 | self.dense = cfgs.dense 14 | self.core = CamLiPWC_Core(cfgs.pwc2d, cfgs.pwc3d, self.dense) 15 | self.stride_h_list = cfgs.stride_h 16 | self.stride_w_list = cfgs.stride_w 17 | self.loss = None 18 | 19 | def train(self, mode=True): 20 | 21 | self.training = mode 22 | 23 | for module in self.children(): 24 | module.train(mode) 25 | 26 | if self.cfgs.freeze_bn: 27 | for m in self.modules(): 28 | if isinstance(m, nn.modules.batchnorm._BatchNorm): 29 | m.eval() 30 | 31 | return self 32 | 33 | def eval(self): 34 | return self.train(False) 35 | 36 | def forward(self, inputs): 37 | image1 = inputs['image1'].float() / 255.0 38 | image2 = inputs['image2'].float() / 255.0 39 | pc1, pc2 = inputs['pc1'].float(), inputs['pc2'].float() 40 | intrinsics = inputs['intrinsics'].float() 41 | 42 | 43 | # assert images.shape[2] % 64 == 0 and images.shape[3] % 64 == 0 44 | origin_h, origin_w = image1.shape[2:] 45 | cam_info = {'sensor_h': origin_h, 'sensor_w': origin_w, 'intrinsics': intrinsics, 'type': "dense" if self.dense else "sparse"} 46 | 47 | # encode features 48 | if self.dense: 49 | xyzs1, xyzs2 = stride_sample_pc(pc1, pc2, self.stride_h_list, self.stride_w_list) 50 | else: 51 | xyzs1, xyzs2, sample_indices1, _ = build_pc_pyramid( 52 | pc1, pc2, [8192, 4096, 2048, 1024, 512, 256] # 1/4 53 | ) 54 | 55 | feats1_2d, feats1_3d = self.core.encode(image1, xyzs1) 56 | feats2_2d, feats2_3d = self.core.encode(image2, xyzs2) 57 | 58 | 59 | # predict flows (1->2) 60 | flows_2d, flows_3d = self.core.decode(xyzs1[1:], xyzs2[1:], feats1_2d, feats2_2d, feats1_3d, feats2_3d, pc1, pc2, cam_info) 61 | 62 | # final_flow_2d = resize_flow2d(flows_2d[0], origin_h, origin_w) 63 | final_flow_2d = flows_2d[0] 64 | final_flow_3d = flows_3d[0] 65 | 66 | if 'flow_2d' not in inputs or 'flow_3d' not in inputs: 67 | return {'flow_2d': final_flow_2d, 'flow_3d': final_flow_3d} 68 | 69 | target_2d = inputs['flow_2d'].float() 70 | target_3d = inputs['flow_3d'].float() 71 | valid_mask = inputs['nonzero_mask'] 72 | if self.dense: 73 | labels_3d = stride_sample(target_3d, self.stride_h_list[:-2], self.stride_w_list[:-2]) 74 | masks_3d = stride_sample(valid_mask, self.stride_h_list[:-2], self.stride_w_list[:-2]) 75 | else: 76 | labels_3d, masks_3d = build_labels(target_3d, sample_indices1) 77 | 78 | # calculate losses 79 | loss_2d = calc_supervised_loss_2d(flows_2d, target_2d, self.cfgs.loss2d) 80 | loss_3d = calc_supervised_loss_3d(flows_3d, labels_3d, self.cfgs.loss3d, masks_3d) 81 | self.loss = loss_2d + loss_3d 82 | 83 | # prepare scalar summary 84 | self.update_metrics('loss', self.loss) 85 | self.update_metrics('loss2d', loss_2d) 86 | self.update_metrics('loss3d', loss_3d) 87 | self.update_2d_metrics(final_flow_2d, target_2d) 88 | self.update_3d_metrics(final_flow_3d, target_3d, valid_mask) 89 | 90 | if 'mask' in inputs: 91 | self.update_3d_metrics(final_flow_3d, target_3d, valid_mask, inputs['mask']) 92 | 93 | return {'flow_2d': final_flow_2d, 'flow_3d': final_flow_3d} 94 | 95 | @staticmethod 96 | def is_better(curr_summary, best_summary): 97 | if best_summary is None: 98 | return True 99 | return curr_summary['epe2d'] < best_summary['epe2d'] 100 | 101 | 102 | -------------------------------------------------------------------------------- /models/camlipwc2d_core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import Conv2dNormRelu 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels, down_sample=True, norm=None): 8 | super().__init__() 9 | 10 | if down_sample: 11 | self.down0 = Conv2dNormRelu(in_channels, out_channels, stride=2, norm=norm, act=None) 12 | self.conv0 = Conv2dNormRelu(in_channels, out_channels, kernel_size=3, stride=2, padding=1, norm=norm) 13 | self.conv1 = Conv2dNormRelu(out_channels, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, act=None) 14 | else: 15 | self.down0 = nn.Identity() 16 | self.conv0 = Conv2dNormRelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, norm=norm) 17 | self.conv1 = Conv2dNormRelu(out_channels, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, act=None) 18 | 19 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 20 | 21 | def forward(self, x): 22 | out = self.conv0(x) 23 | out = self.conv1(out) 24 | out = self.relu(out + self.down0(x)) 25 | return out 26 | 27 | 28 | class FeaturePyramid2D(nn.Module): 29 | def __init__(self, n_channels, norm=None): 30 | super().__init__() 31 | self.pyramid_convs = nn.ModuleList() 32 | for in_channels, out_channels in zip(n_channels[:-1], n_channels[1:]): 33 | self.pyramid_convs.append(ResidualBlock(in_channels, out_channels, norm=norm)) 34 | 35 | def forward(self, x): 36 | outputs = [] 37 | for conv in self.pyramid_convs: 38 | x = conv(x) 39 | outputs.append(x) 40 | return outputs 41 | 42 | 43 | class FlowEstimatorDense2D(nn.Module): 44 | def __init__(self, n_channels, norm=None, conv_last=True): 45 | super().__init__() 46 | self.conv1 = Conv2dNormRelu( 47 | n_channels[0], 48 | n_channels[1], 49 | kernel_size=3, padding=1, norm=norm 50 | ) 51 | self.conv2 = Conv2dNormRelu( 52 | n_channels[0] + n_channels[1], 53 | n_channels[2], 54 | kernel_size=3, padding=1, norm=norm 55 | ) 56 | self.conv3 = Conv2dNormRelu( 57 | n_channels[0] + n_channels[1] + n_channels[2], 58 | n_channels[3], 59 | kernel_size=3, padding=1, norm=norm 60 | ) 61 | self.conv4 = Conv2dNormRelu( 62 | n_channels[0] + n_channels[1] + n_channels[2] + n_channels[3], 63 | n_channels[4], 64 | kernel_size=3, padding=1, norm=norm 65 | ) 66 | self.conv5 = Conv2dNormRelu( 67 | n_channels[0] + n_channels[1] + n_channels[2] + n_channels[3] + n_channels[4], 68 | n_channels[5], 69 | kernel_size=3, padding=1, norm=norm 70 | ) 71 | self.flow_feat_dim = sum(n_channels) 72 | 73 | if conv_last: 74 | self.conv_last = nn.Conv2d(self.flow_feat_dim, 2, kernel_size=3, stride=1, padding=1) 75 | else: 76 | self.conv_last = None 77 | 78 | def forward(self, x): 79 | x1 = torch.cat([self.conv1(x), x], dim=1) 80 | x2 = torch.cat([self.conv2(x1), x1], dim=1) 81 | x3 = torch.cat([self.conv3(x2), x2], dim=1) 82 | x4 = torch.cat([self.conv4(x3), x3], dim=1) 83 | flow_feat = torch.cat([self.conv5(x4), x4], dim=1) 84 | 85 | if self.conv_last is not None: 86 | flow = self.conv_last(flow_feat) 87 | return flow_feat, flow 88 | else: 89 | return flow_feat 90 | 91 | 92 | class ContextNetwork2D(nn.Module): 93 | def __init__(self, n_channels, dilations, norm=None): 94 | super().__init__() 95 | self.convs = nn.ModuleList() 96 | for in_channels, out_channels, dilation in zip(n_channels[:-1], n_channels[1:], dilations): 97 | self.convs.append(Conv2dNormRelu(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation, norm=norm)) 98 | self.conv_last = nn.Conv2d(n_channels[-1], 2, kernel_size=3, stride=1, padding=1) 99 | 100 | def forward(self, x): 101 | for conv in self.convs: 102 | x = conv(x) 103 | outputs = self.conv_last(x) 104 | return x, outputs 105 | -------------------------------------------------------------------------------- /models/camlipwc3d_core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .pointconv import PointConv, PointConvS 4 | from .utils import MLP1d, MLP2d, Conv1dNormRelu, k_nearest_neighbor, batch_indexing, knn_grouping_2d, get_hw_idx, mask_batch_selecting 5 | from ops_pytorch.gpu_threenn_sample.no_sort_knn import no_sort_knn 6 | 7 | 8 | def get_selected_idx(batch_size: int, out_H: int, out_W: int, stride_H: int, stride_W: int): 9 | 10 | select_h_idx = torch.arange(0, out_H * stride_H, stride_H, device = "cuda") # [out_H] 11 | select_w_idx = torch.arange(0, out_W * stride_W, stride_W, device = "cuda") # [out_W] 12 | height_indices = torch.reshape(select_h_idx, (1, 1, -1, 1)).expand(batch_size, 1, out_H, out_W) # b out_H out_W 13 | width_indices = torch.reshape(select_w_idx, (1, 1, 1, -1)).expand(batch_size, 1, out_H, out_W) # b out_H out_W 14 | select_idx = torch.cat([height_indices, width_indices], dim = 1) 15 | return select_idx 16 | 17 | def fast_index(inputs, idx): 18 | """ 19 | Input: 20 | inputs: input points data, [B, 3, H, W] 21 | idx: sample index data, [B, 2, h, w] 22 | Return: 23 | outputs:, indexed points data, [B, 3, h, w] 24 | """ 25 | if len(inputs.shape) == 4: 26 | B, C, H, W = inputs.shape 27 | _, _, h, w = idx.shape 28 | neighbor_idx = idx[:, 0] * W + idx[:, 1] # (B, h, w) 29 | neighbor_idx = neighbor_idx.reshape(B, 1, h*w) # (B, h*w) 30 | inputs_bcn = inputs.reshape(B, C, H*W) # (B, C, H*W) 31 | gather_feat = torch.gather(inputs_bcn, 2, neighbor_idx.expand(-1, C, -1)) # [B, C, h*w] 32 | outputs = torch.reshape(gather_feat, [B, C, h, w]) 33 | else: 34 | B, H, W = inputs.shape 35 | _, _, h, w = idx.shape 36 | neighbor_idx = idx[:, 0] * W + idx[:, 1] # (B, h, w) 37 | neighbor_idx = neighbor_idx.reshape(B, h*w) # (B, h*w) 38 | inputs_bcn = inputs.reshape(B, H*W) # (B, H*W) 39 | gather_feat = torch.gather(inputs, 1, neighbor_idx) # [B, h*w] 40 | outputs = torch.reshape(gather_feat, [B, h, w]) 41 | 42 | return outputs 43 | 44 | def stride_sample_gather(pc1, pc2, label, mask, stride_H_list, stride_W_list): 45 | """ 46 | Input: 47 | pc1: input points data, [B, 3, H, W] 48 | pc2: input points data, [B, 2, H, H] 49 | mask: [B, H, W] 50 | label: [B, 3, H, H] 51 | stride_list: list of sampling strides 52 | Return: 53 | xyzs:, list of sampled pc 54 | """ 55 | B = pc1.shape[0] 56 | H_list = [pc1.shape[2]]; W_list = [pc1.shape[3]] 57 | 58 | xyzs1 = [pc1]; xyzs2 = [pc2]; labels = [label]; masks = [mask] 59 | 60 | for s_h, s_w in zip(stride_H_list, stride_W_list): 61 | H_list.append(H_list[-1] // s_h) 62 | W_list.append(W_list[-1] // s_w) 63 | idx = get_selected_idx(B, H_list[-1], W_list[-1], s_h, s_w) 64 | xyzs1.append(fast_index(xyzs1[-1], idx)) 65 | xyzs2.append(fast_index(xyzs2[-1], idx)) 66 | labels.append(fast_index(labels[-1], idx)) 67 | masks.append(fast_index(masks[-1], idx)) 68 | 69 | return xyzs1, xyzs2, labels[1:], masks[1:] 70 | 71 | class KnnUpsampler3D(nn.Module): 72 | def __init__(self, stride_h, stride_w, ks = [10, 20], dist = 100.0, k=3) -> None: 73 | super().__init__() 74 | self.k = k 75 | self.stride_h = stride_h 76 | self.stride_w = stride_w 77 | self.dist = dist 78 | self.ks = ks 79 | 80 | @torch.no_grad() 81 | def knn_grouping(self, query_xyz, input_xyz): 82 | """ 83 | :param query_xyz: [batch_size, 3, H, W] 84 | :param input_xyz: [batch_size, 3, h, w] 85 | :return grouped idx: [batch_size, H*W, k] 86 | """ 87 | B, C, h, w = input_xyz.shape 88 | _, _, H, W = query_xyz.shape 89 | n_sampled = H * W 90 | 91 | assert H // h == self.stride_h and W // w == self.stride_w, "size mismatch" 92 | 93 | idx_hw = get_hw_idx(B, H, W).contiguous() 94 | random_HW = torch.arange(0, self.ks[0] * self.ks[1], device = "cuda", dtype = torch.int) 95 | input_xyz_hw3 = input_xyz.permute(0, 2, 3, 1).contiguous() # [B, H, W, 3] 96 | query_xyz_hw3 = query_xyz.permute(0, 2, 3, 1).contiguous() # [B, H, W, 3] 97 | 98 | # Initialize 99 | select_b_idx = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').long().detach() # (B N nsample_q 1) 100 | select_h_idx = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').long().detach() 101 | select_w_idx = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').long().detach() 102 | valid_mask = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').float().detach() 103 | 104 | # with torch.no_grad(): 105 | # Sample QNN of (M neighbour points from sampled n points in PC1) in PC2 106 | select_b_idx, select_h_idx, select_w_idx, valid_mask = no_sort_knn\ 107 | (query_xyz_hw3, input_xyz_hw3, idx_hw, random_HW, H, W, n_sampled, self.ks[0], self.ks[1],\ 108 | self.k, 1, self.dist, self.stride_h, self.stride_w, select_b_idx, select_h_idx, select_w_idx, valid_mask) 109 | 110 | neighbor_idx = select_h_idx * w + select_w_idx # [B, H*W, k, 1] 111 | 112 | return neighbor_idx.squeeze(-1), valid_mask.squeeze(-1) 113 | # return neighbor_idx.squeeze(-1), valid_mask.squeeze(-1) 114 | 115 | def forward(self, query_xyz, input_xyz, input_features): 116 | """ 117 | :param input_xyz: 3D locations of input points, [B, 3, h, w] 118 | :param input_features: features of input points, [B, C, h, w] 119 | :param query_xyz: 3D locations of query points, [B, 3, H, W] 120 | :param k: k-nearest neighbor, int 121 | :return interpolated features: [B, C, H, W] 122 | """ 123 | B, _, H, W = query_xyz.shape 124 | 125 | 126 | # knn_indices: [B, H*W, 3] 127 | knn_indices, valid_knn_mask = self.knn_grouping(query_xyz, input_xyz) 128 | knn_xyz = mask_batch_selecting(input_xyz, knn_indices, valid_knn_mask) # [B, 3, H*W, 3] 129 | query_xyz = query_xyz.view(B, 3, H*W) 130 | knn_dists = torch.linalg.norm(knn_xyz - query_xyz[..., None], dim = 1).clamp(1e-8) 131 | # knn_weights: [B, H*W, 3] 132 | knn_weights = 1.0 / knn_dists 133 | knn_weights = knn_weights / torch.sum(knn_weights, dim = -1, keepdim = True) 134 | knn_features = mask_batch_selecting(input_features, knn_indices, valid_knn_mask) # [B, C, H*W, 3] 135 | 136 | # interpolated: [B, C, H*W] 137 | interpolated = torch.sum(knn_features * knn_weights[:, None, :, :], dim=-1) 138 | 139 | interpolated = interpolated.view(B, -1, H, W) 140 | return interpolated 141 | 142 | class FeaturePyramid3D(nn.Module): 143 | def __init__(self, n_channels, norm=None, k=16, ks = [10, 20]): 144 | super().__init__() 145 | 146 | self.mlps = nn.ModuleList([MLP2d(3, [n_channels[0], n_channels[0]])]) 147 | self.convs = nn.ModuleList([PointConv(n_channels[0], n_channels[0], norm=norm, k=k, ks = ks)]) 148 | 149 | for i in range(1, len(n_channels)): 150 | self.mlps.append(MLP2d(n_channels[i - 1], [n_channels[i - 1], n_channels[i]])) 151 | self.convs.append( 152 | PointConv(n_channels[i], n_channels[i], norm=norm, k=k, ks = ks) 153 | ) 154 | 155 | def forward(self, xyzs): 156 | """ 157 | :param xyzs: pyramid of points 158 | :return feats: pyramid of features 159 | """ 160 | assert len(xyzs) == len(self.mlps) + 1 161 | 162 | input_feat = xyzs[0] # [bs, 3, h, w] 163 | # input_feat = self.level0_mlp(inputs) 164 | feats = [] 165 | 166 | for i in range(len(xyzs) - 1): 167 | if i == 0: 168 | feat = self.mlps[i](input_feat) 169 | else: 170 | feat = self.mlps[i](feats[-1]) 171 | 172 | feat = self.convs[i](xyzs[i], feat, xyzs[i + 1]) 173 | feats.append(feat) 174 | 175 | return feats 176 | 177 | class FeaturePyramid3DS(nn.Module): 178 | def __init__(self, n_channels, norm=None, k=16): 179 | super().__init__() 180 | 181 | self.level0_mlp = MLP1d(3, [n_channels[0], n_channels[0]]) 182 | 183 | self.pyramid_mlps = nn.ModuleList() 184 | self.pyramid_convs = nn.ModuleList() 185 | 186 | for i in range(len(n_channels) - 1): 187 | self.pyramid_mlps.append(MLP1d(n_channels[i], [n_channels[i], n_channels[i + 1]])) 188 | self.pyramid_convs.append(PointConvS(n_channels[i + 1], n_channels[i + 1], norm=norm, k=k)) 189 | 190 | def forward(self, xyzs): 191 | """ 192 | :param xyzs: pyramid of points 193 | :return feats: pyramid of features 194 | """ 195 | assert len(xyzs) == len(self.pyramid_mlps) + 1 196 | 197 | inputs = xyzs[0] # [bs, 3, n_points] 198 | feats = [self.level0_mlp(inputs)] 199 | 200 | for i in range(len(xyzs) - 1): 201 | feat = self.pyramid_mlps[i](feats[-1]) 202 | feats.append(self.pyramid_convs[i](xyzs[i], feat, xyzs[i + 1])) 203 | 204 | return feats 205 | 206 | class Costvolume3D(nn.Module): 207 | def __init__(self, in_channels, out_channels, ks = [10, 20], dist = 100.0, k=16): 208 | super().__init__() 209 | 210 | self.k = k 211 | self.ks = ks 212 | self.dist = dist 213 | self.cost_mlp = MLP2d(3 + 2 * in_channels, [out_channels, out_channels], act='leaky_relu') 214 | self.weight_net1 = MLP2d(3, [8, 8, out_channels], act='relu') 215 | self.weight_net2 = MLP2d(3, [8, 8, out_channels], act='relu') 216 | 217 | 218 | def forward(self, xyz1, feat1, xyz2, feat2, idx_fetching=None, knn_indices_1in1=None, valid_mask_1in1 = None): 219 | """ 220 | :param xyz1: [batch_size, 3, H, W] 221 | :param feat1: [batch_size, in_channels, H, W] 222 | :param xyz2: [batch_size, 3, H, W] 223 | :param feat2: [batch_size, in_channels, H, W] 224 | :param warping idx: for each warped point in xyz1, find its position in xyz2, [batch_size, H * W, 2] 225 | :return cost volume: [batch_size, n_cost_channels, H, W] 226 | """ 227 | B, C, H, W = feat1.shape 228 | feat1 = feat1.view(B, C, H * W) 229 | 230 | knn_indices_1in2, valid_mask_1in2 = knn_grouping_2d(query_xyz=xyz1, input_xyz=xyz2, k=self.k, idx_fetching = idx_fetching) 231 | # knn_xyz2: [B, 3, H*W, k], 232 | knn_xyz2 = mask_batch_selecting(xyz2, knn_indices_1in2, valid_mask_1in2) 233 | # knn_xyz2_norm: [B, 3, H*W, k] 234 | knn_xyz2_norm = knn_xyz2 - xyz1.view(B, 3, H * W, 1) 235 | # knn_features2: [B, C, H*W, k] 236 | knn_features2 = mask_batch_selecting(feat2, knn_indices_1in2, valid_mask_1in2) 237 | # features1_expand: [B, C, H*W, k] 238 | features1_expand = feat1[:, :, :, None].expand(B, C, H * W, self.k) 239 | # concatenated_features: [B, 2C+3, H*W, k] 240 | concatenated_features = torch.cat([features1_expand, knn_features2, knn_xyz2_norm], dim=1) 241 | # p2p_cost (point-to-point cost): [B, out_channels, H*W, k] 242 | p2p_cost = self.cost_mlp(concatenated_features) 243 | 244 | # weights2: [B, out_channels, H*W, k] 245 | weights2 = self.weight_net2(knn_xyz2_norm) 246 | # p2n_cost (point-to-neighbor cost): [B, out_channels, H * W] 247 | p2n_cost = torch.sum(weights2 * p2p_cost, dim=3) 248 | 249 | if knn_indices_1in1 is not None: 250 | assert knn_indices_1in1.shape[:2] == torch.Size([B, H * W]) 251 | assert knn_indices_1in1.shape[2] >= self.k 252 | knn_indices_1in1 = knn_indices_1in1[:, :, :self.k] 253 | valid_mask_1in1 = valid_mask_1in1[:, :, :self.k] 254 | else: 255 | knn_indices_1in1, valid_mask_1in1 = knn_grouping_2d(query_xyz=xyz1, input_xyz=xyz1, k=self.k) # [bs, n_points, k] 256 | 257 | # knn_xyz1: [B, 3, H*W, k] 258 | knn_xyz1 = mask_batch_selecting(xyz1, knn_indices_1in1, valid_mask_1in1) 259 | # knn_xyz1_norm: [B, 3, H*W, k] 260 | knn_xyz1_norm = knn_xyz1 - xyz1.view(B, 3, H * W, 1) 261 | # weights1: [B, out_channels, H*W, k] 262 | weights1 = self.weight_net1(knn_xyz1_norm) 263 | # n2n_cost: [B, out_channels, H*W, k] 264 | n2n_cost = mask_batch_selecting(p2n_cost, knn_indices_1in1, valid_mask_1in1) 265 | # n2n_cost: [B, out_channels, H * W] 266 | n2n_cost = torch.sum(weights1 * n2n_cost, dim=3) 267 | 268 | n2n_cost = n2n_cost.view(B, -1, H, W) 269 | 270 | return n2n_cost 271 | 272 | class Costvolume3DS(nn.Module): 273 | def __init__(self, in_channels, out_channels, align_channels=None, k=16): 274 | super().__init__() 275 | self.k = k 276 | 277 | self.cost_mlp = MLP2d(3 + 2 * in_channels, [out_channels, out_channels], act='leaky_relu') 278 | self.weight_net1 = MLP2d(3, [8, 8, out_channels], act='relu') 279 | self.weight_net2 = MLP2d(3, [8, 8, out_channels], act='relu') 280 | 281 | if align_channels is not None: 282 | self.feat_aligner = Conv1dNormRelu(out_channels, align_channels) 283 | else: 284 | self.feat_aligner = nn.Identity() 285 | 286 | def forward(self, xyz1, feat1, xyz2, feat2, knn_indices_1in1=None): 287 | """ 288 | :param xyz1: [batch_size, 3, n_points] 289 | :param feat1: [batch_size, in_channels, n_points] 290 | :param xyz2: [batch_size, 3, n_points] 291 | :param feat2: [batch_size, in_channels, n_points] 292 | :param knn_indices_1in1: for each point in xyz1, find its neighbors in xyz1, [batch_size, n_points, k] 293 | :return cost volume for each point in xyz1: [batch_size, n_cost_channels, n_points] 294 | """ 295 | batch_size, in_channels, n_points = feat1.shape 296 | 297 | # Step1: for each point in xyz1, find its neighbors in xyz2 298 | knn_indices_1in2 = k_nearest_neighbor(input_xyz=xyz2, query_xyz=xyz1, k=self.k) 299 | # knn_xyz2: [bs, 3, n_points, k] 300 | knn_xyz2 = batch_indexing(xyz2, knn_indices_1in2) 301 | # knn_xyz2_norm: [bs, 3, n_points, k] 302 | knn_xyz2_norm = knn_xyz2 - xyz1.view(batch_size, 3, n_points, 1) 303 | # knn_features2: [bs, in_channels, n_points, k] 304 | knn_features2 = batch_indexing(feat2, knn_indices_1in2) 305 | # features1_expand: [bs, in_channels, n_points, k] 306 | features1_expand = feat1[:, :, :, None].expand(batch_size, in_channels, n_points, self.k) 307 | # concatenated_features: [bs, 2 * in_channels + 3, n_points, k] 308 | concatenated_features = torch.cat([features1_expand, knn_features2, knn_xyz2_norm], dim=1) 309 | # p2p_cost (point-to-point cost): [bs, out_channels, n_points, k] 310 | p2p_cost = self.cost_mlp(concatenated_features) 311 | 312 | # weights2: [bs, out_channels, n_points, k] 313 | weights2 = self.weight_net2(knn_xyz2_norm) 314 | # p2n_cost (point-to-neighbor cost): [bs, out_channels, n_points] 315 | p2n_cost = torch.sum(weights2 * p2p_cost, dim=3) 316 | 317 | # Step2: for each point in xyz1, find its neighbors in xyz1 318 | if knn_indices_1in1 is not None: 319 | assert knn_indices_1in1.shape[:2] == torch.Size([batch_size, n_points]) 320 | assert knn_indices_1in1.shape[2] >= self.k 321 | knn_indices_1in1 = knn_indices_1in1[:, :, :self.k] 322 | else: 323 | knn_indices_1in1 = k_nearest_neighbor(input_xyz=xyz1, query_xyz=xyz1, k=self.k) # [bs, n_points, k] 324 | # knn_xyz1: [bs, 3, n_points, k] 325 | knn_xyz1 = batch_indexing(xyz1, knn_indices_1in1) 326 | # knn_xyz1_norm: [bs, 3, n_points, k] 327 | knn_xyz1_norm = knn_xyz1 - xyz1.view(batch_size, 3, n_points, 1) 328 | 329 | # weights1: [bs, out_channels, n_points, k] 330 | weights1 = self.weight_net1(knn_xyz1_norm) 331 | # n2n_cost: [bs, out_channels, n_points, k] 332 | n2n_cost = batch_indexing(p2n_cost, knn_indices_1in1) 333 | # n2n_cost (neighbor-to-neighbor cost): [bs, out_channels, n_points] 334 | n2n_cost = torch.sum(weights1 * n2n_cost, dim=3) 335 | # align features (optional) 336 | n2n_cost = self.feat_aligner(n2n_cost) 337 | 338 | return n2n_cost 339 | 340 | class FlowPredictor3D(nn.Module): 341 | def __init__(self, n_channels, norm=None, conv_last=True, k=16): 342 | super().__init__() 343 | self.point_conv1 = PointConv(in_channels=n_channels[0], out_channels=n_channels[1], norm=norm, k=k) 344 | self.point_conv2 = PointConv(in_channels=n_channels[1], out_channels=n_channels[2], norm=norm, k=k) 345 | self.mlp = MLP2d(n_channels[2], [n_channels[2], n_channels[3]]) 346 | self.flow_feat_dim = n_channels[3] 347 | 348 | if conv_last: 349 | self.conv_last = nn.Conv2d(n_channels[3], 3, kernel_size=1) 350 | else: 351 | self.conv_last = None 352 | 353 | def forward(self, xyz, feat, knn_indices, valid_knn_mask): 354 | """ 355 | :param xyz: 3D locations of points, [B, 3, H, W] 356 | :param feat: features of points, [B, in_channels, H, W] 357 | :return flow_feat: [B, 64, H, W] 358 | :return flow: [B, 3, H, W] 359 | """ 360 | feat = self.point_conv1(xyz, feat, knn_indices = knn_indices, valid_knn_mask = valid_knn_mask) # [bs, 128, H, W] 361 | feat = self.point_conv2(xyz, feat, knn_indices = knn_indices, valid_knn_mask = valid_knn_mask) # [bs, 128, H, W] 362 | feat = self.mlp(feat) # [bs, 64, H, W] 363 | 364 | if self.conv_last is not None: 365 | flow = self.conv_last(feat) # [bs, 3, H, W] 366 | return feat, flow 367 | else: 368 | return feat 369 | 370 | 371 | class FlowPredictor3DS(nn.Module): 372 | def __init__(self, n_channels, norm=None, conv_last=False, k=16): 373 | super().__init__() 374 | self.point_conv1 = PointConvS(in_channels=n_channels[0], out_channels=n_channels[1], norm=norm, k=k) 375 | self.point_conv2 = PointConvS(in_channels=n_channels[1], out_channels=n_channels[2], norm=norm, k=k) 376 | self.mlp = MLP1d(n_channels[2], [n_channels[2], n_channels[3]]) 377 | self.flow_feat_dim = n_channels[3] 378 | 379 | if conv_last: 380 | self.conv_last = nn.Conv1d(n_channels[3], 3, kernel_size=1) 381 | else: 382 | self.conv_last = None 383 | 384 | def forward(self, xyz, feat, knn_indices, mask): 385 | """ 386 | :param xyz: 3D locations of points, [batch_size, 3, n_points] 387 | :param feat: features of points, [batch_size, in_channels, n_points] 388 | :param knn_indices: knn indices of points, [batch_size, n_points, k] 389 | :return flow_feat: [batch_size, 64, n_points] 390 | :return flow: [batch_size, 3, n_points] 391 | """ 392 | feat = self.point_conv1(xyz, feat, knn_indices=knn_indices) 393 | feat = self.point_conv2(xyz, feat, knn_indices=knn_indices) 394 | feat = self.mlp(feat) 395 | 396 | if self.conv_last is not None: 397 | flow = self.conv_last(feat) 398 | return feat, flow 399 | else: 400 | return feat -------------------------------------------------------------------------------- /models/camlipwc_core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import leaky_relu, interpolate 4 | from .camlipwc2d_core import FeaturePyramid2D, FlowEstimatorDense2D, ContextNetwork2D 5 | from .camlipwc3d_core import FeaturePyramid3D, FeaturePyramid3DS, Costvolume3D, Costvolume3DS, FlowPredictor3D, FlowPredictor3DS, KnnUpsampler3D 6 | from .utils import Conv1dNormRelu, Conv2dNormRelu 7 | from .utils import backwarp_2d, backwarp_3d, mesh_grid, forwarp_3d, knn_interpolation, convex_upsample, project_3d_to_2d, knn_grouping_2d, mask_batch_selecting 8 | from .csrc import correlation2d, k_nearest_neighbor 9 | from .fusion_module import GlobalFuser, GlobalFuserS 10 | 11 | class CamLiPWC_Core(nn.Module): 12 | def __init__(self, cfgs2d, cfgs3d, dense=False, debug=False): 13 | super().__init__() 14 | 15 | self.cfgs2d, self.cfgs3d, self.debug, self.dense = cfgs2d, cfgs3d, debug, dense 16 | corr_channels_2d = (2 * cfgs2d.max_displacement + 1) ** 2 17 | 18 | ## PWC-Net 2D (IRR-PWC) 19 | self.branch_2d_fnet = FeaturePyramid2D( 20 | [3, 16, 32, 64, 96, 128, 192], 21 | norm=cfgs2d.norm.feature_pyramid 22 | ) 23 | self.branch_2d_fnet_aligners = nn.ModuleList([ 24 | nn.Identity(), 25 | Conv2dNormRelu(32, 64), 26 | Conv2dNormRelu(64, 64), 27 | Conv2dNormRelu(96, 64), 28 | Conv2dNormRelu(128, 64), 29 | Conv2dNormRelu(192, 64), 30 | ]) 31 | self.branch_2d_flow_estimator = FlowEstimatorDense2D( 32 | [64 + corr_channels_2d + 2 + 32, 128, 128, 96, 64, 32], 33 | norm=cfgs2d.norm.flow_estimator, 34 | conv_last=False, 35 | ) 36 | self.branch_2d_context_network = ContextNetwork2D( 37 | [self.branch_2d_flow_estimator.flow_feat_dim + 2, 128, 128, 128, 96, 64, 32], 38 | dilations=[1, 2, 4, 8, 16, 1], 39 | norm=cfgs2d.norm.context_network 40 | ) 41 | self.branch_2d_up_mask_head = nn.Sequential( # for convex upsampling (see RAFT) 42 | nn.Conv2d(32, 256, kernel_size=3, stride=1, padding=1), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(256, 4 * 4 * 9, kernel_size=1, stride=1, padding=0), 45 | ) 46 | self.branch_2d_conv_last = nn.Conv2d(self.branch_2d_flow_estimator.flow_feat_dim, 2, kernel_size=3, stride=1, padding=1) 47 | 48 | if dense: 49 | ## PWC-Net 3D (Point-PWC) 50 | self.branch_3d_fnet = FeaturePyramid3D( 51 | [16, 32, 64, 96, 128, 192], 52 | norm=cfgs3d.norm.feature_pyramid, 53 | k=cfgs3d.k, 54 | ks=cfgs3d.kernel_size 55 | ) 56 | self.branch_3d_fnet_aligners = nn.ModuleList([ 57 | nn.Identity(), 58 | Conv2dNormRelu(32, 64), 59 | Conv2dNormRelu(64, 64), 60 | Conv2dNormRelu(96, 64), 61 | Conv2dNormRelu(128, 64), 62 | Conv2dNormRelu(192, 64), 63 | ]) 64 | self.branch_3d_correlations = nn.ModuleList([ 65 | nn.Identity(), 66 | Costvolume3D(32, 32, k=self.cfgs3d.k), 67 | Costvolume3D(64, 64, k=self.cfgs3d.k), 68 | Costvolume3D(96, 96, k=self.cfgs3d.k), 69 | Costvolume3D(128, 128, k=self.cfgs3d.k), 70 | Costvolume3D(192, 192, k=self.cfgs3d.k), 71 | ]) 72 | self.branch_3d_correlation_aligners = nn.ModuleList([ 73 | nn.Identity(), 74 | Conv2dNormRelu(32, 64), 75 | Conv2dNormRelu(64, 64), 76 | Conv2dNormRelu(96, 64), 77 | Conv2dNormRelu(128, 64), 78 | Conv2dNormRelu(192, 64), 79 | ]) 80 | self.branch_3d_flow_estimator = FlowPredictor3D( 81 | [64 + 64 + 3 + 64, 128, 128, 64], 82 | cfgs3d.norm.flow_estimator, 83 | conv_last=False, 84 | k=self.cfgs3d.k, 85 | ) 86 | 87 | ## Bi-CLFM for pyramid features 88 | self.pyramid_feat_fusers = nn.ModuleList([ 89 | nn.Identity(), 90 | GlobalFuser(32, 32, norm=cfgs2d.norm.feature_pyramid), 91 | GlobalFuser(64, 64, norm=cfgs2d.norm.feature_pyramid), 92 | GlobalFuser(96, 96, norm=cfgs2d.norm.feature_pyramid), 93 | GlobalFuser(128, 128, norm=cfgs2d.norm.feature_pyramid), 94 | GlobalFuser(192, 192, norm=cfgs2d.norm.feature_pyramid), 95 | ]) 96 | 97 | ## Bi-CLFM for correlation features 98 | self.corr_feat_fusers = nn.ModuleList([ 99 | nn.Identity(), 100 | GlobalFuser(corr_channels_2d, 32), 101 | GlobalFuser(corr_channels_2d, 64), 102 | GlobalFuser(corr_channels_2d, 96), 103 | GlobalFuser(corr_channels_2d, 128), 104 | GlobalFuser(corr_channels_2d, 192), 105 | ]) 106 | 107 | self.intepolate_3d = KnnUpsampler3D(stride_h = 2, stride_w = 2, k = 3) 108 | self.knn_upconv = KnnUpsampler3D(stride_h = 4, stride_w = 4, k = 3) 109 | 110 | ## Bi-CLFM for decoder features 111 | self.estimator_feat_fuser = GlobalFuser(self.branch_2d_flow_estimator.flow_feat_dim, self.branch_3d_flow_estimator.flow_feat_dim) 112 | 113 | self.branch_3d_conv_last = nn.Conv2d(self.branch_3d_flow_estimator.flow_feat_dim, 3, kernel_size=1) 114 | 115 | else: 116 | self.branch_3d_fnet = FeaturePyramid3DS( 117 | n_channels=[16, 32, 64, 96, 128, 192], # 1/4 118 | norm=cfgs3d.norm.feature_pyramid, 119 | k=cfgs3d.k, 120 | ) 121 | self.branch_3d_fnet_aligners = nn.ModuleList([ 122 | nn.Identity(), 123 | Conv1dNormRelu(32, 64), # 1/4 124 | Conv1dNormRelu(64, 64), 125 | Conv1dNormRelu(96, 64), 126 | Conv1dNormRelu(128, 64), 127 | Conv1dNormRelu(192, 64), 128 | ]) 129 | self.branch_3d_correlations = nn.ModuleList([ 130 | nn.Identity(), 131 | Costvolume3DS(32, 32, k=self.cfgs3d.k), # 1/4 132 | Costvolume3DS(64, 64, k=self.cfgs3d.k), 133 | Costvolume3DS(96, 96, k=self.cfgs3d.k), 134 | Costvolume3DS(128, 128, k=self.cfgs3d.k), 135 | Costvolume3DS(192, 192, k=self.cfgs3d.k), 136 | ]) 137 | self.branch_3d_correlation_aligners = nn.ModuleList([ 138 | nn.Identity(), 139 | Conv1dNormRelu(32, 64), # 1/4 140 | Conv1dNormRelu(64, 64), 141 | Conv1dNormRelu(96, 64), 142 | Conv1dNormRelu(128, 64), 143 | Conv1dNormRelu(192, 64), 144 | ]) 145 | self.branch_3d_flow_estimator = FlowPredictor3DS( 146 | [64 + 64 + 3 + 64, 128, 128, 64], 147 | cfgs3d.norm.flow_estimator, 148 | k=self.cfgs3d.k, 149 | ) 150 | 151 | 152 | self.pyramid_feat_fusers = nn.ModuleList([ 153 | nn.Identity(), 154 | GlobalFuserS(32, 32, norm=cfgs2d.norm.feature_pyramid), # 1/4 155 | GlobalFuserS(64, 64, norm=cfgs2d.norm.feature_pyramid), 156 | GlobalFuserS(96, 96, norm=cfgs2d.norm.feature_pyramid), 157 | GlobalFuserS(128, 128, norm=cfgs2d.norm.feature_pyramid), 158 | GlobalFuserS(192, 192, norm=cfgs2d.norm.feature_pyramid), 159 | ]) 160 | 161 | self.corr_feat_fusers = nn.ModuleList([ 162 | nn.Identity(), 163 | GlobalFuserS(corr_channels_2d, 32), # 1/4 164 | GlobalFuserS(corr_channels_2d, 64), 165 | GlobalFuserS(corr_channels_2d, 96), 166 | GlobalFuserS(corr_channels_2d, 128), 167 | GlobalFuserS(corr_channels_2d, 192), 168 | ]) 169 | 170 | self.estimator_feat_fuser = GlobalFuserS(self.branch_2d_flow_estimator.flow_feat_dim, self.branch_3d_flow_estimator.flow_feat_dim) 171 | self.branch_3d_conv_last = nn.Conv1d(self.branch_3d_flow_estimator.flow_feat_dim, 3, kernel_size=1) 172 | 173 | 174 | def encode(self, image, xyzs): 175 | feats_2d = self.branch_2d_fnet(image) 176 | feats_3d = self.branch_3d_fnet(xyzs) 177 | return feats_2d, feats_3d 178 | 179 | def decode(self, xyzs1, xyzs2, feats1_2d, feats2_2d, feats1_3d, feats2_3d, raw_pc1, raw_pc2, camera_info): 180 | assert len(xyzs1) == len(xyzs2) == len(feats1_2d) == len(feats2_2d) == len(feats1_3d) == len(feats2_3d) 181 | 182 | flows_2d, flows_3d, flow_feats_2d, flow_feats_3d = [], [], [], [] 183 | for level in range(len(xyzs1) - 1, 0, -1): 184 | xyz1, feat1_2d, feat1_3d = xyzs1[level], feats1_2d[level], feats1_3d[level] 185 | xyz2, feat2_2d, feat2_3d = xyzs2[level], feats2_2d[level], feats2_3d[level] 186 | 187 | batch_size, image_h, image_w = feat1_2d.shape[0], feat1_2d.shape[2], feat1_2d.shape[3] 188 | if not self.dense: 189 | n_points = xyz1.shape[-1] 190 | 191 | # project point cloud to image 192 | uv1, uv_mask1 = project_3d_to_2d(xyz1, camera_info, image_h, image_w) 193 | uv2, uv_mask2 = project_3d_to_2d(xyz2, camera_info, image_h, image_w) 194 | 195 | # pre-compute knn indices 196 | if self.dense: 197 | knn_1in1, valid_mask_1in1 = knn_grouping_2d(xyz1, xyz1, k=self.cfgs3d.k) # [bs, n_points, k] 198 | else: 199 | grid = mesh_grid(batch_size, image_h, image_w, uv1.device) # [B, 2, H, W] 200 | grid = grid.reshape([batch_size, 2, -1]) # [B, 2, HW] 201 | knn_1in1 = k_nearest_neighbor(xyz1, xyz1, k=self.cfgs3d.k) # [bs, n_points, k] 202 | valid_mask_1in1 = None 203 | 204 | # fuse pyramid features 205 | feat1_2d, feat1_3d = self.pyramid_feat_fusers[level](uv1, uv_mask1, feat1_2d, feat1_3d) 206 | feat2_2d, feat2_3d = self.pyramid_feat_fusers[level](uv2, uv_mask2, feat2_2d, feat2_3d) 207 | 208 | 209 | if level == len(xyzs1) - 1: 210 | last_flow_2d = torch.zeros([batch_size, 2, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device) 211 | last_flow_feat_2d = torch.zeros([batch_size, 32, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device) 212 | 213 | if self.dense: 214 | last_flow_3d = torch.zeros([batch_size, 3, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device) 215 | last_flow_feat_3d = torch.zeros([batch_size, 64, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device) 216 | xyz1_warp, feat2_2d_warp = xyz1, feat2_2d 217 | warping_idx = None 218 | else: 219 | last_flow_3d = torch.zeros([batch_size, 3, n_points], dtype=xyz1.dtype, device=xyz1.device) 220 | last_flow_feat_3d = torch.zeros([batch_size, 64, n_points], dtype=xyz1.dtype, device=xyz1.device) 221 | xyz2_warp, feat2_2d_warp = xyz2, feat2_2d 222 | else: 223 | # upsample 2d flow and backwarp 224 | last_flow_2d = interpolate(flows_2d[-1] * 2, scale_factor=2, mode='bilinear', align_corners=True) 225 | last_flow_feat_2d = interpolate(flow_feats_2d[-1], scale_factor=2, mode='bilinear', align_corners=True) 226 | feat2_2d_warp = backwarp_2d(feat2_2d, last_flow_2d, padding_mode='border') 227 | 228 | if self.dense: 229 | # upsample 3d flow and backwarp 230 | flow_with_feat_3d = torch.cat([flows_3d[-1], flow_feats_3d[-1]], dim=1) 231 | flow_with_feat_upsampled_3d = self.intepolate_3d(xyz1, xyzs1[level + 1], flow_with_feat_3d) 232 | last_flow_3d, last_flow_feat_3d = torch.split(flow_with_feat_upsampled_3d, [3, 64], dim=1) 233 | xyz1_warp, warping_idx = forwarp_3d(xyz1, xyz2, last_flow_3d, camera_info) 234 | else: 235 | last_flow_3d, last_flow_feat_3d = torch.split( 236 | knn_interpolation( 237 | xyzs1[level + 1], 238 | torch.cat([flows_3d[-1], flow_feats_3d[-1]], dim=1), 239 | xyz1 240 | ), [3, 64], dim=1) 241 | xyz2_warp = backwarp_3d(xyz1, xyz2, last_flow_3d) 242 | 243 | # correlation (2D & 3D) 244 | if self.dense: 245 | feat_corr_3d = self.branch_3d_correlations[level](xyz1_warp, feat1_3d, xyz2, feat2_3d, warping_idx, knn_1in1, valid_mask_1in1) 246 | else: 247 | feat_corr_3d = self.branch_3d_correlations[level](xyz1, feat1_3d, xyz2_warp, feat2_3d, knn_1in1) 248 | feat_corr_2d = leaky_relu(correlation2d(feat1_2d, feat2_2d_warp, self.cfgs2d.max_displacement), 0.1) 249 | 250 | # fuse correlation features 251 | feat_corr_2d, feat_corr_3d = self.corr_feat_fusers[level](uv1, uv_mask1, feat_corr_2d, feat_corr_3d) 252 | 253 | # align features using 1x1 convolution 254 | feat1_2d = self.branch_2d_fnet_aligners[level](feat1_2d) 255 | feat1_3d = self.branch_3d_fnet_aligners[level](feat1_3d) 256 | feat_corr_3d = self.branch_3d_correlation_aligners[level](feat_corr_3d) 257 | 258 | # flow decoder (or estimator) 259 | x_2d = torch.cat([feat_corr_2d, feat1_2d, last_flow_2d, last_flow_feat_2d], dim=1) 260 | x_3d = torch.cat([feat_corr_3d, feat1_3d, last_flow_3d, last_flow_feat_3d], dim=1) 261 | flow_feat_2d = self.branch_2d_flow_estimator(x_2d) # [bs, 96, image_h, image_w] 262 | flow_feat_3d = self.branch_3d_flow_estimator(xyz1, x_3d, knn_1in1, valid_mask_1in1) # [bs, 64, n_points] 263 | 264 | # fuse decoder features 265 | flow_feat_2d, flow_feat_3d = self.estimator_feat_fuser(uv1, uv_mask1, flow_feat_2d, flow_feat_3d) 266 | 267 | # flow prediction 268 | flow_delta_2d = self.branch_2d_conv_last(flow_feat_2d) 269 | flow_delta_3d = self.branch_3d_conv_last(flow_feat_3d) 270 | 271 | # residual connection 272 | flow_2d = last_flow_2d + flow_delta_2d 273 | flow_3d = last_flow_3d + flow_delta_3d 274 | 275 | # context network (2D only) 276 | flow_feat_2d, flow_delta_2d = self.branch_2d_context_network(torch.cat([flow_feat_2d, flow_2d], dim=1)) 277 | flow_2d = flow_delta_2d + flow_2d 278 | 279 | # clip 280 | flow_2d = torch.clip(flow_2d, min=-1000, max=1000) 281 | flow_3d = torch.clip(flow_3d, min=-100, max=100) 282 | 283 | # save results 284 | flows_2d.append(flow_2d) 285 | flows_3d.append(flow_3d) 286 | flow_feats_2d.append(flow_feat_2d) 287 | flow_feats_3d.append(flow_feat_3d) 288 | 289 | flows_2d = [f.float() for f in flows_2d][::-1] 290 | flows_3d = [f.float() for f in flows_3d][::-1] 291 | 292 | # convex upsamling module, from RAFT 293 | flows_2d[0] = convex_upsample(flows_2d[0], self.branch_2d_up_mask_head(flow_feats_2d[-1]), scale_factor=4) 294 | 295 | for i in range(1, len(flows_2d)): 296 | flows_2d[i] = interpolate(flows_2d[i] * 4, scale_factor=4, mode='bilinear', align_corners=True) 297 | 298 | for i in range(len(flows_3d)): 299 | if self.dense: 300 | if i == 0: 301 | flows_3d[i] = self.knn_upconv(raw_pc1, xyzs1[i + 1], flows_3d[i]) 302 | else: 303 | flows_3d[i] = self.knn_upconv(xyzs1[i - 1], xyzs1[i + 1], flows_3d[i]) 304 | else: 305 | flows_3d[i] = knn_interpolation(xyzs1[i + 1], flows_3d[i], xyzs1[i]) 306 | 307 | return flows_2d, flows_3d 308 | -------------------------------------------------------------------------------- /models/csrc/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrapper import correlation2d, furthest_point_sampling, squared_distance, k_nearest_neighbor 2 | -------------------------------------------------------------------------------- /models/csrc/correlation/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include "correlation.h" 2 | 3 | void correlation_forward_kernel_wrapper(float* output, const float* input1, const float* input2, 4 | int n_batches, int in_channels, int height, int width, int max_displacement); 5 | 6 | void correlation_backward_kernel_wrapper(const float* grad_output, 7 | float* grad_input1, float* grad_input2, 8 | const float* input1, const float* input2, 9 | int n_batches, int in_channels, int height, int width, int max_displacement); 10 | 11 | torch::Tensor correlation_forward_cuda(torch::Tensor& input1, torch::Tensor& input2, int max_displacement) { 12 | TORCH_CHECK(input1.is_contiguous(), "input1 must be a contiguous tensor"); 13 | TORCH_CHECK(input2.is_contiguous(), "input2 must be a contiguous tensor"); 14 | 15 | int batch_size = input1.size(0), height = input1.size(1), width = input1.size(2), in_channels = input1.size(3); 16 | int out_channels = (max_displacement * 2 + 1) * (max_displacement * 2 + 1); 17 | torch::Tensor output = torch::zeros({batch_size, out_channels, height, width}, torch::device(input1.device())); 18 | 19 | correlation_forward_kernel_wrapper(output.data_ptr(), input1.data_ptr(), input2.data_ptr(), 20 | batch_size, in_channels, height, width, max_displacement); 21 | return output; 22 | } 23 | 24 | std::pair correlation_backward_cuda(torch::Tensor& grad_output, torch::Tensor& input1, torch::Tensor& input2, int max_displacement) { 25 | int batch_size = input1.size(0), height = input1.size(1), width = input1.size(2), in_channels = input1.size(3); 26 | torch::Tensor grad_input1 = torch::empty({batch_size, in_channels, height, width}, torch::device(input1.device())); 27 | torch::Tensor grad_input2 = torch::empty({batch_size, in_channels, height, width}, torch::device(input2.device())); 28 | 29 | correlation_backward_kernel_wrapper(grad_output.data_ptr(), 30 | grad_input1.data_ptr(), grad_input2.data_ptr(), 31 | input1.data_ptr(), input2.data_ptr(), 32 | batch_size, in_channels, height, width, max_displacement); 33 | 34 | return std::pair(grad_input1, grad_input2); 35 | } 36 | 37 | #ifdef TORCH_EXTENSION_NAME 38 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 39 | m.def("_correlation_forward_cuda", &correlation_forward_cuda, "Correlation forward pass (CUDA)"); 40 | m.def("_correlation_backward_cuda", &correlation_backward_cuda, "Correlation backward pass (CUDA)"); 41 | } 42 | #endif 43 | -------------------------------------------------------------------------------- /models/csrc/correlation/correlation.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | torch::Tensor correlation_forward_cuda(torch::Tensor& input1, torch::Tensor& input2, int max_displacement); 6 | std::pair correlation_backward_cuda(torch::Tensor& grad_output, torch::Tensor& input1, torch::Tensor& input2, int max_displacement); 7 | -------------------------------------------------------------------------------- /models/csrc/correlation/correlation_backward_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | __global__ void correlation_backward_input1_kernel(float* __restrict__ grad_input1, 5 | const float* __restrict__ grad_output, 6 | const float* __restrict__ input2, 7 | int in_channels, int height, int width, int max_displacement) { 8 | int n = blockIdx.x, y1 = blockIdx.y, x1 = blockIdx.z, c = threadIdx.x; 9 | 10 | int displacement_size = 2 * max_displacement + 1; 11 | int out_channels = displacement_size * displacement_size; 12 | 13 | int in_stride0 = height * width * in_channels; 14 | int in_stride1 = width * in_channels; 15 | int in_stride2 = in_channels; 16 | 17 | int out_stride0 = out_channels * height * width; 18 | int out_stride1 = height * width; 19 | int out_stride2 = width; 20 | 21 | int grad_in_stride0 = in_channels * height* width; 22 | int grad_in_stride1 = height * width; 23 | int grad_in_stride2 = width; 24 | 25 | float sum1 = 0; 26 | for (int tx = -max_displacement; tx <= max_displacement; ++tx) 27 | for (int ty = -max_displacement; ty <= max_displacement; ++ty) { 28 | int x2 = x1 + ty, y2 = y1 + tx; 29 | if (x2 < 0 || y2 < 0 || x2 >= width || y2 >= height) continue; 30 | int tc = (tx + max_displacement) * displacement_size + (ty + max_displacement); 31 | int idx = n * out_stride0 + tc * out_stride1 + y1 * out_stride2 + x1; 32 | int idx2 = n * in_stride0 + y2 * in_stride1 + x2 * in_stride2 + c; 33 | sum1 += grad_output[idx] * input2[idx2]; 34 | } 35 | 36 | int idx1 = n * grad_in_stride0 + c * grad_in_stride1 + y1 * grad_in_stride2 + x1; 37 | grad_input1[idx1] = sum1 / in_channels; 38 | } 39 | 40 | __global__ void correlation_backward_input2_kernel(float* __restrict__ grad_input2, 41 | const float* __restrict__ grad_output, 42 | const float* __restrict__ input1, 43 | int in_channels, int height, int width, int max_displacement) { 44 | int n = blockIdx.x, y2 = blockIdx.y, x2 = blockIdx.z, c = threadIdx.x; 45 | 46 | int displacement_size = 2 * max_displacement + 1; 47 | int out_channels = displacement_size * displacement_size; 48 | 49 | int in_stride0 = height * width * in_channels; 50 | int in_stride1 = width * in_channels; 51 | int in_stride2 = in_channels; 52 | 53 | int out_stride0 = out_channels * height * width; 54 | int out_stride1 = height * width; 55 | int out_stride2 = width; 56 | 57 | int grad_in_stride0 = in_channels * height* width; 58 | int grad_in_stride1 = height * width; 59 | int grad_in_stride2 = width; 60 | 61 | float sum2 = 0; 62 | for (int tx = -max_displacement; tx <= max_displacement; ++tx) 63 | for (int ty = -max_displacement; ty <= max_displacement; ++ty) { 64 | int x1 = x2 - ty, y1 = y2 - tx; 65 | if (x1 < 0 || y1 < 0 || x1 >= width || y1 >= height) continue; 66 | int tc = (tx + max_displacement) * displacement_size + (ty + max_displacement); 67 | int idx = n * out_stride0 + tc * out_stride1 + y1 * out_stride2 + x1; 68 | int idx1 = n * in_stride0 + y1 * in_stride1 + x1 * in_stride2 + c; 69 | sum2 += grad_output[idx] * input1[idx1]; 70 | } 71 | 72 | int idx2 = n * grad_in_stride0 + c * grad_in_stride1 + y2 * grad_in_stride2 + x2; 73 | grad_input2[idx2] = sum2 / in_channels; 74 | } 75 | 76 | void correlation_backward_kernel_wrapper(const float* grad_output, 77 | float* grad_input1, float* grad_input2, 78 | const float* input1, const float* input2, 79 | int n_batches, int in_channels, int height, int width, int max_displacement) { 80 | dim3 totalBlocksCorr(n_batches, height, width); 81 | correlation_backward_input1_kernel<<>>( 82 | grad_input1, grad_output, input2, 83 | in_channels, height, width, max_displacement 84 | ); 85 | correlation_backward_input2_kernel<<>>( 86 | grad_input2, grad_output, input1, 87 | in_channels, height, width, max_displacement 88 | ); 89 | } 90 | -------------------------------------------------------------------------------- /models/csrc/correlation/correlation_forward_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #define BLOCK_SIZE 32 4 | 5 | __forceinline__ __device__ float warpReduceSum(float val) { 6 | for (int offset = 16; offset > 0; offset /= 2) 7 | val += __shfl_down_sync(0xffffffff, val, offset); 8 | return val; 9 | } 10 | 11 | __global__ void correlation_forward_kernel(float* __restrict__ output, 12 | const float* __restrict__ input1, 13 | const float* __restrict__ input2, 14 | int in_channels, int height, int width, int max_displacement) { 15 | int n = blockIdx.x, y1 = blockIdx.y, x1 = blockIdx.z; 16 | 17 | int displacement_size = 2 * max_displacement + 1; 18 | int out_channels = displacement_size * displacement_size; 19 | 20 | int in_stride0 = height * width * in_channels; 21 | int in_stride1 = width * in_channels; 22 | int in_stride2 = in_channels; 23 | 24 | int out_stride0 = out_channels * height * width; 25 | int out_stride1 = height * width; 26 | int out_stride2 = width; 27 | 28 | for (int tx = -max_displacement; tx <= max_displacement; ++tx) 29 | for (int ty = -max_displacement; ty <= max_displacement; ++ty) { 30 | int x2 = x1 + ty, y2 = y1 + tx; 31 | if (x2 < 0 || y2 < 0 || x2 >= width || y2 >= height) continue; 32 | 33 | float sum = 0.0f; 34 | for (int c = threadIdx.x; c < in_channels; c += BLOCK_SIZE) { 35 | int idx1 = n * in_stride0 + y1 * in_stride1 + x1 * in_stride2 + c; 36 | int idx2 = n * in_stride0 + y2 * in_stride1 + x2 * in_stride2 + c; 37 | sum += input1[idx1] * input2[idx2]; 38 | } 39 | 40 | __syncthreads(); 41 | sum = warpReduceSum(sum); 42 | 43 | if (threadIdx.x == 0) { 44 | int tc = (tx + max_displacement) * displacement_size + (ty + max_displacement); 45 | int idx = n * out_stride0 + tc * out_stride1 + y1 * out_stride2 + x1; 46 | output[idx] = sum / in_channels; 47 | } 48 | } 49 | } 50 | 51 | void correlation_forward_kernel_wrapper(float* output, const float* input1, const float* input2, 52 | int n_batches, int in_channels, int height, int width, int max_displacement) { 53 | dim3 number_of_blocks(n_batches, height, width); 54 | correlation_forward_kernel<<>>(output, input1, input2, in_channels, height, width, max_displacement); 55 | } 56 | -------------------------------------------------------------------------------- /models/csrc/correlation/correlation_test.cpp: -------------------------------------------------------------------------------- 1 | #include "correlation.h" 2 | #include 3 | #include 4 | using namespace std; 5 | 6 | void _checkCudaErrors(cudaError_t result, char const *const func, const char *const file, int const line) { 7 | if (result != cudaSuccess) { 8 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, 9 | static_cast(result), cudaGetErrorName(result), func); 10 | cudaDeviceReset(); 11 | exit(EXIT_FAILURE); 12 | } 13 | } 14 | #define checkCudaErrors(val) _checkCudaErrors((val), #val, __FILE__, __LINE__) 15 | 16 | std::vector run_cuda_implementation(torch::Tensor input1, torch::Tensor input2, torch::Tensor grad_output, int max_displacement) { 17 | // nchw -> nhwc 18 | input1 = input1.permute({0, 2, 3, 1}).contiguous(); 19 | input2 = input2.permute({0, 2, 3, 1}).contiguous(); 20 | 21 | torch::Tensor output = correlation_forward_cuda(input1, input2, max_displacement); 22 | std::pair grads_cuda = correlation_backward_cuda(grad_output, input1, input2, max_displacement); 23 | 24 | return {output, grads_cuda.first, grads_cuda.second}; 25 | } 26 | 27 | std::vector run_naive_implementation(torch::Tensor input1, torch::Tensor input2, torch::Tensor grad_output, int max_displacement) { 28 | int batch_size = input1.size(0), in_channels = input1.size(1), height = input1.size(2), width = input1.size(3); 29 | torch::Tensor padded_input2 = torch::nn::functional::detail::pad(input2, {max_displacement, max_displacement, max_displacement, max_displacement}, torch::kConstant, 0); 30 | 31 | std::vector cost_volumes; 32 | for (int i = 0; i < 2 * max_displacement + 1; i++) 33 | for (int j = 0; j < 2 * max_displacement + 1; j++) { 34 | torch::Tensor cost = input1 * padded_input2.slice(2, i, i + height).slice(3, j, j + width); 35 | cost_volumes.push_back(torch::mean(cost, 1, true)); 36 | } 37 | 38 | torch::Tensor output = torch::cat(cost_volumes, 1); 39 | output.backward(grad_output); 40 | 41 | return {output, input1.grad(), input2.grad()}; 42 | } 43 | 44 | int main() { 45 | constexpr int batch_size = 32; 46 | constexpr int in_channels = 128; 47 | constexpr int height = 144; 48 | constexpr int width = 240; 49 | constexpr int max_displacement = 4; 50 | constexpr int out_channels = (max_displacement * 2 + 1) * (max_displacement * 2 + 1); 51 | 52 | if (!torch::cuda::is_available()) { 53 | cout << "CUDA is not available, exiting..." << endl; 54 | return 1; 55 | } 56 | 57 | torch::manual_seed(0); 58 | torch::Tensor input1 = torch::rand({batch_size, in_channels, height, width}, torch::requires_grad().device("cuda")); 59 | torch::Tensor input2 = torch::rand({batch_size, in_channels, height, width}, torch::requires_grad().device("cuda")); 60 | torch::Tensor grad_output = torch::rand({batch_size, out_channels, height, width}, torch::requires_grad(false).device("cuda")); 61 | 62 | cout << "Running NAIVE implementation of correlation... " << flush; 63 | auto naive_t1 = chrono::high_resolution_clock::now(); 64 | std::vector results_naive = run_naive_implementation(input1, input2, grad_output, max_displacement); 65 | checkCudaErrors(cudaDeviceSynchronize()); 66 | auto naive_t2 = chrono::high_resolution_clock::now(); 67 | cout << "(" << chrono::duration_cast(naive_t2 - naive_t1).count() << "ms)" << endl; 68 | 69 | // warm up... 70 | for (int t = 0; t < 2; t++) { 71 | run_cuda_implementation(input1, input2, grad_output, max_displacement); 72 | checkCudaErrors(cudaDeviceSynchronize()); 73 | } 74 | 75 | cout << "Running CUDA implementation of correlation... " << flush; 76 | auto cuda_t1 = chrono::high_resolution_clock::now(); 77 | std::vector results_cuda = run_cuda_implementation(input1, input2, grad_output, max_displacement); 78 | checkCudaErrors(cudaDeviceSynchronize()); 79 | auto cuda_t2 = chrono::high_resolution_clock::now(); 80 | cout << "(" << chrono::duration_cast(cuda_t2 - cuda_t1).count() << "ms)" << endl; 81 | 82 | float diff = torch::mean(torch::abs(results_cuda[0].cpu() - results_naive[0].cpu())).data_ptr()[0]; 83 | cout << "Checking forward results... " << (diff < 1e-6 ? "OK" : "Failed") << endl; 84 | 85 | diff = torch::mean(torch::abs(results_cuda[1].cpu() - results_naive[1].cpu())).data_ptr()[0]; 86 | cout << "Checking backward results for input1... " << (diff < 1e-6 ? "OK" : "Failed") << endl; 87 | 88 | diff = torch::mean(torch::abs(results_cuda[2].cpu() - results_naive[2].cpu())).data_ptr()[0]; 89 | cout << "Checking backward results for input2... " << (diff < 1e-6 ? "OK" : "Failed") << endl; 90 | 91 | return 0; 92 | } 93 | -------------------------------------------------------------------------------- /models/csrc/furthest_point_sampling/furthest_point_sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "furthest_point_sampling.h" 2 | 3 | void furthest_point_sampling_kernel_wrapper(float* batched_points_xyz, float* batched_dists_temp, int n_batch, int n_points, int n_samples, int64_t* batched_furthest_indices); 4 | 5 | torch::Tensor furthest_point_sampling_cuda(torch::Tensor points_xyz, const int n_samples) { 6 | TORCH_CHECK(points_xyz.is_contiguous(), "points_xyz must be a contiguous tensor"); 7 | TORCH_CHECK(points_xyz.is_cuda(), "points_xyz must be a CUDA tensor"); 8 | TORCH_CHECK(points_xyz.scalar_type() == torch::ScalarType::Float, "points_xyz must be a float tensor"); 9 | 10 | int64_t batch_size = points_xyz.size(0), n_points = points_xyz.size(1); 11 | torch::Tensor furthest_indices = torch::empty({ batch_size, n_samples }, torch::TensorOptions().dtype(torch::kInt64).device(points_xyz.device())); 12 | torch::Tensor dists_temp = torch::ones({batch_size, n_points}, torch::TensorOptions().dtype(torch::kFloat32).device(points_xyz.device())) * 1e10; 13 | furthest_point_sampling_kernel_wrapper(points_xyz.data_ptr(), dists_temp.data_ptr(), batch_size, n_points, n_samples, furthest_indices.data_ptr()); 14 | 15 | return furthest_indices; 16 | } 17 | 18 | #ifdef TORCH_EXTENSION_NAME 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("_furthest_point_sampling_cuda", &furthest_point_sampling_cuda, "CUDA implementation of furthest-point-sampling (FPS)"); 21 | } 22 | #endif -------------------------------------------------------------------------------- /models/csrc/furthest_point_sampling/furthest_point_sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | torch::Tensor furthest_point_sampling_cuda(torch::Tensor points_xyz, const int n_samples); 6 | -------------------------------------------------------------------------------- /models/csrc/furthest_point_sampling/furthest_point_sampling_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define argmaxReduceMacro(arr, arr_idx, idx1, idx2) {\ 6 | if (arr[idx1] <= arr[idx2]) {\ 7 | arr[idx1] = arr[idx2];\ 8 | arr_idx[idx1] = arr_idx[idx2];\ 9 | }\ 10 | } 11 | 12 | template 13 | __device__ void warpReduce(volatile float* max_values, volatile int* max_values_idx, int tid) { 14 | if (block_size >= 64) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 32); 15 | if (block_size >= 32) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 16); 16 | if (block_size >= 16) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 8); 17 | if (block_size >= 8) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 4); 18 | if (block_size >= 4) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 2); 19 | if (block_size >= 2) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 1); 20 | } 21 | 22 | template 23 | __device__ int64_t argmax(float* max_values, int* max_values_idx) { 24 | int tid = threadIdx.x; 25 | if (block_size >= 1024) { if (tid < 512) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 512); __syncthreads(); } 26 | if (block_size >= 512) { if (tid < 256) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 256); __syncthreads(); } 27 | if (block_size >= 256) { if (tid < 128) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 128); __syncthreads(); } 28 | if (block_size >= 128) { if (tid < 64) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 64); __syncthreads(); } 29 | if (tid < 32) warpReduce(max_values, max_values_idx, tid); 30 | __syncthreads(); 31 | return max_values_idx[0]; 32 | } 33 | 34 | template 35 | __global__ void furthest_point_sampling_kernel(float* __restrict__ batched_points_xyz, 36 | float* __restrict__ batched_dists_temp, 37 | int n_points, int n_samples, 38 | int64_t* __restrict__ batched_furthest_indices) { 39 | int bid = blockIdx.x, tid = threadIdx.x; 40 | 41 | float* __restrict__ points_xyz = batched_points_xyz + bid * n_points * 3; 42 | float* __restrict__ dists_temp = batched_dists_temp + bid * n_points; 43 | int64_t* __restrict__ furthest_indices = batched_furthest_indices + bid * n_samples; 44 | 45 | __shared__ float max_dists[block_size]; 46 | __shared__ int max_dists_idx[block_size]; 47 | 48 | int64_t curr_furthest_idx = 0; 49 | for (int s = 0; s < n_samples; s++) { 50 | if (tid == 0) furthest_indices[s] = curr_furthest_idx; 51 | 52 | float x1 = points_xyz[curr_furthest_idx * 3 + 0]; 53 | float y1 = points_xyz[curr_furthest_idx * 3 + 1]; 54 | float z1 = points_xyz[curr_furthest_idx * 3 + 2]; 55 | 56 | float local_max_dist = -1; 57 | int local_max_dist_idx = 0; 58 | 59 | for (int i = tid; i < n_points; i += block_size) { 60 | float x2 = points_xyz[i * 3 + 0]; 61 | float y2 = points_xyz[i * 3 + 1]; 62 | float z2 = points_xyz[i * 3 + 2]; 63 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 64 | float new_dist = min(dists_temp[i], d); 65 | if (new_dist > local_max_dist) { 66 | local_max_dist = new_dist; 67 | local_max_dist_idx = i; 68 | } 69 | dists_temp[i] = new_dist; 70 | } 71 | 72 | max_dists[tid] = local_max_dist; 73 | max_dists_idx[tid] = local_max_dist_idx; 74 | 75 | __syncthreads(); 76 | 77 | curr_furthest_idx = argmax(max_dists, max_dists_idx); 78 | } 79 | } 80 | 81 | void furthest_point_sampling_kernel_wrapper(float* batched_points_xyz, float* batched_dists_temp, 82 | int n_batch, int n_points, int n_samples, 83 | int64_t* batched_furthest_indices) { 84 | furthest_point_sampling_kernel<1024> <<>> (batched_points_xyz, batched_dists_temp, n_points, n_samples, batched_furthest_indices); 85 | } -------------------------------------------------------------------------------- /models/csrc/furthest_point_sampling/furthest_point_sampling_test.cpp: -------------------------------------------------------------------------------- 1 | #include "furthest_point_sampling.h" 2 | #include 3 | #include 4 | using namespace std; 5 | 6 | void _checkCudaErrors(cudaError_t result, char const *const func, const char *const file, int const line) { 7 | if (result != cudaSuccess) { 8 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, 9 | static_cast(result), cudaGetErrorName(result), func); 10 | cudaDeviceReset(); 11 | exit(EXIT_FAILURE); 12 | } 13 | } 14 | #define checkCudaErrors(val) _checkCudaErrors((val), #val, __FILE__, __LINE__) 15 | 16 | torch::Tensor furthest_point_sampling_std(torch::Tensor points_xyz, const int n_samples) { 17 | int64_t batch_size = points_xyz.size(0), n_points = points_xyz.size(1); 18 | torch::Tensor furthest_indices = torch::empty({batch_size, n_samples}, torch::TensorOptions().dtype(torch::kInt64)); 19 | 20 | for (int64_t b = 0; b < batch_size; b++) { 21 | torch::Tensor dists = torch::ones(n_points) * 1e10; 22 | int curr_furthest_idx = 0; 23 | for (int s = 0; s < n_samples; s++) { 24 | furthest_indices[b][s] = curr_furthest_idx; 25 | dists = torch::min(dists, torch::sum((points_xyz[b] - points_xyz[b][curr_furthest_idx]).pow(2), -1)); 26 | curr_furthest_idx = dists.argmax().data_ptr()[0]; 27 | } 28 | } 29 | 30 | return furthest_indices; 31 | } 32 | 33 | int main() { 34 | constexpr int batch_size = 64; 35 | constexpr int n_points = 4096; 36 | constexpr int n_samples = 1024; 37 | 38 | if (!torch::cuda::is_available()) { 39 | cout << "CUDA is not available, exiting..." << endl; 40 | return 1; 41 | } 42 | 43 | torch::manual_seed(0); 44 | torch::Tensor points_xyz = torch::rand({batch_size, n_points, 3}), points_xyz_cuda = points_xyz.cuda(); 45 | 46 | // warm up... 47 | for (int t = 0; t < 3; t++) furthest_point_sampling_cuda(points_xyz_cuda, n_samples); 48 | checkCudaErrors(cudaDeviceSynchronize()); 49 | 50 | cout << "Running furthest-point-sampling on GPU... " << flush; 51 | auto t1 = chrono::high_resolution_clock::now(); 52 | torch::Tensor indices_gpu = furthest_point_sampling_cuda(points_xyz_cuda, n_samples); 53 | checkCudaErrors(cudaDeviceSynchronize()); 54 | auto t2 = chrono::high_resolution_clock::now(); 55 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl; 56 | 57 | cout << "Running furthest-point-sampling on CPU... " << flush; 58 | t1 = chrono::high_resolution_clock::now(); 59 | torch::Tensor indices_std = furthest_point_sampling_std(points_xyz, n_samples); 60 | t2 = chrono::high_resolution_clock::now(); 61 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl; 62 | 63 | cout << "Checking results... " << (torch::equal(indices_std.cpu(), indices_gpu.cpu()) ? "OK" : "Failed") << endl; 64 | } -------------------------------------------------------------------------------- /models/csrc/k_nearest_neighbor/k_nearest_neighbor.cpp: -------------------------------------------------------------------------------- 1 | #include "k_nearest_neighbor.h" 2 | 3 | void k_nearest_neighbor_2d_kernel_wrapper(int b, int n, int m, int k, const float *query_xyz, const float *input_xyz, int64_t *indices); 4 | void k_nearest_neighbor_3d_kernel_wrapper(int b, int n, int m, int k, const float *query_xyz, const float *input_xyz, int64_t *indices); 5 | 6 | torch::Tensor k_nearest_neighbor_cuda(torch::Tensor input_xyz, torch::Tensor query_xyz, int k) { 7 | TORCH_CHECK(input_xyz.is_contiguous(), "input_xyz must be a contiguous tensor"); 8 | TORCH_CHECK(input_xyz.is_cuda(), "input_xyz must be a CUDA tensor"); 9 | TORCH_CHECK(input_xyz.scalar_type() == torch::ScalarType::Float, "input_xyz must be a float tensor"); 10 | 11 | TORCH_CHECK(query_xyz.is_contiguous(), "query_xyz must be a contiguous tensor"); 12 | TORCH_CHECK(query_xyz.is_cuda(), "query_xyz must be a CUDA tensor"); 13 | TORCH_CHECK(query_xyz.scalar_type() == torch::ScalarType::Float, "query_xyz must be a float tensor"); 14 | 15 | int batch_size = query_xyz.size(0), n_queries = query_xyz.size(1), n_inputs = input_xyz.size(1), n_dim = query_xyz.size(2); 16 | torch::Tensor indices = torch::zeros({batch_size, n_queries, k}, torch::device(query_xyz.device()).dtype(torch::ScalarType::Long)); 17 | 18 | if (n_dim == 2) 19 | k_nearest_neighbor_2d_kernel_wrapper(batch_size, n_queries, n_inputs, k, query_xyz.data_ptr(), input_xyz.data_ptr(), indices.data_ptr()); 20 | else 21 | k_nearest_neighbor_3d_kernel_wrapper(batch_size, n_queries, n_inputs, k, query_xyz.data_ptr(), input_xyz.data_ptr(), indices.data_ptr()); 22 | 23 | return indices; 24 | } 25 | 26 | #ifdef TORCH_EXTENSION_NAME 27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 28 | m.def("_k_nearest_neighbor_cuda", &k_nearest_neighbor_cuda, "CUDA implementation of KNN"); 29 | } 30 | #endif -------------------------------------------------------------------------------- /models/csrc/k_nearest_neighbor/k_nearest_neighbor.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | torch::Tensor k_nearest_neighbor_cuda(torch::Tensor input_xyz, torch::Tensor query_xyz, int k); -------------------------------------------------------------------------------- /models/csrc/k_nearest_neighbor/k_nearest_neighbor_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #define THREADS_PER_BLOCK 256 6 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 7 | 8 | __global__ void k_nearest_neighbor_2d_kernel(int b, int n, int m, int k, 9 | const float *__restrict__ query_xyz, 10 | const float *__restrict__ input_xyz, 11 | int64_t *__restrict__ indices) { 12 | int bs_idx = blockIdx.y; 13 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 14 | if (bs_idx >= b || pt_idx >= n) return; 15 | 16 | query_xyz += bs_idx * n * 2 + pt_idx * 2; 17 | input_xyz += bs_idx * m * 2; 18 | indices += bs_idx * n * k + pt_idx * k; 19 | 20 | float ux = query_xyz[0]; 21 | float uy = query_xyz[1]; 22 | 23 | // initialize 24 | float nn_dists[32]; int nn_indices[32]; 25 | for (int i = 0; i < 32; i++) { 26 | nn_dists[i] = 1e9; 27 | nn_indices[i] = 0; 28 | } 29 | 30 | for (int idx = 0; idx < m; idx++) { 31 | float x = input_xyz[idx * 2 + 0]; 32 | float y = input_xyz[idx * 2 + 1]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y); 34 | if (d > nn_dists[k - 1]) continue; 35 | 36 | int j = min(idx, k - 1); 37 | while (j > 0 && nn_dists[j - 1] > d) { 38 | nn_dists[j] = nn_dists[j - 1]; 39 | nn_indices[j] = nn_indices[j - 1]; 40 | j--; 41 | } 42 | 43 | nn_dists[j] = d; 44 | nn_indices[j] = idx; 45 | } 46 | 47 | for (int i = 0; i < k; i++) 48 | indices[i] = nn_indices[i]; 49 | } 50 | 51 | __global__ void k_nearest_neighbor_3d_kernel(int b, int n, int m, int k, 52 | const float *__restrict__ query_xyz, 53 | const float *__restrict__ input_xyz, 54 | int64_t *__restrict__ indices) { 55 | int bs_idx = blockIdx.y; 56 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 57 | if (bs_idx >= b || pt_idx >= n) return; 58 | 59 | query_xyz += bs_idx * n * 3 + pt_idx * 3; 60 | input_xyz += bs_idx * m * 3; 61 | indices += bs_idx * n * k + pt_idx * k; 62 | 63 | float ux = query_xyz[0]; 64 | float uy = query_xyz[1]; 65 | float uz = query_xyz[2]; 66 | 67 | // initialize 68 | float nn_dists[32]; int nn_indices[32]; 69 | for (int i = 0; i < 32; i++) { 70 | nn_dists[i] = 1e9; 71 | nn_indices[i] = 0; 72 | } 73 | 74 | for (int idx = 0; idx < m; idx++) { 75 | float x = input_xyz[idx * 3 + 0]; 76 | float y = input_xyz[idx * 3 + 1]; 77 | float z = input_xyz[idx * 3 + 2]; 78 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 79 | if (d > nn_dists[k - 1]) continue; 80 | 81 | int j = min(idx, k - 1); 82 | while (j > 0 && nn_dists[j - 1] > d) { 83 | nn_dists[j] = nn_dists[j - 1]; 84 | nn_indices[j] = nn_indices[j - 1]; 85 | j--; 86 | } 87 | 88 | nn_dists[j] = d; 89 | nn_indices[j] = idx; 90 | } 91 | 92 | for (int i = 0; i < k; i++) 93 | indices[i] = nn_indices[i]; 94 | } 95 | 96 | void k_nearest_neighbor_2d_kernel_wrapper(int b, int n, int m, int k, 97 | const float *query_xyz, 98 | const float *input_xyz, 99 | int64_t *indices) { 100 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 101 | dim3 threads(THREADS_PER_BLOCK); 102 | k_nearest_neighbor_2d_kernel<<>>(b, n, m, k, query_xyz, input_xyz, indices); 103 | } 104 | 105 | void k_nearest_neighbor_3d_kernel_wrapper(int b, int n, int m, int k, 106 | const float *query_xyz, 107 | const float *input_xyz, 108 | int64_t *indices) { 109 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row) 110 | dim3 threads(THREADS_PER_BLOCK); 111 | k_nearest_neighbor_3d_kernel<<>>(b, n, m, k, query_xyz, input_xyz, indices); 112 | } 113 | -------------------------------------------------------------------------------- /models/csrc/k_nearest_neighbor/k_nearest_neighbor_test.cpp: -------------------------------------------------------------------------------- 1 | #include "k_nearest_neighbor.h" 2 | #include 3 | #include 4 | using namespace std; 5 | 6 | void _checkCudaErrors(cudaError_t result, char const *const func, const char *const file, int const line) { 7 | if (result != cudaSuccess) { 8 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, 9 | static_cast(result), cudaGetErrorName(result), func); 10 | cudaDeviceReset(); 11 | exit(EXIT_FAILURE); 12 | } 13 | } 14 | #define checkCudaErrors(val) _checkCudaErrors((val), #val, __FILE__, __LINE__) 15 | 16 | torch::Tensor k_nearest_neighbor_std(torch::Tensor input_xyz, torch::Tensor query_xyz, int k) { 17 | int64_t batch_size = input_xyz.size(0), n_points_1 = input_xyz.size(1), n_points_2 = query_xyz.size(1); 18 | torch::Tensor dists = -2 * torch::matmul(query_xyz, input_xyz.permute({0, 2, 1})); 19 | dists += torch::sum(query_xyz.pow(2), -1).view({batch_size, n_points_2, 1}); 20 | dists += torch::sum(input_xyz.pow(2), -1).view({batch_size, 1, n_points_1}); 21 | return std::get<1>(dists.topk(k, 2, false)); 22 | } 23 | 24 | int main() { 25 | constexpr int batch_size = 8; 26 | constexpr int n_points_input = 8192; 27 | constexpr int n_points_query = 8192; 28 | constexpr int k = 16; 29 | constexpr int dim = 3; 30 | 31 | if (!torch::cuda::is_available()) { 32 | cout << "CUDA is not available, exiting..." << endl; 33 | return 1; 34 | } 35 | 36 | torch::manual_seed(0); 37 | torch::Tensor input_xyz = torch::rand({batch_size, n_points_input, dim}).cuda(); 38 | torch::Tensor query_xyz = torch::rand({batch_size, n_points_query, dim}).cuda(); 39 | 40 | // warm up... 41 | for (int t = 0; t < 3; t++) { 42 | k_nearest_neighbor_cuda(input_xyz, query_xyz, k); 43 | k_nearest_neighbor_std(input_xyz, query_xyz, k); 44 | } 45 | checkCudaErrors(cudaDeviceSynchronize()); 46 | 47 | cout << "Running KNN using custom CUDA implementation... " << flush; 48 | auto t1 = chrono::high_resolution_clock::now(); 49 | torch::Tensor indices_gpu = k_nearest_neighbor_cuda(input_xyz, query_xyz, k); 50 | checkCudaErrors(cudaDeviceSynchronize()); 51 | auto t2 = chrono::high_resolution_clock::now(); 52 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl; 53 | 54 | cout << "Running KNN using Torch's API... " << flush; 55 | t1 = chrono::high_resolution_clock::now(); 56 | torch::Tensor indices_std = k_nearest_neighbor_std(input_xyz, query_xyz, k); 57 | checkCudaErrors(cudaDeviceSynchronize()); 58 | t2 = chrono::high_resolution_clock::now(); 59 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl; 60 | 61 | torch::Scalar diff_num = (indices_std.cpu() != indices_gpu.cpu()).sum().item(); 62 | torch::Scalar total_num = batch_size * n_points_query * k; 63 | cout << "Checking results... " << diff_num << " of " << total_num << " elements are mismatched." << endl; 64 | 65 | return 0; 66 | } -------------------------------------------------------------------------------- /models/csrc/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | def get_ext_modules(): 6 | return [ 7 | CUDAExtension( 8 | name='_correlation_cuda', 9 | sources=[ 10 | 'correlation/correlation.cpp', 11 | 'correlation/correlation_forward_kernel.cu', 12 | 'correlation/correlation_backward_kernel.cu' 13 | ], 14 | include_dirs=['correlation'] 15 | ), 16 | CUDAExtension( 17 | name='_furthest_point_sampling_cuda', 18 | sources=[ 19 | 'furthest_point_sampling/furthest_point_sampling.cpp', 20 | 'furthest_point_sampling/furthest_point_sampling_kernel.cu' 21 | ], 22 | include_dirs=['furthest_point_sampling'] 23 | ), 24 | CUDAExtension( 25 | name='_k_nearest_neighbor_cuda', 26 | sources=[ 27 | 'k_nearest_neighbor/k_nearest_neighbor.cpp', 28 | 'k_nearest_neighbor/k_nearest_neighbor_kernel.cu' 29 | ], 30 | include_dirs=['k_nearest_neighbor'] 31 | ) 32 | ] 33 | 34 | 35 | setup( 36 | name='csrc', 37 | ext_modules=get_ext_modules(), 38 | cmdclass={'build_ext': BuildExtension} 39 | ) 40 | -------------------------------------------------------------------------------- /models/csrc/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional 3 | 4 | try: 5 | from ._correlation_cuda import _correlation_forward_cuda 6 | from ._correlation_cuda import _correlation_backward_cuda 7 | from ._furthest_point_sampling_cuda import _furthest_point_sampling_cuda 8 | from ._k_nearest_neighbor_cuda import _k_nearest_neighbor_cuda 9 | except ImportError as e: 10 | _correlation_forward_cuda = None 11 | _correlation_backward_cuda = None 12 | _furthest_point_sampling_cuda = None 13 | _k_nearest_neighbor_cuda = None 14 | print('Failed to load one or more CUDA extensions, performance may be hurt.') 15 | print('Error message:', e) 16 | 17 | 18 | class CorrelationFunction(torch.autograd.Function): 19 | @staticmethod 20 | def forward(ctx, input1, input2, max_displacement): 21 | ctx.save_for_backward(input1, input2) 22 | ctx.max_displacement = max_displacement 23 | assert callable(_correlation_forward_cuda) 24 | return _correlation_forward_cuda(input1, input2, max_displacement) 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | input1, input2 = ctx.saved_tensors 29 | 30 | assert callable(_correlation_backward_cuda) 31 | grad_input1, grad_input2 = _correlation_backward_cuda( 32 | grad_output, input1, input2, ctx.max_displacement 33 | ) 34 | grad_input1 = grad_input1.permute(0, 2, 3, 1).contiguous() 35 | grad_input2 = grad_input2.permute(0, 2, 3, 1).contiguous() 36 | 37 | return grad_input1, grad_input2, None 38 | 39 | 40 | def squared_distance(xyz1: torch.Tensor, xyz2: torch.Tensor): 41 | """ 42 | Calculate the Euclidean squared distance between every two points. 43 | :param xyz1: the 1st set of points, [batch_size, n_points_1, 3] 44 | :param xyz2: the 2nd set of points, [batch_size, n_points_2, 3] 45 | :return: squared distance between every two points, [batch_size, n_points_1, n_points_2] 46 | """ 47 | assert xyz1.shape[-1] == xyz2.shape[-1] and xyz1.shape[-1] <= 3 # assert channel_last 48 | batch_size, n_points1, n_points2 = xyz1.shape[0], xyz1.shape[1], xyz2.shape[1] 49 | dist = -2 * torch.matmul(xyz1, xyz2.permute(0, 2, 1)) 50 | dist += torch.sum(xyz1 ** 2, -1).view(batch_size, n_points1, 1) 51 | dist += torch.sum(xyz2 ** 2, -1).view(batch_size, 1, n_points2) 52 | return dist 53 | 54 | 55 | def correlation2d(input1: torch.Tensor, input2: torch.Tensor, max_displacement: int, cpp_impl=True): 56 | def _correlation_py(_input1, _input2, _max_displacement): 57 | height, width = _input1.shape[2:] 58 | _input2 = torch.nn.functional.pad(_input2, [_max_displacement] * 4) 59 | cost_volumes = [] 60 | for i in range(2 * _max_displacement + 1): 61 | for j in range(2 * _max_displacement + 1): 62 | cost_volume = _input1 * _input2[:, :, i:(i + height), j:(j + width)] 63 | cost_volume = torch.mean(cost_volume, 1, keepdim=True) 64 | cost_volumes.append(cost_volume) 65 | return torch.cat(cost_volumes, 1) 66 | 67 | if cpp_impl and callable(_correlation_forward_cuda) and callable(_correlation_backward_cuda) and input1.is_cuda and input2.is_cuda: 68 | input1 = input1.permute(0, 2, 3, 1).contiguous().float() 69 | input2 = input2.permute(0, 2, 3, 1).contiguous().float() 70 | return CorrelationFunction.apply(input1, input2, max_displacement) 71 | else: 72 | return _correlation_py(input1, input2, max_displacement) 73 | 74 | 75 | def furthest_point_sampling(xyz: torch.Tensor, n_samples: int, cpp_impl=True): 76 | """ 77 | Perform furthest point sampling on a set of points. 78 | :param xyz: a set of points, [batch_size, n_points, 3] 79 | :param n_samples: number of samples, int 80 | :param cpp_impl: whether to use the CUDA C++ implementation of furthest-point-sampling 81 | :return: indices of sampled points, [batch_size, n_samples] 82 | """ 83 | def _furthest_point_sampling_py(_xyz: torch.Tensor, _n_samples: int): 84 | batch_size, n_points, _ = _xyz.shape 85 | farthest_indices = torch.zeros(batch_size, _n_samples, dtype=torch.int64, device=_xyz.device) 86 | distances = torch.ones(batch_size, n_points, device=_xyz.device) * 1e10 87 | batch_indices = torch.arange(batch_size, dtype=torch.int64, device=_xyz.device) 88 | curr_farthest_idx = torch.zeros(batch_size, dtype=torch.int64, device=_xyz.device) 89 | for i in range(_n_samples): 90 | farthest_indices[:, i] = curr_farthest_idx 91 | curr_farthest = _xyz[batch_indices, curr_farthest_idx, :].view(batch_size, 1, 3) 92 | new_distances = torch.sum((_xyz - curr_farthest) ** 2, -1) 93 | mask = new_distances < distances 94 | distances[mask] = new_distances[mask] 95 | curr_farthest_idx = torch.max(distances, -1)[1] 96 | return farthest_indices 97 | 98 | assert xyz.shape[2] == 3 and xyz.shape[1] > n_samples 99 | 100 | if cpp_impl and callable(_furthest_point_sampling_cuda) and xyz.is_cuda: 101 | return _furthest_point_sampling_cuda(xyz.contiguous(), n_samples).to(torch.int64) 102 | else: 103 | return _furthest_point_sampling_py(xyz, n_samples).to(torch.int64) 104 | 105 | 106 | def k_nearest_neighbor(input_xyz: torch.Tensor, query_xyz: torch.Tensor, k: int, cpp_impl=True): 107 | """ 108 | Calculate k-nearest neighbor for each query. 109 | :param input_xyz: a set of points, [batch_size, n_points, 3] or [batch_size, 3, n_points] 110 | :param query_xyz: a set of centroids, [batch_size, n_queries, 3] or [batch_size, 3, n_queries] 111 | :param k: int 112 | :param cpp_impl: whether to use the CUDA C++ implementation of k-nearest-neighbor 113 | :return: indices of k-nearest neighbors, [batch_size, n_queries, k] 114 | """ 115 | def _k_nearest_neighbor_py(_input_xyz: torch.Tensor, _query_xyz: torch.Tensor, _k: int): 116 | dists = squared_distance(_query_xyz, _input_xyz) 117 | return dists.topk(_k, dim=2, largest=False).indices.to(torch.long) 118 | 119 | if input_xyz.shape[1] <= 3: # channel_first to channel_last 120 | assert query_xyz.shape[1] == input_xyz.shape[1] 121 | input_xyz = input_xyz.transpose(1, 2).contiguous() 122 | query_xyz = query_xyz.transpose(1, 2).contiguous() 123 | 124 | if cpp_impl and callable(_k_nearest_neighbor_cuda) and input_xyz.is_cuda and query_xyz.is_cuda: 125 | return _k_nearest_neighbor_cuda(input_xyz.contiguous(), query_xyz.contiguous(), k) 126 | else: 127 | return _k_nearest_neighbor_py(input_xyz, query_xyz, k) 128 | -------------------------------------------------------------------------------- /models/losses2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import interpolate 4 | from .utils import backwarp_2d, resize_flow2d 5 | 6 | 7 | def calc_supervised_loss_2d(flows, target, cfgs): 8 | assert len(flows) <= len(cfgs.level_weights) 9 | 10 | total_loss = 0 11 | for pred, level_weight in zip(flows, cfgs.level_weights): 12 | assert pred.shape[1] == 2 # [B, 2, H, W] 13 | 14 | flow_mask = target[:, 2] > 0 15 | 16 | diff = torch.abs(resize_flow2d(pred, target.shape[2], target.shape[3]) - target[:, :2]) 17 | 18 | if cfgs.order == 'robust': 19 | loss_l1_map = torch.pow(diff.sum(dim=1) + 0.01, 0.4) 20 | loss_l1 = loss_l1_map[flow_mask].mean() 21 | total_loss += level_weight * loss_l1 22 | elif cfgs.order == 'l2-norm': 23 | loss_l2_map = torch.linalg.norm(diff, dim=1) 24 | loss_l2 = loss_l2_map[flow_mask].mean() 25 | total_loss += level_weight * loss_l2 26 | else: 27 | raise NotImplementedError 28 | 29 | return total_loss 30 | 31 | 32 | def calc_census_loss_2d(image1, image2, noc_mask=None, max_distance=1): 33 | """ 34 | Calculate photometric loss based on census transform. 35 | :param image1: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB 36 | :param image2: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB 37 | :param noc_mask: [N, 1, H, W] float tensor, ranging from 0 to 1 38 | :param max_distance: int 39 | """ 40 | def rgb_to_grayscale(image): 41 | grayscale = image[:, 0, :, :] * 0.2989 + \ 42 | image[:, 1, :, :] * 0.5870 + \ 43 | image[:, 2, :, :] * 0.1140 44 | return grayscale.unsqueeze(1) * 255.0 45 | 46 | def census_transform(gray_image): 47 | patch_size = 2 * max_distance + 1 48 | out_channels = patch_size * patch_size # 9 49 | weights = torch.eye(out_channels, dtype=gray_image.dtype, device=gray_image.device) 50 | weights = weights.view([out_channels, 1, patch_size, patch_size]) # [9, 1, 3, 3] 51 | patches = nn.functional.conv2d(gray_image, weights, padding=max_distance) 52 | result = patches - gray_image 53 | result = result / torch.sqrt(0.81 + torch.pow(result, 2)) 54 | return result 55 | 56 | if noc_mask is not None: 57 | image1 = noc_mask * image1 58 | image2 = noc_mask * image2 59 | 60 | gray_image1 = rgb_to_grayscale(image1) 61 | gray_image2 = rgb_to_grayscale(image2) 62 | 63 | t1 = census_transform(gray_image1) 64 | t2 = census_transform(gray_image2) 65 | 66 | dist = torch.pow(t1 - t2, 2) 67 | dist_norm = dist / (0.1 + dist) 68 | dist_mean = torch.mean(dist_norm, 1, keepdim=True) # instead of sum 69 | 70 | n, _, h, w = image1.shape 71 | inner = torch.ones([n, 1, h - 2 * max_distance, w - 2 * max_distance], dtype=image1.dtype, device=image1.device) 72 | inner_mask = nn.functional.pad(inner, [max_distance] * 4) 73 | loss = dist_mean * inner_mask 74 | 75 | if noc_mask is not None: 76 | return loss.mean() / (noc_mask.mean() + 1e-7) 77 | else: 78 | return loss.mean() 79 | 80 | 81 | @torch.cuda.amp.autocast(enabled=False) 82 | def calc_smooth_loss_2d(image, flow, derivative='first'): 83 | """ 84 | :param image: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB 85 | :param flow: [N, 2, H, W] float tensor 86 | :param derivative: 'first' or 'second' 87 | """ 88 | def gradient(inputs): 89 | dy = inputs[:, :, 1:, :] - inputs[:, :, :-1, :] 90 | dx = inputs[:, :, :, 1:] - inputs[:, :, :, :-1] 91 | return dx, dy 92 | 93 | image_dx, image_dy = gradient(image) 94 | flow_dx, flow_dy = gradient(flow) 95 | 96 | weights_x = torch.exp(-torch.mean(image_dx.abs(), 1, keepdim=True) * 10) 97 | weights_y = torch.exp(-torch.mean(image_dy.abs(), 1, keepdim=True) * 10) 98 | 99 | if derivative == 'first': 100 | loss_x = weights_x * flow_dx.abs() / 2.0 101 | loss_y = weights_y * flow_dy.abs() / 2.0 102 | elif derivative == 'second': 103 | flow_dx2 = flow_dx[:, :, :, 1:] - flow_dx[:, :, :, :-1] 104 | flow_dy2 = flow_dy[:, :, 1:, :] - flow_dy[:, :, :-1, :] 105 | loss_x = weights_x[:, :, :, 1:] * flow_dx2.abs() 106 | loss_y = weights_y[:, :, 1:, :] * flow_dy2.abs() 107 | else: 108 | raise NotImplementedError('Unknown derivative: %s' % derivative) 109 | 110 | return loss_x.mean() / 2 + loss_y.mean() / 2 111 | 112 | 113 | def calc_ssim_loss_2d(image1, image2, noc_mask=None, max_distance=1): 114 | """ 115 | Calculate photometric loss based on SSIM. 116 | :param image1: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB 117 | :param image2: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB 118 | :param noc_mask: [N, 1, H, W] float tensor, ranging from 0 to 1 119 | :param max_distance: int 120 | """ 121 | patch_size = 2 * max_distance + 1 122 | c1, c2 = 0.01 ** 2, 0.03 ** 2 123 | 124 | if noc_mask is not None: 125 | image1 = noc_mask * image1 126 | image2 = noc_mask * image2 127 | 128 | mu_x = nn.AvgPool2d(patch_size, 1, 0)(image1) 129 | mu_y = nn.AvgPool2d(patch_size, 1, 0)(image2) 130 | mu_x_square, mu_y_square = mu_x.pow(2), mu_y.pow(2) 131 | mu_xy = mu_x * mu_y 132 | 133 | sigma_x = nn.AvgPool2d(patch_size, 1, 0)(image1 * image1) - mu_x_square 134 | sigma_y = nn.AvgPool2d(patch_size, 1, 0)(image2 * image2) - mu_y_square 135 | sigma_xy = nn.AvgPool2d(patch_size, 1, 0)(image1 * image2) - mu_xy 136 | 137 | ssim_n = (2 * mu_xy + c1) * (2 * sigma_xy + c2) 138 | ssim_d = (mu_x_square + mu_y_square + c1) * (sigma_x + sigma_y + c2) 139 | ssim = ssim_n / ssim_d 140 | loss = torch.clamp((1 - ssim) / 2, min=0.0, max=1.0) 141 | 142 | if noc_mask is not None: 143 | return loss.mean() / (noc_mask.mean() + 1e-7) 144 | else: 145 | return loss.mean() 146 | 147 | 148 | def calc_unsupervised_loss_2d(pyramid_flows12, pyramid_flows21, image1, image2, occ_mask1, occ_mask2, cfgs): 149 | photo_loss = smooth_loss = 0 150 | for lv, (pyramid_flow12, pyramid_flow21) in enumerate(zip(pyramid_flows12, pyramid_flows21)): 151 | if lv == 0: 152 | image1_scaled, noc_mask1_scaled = image1, 1 - occ_mask1[:, None, :, :] 153 | image2_scaled, noc_mask2_scaled = image2, 1 - occ_mask2[:, None, :, :] 154 | else: 155 | curr_h, curr_w = pyramid_flow12.shape[2], pyramid_flow12.shape[3] 156 | image1_scaled = interpolate(image1, (curr_h, curr_w), mode='area') 157 | image2_scaled = interpolate(image2, (curr_h, curr_w), mode='area') 158 | noc_mask1_scaled = 1 - interpolate(occ_mask1[:, None, :, :], (curr_h, curr_w), mode='nearest') 159 | noc_mask2_scaled = 1 - interpolate(occ_mask2[:, None, :, :], (curr_h, curr_w), mode='nearest') 160 | 161 | image1_scaled_warp = backwarp_2d(image1_scaled, pyramid_flow21, padding_mode='border') 162 | image2_scaled_warp = backwarp_2d(image2_scaled, pyramid_flow12, padding_mode='border') 163 | 164 | # calculate photometric loss 165 | if cfgs.photometric_loss == 'ssim': 166 | photo_loss1 = calc_ssim_loss_2d(image1_scaled, image2_scaled_warp, noc_mask1_scaled) 167 | photo_loss2 = calc_ssim_loss_2d(image2_scaled, image1_scaled_warp, noc_mask2_scaled) 168 | elif cfgs.photometric_loss == 'census': 169 | photo_loss1 = calc_census_loss_2d(image1_scaled, image2_scaled_warp, noc_mask1_scaled) 170 | photo_loss2 = calc_census_loss_2d(image2_scaled, image1_scaled_warp, noc_mask2_scaled) 171 | else: 172 | raise NotImplementedError('Unknown photometric loss: %s' % cfgs.photometric_loss) 173 | photo_loss += cfgs.photometric_weights[lv] * (photo_loss1 + photo_loss2) / 2 174 | 175 | # calculate smooth loss 176 | scale = min(pyramid_flows12[0].shape[2], pyramid_flows12[0].shape[3]) 177 | smooth_loss1 = calc_smooth_loss_2d(image1_scaled, pyramid_flow12 / scale, cfgs.smooth_derivative) 178 | smooth_loss2 = calc_smooth_loss_2d(image2_scaled, pyramid_flow21 / scale, cfgs.smooth_derivative) 179 | smooth_loss += cfgs.smooth_weights[lv] * (smooth_loss1 + smooth_loss2) / 2 180 | 181 | return photo_loss, smooth_loss 182 | -------------------------------------------------------------------------------- /models/losses3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | def calc_supervised_loss_3d(flows, targets, cfgs, masks): 6 | assert len(flows) <= len(cfgs.level_weights) 7 | 8 | total_loss = 0 9 | for idx, (flow, level_weight) in enumerate(zip(flows, cfgs.level_weights)): 10 | level_target = targets[idx] 11 | 12 | mask = masks[idx] 13 | 14 | diff = flow - level_target # B, 3, H, W 15 | 16 | if cfgs.order == 'robust': 17 | epe_l1 = torch.pow(diff.abs().sum(dim=1) + 0.01, 0.4)[mask].mean() 18 | total_loss += level_weight * epe_l1 19 | elif cfgs.order == 'l2-norm': 20 | epe_l2 = torch.linalg.norm(diff, dim=1)[mask].mean() 21 | total_loss += level_weight * epe_l2 22 | else: 23 | raise NotImplementedError 24 | 25 | return total_loss 26 | 27 | 28 | -------------------------------------------------------------------------------- /models/pointconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import MLP2d, LayerNormCF1d, batch_indexing, knn_grouping_2d, mask_batch_selecting 4 | 5 | 6 | class PointConvS(nn.Module): 7 | def __init__(self, in_channels, out_channels, norm=None, act='leaky_relu', k=16): 8 | super().__init__() 9 | self.k = k 10 | 11 | self.weight_net = MLP2d(3, [8, 16], act=act) 12 | self.linear = nn.Linear(16 * (in_channels + 3), out_channels) 13 | 14 | if norm == 'batch_norm': 15 | self.norm_fn = nn.BatchNorm1d(out_channels, affine=True) 16 | elif norm == 'instance_norm': 17 | self.norm_fn = nn.InstanceNorm1d(out_channels, affine=True) 18 | elif norm == 'layer_norm': 19 | self.norm_fn = LayerNormCF1d(out_channels) 20 | elif norm is None: 21 | self.norm_fn = nn.Identity() 22 | else: 23 | raise NotImplementedError('Unknown normalization function: %s' % norm) 24 | 25 | if act == 'relu': 26 | self.act_fn = nn.ReLU(inplace=True) 27 | elif act == 'leaky_relu': 28 | self.act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True) 29 | elif act is None: 30 | self.act_fn = nn.Identity() 31 | else: 32 | raise NotImplementedError('Unknown activation function: %s' % act) 33 | 34 | def forward(self, xyz, features, sampled_xyz=None, knn_indices=None): 35 | """ 36 | :param xyz: 3D locations of points, [batch_size, 3, n_points] 37 | :param features: features of points, [batch_size, in_channels, n_points] 38 | :param sampled_xyz: 3D locations of sampled points, [batch_size, 3, n_samples] 39 | :return weighted_features: features of sampled points, [batch_size, out_channels, n_samples] 40 | """ 41 | if sampled_xyz is None: 42 | sampled_xyz = xyz 43 | 44 | bs, n_samples = sampled_xyz.shape[0], sampled_xyz.shape[-1] 45 | features = torch.cat([xyz, features], dim=1) # [bs, in_channels + 3, n_points] 46 | features_cl = features.transpose(1, 2) # [bs, n_points, n_channels + 3] 47 | 48 | # Calculate k nearest neighbors 49 | if knn_indices is None: 50 | knn_indices = k_nearest_neighbor(xyz, sampled_xyz, self.k) # [bs, n_samples, k] 51 | else: 52 | assert knn_indices.shape[:2] == torch.Size([bs, n_samples]) 53 | assert knn_indices.shape[2] >= self.k 54 | knn_indices = knn_indices[:, :, :self.k] 55 | 56 | # Calculate weights 57 | knn_xyz = batch_indexing(xyz, knn_indices) # [bs, 3, n_samples, k] 58 | knn_xyz_norm = knn_xyz - sampled_xyz[:, :, :, None] # [bs, 3, n_samples, k] 59 | weights = self.weight_net(knn_xyz_norm) # [bs, n_weights, n_samples, k] 60 | 61 | # Calculate weighted features 62 | weights = weights.transpose(1, 2) # [bs, n_samples, n_weights, k] 63 | knn_features = batch_indexing(features_cl, knn_indices, layout='channel_last') # [bs, n_samples, k, 3 + in_channels] 64 | out = torch.matmul(weights, knn_features) # [bs, n_samples, n_weights, 3 + in_channels] 65 | out = out.view(bs, n_samples, -1) # [bs, n_samples, (3 + in_channels) * n_weights] 66 | out = self.linear(out) # [bs, n_samples, out_channels] 67 | out = self.act_fn(self.norm_fn(out.transpose(1, 2))) # [bs, out_channels, n_samples] 68 | 69 | return out 70 | 71 | 72 | class PointConv(nn.Module): 73 | 74 | def __init__(self, in_channels, out_channels, ks = [10, 20], dist = 100.0, norm=None, act='leaky_relu', k=16): 75 | super().__init__() 76 | self.k = k 77 | self.kernel_size = ks 78 | self.distance = dist 79 | self.in_channels = in_channels 80 | self.out_channels = out_channels 81 | self.weight_net = MLP2d(3, [8, 16], act=act) 82 | self.linear = nn.Linear(16 * (in_channels + 3), out_channels) 83 | 84 | if norm == 'batch_norm': 85 | self.norm_fn = nn.BatchNorm1d(out_channels) 86 | elif norm == 'instance_norm': 87 | self.norm_fn = nn.InstanceNorm1d(out_channels) 88 | elif norm is None: 89 | self.norm_fn = nn.Identity() 90 | else: 91 | raise NotImplementedError('Unknown normalization function: %s' % norm) 92 | 93 | if act == 'relu': 94 | self.act_fn = nn.ReLU(inplace=True) 95 | elif act == 'leaky_relu': 96 | self.act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True) 97 | elif act is None: 98 | self.act_fn = nn.Identity() 99 | else: 100 | raise NotImplementedError('Unknown act function: %s' % act) 101 | 102 | def forward(self, xyz, features, sampled_xyz = None, knn_indices=None, valid_knn_mask = None): 103 | 104 | """ 105 | :param xyz: [batch_size, 3, H, W] 106 | :param features: [batch_size, C, H, W] 107 | :param sampled_xyz: [batch_size, 3, h, w] 108 | :return: out: [batch_size, C', h, w] 109 | """ 110 | if sampled_xyz == None: 111 | sampled_xyz = xyz 112 | 113 | 114 | B, C, H, W = features.shape 115 | h, w = sampled_xyz.shape[2:] 116 | features = torch.cat([xyz, features], dim=1).reshape(B, C + 3, H * W) # [B, in_channels + 3, H, W] 117 | features_cl = features.transpose(1, 2) # [B, H * W, in_channels + 3] 118 | 119 | ################# Calculate k nearest neighbors 120 | if knn_indices is None: 121 | # [B, h*w, k], [B, h*w, k] 122 | knn_indices, valid_knn_mask = knn_grouping_2d(sampled_xyz, xyz, self.k) 123 | else: 124 | assert knn_indices.shape[:2] == torch.Size([B, h * w]) 125 | assert knn_indices.shape[2] >= self.k 126 | knn_indices = knn_indices[:, :, :self.k] 127 | valid_knn_mask = valid_knn_mask[:, :, :self.k] 128 | # valid_mask = valid_mask[:, :, :self.k] 129 | 130 | # Calculate weights 131 | knn_xyz = mask_batch_selecting(xyz, knn_indices, valid_knn_mask) # [B, 3, h * w, k] 132 | new_xyz = sampled_xyz.reshape(B, 3,-1) # [B, 3, h*w] 133 | knn_xyz_norm = knn_xyz - new_xyz[:, :, :, None] # [B, 3, h*w, k] 134 | weights = self.weight_net(knn_xyz_norm) # [B, n_weights, h*w, k] 135 | 136 | # Calculate weighted features 137 | weights = weights.transpose(1, 2) # [B, h*w, n_weights, k] 138 | knn_features = mask_batch_selecting(features_cl, knn_indices, valid_knn_mask, layout='channel_last') # [B, h*w, k, C + 3] 139 | 140 | out = torch.matmul(weights, knn_features) # [B, h*w, n_weights, C+3] 141 | out = out.view(B, h*w, -1) # [B, h*w, n_weights*(C+3)] 142 | out = self.linear(out) # [B, h*w, out_channels] 143 | out = self.act_fn(self.norm_fn(out.transpose(1, 2))) # [B, out_channels, h * w] 144 | 145 | # out = torch.reshape(out, [B, -1, h, w]) 146 | out = out.view(B, -1, h, w) 147 | return out 148 | 149 | 150 | class PointConvDW(nn.Module): 151 | def __init__(self, in_channels, out_channels, norm=None, act='leaky_relu', k=16): 152 | super().__init__() 153 | self.k = k 154 | self.mlp = MLP2d(in_channels, [out_channels], norm, act) 155 | self.weight_net = MLP2d(3, [8, 32, out_channels], act='relu') 156 | 157 | def forward(self, xyz, features, sampled_xyz=None, knn_indices=None, valid_knn_mask = None): 158 | """ 159 | :param xyz: [batch_size, 3, H, W] 160 | :param features: [batch_size, C, H, W] 161 | :param sampled_xyz: [batch_size, 3, h, w] 162 | :return: out: [batch_size, C', h, w] 163 | """ 164 | 165 | if sampled_xyz is None: 166 | sampled_xyz = xyz 167 | 168 | B, C, H, W = features.shape 169 | h, w = sampled_xyz.shape[2:] 170 | 171 | # Calculate k nearest neighbors 172 | if knn_indices is None: 173 | # [B, h*w, k], [B, h*w, k] 174 | knn_indices, valid_knn_mask = knn_grouping_2d(sampled_xyz, xyz, self.k) 175 | else: 176 | assert knn_indices.shape[:2] == torch.Size([B, h * w]) 177 | assert knn_indices.shape[2] >= self.k 178 | knn_indices = knn_indices[:, :, :self.k] 179 | valid_knn_mask = valid_knn_mask[:, :, :self.k] 180 | 181 | # Calculate weights 182 | knn_xyz = mask_batch_selecting(xyz, knn_indices, valid_knn_mask) # [B, 3, h * w, k] 183 | new_xyz = sampled_xyz.reshape(B, 3,-1) # [B, 3, h*w] 184 | knn_offset = knn_xyz - new_xyz[:, :, :, None] # [B, 3, h*w, k] 185 | 186 | features = self.mlp(features) # [B, C_out, H, W] 187 | features = mask_batch_selecting(features, knn_indices, valid_knn_mask) # [B, C_out, h * w, k] 188 | features = features * self.weight_net(knn_offset) # [B, C_out, h * w, k] * [B, C_out, h*w, k] 189 | features = torch.max(features, dim=-1)[0] # [B, C_out, h*w] 190 | features = features.view(B, -1, h, w) 191 | 192 | return features -------------------------------------------------------------------------------- /ops_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/ops_pytorch/__init__.py -------------------------------------------------------------------------------- /ops_pytorch/fused_conv_select/fused_conv_g.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "fused_conv_gpu.h" 6 | 7 | extern THCState *state; 8 | 9 | //#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 10 | //#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 14 | 15 | void torch_FusedConvSelectKLauncher( 16 | torch::Tensor xyz_tensor, 17 | torch::Tensor xyz2_tensor, 18 | torch::Tensor idx_n2_tensor, 19 | torch::Tensor idx_fetching_tensor, 20 | torch::Tensor random_hw_tensor, 21 | int H, 22 | int W, 23 | int npoints, 24 | int kernel_size_H, 25 | int kernel_size_W, 26 | int K, 27 | bool flag_copy, 28 | float distance, 29 | int stride_h, 30 | int stride_w, 31 | torch::Tensor selected_b_idx_tensor, 32 | torch::Tensor selected_h_idx_tensor, 33 | torch::Tensor selected_w_idx_tensor, 34 | torch::Tensor selected_mask_tensor, 35 | int small_h, 36 | int small_w ){ 37 | 38 | CHECK_INPUT(xyz_tensor); 39 | CHECK_INPUT(xyz2_tensor); 40 | CHECK_INPUT(idx_n2_tensor); 41 | CHECK_INPUT(idx_fetching_tensor); 42 | CHECK_INPUT(random_hw_tensor); 43 | CHECK_INPUT(selected_b_idx_tensor); 44 | CHECK_INPUT(selected_h_idx_tensor); 45 | CHECK_INPUT(selected_w_idx_tensor); 46 | CHECK_INPUT(selected_mask_tensor); 47 | 48 | const auto batch_size = xyz_tensor.size(0); 49 | const float *xyz1 = xyz_tensor.data(); 50 | const float *xyz2 = xyz2_tensor.data(); 51 | const int *idx_n2 = idx_n2_tensor.data(); 52 | const int *idx_fetching = idx_fetching_tensor.data(); 53 | const int *random_hw = random_hw_tensor.data(); 54 | long *selected_b_idx = selected_b_idx_tensor.data(); 55 | long *selected_h_idx = selected_h_idx_tensor.data(); 56 | long *selected_w_idx = selected_w_idx_tensor.data(); 57 | float *selected_mask = selected_mask_tensor.data(); 58 | 59 | //cudaStream_t stream = THCState_getCurrentStream(state); 60 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 61 | 62 | FusedConvSelectKLauncher(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, 63 | stride_h, stride_w, xyz1, xyz2, idx_n2, idx_fetching, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, small_h, small_w, stream); 64 | } 65 | 66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 67 | m.def("fused_conv_select_k", 68 | &torch_FusedConvSelectKLauncher, 69 | "torch_FusedConvSelectKLauncher kernel warpper"); 70 | } -------------------------------------------------------------------------------- /ops_pytorch/fused_conv_select/fused_conv_go.cu: -------------------------------------------------------------------------------- 1 | // input: kernel_size(h,w), stride_size(h,w), distance(float), flag_padding, xyz(b,H,W,3), bhw_idx(b,H,W,3) 2 | // output: selected_xyz(b, npoints, h*w, 3), selected_feature(b, npoints, h*w, 3) 3 | #include 4 | #include 5 | #include /* srand, rand */ 6 | #include /* time */ 7 | #include // Header file needed to use rand 8 | #include "fused_conv_gpu.h" 9 | 10 | __global__ void fused_conv_select_k_gpu(int batch_size, int H, int W, int npoints, int kernel_size_H, 11 | int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1, 12 | const float *xyz2, const int *idx_n2, const int *idx_fetching, const int *random_hw, long *selected_b_idx, long *selected_h_idx, long *selected_w_idx, 13 | float *selected_mask, int small_h, int small_w) 14 | { 15 | 16 | int batch_index = blockIdx.x; //当前线程块索引 17 | int index_thread = threadIdx.x; 18 | int stride_thread = blockDim.x; 19 | 20 | int kernel_total = kernel_size_H * kernel_size_W; // 一个kernel的大小 21 | int selected_W_idx = 0, selected_H_idx = 0; 22 | int fetched_H_idx = 0, fetched_W_idx = 0; 23 | 24 | float dist_square = distance * distance; 25 | 26 | int kernel_half_H = kernel_size_H / 2; 27 | int kernel_half_W = kernel_size_W / 2; 28 | // small_h = H, small_w = H 29 | xyz1 += batch_index * H * W * 3; // point cloud of current image 30 | xyz2 += batch_index * small_h * small_w * 3; 31 | idx_n2 += batch_index * npoints * 2; 32 | idx_fetching += batch_index * npoints * 2; // 2d coordinates of central points 33 | selected_b_idx += batch_index * npoints * K * 1; //(b, npoints, k, 3), // batch index of K selected points around central points 34 | selected_h_idx += batch_index * npoints * K * 1; //(b, npoints, k, 3), 35 | selected_w_idx += batch_index * npoints * K * 1; //(b, npoints, k, 3), 36 | 37 | 38 | selected_mask += batch_index * npoints * K * 1; //(b, npoints, k, 1),坐标有效且距离有效的点,含复制的点(重复使用最近邻的点) 39 | 40 | ////////////// Fused Conv Between 41 | const int MaxPoint = 960; 42 | for (int current_n = index_thread; current_n < npoints; current_n += stride_thread) // output_W circle 43 | { 44 | 45 | int idx_w[MaxPoint], idx_h[MaxPoint]; 46 | float Dist[MaxPoint]; 47 | 48 | for (int ii = 0; ii < MaxPoint; ++ii) 49 | { 50 | idx_w[ii] = 0; 51 | idx_h[ii] = 0; 52 | Dist[ii] = 1e10f; 53 | } 54 | 55 | int m_idx = 0; // mth point in each kernel 56 | int num_select = 0; // the number of selected points in each kernel 57 | // idx_n2: subsampled index 58 | selected_H_idx = idx_n2[current_n * 2 + 0]; // the central points H idx of input 2d frame 59 | selected_W_idx = idx_n2[current_n * 2 + 1]; // the central points W idx of input 2d frame 60 | 61 | float x_c = xyz1[selected_H_idx * W * 3 + selected_W_idx * 3 + 0]; 62 | float y_c = xyz1[selected_H_idx * W * 3 + selected_W_idx * 3 + 1]; 63 | float z_c = xyz1[selected_H_idx * W * 3 + selected_W_idx * 3 + 2]; 64 | 65 | float Dist_c = max((x_c - 0) * (x_c - 0) + (y_c - 0) * (y_c - 0) + (z_c - 0) * (z_c - 0), 1e-10f); 66 | 67 | // NOTE: the following is different 68 | if (Dist_c <= 1e-10f) // not valid central points of xyz1 69 | { 70 | continue; 71 | } 72 | 73 | fetched_H_idx = idx_fetching[current_n * 2 + 0]; 74 | fetched_W_idx = idx_fetching[current_n * 2 + 1]; 75 | // valid central points of xyz2 76 | if (fetched_H_idx == -1) 77 | { 78 | continue; 79 | } 80 | 81 | // compute the distance 82 | for (int current_HW_idx = 0; current_HW_idx < kernel_total; ++current_HW_idx) // select points in every kernel element 83 | { 84 | 85 | int kernel_HW_idx = random_hw[current_HW_idx]; 86 | 87 | int kernel_select_H_idx = fetched_H_idx / stride_h + kernel_HW_idx / kernel_size_W - kernel_half_H; // random select ??? 88 | int kernel_select_W_idx = fetched_W_idx / stride_w + kernel_HW_idx % kernel_size_W - kernel_half_W; // random select ??? 89 | 90 | if ((kernel_select_H_idx < 0) || (kernel_select_H_idx >= small_h) || (kernel_select_W_idx < 0) || (kernel_select_W_idx >= small_w)) // the region of padding points (not valid) 91 | { 92 | ++m_idx; 93 | continue; 94 | } 95 | 96 | // not the padding points 97 | 98 | float x_q = xyz2[kernel_select_H_idx * small_w * 3 + kernel_select_W_idx * 3 + 0]; 99 | float y_q = xyz2[kernel_select_H_idx * small_w * 3 + kernel_select_W_idx * 3 + 1]; 100 | float z_q = xyz2[kernel_select_H_idx * small_w * 3 + kernel_select_W_idx * 3 + 2]; 101 | 102 | float Dist_q_0 = x_q * x_q + y_q * y_q + z_q * z_q; 103 | 104 | if (Dist_q_0 <= 1e-10f) // not valid xyz2 points 105 | { 106 | ++m_idx; 107 | continue; 108 | } 109 | 110 | // valid xyz2 points, calculate the distance 111 | 112 | 113 | float Dist_q = max((x_c - x_q) * (x_c - x_q) + (y_c - y_q) * (y_c - y_q) + (z_c - z_q) * (z_c - z_q), 1e-10f); 114 | 115 | if (Dist_q > dist_square) // too far from the central points, regarding as not valid 116 | { 117 | ++m_idx; 118 | continue; 119 | } 120 | 121 | 122 | 123 | Dist[m_idx] = Dist_q; 124 | idx_h[m_idx] = kernel_select_H_idx; 125 | idx_w[m_idx] = kernel_select_W_idx; 126 | 127 | ++m_idx; 128 | ++num_select; 129 | 130 | if (num_select >= kernel_total) // search all position 131 | break; 132 | } 133 | 134 | //?int sort_num = 0; 135 | 136 | for (int s_idx = 0; s_idx < K; ++s_idx) // knn 137 | { 138 | int min_idx = s_idx; // min_idx idx 139 | 140 | // find the min_idx 141 | for (int t = s_idx + 1; t < kernel_total; ++t) 142 | { 143 | if (Dist[t] < Dist[min_idx]) 144 | { 145 | min_idx = t; 146 | } 147 | } 148 | 149 | // swap min_idx-th and i-th element 150 | if (min_idx != s_idx) 151 | { 152 | float tmp_dist = Dist[min_idx]; 153 | int tmp_idx_w = idx_w[min_idx]; 154 | int tmp_idx_h = idx_h[min_idx]; 155 | 156 | Dist[min_idx] = Dist[s_idx]; 157 | idx_w[min_idx] = idx_w[s_idx]; 158 | idx_h[min_idx] = idx_h[s_idx]; 159 | 160 | Dist[s_idx] = tmp_dist; 161 | idx_w[s_idx] = tmp_idx_w; 162 | idx_h[s_idx] = tmp_idx_h; 163 | } 164 | 165 | if ((flag_copy & 0x1) && (s_idx == 0)) // copy the first selected point in xyz2 for K times 166 | { 167 | for (int k_idx = 0; k_idx < K; ++k_idx) 168 | { 169 | 170 | selected_b_idx[current_n * K + k_idx] = batch_index; 171 | selected_h_idx[current_n * K + k_idx] = idx_h[s_idx]; 172 | selected_w_idx[current_n * K + k_idx] = idx_w[s_idx]; 173 | selected_mask[current_n * K * 1 + k_idx * 1 + 0] = 1.0; 174 | } 175 | 176 | } // copy done 177 | 178 | if (Dist[s_idx] < 1e10f) // whether this is a valid points or not 179 | { 180 | 181 | selected_b_idx[current_n * K + s_idx] = batch_index; 182 | selected_h_idx[current_n * K + s_idx] = idx_h[s_idx]; 183 | selected_w_idx[current_n * K + s_idx] = idx_w[s_idx]; 184 | selected_mask[current_n * K * 1 + s_idx * 1 + 0] = 1.0; 185 | } 186 | } 187 | } 188 | } 189 | 190 | void FusedConvSelectKLauncher(int batch_size, int H, int W, int npoints, int kernel_size_H, 191 | int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, 192 | const float *xyz1, const float *xyz2, const int *idx_n2, const int *idx_fetching, const int *random_hw, 193 | long *selected_b_idx, long *selected_h_idx, long *selected_w_idx, 194 | float *selected_mask, int small_h, int small_w, cudaStream_t stream) 195 | { 196 | 197 | cudaError_t err; 198 | 199 | fused_conv_select_k_gpu<<>>(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, xyz1, xyz2, idx_n2, idx_fetching, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, small_h, small_w); 200 | 201 | // cudaDeviceSynchronize(); 202 | err = cudaGetLastError(); 203 | if (cudaSuccess != err) 204 | { 205 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 206 | exit(-1); 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /ops_pytorch/fused_conv_select/fused_conv_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _FUSE_CONV_GPU_H_ 2 | #define _FUSE_CONV_GPU_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | void torch_FusedConvSelectKLauncher( 11 | torch::Tensor xyz_tensor, 12 | torch::Tensor xyz2_tensor, 13 | torch::Tensor idx_n2_tensor, 14 | torch::Tensor idx_fetching_tensor, 15 | torch::Tensor random_hw_tensor, 16 | int H, 17 | int W, 18 | int npoints, 19 | int kernel_size_H, 20 | int kernel_size_W, 21 | int K, 22 | bool flag_copy, 23 | float distance, 24 | int stride_h, 25 | int stride_w, 26 | torch::Tensor select_b_idx_tensor, 27 | torch::Tensor select_h_idx_tensor, 28 | torch::Tensor select_w_idx_tensor, 29 | torch::Tensor select_mask_tensor, 30 | int small_h, 31 | int small_w); 32 | 33 | void FusedConvSelectKLauncher( 34 | int batch_size, 35 | int H, 36 | int W, 37 | int npoints, 38 | int kernel_size_H, 39 | int kernel_size_W, 40 | int K, 41 | int flag_copy, 42 | float distance, 43 | int stride_h, 44 | int stride_w, 45 | const float *xyz1, 46 | const float *xyz2, 47 | const int *idx_n2, 48 | const int *idx_fetching, 49 | const int *random_hw, 50 | long *selected_b_idx, 51 | long *selected_h_idx, 52 | long *selected_w_idx, 53 | float *selected_mask, 54 | int small_h, 55 | int small_w, 56 | cudaStream_t stream); 57 | 58 | #endif -------------------------------------------------------------------------------- /ops_pytorch/fused_conv_select/fused_conv_select_k.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | import numpy as np 5 | import fused_conv_select_k_cuda as fused_conv_select_k_module 6 | 7 | 8 | def fused_conv_select_k(xyz1, xyz2, idx_n2, idx_fetching, random_hw, H, W, 9 | npoints, kernel_size_H, kernel_size_W, K, flag_copy, 10 | distance, stride_h, stride_w, select_b_idx, 11 | select_h_idx, select_w_idx, select_mask, small_h, small_w): 12 | ''' 13 | Input: 14 | xyz1:(b, h, w, 3) float, projected xyz1 points 15 | xyz2_feature:(b, h, w, c+3) float, projected xyz2 points with features 16 | idx_n2: (b, n, 2) int array, query idx of central points 17 | H, W : Input shape 18 | kernel_size_H, kernel_size_W: (size, size) int32 array, size 19 | k: the number of selected points (knn) 20 | distance: ( distance ) float distance 21 | flag_copy (bool) whether copy or not for the output points 22 | 23 | Output: 24 | space_weight:(batch_size, npoint, size*size , c) 25 | ''' 26 | 27 | fused_conv_select_k_module.fused_conv_select_k( 28 | xyz1, xyz2, idx_n2, idx_fetching, random_hw, H, W, npoints, 29 | kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, 30 | stride_w, select_b_idx, select_h_idx, select_w_idx, select_mask, small_h, small_w) 31 | return select_b_idx, select_h_idx, select_w_idx, select_mask 32 | 33 | -------------------------------------------------------------------------------- /ops_pytorch/fused_conv_select/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name="fused_conv_select_k", 6 | ext_modules=[ 7 | CUDAExtension( 8 | "fused_conv_select_k_cuda", 9 | ["fused_conv_g.cpp", "fused_conv_go.cu"], 10 | extra_compile_args={'cxx': ['-g'], 11 | 'nvcc': ['-O2']}) 12 | ], 13 | cmdclass={ 14 | "build_ext": BuildExtension 15 | } 16 | ) -------------------------------------------------------------------------------- /ops_pytorch/gpu_threenn_sample/no_sort_knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | import numpy as np 5 | 6 | import no_sort_knn_cuda as no_sort_knn_module 7 | 8 | def no_sort_knn(xyz1, xyz2, idx_n2, random_hw, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, select_b_idx, select_h_idx, select_w_idx, select_mask): 9 | 10 | no_sort_knn_module.no_sort_knn(xyz1, xyz2, idx_n2, random_hw, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, select_b_idx, select_h_idx, select_w_idx, select_mask) 11 | return select_b_idx, select_h_idx, select_w_idx, select_mask 12 | 13 | -------------------------------------------------------------------------------- /ops_pytorch/gpu_threenn_sample/no_sort_knn_g.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "no_sort_knn_gpu.h" 6 | 7 | extern THCState *state; 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 10 | #define CHECK_INPUT(x) \ 11 | CHECK_CUDA(x); \ 12 | CHECK_CONTIGUOUS(x) 13 | 14 | void torch_NoSortKnnLauncher( 15 | torch::Tensor xyz1_tensor, 16 | torch::Tensor xyz2_tensor, 17 | torch::Tensor idx_n2_tensor, 18 | torch::Tensor random_hw_tensor, 19 | int H, 20 | int W, 21 | int npoints, 22 | int kernel_size_H, 23 | int kernel_size_W, 24 | int K, 25 | bool flag_copy, 26 | float distance, 27 | int stride_h, 28 | int stride_w, 29 | torch::Tensor selected_b_idx_tensor, 30 | torch::Tensor selected_h_idx_tensor, 31 | torch::Tensor selected_w_idx_tensor, 32 | torch::Tensor selected_mask_tensor) 33 | { 34 | CHECK_INPUT(xyz1_tensor); 35 | CHECK_INPUT(xyz2_tensor); 36 | CHECK_INPUT(idx_n2_tensor); 37 | CHECK_INPUT(random_hw_tensor); 38 | CHECK_INPUT(selected_b_idx_tensor); 39 | CHECK_INPUT(selected_h_idx_tensor); 40 | CHECK_INPUT(selected_w_idx_tensor); 41 | CHECK_INPUT(selected_mask_tensor); 42 | 43 | const auto batch_size = xyz1_tensor.size(0); 44 | const float *xyz1 = xyz1_tensor.data(); 45 | const float *xyz2 = xyz2_tensor.data(); 46 | const int *idx_n2 = idx_n2_tensor.data(); 47 | const int *random_hw = random_hw_tensor.data(); 48 | long *selected_b_idx = selected_b_idx_tensor.data(); 49 | long *selected_h_idx = selected_h_idx_tensor.data(); 50 | long *selected_w_idx = selected_w_idx_tensor.data(); 51 | float *selected_mask = selected_mask_tensor.data(); 52 | 53 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 54 | 55 | NoSortKnnLauncher(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, 56 | stride_h, stride_w, xyz1, xyz2, idx_n2, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, stream); 57 | } 58 | 59 | 60 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 61 | m.def( 62 | "no_sort_knn", 63 | &torch_NoSortKnnLauncher, 64 | "torch_NoSortKnnLauncher kernel warpper" 65 | ); 66 | } -------------------------------------------------------------------------------- /ops_pytorch/gpu_threenn_sample/no_sort_knn_go.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include /* srand, rand */ 4 | #include /* time */ 5 | #include // Header file needed to use rand 6 | #include 7 | #include 8 | 9 | struct special_idx{ 10 | int idx_h; 11 | int idx_w; 12 | }; 13 | 14 | __global__ void no_sort_knn_gpu(int batch_size, int H, int W, int npoints, int kernel_size_H, int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1, const float *xyz2, const int *idx_n2, const int *random_hw, long *selected_b_idx, long * selected_h_idx, long * selected_w_idx, float *selected_mask, int small_h, int small_w) 15 | { 16 | // in this function, only select 3 closest points 17 | int batch_index = blockIdx.x; 18 | int index_thread = threadIdx.x; 19 | int stride_thread = blockDim.x; 20 | 21 | int kernel_total = kernel_size_H * kernel_size_W; 22 | int selected_W_idx = 0, selected_H_idx =0; 23 | 24 | float dist_square = distance * distance; 25 | 26 | int kernel_half_H = kernel_size_H / 2; 27 | int kernel_half_W = kernel_size_W / 2; 28 | 29 | xyz1 += batch_index * H * W * 3; 30 | xyz2 += batch_index * small_h * small_w * 3; 31 | idx_n2 += batch_index * npoints * 2; 32 | selected_b_idx += batch_index * npoints * K * 1; 33 | selected_h_idx += batch_index * npoints * K * 1; 34 | selected_w_idx += batch_index * npoints * K * 1; 35 | 36 | 37 | // valid_idx += batch_index * npoints * kernel_total * 1 ; //(b, npoints, h*w, 1) 38 | // valid_in_dis_idx += batch_index * npoints * kernel_total * 1 ; //(b, npoints, h*w, 1) 39 | 40 | selected_mask += batch_index * npoints * K * 1 ; //(b, npoints, h*w, 1) 41 | 42 | for (int dense_i = index_thread; dense_i < npoints; dense_i += stride_thread) { 43 | // for each point in dense pl, search closest three points in sparse pl 44 | selected_H_idx = idx_n2[dense_i * 2]; 45 | selected_W_idx = idx_n2[dense_i * 2 + 1]; 46 | int central_idx = selected_H_idx * W * 3 + selected_W_idx * 3; 47 | float dense_x = xyz1[central_idx]; 48 | float dense_y = xyz1[central_idx + 1]; 49 | float dense_z = xyz1[central_idx + 2]; 50 | // int num_valid_idx = 0; 51 | int num_select = 0; 52 | 53 | float dist_from_origin = dense_x * dense_x + dense_y * dense_y + dense_z * dense_z; 54 | if (dist_from_origin <= 1e-10) { 55 | continue; 56 | } 57 | float best1 = 1e30, best2 = 1e30, best3 = 1e30; 58 | special_idx besti[3]; 59 | for (int i = 0; i < 3; ++i) { 60 | besti[i].idx_h = besti[i].idx_w = 0; 61 | } 62 | 63 | // once the central points in xyz1 are valid, begin to dispose xyz2 points 64 | for (int current_HW_idx = 0; current_HW_idx < kernel_total; ++current_HW_idx) { 65 | int kernel_hw_idx = random_hw[current_HW_idx]; 66 | int kernel_select_h_idx = selected_H_idx / stride_h + kernel_hw_idx / kernel_size_W - kernel_half_H; 67 | int kernel_select_w_idx = selected_W_idx / stride_w + kernel_hw_idx % kernel_size_W - kernel_half_W; 68 | 69 | if ((kernel_select_h_idx < 0) || (kernel_select_w_idx < 0) || (kernel_select_h_idx >= small_h) || (kernel_select_w_idx >= small_w)) { 70 | continue; 71 | } 72 | int select_idx = kernel_select_h_idx * small_w * 3 + kernel_select_w_idx * 3; 73 | float sparse_x = xyz2[select_idx]; 74 | float sparse_y = xyz2[select_idx + 1]; 75 | float sparse_z = xyz2[select_idx + 2]; 76 | 77 | float queried_dist_from_center = sparse_x * sparse_x + sparse_y * sparse_y + sparse_z * sparse_z; 78 | if (queried_dist_from_center <= 1e-10) { 79 | continue; 80 | // queried points are invalid points 81 | } 82 | float dist_from_query = (dense_x - sparse_x) * (dense_x - sparse_x) + (dense_y - sparse_y) * (dense_y - sparse_y) + (dense_z - sparse_z) * (dense_z - sparse_z); 83 | 84 | // if (num_select == 0 && flag_copy) { 85 | // // if at least one point is found, copy its information to all three points 86 | // best1 = best2 = best3 = dist_from_query; 87 | // special_idx temp; 88 | // temp.idx_h = kernel_select_h_idx; 89 | // temp.idx_w = kernel_select_w_idx; 90 | // besti[0] = besti[1] = besti[2] = temp; 91 | // } 92 | 93 | 94 | if (dist_from_query < 1e-10) { 95 | // we treat it as the original point and the first queried point 96 | best3 = best2; 97 | besti[2] = besti[1]; 98 | best2 = best1; 99 | besti[1] = besti[0]; 100 | best1 = dist_from_query; 101 | besti[0].idx_h = kernel_select_h_idx; 102 | besti[0].idx_w = kernel_select_w_idx; 103 | continue; 104 | } 105 | if (dist_from_query > dist_square) { 106 | continue; 107 | } 108 | 109 | ++num_select; 110 | // given a central point, select the closest 3 points in a kernel 111 | 112 | if (dist_from_query < best1) { 113 | best3 = best2; 114 | besti[2] = besti[1]; 115 | best2 = best1; 116 | besti[1] = besti[0]; 117 | best1 = dist_from_query; 118 | besti[0].idx_h = kernel_select_h_idx; 119 | besti[0].idx_w = kernel_select_w_idx; 120 | } else if (dist_from_query < best2) { 121 | best3 = best2; 122 | besti[2] = besti[1]; 123 | best2 = dist_from_query; 124 | besti[1].idx_h = kernel_select_h_idx; 125 | besti[1].idx_w = kernel_select_w_idx; 126 | } else if (dist_from_query < best3) { 127 | best3 = dist_from_query; 128 | besti[2].idx_h = kernel_select_h_idx; 129 | besti[2].idx_w = kernel_select_w_idx; 130 | } 131 | } 132 | // bool no_point_flag = (best1 >= 1e30 && best2 >= 1e30 && best3 >= 1e30); 133 | 134 | int max_points = num_select < K ? num_select : K; 135 | int temp; 136 | for (int k = 0; k < max_points; ++k) { 137 | temp = dense_i * K + k; 138 | selected_b_idx[temp] = batch_index; 139 | selected_h_idx[temp] = besti[k].idx_h; 140 | selected_w_idx[temp] = besti[k].idx_w; 141 | selected_mask[temp] = 1.0; 142 | } 143 | if (flag_copy) { 144 | // if no points are selected, copy the first item 145 | for (int k = max_points; k < K; ++k) { 146 | int temp = dense_i * K + k; 147 | selected_b_idx[temp] = batch_index; 148 | selected_h_idx[temp] = besti[0].idx_h; 149 | selected_w_idx[temp] = besti[0].idx_w; 150 | selected_mask[temp] = 1.0; 151 | } 152 | } 153 | 154 | 155 | 156 | // for (int k = 0; k < K; ++k) { 157 | // int temp = dense_i * K + k; 158 | // selected_b_idx[temp] = batch_index; 159 | // selected_h_idx[temp] = besti[k].idx_h; 160 | // selected_w_idx[temp] = besti[k].idx_w; 161 | // selected_mask[dense_i * K + k] = 1.0; 162 | // } 163 | } 164 | } 165 | 166 | 167 | 168 | void NoSortKnnLauncher(int batch_size, int H, int W, int npoints, int kernel_size_H, int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1, const float *xyz2, const int *idx_n2, const int *random_hw, long *selected_b_idx, long *selected_h_idx, long *selected_w_idx, float *selected_mask, cudaStream_t stream) 169 | { 170 | int small_h = ceil(H / stride_h); 171 | int small_w = ceil(W / stride_w); 172 | cudaError_t err; 173 | no_sort_knn_gpu<<>>(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, xyz1, xyz2, idx_n2, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, small_h, small_w); 174 | 175 | err = cudaGetLastError(); 176 | if (cudaSuccess != err) { 177 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 178 | exit(-1); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /ops_pytorch/gpu_threenn_sample/no_sort_knn_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _NO_SORT_KNN_GPU_H 2 | #define _NO_SORT_KNN_GPU_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | void torch_NoSortKnnLauncher( 11 | torch::Tensor xyz1_tensor, 12 | torch::Tensor xyz2_tensor, 13 | torch::Tensor idx_n2_tensor, 14 | torch::Tensor random_hw_tensor, 15 | int H, 16 | int W, 17 | int npoints, 18 | int kernel_size_H, 19 | int kernel_size_W, 20 | int K, 21 | bool flag_copy, 22 | float distance, 23 | int stride_h, 24 | int stride_w, 25 | torch::Tensor selected_b_idx_tensor, 26 | torch::Tensor selected_h_idx_tensor, 27 | torch::Tensor selected_w_idx_tensor, 28 | torch::Tensor selected_mask_tensor); 29 | 30 | void NoSortKnnLauncher( 31 | int batch_size, int H, int W, int npoints, int kernel_size_H, int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1, const float *xyz2, const int *idx_n2, const int *random_hw, long *selected_b_idx, long *selected_h_idx, long *selected_w_idx, float *selected_mask, cudaStream_t stream); 32 | 33 | 34 | 35 | #endif -------------------------------------------------------------------------------- /ops_pytorch/gpu_threenn_sample/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name="no_sort_knn", 6 | ext_modules=[ 7 | CUDAExtension( 8 | "no_sort_knn_cuda", 9 | ["no_sort_knn_g.cpp", "no_sort_knn_go.cu"], 10 | extra_compile_args={ 11 | "cxx": ['-g'], 12 | "nvcc": ['-O2'] 13 | }) 14 | 15 | ], 16 | cmdclass={ 17 | "build_ext": BuildExtension 18 | } 19 | ) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import sys 4 | import random 5 | import logging 6 | import torch 7 | import torch.optim 8 | import torch.multiprocessing as mp 9 | import torch.backends.cudnn as cudnn 10 | from datetime import datetime 11 | from omegaconf import DictConfig, OmegaConf 12 | from torch.nn.parallel import DistributedDataParallel 13 | from torch.distributed import init_process_group 14 | from torch.utils.data.distributed import DistributedSampler 15 | from dataset.flyingthings_subset import FlyingThings3D 16 | from dataset.kitti import KITTI 17 | from models import CamLiPWC 18 | from torch.utils.tensorboard import SummaryWriter 19 | from torch.cuda.amp.grad_scaler import GradScaler 20 | from utils import copy_to_device, build_optim_and_sched, FastDataLoader, init_log 21 | 22 | 23 | class Trainer: 24 | def __init__(self, device: torch.device, cfgs: DictConfig): 25 | os.environ["MKL_NUM_THREADS"] = "1" 26 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 27 | os.environ["NCCL_IB_DISABLE"] = "1" 28 | os.environ["NCCL_P2P_DISABLE"] = "1" 29 | 30 | self.cfgs = cfgs 31 | self.curr_epoch = 1 32 | self.device = device 33 | self.n_gpus = torch.cuda.device_count() 34 | self.is_main = device.index is None or device.index == 0 35 | self.cfgs.log.dir = ( 36 | self.cfgs.log.dir 37 | + "/{}/".format(self.cfgs.model.name) 38 | + str(datetime.now().strftime("%Y-%m-%d_%H-%M")) 39 | ) 40 | os.makedirs(self.cfgs.log.dir, exist_ok=True) 41 | init_log(os.path.join(self.cfgs.log.dir, "train.log")) 42 | 43 | if device.index is None: 44 | logging.info("No CUDA device detected, using CPU for training") 45 | else: 46 | logging.info( 47 | "Using GPU %d: %s" % (device.index, torch.cuda.get_device_name(device)) 48 | ) 49 | logging.info( 50 | "PID:{}".format(os.getpid()) 51 | ) 52 | if self.n_gpus > 1: 53 | init_process_group( 54 | "nccl", 55 | "tcp://localhost:%d" % self.cfgs.port, 56 | world_size=self.n_gpus, 57 | rank=self.device.index, 58 | ) 59 | self.cfgs.model.batch_size = int( 60 | self.cfgs.model.batch_size / self.n_gpus 61 | ) 62 | self.cfgs.trainset.n_workers = int( 63 | self.cfgs.trainset.n_workers / self.n_gpus 64 | ) 65 | self.cfgs.valset.n_workers = int( 66 | self.cfgs.valset.n_workers / self.n_gpus 67 | ) 68 | 69 | cudnn.benchmark = False 70 | torch.cuda.set_device(self.device) 71 | 72 | if self.is_main: 73 | logging.info("Logs will be saved to %s" % self.cfgs.log.dir) 74 | self.summary_writer = SummaryWriter(self.cfgs.log.dir) 75 | logging.info("Configurations:\n" + OmegaConf.to_yaml(self.cfgs)) 76 | os.system("cp -r %s %s" % ("models", self.cfgs.log.dir)) 77 | os.system("cp -r %s %s" % ("dataset", self.cfgs.log.dir)) 78 | os.system("cp -r %s %s" % ("config", self.cfgs.log.dir)) 79 | os.system("cp %s %s" % ("train.py", self.cfgs.log.dir)) 80 | else: 81 | logging.root.disabled = True 82 | 83 | if self.cfgs.trainset.name == "flyingthings3d": 84 | self.train_dataset = FlyingThings3D(self.cfgs.trainset) 85 | self.val_dataset = FlyingThings3D(self.cfgs.valset) 86 | elif self.cfgs.trainset.name == "kitti": 87 | self.train_dataset = KITTI(self.cfgs.trainset) 88 | self.val_dataset = KITTI(self.cfgs.valset) 89 | else: 90 | raise NotImplementedError 91 | 92 | logging.info("Loading training set from %s" % self.cfgs.trainset.root_dir) 93 | self.train_sampler = ( 94 | DistributedSampler(self.train_dataset) if self.n_gpus > 1 else None 95 | ) 96 | logging.info("Loading validation set from %s" % self.cfgs.valset.root_dir) 97 | self.val_sampler = ( 98 | DistributedSampler(self.val_dataset) if self.n_gpus > 1 else None 99 | ) 100 | 101 | self.train_loader = FastDataLoader( 102 | dataset=self.train_dataset, 103 | batch_size=self.cfgs.model.batch_size, 104 | shuffle=(self.train_sampler is None), 105 | num_workers=self.cfgs.trainset.n_workers, 106 | pin_memory=True, 107 | sampler=self.train_sampler, 108 | drop_last=self.cfgs.trainset.drop_last, 109 | ) 110 | 111 | self.val_loader = FastDataLoader( 112 | dataset=self.val_dataset, 113 | batch_size=self.cfgs.model.batch_size, 114 | shuffle=False, 115 | num_workers=self.cfgs.valset.n_workers, 116 | pin_memory=True, 117 | sampler=self.val_sampler, 118 | ) 119 | 120 | logging.info("Creating model: %s" % self.cfgs.model.name) 121 | 122 | self.model = CamLiPWC(self.cfgs.model) 123 | self.model.to(device=self.device) 124 | 125 | n_params = sum([p.numel() for p in self.model.parameters() if p.requires_grad]) 126 | logging.info("Trainable parameters: %d (%.1fM)" % (n_params, n_params / 1e6)) 127 | 128 | if self.n_gpus > 1: 129 | if self.cfgs.sync_bn: 130 | self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) 131 | self.ddp = DistributedDataParallel(self.model, [self.device.index]) 132 | else: 133 | self.ddp = self.model 134 | 135 | self.best_metrics = None 136 | if self.cfgs.ckpt.path is not None: 137 | self.load_ckpt(self.cfgs.ckpt.path, resume=self.cfgs.ckpt.resume) 138 | 139 | logging.info("Creating optimizer: %s" % self.cfgs.training.opt) 140 | self.optimizer, self.scheduler = build_optim_and_sched( 141 | self.cfgs.training, self.model 142 | ) 143 | self.scheduler.step(self.curr_epoch - 1) 144 | 145 | self.amp_scaler = GradScaler(enabled=self.cfgs.amp) 146 | 147 | def run(self): 148 | while self.curr_epoch <= self.cfgs.training.epochs: 149 | if self.train_sampler is not None: 150 | self.train_sampler.set_epoch(self.curr_epoch) 151 | if self.val_sampler is not None: 152 | self.val_sampler.set_epoch(self.curr_epoch) 153 | 154 | self.train_one_epoch() 155 | 156 | if self.curr_epoch % self.cfgs.val_interval == 0: 157 | self.validate() 158 | 159 | self.save_ckpt() 160 | self.scheduler.step(self.curr_epoch) 161 | 162 | self.curr_epoch += 1 163 | 164 | def train_one_epoch(self): 165 | logging.info("Start training...") 166 | 167 | self.ddp.train() 168 | self.model.clear_metrics() 169 | self.optimizer.zero_grad() 170 | 171 | lr = self.optimizer.param_groups[0]["lr"] 172 | self.save_scalar_summary({"learning_rate": lr}, prefix="train") 173 | 174 | logging.info("Epoch: [%d/%d]" % (self.curr_epoch, self.cfgs.training.epochs)) 175 | logging.info("Current learning rate: %.8f" % lr) 176 | 177 | for i, inputs in enumerate(self.train_loader): 178 | inputs = copy_to_device(inputs, self.device) 179 | 180 | # forward 181 | with torch.cuda.amp.autocast(enabled=self.cfgs.amp): 182 | self.ddp.forward(inputs) 183 | loss = self.model.get_loss() 184 | 185 | # backward 186 | self.amp_scaler.scale(loss).backward() 187 | 188 | # grad clip 189 | if "grad_max_norm" in self.cfgs.training.keys(): 190 | self.amp_scaler.unscale_(self.optimizer) 191 | torch.nn.utils.clip_grad_norm_( 192 | parameters=self.model.parameters(), 193 | max_norm=self.cfgs.training.grad_max_norm, 194 | ) 195 | 196 | # update 197 | self.amp_scaler.step(self.optimizer) 198 | self.amp_scaler.update() 199 | self.optimizer.zero_grad() 200 | 201 | metrics = self.model.get_metrics() 202 | self.save_scalar_summary(metrics, prefix="train") 203 | 204 | @torch.no_grad() 205 | def validate(self): 206 | logging.info("Start validating...") 207 | 208 | self.ddp.eval() 209 | self.model.clear_metrics() 210 | 211 | for i, inputs in enumerate(self.val_loader): 212 | inputs = copy_to_device(inputs, self.device) 213 | 214 | with torch.cuda.amp.autocast(enabled=False): 215 | self.ddp.forward(inputs) 216 | 217 | metrics = self.model.get_metrics() 218 | self.save_scalar_summary(metrics, prefix="val") 219 | 220 | for k, v in metrics.items(): 221 | logging.info("%s: %.4f" % (k, v)) 222 | 223 | if self.model.is_better(metrics, self.best_metrics): 224 | self.best_metrics = metrics 225 | self.save_ckpt("best.pt") 226 | 227 | def save_scalar_summary(self, scalar_summary: dict, prefix): 228 | if self.is_main and self.cfgs.log.save_scalar_summary: 229 | for name in scalar_summary.keys(): 230 | self.summary_writer.add_scalar( 231 | prefix + "/" + name, scalar_summary[name], self.curr_epoch 232 | ) 233 | 234 | def save_ckpt(self, filename=None): 235 | if self.is_main and self.cfgs.log.save_ckpt: 236 | ckpt_dir = os.path.join(self.cfgs.log.dir, "ckpts") 237 | os.makedirs(ckpt_dir, exist_ok=True) 238 | # filepath = os.path.join( 239 | # ckpt_dir, filename or "epoch-%03d.pt" % self.curr_epoch 240 | # ) 241 | filepath = os.path.join( 242 | ckpt_dir, filename or "epoch-latest.pt" 243 | ) 244 | logging.info("Saving checkpoint to %s" % filepath) 245 | torch.save( 246 | { 247 | "last_epoch": self.curr_epoch, 248 | "state_dict": self.model.state_dict(), 249 | "best_metrics": self.best_metrics, 250 | }, 251 | filepath, 252 | ) 253 | 254 | def load_ckpt(self, filepath, resume=True): 255 | logging.info("Loading checkpoint from %s" % filepath) 256 | checkpoint = torch.load(filepath, self.device) 257 | if resume: 258 | self.curr_epoch = checkpoint["last_epoch"] + 1 259 | self.best_metrics = checkpoint["best_metrics"] 260 | logging.info("Current best metrics: %s" % str(self.best_metrics)) 261 | # self.model.load_state_dict(checkpoint["state_dict"], strict=True) 262 | self.model.load_state_dict(checkpoint["state_dict"], strict=False) 263 | 264 | 265 | def create_trainer(device_id, cfgs): 266 | device = torch.device("cpu" if device_id is None else "cuda:%d" % device_id) 267 | trainer = Trainer(device, cfgs) 268 | trainer.run() 269 | 270 | 271 | def main(cfgs: DictConfig): 272 | # set num_workers of data loader 273 | if not cfgs.debug: 274 | n_devices = max(torch.cuda.device_count(), 1) 275 | cfgs.trainset.n_workers = min( 276 | os.cpu_count(), cfgs.trainset.n_workers * n_devices 277 | ) 278 | cfgs.valset.n_workers = min(os.cpu_count(), cfgs.valset.n_workers * n_devices) 279 | else: 280 | cfgs.trainset.n_workers = 0 281 | cfgs.valset.n_workers = 0 282 | 283 | if cfgs.port == "random": 284 | cfgs.port = random.randint(10000, 20000) 285 | 286 | if cfgs.training.accum_iter > 1: 287 | cfgs.model.batch_size //= int(cfgs.training.accum_iter) 288 | 289 | # create trainers 290 | if torch.cuda.device_count() == 0: # CPU 291 | create_trainer(None, cfgs) 292 | elif torch.cuda.device_count() == 1: # Single GPU 293 | create_trainer(0, cfgs) 294 | elif torch.cuda.device_count() > 1: # Multiple GPUs 295 | mp.spawn(create_trainer, (cfgs,), torch.cuda.device_count()) 296 | 297 | 298 | if __name__ == "__main__": 299 | path = sys.argv[1] 300 | with open(path, encoding="utf-8") as f: 301 | cfgs = DictConfig(yaml.load(f, Loader=yaml.FullLoader)) 302 | main(cfgs) 303 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_meter import * 2 | from .build_utils import build_optim_and_sched 3 | from .train_utils import copy_to_device, copy_to_cuda 4 | from .log_utils import init_experiment_dir, init_logging, init_log 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | 10 | class _RepeatSampler(object): 11 | def __init__(self, sampler): 12 | self.sampler = sampler 13 | 14 | def __iter__(self): 15 | while True: 16 | yield from iter(self.sampler) 17 | 18 | 19 | class FastDataLoader(torch.utils.data.dataloader.DataLoader): 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) 23 | self.iterator = super().__iter__() 24 | 25 | def __len__(self): 26 | return len(self.batch_sampler.sampler) 27 | 28 | def __iter__(self): 29 | for i in range(len(self)): 30 | yield next(self.iterator) 31 | 32 | def size_of_batch(inputs): 33 | if isinstance(inputs, list): 34 | return size_of_batch(inputs[0]) 35 | elif isinstance(inputs, dict): 36 | return size_of_batch(list(inputs.values())[0]) 37 | elif isinstance(inputs, torch.Tensor): 38 | return inputs.shape[0] 39 | else: 40 | raise TypeError('Unknown type: %s' % str(type(inputs))) 41 | 42 | def save_flow_png(filepath, flow, mask=None, scale=64.0): 43 | assert flow.shape[2] == 2 44 | assert np.abs(flow).max() < 32767.0 / scale 45 | flow = flow * scale 46 | flow = flow + 32768.0 47 | 48 | if mask is None: 49 | mask = np.ones_like(flow)[..., 0] 50 | else: 51 | mask = np.float32(mask > 0) 52 | 53 | flow_img = np.concatenate([ 54 | mask[..., None], 55 | flow[..., 1:2], 56 | flow[..., 0:1] 57 | ], axis=-1).astype(np.uint16) 58 | 59 | cv2.imwrite(filepath, flow_img) 60 | 61 | 62 | def load_disp_png(filepath): 63 | array = cv2.imread(filepath, -1) 64 | valid_mask = array > 0 65 | disp = array.astype(np.float32) / 256.0 66 | disp[np.logical_not(valid_mask)] = -1.0 67 | return disp, valid_mask 68 | 69 | 70 | def save_disp_png(filepath, disp, mask=None): 71 | if mask is None: 72 | mask = disp > 0 73 | disp = np.uint16(disp * 256.0) 74 | disp[np.logical_not(mask)] = 0 75 | cv2.imwrite(filepath, disp) 76 | 77 | def disp2pc(disp, baseline, f, cx, cy, flow=None): 78 | h, w = disp.shape 79 | depth = baseline * f / (disp + 1e-5) 80 | 81 | xx = np.tile(np.arange(w, dtype=np.float32)[None, :], (h, 1)) 82 | yy = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w)) 83 | 84 | if flow is None: 85 | x = (xx - cx) * depth / f 86 | y = (yy - cy) * depth / f 87 | else: 88 | x = (xx - cx + flow[..., 0]) * depth / f 89 | y = (yy - cy + flow[..., 1]) * depth / f 90 | 91 | pc = np.concatenate([ 92 | x[:, :, None], 93 | y[:, :, None], 94 | depth[:, :, None], 95 | ], axis=-1) 96 | 97 | return pc 98 | 99 | mesh_grid_cache = {} 100 | def mesh_grid(n, h, w, device, channel_first=True): 101 | global mesh_grid_cache 102 | str_id = '%d,%d,%d,%s,%s' % (n, h, w, device, channel_first) 103 | if str_id not in mesh_grid_cache: 104 | x_base = torch.arange(0, w, dtype=torch.float32, device=device)[None, None, :].expand(n, h, w) 105 | y_base = torch.arange(0, h, dtype=torch.float32, device=device)[None, None, :].expand(n, w, h) # NWH 106 | grid = torch.stack([x_base, y_base.transpose(1, 2)], 1) # B2HW 107 | if not channel_first: 108 | grid = grid.permute(0, 2, 3, 1) # BHW2 109 | mesh_grid_cache[str_id] = grid 110 | return mesh_grid_cache[str_id] -------------------------------------------------------------------------------- /utils/average_meter.py: -------------------------------------------------------------------------------- 1 | 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count 19 | -------------------------------------------------------------------------------- /utils/build_utils.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Adam, AdamW 2 | from timm.scheduler import create_scheduler 3 | 4 | def build_optim_and_sched(cfgs, model): 5 | params_2d_decay = [] 6 | params_3d_decay = [] 7 | 8 | params_2d_no_decay = [] 9 | params_3d_no_decay = [] 10 | 11 | for name, param in model.named_parameters(): 12 | if not param.requires_grad: 13 | continue # frozen weights 14 | 15 | if len(param.shape) == 1 or name.endswith(".bias"): 16 | if name.startswith('core.branch_3d'): 17 | params_3d_no_decay.append(param) 18 | else: 19 | params_2d_no_decay.append(param) 20 | else: 21 | if name.startswith('core.branch_3d'): 22 | params_3d_decay.append(param) 23 | else: 24 | params_2d_decay.append(param) 25 | 26 | lr = getattr(cfgs, 'lr', None) 27 | lr_2d = getattr(cfgs, 'lr_2d', lr) 28 | lr_3d = getattr(cfgs, 'lr_3d', lr) 29 | 30 | params = [ 31 | {'params': params_2d_decay, 'weight_decay': cfgs.weight_decay, 'lr': lr_2d}, 32 | {'params': params_3d_decay, 'weight_decay': cfgs.weight_decay, 'lr': lr_3d}, 33 | {'params': params_2d_no_decay, 'weight_decay': 0, 'lr': lr_2d}, 34 | {'params': params_3d_no_decay, 'weight_decay': 0, 'lr': lr_3d}, 35 | ] 36 | 37 | if cfgs.opt == 'adam': 38 | optimizer = Adam(params) 39 | elif cfgs.opt == 'adamw': 40 | optimizer = AdamW(params) 41 | else: 42 | raise NotImplementedError 43 | 44 | scheduler = create_scheduler(cfgs, optimizer)[0] 45 | 46 | return optimizer, scheduler 47 | -------------------------------------------------------------------------------- /utils/evaluation_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation metrics 3 | Borrowed from HPLFlowNet 4 | Date: May 2020 5 | 6 | @inproceedings{HPLFlowNet, 7 | title={HPLFlowNet: Hierarchical Permutohedral Lattice FlowNet for 8 | Scene Flow Estimation on Large-scale Point Clouds}, 9 | author={Gu, Xiuye and Wang, Yijie and Wu, Chongruo and Lee, Yong Jae and Wang, Panqu}, 10 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2019 IEEE International Conference on}, 11 | year={2019} 12 | } 13 | """ 14 | 15 | import numpy as np 16 | 17 | 18 | def evaluate_3d(sf_pred, sf_gt): 19 | """ 20 | sf_pred: (N, 3) 21 | sf_gt: (N, 3) 22 | """ 23 | l2_norm = np.linalg.norm(sf_gt - sf_pred, axis=-1) 24 | EPE3D = l2_norm.mean() 25 | 26 | sf_norm = np.linalg.norm(sf_gt, axis=-1) 27 | relative_err = l2_norm / (sf_norm + 1e-4) 28 | 29 | acc3d_strict = (np.logical_or(l2_norm < 0.05, relative_err < 0.05)).astype(np.float).mean() 30 | acc3d_relax = (np.logical_or(l2_norm < 0.1, relative_err < 0.1)).astype(np.float).mean() 31 | outlier = (np.logical_or(l2_norm > 0.3, relative_err > 0.1)).astype(np.float).mean() 32 | 33 | return EPE3D, acc3d_strict, acc3d_relax, outlier 34 | 35 | 36 | def evaluate_2d(flow_pred, flow_gt): 37 | """ 38 | flow_pred: (N, 2) 39 | flow_gt: (N, 2) 40 | """ 41 | 42 | epe2d = np.linalg.norm(flow_gt - flow_pred, axis=-1) 43 | epe2d_mean = epe2d.mean() 44 | 45 | flow_gt_norm = np.linalg.norm(flow_gt, axis=-1) 46 | relative_err = epe2d / (flow_gt_norm + 1e-5) 47 | 48 | acc2d = (np.logical_or(epe2d < 3., relative_err < 0.05)).astype(np.float).mean() 49 | 50 | return epe2d_mean, acc2d 51 | -------------------------------------------------------------------------------- /utils/geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import os.path as osp 4 | 5 | 6 | def get_batch_2d_flow(pc1, pc2, predicted_pc2, intrinsics): 7 | 8 | focallengths = intrinsics[0][0] 9 | cxs = intrinsics[0][2] 10 | cys = intrinsics[0][3] 11 | constx = 0 12 | consty = 0 13 | constz = 0 14 | 15 | px1, py1 = project_3d_to_2d(pc1, f=focallengths, cx=cxs, cy=cys, 16 | constx=constx, consty=consty, constz=constz) 17 | px2, py2 = project_3d_to_2d(predicted_pc2, f=focallengths, cx=cxs, cy=cys, 18 | constx=constx, consty=consty, constz=constz) 19 | px2_gt, py2_gt = project_3d_to_2d(pc2, f=focallengths, cx=cxs, cy=cys, 20 | constx=constx, consty=consty, constz=constz) 21 | 22 | 23 | flow_x = px2 - px1 24 | flow_y = py2 - py1 25 | 26 | flow_x_gt = px2_gt - px1 27 | flow_y_gt = py2_gt - py1 28 | 29 | flow_pred = np.concatenate((flow_x[..., None], flow_y[..., None]), axis=-1) 30 | flow_gt = np.concatenate((flow_x_gt[..., None], flow_y_gt[..., None]), axis=-1) 31 | return flow_pred, flow_gt 32 | 33 | 34 | def project_3d_to_2d(pc, f=1050., cx=479.5, cy=269.5, constx=0, consty=0, constz=0): 35 | x = (pc[..., 0] * f + cx * pc[..., 2] + constx) / (pc[..., 2] + constz + 10e-10) 36 | y = (pc[..., 1] * f + cy * pc[..., 2] + consty) / (pc[..., 2] + constz + 10e-10) 37 | 38 | return x, y 39 | -------------------------------------------------------------------------------- /utils/log_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import datetime 3 | import os 4 | import logging 5 | import sys 6 | 7 | def init_experiment_dir(cfg): 8 | experiment_dir = Path(cfg.log.dir) 9 | experiment_dir.mkdir(exist_ok=True) 10 | file_dir = Path( 11 | str(experiment_dir) 12 | + "/Flyingthings3d-" 13 | + str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")) 14 | ) 15 | file_dir.mkdir(exist_ok=True) 16 | checkpoints_dir = file_dir.joinpath(cfg.ckpt.save_path) 17 | checkpoints_dir.mkdir(exist_ok=True) 18 | log_dir = file_dir.joinpath("logs/") 19 | log_dir.mkdir(exist_ok=True) 20 | 21 | os.system("cp -r %s %s" % ("models", log_dir)) 22 | os.system("cp -r %s %s" % ("dataset", log_dir)) 23 | os.system("cp -r %s %s" % ("config", log_dir)) 24 | os.system("cp %s %s" % ("train.py", log_dir)) 25 | 26 | return checkpoints_dir, log_dir 27 | 28 | def init_logging(log_dir, cfgs): 29 | logger = logging.getLogger(cfgs.model.name) 30 | formatter = logging.Formatter("[%(asctime)s][%(levelname)s] - %(message)s") 31 | file_handler = logging.FileHandler(str(log_dir) + "/log_%s.log" % cfgs.model.name) 32 | file_handler.setFormatter(formatter) 33 | stream_handler = logging.StreamHandler() 34 | stream_handler.setFormatter(formatter) 35 | logger.setLevel(logging.INFO) 36 | logger.addHandler(file_handler) 37 | logger.addHandler(stream_handler) 38 | return logger 39 | 40 | def init_log(filename=None, debug=False): 41 | logging.root = logging.RootLogger('DEBUG' if debug else 'INFO') 42 | formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s') 43 | 44 | stream_handler = logging.StreamHandler(sys.stdout) 45 | stream_handler.setFormatter(formatter) 46 | logging.root.addHandler(stream_handler) 47 | 48 | if filename is not None: 49 | file_handler = logging.FileHandler(filename) 50 | file_handler.setFormatter(formatter) 51 | logging.root.addHandler(file_handler) -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import re 3 | import os 4 | import cv2 5 | import sys 6 | import smtplib 7 | import logging 8 | import numpy as np 9 | 10 | def copy_to_device(inputs, device, non_blocking=True): 11 | 12 | if isinstance(inputs, list): 13 | inputs = [copy_to_device(item, device, non_blocking) for item in inputs] 14 | elif isinstance(inputs, dict): 15 | inputs = {k: copy_to_device(v, device, non_blocking) for k, v in inputs.items()} 16 | elif isinstance(inputs, torch.Tensor): 17 | inputs = inputs.to(device=device, non_blocking=non_blocking) 18 | else: 19 | raise TypeError("Unknown type: %s" % str(type(inputs))) 20 | return inputs 21 | 22 | def copy_to_cuda(inputs, non_blocking=True): 23 | 24 | assert isinstance(inputs, dict), "inputs not dict" 25 | 26 | inputs = {k: v.cuda(non_blocking) for k, v in inputs.items()} 27 | 28 | return inputs 29 | 30 | --------------------------------------------------------------------------------