├── .gitignore
├── README.md
├── assets
├── ft_visualize.png
├── represent.png
├── res.png
└── structure_overview.png
├── checkpoints
└── ckpt_best.pt
├── config
├── config_kitti.yaml
├── config_pretrain.yaml
└── config_things.yaml
├── dataset
├── __init__.py
├── augmentation.py
├── flyingthings_subset.py
├── kitti.py
└── preprocess_data.py
├── env.yaml
├── models
├── __init__.py
├── base.py
├── camlipwc.py
├── camlipwc2d_core.py
├── camlipwc3d_core.py
├── camlipwc_core.py
├── csrc
│ ├── __init__.py
│ ├── correlation
│ │ ├── correlation.cpp
│ │ ├── correlation.h
│ │ ├── correlation_backward_kernel.cu
│ │ ├── correlation_forward_kernel.cu
│ │ └── correlation_test.cpp
│ ├── furthest_point_sampling
│ │ ├── furthest_point_sampling.cpp
│ │ ├── furthest_point_sampling.h
│ │ ├── furthest_point_sampling_kernel.cu
│ │ └── furthest_point_sampling_test.cpp
│ ├── k_nearest_neighbor
│ │ ├── k_nearest_neighbor.cpp
│ │ ├── k_nearest_neighbor.h
│ │ ├── k_nearest_neighbor_kernel.cu
│ │ └── k_nearest_neighbor_test.cpp
│ ├── setup.py
│ └── wrapper.py
├── fusion_module.py
├── losses2d.py
├── losses3d.py
├── pointconv.py
└── utils.py
├── ops_pytorch
├── __init__.py
├── fused_conv_select
│ ├── fused_conv_g.cpp
│ ├── fused_conv_go.cu
│ ├── fused_conv_gpu.h
│ ├── fused_conv_select_k.py
│ └── setup.py
└── gpu_threenn_sample
│ ├── no_sort_knn.py
│ ├── no_sort_knn_g.cpp
│ ├── no_sort_knn_go.cu
│ ├── no_sort_knn_gpu.h
│ └── setup.py
├── train.py
└── utils
├── __init__.py
├── average_meter.py
├── build_utils.py
├── evaluation_utils.py
├── geometry.py
├── log_utils.py
└── train_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.zip
2 | output/*
3 | .vscode
4 | # 3rdparty
5 | Datasets
6 | SUMMARY
7 | *.out
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 | *.log
13 | # C extensions
14 | *.so
15 | *.ply
16 | *.npy
17 | *.log
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | exps*/
22 | ckpts/*
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | pip-wheel-metadata/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 | output/
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 | *.ppt
142 | *.pptx
143 | *.log
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
DELFlow: Dense Efficient Learning of Scene Flow for Large-Scale Point Clouds
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## :bookmark_tabs: Table of Contents
12 |
13 | 1. [Introduction](#clapper-introduction)
14 | 2. [Installation](#memo-installation)
15 | 3. [Dataset](#file_folder-dataset)
16 | 4. [Usage](#computer-usage)
17 | 5. [Qualitative Results](#art-visualization)
18 |
19 |
20 |
21 |
22 | ## :clapper: Introduction
23 |
24 | Point clouds are naturally sparse, while image pixels are dense. The inconsistency limits feature fusion from both modalities for point-wise scene flow estimation. Previous methods rarely predict scene flow from the entire point clouds of the scene with one-time inference due to the memory inefficiency and heavy overhead from distance calculation and sorting involved in commonly used farthest point sampling, KNN, and ball query algorithms for local feature aggregation.
25 |
26 | To mitigate these issues in scene flow learning, we regularize raw points to a dense format by storing 3D coordinates in 2D grids. Unlike the sampling operation commonly used in existing works, the dense 2D representation
27 |
28 | - preserves most points in the given scene,
29 | - brings in a significant boost of efficiency
30 | - eliminates the density gap between points and pixels, allowing us to perform effective feature fusion.
31 |
32 | We also present a novel warping projection technique to alleviate the information loss problem resulting from the fact that multiple points could be mapped into one grid during projection when computing cost volume. Sufficient experiments demonstrate the efficiency and effectiveness of our method, outperforming the prior-arts on the FlyingThings3D and KITTI dataset.
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | For more details, please refer to our [paper](https://openaccess.thecvf.com/content/ICCV2023/html/Peng_DELFlow_Dense_Efficient_Learning_of_Scene_Flow_for_Large-Scale_Point_ICCV_2023_paper.html), [arxiv](https://arxiv.org/abs/2308.04383).
41 |
42 |
43 |
44 | ## :memo: Installation
45 |
46 | Create a PyTorch environment using `conda`:
47 | ```bash
48 | git clone https://github.com/IRMVLab/DELFlow.git
49 | cd DELFlow
50 | conda env create -f env.yaml
51 | conda activate delflow
52 | ```
53 |
54 | Compile CUDA extensions
55 | ```bash
56 | cd models/csrc
57 | python setup.py build_ext --inplace
58 |
59 | cd ops_pytorch/fused_conv_select
60 | python setup.py install
61 |
62 | cd ops_pytorch/gpu_threenn_sample
63 | python setup.py install
64 | ```
65 |
66 |
67 |
68 | ## :file_folder: Dataset
69 |
70 | ### FlyingThings3D
71 |
72 | First, download and preprocess the [FlyingThings3D subset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) dataset, then process them as follows (you may need to change the `input_dir` and `output_dir`):
73 |
74 | ```bash
75 | python dataset/preprocess_data.py
76 | ```
77 |
78 | ### KITTI
79 |
80 | First, download the following parts:
81 |
82 | - Main data: [data_scene_flow.zip](https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow.zip)
83 | - Calibration files: [data_scene_flow_calib.zip](https://s3.eu-central-1.amazonaws.com/avg-kitti/data_scene_flow_calib.zip)
84 | - Disparity estimation (from GA-Net): [disp_ganet.zip](https://drive.google.com/file/d/1ieFpOVzqCzT8TXNk1zm2d9RLkrcaI78o/view?usp=sharing)
85 | - Semantic segmentation (from DDR-Net): [semantic_ddr.zip](https://drive.google.com/file/d/1dVSJeE9BBmVv2rCe5TR0PVanEv2WzwIy/view?usp=sharing)
86 |
87 |
88 |
89 | [Unzip them and organize the directory as follows (click to expand)]
90 |
91 |
92 | ```
93 | dataset/kitti_scene_flow
94 | ├── testing
95 | │ ├── calib_cam_to_cam
96 | │ ├── calib_imu_to_velo
97 | │ ├── calib_velo_to_cam
98 | │ ├── disp_ganet
99 | │ ├── flow_occ
100 | │ ├── image_2
101 | │ ├── image_3
102 | │ ├── semantic_ddr
103 | └── training
104 | ├── calib_cam_to_cam
105 | ├── calib_imu_to_velo
106 | ├── calib_velo_to_cam
107 | ├── disp_ganet
108 | ├── disp_occ_0
109 | ├── disp_occ_1
110 | ├── flow_occ
111 | ├── image_2
112 | ├── image_3
113 | ├── obj_map
114 | ├── semantic_ddr
115 | ```
116 |
117 |
118 |
119 | ## :computer: Usage
120 |
121 | Run the code using the following command:
122 |
123 | ```python
124 | python train.py config/$config_.yaml$
125 | ```
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 | ## :book: Citation
136 |
137 | If you use this codebase or model in your research, please cite:
138 |
139 | ```
140 | @inproceedings{peng2023delflow,
141 | title={Delflow: Dense efficient learning of scene flow for large-scale point clouds},
142 | author={Peng, Chensheng and Wang, Guangming and Lo, Xian Wan and Wu, Xinrui and Xu, Chenfeng and Tomizuka, Masayoshi and Zhan, Wei and Wang, Hesheng},
143 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
144 | pages={16901--16910},
145 | year={2023}
146 | }
147 | ```
148 |
149 |
150 |
151 | ## :art: Visualization
152 |
153 | In this section, we present illustrative examples that demonstrate the effectiveness of our proposal.
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 | ## :pray: Acknowledgements
164 |
165 | This code benefits a lot from [CamLiFlow](https://github.com/MCG-NJU/CamLiFlow). Thanks for making codes public available.
166 |
--------------------------------------------------------------------------------
/assets/ft_visualize.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/ft_visualize.png
--------------------------------------------------------------------------------
/assets/represent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/represent.png
--------------------------------------------------------------------------------
/assets/res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/res.png
--------------------------------------------------------------------------------
/assets/structure_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/assets/structure_overview.png
--------------------------------------------------------------------------------
/checkpoints/ckpt_best.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/checkpoints/ckpt_best.pt
--------------------------------------------------------------------------------
/config/config_kitti.yaml:
--------------------------------------------------------------------------------
1 | trainset:
2 | name: kitti
3 | root_dir: /data/KITTI_SCENE
4 | split: traing160
5 | n_workers: 8
6 | max_depth: 90
7 | drop_last: true
8 |
9 | valset:
10 | name: kitti
11 | root_dir: /data/KITTI_SCENE
12 | split: training40
13 | n_workers: 8
14 | max_depth: 90
15 |
16 | model:
17 | name: kitti
18 | height: 320
19 | width: 1280
20 | stride_h: [2, 2, 2, 2, 2, 2]
21 | stride_w: [2, 2, 2, 2, 2, 2]
22 | freeze_bn: false
23 | batch_size: 12
24 |
25 | pwc2d:
26 | norm:
27 | feature_pyramid: batch_norm
28 | flow_estimator: null
29 | context_network: null
30 | max_displacement: 4
31 |
32 | pwc3d:
33 | norm:
34 | feature_pyramid: batch_norm
35 | correlation: null
36 | flow_estimator: null
37 | k: 16
38 | kernel_size: [10, 30]
39 |
40 | loss2d:
41 | level_weights: [16, 8, 4, 2, 1]
42 | order: robust
43 |
44 | loss3d:
45 | level_weights: [8, 4, 2, 1, 0.5]
46 | order: robust
47 |
48 |
49 | training:
50 | opt: adamw
51 | momentum: 0.9
52 | weight_decay: 0.000001
53 | grad_max_norm: 25
54 | accum_iter: 1
55 |
56 | sched: cosine
57 | epochs: 1200
58 | # lr_2d: 0.0005 # LR for 2D branch
59 | # lr_3d: 0.001 # LR for 3D branch
60 | lr_2d: 0.0005 # LR for 2D branch
61 | lr_3d: 0.001 # LR for 3D branch
62 | min_lr: 0.00001
63 | warmup_lr: 0.00001
64 | warmup_epochs: 0
65 | cooldown_epochs: 0
66 |
67 | log:
68 | dir: ./experiment
69 | run_name: kitti
70 | save_ckpt: true
71 | save_scalar_summary: true
72 | save_ckpt_every_n_epochs: 1
73 | save_summary_every_n_steps: 100
74 |
75 | ckpt:
76 | path: null
77 | save_path: ./checkpoints
78 | resume: false
79 |
80 | val_interval: 4
81 | gpu: 1
82 | port: random # for multi-gpu training
83 | amp: false
84 | debug: false
85 | sync_bn: true
86 |
--------------------------------------------------------------------------------
/config/config_pretrain.yaml:
--------------------------------------------------------------------------------
1 | trainset:
2 | name: flyingthings3d
3 | root_dir: /data/processed_flyingthings3d
4 | split: train
5 | n_workers: 4
6 | drop_last: true
7 | full: false
8 |
9 | augmentation:
10 | enabled: false
11 | color_jitter:
12 | enabled: false
13 | random_horizontal_flip:
14 | enabled: false
15 | random_vertical_flip:
16 | enabled: false
17 | random_crop:
18 | enabled: false
19 | random_scale:
20 | enabled: false
21 |
22 | valset:
23 | name: flyingthings3d
24 | root_dir: /data/processed_flyingthings3d
25 | split: val
26 | n_workers: 4
27 | full: false
28 |
29 | augmentation:
30 | enabled: false
31 |
32 | model:
33 | name: pretrain
34 | height: 256
35 | width: 448
36 | stride_h: [2, 2, 2, 2, 2, 2]
37 | stride_w: [2, 2, 2, 2, 2, 2]
38 | freeze_bn: false
39 | batch_size: 64
40 |
41 | pwc2d:
42 | norm:
43 | feature_pyramid: batch_norm
44 | flow_estimator: null
45 | context_network: null
46 | max_displacement: 4
47 |
48 | pwc3d:
49 | norm:
50 | feature_pyramid: batch_norm
51 | correlation: null
52 | flow_estimator: null
53 | k: 16
54 | kernel_size: [10, 20]
55 |
56 | loss2d:
57 | level_weights: [8, 4, 2, 1, 0.5]
58 | order: l2
59 |
60 | loss3d:
61 | level_weights: [8, 4, 2, 1, 0.5]
62 | order: l2
63 |
64 |
65 | training:
66 | opt: adamw
67 | momentum: 0.9
68 | weight_decay: 0.00001
69 | grad_max_norm: 20
70 | accum_iter: 1
71 |
72 | sched: cosine
73 | epochs: 800
74 | lr_2d: 0.0005 # LR for 2D branch
75 | lr_3d: 0.001 # LR for 3D branch
76 | min_lr: 0.00001
77 | warmup_lr: 0.00001
78 | warmup_epochs: 2
79 | cooldown_epochs: 0
80 |
81 | log:
82 | dir: ./experiment
83 | run_name: pretrain
84 | save_ckpt: true
85 | save_scalar_summary: true
86 | save_ckpt_every_n_epochs: 1
87 | save_summary_every_n_steps: 100
88 |
89 | ckpt:
90 | path: null
91 | save_path: ./checkpoints
92 | resume: false
93 |
94 | val_interval: 10
95 | gpu: 1
96 | port: random # for multi-gpu training
97 | amp: false
98 | debug: false
99 | sync_bn: true
100 |
--------------------------------------------------------------------------------
/config/config_things.yaml:
--------------------------------------------------------------------------------
1 | trainset:
2 | name: flyingthings3d
3 | root_dir: /data/processed_flyingthings3d
4 | split: train
5 | n_workers: 4
6 | drop_last: true
7 | full: true
8 |
9 | augmentation:
10 | enabled: true
11 | color_jitter:
12 | enabled: true
13 | brightness: 0.3
14 | contrast: 0.3
15 | saturation: 0.3
16 | hue: 0.159 # 0.5/3.14
17 | random_horizontal_flip:
18 | enabled: false
19 | random_vertical_flip:
20 | enabled: false
21 | random_crop:
22 | enabled: false
23 | crop_size: [640, 384] # [640, 384]
24 | random_scale:
25 | enabled: false
26 | random_down:
27 | enabled: true
28 | type: train
29 |
30 | valset:
31 | name: flyingthings3d
32 | root_dir: /data/processed_flyingthings3d
33 | split: val
34 | n_workers: 4
35 | full: true
36 |
37 | augmentation:
38 | enabled: false
39 | color_jitter:
40 | enabled: false
41 | random_horizontal_flip:
42 | enabled: false
43 | random_vertical_flip:
44 | enabled: false
45 | random_crop:
46 | enabled: false
47 | crop_size: [640, 384] # [640, 384]
48 | random_scale:
49 | enabled: false
50 | random_down:
51 | enabled: true
52 | type: eval
53 |
54 | model:
55 | name: flyingthings
56 | height: 256
57 | width: 448
58 | stride_h: [2, 2, 2, 2, 2, 2]
59 | stride_w: [2, 2, 2, 2, 2, 2]
60 | freeze_bn: false
61 | batch_size: 64
62 |
63 | pwc2d:
64 | norm:
65 | feature_pyramid: batch_norm
66 | flow_estimator: null
67 | context_network: null
68 | max_displacement: 4
69 |
70 | pwc3d:
71 | norm:
72 | feature_pyramid: batch_norm
73 | correlation: null
74 | flow_estimator: null
75 | k: 16
76 | kernel_size: [10, 20]
77 |
78 | loss2d:
79 | level_weights: [8, 4, 2, 1, 0.5]
80 | order: robust
81 |
82 | loss3d:
83 | level_weights: [8, 4, 2, 1, 0.5]
84 | order: robust
85 |
86 |
87 | training:
88 | opt: adamw
89 | momentum: 0.9
90 | weight_decay: 0.000001
91 | grad_max_norm: 20
92 | accum_iter: 1
93 |
94 | sched: cosine
95 | epochs: 400
96 | lr_2d: 0.00005 # LR for 2D branch
97 | lr_3d: 0.0001 # LR for 3D branch
98 | min_lr: 0.00001
99 | warmup_lr: 0.00001
100 | warmup_epochs: 2
101 | cooldown_epochs: 0
102 |
103 | log:
104 | dir: ./experiment
105 | run_name: flyingthings
106 | save_ckpt: true
107 | save_scalar_summary: true
108 | save_ckpt_every_n_epochs: 1
109 | save_summary_every_n_steps: 100
110 |
111 | ckpt:
112 | path: ./checkpoints/best.pt
113 | save_path: ./checkpoints
114 | resume: false
115 |
116 | val_interval: 10
117 | gpu: 0
118 | port: random # for multi-gpu training
119 | amp: false
120 | debug: false
121 | sync_bn: true
122 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/augmentation.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import torchvision
4 | import numpy as np
5 |
6 |
7 | def color_jitter(image1, image2, brightness, contrast, saturation, hue):
8 | assert image1.shape == image2.shape
9 | cj_module = torchvision.transforms.ColorJitter(brightness, contrast, saturation, hue)
10 |
11 | images = np.concatenate([image1, image2], axis=0)
12 | images_t = torch.from_numpy(images.transpose([2, 0, 1]).copy())
13 | images_t = cj_module.forward(images_t / 255.0) * 255.0
14 | images = images_t.numpy().astype(np.uint8).transpose(1, 2, 0)
15 | image1, image2 = images[:image1.shape[0]], images[image1.shape[0]:]
16 |
17 | return image1, image2
18 |
19 |
20 | def flip_point_cloud(pc, image_h, image_w, f, cx, cy, flip_mode):
21 | assert flip_mode in ['lr', 'ud']
22 | pc_x, pc_y, depth = pc[..., 0], pc[..., 1], pc[..., 2]
23 |
24 | image_x = cx + (f / depth) * pc_x
25 | image_y = cy + (f / depth) * pc_y
26 |
27 | if flip_mode == 'lr':
28 | image_x = image_w - 1 - image_x
29 | else:
30 | image_y = image_h - 1 - image_y
31 |
32 | pc_x = (image_x - cx) * depth / f
33 | pc_y = (image_y - cy) * depth / f
34 | pc = np.concatenate([pc_x[:, None], pc_y[:, None], depth[:, None]], axis=-1)
35 |
36 | return pc
37 |
38 |
39 | def flip_scene_flow(pc1, flow_3d, image_h, image_w, f, cx, cy, flip_mode):
40 | new_pc1 = flip_point_cloud(pc1, image_h, image_w, f, cx, cy, flip_mode)
41 | new_pc1_warp = flip_point_cloud(pc1 + flow_3d[:, :3], image_h, image_w, f, cx, cy, flip_mode)
42 | return np.concatenate([new_pc1_warp - new_pc1, flow_3d[:, 3:]], axis=-1)
43 |
44 |
45 | def flip_image(image, flip_mode):
46 | if flip_mode == 'lr':
47 | return np.fliplr(image).copy()
48 | else:
49 | return np.flipud(image).copy()
50 |
51 |
52 | def flip_optical_flow(flow, flip_mode):
53 | assert flip_mode in ['lr', 'ud']
54 | if flip_mode == 'lr':
55 | flow = np.fliplr(flow).copy()
56 | flow[:, :, 0] *= -1
57 | else:
58 | flow = np.flipud(flow).copy()
59 | flow[:, :, 1] *= -1
60 | return flow
61 |
62 |
63 | def random_flip(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, flip_mode):
64 | assert flow_3d.shape[1] <= 4
65 | assert flip_mode in ['lr', 'ud']
66 | image_h, image_w = image1.shape[:2]
67 |
68 | if np.random.rand() < 0.5: # do nothing
69 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy
70 |
71 | if flip_mode == 'lr':
72 | new_f = -1.0 * f
73 | new_cx = image_w - 1 - cx
74 | new_cy = cy
75 | else:
76 | new_f = -1.0 * f
77 | new_cy = image_h - 1 - cy
78 | new_cx = cx
79 |
80 | new_mask = flip_image(mask, flip_mode)
81 | # flip images
82 | new_image1 = flip_image(image1, flip_mode)
83 | new_image2 = flip_image(image2, flip_mode)
84 |
85 | # flip point clouds
86 | new_pc1 = flip_image(pc1, flip_mode)
87 | new_pc2 = flip_image(pc2, flip_mode)
88 |
89 | # flip optical flow and scene flow
90 | new_flow_2d = flip_optical_flow(flow_2d, flip_mode)
91 | new_flow_3d = flip_image(flow_3d, flip_mode)
92 | # new_flow_3d = flip_scene_flow(pc1, flow_3d, image_h, image_w, f, cx, cy, flip_mode)
93 |
94 | return new_image1, new_image2, new_pc1, new_pc2, new_flow_2d, flow_3d, new_mask, new_f, new_cx, new_cy
95 |
96 |
97 | def crop_image_with_pc(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, crop_window):
98 | assert len(crop_window) == 4 # [x1, y1, x2, y2]
99 |
100 | x1, y1, x2, y2 = crop_window
101 | # image_h, image_w = image1.shape[:2]
102 |
103 | # crop images
104 | image1 = image1[y1:y2, x1:x2].copy()
105 | image2 = image2[y1:y2, x1:x2].copy()
106 | flow_2d = flow_2d[y1:y2, x1:x2].copy()
107 |
108 | # crop pc hw3
109 | pc1 = pc1[y1:y2, x1:x2].copy()
110 | pc2 = pc2[y1:y2, x1:x2].copy()
111 | flow_3d = flow_3d[y1:y2, x1:x2].copy()
112 | mask = mask[y1:y2, x1:x2].copy()
113 |
114 | # adjust camera params
115 | cx = cx - x1
116 | cy = cy - y1
117 |
118 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy
119 |
120 |
121 | def random_crop(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, crop_size):
122 | assert flow_3d.shape[-1] <= 4
123 | assert len(crop_size) == 2
124 | crop_w, crop_h = crop_size
125 |
126 | image_h, image_w = image1.shape[:2]
127 | assert crop_w <= image_w and crop_h <= image_h
128 |
129 | # top left of the cropping window
130 | x1 = np.random.randint(low=0, high=image_w - crop_w + 1)
131 | y1 = np.random.randint(low=0, high=image_h - crop_h + 1)
132 | crop_window = [x1, y1, x1 + crop_w, y1 + crop_h]
133 |
134 | return crop_image_with_pc(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, crop_window)
135 |
136 |
137 | def resize_sparse_flow_map(flow, target_w, target_h):
138 | curr_h, curr_w = flow.shape[:2]
139 |
140 | coords = np.meshgrid(np.arange(curr_w), np.arange(curr_h))
141 | coords = np.stack(coords, axis=-1).astype(np.float32)
142 |
143 | mask = flow[..., -1] > 0
144 | coords0, flow0 = coords[mask], flow[mask][:, :2]
145 |
146 | scale_ratio_w = (target_w - 1) / (curr_w - 1)
147 | scale_ratio_h = (target_h - 1) / (curr_h - 1)
148 |
149 | coords1 = coords0 * [scale_ratio_w, scale_ratio_h]
150 | flow1 = flow0 * [scale_ratio_w, scale_ratio_h]
151 |
152 | xx = np.round(coords1[:, 0]).astype(np.int32)
153 | yy = np.round(coords1[:, 1]).astype(np.int32)
154 | valid = (xx >= 0) & (xx < target_w) & (yy >= 0) & (yy < target_h)
155 | xx, yy, flow1 = xx[valid], yy[valid], flow1[valid]
156 |
157 | flow_resized = np.zeros([target_h, target_w, 3], dtype=np.float32)
158 | flow_resized[yy, xx, :2] = flow1
159 | flow_resized[yy, xx, 2:] = 1.0
160 |
161 | return flow_resized
162 |
163 |
164 | def random_scale(image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy, scale_range):
165 | assert len(scale_range) == 2
166 | assert 1 <= scale_range[0] < scale_range[1]
167 |
168 | if np.random.rand() < 0.5:
169 | return image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy
170 |
171 | scale_ratio = np.random.uniform(scale_range[0], scale_range[1])
172 | image_h, image_w = image1.shape[:2]
173 | crop_h, crop_w = int(image_h / scale_ratio), int(image_w / scale_ratio)
174 |
175 | # top left of the cropping window
176 | x1 = np.random.randint(low=0, high=image_w - crop_w + 1)
177 | y1 = np.random.randint(low=0, high=image_h - crop_h + 1)
178 | crop_window = [x1, y1, x1 + crop_w, y1 + crop_h]
179 |
180 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy = crop_image_with_pc(
181 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy, crop_window
182 | )
183 |
184 | # resize images and optical flow
185 | image1 = cv2.resize(image1, (image_w, image_h), interpolation=cv2.INTER_LINEAR)
186 | image2 = cv2.resize(image2, (image_w, image_h), interpolation=cv2.INTER_LINEAR)
187 | flow_2d = resize_sparse_flow_map(flow_2d, image_w, image_h)
188 |
189 | # resize points and scene flow
190 | scale_ratio_w = (image_w - 1) / (crop_w - 1)
191 | scale_ratio_h = (image_h - 1) / (crop_h - 1)
192 | pc1[:, 0] *= scale_ratio_w
193 | pc1[:, 1] *= scale_ratio_h
194 | pc2[:, 0] *= scale_ratio_w
195 | pc2[:, 1] *= scale_ratio_h
196 | flow_3d[:, 0] *= scale_ratio_w
197 | flow_3d[:, 1] *= scale_ratio_h
198 |
199 | # adjust camera params
200 | cx *= scale_ratio_w
201 | cy *= scale_ratio_h
202 |
203 | return image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy
204 |
205 |
206 | def random_downsample(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, tpye):
207 |
208 | if type == "eval":
209 | a = 0; b = 0
210 | else:
211 | a, b = np.random.randint(0, 2, 2)
212 |
213 | return image1[a::2, b::2, :], image2[a::2, b::2, :], pc1[a::2, b::2, :], pc2[a::2, b::2, :], \
214 | flow_2d[a::2, b::2, :] * 0.5, flow_3d[a::2, b::2, :], mask[a::2, b::2], f / 2.0, (cx - b) / 2.0, (cy - a) / 2.0
215 |
216 |
217 | def scale_crop(image1, image2, pc1, pc2, flow_2d, flow_3d, mask):
218 |
219 | return image1[:256, :448, ...], image2[:256, :448, ...], pc1[:256, :448, ...], pc2[:256, :448, ...], flow_2d[:256, :448, ...], flow_3d[:256, :448, ...], mask[:256, :448, ...]
220 |
221 |
222 |
223 | def joint_augmentation(image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, cfgs):
224 | if not cfgs.enabled:
225 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy
226 |
227 | if cfgs.color_jitter.enabled:
228 | image1, image2 = color_jitter(
229 | image1, image2,
230 | brightness=cfgs.color_jitter.brightness,
231 | contrast=cfgs.color_jitter.contrast,
232 | saturation=cfgs.color_jitter.saturation,
233 | hue=cfgs.color_jitter.hue,
234 | )
235 |
236 | # u = fx * x / z + cx, v = fy * y / z + cy
237 | # w - u = - fx * x + w - cx , h - v = - fy * y / z + h - cy
238 | if cfgs.random_horizontal_flip.enabled:
239 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_flip(
240 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, flip_mode='lr'
241 | )
242 |
243 | if cfgs.random_vertical_flip.enabled:
244 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_flip(
245 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, flip_mode='ud'
246 | )
247 |
248 | if cfgs.random_crop.enabled:
249 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_crop(
250 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy,
251 | crop_size=cfgs.random_crop.crop_size
252 | )
253 |
254 | if cfgs.random_scale.enabled:
255 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy = random_scale(
256 | image1, image2, pc1, pc2, flow_2d, flow_3d, f, cx, cy,
257 | scale_range=cfgs.random_scale.scale_range
258 | )
259 |
260 | if cfgs.random_down.enabled:
261 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy = random_downsample(
262 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy, cfgs.random_down.type)
263 |
264 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask = scale_crop(
265 | image1, image2, pc1, pc2, flow_2d, flow_3d, mask)
266 |
267 |
268 | return image1, image2, pc1, pc2, flow_2d, flow_3d, mask, f, cx, cy
269 |
--------------------------------------------------------------------------------
/dataset/flyingthings_subset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import glob
4 | import torch
5 | import cv2
6 | from .augmentation import joint_augmentation, random_downsample, scale_crop
7 |
8 | __all__ = ['FlyingThings']
9 |
10 | def load_flow_png(filepath, scale=64.0):
11 | # for KITTI which uses 16bit PNG images
12 | # see 'https://github.com/ClementPinard/FlowNetPytorch/blob/master/datasets/KITTI.py'
13 | # The -1 is here to specify not to change the image depth (16bit), and is compatible
14 | # with both OpenCV2 and OpenCV3
15 | flow_img = cv2.imread(filepath, -1)
16 | flow = flow_img[:, :, 2:0:-1].astype(np.float32)
17 | mask = flow_img[:, :, 0] > 0
18 | flow = flow - 32768.0
19 | flow = flow / scale
20 | return flow, mask
21 |
22 | class FlyingThings3D(torch.utils.data.Dataset):
23 | def __init__(self, cfgs):
24 | assert os.path.isdir(cfgs.root_dir)
25 |
26 | self.root_dir = str(cfgs.root_dir)
27 | self.split = str(cfgs.split)
28 | self.split_dir = os.path.join(self.root_dir, self.split)
29 | self.cfgs = cfgs
30 |
31 | self.indices = []
32 | for filename in os.listdir(os.path.join(self.root_dir, self.split, 'flow_2d')):
33 | self.indices.append(int(filename.split('.')[0]))
34 |
35 | if not cfgs.full:
36 | self.indices = self.indices[::4]
37 |
38 |
39 | def __len__(self):
40 | return len(self.indices)
41 |
42 | def __getitem__(self, i):
43 | if not self.cfgs.augmentation.enabled:
44 | np.random.seed(0)
45 |
46 | idx1 = self.indices[i]
47 | idx2 = idx1 + 1
48 | data_dict = {'index': idx1}
49 |
50 | # camera intrinsics
51 | f, cx, cy = 1050, 479.5, 269.5
52 |
53 | # load data
54 | pcs = np.load(os.path.join(self.split_dir, 'pc', '%07d.npz' % idx1))
55 | pc1, pc2 = pcs['pc1'], pcs['pc2']
56 |
57 | flow_2d, flow_mask_2d = load_flow_png(os.path.join(self.split_dir, 'flow_2d', '%07d.png' % idx1))
58 | flow_3d = np.load(os.path.join(self.split_dir, 'flow_3d', '%07d.npy' % idx1))
59 |
60 | occ_mask_3d = np.load(os.path.join(self.split_dir, 'occ_mask_3d', '%07d.npy' % idx1))
61 | occ_mask_3d = np.unpackbits(occ_mask_3d, count=len(pc1))
62 |
63 | image1 = cv2.imread(os.path.join(self.split_dir, 'image', '%07d.png' % idx1))[..., ::-1].copy().astype(np.float32)
64 | image2 = cv2.imread(os.path.join(self.split_dir, 'image', '%07d.png' % idx2))[..., ::-1].copy().astype(np.float32)
65 |
66 | H, W = image1.shape[:2]
67 |
68 | pc1_mask = np.load(os.path.join(self.split_dir, 'pc1_mask', '%07d.npy' % idx1))
69 | pc1_mask = np.unpackbits(pc1_mask).reshape(H, W).astype(np.bool_)
70 |
71 | pc2_mask = np.load(os.path.join(self.split_dir, 'pc2_mask', '%07d.npy' % idx1))
72 | pc2_mask = np.unpackbits(pc2_mask).reshape(H, W).astype(np.bool_)
73 |
74 |
75 | # ignore fast moving objects
76 | flow_mask_2d = np.logical_and(flow_mask_2d, np.linalg.norm(flow_2d, axis=-1) < 250.0)
77 | flow_2d = np.concatenate([flow_2d, flow_mask_2d[..., None].astype(np.float32)], axis=2)
78 |
79 |
80 | pc1_hw3 = np.zeros((H, W, 3), dtype = np.float32)
81 | pc2_hw3 = np.zeros((H, W, 3), dtype = np.float32)
82 | flow_hw3 = np.zeros((H, W, 3), dtype = np.float32)
83 | valid_mask = np.zeros((H, W), dtype = np.bool_)
84 |
85 |
86 | pc1_hw3[pc1_mask] = pc1
87 | pc2_hw3[pc2_mask] = pc2
88 | flow_hw3[pc1_mask] = flow_3d
89 | valid_mask[pc1_mask] = np.logical_not(occ_mask_3d)
90 | # valid_mask = pc1_mask
91 |
92 | image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask, f, cx, cy = joint_augmentation(
93 | image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask, f, cx, cy, self.cfgs.augmentation
94 | )
95 |
96 | # image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask = scale_crop(image1, image2, pc1_hw3, pc2_hw3, flow_2d, flow_hw3, valid_mask)
97 |
98 | intrinsics = np.float32([f, f, cx, cy])
99 |
100 | image1 = np.ascontiguousarray(image1.transpose([2, 0, 1]))
101 | image2 = np.ascontiguousarray(image2.transpose([2, 0, 1]))
102 | pc1_3hw = np.ascontiguousarray(pc1_hw3.transpose([2, 0, 1]))
103 | pc2_3hw = np.ascontiguousarray(pc2_hw3.transpose([2, 0, 1]))
104 | flow_3hw = np.ascontiguousarray(flow_hw3.transpose([2, 0, 1]))
105 | valid_mask = np.ascontiguousarray(valid_mask)
106 | pc1_mask = np.ascontiguousarray(pc1_mask)
107 | data_dict['image1'] = image1 # 3 x H x W
108 | data_dict['image2'] = image2 # 3 x H x W
109 | data_dict['pc1'] = pc1_3hw # 3 x H x W
110 | data_dict['pc2'] = pc2_3hw # 3 x H x W
111 | data_dict['flow_3d'] = flow_3hw # 3 x H x W
112 | data_dict['flow_2d'] = flow_2d.transpose([2, 0, 1]) # 3 x H x W
113 | data_dict['mask'] = valid_mask # H x W
114 |
115 | non_zero_mask = np.linalg.norm(pc1_3hw, axis = 0) > 1e-8
116 | data_dict['nonzero_mask'] = non_zero_mask # H x W
117 | data_dict['intrinsics'] = intrinsics # 4
118 |
119 | return data_dict
120 |
--------------------------------------------------------------------------------
/dataset/kitti.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import torch.utils.data
5 | import cv2
6 | import numpy as np
7 | import torch.utils.data
8 |
9 |
10 | def load_flow_png(filepath, scale=64.0):
11 | # for KITTI which uses 16bit PNG images
12 | # see 'https://github.com/ClementPinard/FlowNetPytorch/blob/master/datasets/KITTI.py'
13 | # The -1 is here to specify not to change the image depth (16bit), and is compatible
14 | # with both OpenCV2 and OpenCV3
15 | flow_img = cv2.imread(filepath, -1)
16 | flow = flow_img[:, :, 2:0:-1].astype(np.float32)
17 | mask = flow_img[:, :, 0] > 0
18 | flow = flow - 32768.0
19 | flow = flow / scale
20 | return flow, mask
21 |
22 | def load_disp_png(filepath):
23 | array = cv2.imread(filepath, -1)
24 | valid_mask = array > 0
25 | disp = array.astype(np.float32) / 256.0
26 | disp[np.logical_not(valid_mask)] = -1.0
27 | return disp, valid_mask
28 |
29 | def load_calib(filepath):
30 | with open(filepath) as f:
31 | lines = f.readlines()
32 | for line in lines:
33 | if line.startswith('P_rect_02'):
34 | proj_mat = line.split()[1:]
35 | proj_mat = [float(param) for param in proj_mat]
36 | proj_mat = np.array(proj_mat, dtype=np.float32).reshape(3, 4)
37 | assert proj_mat[0, 1] == proj_mat[1, 0] == 0
38 | assert proj_mat[2, 0] == proj_mat[2, 1] == 0
39 | assert proj_mat[0, 0] == proj_mat[1, 1]
40 | assert proj_mat[2, 2] == 1
41 |
42 | return proj_mat
43 |
44 |
45 | def disp2pc(disp, baseline, f, cx, cy, flow=None):
46 | h, w = disp.shape
47 | depth = baseline * f / (disp + 1e-5)
48 |
49 | xx = np.tile(np.arange(w, dtype=np.float32)[None, :], (h, 1))
50 | yy = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w))
51 |
52 | if flow is None:
53 | x = (xx - cx) * depth / f
54 | y = (yy - cy) * depth / f
55 | else:
56 | x = (xx - cx + flow[..., 0]) * depth / f
57 | y = (yy - cy + flow[..., 1]) * depth / f
58 |
59 | pc = np.concatenate([
60 | x[:, :, None],
61 | y[:, :, None],
62 | depth[:, :, None],
63 | ], axis=-1)
64 |
65 | return pc
66 |
67 |
68 |
69 | def zero_padding(inputs, pad_h, pad_w):
70 | input_dim = len(inputs.shape)
71 | assert input_dim in [2, 3]
72 |
73 | if input_dim == 2:
74 | inputs = inputs[..., None]
75 |
76 | h, w, c = inputs.shape
77 | assert h <= pad_h and w <= pad_w
78 |
79 | result = np.zeros([pad_h, pad_w, c], dtype=inputs.dtype)
80 | result[:h, :w, :c] = inputs
81 |
82 | if input_dim == 2:
83 | result = result[..., 0]
84 |
85 | return result
86 |
87 |
88 | class KITTI(torch.utils.data.Dataset):
89 | def __init__(self, cfgs):
90 | assert os.path.isdir(cfgs.root_dir)
91 | assert cfgs.split in ['training200', 'training160', 'training40']
92 |
93 | self.root_dir = os.path.join(cfgs.root_dir, 'training')
94 | self.split = cfgs.split
95 | self.cfgs = cfgs
96 | self.crop = 80
97 |
98 | if self.split == 'training200':
99 | self.indices = np.arange(200)
100 | elif self.split == 'training160':
101 | self.indices = [i for i in range(200) if i % 5 != 0]
102 | elif self.split == 'training40':
103 | self.indices = [i for i in range(200) if i % 5 == 0]
104 |
105 | def __len__(self):
106 | return len(self.indices)
107 |
108 | def __getitem__(self, i):
109 | np.random.seed(23333)
110 |
111 | index = self.indices[i]
112 | data_dict = {'index': index}
113 |
114 | proj_mat = load_calib(os.path.join(self.root_dir, 'calib_cam_to_cam', '%06d.txt' % index))
115 | f, cx, cy = proj_mat[0, 0], proj_mat[0, 2], proj_mat[1, 2]
116 |
117 | image1 = cv2.imread(os.path.join(self.root_dir, 'image_2', '%06d_10.png' % index))[..., ::-1].copy().astype(np.float32)
118 | image2 = cv2.imread(os.path.join(self.root_dir, 'image_2', '%06d_11.png' % index))[..., ::-1].copy().astype(np.float32)
119 | flow_2d, flow_2d_mask = load_flow_png(os.path.join(self.root_dir, 'flow_occ', '%06d_10.png' % index))
120 |
121 | data_dict['input_h'] = image1.shape[0]
122 | data_dict['input_w'] = image1.shape[1]
123 |
124 | disp1, mask1 = load_disp_png(os.path.join(self.root_dir, 'disp_occ_0', '%06d_10.png' % index))
125 | disp2, mask2 = load_disp_png(os.path.join(self.root_dir, 'disp_occ_1', '%06d_10.png' % index))
126 | valid = np.logical_and(np.logical_and(mask1, mask2), flow_2d_mask)
127 |
128 | valid = np.logical_and(valid, disp2 > 0.0)
129 |
130 | disp1_dense, mask1_dense = load_disp_png(os.path.join(self.root_dir, 'disp_ganet_training', '%06d_10.png' % index))
131 | disp2_dense, mask2_dense = load_disp_png(os.path.join(self.root_dir, 'disp_ganet_training', '%06d_11.png' % index))
132 |
133 |
134 | flow_3d = disp2pc(disp2, baseline=0.54, f=f, cx=cx, cy=cy, flow=flow_2d) - disp2pc(disp1, baseline=0.54, f=f, cx=cx, cy=cy)
135 |
136 | pc1 = disp2pc(disp1_dense, 0.54, f=f, cx=cx, cy=cy)
137 | pc2 = disp2pc(disp2_dense, 0.54, f=f, cx=cx, cy=cy)
138 |
139 |
140 | pc1 = pc1[self.crop:]
141 | pc2 = pc2[self.crop:]
142 |
143 |
144 | # limit max depth
145 | pc1[pc1[..., -1] < self.cfgs.max_depth] = 0.0
146 | pc2[pc2[..., -1] < self.cfgs.max_depth] = 0.0
147 |
148 | image1 = image1[self.crop:]
149 | image2 = image2[self.crop:]
150 | intrinsics = np.array([f, f, cx, cy - self.crop])
151 | flow_3d = flow_3d[self.crop:]
152 | flow_2d = flow_2d[self.crop:]
153 | valid = valid[self.crop:]
154 |
155 |
156 | padding_h, padding_w = 320, 1280
157 | image1 = zero_padding(image1, padding_h, padding_w)
158 | image2 = zero_padding(image2, padding_h, padding_w)
159 | pc1 = zero_padding(pc1, padding_h, padding_w)
160 | pc2 = zero_padding(pc2, padding_h, padding_w)
161 | valid = zero_padding(valid, padding_h, padding_w)
162 | flow_3d = zero_padding(flow_3d, padding_h, padding_w)
163 | flow_2d = zero_padding(flow_2d, padding_h, padding_w)
164 |
165 | data_dict['image1'] = np.ascontiguousarray(image1.transpose([2, 0, 1])) # 3 x H x W
166 | data_dict['image2'] = np.ascontiguousarray(image2.transpose([2, 0, 1])) # 3 x H x W
167 | data_dict['pc1'] = np.ascontiguousarray(pc1.transpose([2, 0, 1])) # 3 x H x W
168 | data_dict['pc2'] = np.ascontiguousarray(pc2.transpose([2, 0, 1])) # 3 x H x W
169 | non_zero_mask = np.linalg.norm(data_dict['pc1'], axis = 0) > 1e-8
170 | data_dict['nonzero_mask'] = non_zero_mask # H x W
171 | data_dict['intrinsics'] = intrinsics # 4
172 | flow_2d = np.concatenate([flow_2d, flow_2d_mask[..., None].astype(np.float32)], axis=-1) # H x W x 3
173 | data_dict['flow_3d'] = np.ascontiguousarray(flow_3d.transpose([2, 0, 1])) # 3 x H x W
174 | data_dict['flow_2d'] = np.ascontiguousarray(flow_2d.transpose([2, 0, 1])) # 3 x H x W
175 | data_dict['mask'] = valid & non_zero_mask # H x W
176 |
177 |
178 | return data_dict
--------------------------------------------------------------------------------
/dataset/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import shutil
4 | import logging
5 | import argparse
6 | import torch.utils.data
7 | import numpy as np
8 | from tqdm import tqdm
9 | import re
10 | import sys
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--input_dir', default = '/data/FlyingThings3D_subset', help='Path to the FlyingThings3D subset')
15 | parser.add_argument('--output_dir', required=False, default='/data/processed_flyingthings3d')
16 | parser.add_argument('--max_depth', required=False, default=35.0)
17 | parser.add_argument('--remove_occluded_points', action='store_true')
18 | args = parser.parse_args()
19 |
20 | def init_logging(filename=None, debug=False):
21 | logging.root = logging.RootLogger('DEBUG' if debug else 'INFO')
22 | formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s')
23 |
24 | stream_handler = logging.StreamHandler(sys.stdout)
25 | stream_handler.setFormatter(formatter)
26 | logging.root.addHandler(stream_handler)
27 |
28 | if filename is not None:
29 | file_handler = logging.FileHandler(filename)
30 | file_handler.setFormatter(formatter)
31 | logging.root.addHandler(file_handler)
32 |
33 | def load_fpm(filename):
34 | with open(filename, 'rb') as f:
35 | header = f.readline().rstrip()
36 | if header.decode("ascii") == 'PF':
37 | color = True
38 | elif header.decode("ascii") == 'Pf':
39 | color = False
40 | else:
41 | raise Exception('Not a PFM file.')
42 |
43 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', f.readline().decode("ascii"))
44 | if dim_match:
45 | width, height = list(map(int, dim_match.groups()))
46 | else:
47 | raise Exception('Malformed PFM header.')
48 |
49 | scale = float(f.readline().decode("ascii").rstrip())
50 | if scale < 0: # little-endian
51 | endian = '<'
52 | else:
53 | endian = '>' # big-endian
54 |
55 | data = np.fromfile(f, endian + 'f')
56 | shape = (height, width, 3) if color else (height, width)
57 | data = np.reshape(data, shape)
58 | data = np.flipud(data)
59 |
60 | return data
61 |
62 | def disp2pc(disp, baseline, f, cx, cy, flow=None):
63 | h, w = disp.shape
64 | depth = baseline * f / (disp + 1e-5)
65 |
66 | xx = np.tile(np.arange(w, dtype=np.float32)[None, :], (h, 1))
67 | yy = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w))
68 |
69 | if flow is None:
70 | x = (xx - cx) * depth / f
71 | y = (yy - cy) * depth / f
72 | else:
73 | x = (xx - cx + flow[..., 0]) * depth / f
74 | y = (yy - cy + flow[..., 1]) * depth / f
75 |
76 | pc = np.concatenate([
77 | x[:, :, None],
78 | y[:, :, None],
79 | depth[:, :, None],
80 | ], axis=-1)
81 |
82 | return pc
83 |
84 | def load_flow(filepath):
85 | with open(filepath, 'rb') as f:
86 | magic = np.fromfile(f, np.float32, count=1)
87 | assert (202021.25 == magic), 'Invalid .flo file: incorrect magic number'
88 | w = np.fromfile(f, np.int32, count=1)[0]
89 | h = np.fromfile(f, np.int32, count=1)[0]
90 | flow = np.fromfile(f, np.float32, count=2 * w * h).reshape([h, w, 2])
91 |
92 | return flow
93 |
94 |
95 | def save_flow_png(filepath, flow, mask=None, scale=64.0):
96 | assert flow.shape[2] == 2
97 | assert np.abs(flow).max() < 32767.0 / scale
98 | flow = flow * scale
99 | flow = flow + 32768.0
100 |
101 | if mask is None:
102 | mask = np.ones_like(flow)[..., 0]
103 | else:
104 | mask = np.float32(mask > 0)
105 |
106 | flow_img = np.concatenate([
107 | mask[..., None],
108 | flow[..., 1:2],
109 | flow[..., 0:1]
110 | ], axis=-1).astype(np.uint16)
111 |
112 | cv2.imwrite(filepath, flow_img)
113 |
114 |
115 |
116 |
117 | class Preprocessor(torch.utils.data.Dataset):
118 | def __init__(self, input_dir, output_dir, split, max_depth, remove_occluded_points):
119 | super(Preprocessor, self).__init__()
120 |
121 | self.input_dir = input_dir
122 | self.output_dir = output_dir
123 | self.split = split
124 | self.max_depth = max_depth
125 | self.remove_occluded_points = remove_occluded_points
126 |
127 | self.indices = []
128 | for filename in os.listdir(os.path.join(input_dir, split, 'flow', 'left', 'into_future')):
129 | index = int(filename.split('.')[0])
130 | self.indices.append(index)
131 |
132 | def __len__(self):
133 | return len(self.indices)
134 |
135 | def __getitem__(self, i):
136 | np.random.seed(0)
137 |
138 | index1 = self.indices[i]
139 | index2 = index1 + 1
140 |
141 | # camera intrinsics
142 | baseline, f, cx, cy = 1.0, 1050.0, 479.5, 269.5
143 |
144 | # load data
145 | disp1 = -load_fpm(os.path.join(
146 | self.input_dir, self.split, 'disparity', 'left', '%07d.pfm' % index1
147 | ))
148 | disp2 = -load_fpm(os.path.join(
149 | self.input_dir, self.split, 'disparity', 'left', '%07d.pfm' % index2
150 | ))
151 | disp1_change = -load_fpm(os.path.join(
152 | self.input_dir, self.split, 'disparity_change', 'left', 'into_future', '%07d.pfm' % index1
153 | ))
154 | flow_2d = load_flow(os.path.join(
155 | self.input_dir, self.split, 'flow', 'left', 'into_future', '%07d.flo' % index1
156 | ))
157 | occ_mask_2d = cv2.imread(os.path.join(
158 | self.input_dir, self.split, 'flow_occlusions', 'left', 'into_future', '%07d.png' % index1
159 | ))
160 | occ_mask_2d = occ_mask_2d[..., 0] > 1
161 |
162 | if self.remove_occluded_points:
163 | pc1 = disp2pc(disp1, baseline, f, cx, cy)
164 | pc2 = disp2pc(disp1 + disp1_change, baseline, f, cx, cy, flow_2d)
165 |
166 | # apply non-occlusion mask
167 | noc_mask_2d = np.logical_not(occ_mask_2d)
168 | pc1, pc2 = pc1[noc_mask_2d], pc2[noc_mask_2d]
169 |
170 | # apply depth mask
171 | mask = np.logical_and(pc1[..., -1] < self.max_depth, pc2[..., -1] < self.max_depth)
172 | pc1, pc2 = pc1[mask], pc2[mask]
173 |
174 | # NaN check
175 | mask = np.logical_not(np.isnan(np.sum(pc1, axis=-1) + np.sum(pc2, axis=-1)))
176 | pc1, pc2 = pc1[mask], pc2[mask]
177 |
178 | # compute scene flow
179 | flow_3d = pc2 - pc1
180 | occ_mask_3d = np.zeros(len(pc1), dtype=np.bool)
181 | else:
182 | pc1 = disp2pc(disp1, baseline, f, cx, cy)
183 | pc2 = disp2pc(disp2, baseline, f, cx, cy)
184 | flow_3d = disp2pc(disp1 + disp1_change, baseline, f, cx, cy, flow_2d) - pc1 # h x w x 3
185 |
186 |
187 | x, y = np.meshgrid(np.arange(pc1.shape[1]), np.arange(pc1.shape[0]))
188 | coords = np.concatenate([y[:, :, None], x[:, :, None]], axis= -1) # h x w x 2
189 |
190 | # apply depth mask and NaN check
191 | mask1 = np.logical_and((pc1[..., -1] < self.max_depth), np.logical_not(np.isnan(np.sum(pc1, axis=-1) + np.sum(flow_3d, axis=-1))))
192 | mask2 = np.logical_and((pc2[..., -1] < self.max_depth), np.logical_not(np.isnan(np.sum(pc2, axis=-1))))
193 |
194 |
195 |
196 | pc1, pc2, flow_3d, occ_mask_3d = pc1[mask1], pc2[mask2], flow_3d[mask1], occ_mask_2d[mask1]
197 |
198 |
199 | # save point clouds and occ mask
200 | np.savez(
201 | os.path.join(self.output_dir, self.split, 'pc', '%07d.npz' % index1),
202 | pc1=pc1, pc2=pc2
203 | )
204 | np.save(
205 | os.path.join(self.output_dir, self.split, 'occ_mask_3d', '%07d.npy' % index1),
206 | np.packbits(occ_mask_3d)
207 | )
208 |
209 | np.save(
210 | os.path.join(self.output_dir, self.split, 'pc1_mask', '%07d.npy' % index1),
211 | np.packbits(mask1)
212 | )
213 |
214 | np.save(
215 | os.path.join(self.output_dir, self.split, 'pc2_mask', '%07d.npy' % index1),
216 | np.packbits(mask2)
217 | )
218 |
219 | # mask regions moving extremely fast
220 | flow_mask = np.logical_and(np.abs(flow_2d[..., 0]) < 500, np.abs(flow_2d[..., 1]) < 500)
221 | flow_2d[np.logical_not(flow_mask)] = 0.0
222 |
223 | # save ground-truth flow
224 | save_flow_png(
225 | os.path.join(self.output_dir, self.split, 'flow_2d', '%07d.png' % index1),
226 | flow_2d, flow_mask
227 | )
228 | np.save(
229 | os.path.join(self.output_dir, self.split, 'flow_3d', '%07d.npy' % index1),
230 | flow_3d
231 | )
232 |
233 | return 0
234 |
235 |
236 | def main():
237 | for split_idx, split in enumerate(['train', 'val']):
238 |
239 | if not os.path.exists(os.path.join(args.input_dir, split)):
240 | print(os.path.join(args.input_dir, split))
241 | continue
242 |
243 | logging.info('Processing "%s" split...' % split)
244 |
245 | os.makedirs(os.path.join(args.output_dir, split, 'pc'), exist_ok=True)
246 | os.makedirs(os.path.join(args.output_dir, split, 'flow_2d'), exist_ok=True)
247 | os.makedirs(os.path.join(args.output_dir, split, 'flow_3d'), exist_ok=True)
248 | os.makedirs(os.path.join(args.output_dir, split, 'pc1_mask'), exist_ok=True)
249 | os.makedirs(os.path.join(args.output_dir, split, 'pc2_mask'), exist_ok=True)
250 |
251 | if not os.path.exists(os.path.join(args.output_dir, split, 'image')):
252 | logging.info('Copying images...')
253 | shutil.copytree(
254 | src=os.path.join(args.input_dir, split, 'image_clean', 'left'),
255 | dst=os.path.join(args.output_dir, split, 'image')
256 | )
257 |
258 | if not os.path.exists(os.path.join(args.output_dir, split, 'occ_mask_2d')):
259 | logging.info('Copying occ_mask_2d...')
260 | shutil.copytree(
261 | src=os.path.join(args.input_dir, split, 'flow_occlusions', 'left', 'into_future'),
262 | dst=os.path.join(args.output_dir, split, 'occ_mask_2d')
263 | )
264 |
265 | logging.info('Generating point clouds...')
266 | preprocessor = Preprocessor(
267 | args.input_dir,
268 | args.output_dir,
269 | split,
270 | args.max_depth,
271 | args.remove_occluded_points,
272 | )
273 | preprocessor = torch.utils.data.DataLoader(dataset=preprocessor, num_workers=4)
274 |
275 | for i in tqdm(preprocessor):
276 | # print(i)
277 | # if i > 5:
278 | # break
279 | pass
280 |
281 |
282 | if __name__ == '__main__':
283 | init_logging()
284 | main()
285 | logging.info('All done.')
286 |
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: delflow
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - archspec=0.2.3=pyhd3eb1b0_0
10 | - asttokens=2.0.5=pyhd3eb1b0_0
11 | - attrs=23.1.0=py310h06a4308_0
12 | - beautifulsoup4=4.12.2=py310h06a4308_0
13 | - blas=1.0=mkl
14 | - boltons=23.0.0=py310h06a4308_0
15 | - brotli-python=1.0.9=py310h6a678d5_7
16 | - bzip2=1.0.8=h7b6447c_0
17 | - c-ares=1.19.1=h5eee18b_0
18 | - ca-certificates=2024.3.11=h06a4308_0
19 | - certifi=2024.2.2=py310h06a4308_0
20 | - cffi=1.16.0=py310h5eee18b_0
21 | - chardet=4.0.0=py310h06a4308_1003
22 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
23 | - click=8.1.7=py310h06a4308_0
24 | - cmake=3.26.4=h96355d8_0
25 | - conda=23.9.0=py310h06a4308_0
26 | - conda-build=24.3.0=py310h06a4308_0
27 | - conda-content-trust=0.2.0=py310h06a4308_0
28 | - conda-index=0.4.0=pyhd3eb1b0_0
29 | - conda-libmamba-solver=23.7.0=py310h06a4308_0
30 | - conda-package-handling=2.2.0=py310h06a4308_0
31 | - conda-package-streaming=0.9.0=py310h06a4308_0
32 | - cryptography=42.0.5=py310hdda0065_0
33 | - cuda-cudart=12.1.105=0
34 | - cuda-cupti=12.1.105=0
35 | - cuda-libraries=12.1.0=0
36 | - cuda-nvrtc=12.1.105=0
37 | - cuda-nvtx=12.1.105=0
38 | - cuda-opencl=12.4.99=0
39 | - cuda-runtime=12.1.0=0
40 | - decorator=5.1.1=pyhd3eb1b0_0
41 | - distro=1.8.0=py310h06a4308_0
42 | - exceptiongroup=1.2.0=py310h06a4308_0
43 | - executing=0.8.3=pyhd3eb1b0_0
44 | - expat=2.5.0=h6a678d5_0
45 | - ffmpeg=4.3=hf484d3e_0
46 | - filelock=3.13.1=py310h06a4308_0
47 | - fmt=9.1.0=hdb19cb5_0
48 | - freetype=2.12.1=h4a9f257_0
49 | - gmp=6.2.1=h295c915_3
50 | - gmpy2=2.1.2=py310heeb90bb_0
51 | - gnutls=3.6.15=he1e5248_0
52 | - icu=73.1=h6a678d5_0
53 | - idna=3.4=py310h06a4308_0
54 | - intel-openmp=2023.1.0=hdb19cb5_46306
55 | - ipython=8.20.0=py310h06a4308_0
56 | - jedi=0.18.1=py310h06a4308_1
57 | - jinja2=3.1.3=py310h06a4308_0
58 | - jpeg=9e=h5eee18b_1
59 | - jsonpatch=1.32=pyhd3eb1b0_0
60 | - jsonpointer=2.1=pyhd3eb1b0_0
61 | - jsonschema=4.19.2=py310h06a4308_0
62 | - jsonschema-specifications=2023.7.1=py310h06a4308_0
63 | - krb5=1.20.1=h143b758_1
64 | - lame=3.100=h7b6447c_0
65 | - lcms2=2.12=h3be6417_0
66 | - ld_impl_linux-64=2.38=h1181459_1
67 | - lerc=3.0=h295c915_0
68 | - libarchive=3.6.2=h6ac8c49_2
69 | - libcublas=12.1.0.26=0
70 | - libcufft=11.0.2.4=0
71 | - libcufile=1.9.0.20=0
72 | - libcurand=10.3.5.119=0
73 | - libcurl=8.5.0=h251f7ec_0
74 | - libcusolver=11.4.4.55=0
75 | - libcusparse=12.0.2.55=0
76 | - libdeflate=1.17=h5eee18b_1
77 | - libedit=3.1.20230828=h5eee18b_0
78 | - libev=4.33=h7f8727e_1
79 | - libffi=3.4.4=h6a678d5_0
80 | - libgcc-ng=11.2.0=h1234567_1
81 | - libgomp=11.2.0=h1234567_1
82 | - libiconv=1.16=h7f8727e_2
83 | - libidn2=2.3.4=h5eee18b_0
84 | - libjpeg-turbo=2.0.0=h9bf148f_0
85 | - liblief=0.12.3=h6a678d5_0
86 | - libmamba=1.5.3=haf1ee3a_0
87 | - libmambapy=1.5.3=py310h2dafd23_0
88 | - libnghttp2=1.57.0=h2d74bed_0
89 | - libnpp=12.0.2.50=0
90 | - libnvjitlink=12.1.105=0
91 | - libnvjpeg=12.1.1.14=0
92 | - libpng=1.6.39=h5eee18b_0
93 | - libsolv=0.7.24=he621ea3_0
94 | - libssh2=1.10.0=hdbd6064_2
95 | - libstdcxx-ng=11.2.0=h1234567_1
96 | - libtasn1=4.19.0=h5eee18b_0
97 | - libtiff=4.5.1=h6a678d5_0
98 | - libunistring=0.9.10=h27cfd23_0
99 | - libuuid=1.41.5=h5eee18b_0
100 | - libuv=1.44.2=h5eee18b_0
101 | - libwebp-base=1.3.2=h5eee18b_0
102 | - libxml2=2.10.4=hf1b16e4_1
103 | - llvm-openmp=14.0.6=h9e868ea_0
104 | - lz4-c=1.9.4=h6a678d5_0
105 | - markupsafe=2.1.3=py310h5eee18b_0
106 | - matplotlib-inline=0.1.6=py310h06a4308_0
107 | - menuinst=2.0.2=py310h06a4308_0
108 | - mkl=2023.1.0=h213fc3f_46344
109 | - mkl-service=2.4.0=py310h5eee18b_1
110 | - mkl_fft=1.3.8=py310h5eee18b_0
111 | - mkl_random=1.2.4=py310hdb19cb5_0
112 | - more-itertools=10.1.0=py310h06a4308_0
113 | - mpc=1.1.0=h10f8cd9_1
114 | - mpfr=4.0.2=hb69a4c5_1
115 | - mpmath=1.3.0=py310h06a4308_0
116 | - ncurses=6.4=h6a678d5_0
117 | - nettle=3.7.3=hbbd107a_1
118 | - numpy=1.26.4=py310h5f9d8c6_0
119 | - numpy-base=1.26.4=py310hb5e798b_0
120 | - openh264=2.1.1=h4ff587b_0
121 | - openjpeg=2.4.0=h3ad879b_0
122 | - openssl=3.0.13=h7f8727e_0
123 | - packaging=23.2=py310h06a4308_0
124 | - parso=0.8.3=pyhd3eb1b0_0
125 | - patch=2.7.6=h7b6447c_1001
126 | - patchelf=0.17.2=h6a678d5_0
127 | - pcre2=10.42=hebb0a14_0
128 | - pexpect=4.8.0=pyhd3eb1b0_3
129 | - pillow=10.2.0=py310h5eee18b_0
130 | - pip=23.3.1=py310h06a4308_0
131 | - pkginfo=1.9.6=py310h06a4308_0
132 | - platformdirs=3.10.0=py310h06a4308_0
133 | - pluggy=1.0.0=py310h06a4308_1
134 | - prompt-toolkit=3.0.43=py310h06a4308_0
135 | - prompt_toolkit=3.0.43=hd3eb1b0_0
136 | - psutil=5.9.0=py310h5eee18b_0
137 | - ptyprocess=0.7.0=pyhd3eb1b0_2
138 | - pure_eval=0.2.2=pyhd3eb1b0_0
139 | - py-lief=0.12.3=py310h6a678d5_0
140 | - pybind11-abi=4=hd3eb1b0_1
141 | - pycosat=0.6.6=py310h5eee18b_0
142 | - pycparser=2.21=pyhd3eb1b0_0
143 | - pygments=2.15.1=py310h06a4308_1
144 | - pyopenssl=24.0.0=py310h06a4308_0
145 | - pysocks=1.7.1=py310h06a4308_0
146 | - python=3.10.14=h955ad1f_0
147 | - python-libarchive-c=2.9=pyhd3eb1b0_1
148 | - pytorch=2.2.2=py3.10_cuda12.1_cudnn8.9.2_0
149 | - pytorch-cuda=12.1=ha16c6d3_5
150 | - pytorch-mutex=1.0=cuda
151 | - pytz=2023.3.post1=py310h06a4308_0
152 | - pyyaml=6.0.1=py310h5eee18b_0
153 | - readline=8.2=h5eee18b_0
154 | - referencing=0.30.2=py310h06a4308_0
155 | - reproc=14.2.4=h295c915_1
156 | - reproc-cpp=14.2.4=h295c915_1
157 | - rhash=1.4.3=hdbd6064_0
158 | - rpds-py=0.10.6=py310hb02cf49_0
159 | - ruamel.yaml=0.17.21=py310h5eee18b_0
160 | - ruamel.yaml.clib=0.2.6=py310h5eee18b_1
161 | - six=1.16.0=pyhd3eb1b0_1
162 | - soupsieve=2.5=py310h06a4308_0
163 | - sqlite=3.41.2=h5eee18b_0
164 | - stack_data=0.2.0=pyhd3eb1b0_0
165 | - sympy=1.12=py310h06a4308_0
166 | - tbb=2021.8.0=hdb19cb5_0
167 | - tk=8.6.12=h1ccaba5_0
168 | - tomli=2.0.1=py310h06a4308_0
169 | - toolz=0.12.0=py310h06a4308_0
170 | - torchaudio=2.2.2=py310_cu121
171 | - torchtriton=2.2.0=py310
172 | - torchvision=0.17.2=py310_cu121
173 | - tqdm=4.65.0=py310h2f386ee_0
174 | - traitlets=5.7.1=py310h06a4308_0
175 | - truststore=0.8.0=py310h06a4308_0
176 | - typing_extensions=4.9.0=py310h06a4308_1
177 | - wcwidth=0.2.5=pyhd3eb1b0_0
178 | - wheel=0.41.2=py310h06a4308_0
179 | - xz=5.4.6=h5eee18b_0
180 | - yaml=0.2.5=h7b6447c_0
181 | - yaml-cpp=0.8.0=h6a678d5_0
182 | - zlib=1.2.13=h5eee18b_0
183 | - zstandard=0.19.0=py310h5eee18b_0
184 | - zstd=1.5.5=hc292b87_0
185 | - pip:
186 | - absl-py==2.1.0
187 | - accelerate==0.29.1
188 | - addict==2.4.0
189 | - aliyun-python-sdk-core==2.15.0
190 | - aliyun-python-sdk-kms==2.16.2
191 | - antlr4-python3-runtime==4.8
192 | - astunparse==1.6.3
193 | - cachetools==5.3.3
194 | - colorama==0.4.6
195 | - contourpy==1.2.0
196 | - crcmod==1.7
197 | - cycler==0.12.1
198 | - diffusers==0.27.2
199 | - dnspython==2.6.1
200 | - expecttest==0.2.1
201 | - fonttools==4.50.0
202 | - fsspec==2024.3.1
203 | - google-auth==2.29.0
204 | - google-auth-oauthlib==0.4.6
205 | - grpcio==1.62.1
206 | - huggingface-hub==0.22.2
207 | - hypothesis==6.99.13
208 | - importlib-metadata==7.1.0
209 | - jmespath==0.10.0
210 | - kiwisolver==1.4.5
211 | - markdown==3.6
212 | - markdown-it-py==3.0.0
213 | - matplotlib==3.8.3
214 | - mdurl==0.1.2
215 | - mmcv==2.1.0
216 | - mmdet==3.3.0
217 | - mmengine==0.10.3
218 | - model-index==0.1.11
219 | - networkx==3.2.1
220 | - ninja==1.11.1.1
221 | - oauthlib==3.2.2
222 | - omegaconf==2.1.0
223 | - opencv-python==4.6.0.66
224 | - opendatalab==0.0.10
225 | - openmim==0.3.9
226 | - openxlab==0.0.37
227 | - optree==0.11.0
228 | - ordered-set==4.1.0
229 | - oss2==2.17.0
230 | - pandas==2.2.1
231 | - protobuf==3.20.3
232 | - pyasn1==0.6.0
233 | - pyasn1-modules==0.4.0
234 | - pycocotools==2.0.7
235 | - pycryptodome==3.20.0
236 | - pyparsing==3.1.2
237 | - python-dateutil==2.9.0.post0
238 | - python-etcd==0.4.5
239 | - regex==2023.12.25
240 | - requests==2.28.2
241 | - requests-oauthlib==2.0.0
242 | - rich==13.4.2
243 | - rsa==4.9
244 | - safetensors==0.4.2
245 | - scipy==1.12.0
246 | - setuptools==60.2.0
247 | - shapely==2.0.3
248 | - sortedcontainers==2.4.0
249 | - tabulate==0.9.0
250 | - tensorboard==2.11.2
251 | - tensorboard-data-server==0.6.1
252 | - tensorboard-plugin-wit==1.8.1
253 | - tensorboardx==2.6
254 | - termcolor==2.4.0
255 | - terminaltables==3.1.10
256 | - timm==0.6.13
257 | - torchelastic==0.2.2
258 | - types-dataclasses==0.6.6
259 | - typing-extensions==4.10.0
260 | - tzdata==2024.1
261 | - urllib3==1.26.18
262 | - werkzeug==3.0.1
263 | - yapf==0.40.2
264 | - zipp==3.18.1
265 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .camlipwc import CamLiPWC
2 |
--------------------------------------------------------------------------------
/models/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def dist_reduce_sum(value):
5 | if torch.distributed.is_initialized():
6 | value_t = torch.Tensor([value]).cuda()
7 | torch.distributed.all_reduce(value_t)
8 | return value_t
9 | else:
10 | return value
11 |
12 |
13 | class BaseModel(nn.Module):
14 | def __init__(self):
15 | super().__init__()
16 | self.loss = None
17 | self.metrics = {}
18 |
19 | def clear_metrics(self):
20 | self.metrics = {}
21 |
22 | @torch.no_grad()
23 | def update_metrics(self, name, var):
24 | if isinstance(var, torch.Tensor):
25 | var = var.reshape(-1)
26 | count = var.shape[0]
27 | var = var.float().sum().item()
28 |
29 | var = dist_reduce_sum(var)
30 | count = dist_reduce_sum(count)
31 |
32 | if count <= 0:
33 | return
34 |
35 | if name not in self.metrics.keys():
36 | self.metrics[name] = [0, 0] # [var, count]
37 |
38 | self.metrics[name][0] += var
39 | self.metrics[name][1] += count
40 |
41 | def get_metrics(self):
42 | results = {}
43 | for name, (var, count) in self.metrics.items():
44 | results[name] = var / count
45 | return results
46 |
47 | def get_loss(self):
48 | if self.loss is None:
49 | raise ValueError('Loss is empty.')
50 | return self.loss
51 |
52 | @staticmethod
53 | def is_better(curr_metrics, best_metrics):
54 | raise RuntimeError('Function `is_better` must be implemented.')
55 |
56 |
57 | class FlowModel(BaseModel):
58 | def __init__(self):
59 | super(FlowModel, self).__init__()
60 |
61 | @torch.no_grad()
62 | def update_2d_metrics(self, pred, target):
63 | if target.shape[1] == 3: # sparse evaluation
64 | mask = target[:, 2, :, :] > 0
65 | target = target[:, :2, :, :]
66 | else: # dense evaluation
67 | mask = torch.ones_like(target)[:, 0, :, :] > 0 # B x H x W
68 |
69 | # compute endpoint error
70 | diff = pred - target
71 | epe2d_map = torch.linalg.norm(diff, dim=1) # B x H x W
72 | self.update_metrics('epe2d', epe2d_map[mask])
73 |
74 | # compute 1px accuracy
75 | acc2d_map = epe2d_map < 1.0
76 | self.update_metrics('acc2d_1px', acc2d_map[mask])
77 |
78 | # compute flow outliers
79 | mag = torch.linalg.norm(target, dim=1) + 1e-5
80 | out2d_map = torch.logical_and(epe2d_map > 3.0, epe2d_map / mag > 0.05)
81 | self.update_metrics('outlier2d', out2d_map[mask])
82 |
83 | @torch.no_grad()
84 | def update_3d_metrics(self, pred, target, mask, noc_mask = None):
85 |
86 | # compute endpoint error
87 | diff = pred - target # [B, 3, H, W]
88 | epe3d_map = torch.linalg.norm(diff, dim=1) # [B, H, W]
89 | acc5_3d_map = epe3d_map < 0.05 # compute 5cm accuracy
90 | acc10_3d_map = epe3d_map < 0.10 # compute 10cm accuracy
91 |
92 | if noc_mask is not None:
93 | self.update_metrics('epe3d(noc)', epe3d_map[noc_mask])
94 | self.update_metrics('acc3d_5cm(noc)', acc5_3d_map[noc_mask])
95 | self.update_metrics('acc3d_10cm(noc)', acc10_3d_map[noc_mask])
96 | else:
97 | self.update_metrics('epe3d', epe3d_map[mask])
98 | self.update_metrics('acc3d_5cm', acc5_3d_map[mask])
99 | self.update_metrics('acc3d_10cm', acc10_3d_map[mask])
--------------------------------------------------------------------------------
/models/camlipwc.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .camlipwc_core import CamLiPWC_Core
3 | from .base import FlowModel
4 | from .losses2d import calc_supervised_loss_2d
5 | from .losses3d import calc_supervised_loss_3d
6 | from .utils import stride_sample, stride_sample_pc, normalize_image, resize_to_64x, resize_flow2d, build_pc_pyramid, build_labels
7 |
8 |
9 | class CamLiPWC(FlowModel):
10 | def __init__(self, cfgs):
11 | super(CamLiPWC, self).__init__()
12 | self.cfgs = cfgs
13 | self.dense = cfgs.dense
14 | self.core = CamLiPWC_Core(cfgs.pwc2d, cfgs.pwc3d, self.dense)
15 | self.stride_h_list = cfgs.stride_h
16 | self.stride_w_list = cfgs.stride_w
17 | self.loss = None
18 |
19 | def train(self, mode=True):
20 |
21 | self.training = mode
22 |
23 | for module in self.children():
24 | module.train(mode)
25 |
26 | if self.cfgs.freeze_bn:
27 | for m in self.modules():
28 | if isinstance(m, nn.modules.batchnorm._BatchNorm):
29 | m.eval()
30 |
31 | return self
32 |
33 | def eval(self):
34 | return self.train(False)
35 |
36 | def forward(self, inputs):
37 | image1 = inputs['image1'].float() / 255.0
38 | image2 = inputs['image2'].float() / 255.0
39 | pc1, pc2 = inputs['pc1'].float(), inputs['pc2'].float()
40 | intrinsics = inputs['intrinsics'].float()
41 |
42 |
43 | # assert images.shape[2] % 64 == 0 and images.shape[3] % 64 == 0
44 | origin_h, origin_w = image1.shape[2:]
45 | cam_info = {'sensor_h': origin_h, 'sensor_w': origin_w, 'intrinsics': intrinsics, 'type': "dense" if self.dense else "sparse"}
46 |
47 | # encode features
48 | if self.dense:
49 | xyzs1, xyzs2 = stride_sample_pc(pc1, pc2, self.stride_h_list, self.stride_w_list)
50 | else:
51 | xyzs1, xyzs2, sample_indices1, _ = build_pc_pyramid(
52 | pc1, pc2, [8192, 4096, 2048, 1024, 512, 256] # 1/4
53 | )
54 |
55 | feats1_2d, feats1_3d = self.core.encode(image1, xyzs1)
56 | feats2_2d, feats2_3d = self.core.encode(image2, xyzs2)
57 |
58 |
59 | # predict flows (1->2)
60 | flows_2d, flows_3d = self.core.decode(xyzs1[1:], xyzs2[1:], feats1_2d, feats2_2d, feats1_3d, feats2_3d, pc1, pc2, cam_info)
61 |
62 | # final_flow_2d = resize_flow2d(flows_2d[0], origin_h, origin_w)
63 | final_flow_2d = flows_2d[0]
64 | final_flow_3d = flows_3d[0]
65 |
66 | if 'flow_2d' not in inputs or 'flow_3d' not in inputs:
67 | return {'flow_2d': final_flow_2d, 'flow_3d': final_flow_3d}
68 |
69 | target_2d = inputs['flow_2d'].float()
70 | target_3d = inputs['flow_3d'].float()
71 | valid_mask = inputs['nonzero_mask']
72 | if self.dense:
73 | labels_3d = stride_sample(target_3d, self.stride_h_list[:-2], self.stride_w_list[:-2])
74 | masks_3d = stride_sample(valid_mask, self.stride_h_list[:-2], self.stride_w_list[:-2])
75 | else:
76 | labels_3d, masks_3d = build_labels(target_3d, sample_indices1)
77 |
78 | # calculate losses
79 | loss_2d = calc_supervised_loss_2d(flows_2d, target_2d, self.cfgs.loss2d)
80 | loss_3d = calc_supervised_loss_3d(flows_3d, labels_3d, self.cfgs.loss3d, masks_3d)
81 | self.loss = loss_2d + loss_3d
82 |
83 | # prepare scalar summary
84 | self.update_metrics('loss', self.loss)
85 | self.update_metrics('loss2d', loss_2d)
86 | self.update_metrics('loss3d', loss_3d)
87 | self.update_2d_metrics(final_flow_2d, target_2d)
88 | self.update_3d_metrics(final_flow_3d, target_3d, valid_mask)
89 |
90 | if 'mask' in inputs:
91 | self.update_3d_metrics(final_flow_3d, target_3d, valid_mask, inputs['mask'])
92 |
93 | return {'flow_2d': final_flow_2d, 'flow_3d': final_flow_3d}
94 |
95 | @staticmethod
96 | def is_better(curr_summary, best_summary):
97 | if best_summary is None:
98 | return True
99 | return curr_summary['epe2d'] < best_summary['epe2d']
100 |
101 |
102 |
--------------------------------------------------------------------------------
/models/camlipwc2d_core.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .utils import Conv2dNormRelu
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_channels, out_channels, down_sample=True, norm=None):
8 | super().__init__()
9 |
10 | if down_sample:
11 | self.down0 = Conv2dNormRelu(in_channels, out_channels, stride=2, norm=norm, act=None)
12 | self.conv0 = Conv2dNormRelu(in_channels, out_channels, kernel_size=3, stride=2, padding=1, norm=norm)
13 | self.conv1 = Conv2dNormRelu(out_channels, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, act=None)
14 | else:
15 | self.down0 = nn.Identity()
16 | self.conv0 = Conv2dNormRelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, norm=norm)
17 | self.conv1 = Conv2dNormRelu(out_channels, out_channels, kernel_size=3, stride=1, padding=1, norm=norm, act=None)
18 |
19 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
20 |
21 | def forward(self, x):
22 | out = self.conv0(x)
23 | out = self.conv1(out)
24 | out = self.relu(out + self.down0(x))
25 | return out
26 |
27 |
28 | class FeaturePyramid2D(nn.Module):
29 | def __init__(self, n_channels, norm=None):
30 | super().__init__()
31 | self.pyramid_convs = nn.ModuleList()
32 | for in_channels, out_channels in zip(n_channels[:-1], n_channels[1:]):
33 | self.pyramid_convs.append(ResidualBlock(in_channels, out_channels, norm=norm))
34 |
35 | def forward(self, x):
36 | outputs = []
37 | for conv in self.pyramid_convs:
38 | x = conv(x)
39 | outputs.append(x)
40 | return outputs
41 |
42 |
43 | class FlowEstimatorDense2D(nn.Module):
44 | def __init__(self, n_channels, norm=None, conv_last=True):
45 | super().__init__()
46 | self.conv1 = Conv2dNormRelu(
47 | n_channels[0],
48 | n_channels[1],
49 | kernel_size=3, padding=1, norm=norm
50 | )
51 | self.conv2 = Conv2dNormRelu(
52 | n_channels[0] + n_channels[1],
53 | n_channels[2],
54 | kernel_size=3, padding=1, norm=norm
55 | )
56 | self.conv3 = Conv2dNormRelu(
57 | n_channels[0] + n_channels[1] + n_channels[2],
58 | n_channels[3],
59 | kernel_size=3, padding=1, norm=norm
60 | )
61 | self.conv4 = Conv2dNormRelu(
62 | n_channels[0] + n_channels[1] + n_channels[2] + n_channels[3],
63 | n_channels[4],
64 | kernel_size=3, padding=1, norm=norm
65 | )
66 | self.conv5 = Conv2dNormRelu(
67 | n_channels[0] + n_channels[1] + n_channels[2] + n_channels[3] + n_channels[4],
68 | n_channels[5],
69 | kernel_size=3, padding=1, norm=norm
70 | )
71 | self.flow_feat_dim = sum(n_channels)
72 |
73 | if conv_last:
74 | self.conv_last = nn.Conv2d(self.flow_feat_dim, 2, kernel_size=3, stride=1, padding=1)
75 | else:
76 | self.conv_last = None
77 |
78 | def forward(self, x):
79 | x1 = torch.cat([self.conv1(x), x], dim=1)
80 | x2 = torch.cat([self.conv2(x1), x1], dim=1)
81 | x3 = torch.cat([self.conv3(x2), x2], dim=1)
82 | x4 = torch.cat([self.conv4(x3), x3], dim=1)
83 | flow_feat = torch.cat([self.conv5(x4), x4], dim=1)
84 |
85 | if self.conv_last is not None:
86 | flow = self.conv_last(flow_feat)
87 | return flow_feat, flow
88 | else:
89 | return flow_feat
90 |
91 |
92 | class ContextNetwork2D(nn.Module):
93 | def __init__(self, n_channels, dilations, norm=None):
94 | super().__init__()
95 | self.convs = nn.ModuleList()
96 | for in_channels, out_channels, dilation in zip(n_channels[:-1], n_channels[1:], dilations):
97 | self.convs.append(Conv2dNormRelu(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation, norm=norm))
98 | self.conv_last = nn.Conv2d(n_channels[-1], 2, kernel_size=3, stride=1, padding=1)
99 |
100 | def forward(self, x):
101 | for conv in self.convs:
102 | x = conv(x)
103 | outputs = self.conv_last(x)
104 | return x, outputs
105 |
--------------------------------------------------------------------------------
/models/camlipwc3d_core.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .pointconv import PointConv, PointConvS
4 | from .utils import MLP1d, MLP2d, Conv1dNormRelu, k_nearest_neighbor, batch_indexing, knn_grouping_2d, get_hw_idx, mask_batch_selecting
5 | from ops_pytorch.gpu_threenn_sample.no_sort_knn import no_sort_knn
6 |
7 |
8 | def get_selected_idx(batch_size: int, out_H: int, out_W: int, stride_H: int, stride_W: int):
9 |
10 | select_h_idx = torch.arange(0, out_H * stride_H, stride_H, device = "cuda") # [out_H]
11 | select_w_idx = torch.arange(0, out_W * stride_W, stride_W, device = "cuda") # [out_W]
12 | height_indices = torch.reshape(select_h_idx, (1, 1, -1, 1)).expand(batch_size, 1, out_H, out_W) # b out_H out_W
13 | width_indices = torch.reshape(select_w_idx, (1, 1, 1, -1)).expand(batch_size, 1, out_H, out_W) # b out_H out_W
14 | select_idx = torch.cat([height_indices, width_indices], dim = 1)
15 | return select_idx
16 |
17 | def fast_index(inputs, idx):
18 | """
19 | Input:
20 | inputs: input points data, [B, 3, H, W]
21 | idx: sample index data, [B, 2, h, w]
22 | Return:
23 | outputs:, indexed points data, [B, 3, h, w]
24 | """
25 | if len(inputs.shape) == 4:
26 | B, C, H, W = inputs.shape
27 | _, _, h, w = idx.shape
28 | neighbor_idx = idx[:, 0] * W + idx[:, 1] # (B, h, w)
29 | neighbor_idx = neighbor_idx.reshape(B, 1, h*w) # (B, h*w)
30 | inputs_bcn = inputs.reshape(B, C, H*W) # (B, C, H*W)
31 | gather_feat = torch.gather(inputs_bcn, 2, neighbor_idx.expand(-1, C, -1)) # [B, C, h*w]
32 | outputs = torch.reshape(gather_feat, [B, C, h, w])
33 | else:
34 | B, H, W = inputs.shape
35 | _, _, h, w = idx.shape
36 | neighbor_idx = idx[:, 0] * W + idx[:, 1] # (B, h, w)
37 | neighbor_idx = neighbor_idx.reshape(B, h*w) # (B, h*w)
38 | inputs_bcn = inputs.reshape(B, H*W) # (B, H*W)
39 | gather_feat = torch.gather(inputs, 1, neighbor_idx) # [B, h*w]
40 | outputs = torch.reshape(gather_feat, [B, h, w])
41 |
42 | return outputs
43 |
44 | def stride_sample_gather(pc1, pc2, label, mask, stride_H_list, stride_W_list):
45 | """
46 | Input:
47 | pc1: input points data, [B, 3, H, W]
48 | pc2: input points data, [B, 2, H, H]
49 | mask: [B, H, W]
50 | label: [B, 3, H, H]
51 | stride_list: list of sampling strides
52 | Return:
53 | xyzs:, list of sampled pc
54 | """
55 | B = pc1.shape[0]
56 | H_list = [pc1.shape[2]]; W_list = [pc1.shape[3]]
57 |
58 | xyzs1 = [pc1]; xyzs2 = [pc2]; labels = [label]; masks = [mask]
59 |
60 | for s_h, s_w in zip(stride_H_list, stride_W_list):
61 | H_list.append(H_list[-1] // s_h)
62 | W_list.append(W_list[-1] // s_w)
63 | idx = get_selected_idx(B, H_list[-1], W_list[-1], s_h, s_w)
64 | xyzs1.append(fast_index(xyzs1[-1], idx))
65 | xyzs2.append(fast_index(xyzs2[-1], idx))
66 | labels.append(fast_index(labels[-1], idx))
67 | masks.append(fast_index(masks[-1], idx))
68 |
69 | return xyzs1, xyzs2, labels[1:], masks[1:]
70 |
71 | class KnnUpsampler3D(nn.Module):
72 | def __init__(self, stride_h, stride_w, ks = [10, 20], dist = 100.0, k=3) -> None:
73 | super().__init__()
74 | self.k = k
75 | self.stride_h = stride_h
76 | self.stride_w = stride_w
77 | self.dist = dist
78 | self.ks = ks
79 |
80 | @torch.no_grad()
81 | def knn_grouping(self, query_xyz, input_xyz):
82 | """
83 | :param query_xyz: [batch_size, 3, H, W]
84 | :param input_xyz: [batch_size, 3, h, w]
85 | :return grouped idx: [batch_size, H*W, k]
86 | """
87 | B, C, h, w = input_xyz.shape
88 | _, _, H, W = query_xyz.shape
89 | n_sampled = H * W
90 |
91 | assert H // h == self.stride_h and W // w == self.stride_w, "size mismatch"
92 |
93 | idx_hw = get_hw_idx(B, H, W).contiguous()
94 | random_HW = torch.arange(0, self.ks[0] * self.ks[1], device = "cuda", dtype = torch.int)
95 | input_xyz_hw3 = input_xyz.permute(0, 2, 3, 1).contiguous() # [B, H, W, 3]
96 | query_xyz_hw3 = query_xyz.permute(0, 2, 3, 1).contiguous() # [B, H, W, 3]
97 |
98 | # Initialize
99 | select_b_idx = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').long().detach() # (B N nsample_q 1)
100 | select_h_idx = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').long().detach()
101 | select_w_idx = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').long().detach()
102 | valid_mask = torch.zeros(B, n_sampled, self.k, 1, device = 'cuda').float().detach()
103 |
104 | # with torch.no_grad():
105 | # Sample QNN of (M neighbour points from sampled n points in PC1) in PC2
106 | select_b_idx, select_h_idx, select_w_idx, valid_mask = no_sort_knn\
107 | (query_xyz_hw3, input_xyz_hw3, idx_hw, random_HW, H, W, n_sampled, self.ks[0], self.ks[1],\
108 | self.k, 1, self.dist, self.stride_h, self.stride_w, select_b_idx, select_h_idx, select_w_idx, valid_mask)
109 |
110 | neighbor_idx = select_h_idx * w + select_w_idx # [B, H*W, k, 1]
111 |
112 | return neighbor_idx.squeeze(-1), valid_mask.squeeze(-1)
113 | # return neighbor_idx.squeeze(-1), valid_mask.squeeze(-1)
114 |
115 | def forward(self, query_xyz, input_xyz, input_features):
116 | """
117 | :param input_xyz: 3D locations of input points, [B, 3, h, w]
118 | :param input_features: features of input points, [B, C, h, w]
119 | :param query_xyz: 3D locations of query points, [B, 3, H, W]
120 | :param k: k-nearest neighbor, int
121 | :return interpolated features: [B, C, H, W]
122 | """
123 | B, _, H, W = query_xyz.shape
124 |
125 |
126 | # knn_indices: [B, H*W, 3]
127 | knn_indices, valid_knn_mask = self.knn_grouping(query_xyz, input_xyz)
128 | knn_xyz = mask_batch_selecting(input_xyz, knn_indices, valid_knn_mask) # [B, 3, H*W, 3]
129 | query_xyz = query_xyz.view(B, 3, H*W)
130 | knn_dists = torch.linalg.norm(knn_xyz - query_xyz[..., None], dim = 1).clamp(1e-8)
131 | # knn_weights: [B, H*W, 3]
132 | knn_weights = 1.0 / knn_dists
133 | knn_weights = knn_weights / torch.sum(knn_weights, dim = -1, keepdim = True)
134 | knn_features = mask_batch_selecting(input_features, knn_indices, valid_knn_mask) # [B, C, H*W, 3]
135 |
136 | # interpolated: [B, C, H*W]
137 | interpolated = torch.sum(knn_features * knn_weights[:, None, :, :], dim=-1)
138 |
139 | interpolated = interpolated.view(B, -1, H, W)
140 | return interpolated
141 |
142 | class FeaturePyramid3D(nn.Module):
143 | def __init__(self, n_channels, norm=None, k=16, ks = [10, 20]):
144 | super().__init__()
145 |
146 | self.mlps = nn.ModuleList([MLP2d(3, [n_channels[0], n_channels[0]])])
147 | self.convs = nn.ModuleList([PointConv(n_channels[0], n_channels[0], norm=norm, k=k, ks = ks)])
148 |
149 | for i in range(1, len(n_channels)):
150 | self.mlps.append(MLP2d(n_channels[i - 1], [n_channels[i - 1], n_channels[i]]))
151 | self.convs.append(
152 | PointConv(n_channels[i], n_channels[i], norm=norm, k=k, ks = ks)
153 | )
154 |
155 | def forward(self, xyzs):
156 | """
157 | :param xyzs: pyramid of points
158 | :return feats: pyramid of features
159 | """
160 | assert len(xyzs) == len(self.mlps) + 1
161 |
162 | input_feat = xyzs[0] # [bs, 3, h, w]
163 | # input_feat = self.level0_mlp(inputs)
164 | feats = []
165 |
166 | for i in range(len(xyzs) - 1):
167 | if i == 0:
168 | feat = self.mlps[i](input_feat)
169 | else:
170 | feat = self.mlps[i](feats[-1])
171 |
172 | feat = self.convs[i](xyzs[i], feat, xyzs[i + 1])
173 | feats.append(feat)
174 |
175 | return feats
176 |
177 | class FeaturePyramid3DS(nn.Module):
178 | def __init__(self, n_channels, norm=None, k=16):
179 | super().__init__()
180 |
181 | self.level0_mlp = MLP1d(3, [n_channels[0], n_channels[0]])
182 |
183 | self.pyramid_mlps = nn.ModuleList()
184 | self.pyramid_convs = nn.ModuleList()
185 |
186 | for i in range(len(n_channels) - 1):
187 | self.pyramid_mlps.append(MLP1d(n_channels[i], [n_channels[i], n_channels[i + 1]]))
188 | self.pyramid_convs.append(PointConvS(n_channels[i + 1], n_channels[i + 1], norm=norm, k=k))
189 |
190 | def forward(self, xyzs):
191 | """
192 | :param xyzs: pyramid of points
193 | :return feats: pyramid of features
194 | """
195 | assert len(xyzs) == len(self.pyramid_mlps) + 1
196 |
197 | inputs = xyzs[0] # [bs, 3, n_points]
198 | feats = [self.level0_mlp(inputs)]
199 |
200 | for i in range(len(xyzs) - 1):
201 | feat = self.pyramid_mlps[i](feats[-1])
202 | feats.append(self.pyramid_convs[i](xyzs[i], feat, xyzs[i + 1]))
203 |
204 | return feats
205 |
206 | class Costvolume3D(nn.Module):
207 | def __init__(self, in_channels, out_channels, ks = [10, 20], dist = 100.0, k=16):
208 | super().__init__()
209 |
210 | self.k = k
211 | self.ks = ks
212 | self.dist = dist
213 | self.cost_mlp = MLP2d(3 + 2 * in_channels, [out_channels, out_channels], act='leaky_relu')
214 | self.weight_net1 = MLP2d(3, [8, 8, out_channels], act='relu')
215 | self.weight_net2 = MLP2d(3, [8, 8, out_channels], act='relu')
216 |
217 |
218 | def forward(self, xyz1, feat1, xyz2, feat2, idx_fetching=None, knn_indices_1in1=None, valid_mask_1in1 = None):
219 | """
220 | :param xyz1: [batch_size, 3, H, W]
221 | :param feat1: [batch_size, in_channels, H, W]
222 | :param xyz2: [batch_size, 3, H, W]
223 | :param feat2: [batch_size, in_channels, H, W]
224 | :param warping idx: for each warped point in xyz1, find its position in xyz2, [batch_size, H * W, 2]
225 | :return cost volume: [batch_size, n_cost_channels, H, W]
226 | """
227 | B, C, H, W = feat1.shape
228 | feat1 = feat1.view(B, C, H * W)
229 |
230 | knn_indices_1in2, valid_mask_1in2 = knn_grouping_2d(query_xyz=xyz1, input_xyz=xyz2, k=self.k, idx_fetching = idx_fetching)
231 | # knn_xyz2: [B, 3, H*W, k],
232 | knn_xyz2 = mask_batch_selecting(xyz2, knn_indices_1in2, valid_mask_1in2)
233 | # knn_xyz2_norm: [B, 3, H*W, k]
234 | knn_xyz2_norm = knn_xyz2 - xyz1.view(B, 3, H * W, 1)
235 | # knn_features2: [B, C, H*W, k]
236 | knn_features2 = mask_batch_selecting(feat2, knn_indices_1in2, valid_mask_1in2)
237 | # features1_expand: [B, C, H*W, k]
238 | features1_expand = feat1[:, :, :, None].expand(B, C, H * W, self.k)
239 | # concatenated_features: [B, 2C+3, H*W, k]
240 | concatenated_features = torch.cat([features1_expand, knn_features2, knn_xyz2_norm], dim=1)
241 | # p2p_cost (point-to-point cost): [B, out_channels, H*W, k]
242 | p2p_cost = self.cost_mlp(concatenated_features)
243 |
244 | # weights2: [B, out_channels, H*W, k]
245 | weights2 = self.weight_net2(knn_xyz2_norm)
246 | # p2n_cost (point-to-neighbor cost): [B, out_channels, H * W]
247 | p2n_cost = torch.sum(weights2 * p2p_cost, dim=3)
248 |
249 | if knn_indices_1in1 is not None:
250 | assert knn_indices_1in1.shape[:2] == torch.Size([B, H * W])
251 | assert knn_indices_1in1.shape[2] >= self.k
252 | knn_indices_1in1 = knn_indices_1in1[:, :, :self.k]
253 | valid_mask_1in1 = valid_mask_1in1[:, :, :self.k]
254 | else:
255 | knn_indices_1in1, valid_mask_1in1 = knn_grouping_2d(query_xyz=xyz1, input_xyz=xyz1, k=self.k) # [bs, n_points, k]
256 |
257 | # knn_xyz1: [B, 3, H*W, k]
258 | knn_xyz1 = mask_batch_selecting(xyz1, knn_indices_1in1, valid_mask_1in1)
259 | # knn_xyz1_norm: [B, 3, H*W, k]
260 | knn_xyz1_norm = knn_xyz1 - xyz1.view(B, 3, H * W, 1)
261 | # weights1: [B, out_channels, H*W, k]
262 | weights1 = self.weight_net1(knn_xyz1_norm)
263 | # n2n_cost: [B, out_channels, H*W, k]
264 | n2n_cost = mask_batch_selecting(p2n_cost, knn_indices_1in1, valid_mask_1in1)
265 | # n2n_cost: [B, out_channels, H * W]
266 | n2n_cost = torch.sum(weights1 * n2n_cost, dim=3)
267 |
268 | n2n_cost = n2n_cost.view(B, -1, H, W)
269 |
270 | return n2n_cost
271 |
272 | class Costvolume3DS(nn.Module):
273 | def __init__(self, in_channels, out_channels, align_channels=None, k=16):
274 | super().__init__()
275 | self.k = k
276 |
277 | self.cost_mlp = MLP2d(3 + 2 * in_channels, [out_channels, out_channels], act='leaky_relu')
278 | self.weight_net1 = MLP2d(3, [8, 8, out_channels], act='relu')
279 | self.weight_net2 = MLP2d(3, [8, 8, out_channels], act='relu')
280 |
281 | if align_channels is not None:
282 | self.feat_aligner = Conv1dNormRelu(out_channels, align_channels)
283 | else:
284 | self.feat_aligner = nn.Identity()
285 |
286 | def forward(self, xyz1, feat1, xyz2, feat2, knn_indices_1in1=None):
287 | """
288 | :param xyz1: [batch_size, 3, n_points]
289 | :param feat1: [batch_size, in_channels, n_points]
290 | :param xyz2: [batch_size, 3, n_points]
291 | :param feat2: [batch_size, in_channels, n_points]
292 | :param knn_indices_1in1: for each point in xyz1, find its neighbors in xyz1, [batch_size, n_points, k]
293 | :return cost volume for each point in xyz1: [batch_size, n_cost_channels, n_points]
294 | """
295 | batch_size, in_channels, n_points = feat1.shape
296 |
297 | # Step1: for each point in xyz1, find its neighbors in xyz2
298 | knn_indices_1in2 = k_nearest_neighbor(input_xyz=xyz2, query_xyz=xyz1, k=self.k)
299 | # knn_xyz2: [bs, 3, n_points, k]
300 | knn_xyz2 = batch_indexing(xyz2, knn_indices_1in2)
301 | # knn_xyz2_norm: [bs, 3, n_points, k]
302 | knn_xyz2_norm = knn_xyz2 - xyz1.view(batch_size, 3, n_points, 1)
303 | # knn_features2: [bs, in_channels, n_points, k]
304 | knn_features2 = batch_indexing(feat2, knn_indices_1in2)
305 | # features1_expand: [bs, in_channels, n_points, k]
306 | features1_expand = feat1[:, :, :, None].expand(batch_size, in_channels, n_points, self.k)
307 | # concatenated_features: [bs, 2 * in_channels + 3, n_points, k]
308 | concatenated_features = torch.cat([features1_expand, knn_features2, knn_xyz2_norm], dim=1)
309 | # p2p_cost (point-to-point cost): [bs, out_channels, n_points, k]
310 | p2p_cost = self.cost_mlp(concatenated_features)
311 |
312 | # weights2: [bs, out_channels, n_points, k]
313 | weights2 = self.weight_net2(knn_xyz2_norm)
314 | # p2n_cost (point-to-neighbor cost): [bs, out_channels, n_points]
315 | p2n_cost = torch.sum(weights2 * p2p_cost, dim=3)
316 |
317 | # Step2: for each point in xyz1, find its neighbors in xyz1
318 | if knn_indices_1in1 is not None:
319 | assert knn_indices_1in1.shape[:2] == torch.Size([batch_size, n_points])
320 | assert knn_indices_1in1.shape[2] >= self.k
321 | knn_indices_1in1 = knn_indices_1in1[:, :, :self.k]
322 | else:
323 | knn_indices_1in1 = k_nearest_neighbor(input_xyz=xyz1, query_xyz=xyz1, k=self.k) # [bs, n_points, k]
324 | # knn_xyz1: [bs, 3, n_points, k]
325 | knn_xyz1 = batch_indexing(xyz1, knn_indices_1in1)
326 | # knn_xyz1_norm: [bs, 3, n_points, k]
327 | knn_xyz1_norm = knn_xyz1 - xyz1.view(batch_size, 3, n_points, 1)
328 |
329 | # weights1: [bs, out_channels, n_points, k]
330 | weights1 = self.weight_net1(knn_xyz1_norm)
331 | # n2n_cost: [bs, out_channels, n_points, k]
332 | n2n_cost = batch_indexing(p2n_cost, knn_indices_1in1)
333 | # n2n_cost (neighbor-to-neighbor cost): [bs, out_channels, n_points]
334 | n2n_cost = torch.sum(weights1 * n2n_cost, dim=3)
335 | # align features (optional)
336 | n2n_cost = self.feat_aligner(n2n_cost)
337 |
338 | return n2n_cost
339 |
340 | class FlowPredictor3D(nn.Module):
341 | def __init__(self, n_channels, norm=None, conv_last=True, k=16):
342 | super().__init__()
343 | self.point_conv1 = PointConv(in_channels=n_channels[0], out_channels=n_channels[1], norm=norm, k=k)
344 | self.point_conv2 = PointConv(in_channels=n_channels[1], out_channels=n_channels[2], norm=norm, k=k)
345 | self.mlp = MLP2d(n_channels[2], [n_channels[2], n_channels[3]])
346 | self.flow_feat_dim = n_channels[3]
347 |
348 | if conv_last:
349 | self.conv_last = nn.Conv2d(n_channels[3], 3, kernel_size=1)
350 | else:
351 | self.conv_last = None
352 |
353 | def forward(self, xyz, feat, knn_indices, valid_knn_mask):
354 | """
355 | :param xyz: 3D locations of points, [B, 3, H, W]
356 | :param feat: features of points, [B, in_channels, H, W]
357 | :return flow_feat: [B, 64, H, W]
358 | :return flow: [B, 3, H, W]
359 | """
360 | feat = self.point_conv1(xyz, feat, knn_indices = knn_indices, valid_knn_mask = valid_knn_mask) # [bs, 128, H, W]
361 | feat = self.point_conv2(xyz, feat, knn_indices = knn_indices, valid_knn_mask = valid_knn_mask) # [bs, 128, H, W]
362 | feat = self.mlp(feat) # [bs, 64, H, W]
363 |
364 | if self.conv_last is not None:
365 | flow = self.conv_last(feat) # [bs, 3, H, W]
366 | return feat, flow
367 | else:
368 | return feat
369 |
370 |
371 | class FlowPredictor3DS(nn.Module):
372 | def __init__(self, n_channels, norm=None, conv_last=False, k=16):
373 | super().__init__()
374 | self.point_conv1 = PointConvS(in_channels=n_channels[0], out_channels=n_channels[1], norm=norm, k=k)
375 | self.point_conv2 = PointConvS(in_channels=n_channels[1], out_channels=n_channels[2], norm=norm, k=k)
376 | self.mlp = MLP1d(n_channels[2], [n_channels[2], n_channels[3]])
377 | self.flow_feat_dim = n_channels[3]
378 |
379 | if conv_last:
380 | self.conv_last = nn.Conv1d(n_channels[3], 3, kernel_size=1)
381 | else:
382 | self.conv_last = None
383 |
384 | def forward(self, xyz, feat, knn_indices, mask):
385 | """
386 | :param xyz: 3D locations of points, [batch_size, 3, n_points]
387 | :param feat: features of points, [batch_size, in_channels, n_points]
388 | :param knn_indices: knn indices of points, [batch_size, n_points, k]
389 | :return flow_feat: [batch_size, 64, n_points]
390 | :return flow: [batch_size, 3, n_points]
391 | """
392 | feat = self.point_conv1(xyz, feat, knn_indices=knn_indices)
393 | feat = self.point_conv2(xyz, feat, knn_indices=knn_indices)
394 | feat = self.mlp(feat)
395 |
396 | if self.conv_last is not None:
397 | flow = self.conv_last(feat)
398 | return feat, flow
399 | else:
400 | return feat
--------------------------------------------------------------------------------
/models/camlipwc_core.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.functional import leaky_relu, interpolate
4 | from .camlipwc2d_core import FeaturePyramid2D, FlowEstimatorDense2D, ContextNetwork2D
5 | from .camlipwc3d_core import FeaturePyramid3D, FeaturePyramid3DS, Costvolume3D, Costvolume3DS, FlowPredictor3D, FlowPredictor3DS, KnnUpsampler3D
6 | from .utils import Conv1dNormRelu, Conv2dNormRelu
7 | from .utils import backwarp_2d, backwarp_3d, mesh_grid, forwarp_3d, knn_interpolation, convex_upsample, project_3d_to_2d, knn_grouping_2d, mask_batch_selecting
8 | from .csrc import correlation2d, k_nearest_neighbor
9 | from .fusion_module import GlobalFuser, GlobalFuserS
10 |
11 | class CamLiPWC_Core(nn.Module):
12 | def __init__(self, cfgs2d, cfgs3d, dense=False, debug=False):
13 | super().__init__()
14 |
15 | self.cfgs2d, self.cfgs3d, self.debug, self.dense = cfgs2d, cfgs3d, debug, dense
16 | corr_channels_2d = (2 * cfgs2d.max_displacement + 1) ** 2
17 |
18 | ## PWC-Net 2D (IRR-PWC)
19 | self.branch_2d_fnet = FeaturePyramid2D(
20 | [3, 16, 32, 64, 96, 128, 192],
21 | norm=cfgs2d.norm.feature_pyramid
22 | )
23 | self.branch_2d_fnet_aligners = nn.ModuleList([
24 | nn.Identity(),
25 | Conv2dNormRelu(32, 64),
26 | Conv2dNormRelu(64, 64),
27 | Conv2dNormRelu(96, 64),
28 | Conv2dNormRelu(128, 64),
29 | Conv2dNormRelu(192, 64),
30 | ])
31 | self.branch_2d_flow_estimator = FlowEstimatorDense2D(
32 | [64 + corr_channels_2d + 2 + 32, 128, 128, 96, 64, 32],
33 | norm=cfgs2d.norm.flow_estimator,
34 | conv_last=False,
35 | )
36 | self.branch_2d_context_network = ContextNetwork2D(
37 | [self.branch_2d_flow_estimator.flow_feat_dim + 2, 128, 128, 128, 96, 64, 32],
38 | dilations=[1, 2, 4, 8, 16, 1],
39 | norm=cfgs2d.norm.context_network
40 | )
41 | self.branch_2d_up_mask_head = nn.Sequential( # for convex upsampling (see RAFT)
42 | nn.Conv2d(32, 256, kernel_size=3, stride=1, padding=1),
43 | nn.ReLU(inplace=True),
44 | nn.Conv2d(256, 4 * 4 * 9, kernel_size=1, stride=1, padding=0),
45 | )
46 | self.branch_2d_conv_last = nn.Conv2d(self.branch_2d_flow_estimator.flow_feat_dim, 2, kernel_size=3, stride=1, padding=1)
47 |
48 | if dense:
49 | ## PWC-Net 3D (Point-PWC)
50 | self.branch_3d_fnet = FeaturePyramid3D(
51 | [16, 32, 64, 96, 128, 192],
52 | norm=cfgs3d.norm.feature_pyramid,
53 | k=cfgs3d.k,
54 | ks=cfgs3d.kernel_size
55 | )
56 | self.branch_3d_fnet_aligners = nn.ModuleList([
57 | nn.Identity(),
58 | Conv2dNormRelu(32, 64),
59 | Conv2dNormRelu(64, 64),
60 | Conv2dNormRelu(96, 64),
61 | Conv2dNormRelu(128, 64),
62 | Conv2dNormRelu(192, 64),
63 | ])
64 | self.branch_3d_correlations = nn.ModuleList([
65 | nn.Identity(),
66 | Costvolume3D(32, 32, k=self.cfgs3d.k),
67 | Costvolume3D(64, 64, k=self.cfgs3d.k),
68 | Costvolume3D(96, 96, k=self.cfgs3d.k),
69 | Costvolume3D(128, 128, k=self.cfgs3d.k),
70 | Costvolume3D(192, 192, k=self.cfgs3d.k),
71 | ])
72 | self.branch_3d_correlation_aligners = nn.ModuleList([
73 | nn.Identity(),
74 | Conv2dNormRelu(32, 64),
75 | Conv2dNormRelu(64, 64),
76 | Conv2dNormRelu(96, 64),
77 | Conv2dNormRelu(128, 64),
78 | Conv2dNormRelu(192, 64),
79 | ])
80 | self.branch_3d_flow_estimator = FlowPredictor3D(
81 | [64 + 64 + 3 + 64, 128, 128, 64],
82 | cfgs3d.norm.flow_estimator,
83 | conv_last=False,
84 | k=self.cfgs3d.k,
85 | )
86 |
87 | ## Bi-CLFM for pyramid features
88 | self.pyramid_feat_fusers = nn.ModuleList([
89 | nn.Identity(),
90 | GlobalFuser(32, 32, norm=cfgs2d.norm.feature_pyramid),
91 | GlobalFuser(64, 64, norm=cfgs2d.norm.feature_pyramid),
92 | GlobalFuser(96, 96, norm=cfgs2d.norm.feature_pyramid),
93 | GlobalFuser(128, 128, norm=cfgs2d.norm.feature_pyramid),
94 | GlobalFuser(192, 192, norm=cfgs2d.norm.feature_pyramid),
95 | ])
96 |
97 | ## Bi-CLFM for correlation features
98 | self.corr_feat_fusers = nn.ModuleList([
99 | nn.Identity(),
100 | GlobalFuser(corr_channels_2d, 32),
101 | GlobalFuser(corr_channels_2d, 64),
102 | GlobalFuser(corr_channels_2d, 96),
103 | GlobalFuser(corr_channels_2d, 128),
104 | GlobalFuser(corr_channels_2d, 192),
105 | ])
106 |
107 | self.intepolate_3d = KnnUpsampler3D(stride_h = 2, stride_w = 2, k = 3)
108 | self.knn_upconv = KnnUpsampler3D(stride_h = 4, stride_w = 4, k = 3)
109 |
110 | ## Bi-CLFM for decoder features
111 | self.estimator_feat_fuser = GlobalFuser(self.branch_2d_flow_estimator.flow_feat_dim, self.branch_3d_flow_estimator.flow_feat_dim)
112 |
113 | self.branch_3d_conv_last = nn.Conv2d(self.branch_3d_flow_estimator.flow_feat_dim, 3, kernel_size=1)
114 |
115 | else:
116 | self.branch_3d_fnet = FeaturePyramid3DS(
117 | n_channels=[16, 32, 64, 96, 128, 192], # 1/4
118 | norm=cfgs3d.norm.feature_pyramid,
119 | k=cfgs3d.k,
120 | )
121 | self.branch_3d_fnet_aligners = nn.ModuleList([
122 | nn.Identity(),
123 | Conv1dNormRelu(32, 64), # 1/4
124 | Conv1dNormRelu(64, 64),
125 | Conv1dNormRelu(96, 64),
126 | Conv1dNormRelu(128, 64),
127 | Conv1dNormRelu(192, 64),
128 | ])
129 | self.branch_3d_correlations = nn.ModuleList([
130 | nn.Identity(),
131 | Costvolume3DS(32, 32, k=self.cfgs3d.k), # 1/4
132 | Costvolume3DS(64, 64, k=self.cfgs3d.k),
133 | Costvolume3DS(96, 96, k=self.cfgs3d.k),
134 | Costvolume3DS(128, 128, k=self.cfgs3d.k),
135 | Costvolume3DS(192, 192, k=self.cfgs3d.k),
136 | ])
137 | self.branch_3d_correlation_aligners = nn.ModuleList([
138 | nn.Identity(),
139 | Conv1dNormRelu(32, 64), # 1/4
140 | Conv1dNormRelu(64, 64),
141 | Conv1dNormRelu(96, 64),
142 | Conv1dNormRelu(128, 64),
143 | Conv1dNormRelu(192, 64),
144 | ])
145 | self.branch_3d_flow_estimator = FlowPredictor3DS(
146 | [64 + 64 + 3 + 64, 128, 128, 64],
147 | cfgs3d.norm.flow_estimator,
148 | k=self.cfgs3d.k,
149 | )
150 |
151 |
152 | self.pyramid_feat_fusers = nn.ModuleList([
153 | nn.Identity(),
154 | GlobalFuserS(32, 32, norm=cfgs2d.norm.feature_pyramid), # 1/4
155 | GlobalFuserS(64, 64, norm=cfgs2d.norm.feature_pyramid),
156 | GlobalFuserS(96, 96, norm=cfgs2d.norm.feature_pyramid),
157 | GlobalFuserS(128, 128, norm=cfgs2d.norm.feature_pyramid),
158 | GlobalFuserS(192, 192, norm=cfgs2d.norm.feature_pyramid),
159 | ])
160 |
161 | self.corr_feat_fusers = nn.ModuleList([
162 | nn.Identity(),
163 | GlobalFuserS(corr_channels_2d, 32), # 1/4
164 | GlobalFuserS(corr_channels_2d, 64),
165 | GlobalFuserS(corr_channels_2d, 96),
166 | GlobalFuserS(corr_channels_2d, 128),
167 | GlobalFuserS(corr_channels_2d, 192),
168 | ])
169 |
170 | self.estimator_feat_fuser = GlobalFuserS(self.branch_2d_flow_estimator.flow_feat_dim, self.branch_3d_flow_estimator.flow_feat_dim)
171 | self.branch_3d_conv_last = nn.Conv1d(self.branch_3d_flow_estimator.flow_feat_dim, 3, kernel_size=1)
172 |
173 |
174 | def encode(self, image, xyzs):
175 | feats_2d = self.branch_2d_fnet(image)
176 | feats_3d = self.branch_3d_fnet(xyzs)
177 | return feats_2d, feats_3d
178 |
179 | def decode(self, xyzs1, xyzs2, feats1_2d, feats2_2d, feats1_3d, feats2_3d, raw_pc1, raw_pc2, camera_info):
180 | assert len(xyzs1) == len(xyzs2) == len(feats1_2d) == len(feats2_2d) == len(feats1_3d) == len(feats2_3d)
181 |
182 | flows_2d, flows_3d, flow_feats_2d, flow_feats_3d = [], [], [], []
183 | for level in range(len(xyzs1) - 1, 0, -1):
184 | xyz1, feat1_2d, feat1_3d = xyzs1[level], feats1_2d[level], feats1_3d[level]
185 | xyz2, feat2_2d, feat2_3d = xyzs2[level], feats2_2d[level], feats2_3d[level]
186 |
187 | batch_size, image_h, image_w = feat1_2d.shape[0], feat1_2d.shape[2], feat1_2d.shape[3]
188 | if not self.dense:
189 | n_points = xyz1.shape[-1]
190 |
191 | # project point cloud to image
192 | uv1, uv_mask1 = project_3d_to_2d(xyz1, camera_info, image_h, image_w)
193 | uv2, uv_mask2 = project_3d_to_2d(xyz2, camera_info, image_h, image_w)
194 |
195 | # pre-compute knn indices
196 | if self.dense:
197 | knn_1in1, valid_mask_1in1 = knn_grouping_2d(xyz1, xyz1, k=self.cfgs3d.k) # [bs, n_points, k]
198 | else:
199 | grid = mesh_grid(batch_size, image_h, image_w, uv1.device) # [B, 2, H, W]
200 | grid = grid.reshape([batch_size, 2, -1]) # [B, 2, HW]
201 | knn_1in1 = k_nearest_neighbor(xyz1, xyz1, k=self.cfgs3d.k) # [bs, n_points, k]
202 | valid_mask_1in1 = None
203 |
204 | # fuse pyramid features
205 | feat1_2d, feat1_3d = self.pyramid_feat_fusers[level](uv1, uv_mask1, feat1_2d, feat1_3d)
206 | feat2_2d, feat2_3d = self.pyramid_feat_fusers[level](uv2, uv_mask2, feat2_2d, feat2_3d)
207 |
208 |
209 | if level == len(xyzs1) - 1:
210 | last_flow_2d = torch.zeros([batch_size, 2, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device)
211 | last_flow_feat_2d = torch.zeros([batch_size, 32, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device)
212 |
213 | if self.dense:
214 | last_flow_3d = torch.zeros([batch_size, 3, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device)
215 | last_flow_feat_3d = torch.zeros([batch_size, 64, image_h, image_w], dtype=xyz1.dtype, device=xyz1.device)
216 | xyz1_warp, feat2_2d_warp = xyz1, feat2_2d
217 | warping_idx = None
218 | else:
219 | last_flow_3d = torch.zeros([batch_size, 3, n_points], dtype=xyz1.dtype, device=xyz1.device)
220 | last_flow_feat_3d = torch.zeros([batch_size, 64, n_points], dtype=xyz1.dtype, device=xyz1.device)
221 | xyz2_warp, feat2_2d_warp = xyz2, feat2_2d
222 | else:
223 | # upsample 2d flow and backwarp
224 | last_flow_2d = interpolate(flows_2d[-1] * 2, scale_factor=2, mode='bilinear', align_corners=True)
225 | last_flow_feat_2d = interpolate(flow_feats_2d[-1], scale_factor=2, mode='bilinear', align_corners=True)
226 | feat2_2d_warp = backwarp_2d(feat2_2d, last_flow_2d, padding_mode='border')
227 |
228 | if self.dense:
229 | # upsample 3d flow and backwarp
230 | flow_with_feat_3d = torch.cat([flows_3d[-1], flow_feats_3d[-1]], dim=1)
231 | flow_with_feat_upsampled_3d = self.intepolate_3d(xyz1, xyzs1[level + 1], flow_with_feat_3d)
232 | last_flow_3d, last_flow_feat_3d = torch.split(flow_with_feat_upsampled_3d, [3, 64], dim=1)
233 | xyz1_warp, warping_idx = forwarp_3d(xyz1, xyz2, last_flow_3d, camera_info)
234 | else:
235 | last_flow_3d, last_flow_feat_3d = torch.split(
236 | knn_interpolation(
237 | xyzs1[level + 1],
238 | torch.cat([flows_3d[-1], flow_feats_3d[-1]], dim=1),
239 | xyz1
240 | ), [3, 64], dim=1)
241 | xyz2_warp = backwarp_3d(xyz1, xyz2, last_flow_3d)
242 |
243 | # correlation (2D & 3D)
244 | if self.dense:
245 | feat_corr_3d = self.branch_3d_correlations[level](xyz1_warp, feat1_3d, xyz2, feat2_3d, warping_idx, knn_1in1, valid_mask_1in1)
246 | else:
247 | feat_corr_3d = self.branch_3d_correlations[level](xyz1, feat1_3d, xyz2_warp, feat2_3d, knn_1in1)
248 | feat_corr_2d = leaky_relu(correlation2d(feat1_2d, feat2_2d_warp, self.cfgs2d.max_displacement), 0.1)
249 |
250 | # fuse correlation features
251 | feat_corr_2d, feat_corr_3d = self.corr_feat_fusers[level](uv1, uv_mask1, feat_corr_2d, feat_corr_3d)
252 |
253 | # align features using 1x1 convolution
254 | feat1_2d = self.branch_2d_fnet_aligners[level](feat1_2d)
255 | feat1_3d = self.branch_3d_fnet_aligners[level](feat1_3d)
256 | feat_corr_3d = self.branch_3d_correlation_aligners[level](feat_corr_3d)
257 |
258 | # flow decoder (or estimator)
259 | x_2d = torch.cat([feat_corr_2d, feat1_2d, last_flow_2d, last_flow_feat_2d], dim=1)
260 | x_3d = torch.cat([feat_corr_3d, feat1_3d, last_flow_3d, last_flow_feat_3d], dim=1)
261 | flow_feat_2d = self.branch_2d_flow_estimator(x_2d) # [bs, 96, image_h, image_w]
262 | flow_feat_3d = self.branch_3d_flow_estimator(xyz1, x_3d, knn_1in1, valid_mask_1in1) # [bs, 64, n_points]
263 |
264 | # fuse decoder features
265 | flow_feat_2d, flow_feat_3d = self.estimator_feat_fuser(uv1, uv_mask1, flow_feat_2d, flow_feat_3d)
266 |
267 | # flow prediction
268 | flow_delta_2d = self.branch_2d_conv_last(flow_feat_2d)
269 | flow_delta_3d = self.branch_3d_conv_last(flow_feat_3d)
270 |
271 | # residual connection
272 | flow_2d = last_flow_2d + flow_delta_2d
273 | flow_3d = last_flow_3d + flow_delta_3d
274 |
275 | # context network (2D only)
276 | flow_feat_2d, flow_delta_2d = self.branch_2d_context_network(torch.cat([flow_feat_2d, flow_2d], dim=1))
277 | flow_2d = flow_delta_2d + flow_2d
278 |
279 | # clip
280 | flow_2d = torch.clip(flow_2d, min=-1000, max=1000)
281 | flow_3d = torch.clip(flow_3d, min=-100, max=100)
282 |
283 | # save results
284 | flows_2d.append(flow_2d)
285 | flows_3d.append(flow_3d)
286 | flow_feats_2d.append(flow_feat_2d)
287 | flow_feats_3d.append(flow_feat_3d)
288 |
289 | flows_2d = [f.float() for f in flows_2d][::-1]
290 | flows_3d = [f.float() for f in flows_3d][::-1]
291 |
292 | # convex upsamling module, from RAFT
293 | flows_2d[0] = convex_upsample(flows_2d[0], self.branch_2d_up_mask_head(flow_feats_2d[-1]), scale_factor=4)
294 |
295 | for i in range(1, len(flows_2d)):
296 | flows_2d[i] = interpolate(flows_2d[i] * 4, scale_factor=4, mode='bilinear', align_corners=True)
297 |
298 | for i in range(len(flows_3d)):
299 | if self.dense:
300 | if i == 0:
301 | flows_3d[i] = self.knn_upconv(raw_pc1, xyzs1[i + 1], flows_3d[i])
302 | else:
303 | flows_3d[i] = self.knn_upconv(xyzs1[i - 1], xyzs1[i + 1], flows_3d[i])
304 | else:
305 | flows_3d[i] = knn_interpolation(xyzs1[i + 1], flows_3d[i], xyzs1[i])
306 |
307 | return flows_2d, flows_3d
308 |
--------------------------------------------------------------------------------
/models/csrc/__init__.py:
--------------------------------------------------------------------------------
1 | from .wrapper import correlation2d, furthest_point_sampling, squared_distance, k_nearest_neighbor
2 |
--------------------------------------------------------------------------------
/models/csrc/correlation/correlation.cpp:
--------------------------------------------------------------------------------
1 | #include "correlation.h"
2 |
3 | void correlation_forward_kernel_wrapper(float* output, const float* input1, const float* input2,
4 | int n_batches, int in_channels, int height, int width, int max_displacement);
5 |
6 | void correlation_backward_kernel_wrapper(const float* grad_output,
7 | float* grad_input1, float* grad_input2,
8 | const float* input1, const float* input2,
9 | int n_batches, int in_channels, int height, int width, int max_displacement);
10 |
11 | torch::Tensor correlation_forward_cuda(torch::Tensor& input1, torch::Tensor& input2, int max_displacement) {
12 | TORCH_CHECK(input1.is_contiguous(), "input1 must be a contiguous tensor");
13 | TORCH_CHECK(input2.is_contiguous(), "input2 must be a contiguous tensor");
14 |
15 | int batch_size = input1.size(0), height = input1.size(1), width = input1.size(2), in_channels = input1.size(3);
16 | int out_channels = (max_displacement * 2 + 1) * (max_displacement * 2 + 1);
17 | torch::Tensor output = torch::zeros({batch_size, out_channels, height, width}, torch::device(input1.device()));
18 |
19 | correlation_forward_kernel_wrapper(output.data_ptr(), input1.data_ptr(), input2.data_ptr(),
20 | batch_size, in_channels, height, width, max_displacement);
21 | return output;
22 | }
23 |
24 | std::pair correlation_backward_cuda(torch::Tensor& grad_output, torch::Tensor& input1, torch::Tensor& input2, int max_displacement) {
25 | int batch_size = input1.size(0), height = input1.size(1), width = input1.size(2), in_channels = input1.size(3);
26 | torch::Tensor grad_input1 = torch::empty({batch_size, in_channels, height, width}, torch::device(input1.device()));
27 | torch::Tensor grad_input2 = torch::empty({batch_size, in_channels, height, width}, torch::device(input2.device()));
28 |
29 | correlation_backward_kernel_wrapper(grad_output.data_ptr(),
30 | grad_input1.data_ptr(), grad_input2.data_ptr(),
31 | input1.data_ptr(), input2.data_ptr(),
32 | batch_size, in_channels, height, width, max_displacement);
33 |
34 | return std::pair(grad_input1, grad_input2);
35 | }
36 |
37 | #ifdef TORCH_EXTENSION_NAME
38 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
39 | m.def("_correlation_forward_cuda", &correlation_forward_cuda, "Correlation forward pass (CUDA)");
40 | m.def("_correlation_backward_cuda", &correlation_backward_cuda, "Correlation backward pass (CUDA)");
41 | }
42 | #endif
43 |
--------------------------------------------------------------------------------
/models/csrc/correlation/correlation.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | torch::Tensor correlation_forward_cuda(torch::Tensor& input1, torch::Tensor& input2, int max_displacement);
6 | std::pair correlation_backward_cuda(torch::Tensor& grad_output, torch::Tensor& input1, torch::Tensor& input2, int max_displacement);
7 |
--------------------------------------------------------------------------------
/models/csrc/correlation/correlation_backward_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | __global__ void correlation_backward_input1_kernel(float* __restrict__ grad_input1,
5 | const float* __restrict__ grad_output,
6 | const float* __restrict__ input2,
7 | int in_channels, int height, int width, int max_displacement) {
8 | int n = blockIdx.x, y1 = blockIdx.y, x1 = blockIdx.z, c = threadIdx.x;
9 |
10 | int displacement_size = 2 * max_displacement + 1;
11 | int out_channels = displacement_size * displacement_size;
12 |
13 | int in_stride0 = height * width * in_channels;
14 | int in_stride1 = width * in_channels;
15 | int in_stride2 = in_channels;
16 |
17 | int out_stride0 = out_channels * height * width;
18 | int out_stride1 = height * width;
19 | int out_stride2 = width;
20 |
21 | int grad_in_stride0 = in_channels * height* width;
22 | int grad_in_stride1 = height * width;
23 | int grad_in_stride2 = width;
24 |
25 | float sum1 = 0;
26 | for (int tx = -max_displacement; tx <= max_displacement; ++tx)
27 | for (int ty = -max_displacement; ty <= max_displacement; ++ty) {
28 | int x2 = x1 + ty, y2 = y1 + tx;
29 | if (x2 < 0 || y2 < 0 || x2 >= width || y2 >= height) continue;
30 | int tc = (tx + max_displacement) * displacement_size + (ty + max_displacement);
31 | int idx = n * out_stride0 + tc * out_stride1 + y1 * out_stride2 + x1;
32 | int idx2 = n * in_stride0 + y2 * in_stride1 + x2 * in_stride2 + c;
33 | sum1 += grad_output[idx] * input2[idx2];
34 | }
35 |
36 | int idx1 = n * grad_in_stride0 + c * grad_in_stride1 + y1 * grad_in_stride2 + x1;
37 | grad_input1[idx1] = sum1 / in_channels;
38 | }
39 |
40 | __global__ void correlation_backward_input2_kernel(float* __restrict__ grad_input2,
41 | const float* __restrict__ grad_output,
42 | const float* __restrict__ input1,
43 | int in_channels, int height, int width, int max_displacement) {
44 | int n = blockIdx.x, y2 = blockIdx.y, x2 = blockIdx.z, c = threadIdx.x;
45 |
46 | int displacement_size = 2 * max_displacement + 1;
47 | int out_channels = displacement_size * displacement_size;
48 |
49 | int in_stride0 = height * width * in_channels;
50 | int in_stride1 = width * in_channels;
51 | int in_stride2 = in_channels;
52 |
53 | int out_stride0 = out_channels * height * width;
54 | int out_stride1 = height * width;
55 | int out_stride2 = width;
56 |
57 | int grad_in_stride0 = in_channels * height* width;
58 | int grad_in_stride1 = height * width;
59 | int grad_in_stride2 = width;
60 |
61 | float sum2 = 0;
62 | for (int tx = -max_displacement; tx <= max_displacement; ++tx)
63 | for (int ty = -max_displacement; ty <= max_displacement; ++ty) {
64 | int x1 = x2 - ty, y1 = y2 - tx;
65 | if (x1 < 0 || y1 < 0 || x1 >= width || y1 >= height) continue;
66 | int tc = (tx + max_displacement) * displacement_size + (ty + max_displacement);
67 | int idx = n * out_stride0 + tc * out_stride1 + y1 * out_stride2 + x1;
68 | int idx1 = n * in_stride0 + y1 * in_stride1 + x1 * in_stride2 + c;
69 | sum2 += grad_output[idx] * input1[idx1];
70 | }
71 |
72 | int idx2 = n * grad_in_stride0 + c * grad_in_stride1 + y2 * grad_in_stride2 + x2;
73 | grad_input2[idx2] = sum2 / in_channels;
74 | }
75 |
76 | void correlation_backward_kernel_wrapper(const float* grad_output,
77 | float* grad_input1, float* grad_input2,
78 | const float* input1, const float* input2,
79 | int n_batches, int in_channels, int height, int width, int max_displacement) {
80 | dim3 totalBlocksCorr(n_batches, height, width);
81 | correlation_backward_input1_kernel<<>>(
82 | grad_input1, grad_output, input2,
83 | in_channels, height, width, max_displacement
84 | );
85 | correlation_backward_input2_kernel<<>>(
86 | grad_input2, grad_output, input1,
87 | in_channels, height, width, max_displacement
88 | );
89 | }
90 |
--------------------------------------------------------------------------------
/models/csrc/correlation/correlation_forward_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #define BLOCK_SIZE 32
4 |
5 | __forceinline__ __device__ float warpReduceSum(float val) {
6 | for (int offset = 16; offset > 0; offset /= 2)
7 | val += __shfl_down_sync(0xffffffff, val, offset);
8 | return val;
9 | }
10 |
11 | __global__ void correlation_forward_kernel(float* __restrict__ output,
12 | const float* __restrict__ input1,
13 | const float* __restrict__ input2,
14 | int in_channels, int height, int width, int max_displacement) {
15 | int n = blockIdx.x, y1 = blockIdx.y, x1 = blockIdx.z;
16 |
17 | int displacement_size = 2 * max_displacement + 1;
18 | int out_channels = displacement_size * displacement_size;
19 |
20 | int in_stride0 = height * width * in_channels;
21 | int in_stride1 = width * in_channels;
22 | int in_stride2 = in_channels;
23 |
24 | int out_stride0 = out_channels * height * width;
25 | int out_stride1 = height * width;
26 | int out_stride2 = width;
27 |
28 | for (int tx = -max_displacement; tx <= max_displacement; ++tx)
29 | for (int ty = -max_displacement; ty <= max_displacement; ++ty) {
30 | int x2 = x1 + ty, y2 = y1 + tx;
31 | if (x2 < 0 || y2 < 0 || x2 >= width || y2 >= height) continue;
32 |
33 | float sum = 0.0f;
34 | for (int c = threadIdx.x; c < in_channels; c += BLOCK_SIZE) {
35 | int idx1 = n * in_stride0 + y1 * in_stride1 + x1 * in_stride2 + c;
36 | int idx2 = n * in_stride0 + y2 * in_stride1 + x2 * in_stride2 + c;
37 | sum += input1[idx1] * input2[idx2];
38 | }
39 |
40 | __syncthreads();
41 | sum = warpReduceSum(sum);
42 |
43 | if (threadIdx.x == 0) {
44 | int tc = (tx + max_displacement) * displacement_size + (ty + max_displacement);
45 | int idx = n * out_stride0 + tc * out_stride1 + y1 * out_stride2 + x1;
46 | output[idx] = sum / in_channels;
47 | }
48 | }
49 | }
50 |
51 | void correlation_forward_kernel_wrapper(float* output, const float* input1, const float* input2,
52 | int n_batches, int in_channels, int height, int width, int max_displacement) {
53 | dim3 number_of_blocks(n_batches, height, width);
54 | correlation_forward_kernel<<>>(output, input1, input2, in_channels, height, width, max_displacement);
55 | }
56 |
--------------------------------------------------------------------------------
/models/csrc/correlation/correlation_test.cpp:
--------------------------------------------------------------------------------
1 | #include "correlation.h"
2 | #include
3 | #include
4 | using namespace std;
5 |
6 | void _checkCudaErrors(cudaError_t result, char const *const func, const char *const file, int const line) {
7 | if (result != cudaSuccess) {
8 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line,
9 | static_cast(result), cudaGetErrorName(result), func);
10 | cudaDeviceReset();
11 | exit(EXIT_FAILURE);
12 | }
13 | }
14 | #define checkCudaErrors(val) _checkCudaErrors((val), #val, __FILE__, __LINE__)
15 |
16 | std::vector run_cuda_implementation(torch::Tensor input1, torch::Tensor input2, torch::Tensor grad_output, int max_displacement) {
17 | // nchw -> nhwc
18 | input1 = input1.permute({0, 2, 3, 1}).contiguous();
19 | input2 = input2.permute({0, 2, 3, 1}).contiguous();
20 |
21 | torch::Tensor output = correlation_forward_cuda(input1, input2, max_displacement);
22 | std::pair grads_cuda = correlation_backward_cuda(grad_output, input1, input2, max_displacement);
23 |
24 | return {output, grads_cuda.first, grads_cuda.second};
25 | }
26 |
27 | std::vector run_naive_implementation(torch::Tensor input1, torch::Tensor input2, torch::Tensor grad_output, int max_displacement) {
28 | int batch_size = input1.size(0), in_channels = input1.size(1), height = input1.size(2), width = input1.size(3);
29 | torch::Tensor padded_input2 = torch::nn::functional::detail::pad(input2, {max_displacement, max_displacement, max_displacement, max_displacement}, torch::kConstant, 0);
30 |
31 | std::vector cost_volumes;
32 | for (int i = 0; i < 2 * max_displacement + 1; i++)
33 | for (int j = 0; j < 2 * max_displacement + 1; j++) {
34 | torch::Tensor cost = input1 * padded_input2.slice(2, i, i + height).slice(3, j, j + width);
35 | cost_volumes.push_back(torch::mean(cost, 1, true));
36 | }
37 |
38 | torch::Tensor output = torch::cat(cost_volumes, 1);
39 | output.backward(grad_output);
40 |
41 | return {output, input1.grad(), input2.grad()};
42 | }
43 |
44 | int main() {
45 | constexpr int batch_size = 32;
46 | constexpr int in_channels = 128;
47 | constexpr int height = 144;
48 | constexpr int width = 240;
49 | constexpr int max_displacement = 4;
50 | constexpr int out_channels = (max_displacement * 2 + 1) * (max_displacement * 2 + 1);
51 |
52 | if (!torch::cuda::is_available()) {
53 | cout << "CUDA is not available, exiting..." << endl;
54 | return 1;
55 | }
56 |
57 | torch::manual_seed(0);
58 | torch::Tensor input1 = torch::rand({batch_size, in_channels, height, width}, torch::requires_grad().device("cuda"));
59 | torch::Tensor input2 = torch::rand({batch_size, in_channels, height, width}, torch::requires_grad().device("cuda"));
60 | torch::Tensor grad_output = torch::rand({batch_size, out_channels, height, width}, torch::requires_grad(false).device("cuda"));
61 |
62 | cout << "Running NAIVE implementation of correlation... " << flush;
63 | auto naive_t1 = chrono::high_resolution_clock::now();
64 | std::vector results_naive = run_naive_implementation(input1, input2, grad_output, max_displacement);
65 | checkCudaErrors(cudaDeviceSynchronize());
66 | auto naive_t2 = chrono::high_resolution_clock::now();
67 | cout << "(" << chrono::duration_cast(naive_t2 - naive_t1).count() << "ms)" << endl;
68 |
69 | // warm up...
70 | for (int t = 0; t < 2; t++) {
71 | run_cuda_implementation(input1, input2, grad_output, max_displacement);
72 | checkCudaErrors(cudaDeviceSynchronize());
73 | }
74 |
75 | cout << "Running CUDA implementation of correlation... " << flush;
76 | auto cuda_t1 = chrono::high_resolution_clock::now();
77 | std::vector results_cuda = run_cuda_implementation(input1, input2, grad_output, max_displacement);
78 | checkCudaErrors(cudaDeviceSynchronize());
79 | auto cuda_t2 = chrono::high_resolution_clock::now();
80 | cout << "(" << chrono::duration_cast(cuda_t2 - cuda_t1).count() << "ms)" << endl;
81 |
82 | float diff = torch::mean(torch::abs(results_cuda[0].cpu() - results_naive[0].cpu())).data_ptr()[0];
83 | cout << "Checking forward results... " << (diff < 1e-6 ? "OK" : "Failed") << endl;
84 |
85 | diff = torch::mean(torch::abs(results_cuda[1].cpu() - results_naive[1].cpu())).data_ptr()[0];
86 | cout << "Checking backward results for input1... " << (diff < 1e-6 ? "OK" : "Failed") << endl;
87 |
88 | diff = torch::mean(torch::abs(results_cuda[2].cpu() - results_naive[2].cpu())).data_ptr()[0];
89 | cout << "Checking backward results for input2... " << (diff < 1e-6 ? "OK" : "Failed") << endl;
90 |
91 | return 0;
92 | }
93 |
--------------------------------------------------------------------------------
/models/csrc/furthest_point_sampling/furthest_point_sampling.cpp:
--------------------------------------------------------------------------------
1 | #include "furthest_point_sampling.h"
2 |
3 | void furthest_point_sampling_kernel_wrapper(float* batched_points_xyz, float* batched_dists_temp, int n_batch, int n_points, int n_samples, int64_t* batched_furthest_indices);
4 |
5 | torch::Tensor furthest_point_sampling_cuda(torch::Tensor points_xyz, const int n_samples) {
6 | TORCH_CHECK(points_xyz.is_contiguous(), "points_xyz must be a contiguous tensor");
7 | TORCH_CHECK(points_xyz.is_cuda(), "points_xyz must be a CUDA tensor");
8 | TORCH_CHECK(points_xyz.scalar_type() == torch::ScalarType::Float, "points_xyz must be a float tensor");
9 |
10 | int64_t batch_size = points_xyz.size(0), n_points = points_xyz.size(1);
11 | torch::Tensor furthest_indices = torch::empty({ batch_size, n_samples }, torch::TensorOptions().dtype(torch::kInt64).device(points_xyz.device()));
12 | torch::Tensor dists_temp = torch::ones({batch_size, n_points}, torch::TensorOptions().dtype(torch::kFloat32).device(points_xyz.device())) * 1e10;
13 | furthest_point_sampling_kernel_wrapper(points_xyz.data_ptr(), dists_temp.data_ptr(), batch_size, n_points, n_samples, furthest_indices.data_ptr());
14 |
15 | return furthest_indices;
16 | }
17 |
18 | #ifdef TORCH_EXTENSION_NAME
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("_furthest_point_sampling_cuda", &furthest_point_sampling_cuda, "CUDA implementation of furthest-point-sampling (FPS)");
21 | }
22 | #endif
--------------------------------------------------------------------------------
/models/csrc/furthest_point_sampling/furthest_point_sampling.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | torch::Tensor furthest_point_sampling_cuda(torch::Tensor points_xyz, const int n_samples);
6 |
--------------------------------------------------------------------------------
/models/csrc/furthest_point_sampling/furthest_point_sampling_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #define argmaxReduceMacro(arr, arr_idx, idx1, idx2) {\
6 | if (arr[idx1] <= arr[idx2]) {\
7 | arr[idx1] = arr[idx2];\
8 | arr_idx[idx1] = arr_idx[idx2];\
9 | }\
10 | }
11 |
12 | template
13 | __device__ void warpReduce(volatile float* max_values, volatile int* max_values_idx, int tid) {
14 | if (block_size >= 64) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 32);
15 | if (block_size >= 32) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 16);
16 | if (block_size >= 16) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 8);
17 | if (block_size >= 8) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 4);
18 | if (block_size >= 4) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 2);
19 | if (block_size >= 2) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 1);
20 | }
21 |
22 | template
23 | __device__ int64_t argmax(float* max_values, int* max_values_idx) {
24 | int tid = threadIdx.x;
25 | if (block_size >= 1024) { if (tid < 512) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 512); __syncthreads(); }
26 | if (block_size >= 512) { if (tid < 256) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 256); __syncthreads(); }
27 | if (block_size >= 256) { if (tid < 128) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 128); __syncthreads(); }
28 | if (block_size >= 128) { if (tid < 64) argmaxReduceMacro(max_values, max_values_idx, tid, tid + 64); __syncthreads(); }
29 | if (tid < 32) warpReduce(max_values, max_values_idx, tid);
30 | __syncthreads();
31 | return max_values_idx[0];
32 | }
33 |
34 | template
35 | __global__ void furthest_point_sampling_kernel(float* __restrict__ batched_points_xyz,
36 | float* __restrict__ batched_dists_temp,
37 | int n_points, int n_samples,
38 | int64_t* __restrict__ batched_furthest_indices) {
39 | int bid = blockIdx.x, tid = threadIdx.x;
40 |
41 | float* __restrict__ points_xyz = batched_points_xyz + bid * n_points * 3;
42 | float* __restrict__ dists_temp = batched_dists_temp + bid * n_points;
43 | int64_t* __restrict__ furthest_indices = batched_furthest_indices + bid * n_samples;
44 |
45 | __shared__ float max_dists[block_size];
46 | __shared__ int max_dists_idx[block_size];
47 |
48 | int64_t curr_furthest_idx = 0;
49 | for (int s = 0; s < n_samples; s++) {
50 | if (tid == 0) furthest_indices[s] = curr_furthest_idx;
51 |
52 | float x1 = points_xyz[curr_furthest_idx * 3 + 0];
53 | float y1 = points_xyz[curr_furthest_idx * 3 + 1];
54 | float z1 = points_xyz[curr_furthest_idx * 3 + 2];
55 |
56 | float local_max_dist = -1;
57 | int local_max_dist_idx = 0;
58 |
59 | for (int i = tid; i < n_points; i += block_size) {
60 | float x2 = points_xyz[i * 3 + 0];
61 | float y2 = points_xyz[i * 3 + 1];
62 | float z2 = points_xyz[i * 3 + 2];
63 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
64 | float new_dist = min(dists_temp[i], d);
65 | if (new_dist > local_max_dist) {
66 | local_max_dist = new_dist;
67 | local_max_dist_idx = i;
68 | }
69 | dists_temp[i] = new_dist;
70 | }
71 |
72 | max_dists[tid] = local_max_dist;
73 | max_dists_idx[tid] = local_max_dist_idx;
74 |
75 | __syncthreads();
76 |
77 | curr_furthest_idx = argmax(max_dists, max_dists_idx);
78 | }
79 | }
80 |
81 | void furthest_point_sampling_kernel_wrapper(float* batched_points_xyz, float* batched_dists_temp,
82 | int n_batch, int n_points, int n_samples,
83 | int64_t* batched_furthest_indices) {
84 | furthest_point_sampling_kernel<1024> <<>> (batched_points_xyz, batched_dists_temp, n_points, n_samples, batched_furthest_indices);
85 | }
--------------------------------------------------------------------------------
/models/csrc/furthest_point_sampling/furthest_point_sampling_test.cpp:
--------------------------------------------------------------------------------
1 | #include "furthest_point_sampling.h"
2 | #include
3 | #include
4 | using namespace std;
5 |
6 | void _checkCudaErrors(cudaError_t result, char const *const func, const char *const file, int const line) {
7 | if (result != cudaSuccess) {
8 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line,
9 | static_cast(result), cudaGetErrorName(result), func);
10 | cudaDeviceReset();
11 | exit(EXIT_FAILURE);
12 | }
13 | }
14 | #define checkCudaErrors(val) _checkCudaErrors((val), #val, __FILE__, __LINE__)
15 |
16 | torch::Tensor furthest_point_sampling_std(torch::Tensor points_xyz, const int n_samples) {
17 | int64_t batch_size = points_xyz.size(0), n_points = points_xyz.size(1);
18 | torch::Tensor furthest_indices = torch::empty({batch_size, n_samples}, torch::TensorOptions().dtype(torch::kInt64));
19 |
20 | for (int64_t b = 0; b < batch_size; b++) {
21 | torch::Tensor dists = torch::ones(n_points) * 1e10;
22 | int curr_furthest_idx = 0;
23 | for (int s = 0; s < n_samples; s++) {
24 | furthest_indices[b][s] = curr_furthest_idx;
25 | dists = torch::min(dists, torch::sum((points_xyz[b] - points_xyz[b][curr_furthest_idx]).pow(2), -1));
26 | curr_furthest_idx = dists.argmax().data_ptr()[0];
27 | }
28 | }
29 |
30 | return furthest_indices;
31 | }
32 |
33 | int main() {
34 | constexpr int batch_size = 64;
35 | constexpr int n_points = 4096;
36 | constexpr int n_samples = 1024;
37 |
38 | if (!torch::cuda::is_available()) {
39 | cout << "CUDA is not available, exiting..." << endl;
40 | return 1;
41 | }
42 |
43 | torch::manual_seed(0);
44 | torch::Tensor points_xyz = torch::rand({batch_size, n_points, 3}), points_xyz_cuda = points_xyz.cuda();
45 |
46 | // warm up...
47 | for (int t = 0; t < 3; t++) furthest_point_sampling_cuda(points_xyz_cuda, n_samples);
48 | checkCudaErrors(cudaDeviceSynchronize());
49 |
50 | cout << "Running furthest-point-sampling on GPU... " << flush;
51 | auto t1 = chrono::high_resolution_clock::now();
52 | torch::Tensor indices_gpu = furthest_point_sampling_cuda(points_xyz_cuda, n_samples);
53 | checkCudaErrors(cudaDeviceSynchronize());
54 | auto t2 = chrono::high_resolution_clock::now();
55 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl;
56 |
57 | cout << "Running furthest-point-sampling on CPU... " << flush;
58 | t1 = chrono::high_resolution_clock::now();
59 | torch::Tensor indices_std = furthest_point_sampling_std(points_xyz, n_samples);
60 | t2 = chrono::high_resolution_clock::now();
61 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl;
62 |
63 | cout << "Checking results... " << (torch::equal(indices_std.cpu(), indices_gpu.cpu()) ? "OK" : "Failed") << endl;
64 | }
--------------------------------------------------------------------------------
/models/csrc/k_nearest_neighbor/k_nearest_neighbor.cpp:
--------------------------------------------------------------------------------
1 | #include "k_nearest_neighbor.h"
2 |
3 | void k_nearest_neighbor_2d_kernel_wrapper(int b, int n, int m, int k, const float *query_xyz, const float *input_xyz, int64_t *indices);
4 | void k_nearest_neighbor_3d_kernel_wrapper(int b, int n, int m, int k, const float *query_xyz, const float *input_xyz, int64_t *indices);
5 |
6 | torch::Tensor k_nearest_neighbor_cuda(torch::Tensor input_xyz, torch::Tensor query_xyz, int k) {
7 | TORCH_CHECK(input_xyz.is_contiguous(), "input_xyz must be a contiguous tensor");
8 | TORCH_CHECK(input_xyz.is_cuda(), "input_xyz must be a CUDA tensor");
9 | TORCH_CHECK(input_xyz.scalar_type() == torch::ScalarType::Float, "input_xyz must be a float tensor");
10 |
11 | TORCH_CHECK(query_xyz.is_contiguous(), "query_xyz must be a contiguous tensor");
12 | TORCH_CHECK(query_xyz.is_cuda(), "query_xyz must be a CUDA tensor");
13 | TORCH_CHECK(query_xyz.scalar_type() == torch::ScalarType::Float, "query_xyz must be a float tensor");
14 |
15 | int batch_size = query_xyz.size(0), n_queries = query_xyz.size(1), n_inputs = input_xyz.size(1), n_dim = query_xyz.size(2);
16 | torch::Tensor indices = torch::zeros({batch_size, n_queries, k}, torch::device(query_xyz.device()).dtype(torch::ScalarType::Long));
17 |
18 | if (n_dim == 2)
19 | k_nearest_neighbor_2d_kernel_wrapper(batch_size, n_queries, n_inputs, k, query_xyz.data_ptr(), input_xyz.data_ptr(), indices.data_ptr());
20 | else
21 | k_nearest_neighbor_3d_kernel_wrapper(batch_size, n_queries, n_inputs, k, query_xyz.data_ptr(), input_xyz.data_ptr(), indices.data_ptr());
22 |
23 | return indices;
24 | }
25 |
26 | #ifdef TORCH_EXTENSION_NAME
27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
28 | m.def("_k_nearest_neighbor_cuda", &k_nearest_neighbor_cuda, "CUDA implementation of KNN");
29 | }
30 | #endif
--------------------------------------------------------------------------------
/models/csrc/k_nearest_neighbor/k_nearest_neighbor.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | torch::Tensor k_nearest_neighbor_cuda(torch::Tensor input_xyz, torch::Tensor query_xyz, int k);
--------------------------------------------------------------------------------
/models/csrc/k_nearest_neighbor/k_nearest_neighbor_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #define THREADS_PER_BLOCK 256
6 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
7 |
8 | __global__ void k_nearest_neighbor_2d_kernel(int b, int n, int m, int k,
9 | const float *__restrict__ query_xyz,
10 | const float *__restrict__ input_xyz,
11 | int64_t *__restrict__ indices) {
12 | int bs_idx = blockIdx.y;
13 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
14 | if (bs_idx >= b || pt_idx >= n) return;
15 |
16 | query_xyz += bs_idx * n * 2 + pt_idx * 2;
17 | input_xyz += bs_idx * m * 2;
18 | indices += bs_idx * n * k + pt_idx * k;
19 |
20 | float ux = query_xyz[0];
21 | float uy = query_xyz[1];
22 |
23 | // initialize
24 | float nn_dists[32]; int nn_indices[32];
25 | for (int i = 0; i < 32; i++) {
26 | nn_dists[i] = 1e9;
27 | nn_indices[i] = 0;
28 | }
29 |
30 | for (int idx = 0; idx < m; idx++) {
31 | float x = input_xyz[idx * 2 + 0];
32 | float y = input_xyz[idx * 2 + 1];
33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y);
34 | if (d > nn_dists[k - 1]) continue;
35 |
36 | int j = min(idx, k - 1);
37 | while (j > 0 && nn_dists[j - 1] > d) {
38 | nn_dists[j] = nn_dists[j - 1];
39 | nn_indices[j] = nn_indices[j - 1];
40 | j--;
41 | }
42 |
43 | nn_dists[j] = d;
44 | nn_indices[j] = idx;
45 | }
46 |
47 | for (int i = 0; i < k; i++)
48 | indices[i] = nn_indices[i];
49 | }
50 |
51 | __global__ void k_nearest_neighbor_3d_kernel(int b, int n, int m, int k,
52 | const float *__restrict__ query_xyz,
53 | const float *__restrict__ input_xyz,
54 | int64_t *__restrict__ indices) {
55 | int bs_idx = blockIdx.y;
56 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
57 | if (bs_idx >= b || pt_idx >= n) return;
58 |
59 | query_xyz += bs_idx * n * 3 + pt_idx * 3;
60 | input_xyz += bs_idx * m * 3;
61 | indices += bs_idx * n * k + pt_idx * k;
62 |
63 | float ux = query_xyz[0];
64 | float uy = query_xyz[1];
65 | float uz = query_xyz[2];
66 |
67 | // initialize
68 | float nn_dists[32]; int nn_indices[32];
69 | for (int i = 0; i < 32; i++) {
70 | nn_dists[i] = 1e9;
71 | nn_indices[i] = 0;
72 | }
73 |
74 | for (int idx = 0; idx < m; idx++) {
75 | float x = input_xyz[idx * 3 + 0];
76 | float y = input_xyz[idx * 3 + 1];
77 | float z = input_xyz[idx * 3 + 2];
78 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
79 | if (d > nn_dists[k - 1]) continue;
80 |
81 | int j = min(idx, k - 1);
82 | while (j > 0 && nn_dists[j - 1] > d) {
83 | nn_dists[j] = nn_dists[j - 1];
84 | nn_indices[j] = nn_indices[j - 1];
85 | j--;
86 | }
87 |
88 | nn_dists[j] = d;
89 | nn_indices[j] = idx;
90 | }
91 |
92 | for (int i = 0; i < k; i++)
93 | indices[i] = nn_indices[i];
94 | }
95 |
96 | void k_nearest_neighbor_2d_kernel_wrapper(int b, int n, int m, int k,
97 | const float *query_xyz,
98 | const float *input_xyz,
99 | int64_t *indices) {
100 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
101 | dim3 threads(THREADS_PER_BLOCK);
102 | k_nearest_neighbor_2d_kernel<<>>(b, n, m, k, query_xyz, input_xyz, indices);
103 | }
104 |
105 | void k_nearest_neighbor_3d_kernel_wrapper(int b, int n, int m, int k,
106 | const float *query_xyz,
107 | const float *input_xyz,
108 | int64_t *indices) {
109 | dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
110 | dim3 threads(THREADS_PER_BLOCK);
111 | k_nearest_neighbor_3d_kernel<<>>(b, n, m, k, query_xyz, input_xyz, indices);
112 | }
113 |
--------------------------------------------------------------------------------
/models/csrc/k_nearest_neighbor/k_nearest_neighbor_test.cpp:
--------------------------------------------------------------------------------
1 | #include "k_nearest_neighbor.h"
2 | #include
3 | #include
4 | using namespace std;
5 |
6 | void _checkCudaErrors(cudaError_t result, char const *const func, const char *const file, int const line) {
7 | if (result != cudaSuccess) {
8 | fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line,
9 | static_cast(result), cudaGetErrorName(result), func);
10 | cudaDeviceReset();
11 | exit(EXIT_FAILURE);
12 | }
13 | }
14 | #define checkCudaErrors(val) _checkCudaErrors((val), #val, __FILE__, __LINE__)
15 |
16 | torch::Tensor k_nearest_neighbor_std(torch::Tensor input_xyz, torch::Tensor query_xyz, int k) {
17 | int64_t batch_size = input_xyz.size(0), n_points_1 = input_xyz.size(1), n_points_2 = query_xyz.size(1);
18 | torch::Tensor dists = -2 * torch::matmul(query_xyz, input_xyz.permute({0, 2, 1}));
19 | dists += torch::sum(query_xyz.pow(2), -1).view({batch_size, n_points_2, 1});
20 | dists += torch::sum(input_xyz.pow(2), -1).view({batch_size, 1, n_points_1});
21 | return std::get<1>(dists.topk(k, 2, false));
22 | }
23 |
24 | int main() {
25 | constexpr int batch_size = 8;
26 | constexpr int n_points_input = 8192;
27 | constexpr int n_points_query = 8192;
28 | constexpr int k = 16;
29 | constexpr int dim = 3;
30 |
31 | if (!torch::cuda::is_available()) {
32 | cout << "CUDA is not available, exiting..." << endl;
33 | return 1;
34 | }
35 |
36 | torch::manual_seed(0);
37 | torch::Tensor input_xyz = torch::rand({batch_size, n_points_input, dim}).cuda();
38 | torch::Tensor query_xyz = torch::rand({batch_size, n_points_query, dim}).cuda();
39 |
40 | // warm up...
41 | for (int t = 0; t < 3; t++) {
42 | k_nearest_neighbor_cuda(input_xyz, query_xyz, k);
43 | k_nearest_neighbor_std(input_xyz, query_xyz, k);
44 | }
45 | checkCudaErrors(cudaDeviceSynchronize());
46 |
47 | cout << "Running KNN using custom CUDA implementation... " << flush;
48 | auto t1 = chrono::high_resolution_clock::now();
49 | torch::Tensor indices_gpu = k_nearest_neighbor_cuda(input_xyz, query_xyz, k);
50 | checkCudaErrors(cudaDeviceSynchronize());
51 | auto t2 = chrono::high_resolution_clock::now();
52 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl;
53 |
54 | cout << "Running KNN using Torch's API... " << flush;
55 | t1 = chrono::high_resolution_clock::now();
56 | torch::Tensor indices_std = k_nearest_neighbor_std(input_xyz, query_xyz, k);
57 | checkCudaErrors(cudaDeviceSynchronize());
58 | t2 = chrono::high_resolution_clock::now();
59 | cout << "(" << chrono::duration_cast(t2 - t1).count() << "ms)" << endl;
60 |
61 | torch::Scalar diff_num = (indices_std.cpu() != indices_gpu.cpu()).sum().item();
62 | torch::Scalar total_num = batch_size * n_points_query * k;
63 | cout << "Checking results... " << diff_num << " of " << total_num << " elements are mismatched." << endl;
64 |
65 | return 0;
66 | }
--------------------------------------------------------------------------------
/models/csrc/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | def get_ext_modules():
6 | return [
7 | CUDAExtension(
8 | name='_correlation_cuda',
9 | sources=[
10 | 'correlation/correlation.cpp',
11 | 'correlation/correlation_forward_kernel.cu',
12 | 'correlation/correlation_backward_kernel.cu'
13 | ],
14 | include_dirs=['correlation']
15 | ),
16 | CUDAExtension(
17 | name='_furthest_point_sampling_cuda',
18 | sources=[
19 | 'furthest_point_sampling/furthest_point_sampling.cpp',
20 | 'furthest_point_sampling/furthest_point_sampling_kernel.cu'
21 | ],
22 | include_dirs=['furthest_point_sampling']
23 | ),
24 | CUDAExtension(
25 | name='_k_nearest_neighbor_cuda',
26 | sources=[
27 | 'k_nearest_neighbor/k_nearest_neighbor.cpp',
28 | 'k_nearest_neighbor/k_nearest_neighbor_kernel.cu'
29 | ],
30 | include_dirs=['k_nearest_neighbor']
31 | )
32 | ]
33 |
34 |
35 | setup(
36 | name='csrc',
37 | ext_modules=get_ext_modules(),
38 | cmdclass={'build_ext': BuildExtension}
39 | )
40 |
--------------------------------------------------------------------------------
/models/csrc/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional
3 |
4 | try:
5 | from ._correlation_cuda import _correlation_forward_cuda
6 | from ._correlation_cuda import _correlation_backward_cuda
7 | from ._furthest_point_sampling_cuda import _furthest_point_sampling_cuda
8 | from ._k_nearest_neighbor_cuda import _k_nearest_neighbor_cuda
9 | except ImportError as e:
10 | _correlation_forward_cuda = None
11 | _correlation_backward_cuda = None
12 | _furthest_point_sampling_cuda = None
13 | _k_nearest_neighbor_cuda = None
14 | print('Failed to load one or more CUDA extensions, performance may be hurt.')
15 | print('Error message:', e)
16 |
17 |
18 | class CorrelationFunction(torch.autograd.Function):
19 | @staticmethod
20 | def forward(ctx, input1, input2, max_displacement):
21 | ctx.save_for_backward(input1, input2)
22 | ctx.max_displacement = max_displacement
23 | assert callable(_correlation_forward_cuda)
24 | return _correlation_forward_cuda(input1, input2, max_displacement)
25 |
26 | @staticmethod
27 | def backward(ctx, grad_output):
28 | input1, input2 = ctx.saved_tensors
29 |
30 | assert callable(_correlation_backward_cuda)
31 | grad_input1, grad_input2 = _correlation_backward_cuda(
32 | grad_output, input1, input2, ctx.max_displacement
33 | )
34 | grad_input1 = grad_input1.permute(0, 2, 3, 1).contiguous()
35 | grad_input2 = grad_input2.permute(0, 2, 3, 1).contiguous()
36 |
37 | return grad_input1, grad_input2, None
38 |
39 |
40 | def squared_distance(xyz1: torch.Tensor, xyz2: torch.Tensor):
41 | """
42 | Calculate the Euclidean squared distance between every two points.
43 | :param xyz1: the 1st set of points, [batch_size, n_points_1, 3]
44 | :param xyz2: the 2nd set of points, [batch_size, n_points_2, 3]
45 | :return: squared distance between every two points, [batch_size, n_points_1, n_points_2]
46 | """
47 | assert xyz1.shape[-1] == xyz2.shape[-1] and xyz1.shape[-1] <= 3 # assert channel_last
48 | batch_size, n_points1, n_points2 = xyz1.shape[0], xyz1.shape[1], xyz2.shape[1]
49 | dist = -2 * torch.matmul(xyz1, xyz2.permute(0, 2, 1))
50 | dist += torch.sum(xyz1 ** 2, -1).view(batch_size, n_points1, 1)
51 | dist += torch.sum(xyz2 ** 2, -1).view(batch_size, 1, n_points2)
52 | return dist
53 |
54 |
55 | def correlation2d(input1: torch.Tensor, input2: torch.Tensor, max_displacement: int, cpp_impl=True):
56 | def _correlation_py(_input1, _input2, _max_displacement):
57 | height, width = _input1.shape[2:]
58 | _input2 = torch.nn.functional.pad(_input2, [_max_displacement] * 4)
59 | cost_volumes = []
60 | for i in range(2 * _max_displacement + 1):
61 | for j in range(2 * _max_displacement + 1):
62 | cost_volume = _input1 * _input2[:, :, i:(i + height), j:(j + width)]
63 | cost_volume = torch.mean(cost_volume, 1, keepdim=True)
64 | cost_volumes.append(cost_volume)
65 | return torch.cat(cost_volumes, 1)
66 |
67 | if cpp_impl and callable(_correlation_forward_cuda) and callable(_correlation_backward_cuda) and input1.is_cuda and input2.is_cuda:
68 | input1 = input1.permute(0, 2, 3, 1).contiguous().float()
69 | input2 = input2.permute(0, 2, 3, 1).contiguous().float()
70 | return CorrelationFunction.apply(input1, input2, max_displacement)
71 | else:
72 | return _correlation_py(input1, input2, max_displacement)
73 |
74 |
75 | def furthest_point_sampling(xyz: torch.Tensor, n_samples: int, cpp_impl=True):
76 | """
77 | Perform furthest point sampling on a set of points.
78 | :param xyz: a set of points, [batch_size, n_points, 3]
79 | :param n_samples: number of samples, int
80 | :param cpp_impl: whether to use the CUDA C++ implementation of furthest-point-sampling
81 | :return: indices of sampled points, [batch_size, n_samples]
82 | """
83 | def _furthest_point_sampling_py(_xyz: torch.Tensor, _n_samples: int):
84 | batch_size, n_points, _ = _xyz.shape
85 | farthest_indices = torch.zeros(batch_size, _n_samples, dtype=torch.int64, device=_xyz.device)
86 | distances = torch.ones(batch_size, n_points, device=_xyz.device) * 1e10
87 | batch_indices = torch.arange(batch_size, dtype=torch.int64, device=_xyz.device)
88 | curr_farthest_idx = torch.zeros(batch_size, dtype=torch.int64, device=_xyz.device)
89 | for i in range(_n_samples):
90 | farthest_indices[:, i] = curr_farthest_idx
91 | curr_farthest = _xyz[batch_indices, curr_farthest_idx, :].view(batch_size, 1, 3)
92 | new_distances = torch.sum((_xyz - curr_farthest) ** 2, -1)
93 | mask = new_distances < distances
94 | distances[mask] = new_distances[mask]
95 | curr_farthest_idx = torch.max(distances, -1)[1]
96 | return farthest_indices
97 |
98 | assert xyz.shape[2] == 3 and xyz.shape[1] > n_samples
99 |
100 | if cpp_impl and callable(_furthest_point_sampling_cuda) and xyz.is_cuda:
101 | return _furthest_point_sampling_cuda(xyz.contiguous(), n_samples).to(torch.int64)
102 | else:
103 | return _furthest_point_sampling_py(xyz, n_samples).to(torch.int64)
104 |
105 |
106 | def k_nearest_neighbor(input_xyz: torch.Tensor, query_xyz: torch.Tensor, k: int, cpp_impl=True):
107 | """
108 | Calculate k-nearest neighbor for each query.
109 | :param input_xyz: a set of points, [batch_size, n_points, 3] or [batch_size, 3, n_points]
110 | :param query_xyz: a set of centroids, [batch_size, n_queries, 3] or [batch_size, 3, n_queries]
111 | :param k: int
112 | :param cpp_impl: whether to use the CUDA C++ implementation of k-nearest-neighbor
113 | :return: indices of k-nearest neighbors, [batch_size, n_queries, k]
114 | """
115 | def _k_nearest_neighbor_py(_input_xyz: torch.Tensor, _query_xyz: torch.Tensor, _k: int):
116 | dists = squared_distance(_query_xyz, _input_xyz)
117 | return dists.topk(_k, dim=2, largest=False).indices.to(torch.long)
118 |
119 | if input_xyz.shape[1] <= 3: # channel_first to channel_last
120 | assert query_xyz.shape[1] == input_xyz.shape[1]
121 | input_xyz = input_xyz.transpose(1, 2).contiguous()
122 | query_xyz = query_xyz.transpose(1, 2).contiguous()
123 |
124 | if cpp_impl and callable(_k_nearest_neighbor_cuda) and input_xyz.is_cuda and query_xyz.is_cuda:
125 | return _k_nearest_neighbor_cuda(input_xyz.contiguous(), query_xyz.contiguous(), k)
126 | else:
127 | return _k_nearest_neighbor_py(input_xyz, query_xyz, k)
128 |
--------------------------------------------------------------------------------
/models/losses2d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.functional import interpolate
4 | from .utils import backwarp_2d, resize_flow2d
5 |
6 |
7 | def calc_supervised_loss_2d(flows, target, cfgs):
8 | assert len(flows) <= len(cfgs.level_weights)
9 |
10 | total_loss = 0
11 | for pred, level_weight in zip(flows, cfgs.level_weights):
12 | assert pred.shape[1] == 2 # [B, 2, H, W]
13 |
14 | flow_mask = target[:, 2] > 0
15 |
16 | diff = torch.abs(resize_flow2d(pred, target.shape[2], target.shape[3]) - target[:, :2])
17 |
18 | if cfgs.order == 'robust':
19 | loss_l1_map = torch.pow(diff.sum(dim=1) + 0.01, 0.4)
20 | loss_l1 = loss_l1_map[flow_mask].mean()
21 | total_loss += level_weight * loss_l1
22 | elif cfgs.order == 'l2-norm':
23 | loss_l2_map = torch.linalg.norm(diff, dim=1)
24 | loss_l2 = loss_l2_map[flow_mask].mean()
25 | total_loss += level_weight * loss_l2
26 | else:
27 | raise NotImplementedError
28 |
29 | return total_loss
30 |
31 |
32 | def calc_census_loss_2d(image1, image2, noc_mask=None, max_distance=1):
33 | """
34 | Calculate photometric loss based on census transform.
35 | :param image1: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB
36 | :param image2: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB
37 | :param noc_mask: [N, 1, H, W] float tensor, ranging from 0 to 1
38 | :param max_distance: int
39 | """
40 | def rgb_to_grayscale(image):
41 | grayscale = image[:, 0, :, :] * 0.2989 + \
42 | image[:, 1, :, :] * 0.5870 + \
43 | image[:, 2, :, :] * 0.1140
44 | return grayscale.unsqueeze(1) * 255.0
45 |
46 | def census_transform(gray_image):
47 | patch_size = 2 * max_distance + 1
48 | out_channels = patch_size * patch_size # 9
49 | weights = torch.eye(out_channels, dtype=gray_image.dtype, device=gray_image.device)
50 | weights = weights.view([out_channels, 1, patch_size, patch_size]) # [9, 1, 3, 3]
51 | patches = nn.functional.conv2d(gray_image, weights, padding=max_distance)
52 | result = patches - gray_image
53 | result = result / torch.sqrt(0.81 + torch.pow(result, 2))
54 | return result
55 |
56 | if noc_mask is not None:
57 | image1 = noc_mask * image1
58 | image2 = noc_mask * image2
59 |
60 | gray_image1 = rgb_to_grayscale(image1)
61 | gray_image2 = rgb_to_grayscale(image2)
62 |
63 | t1 = census_transform(gray_image1)
64 | t2 = census_transform(gray_image2)
65 |
66 | dist = torch.pow(t1 - t2, 2)
67 | dist_norm = dist / (0.1 + dist)
68 | dist_mean = torch.mean(dist_norm, 1, keepdim=True) # instead of sum
69 |
70 | n, _, h, w = image1.shape
71 | inner = torch.ones([n, 1, h - 2 * max_distance, w - 2 * max_distance], dtype=image1.dtype, device=image1.device)
72 | inner_mask = nn.functional.pad(inner, [max_distance] * 4)
73 | loss = dist_mean * inner_mask
74 |
75 | if noc_mask is not None:
76 | return loss.mean() / (noc_mask.mean() + 1e-7)
77 | else:
78 | return loss.mean()
79 |
80 |
81 | @torch.cuda.amp.autocast(enabled=False)
82 | def calc_smooth_loss_2d(image, flow, derivative='first'):
83 | """
84 | :param image: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB
85 | :param flow: [N, 2, H, W] float tensor
86 | :param derivative: 'first' or 'second'
87 | """
88 | def gradient(inputs):
89 | dy = inputs[:, :, 1:, :] - inputs[:, :, :-1, :]
90 | dx = inputs[:, :, :, 1:] - inputs[:, :, :, :-1]
91 | return dx, dy
92 |
93 | image_dx, image_dy = gradient(image)
94 | flow_dx, flow_dy = gradient(flow)
95 |
96 | weights_x = torch.exp(-torch.mean(image_dx.abs(), 1, keepdim=True) * 10)
97 | weights_y = torch.exp(-torch.mean(image_dy.abs(), 1, keepdim=True) * 10)
98 |
99 | if derivative == 'first':
100 | loss_x = weights_x * flow_dx.abs() / 2.0
101 | loss_y = weights_y * flow_dy.abs() / 2.0
102 | elif derivative == 'second':
103 | flow_dx2 = flow_dx[:, :, :, 1:] - flow_dx[:, :, :, :-1]
104 | flow_dy2 = flow_dy[:, :, 1:, :] - flow_dy[:, :, :-1, :]
105 | loss_x = weights_x[:, :, :, 1:] * flow_dx2.abs()
106 | loss_y = weights_y[:, :, 1:, :] * flow_dy2.abs()
107 | else:
108 | raise NotImplementedError('Unknown derivative: %s' % derivative)
109 |
110 | return loss_x.mean() / 2 + loss_y.mean() / 2
111 |
112 |
113 | def calc_ssim_loss_2d(image1, image2, noc_mask=None, max_distance=1):
114 | """
115 | Calculate photometric loss based on SSIM.
116 | :param image1: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB
117 | :param image2: [N, 3, H, W] float tensor, ranging from 0 to 1, RGB
118 | :param noc_mask: [N, 1, H, W] float tensor, ranging from 0 to 1
119 | :param max_distance: int
120 | """
121 | patch_size = 2 * max_distance + 1
122 | c1, c2 = 0.01 ** 2, 0.03 ** 2
123 |
124 | if noc_mask is not None:
125 | image1 = noc_mask * image1
126 | image2 = noc_mask * image2
127 |
128 | mu_x = nn.AvgPool2d(patch_size, 1, 0)(image1)
129 | mu_y = nn.AvgPool2d(patch_size, 1, 0)(image2)
130 | mu_x_square, mu_y_square = mu_x.pow(2), mu_y.pow(2)
131 | mu_xy = mu_x * mu_y
132 |
133 | sigma_x = nn.AvgPool2d(patch_size, 1, 0)(image1 * image1) - mu_x_square
134 | sigma_y = nn.AvgPool2d(patch_size, 1, 0)(image2 * image2) - mu_y_square
135 | sigma_xy = nn.AvgPool2d(patch_size, 1, 0)(image1 * image2) - mu_xy
136 |
137 | ssim_n = (2 * mu_xy + c1) * (2 * sigma_xy + c2)
138 | ssim_d = (mu_x_square + mu_y_square + c1) * (sigma_x + sigma_y + c2)
139 | ssim = ssim_n / ssim_d
140 | loss = torch.clamp((1 - ssim) / 2, min=0.0, max=1.0)
141 |
142 | if noc_mask is not None:
143 | return loss.mean() / (noc_mask.mean() + 1e-7)
144 | else:
145 | return loss.mean()
146 |
147 |
148 | def calc_unsupervised_loss_2d(pyramid_flows12, pyramid_flows21, image1, image2, occ_mask1, occ_mask2, cfgs):
149 | photo_loss = smooth_loss = 0
150 | for lv, (pyramid_flow12, pyramid_flow21) in enumerate(zip(pyramid_flows12, pyramid_flows21)):
151 | if lv == 0:
152 | image1_scaled, noc_mask1_scaled = image1, 1 - occ_mask1[:, None, :, :]
153 | image2_scaled, noc_mask2_scaled = image2, 1 - occ_mask2[:, None, :, :]
154 | else:
155 | curr_h, curr_w = pyramid_flow12.shape[2], pyramid_flow12.shape[3]
156 | image1_scaled = interpolate(image1, (curr_h, curr_w), mode='area')
157 | image2_scaled = interpolate(image2, (curr_h, curr_w), mode='area')
158 | noc_mask1_scaled = 1 - interpolate(occ_mask1[:, None, :, :], (curr_h, curr_w), mode='nearest')
159 | noc_mask2_scaled = 1 - interpolate(occ_mask2[:, None, :, :], (curr_h, curr_w), mode='nearest')
160 |
161 | image1_scaled_warp = backwarp_2d(image1_scaled, pyramid_flow21, padding_mode='border')
162 | image2_scaled_warp = backwarp_2d(image2_scaled, pyramid_flow12, padding_mode='border')
163 |
164 | # calculate photometric loss
165 | if cfgs.photometric_loss == 'ssim':
166 | photo_loss1 = calc_ssim_loss_2d(image1_scaled, image2_scaled_warp, noc_mask1_scaled)
167 | photo_loss2 = calc_ssim_loss_2d(image2_scaled, image1_scaled_warp, noc_mask2_scaled)
168 | elif cfgs.photometric_loss == 'census':
169 | photo_loss1 = calc_census_loss_2d(image1_scaled, image2_scaled_warp, noc_mask1_scaled)
170 | photo_loss2 = calc_census_loss_2d(image2_scaled, image1_scaled_warp, noc_mask2_scaled)
171 | else:
172 | raise NotImplementedError('Unknown photometric loss: %s' % cfgs.photometric_loss)
173 | photo_loss += cfgs.photometric_weights[lv] * (photo_loss1 + photo_loss2) / 2
174 |
175 | # calculate smooth loss
176 | scale = min(pyramid_flows12[0].shape[2], pyramid_flows12[0].shape[3])
177 | smooth_loss1 = calc_smooth_loss_2d(image1_scaled, pyramid_flow12 / scale, cfgs.smooth_derivative)
178 | smooth_loss2 = calc_smooth_loss_2d(image2_scaled, pyramid_flow21 / scale, cfgs.smooth_derivative)
179 | smooth_loss += cfgs.smooth_weights[lv] * (smooth_loss1 + smooth_loss2) / 2
180 |
181 | return photo_loss, smooth_loss
182 |
--------------------------------------------------------------------------------
/models/losses3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | def calc_supervised_loss_3d(flows, targets, cfgs, masks):
6 | assert len(flows) <= len(cfgs.level_weights)
7 |
8 | total_loss = 0
9 | for idx, (flow, level_weight) in enumerate(zip(flows, cfgs.level_weights)):
10 | level_target = targets[idx]
11 |
12 | mask = masks[idx]
13 |
14 | diff = flow - level_target # B, 3, H, W
15 |
16 | if cfgs.order == 'robust':
17 | epe_l1 = torch.pow(diff.abs().sum(dim=1) + 0.01, 0.4)[mask].mean()
18 | total_loss += level_weight * epe_l1
19 | elif cfgs.order == 'l2-norm':
20 | epe_l2 = torch.linalg.norm(diff, dim=1)[mask].mean()
21 | total_loss += level_weight * epe_l2
22 | else:
23 | raise NotImplementedError
24 |
25 | return total_loss
26 |
27 |
28 |
--------------------------------------------------------------------------------
/models/pointconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .utils import MLP2d, LayerNormCF1d, batch_indexing, knn_grouping_2d, mask_batch_selecting
4 |
5 |
6 | class PointConvS(nn.Module):
7 | def __init__(self, in_channels, out_channels, norm=None, act='leaky_relu', k=16):
8 | super().__init__()
9 | self.k = k
10 |
11 | self.weight_net = MLP2d(3, [8, 16], act=act)
12 | self.linear = nn.Linear(16 * (in_channels + 3), out_channels)
13 |
14 | if norm == 'batch_norm':
15 | self.norm_fn = nn.BatchNorm1d(out_channels, affine=True)
16 | elif norm == 'instance_norm':
17 | self.norm_fn = nn.InstanceNorm1d(out_channels, affine=True)
18 | elif norm == 'layer_norm':
19 | self.norm_fn = LayerNormCF1d(out_channels)
20 | elif norm is None:
21 | self.norm_fn = nn.Identity()
22 | else:
23 | raise NotImplementedError('Unknown normalization function: %s' % norm)
24 |
25 | if act == 'relu':
26 | self.act_fn = nn.ReLU(inplace=True)
27 | elif act == 'leaky_relu':
28 | self.act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True)
29 | elif act is None:
30 | self.act_fn = nn.Identity()
31 | else:
32 | raise NotImplementedError('Unknown activation function: %s' % act)
33 |
34 | def forward(self, xyz, features, sampled_xyz=None, knn_indices=None):
35 | """
36 | :param xyz: 3D locations of points, [batch_size, 3, n_points]
37 | :param features: features of points, [batch_size, in_channels, n_points]
38 | :param sampled_xyz: 3D locations of sampled points, [batch_size, 3, n_samples]
39 | :return weighted_features: features of sampled points, [batch_size, out_channels, n_samples]
40 | """
41 | if sampled_xyz is None:
42 | sampled_xyz = xyz
43 |
44 | bs, n_samples = sampled_xyz.shape[0], sampled_xyz.shape[-1]
45 | features = torch.cat([xyz, features], dim=1) # [bs, in_channels + 3, n_points]
46 | features_cl = features.transpose(1, 2) # [bs, n_points, n_channels + 3]
47 |
48 | # Calculate k nearest neighbors
49 | if knn_indices is None:
50 | knn_indices = k_nearest_neighbor(xyz, sampled_xyz, self.k) # [bs, n_samples, k]
51 | else:
52 | assert knn_indices.shape[:2] == torch.Size([bs, n_samples])
53 | assert knn_indices.shape[2] >= self.k
54 | knn_indices = knn_indices[:, :, :self.k]
55 |
56 | # Calculate weights
57 | knn_xyz = batch_indexing(xyz, knn_indices) # [bs, 3, n_samples, k]
58 | knn_xyz_norm = knn_xyz - sampled_xyz[:, :, :, None] # [bs, 3, n_samples, k]
59 | weights = self.weight_net(knn_xyz_norm) # [bs, n_weights, n_samples, k]
60 |
61 | # Calculate weighted features
62 | weights = weights.transpose(1, 2) # [bs, n_samples, n_weights, k]
63 | knn_features = batch_indexing(features_cl, knn_indices, layout='channel_last') # [bs, n_samples, k, 3 + in_channels]
64 | out = torch.matmul(weights, knn_features) # [bs, n_samples, n_weights, 3 + in_channels]
65 | out = out.view(bs, n_samples, -1) # [bs, n_samples, (3 + in_channels) * n_weights]
66 | out = self.linear(out) # [bs, n_samples, out_channels]
67 | out = self.act_fn(self.norm_fn(out.transpose(1, 2))) # [bs, out_channels, n_samples]
68 |
69 | return out
70 |
71 |
72 | class PointConv(nn.Module):
73 |
74 | def __init__(self, in_channels, out_channels, ks = [10, 20], dist = 100.0, norm=None, act='leaky_relu', k=16):
75 | super().__init__()
76 | self.k = k
77 | self.kernel_size = ks
78 | self.distance = dist
79 | self.in_channels = in_channels
80 | self.out_channels = out_channels
81 | self.weight_net = MLP2d(3, [8, 16], act=act)
82 | self.linear = nn.Linear(16 * (in_channels + 3), out_channels)
83 |
84 | if norm == 'batch_norm':
85 | self.norm_fn = nn.BatchNorm1d(out_channels)
86 | elif norm == 'instance_norm':
87 | self.norm_fn = nn.InstanceNorm1d(out_channels)
88 | elif norm is None:
89 | self.norm_fn = nn.Identity()
90 | else:
91 | raise NotImplementedError('Unknown normalization function: %s' % norm)
92 |
93 | if act == 'relu':
94 | self.act_fn = nn.ReLU(inplace=True)
95 | elif act == 'leaky_relu':
96 | self.act_fn = nn.LeakyReLU(negative_slope=0.1, inplace=True)
97 | elif act is None:
98 | self.act_fn = nn.Identity()
99 | else:
100 | raise NotImplementedError('Unknown act function: %s' % act)
101 |
102 | def forward(self, xyz, features, sampled_xyz = None, knn_indices=None, valid_knn_mask = None):
103 |
104 | """
105 | :param xyz: [batch_size, 3, H, W]
106 | :param features: [batch_size, C, H, W]
107 | :param sampled_xyz: [batch_size, 3, h, w]
108 | :return: out: [batch_size, C', h, w]
109 | """
110 | if sampled_xyz == None:
111 | sampled_xyz = xyz
112 |
113 |
114 | B, C, H, W = features.shape
115 | h, w = sampled_xyz.shape[2:]
116 | features = torch.cat([xyz, features], dim=1).reshape(B, C + 3, H * W) # [B, in_channels + 3, H, W]
117 | features_cl = features.transpose(1, 2) # [B, H * W, in_channels + 3]
118 |
119 | ################# Calculate k nearest neighbors
120 | if knn_indices is None:
121 | # [B, h*w, k], [B, h*w, k]
122 | knn_indices, valid_knn_mask = knn_grouping_2d(sampled_xyz, xyz, self.k)
123 | else:
124 | assert knn_indices.shape[:2] == torch.Size([B, h * w])
125 | assert knn_indices.shape[2] >= self.k
126 | knn_indices = knn_indices[:, :, :self.k]
127 | valid_knn_mask = valid_knn_mask[:, :, :self.k]
128 | # valid_mask = valid_mask[:, :, :self.k]
129 |
130 | # Calculate weights
131 | knn_xyz = mask_batch_selecting(xyz, knn_indices, valid_knn_mask) # [B, 3, h * w, k]
132 | new_xyz = sampled_xyz.reshape(B, 3,-1) # [B, 3, h*w]
133 | knn_xyz_norm = knn_xyz - new_xyz[:, :, :, None] # [B, 3, h*w, k]
134 | weights = self.weight_net(knn_xyz_norm) # [B, n_weights, h*w, k]
135 |
136 | # Calculate weighted features
137 | weights = weights.transpose(1, 2) # [B, h*w, n_weights, k]
138 | knn_features = mask_batch_selecting(features_cl, knn_indices, valid_knn_mask, layout='channel_last') # [B, h*w, k, C + 3]
139 |
140 | out = torch.matmul(weights, knn_features) # [B, h*w, n_weights, C+3]
141 | out = out.view(B, h*w, -1) # [B, h*w, n_weights*(C+3)]
142 | out = self.linear(out) # [B, h*w, out_channels]
143 | out = self.act_fn(self.norm_fn(out.transpose(1, 2))) # [B, out_channels, h * w]
144 |
145 | # out = torch.reshape(out, [B, -1, h, w])
146 | out = out.view(B, -1, h, w)
147 | return out
148 |
149 |
150 | class PointConvDW(nn.Module):
151 | def __init__(self, in_channels, out_channels, norm=None, act='leaky_relu', k=16):
152 | super().__init__()
153 | self.k = k
154 | self.mlp = MLP2d(in_channels, [out_channels], norm, act)
155 | self.weight_net = MLP2d(3, [8, 32, out_channels], act='relu')
156 |
157 | def forward(self, xyz, features, sampled_xyz=None, knn_indices=None, valid_knn_mask = None):
158 | """
159 | :param xyz: [batch_size, 3, H, W]
160 | :param features: [batch_size, C, H, W]
161 | :param sampled_xyz: [batch_size, 3, h, w]
162 | :return: out: [batch_size, C', h, w]
163 | """
164 |
165 | if sampled_xyz is None:
166 | sampled_xyz = xyz
167 |
168 | B, C, H, W = features.shape
169 | h, w = sampled_xyz.shape[2:]
170 |
171 | # Calculate k nearest neighbors
172 | if knn_indices is None:
173 | # [B, h*w, k], [B, h*w, k]
174 | knn_indices, valid_knn_mask = knn_grouping_2d(sampled_xyz, xyz, self.k)
175 | else:
176 | assert knn_indices.shape[:2] == torch.Size([B, h * w])
177 | assert knn_indices.shape[2] >= self.k
178 | knn_indices = knn_indices[:, :, :self.k]
179 | valid_knn_mask = valid_knn_mask[:, :, :self.k]
180 |
181 | # Calculate weights
182 | knn_xyz = mask_batch_selecting(xyz, knn_indices, valid_knn_mask) # [B, 3, h * w, k]
183 | new_xyz = sampled_xyz.reshape(B, 3,-1) # [B, 3, h*w]
184 | knn_offset = knn_xyz - new_xyz[:, :, :, None] # [B, 3, h*w, k]
185 |
186 | features = self.mlp(features) # [B, C_out, H, W]
187 | features = mask_batch_selecting(features, knn_indices, valid_knn_mask) # [B, C_out, h * w, k]
188 | features = features * self.weight_net(knn_offset) # [B, C_out, h * w, k] * [B, C_out, h*w, k]
189 | features = torch.max(features, dim=-1)[0] # [B, C_out, h*w]
190 | features = features.view(B, -1, h, w)
191 |
192 | return features
--------------------------------------------------------------------------------
/ops_pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IRMVLab/DELFlow/3e12fdec9fd26ace6f42436497ceb236656c6a23/ops_pytorch/__init__.py
--------------------------------------------------------------------------------
/ops_pytorch/fused_conv_select/fused_conv_g.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "fused_conv_gpu.h"
6 |
7 | extern THCState *state;
8 |
9 | //#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
10 | //#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
13 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
14 |
15 | void torch_FusedConvSelectKLauncher(
16 | torch::Tensor xyz_tensor,
17 | torch::Tensor xyz2_tensor,
18 | torch::Tensor idx_n2_tensor,
19 | torch::Tensor idx_fetching_tensor,
20 | torch::Tensor random_hw_tensor,
21 | int H,
22 | int W,
23 | int npoints,
24 | int kernel_size_H,
25 | int kernel_size_W,
26 | int K,
27 | bool flag_copy,
28 | float distance,
29 | int stride_h,
30 | int stride_w,
31 | torch::Tensor selected_b_idx_tensor,
32 | torch::Tensor selected_h_idx_tensor,
33 | torch::Tensor selected_w_idx_tensor,
34 | torch::Tensor selected_mask_tensor,
35 | int small_h,
36 | int small_w ){
37 |
38 | CHECK_INPUT(xyz_tensor);
39 | CHECK_INPUT(xyz2_tensor);
40 | CHECK_INPUT(idx_n2_tensor);
41 | CHECK_INPUT(idx_fetching_tensor);
42 | CHECK_INPUT(random_hw_tensor);
43 | CHECK_INPUT(selected_b_idx_tensor);
44 | CHECK_INPUT(selected_h_idx_tensor);
45 | CHECK_INPUT(selected_w_idx_tensor);
46 | CHECK_INPUT(selected_mask_tensor);
47 |
48 | const auto batch_size = xyz_tensor.size(0);
49 | const float *xyz1 = xyz_tensor.data();
50 | const float *xyz2 = xyz2_tensor.data();
51 | const int *idx_n2 = idx_n2_tensor.data();
52 | const int *idx_fetching = idx_fetching_tensor.data();
53 | const int *random_hw = random_hw_tensor.data();
54 | long *selected_b_idx = selected_b_idx_tensor.data();
55 | long *selected_h_idx = selected_h_idx_tensor.data();
56 | long *selected_w_idx = selected_w_idx_tensor.data();
57 | float *selected_mask = selected_mask_tensor.data();
58 |
59 | //cudaStream_t stream = THCState_getCurrentStream(state);
60 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
61 |
62 | FusedConvSelectKLauncher(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance,
63 | stride_h, stride_w, xyz1, xyz2, idx_n2, idx_fetching, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, small_h, small_w, stream);
64 | }
65 |
66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
67 | m.def("fused_conv_select_k",
68 | &torch_FusedConvSelectKLauncher,
69 | "torch_FusedConvSelectKLauncher kernel warpper");
70 | }
--------------------------------------------------------------------------------
/ops_pytorch/fused_conv_select/fused_conv_go.cu:
--------------------------------------------------------------------------------
1 | // input: kernel_size(h,w), stride_size(h,w), distance(float), flag_padding, xyz(b,H,W,3), bhw_idx(b,H,W,3)
2 | // output: selected_xyz(b, npoints, h*w, 3), selected_feature(b, npoints, h*w, 3)
3 | #include
4 | #include
5 | #include /* srand, rand */
6 | #include /* time */
7 | #include // Header file needed to use rand
8 | #include "fused_conv_gpu.h"
9 |
10 | __global__ void fused_conv_select_k_gpu(int batch_size, int H, int W, int npoints, int kernel_size_H,
11 | int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1,
12 | const float *xyz2, const int *idx_n2, const int *idx_fetching, const int *random_hw, long *selected_b_idx, long *selected_h_idx, long *selected_w_idx,
13 | float *selected_mask, int small_h, int small_w)
14 | {
15 |
16 | int batch_index = blockIdx.x; //当前线程块索引
17 | int index_thread = threadIdx.x;
18 | int stride_thread = blockDim.x;
19 |
20 | int kernel_total = kernel_size_H * kernel_size_W; // 一个kernel的大小
21 | int selected_W_idx = 0, selected_H_idx = 0;
22 | int fetched_H_idx = 0, fetched_W_idx = 0;
23 |
24 | float dist_square = distance * distance;
25 |
26 | int kernel_half_H = kernel_size_H / 2;
27 | int kernel_half_W = kernel_size_W / 2;
28 | // small_h = H, small_w = H
29 | xyz1 += batch_index * H * W * 3; // point cloud of current image
30 | xyz2 += batch_index * small_h * small_w * 3;
31 | idx_n2 += batch_index * npoints * 2;
32 | idx_fetching += batch_index * npoints * 2; // 2d coordinates of central points
33 | selected_b_idx += batch_index * npoints * K * 1; //(b, npoints, k, 3), // batch index of K selected points around central points
34 | selected_h_idx += batch_index * npoints * K * 1; //(b, npoints, k, 3),
35 | selected_w_idx += batch_index * npoints * K * 1; //(b, npoints, k, 3),
36 |
37 |
38 | selected_mask += batch_index * npoints * K * 1; //(b, npoints, k, 1),坐标有效且距离有效的点,含复制的点(重复使用最近邻的点)
39 |
40 | ////////////// Fused Conv Between
41 | const int MaxPoint = 960;
42 | for (int current_n = index_thread; current_n < npoints; current_n += stride_thread) // output_W circle
43 | {
44 |
45 | int idx_w[MaxPoint], idx_h[MaxPoint];
46 | float Dist[MaxPoint];
47 |
48 | for (int ii = 0; ii < MaxPoint; ++ii)
49 | {
50 | idx_w[ii] = 0;
51 | idx_h[ii] = 0;
52 | Dist[ii] = 1e10f;
53 | }
54 |
55 | int m_idx = 0; // mth point in each kernel
56 | int num_select = 0; // the number of selected points in each kernel
57 | // idx_n2: subsampled index
58 | selected_H_idx = idx_n2[current_n * 2 + 0]; // the central points H idx of input 2d frame
59 | selected_W_idx = idx_n2[current_n * 2 + 1]; // the central points W idx of input 2d frame
60 |
61 | float x_c = xyz1[selected_H_idx * W * 3 + selected_W_idx * 3 + 0];
62 | float y_c = xyz1[selected_H_idx * W * 3 + selected_W_idx * 3 + 1];
63 | float z_c = xyz1[selected_H_idx * W * 3 + selected_W_idx * 3 + 2];
64 |
65 | float Dist_c = max((x_c - 0) * (x_c - 0) + (y_c - 0) * (y_c - 0) + (z_c - 0) * (z_c - 0), 1e-10f);
66 |
67 | // NOTE: the following is different
68 | if (Dist_c <= 1e-10f) // not valid central points of xyz1
69 | {
70 | continue;
71 | }
72 |
73 | fetched_H_idx = idx_fetching[current_n * 2 + 0];
74 | fetched_W_idx = idx_fetching[current_n * 2 + 1];
75 | // valid central points of xyz2
76 | if (fetched_H_idx == -1)
77 | {
78 | continue;
79 | }
80 |
81 | // compute the distance
82 | for (int current_HW_idx = 0; current_HW_idx < kernel_total; ++current_HW_idx) // select points in every kernel element
83 | {
84 |
85 | int kernel_HW_idx = random_hw[current_HW_idx];
86 |
87 | int kernel_select_H_idx = fetched_H_idx / stride_h + kernel_HW_idx / kernel_size_W - kernel_half_H; // random select ???
88 | int kernel_select_W_idx = fetched_W_idx / stride_w + kernel_HW_idx % kernel_size_W - kernel_half_W; // random select ???
89 |
90 | if ((kernel_select_H_idx < 0) || (kernel_select_H_idx >= small_h) || (kernel_select_W_idx < 0) || (kernel_select_W_idx >= small_w)) // the region of padding points (not valid)
91 | {
92 | ++m_idx;
93 | continue;
94 | }
95 |
96 | // not the padding points
97 |
98 | float x_q = xyz2[kernel_select_H_idx * small_w * 3 + kernel_select_W_idx * 3 + 0];
99 | float y_q = xyz2[kernel_select_H_idx * small_w * 3 + kernel_select_W_idx * 3 + 1];
100 | float z_q = xyz2[kernel_select_H_idx * small_w * 3 + kernel_select_W_idx * 3 + 2];
101 |
102 | float Dist_q_0 = x_q * x_q + y_q * y_q + z_q * z_q;
103 |
104 | if (Dist_q_0 <= 1e-10f) // not valid xyz2 points
105 | {
106 | ++m_idx;
107 | continue;
108 | }
109 |
110 | // valid xyz2 points, calculate the distance
111 |
112 |
113 | float Dist_q = max((x_c - x_q) * (x_c - x_q) + (y_c - y_q) * (y_c - y_q) + (z_c - z_q) * (z_c - z_q), 1e-10f);
114 |
115 | if (Dist_q > dist_square) // too far from the central points, regarding as not valid
116 | {
117 | ++m_idx;
118 | continue;
119 | }
120 |
121 |
122 |
123 | Dist[m_idx] = Dist_q;
124 | idx_h[m_idx] = kernel_select_H_idx;
125 | idx_w[m_idx] = kernel_select_W_idx;
126 |
127 | ++m_idx;
128 | ++num_select;
129 |
130 | if (num_select >= kernel_total) // search all position
131 | break;
132 | }
133 |
134 | //?int sort_num = 0;
135 |
136 | for (int s_idx = 0; s_idx < K; ++s_idx) // knn
137 | {
138 | int min_idx = s_idx; // min_idx idx
139 |
140 | // find the min_idx
141 | for (int t = s_idx + 1; t < kernel_total; ++t)
142 | {
143 | if (Dist[t] < Dist[min_idx])
144 | {
145 | min_idx = t;
146 | }
147 | }
148 |
149 | // swap min_idx-th and i-th element
150 | if (min_idx != s_idx)
151 | {
152 | float tmp_dist = Dist[min_idx];
153 | int tmp_idx_w = idx_w[min_idx];
154 | int tmp_idx_h = idx_h[min_idx];
155 |
156 | Dist[min_idx] = Dist[s_idx];
157 | idx_w[min_idx] = idx_w[s_idx];
158 | idx_h[min_idx] = idx_h[s_idx];
159 |
160 | Dist[s_idx] = tmp_dist;
161 | idx_w[s_idx] = tmp_idx_w;
162 | idx_h[s_idx] = tmp_idx_h;
163 | }
164 |
165 | if ((flag_copy & 0x1) && (s_idx == 0)) // copy the first selected point in xyz2 for K times
166 | {
167 | for (int k_idx = 0; k_idx < K; ++k_idx)
168 | {
169 |
170 | selected_b_idx[current_n * K + k_idx] = batch_index;
171 | selected_h_idx[current_n * K + k_idx] = idx_h[s_idx];
172 | selected_w_idx[current_n * K + k_idx] = idx_w[s_idx];
173 | selected_mask[current_n * K * 1 + k_idx * 1 + 0] = 1.0;
174 | }
175 |
176 | } // copy done
177 |
178 | if (Dist[s_idx] < 1e10f) // whether this is a valid points or not
179 | {
180 |
181 | selected_b_idx[current_n * K + s_idx] = batch_index;
182 | selected_h_idx[current_n * K + s_idx] = idx_h[s_idx];
183 | selected_w_idx[current_n * K + s_idx] = idx_w[s_idx];
184 | selected_mask[current_n * K * 1 + s_idx * 1 + 0] = 1.0;
185 | }
186 | }
187 | }
188 | }
189 |
190 | void FusedConvSelectKLauncher(int batch_size, int H, int W, int npoints, int kernel_size_H,
191 | int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w,
192 | const float *xyz1, const float *xyz2, const int *idx_n2, const int *idx_fetching, const int *random_hw,
193 | long *selected_b_idx, long *selected_h_idx, long *selected_w_idx,
194 | float *selected_mask, int small_h, int small_w, cudaStream_t stream)
195 | {
196 |
197 | cudaError_t err;
198 |
199 | fused_conv_select_k_gpu<<>>(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, xyz1, xyz2, idx_n2, idx_fetching, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, small_h, small_w);
200 |
201 | // cudaDeviceSynchronize();
202 | err = cudaGetLastError();
203 | if (cudaSuccess != err)
204 | {
205 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
206 | exit(-1);
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/ops_pytorch/fused_conv_select/fused_conv_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _FUSE_CONV_GPU_H_
2 | #define _FUSE_CONV_GPU_H_
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | void torch_FusedConvSelectKLauncher(
11 | torch::Tensor xyz_tensor,
12 | torch::Tensor xyz2_tensor,
13 | torch::Tensor idx_n2_tensor,
14 | torch::Tensor idx_fetching_tensor,
15 | torch::Tensor random_hw_tensor,
16 | int H,
17 | int W,
18 | int npoints,
19 | int kernel_size_H,
20 | int kernel_size_W,
21 | int K,
22 | bool flag_copy,
23 | float distance,
24 | int stride_h,
25 | int stride_w,
26 | torch::Tensor select_b_idx_tensor,
27 | torch::Tensor select_h_idx_tensor,
28 | torch::Tensor select_w_idx_tensor,
29 | torch::Tensor select_mask_tensor,
30 | int small_h,
31 | int small_w);
32 |
33 | void FusedConvSelectKLauncher(
34 | int batch_size,
35 | int H,
36 | int W,
37 | int npoints,
38 | int kernel_size_H,
39 | int kernel_size_W,
40 | int K,
41 | int flag_copy,
42 | float distance,
43 | int stride_h,
44 | int stride_w,
45 | const float *xyz1,
46 | const float *xyz2,
47 | const int *idx_n2,
48 | const int *idx_fetching,
49 | const int *random_hw,
50 | long *selected_b_idx,
51 | long *selected_h_idx,
52 | long *selected_w_idx,
53 | float *selected_mask,
54 | int small_h,
55 | int small_w,
56 | cudaStream_t stream);
57 |
58 | #endif
--------------------------------------------------------------------------------
/ops_pytorch/fused_conv_select/fused_conv_select_k.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | import os
4 | import numpy as np
5 | import fused_conv_select_k_cuda as fused_conv_select_k_module
6 |
7 |
8 | def fused_conv_select_k(xyz1, xyz2, idx_n2, idx_fetching, random_hw, H, W,
9 | npoints, kernel_size_H, kernel_size_W, K, flag_copy,
10 | distance, stride_h, stride_w, select_b_idx,
11 | select_h_idx, select_w_idx, select_mask, small_h, small_w):
12 | '''
13 | Input:
14 | xyz1:(b, h, w, 3) float, projected xyz1 points
15 | xyz2_feature:(b, h, w, c+3) float, projected xyz2 points with features
16 | idx_n2: (b, n, 2) int array, query idx of central points
17 | H, W : Input shape
18 | kernel_size_H, kernel_size_W: (size, size) int32 array, size
19 | k: the number of selected points (knn)
20 | distance: ( distance ) float distance
21 | flag_copy (bool) whether copy or not for the output points
22 |
23 | Output:
24 | space_weight:(batch_size, npoint, size*size , c)
25 | '''
26 |
27 | fused_conv_select_k_module.fused_conv_select_k(
28 | xyz1, xyz2, idx_n2, idx_fetching, random_hw, H, W, npoints,
29 | kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h,
30 | stride_w, select_b_idx, select_h_idx, select_w_idx, select_mask, small_h, small_w)
31 | return select_b_idx, select_h_idx, select_w_idx, select_mask
32 |
33 |
--------------------------------------------------------------------------------
/ops_pytorch/fused_conv_select/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name="fused_conv_select_k",
6 | ext_modules=[
7 | CUDAExtension(
8 | "fused_conv_select_k_cuda",
9 | ["fused_conv_g.cpp", "fused_conv_go.cu"],
10 | extra_compile_args={'cxx': ['-g'],
11 | 'nvcc': ['-O2']})
12 | ],
13 | cmdclass={
14 | "build_ext": BuildExtension
15 | }
16 | )
--------------------------------------------------------------------------------
/ops_pytorch/gpu_threenn_sample/no_sort_knn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | import os
4 | import numpy as np
5 |
6 | import no_sort_knn_cuda as no_sort_knn_module
7 |
8 | def no_sort_knn(xyz1, xyz2, idx_n2, random_hw, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, select_b_idx, select_h_idx, select_w_idx, select_mask):
9 |
10 | no_sort_knn_module.no_sort_knn(xyz1, xyz2, idx_n2, random_hw, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, select_b_idx, select_h_idx, select_w_idx, select_mask)
11 | return select_b_idx, select_h_idx, select_w_idx, select_mask
12 |
13 |
--------------------------------------------------------------------------------
/ops_pytorch/gpu_threenn_sample/no_sort_knn_g.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "no_sort_knn_gpu.h"
6 |
7 | extern THCState *state;
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
10 | #define CHECK_INPUT(x) \
11 | CHECK_CUDA(x); \
12 | CHECK_CONTIGUOUS(x)
13 |
14 | void torch_NoSortKnnLauncher(
15 | torch::Tensor xyz1_tensor,
16 | torch::Tensor xyz2_tensor,
17 | torch::Tensor idx_n2_tensor,
18 | torch::Tensor random_hw_tensor,
19 | int H,
20 | int W,
21 | int npoints,
22 | int kernel_size_H,
23 | int kernel_size_W,
24 | int K,
25 | bool flag_copy,
26 | float distance,
27 | int stride_h,
28 | int stride_w,
29 | torch::Tensor selected_b_idx_tensor,
30 | torch::Tensor selected_h_idx_tensor,
31 | torch::Tensor selected_w_idx_tensor,
32 | torch::Tensor selected_mask_tensor)
33 | {
34 | CHECK_INPUT(xyz1_tensor);
35 | CHECK_INPUT(xyz2_tensor);
36 | CHECK_INPUT(idx_n2_tensor);
37 | CHECK_INPUT(random_hw_tensor);
38 | CHECK_INPUT(selected_b_idx_tensor);
39 | CHECK_INPUT(selected_h_idx_tensor);
40 | CHECK_INPUT(selected_w_idx_tensor);
41 | CHECK_INPUT(selected_mask_tensor);
42 |
43 | const auto batch_size = xyz1_tensor.size(0);
44 | const float *xyz1 = xyz1_tensor.data();
45 | const float *xyz2 = xyz2_tensor.data();
46 | const int *idx_n2 = idx_n2_tensor.data();
47 | const int *random_hw = random_hw_tensor.data();
48 | long *selected_b_idx = selected_b_idx_tensor.data();
49 | long *selected_h_idx = selected_h_idx_tensor.data();
50 | long *selected_w_idx = selected_w_idx_tensor.data();
51 | float *selected_mask = selected_mask_tensor.data();
52 |
53 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
54 |
55 | NoSortKnnLauncher(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance,
56 | stride_h, stride_w, xyz1, xyz2, idx_n2, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, stream);
57 | }
58 |
59 |
60 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
61 | m.def(
62 | "no_sort_knn",
63 | &torch_NoSortKnnLauncher,
64 | "torch_NoSortKnnLauncher kernel warpper"
65 | );
66 | }
--------------------------------------------------------------------------------
/ops_pytorch/gpu_threenn_sample/no_sort_knn_go.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include /* srand, rand */
4 | #include /* time */
5 | #include // Header file needed to use rand
6 | #include
7 | #include
8 |
9 | struct special_idx{
10 | int idx_h;
11 | int idx_w;
12 | };
13 |
14 | __global__ void no_sort_knn_gpu(int batch_size, int H, int W, int npoints, int kernel_size_H, int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1, const float *xyz2, const int *idx_n2, const int *random_hw, long *selected_b_idx, long * selected_h_idx, long * selected_w_idx, float *selected_mask, int small_h, int small_w)
15 | {
16 | // in this function, only select 3 closest points
17 | int batch_index = blockIdx.x;
18 | int index_thread = threadIdx.x;
19 | int stride_thread = blockDim.x;
20 |
21 | int kernel_total = kernel_size_H * kernel_size_W;
22 | int selected_W_idx = 0, selected_H_idx =0;
23 |
24 | float dist_square = distance * distance;
25 |
26 | int kernel_half_H = kernel_size_H / 2;
27 | int kernel_half_W = kernel_size_W / 2;
28 |
29 | xyz1 += batch_index * H * W * 3;
30 | xyz2 += batch_index * small_h * small_w * 3;
31 | idx_n2 += batch_index * npoints * 2;
32 | selected_b_idx += batch_index * npoints * K * 1;
33 | selected_h_idx += batch_index * npoints * K * 1;
34 | selected_w_idx += batch_index * npoints * K * 1;
35 |
36 |
37 | // valid_idx += batch_index * npoints * kernel_total * 1 ; //(b, npoints, h*w, 1)
38 | // valid_in_dis_idx += batch_index * npoints * kernel_total * 1 ; //(b, npoints, h*w, 1)
39 |
40 | selected_mask += batch_index * npoints * K * 1 ; //(b, npoints, h*w, 1)
41 |
42 | for (int dense_i = index_thread; dense_i < npoints; dense_i += stride_thread) {
43 | // for each point in dense pl, search closest three points in sparse pl
44 | selected_H_idx = idx_n2[dense_i * 2];
45 | selected_W_idx = idx_n2[dense_i * 2 + 1];
46 | int central_idx = selected_H_idx * W * 3 + selected_W_idx * 3;
47 | float dense_x = xyz1[central_idx];
48 | float dense_y = xyz1[central_idx + 1];
49 | float dense_z = xyz1[central_idx + 2];
50 | // int num_valid_idx = 0;
51 | int num_select = 0;
52 |
53 | float dist_from_origin = dense_x * dense_x + dense_y * dense_y + dense_z * dense_z;
54 | if (dist_from_origin <= 1e-10) {
55 | continue;
56 | }
57 | float best1 = 1e30, best2 = 1e30, best3 = 1e30;
58 | special_idx besti[3];
59 | for (int i = 0; i < 3; ++i) {
60 | besti[i].idx_h = besti[i].idx_w = 0;
61 | }
62 |
63 | // once the central points in xyz1 are valid, begin to dispose xyz2 points
64 | for (int current_HW_idx = 0; current_HW_idx < kernel_total; ++current_HW_idx) {
65 | int kernel_hw_idx = random_hw[current_HW_idx];
66 | int kernel_select_h_idx = selected_H_idx / stride_h + kernel_hw_idx / kernel_size_W - kernel_half_H;
67 | int kernel_select_w_idx = selected_W_idx / stride_w + kernel_hw_idx % kernel_size_W - kernel_half_W;
68 |
69 | if ((kernel_select_h_idx < 0) || (kernel_select_w_idx < 0) || (kernel_select_h_idx >= small_h) || (kernel_select_w_idx >= small_w)) {
70 | continue;
71 | }
72 | int select_idx = kernel_select_h_idx * small_w * 3 + kernel_select_w_idx * 3;
73 | float sparse_x = xyz2[select_idx];
74 | float sparse_y = xyz2[select_idx + 1];
75 | float sparse_z = xyz2[select_idx + 2];
76 |
77 | float queried_dist_from_center = sparse_x * sparse_x + sparse_y * sparse_y + sparse_z * sparse_z;
78 | if (queried_dist_from_center <= 1e-10) {
79 | continue;
80 | // queried points are invalid points
81 | }
82 | float dist_from_query = (dense_x - sparse_x) * (dense_x - sparse_x) + (dense_y - sparse_y) * (dense_y - sparse_y) + (dense_z - sparse_z) * (dense_z - sparse_z);
83 |
84 | // if (num_select == 0 && flag_copy) {
85 | // // if at least one point is found, copy its information to all three points
86 | // best1 = best2 = best3 = dist_from_query;
87 | // special_idx temp;
88 | // temp.idx_h = kernel_select_h_idx;
89 | // temp.idx_w = kernel_select_w_idx;
90 | // besti[0] = besti[1] = besti[2] = temp;
91 | // }
92 |
93 |
94 | if (dist_from_query < 1e-10) {
95 | // we treat it as the original point and the first queried point
96 | best3 = best2;
97 | besti[2] = besti[1];
98 | best2 = best1;
99 | besti[1] = besti[0];
100 | best1 = dist_from_query;
101 | besti[0].idx_h = kernel_select_h_idx;
102 | besti[0].idx_w = kernel_select_w_idx;
103 | continue;
104 | }
105 | if (dist_from_query > dist_square) {
106 | continue;
107 | }
108 |
109 | ++num_select;
110 | // given a central point, select the closest 3 points in a kernel
111 |
112 | if (dist_from_query < best1) {
113 | best3 = best2;
114 | besti[2] = besti[1];
115 | best2 = best1;
116 | besti[1] = besti[0];
117 | best1 = dist_from_query;
118 | besti[0].idx_h = kernel_select_h_idx;
119 | besti[0].idx_w = kernel_select_w_idx;
120 | } else if (dist_from_query < best2) {
121 | best3 = best2;
122 | besti[2] = besti[1];
123 | best2 = dist_from_query;
124 | besti[1].idx_h = kernel_select_h_idx;
125 | besti[1].idx_w = kernel_select_w_idx;
126 | } else if (dist_from_query < best3) {
127 | best3 = dist_from_query;
128 | besti[2].idx_h = kernel_select_h_idx;
129 | besti[2].idx_w = kernel_select_w_idx;
130 | }
131 | }
132 | // bool no_point_flag = (best1 >= 1e30 && best2 >= 1e30 && best3 >= 1e30);
133 |
134 | int max_points = num_select < K ? num_select : K;
135 | int temp;
136 | for (int k = 0; k < max_points; ++k) {
137 | temp = dense_i * K + k;
138 | selected_b_idx[temp] = batch_index;
139 | selected_h_idx[temp] = besti[k].idx_h;
140 | selected_w_idx[temp] = besti[k].idx_w;
141 | selected_mask[temp] = 1.0;
142 | }
143 | if (flag_copy) {
144 | // if no points are selected, copy the first item
145 | for (int k = max_points; k < K; ++k) {
146 | int temp = dense_i * K + k;
147 | selected_b_idx[temp] = batch_index;
148 | selected_h_idx[temp] = besti[0].idx_h;
149 | selected_w_idx[temp] = besti[0].idx_w;
150 | selected_mask[temp] = 1.0;
151 | }
152 | }
153 |
154 |
155 |
156 | // for (int k = 0; k < K; ++k) {
157 | // int temp = dense_i * K + k;
158 | // selected_b_idx[temp] = batch_index;
159 | // selected_h_idx[temp] = besti[k].idx_h;
160 | // selected_w_idx[temp] = besti[k].idx_w;
161 | // selected_mask[dense_i * K + k] = 1.0;
162 | // }
163 | }
164 | }
165 |
166 |
167 |
168 | void NoSortKnnLauncher(int batch_size, int H, int W, int npoints, int kernel_size_H, int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1, const float *xyz2, const int *idx_n2, const int *random_hw, long *selected_b_idx, long *selected_h_idx, long *selected_w_idx, float *selected_mask, cudaStream_t stream)
169 | {
170 | int small_h = ceil(H / stride_h);
171 | int small_w = ceil(W / stride_w);
172 | cudaError_t err;
173 | no_sort_knn_gpu<<>>(batch_size, H, W, npoints, kernel_size_H, kernel_size_W, K, flag_copy, distance, stride_h, stride_w, xyz1, xyz2, idx_n2, random_hw, selected_b_idx, selected_h_idx, selected_w_idx, selected_mask, small_h, small_w);
174 |
175 | err = cudaGetLastError();
176 | if (cudaSuccess != err) {
177 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
178 | exit(-1);
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/ops_pytorch/gpu_threenn_sample/no_sort_knn_gpu.h:
--------------------------------------------------------------------------------
1 | #ifndef _NO_SORT_KNN_GPU_H
2 | #define _NO_SORT_KNN_GPU_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | void torch_NoSortKnnLauncher(
11 | torch::Tensor xyz1_tensor,
12 | torch::Tensor xyz2_tensor,
13 | torch::Tensor idx_n2_tensor,
14 | torch::Tensor random_hw_tensor,
15 | int H,
16 | int W,
17 | int npoints,
18 | int kernel_size_H,
19 | int kernel_size_W,
20 | int K,
21 | bool flag_copy,
22 | float distance,
23 | int stride_h,
24 | int stride_w,
25 | torch::Tensor selected_b_idx_tensor,
26 | torch::Tensor selected_h_idx_tensor,
27 | torch::Tensor selected_w_idx_tensor,
28 | torch::Tensor selected_mask_tensor);
29 |
30 | void NoSortKnnLauncher(
31 | int batch_size, int H, int W, int npoints, int kernel_size_H, int kernel_size_W, int K, int flag_copy, float distance, int stride_h, int stride_w, const float *xyz1, const float *xyz2, const int *idx_n2, const int *random_hw, long *selected_b_idx, long *selected_h_idx, long *selected_w_idx, float *selected_mask, cudaStream_t stream);
32 |
33 |
34 |
35 | #endif
--------------------------------------------------------------------------------
/ops_pytorch/gpu_threenn_sample/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name="no_sort_knn",
6 | ext_modules=[
7 | CUDAExtension(
8 | "no_sort_knn_cuda",
9 | ["no_sort_knn_g.cpp", "no_sort_knn_go.cu"],
10 | extra_compile_args={
11 | "cxx": ['-g'],
12 | "nvcc": ['-O2']
13 | })
14 |
15 | ],
16 | cmdclass={
17 | "build_ext": BuildExtension
18 | }
19 | )
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import sys
4 | import random
5 | import logging
6 | import torch
7 | import torch.optim
8 | import torch.multiprocessing as mp
9 | import torch.backends.cudnn as cudnn
10 | from datetime import datetime
11 | from omegaconf import DictConfig, OmegaConf
12 | from torch.nn.parallel import DistributedDataParallel
13 | from torch.distributed import init_process_group
14 | from torch.utils.data.distributed import DistributedSampler
15 | from dataset.flyingthings_subset import FlyingThings3D
16 | from dataset.kitti import KITTI
17 | from models import CamLiPWC
18 | from torch.utils.tensorboard import SummaryWriter
19 | from torch.cuda.amp.grad_scaler import GradScaler
20 | from utils import copy_to_device, build_optim_and_sched, FastDataLoader, init_log
21 |
22 |
23 | class Trainer:
24 | def __init__(self, device: torch.device, cfgs: DictConfig):
25 | os.environ["MKL_NUM_THREADS"] = "1"
26 | os.environ["OPENBLAS_NUM_THREADS"] = "1"
27 | os.environ["NCCL_IB_DISABLE"] = "1"
28 | os.environ["NCCL_P2P_DISABLE"] = "1"
29 |
30 | self.cfgs = cfgs
31 | self.curr_epoch = 1
32 | self.device = device
33 | self.n_gpus = torch.cuda.device_count()
34 | self.is_main = device.index is None or device.index == 0
35 | self.cfgs.log.dir = (
36 | self.cfgs.log.dir
37 | + "/{}/".format(self.cfgs.model.name)
38 | + str(datetime.now().strftime("%Y-%m-%d_%H-%M"))
39 | )
40 | os.makedirs(self.cfgs.log.dir, exist_ok=True)
41 | init_log(os.path.join(self.cfgs.log.dir, "train.log"))
42 |
43 | if device.index is None:
44 | logging.info("No CUDA device detected, using CPU for training")
45 | else:
46 | logging.info(
47 | "Using GPU %d: %s" % (device.index, torch.cuda.get_device_name(device))
48 | )
49 | logging.info(
50 | "PID:{}".format(os.getpid())
51 | )
52 | if self.n_gpus > 1:
53 | init_process_group(
54 | "nccl",
55 | "tcp://localhost:%d" % self.cfgs.port,
56 | world_size=self.n_gpus,
57 | rank=self.device.index,
58 | )
59 | self.cfgs.model.batch_size = int(
60 | self.cfgs.model.batch_size / self.n_gpus
61 | )
62 | self.cfgs.trainset.n_workers = int(
63 | self.cfgs.trainset.n_workers / self.n_gpus
64 | )
65 | self.cfgs.valset.n_workers = int(
66 | self.cfgs.valset.n_workers / self.n_gpus
67 | )
68 |
69 | cudnn.benchmark = False
70 | torch.cuda.set_device(self.device)
71 |
72 | if self.is_main:
73 | logging.info("Logs will be saved to %s" % self.cfgs.log.dir)
74 | self.summary_writer = SummaryWriter(self.cfgs.log.dir)
75 | logging.info("Configurations:\n" + OmegaConf.to_yaml(self.cfgs))
76 | os.system("cp -r %s %s" % ("models", self.cfgs.log.dir))
77 | os.system("cp -r %s %s" % ("dataset", self.cfgs.log.dir))
78 | os.system("cp -r %s %s" % ("config", self.cfgs.log.dir))
79 | os.system("cp %s %s" % ("train.py", self.cfgs.log.dir))
80 | else:
81 | logging.root.disabled = True
82 |
83 | if self.cfgs.trainset.name == "flyingthings3d":
84 | self.train_dataset = FlyingThings3D(self.cfgs.trainset)
85 | self.val_dataset = FlyingThings3D(self.cfgs.valset)
86 | elif self.cfgs.trainset.name == "kitti":
87 | self.train_dataset = KITTI(self.cfgs.trainset)
88 | self.val_dataset = KITTI(self.cfgs.valset)
89 | else:
90 | raise NotImplementedError
91 |
92 | logging.info("Loading training set from %s" % self.cfgs.trainset.root_dir)
93 | self.train_sampler = (
94 | DistributedSampler(self.train_dataset) if self.n_gpus > 1 else None
95 | )
96 | logging.info("Loading validation set from %s" % self.cfgs.valset.root_dir)
97 | self.val_sampler = (
98 | DistributedSampler(self.val_dataset) if self.n_gpus > 1 else None
99 | )
100 |
101 | self.train_loader = FastDataLoader(
102 | dataset=self.train_dataset,
103 | batch_size=self.cfgs.model.batch_size,
104 | shuffle=(self.train_sampler is None),
105 | num_workers=self.cfgs.trainset.n_workers,
106 | pin_memory=True,
107 | sampler=self.train_sampler,
108 | drop_last=self.cfgs.trainset.drop_last,
109 | )
110 |
111 | self.val_loader = FastDataLoader(
112 | dataset=self.val_dataset,
113 | batch_size=self.cfgs.model.batch_size,
114 | shuffle=False,
115 | num_workers=self.cfgs.valset.n_workers,
116 | pin_memory=True,
117 | sampler=self.val_sampler,
118 | )
119 |
120 | logging.info("Creating model: %s" % self.cfgs.model.name)
121 |
122 | self.model = CamLiPWC(self.cfgs.model)
123 | self.model.to(device=self.device)
124 |
125 | n_params = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
126 | logging.info("Trainable parameters: %d (%.1fM)" % (n_params, n_params / 1e6))
127 |
128 | if self.n_gpus > 1:
129 | if self.cfgs.sync_bn:
130 | self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
131 | self.ddp = DistributedDataParallel(self.model, [self.device.index])
132 | else:
133 | self.ddp = self.model
134 |
135 | self.best_metrics = None
136 | if self.cfgs.ckpt.path is not None:
137 | self.load_ckpt(self.cfgs.ckpt.path, resume=self.cfgs.ckpt.resume)
138 |
139 | logging.info("Creating optimizer: %s" % self.cfgs.training.opt)
140 | self.optimizer, self.scheduler = build_optim_and_sched(
141 | self.cfgs.training, self.model
142 | )
143 | self.scheduler.step(self.curr_epoch - 1)
144 |
145 | self.amp_scaler = GradScaler(enabled=self.cfgs.amp)
146 |
147 | def run(self):
148 | while self.curr_epoch <= self.cfgs.training.epochs:
149 | if self.train_sampler is not None:
150 | self.train_sampler.set_epoch(self.curr_epoch)
151 | if self.val_sampler is not None:
152 | self.val_sampler.set_epoch(self.curr_epoch)
153 |
154 | self.train_one_epoch()
155 |
156 | if self.curr_epoch % self.cfgs.val_interval == 0:
157 | self.validate()
158 |
159 | self.save_ckpt()
160 | self.scheduler.step(self.curr_epoch)
161 |
162 | self.curr_epoch += 1
163 |
164 | def train_one_epoch(self):
165 | logging.info("Start training...")
166 |
167 | self.ddp.train()
168 | self.model.clear_metrics()
169 | self.optimizer.zero_grad()
170 |
171 | lr = self.optimizer.param_groups[0]["lr"]
172 | self.save_scalar_summary({"learning_rate": lr}, prefix="train")
173 |
174 | logging.info("Epoch: [%d/%d]" % (self.curr_epoch, self.cfgs.training.epochs))
175 | logging.info("Current learning rate: %.8f" % lr)
176 |
177 | for i, inputs in enumerate(self.train_loader):
178 | inputs = copy_to_device(inputs, self.device)
179 |
180 | # forward
181 | with torch.cuda.amp.autocast(enabled=self.cfgs.amp):
182 | self.ddp.forward(inputs)
183 | loss = self.model.get_loss()
184 |
185 | # backward
186 | self.amp_scaler.scale(loss).backward()
187 |
188 | # grad clip
189 | if "grad_max_norm" in self.cfgs.training.keys():
190 | self.amp_scaler.unscale_(self.optimizer)
191 | torch.nn.utils.clip_grad_norm_(
192 | parameters=self.model.parameters(),
193 | max_norm=self.cfgs.training.grad_max_norm,
194 | )
195 |
196 | # update
197 | self.amp_scaler.step(self.optimizer)
198 | self.amp_scaler.update()
199 | self.optimizer.zero_grad()
200 |
201 | metrics = self.model.get_metrics()
202 | self.save_scalar_summary(metrics, prefix="train")
203 |
204 | @torch.no_grad()
205 | def validate(self):
206 | logging.info("Start validating...")
207 |
208 | self.ddp.eval()
209 | self.model.clear_metrics()
210 |
211 | for i, inputs in enumerate(self.val_loader):
212 | inputs = copy_to_device(inputs, self.device)
213 |
214 | with torch.cuda.amp.autocast(enabled=False):
215 | self.ddp.forward(inputs)
216 |
217 | metrics = self.model.get_metrics()
218 | self.save_scalar_summary(metrics, prefix="val")
219 |
220 | for k, v in metrics.items():
221 | logging.info("%s: %.4f" % (k, v))
222 |
223 | if self.model.is_better(metrics, self.best_metrics):
224 | self.best_metrics = metrics
225 | self.save_ckpt("best.pt")
226 |
227 | def save_scalar_summary(self, scalar_summary: dict, prefix):
228 | if self.is_main and self.cfgs.log.save_scalar_summary:
229 | for name in scalar_summary.keys():
230 | self.summary_writer.add_scalar(
231 | prefix + "/" + name, scalar_summary[name], self.curr_epoch
232 | )
233 |
234 | def save_ckpt(self, filename=None):
235 | if self.is_main and self.cfgs.log.save_ckpt:
236 | ckpt_dir = os.path.join(self.cfgs.log.dir, "ckpts")
237 | os.makedirs(ckpt_dir, exist_ok=True)
238 | # filepath = os.path.join(
239 | # ckpt_dir, filename or "epoch-%03d.pt" % self.curr_epoch
240 | # )
241 | filepath = os.path.join(
242 | ckpt_dir, filename or "epoch-latest.pt"
243 | )
244 | logging.info("Saving checkpoint to %s" % filepath)
245 | torch.save(
246 | {
247 | "last_epoch": self.curr_epoch,
248 | "state_dict": self.model.state_dict(),
249 | "best_metrics": self.best_metrics,
250 | },
251 | filepath,
252 | )
253 |
254 | def load_ckpt(self, filepath, resume=True):
255 | logging.info("Loading checkpoint from %s" % filepath)
256 | checkpoint = torch.load(filepath, self.device)
257 | if resume:
258 | self.curr_epoch = checkpoint["last_epoch"] + 1
259 | self.best_metrics = checkpoint["best_metrics"]
260 | logging.info("Current best metrics: %s" % str(self.best_metrics))
261 | # self.model.load_state_dict(checkpoint["state_dict"], strict=True)
262 | self.model.load_state_dict(checkpoint["state_dict"], strict=False)
263 |
264 |
265 | def create_trainer(device_id, cfgs):
266 | device = torch.device("cpu" if device_id is None else "cuda:%d" % device_id)
267 | trainer = Trainer(device, cfgs)
268 | trainer.run()
269 |
270 |
271 | def main(cfgs: DictConfig):
272 | # set num_workers of data loader
273 | if not cfgs.debug:
274 | n_devices = max(torch.cuda.device_count(), 1)
275 | cfgs.trainset.n_workers = min(
276 | os.cpu_count(), cfgs.trainset.n_workers * n_devices
277 | )
278 | cfgs.valset.n_workers = min(os.cpu_count(), cfgs.valset.n_workers * n_devices)
279 | else:
280 | cfgs.trainset.n_workers = 0
281 | cfgs.valset.n_workers = 0
282 |
283 | if cfgs.port == "random":
284 | cfgs.port = random.randint(10000, 20000)
285 |
286 | if cfgs.training.accum_iter > 1:
287 | cfgs.model.batch_size //= int(cfgs.training.accum_iter)
288 |
289 | # create trainers
290 | if torch.cuda.device_count() == 0: # CPU
291 | create_trainer(None, cfgs)
292 | elif torch.cuda.device_count() == 1: # Single GPU
293 | create_trainer(0, cfgs)
294 | elif torch.cuda.device_count() > 1: # Multiple GPUs
295 | mp.spawn(create_trainer, (cfgs,), torch.cuda.device_count())
296 |
297 |
298 | if __name__ == "__main__":
299 | path = sys.argv[1]
300 | with open(path, encoding="utf-8") as f:
301 | cfgs = DictConfig(yaml.load(f, Loader=yaml.FullLoader))
302 | main(cfgs)
303 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .average_meter import *
2 | from .build_utils import build_optim_and_sched
3 | from .train_utils import copy_to_device, copy_to_cuda
4 | from .log_utils import init_experiment_dir, init_logging, init_log
5 |
6 | import cv2
7 | import numpy as np
8 | import torch
9 |
10 | class _RepeatSampler(object):
11 | def __init__(self, sampler):
12 | self.sampler = sampler
13 |
14 | def __iter__(self):
15 | while True:
16 | yield from iter(self.sampler)
17 |
18 |
19 | class FastDataLoader(torch.utils.data.dataloader.DataLoader):
20 | def __init__(self, *args, **kwargs):
21 | super().__init__(*args, **kwargs)
22 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
23 | self.iterator = super().__iter__()
24 |
25 | def __len__(self):
26 | return len(self.batch_sampler.sampler)
27 |
28 | def __iter__(self):
29 | for i in range(len(self)):
30 | yield next(self.iterator)
31 |
32 | def size_of_batch(inputs):
33 | if isinstance(inputs, list):
34 | return size_of_batch(inputs[0])
35 | elif isinstance(inputs, dict):
36 | return size_of_batch(list(inputs.values())[0])
37 | elif isinstance(inputs, torch.Tensor):
38 | return inputs.shape[0]
39 | else:
40 | raise TypeError('Unknown type: %s' % str(type(inputs)))
41 |
42 | def save_flow_png(filepath, flow, mask=None, scale=64.0):
43 | assert flow.shape[2] == 2
44 | assert np.abs(flow).max() < 32767.0 / scale
45 | flow = flow * scale
46 | flow = flow + 32768.0
47 |
48 | if mask is None:
49 | mask = np.ones_like(flow)[..., 0]
50 | else:
51 | mask = np.float32(mask > 0)
52 |
53 | flow_img = np.concatenate([
54 | mask[..., None],
55 | flow[..., 1:2],
56 | flow[..., 0:1]
57 | ], axis=-1).astype(np.uint16)
58 |
59 | cv2.imwrite(filepath, flow_img)
60 |
61 |
62 | def load_disp_png(filepath):
63 | array = cv2.imread(filepath, -1)
64 | valid_mask = array > 0
65 | disp = array.astype(np.float32) / 256.0
66 | disp[np.logical_not(valid_mask)] = -1.0
67 | return disp, valid_mask
68 |
69 |
70 | def save_disp_png(filepath, disp, mask=None):
71 | if mask is None:
72 | mask = disp > 0
73 | disp = np.uint16(disp * 256.0)
74 | disp[np.logical_not(mask)] = 0
75 | cv2.imwrite(filepath, disp)
76 |
77 | def disp2pc(disp, baseline, f, cx, cy, flow=None):
78 | h, w = disp.shape
79 | depth = baseline * f / (disp + 1e-5)
80 |
81 | xx = np.tile(np.arange(w, dtype=np.float32)[None, :], (h, 1))
82 | yy = np.tile(np.arange(h, dtype=np.float32)[:, None], (1, w))
83 |
84 | if flow is None:
85 | x = (xx - cx) * depth / f
86 | y = (yy - cy) * depth / f
87 | else:
88 | x = (xx - cx + flow[..., 0]) * depth / f
89 | y = (yy - cy + flow[..., 1]) * depth / f
90 |
91 | pc = np.concatenate([
92 | x[:, :, None],
93 | y[:, :, None],
94 | depth[:, :, None],
95 | ], axis=-1)
96 |
97 | return pc
98 |
99 | mesh_grid_cache = {}
100 | def mesh_grid(n, h, w, device, channel_first=True):
101 | global mesh_grid_cache
102 | str_id = '%d,%d,%d,%s,%s' % (n, h, w, device, channel_first)
103 | if str_id not in mesh_grid_cache:
104 | x_base = torch.arange(0, w, dtype=torch.float32, device=device)[None, None, :].expand(n, h, w)
105 | y_base = torch.arange(0, h, dtype=torch.float32, device=device)[None, None, :].expand(n, w, h) # NWH
106 | grid = torch.stack([x_base, y_base.transpose(1, 2)], 1) # B2HW
107 | if not channel_first:
108 | grid = grid.permute(0, 2, 3, 1) # BHW2
109 | mesh_grid_cache[str_id] = grid
110 | return mesh_grid_cache[str_id]
--------------------------------------------------------------------------------
/utils/average_meter.py:
--------------------------------------------------------------------------------
1 |
2 | class AverageMeter(object):
3 | """Computes and stores the average and current value"""
4 |
5 | def __init__(self):
6 | self.reset()
7 |
8 | def reset(self):
9 | self.val = 0
10 | self.avg = 0
11 | self.sum = 0
12 | self.count = 0
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += val * n
17 | self.count += n
18 | self.avg = self.sum / self.count
19 |
--------------------------------------------------------------------------------
/utils/build_utils.py:
--------------------------------------------------------------------------------
1 | from torch.optim import Adam, AdamW
2 | from timm.scheduler import create_scheduler
3 |
4 | def build_optim_and_sched(cfgs, model):
5 | params_2d_decay = []
6 | params_3d_decay = []
7 |
8 | params_2d_no_decay = []
9 | params_3d_no_decay = []
10 |
11 | for name, param in model.named_parameters():
12 | if not param.requires_grad:
13 | continue # frozen weights
14 |
15 | if len(param.shape) == 1 or name.endswith(".bias"):
16 | if name.startswith('core.branch_3d'):
17 | params_3d_no_decay.append(param)
18 | else:
19 | params_2d_no_decay.append(param)
20 | else:
21 | if name.startswith('core.branch_3d'):
22 | params_3d_decay.append(param)
23 | else:
24 | params_2d_decay.append(param)
25 |
26 | lr = getattr(cfgs, 'lr', None)
27 | lr_2d = getattr(cfgs, 'lr_2d', lr)
28 | lr_3d = getattr(cfgs, 'lr_3d', lr)
29 |
30 | params = [
31 | {'params': params_2d_decay, 'weight_decay': cfgs.weight_decay, 'lr': lr_2d},
32 | {'params': params_3d_decay, 'weight_decay': cfgs.weight_decay, 'lr': lr_3d},
33 | {'params': params_2d_no_decay, 'weight_decay': 0, 'lr': lr_2d},
34 | {'params': params_3d_no_decay, 'weight_decay': 0, 'lr': lr_3d},
35 | ]
36 |
37 | if cfgs.opt == 'adam':
38 | optimizer = Adam(params)
39 | elif cfgs.opt == 'adamw':
40 | optimizer = AdamW(params)
41 | else:
42 | raise NotImplementedError
43 |
44 | scheduler = create_scheduler(cfgs, optimizer)[0]
45 |
46 | return optimizer, scheduler
47 |
--------------------------------------------------------------------------------
/utils/evaluation_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluation metrics
3 | Borrowed from HPLFlowNet
4 | Date: May 2020
5 |
6 | @inproceedings{HPLFlowNet,
7 | title={HPLFlowNet: Hierarchical Permutohedral Lattice FlowNet for
8 | Scene Flow Estimation on Large-scale Point Clouds},
9 | author={Gu, Xiuye and Wang, Yijie and Wu, Chongruo and Lee, Yong Jae and Wang, Panqu},
10 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2019 IEEE International Conference on},
11 | year={2019}
12 | }
13 | """
14 |
15 | import numpy as np
16 |
17 |
18 | def evaluate_3d(sf_pred, sf_gt):
19 | """
20 | sf_pred: (N, 3)
21 | sf_gt: (N, 3)
22 | """
23 | l2_norm = np.linalg.norm(sf_gt - sf_pred, axis=-1)
24 | EPE3D = l2_norm.mean()
25 |
26 | sf_norm = np.linalg.norm(sf_gt, axis=-1)
27 | relative_err = l2_norm / (sf_norm + 1e-4)
28 |
29 | acc3d_strict = (np.logical_or(l2_norm < 0.05, relative_err < 0.05)).astype(np.float).mean()
30 | acc3d_relax = (np.logical_or(l2_norm < 0.1, relative_err < 0.1)).astype(np.float).mean()
31 | outlier = (np.logical_or(l2_norm > 0.3, relative_err > 0.1)).astype(np.float).mean()
32 |
33 | return EPE3D, acc3d_strict, acc3d_relax, outlier
34 |
35 |
36 | def evaluate_2d(flow_pred, flow_gt):
37 | """
38 | flow_pred: (N, 2)
39 | flow_gt: (N, 2)
40 | """
41 |
42 | epe2d = np.linalg.norm(flow_gt - flow_pred, axis=-1)
43 | epe2d_mean = epe2d.mean()
44 |
45 | flow_gt_norm = np.linalg.norm(flow_gt, axis=-1)
46 | relative_err = epe2d / (flow_gt_norm + 1e-5)
47 |
48 | acc2d = (np.logical_or(epe2d < 3., relative_err < 0.05)).astype(np.float).mean()
49 |
50 | return epe2d_mean, acc2d
51 |
--------------------------------------------------------------------------------
/utils/geometry.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import os.path as osp
4 |
5 |
6 | def get_batch_2d_flow(pc1, pc2, predicted_pc2, intrinsics):
7 |
8 | focallengths = intrinsics[0][0]
9 | cxs = intrinsics[0][2]
10 | cys = intrinsics[0][3]
11 | constx = 0
12 | consty = 0
13 | constz = 0
14 |
15 | px1, py1 = project_3d_to_2d(pc1, f=focallengths, cx=cxs, cy=cys,
16 | constx=constx, consty=consty, constz=constz)
17 | px2, py2 = project_3d_to_2d(predicted_pc2, f=focallengths, cx=cxs, cy=cys,
18 | constx=constx, consty=consty, constz=constz)
19 | px2_gt, py2_gt = project_3d_to_2d(pc2, f=focallengths, cx=cxs, cy=cys,
20 | constx=constx, consty=consty, constz=constz)
21 |
22 |
23 | flow_x = px2 - px1
24 | flow_y = py2 - py1
25 |
26 | flow_x_gt = px2_gt - px1
27 | flow_y_gt = py2_gt - py1
28 |
29 | flow_pred = np.concatenate((flow_x[..., None], flow_y[..., None]), axis=-1)
30 | flow_gt = np.concatenate((flow_x_gt[..., None], flow_y_gt[..., None]), axis=-1)
31 | return flow_pred, flow_gt
32 |
33 |
34 | def project_3d_to_2d(pc, f=1050., cx=479.5, cy=269.5, constx=0, consty=0, constz=0):
35 | x = (pc[..., 0] * f + cx * pc[..., 2] + constx) / (pc[..., 2] + constz + 10e-10)
36 | y = (pc[..., 1] * f + cy * pc[..., 2] + consty) / (pc[..., 2] + constz + 10e-10)
37 |
38 | return x, y
39 |
--------------------------------------------------------------------------------
/utils/log_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import datetime
3 | import os
4 | import logging
5 | import sys
6 |
7 | def init_experiment_dir(cfg):
8 | experiment_dir = Path(cfg.log.dir)
9 | experiment_dir.mkdir(exist_ok=True)
10 | file_dir = Path(
11 | str(experiment_dir)
12 | + "/Flyingthings3d-"
13 | + str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M"))
14 | )
15 | file_dir.mkdir(exist_ok=True)
16 | checkpoints_dir = file_dir.joinpath(cfg.ckpt.save_path)
17 | checkpoints_dir.mkdir(exist_ok=True)
18 | log_dir = file_dir.joinpath("logs/")
19 | log_dir.mkdir(exist_ok=True)
20 |
21 | os.system("cp -r %s %s" % ("models", log_dir))
22 | os.system("cp -r %s %s" % ("dataset", log_dir))
23 | os.system("cp -r %s %s" % ("config", log_dir))
24 | os.system("cp %s %s" % ("train.py", log_dir))
25 |
26 | return checkpoints_dir, log_dir
27 |
28 | def init_logging(log_dir, cfgs):
29 | logger = logging.getLogger(cfgs.model.name)
30 | formatter = logging.Formatter("[%(asctime)s][%(levelname)s] - %(message)s")
31 | file_handler = logging.FileHandler(str(log_dir) + "/log_%s.log" % cfgs.model.name)
32 | file_handler.setFormatter(formatter)
33 | stream_handler = logging.StreamHandler()
34 | stream_handler.setFormatter(formatter)
35 | logger.setLevel(logging.INFO)
36 | logger.addHandler(file_handler)
37 | logger.addHandler(stream_handler)
38 | return logger
39 |
40 | def init_log(filename=None, debug=False):
41 | logging.root = logging.RootLogger('DEBUG' if debug else 'INFO')
42 | formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s')
43 |
44 | stream_handler = logging.StreamHandler(sys.stdout)
45 | stream_handler.setFormatter(formatter)
46 | logging.root.addHandler(stream_handler)
47 |
48 | if filename is not None:
49 | file_handler = logging.FileHandler(filename)
50 | file_handler.setFormatter(formatter)
51 | logging.root.addHandler(file_handler)
--------------------------------------------------------------------------------
/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import re
3 | import os
4 | import cv2
5 | import sys
6 | import smtplib
7 | import logging
8 | import numpy as np
9 |
10 | def copy_to_device(inputs, device, non_blocking=True):
11 |
12 | if isinstance(inputs, list):
13 | inputs = [copy_to_device(item, device, non_blocking) for item in inputs]
14 | elif isinstance(inputs, dict):
15 | inputs = {k: copy_to_device(v, device, non_blocking) for k, v in inputs.items()}
16 | elif isinstance(inputs, torch.Tensor):
17 | inputs = inputs.to(device=device, non_blocking=non_blocking)
18 | else:
19 | raise TypeError("Unknown type: %s" % str(type(inputs)))
20 | return inputs
21 |
22 | def copy_to_cuda(inputs, non_blocking=True):
23 |
24 | assert isinstance(inputs, dict), "inputs not dict"
25 |
26 | inputs = {k: v.cuda(non_blocking) for k, v in inputs.items()}
27 |
28 | return inputs
29 |
30 |
--------------------------------------------------------------------------------