├── LICENSE ├── README.md ├── asset ├── github-demo-2024_0517.jpg └── overview_24_0127.png ├── data ├── dataset.py └── demo │ ├── image │ ├── 000009_09.png │ ├── 000009_10.png │ └── 000009_11.png │ └── pred │ └── flow_000009_10_for_check.png ├── model ├── __pycache__ │ ├── attention.cpython-38.pyc │ ├── corr.cpython-38.pyc │ ├── extractor.cpython-38.pyc │ ├── model_splatflow.cpython-38.pyc │ ├── softsplat.cpython-38.pyc │ └── update.cpython-38.pyc ├── attention.py ├── corr.py ├── extractor.py ├── model_splatflow.py ├── softsplat.py ├── update.py └── util │ ├── __pycache__ │ ├── augmentor.cpython-38.pyc │ └── util.cpython-38.pyc │ ├── augmentor.py │ └── util.py ├── run_demo.py ├── run_test.py ├── run_train.py └── script ├── demo.sh ├── test_kitti.sh ├── test_things.sh ├── train_kitti.sh ├── train_sintel.sh └── train_things.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024, Bo Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [//]: # (@File: README.md) 2 | 3 | [//]: # (@Project: SplatFlow) 4 | 5 | [//]: # (@Author : wangbo) 6 | 7 | [//]: # (@Time : 2024.07.12) 8 | 9 | # SplatFlow: Learning Multi-frame Optical Flow via Splatting 10 | This repository contains the source code for our paper: 11 | - SplatFlow: Learning Multi-frame Optical Flow via Splatting (IJCV 2024) | [Paper](https://arxiv.org/pdf/2306.08887.pdf) 12 | ![](./asset/github-demo-2024_0517.jpg) 13 | - [x] We propose a novel MOFE framework SplatFlow designed explicitly for the single-resolution iterative two-frame backbones. 14 | - [x] Compared with the original backbone, SplatFlow has significantly higher estimation accuracy, especially in occluded regions, while maintaining a high inference speed. 15 | - [x] At the time of submission, our SplatFlow achieved state-of-the-art results on both the [Sintel](http://sintel.is.tue.mpg.de/quant?metric_id=0&selected_pass=1) and [KITTI2015](https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) benchmarks, especially with surprisingly significant 19.4% error reductions compared to the previous best result submitted on the Sintel benchmark. 16 | 17 | ## Updates 18 | 19 | - [2024.04.24] 📣 The code of SplatFlow is now available! 20 | - [2024.01.02] 📣 The paper of SplatFlow is accepted by IJCV 2024! 21 | 22 | ## Environment 23 | 24 | Our code has been successfully tested in the following environments: 25 | 26 | * NVIDIA 3090 GPU 27 | * CUDA 11.1 28 | * Python 3.8 29 | * PyTorch 1.8.2 30 | ``` 31 | conda create -n splatflow python=3.8 32 | conda activate splatflow 33 | 34 | pip install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111 35 | pip install einops==0.4.1 36 | pip install cupy-cuda111 37 | pip install pillow==9.5.0 38 | pip install opencv-python==4.1.2.30 39 | ``` 40 | 41 | ## Trained Weights 42 | 43 | Download the weights below and put them in the `exp/0-pretrain` path. 44 | 45 | | Model | Training process | Weights | Comments | 46 | |-----------|------------------|-------------------------------------------------------------------------------------|---------------------------| 47 | | SplatFlow | K-finetune | splatflow_kitti_50k.pth
[Huggingface](https://huggingface.co/wwcreator/SplatFlow) & [BaiduNetdisk](https://pan.baidu.com/s/1JDPCexLqlj-ULLt1TPsGxA&pwd=wang) | Best performance on KITTI | 48 | 49 | 50 | ## Demo 51 | 52 | * Quick start. 53 | ```Shell 54 | bash script/demo.sh 55 | ``` 56 | 57 | ## Datasets 58 | 59 | To train / test SplatFlow, you will need to download the required datasets and update `data_root` in `data/dataset.py`. 60 | 61 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 62 | * [Sintel](http://sintel.is.tue.mpg.de/) 63 | * [KITTI2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 64 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) 65 | 66 | ```text 67 | data_root/ 68 | │ 69 | ├─ FlyingThings3D/ 70 | │ ├─ frames_cleanpass/ 71 | │ ├─ frames_finalpass/ 72 | │ └─ optical_flow/ 73 | │ 74 | ├─ Sintel/ 75 | │ ├─ training/ 76 | │ └─ test/ 77 | │ 78 | ├─ KITTI/ 79 | │ ├─ training/ 80 | │ └─ testing/ 81 | │ 82 | ├─ HD1k/ 83 | │ ├─ hd1k_input/ 84 | │ └─ hd1k_flow_gt/ 85 | │ 86 | └─ demo/ 87 | ├─ image/ 88 | └─ pred/ 89 | ``` 90 | 91 | ## Training 92 | 93 | * Train SplatFlow under the C+T training process. 94 | ```Shell 95 | bash script/train_things.sh 96 | ``` 97 | 98 | * Train SplatFlow under the S-finetune training process. 99 | ```bash 100 | bash script/train_sintel.sh 101 | ``` 102 | 103 | * Train SplatFlow under the K-finetune training process. 104 | ```bash 105 | bash script/train_kitti.sh 106 | ``` 107 | 108 | ## Testing 109 | 110 | * Test SplatFlow on Things. 111 | ```Shell 112 | bash script/test_things.sh 113 | ``` 114 | 115 | * Test SplatFlow on KITTI. 116 | ```Shell 117 | bash script/test_kitti.sh 118 | ``` 119 | 120 | ## Acknowledgments 121 | We would like to thank [RAFT](https://github.com/princeton-vl/RAFT), [GMA](https://github.com/zacjiang/GMA) and [SoftSplat](https://github.com/JHLew/SoftSplat-Full) for publicly releasing their code and data. 122 | 123 | ## Citing this Work 124 | 125 | If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work: 126 | 127 | ```bibtex 128 | @article{wang2024splatflow, 129 | title={SplatFlow: Learning multi-frame optical flow via splatting}, 130 | author={Wang, Bo and Zhang, Yifan and Li, Jian and Yu, Yang and Sun, Zhenping and Liu, Li and Hu, Dewen}, 131 | journal={International Journal of Computer Vision}, 132 | volume={132}, 133 | number={8}, 134 | pages={3023--3045}, 135 | year={2024}, 136 | publisher={Springer} 137 | } 138 | ``` 139 | -------------------------------------------------------------------------------- /asset/github-demo-2024_0517.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/asset/github-demo-2024_0517.jpg -------------------------------------------------------------------------------- /asset/overview_24_0127.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/asset/overview_24_0127.png -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # @File: dataset.py 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | import os.path as osp 7 | from glob import glob 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torch.utils.data import DataLoader 12 | from model.util import util 13 | from model.util.augmentor import FlowAugmentor, SparseFlowAugmentor 14 | 15 | data_root = '/data1/wangbo/data/' 16 | things_root = data_root + 'FlyingThings3D' 17 | sintel_root = data_root + 'Sintel' 18 | kitti_root = data_root + 'KITTI' 19 | hd1k_root = data_root + 'HD1k' 20 | 21 | class FlowDataset(Dataset): 22 | def __init__(self, aug_params=None, sparse=False): 23 | 24 | self.sparse = sparse 25 | self.aug_params = None 26 | self.augmentor = None 27 | if aug_params is not None: 28 | if sparse: 29 | self.augmentor = SparseFlowAugmentor(**aug_params) 30 | else: 31 | self.augmentor = FlowAugmentor(**aug_params) 32 | self.image_list = [] 33 | self.flow_list = [] 34 | self.occ_list = [] 35 | 36 | def __getitem__(self, index): 37 | 38 | index = index % len(self.image_list) 39 | 40 | img1_dir = self.image_list[index][0] 41 | img2_dir = self.image_list[index][1] 42 | img3_dir = self.image_list[index][2] 43 | 44 | img1 = util.read_gen(img1_dir) 45 | img2 = util.read_gen(img2_dir) 46 | img3 = util.read_gen(img3_dir) 47 | img1 = np.array(img1).astype(np.uint8) 48 | img2 = np.array(img2).astype(np.uint8) 49 | img3 = np.array(img3).astype(np.uint8) 50 | 51 | flow1_dir = self.flow_list[index][0] 52 | flow2_dir = self.flow_list[index][1] 53 | 54 | if flow2_dir is not None: 55 | need_init = 0 56 | valid1 = None 57 | valid2 = None 58 | if self.sparse: 59 | flow1, valid1 = flow2, valid2 = util.readFlowKITTI(flow2_dir) 60 | if flow1_dir is not None: 61 | flow1, valid1 = util.readFlowKITTI(flow1_dir) 62 | need_init = 1 63 | else: 64 | flow1 = flow2 = util.read_gen(flow2_dir) 65 | if flow1_dir is not None: 66 | flow1 = util.read_gen(flow1_dir) 67 | need_init = 1 68 | 69 | flow1 = np.array(flow1).astype(np.float32) 70 | flow2 = np.array(flow2).astype(np.float32) 71 | 72 | if len(img2.shape) == 2: 73 | img1 = np.tile(img1[..., None], (1, 1, 3)) 74 | img2 = np.tile(img2[..., None], (1, 1, 3)) 75 | img3 = np.tile(img3[..., None], (1, 1, 3)) 76 | else: 77 | img1 = img1[..., :3] 78 | img2 = img2[..., :3] 79 | img3 = img3[..., :3] 80 | 81 | if self.augmentor is not None: 82 | if self.sparse: 83 | img1, img2, img3, flow1, valid1, flow2, valid2 = self.augmentor(img1, img2, img3, flow1, valid1, 84 | flow2, valid2) 85 | else: 86 | img1, img2, img3, flow1, flow2 = self.augmentor(img1, img2, img3, flow1, flow2) 87 | 88 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 89 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 90 | img3 = torch.from_numpy(img3).permute(2, 0, 1).float() 91 | flow1 = torch.from_numpy(flow1).permute(2, 0, 1).float() 92 | flow2 = torch.from_numpy(flow2).permute(2, 0, 1).float() 93 | 94 | if valid1 is not None: 95 | valid1 = torch.from_numpy(valid1) 96 | else: 97 | valid1 = (flow1[0].abs() < 1000) & (flow1[1].abs() < 1000) 98 | 99 | if need_init == 0: 100 | valid1 = torch.zeros_like(valid1).bool() 101 | 102 | if valid2 is not None: 103 | valid2 = torch.from_numpy(valid2) 104 | else: 105 | valid2 = (flow2[0].abs() < 1000) & (flow2[1].abs() < 1000) 106 | 107 | valid1 = valid1.float() 108 | valid2 = valid2.float() 109 | 110 | return img1, img2, img3, flow1, valid1, flow2, valid2 111 | 112 | def __rmul__(self, v): 113 | self.image_list = v * self.image_list 114 | self.flow_list = v * self.flow_list 115 | return self 116 | 117 | def __len__(self): 118 | return len(self.image_list) 119 | 120 | class Things(FlowDataset): 121 | def __init__(self, aug_params=None, root=things_root, split='train', ptype='clean'): 122 | super(Things, self).__init__(aug_params) 123 | 124 | split_dir = {'train': 'TRAIN', 'val': 'TEST'}[split] 125 | ptype_dir = {'clean': 'frames_cleanpass', 'final': 'frames_finalpass'}[ptype] 126 | 127 | for cam in ['left']: 128 | for direction in ['into_future', 'into_past']: 129 | image_paths = sorted(glob(osp.join(root, ptype_dir, split_dir, '*/*'))) 130 | image_paths = sorted([osp.join(f, cam) for f in image_paths]) 131 | 132 | flow_paths = sorted(glob(osp.join(root, 'optical_flow', split_dir, '*/*'))) 133 | flow_paths = sorted([osp.join(f, direction, cam) for f in flow_paths]) 134 | 135 | for ipath, fpath in zip(image_paths, flow_paths): 136 | images = sorted(glob(osp.join(ipath, '*.png'))) 137 | flows = sorted(glob(osp.join(fpath, '*.pfm'))) 138 | 139 | images = [img.replace('\\', '/') for img in images] 140 | flows = [flow.replace('\\', '/') for flow in flows] 141 | 142 | if split == 'train': 143 | for i in range(len(flows) - 2): 144 | if direction == 'into_future': 145 | self.image_list += [[images[i], images[i + 1], images[i + 2]]] 146 | self.flow_list += [[flows[i], flows[i + 1]]] 147 | 148 | elif direction == 'into_past': 149 | self.image_list += [[images[i + 2], images[i + 1], images[i]]] 150 | self.flow_list += [[flows[i + 2], flows[i + 1]]] 151 | 152 | elif split == 'val': 153 | if direction == 'into_future': 154 | self.image_list += [[images[3], images[4], images[5]]] 155 | self.flow_list += [[None, flows[4]]] 156 | 157 | elif direction == 'into_past': 158 | self.image_list += [[images[6], images[5], images[4]]] 159 | self.flow_list += [[None, flows[5]]] 160 | 161 | class Sintel(FlowDataset): 162 | def __init__(self, aug_params=None, root=sintel_root, split='train', ptype='clean'): 163 | super(Sintel, self).__init__(aug_params) 164 | 165 | split_dir = {'train': 'training', 'test': 'test'}[split] 166 | ptype_dir = {'clean': 'clean', 'final': 'final'}[ptype] 167 | 168 | image_root = osp.join(root, split_dir, ptype_dir) 169 | flow_root = osp.join(root, split_dir, 'flow') 170 | 171 | for scene in os.listdir(image_root): 172 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 173 | image_list = [img.replace('\\', '/') for img in image_list] 174 | 175 | for i in range(len(image_list) - 2): 176 | self.image_list += [[image_list[i], image_list[i + 1], image_list[i + 2]]] 177 | if split == 'train': 178 | flow_list = sorted(glob(osp.join(flow_root, scene, '*.flo'))) 179 | flow_list = [flow.replace('\\', '/') for flow in flow_list] 180 | self.flow_list += [[f1, f2] for f1, f2 in zip(flow_list[:-1], flow_list[1:])] 181 | 182 | elif split == 'test': 183 | self.flow_list += [[None, None] for _ in range(len(image_list) - 2)] 184 | 185 | class KITTI(FlowDataset): 186 | def __init__(self, aug_params=None, root=kitti_root, split='train'): 187 | super(KITTI, self).__init__(aug_params, sparse=True) 188 | 189 | split_dir = {'train': 'training', 'test': 'testing'}[split] 190 | 191 | root = osp.join(root, split_dir) 192 | 193 | imgs1 = sorted(glob(osp.join(root, 'image_2_multiview/*_09.png'))) 194 | imgs2 = sorted(glob(osp.join(root, 'image_2_multiview/*_10.png'))) 195 | imgs3 = sorted(glob(osp.join(root, 'image_2_multiview/*_11.png'))) 196 | 197 | imgs1 = [img.replace('\\', '/') for img in imgs1] 198 | imgs2 = [img.replace('\\', '/') for img in imgs2] 199 | imgs3 = [img.replace('\\', '/') for img in imgs3] 200 | 201 | for img1, img2, img3 in zip(imgs1, imgs2, imgs3): 202 | self.image_list += [[img1, img2, img3]] 203 | 204 | if split == 'train': 205 | flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 206 | flow_list = [flow.replace('\\', '/') for flow in flow_list] 207 | self.flow_list += [[None, flow] for flow in flow_list] 208 | elif split == 'test': 209 | self.flow_list += [[None, None] for _ in range(len(self.image_list))] 210 | 211 | class HD1K(FlowDataset): 212 | def __init__(self, aug_params=None, root=hd1k_root): 213 | super(HD1K, self).__init__(aug_params, sparse=True) 214 | 215 | seq_ix = 0 216 | while 1: 217 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 218 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 219 | 220 | if len(flows) == 0: 221 | break 222 | 223 | for i in range(len(flows) - 2): 224 | self.image_list += [[images[i], images[i + 1], images[i + 2]]] 225 | self.flow_list += [[flows[i], flows[i + 1]]] 226 | 227 | seq_ix += 1 228 | 229 | from contextlib import contextmanager 230 | @contextmanager 231 | def torch_distributed_zero_first(local_rank: int): 232 | """ 233 | Decorator to make all processes in distributed training wait for each local_master to do something. 234 | """ 235 | if local_rank not in [-1, 0]: 236 | torch.distributed.barrier() 237 | yield 238 | if local_rank == 0: 239 | torch.distributed.barrier() 240 | 241 | def fetch_dataloader(config): 242 | def prepare_data(config): 243 | 244 | if config.stage == 'things': 245 | aug_params = {'crop_size': config.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 246 | things_clean = Things(aug_params, split='train', ptype='clean') 247 | things_final = Things(aug_params, split='train', ptype='final') 248 | dataset = things_clean + things_final 249 | 250 | elif config.stage == 'sintel': 251 | aug_params = {'crop_size': config.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 252 | things_clean = Things(aug_params, split='train', ptype='clean') 253 | sintel_clean = Sintel(aug_params, split='train', ptype='clean') 254 | sintel_final = Sintel(aug_params, split='train', ptype='final') 255 | aug_params = {'crop_size': config.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True} 256 | kitti = KITTI(aug_params, split='train') 257 | aug_params = {'crop_size': config.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True} 258 | hd1k = HD1K(aug_params) 259 | dataset = 100 * sintel_clean + 100 * sintel_final + 200 * kitti + 5 * hd1k + 1 * things_clean 260 | 261 | elif config.stage == 'kitti': 262 | aug_params = {'crop_size': config.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 263 | things_clean = Things(aug_params, split='train', ptype='clean') 264 | sintel_clean = Sintel(aug_params, split='train', ptype='clean') 265 | aug_params = {'crop_size': config.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': False} 266 | kitti = KITTI(aug_params, split='train') 267 | aug_params = {'crop_size': config.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': False} 268 | hd1k = HD1K(aug_params) 269 | dataset = 20 * sintel_clean + 420 * kitti + 10 * hd1k + things_clean 270 | 271 | return dataset 272 | 273 | if config.is_ddp: 274 | with torch_distributed_zero_first(config.rank): 275 | dataset = prepare_data(config) 276 | else: 277 | dataset = prepare_data(config) 278 | 279 | batch_size_tmp = config.batch_size // config.world_size 280 | 281 | dataloder = DataLoader(dataset, 282 | batch_size=batch_size_tmp, 283 | pin_memory=True, 284 | sampler=torch.utils.data.distributed.DistributedSampler(dataset) if config.is_ddp else None, 285 | num_workers=8 if config.is_ddp else 0, 286 | drop_last=True) 287 | 288 | if config.is_master: 289 | print('Training with %d image pairs' % len(dataset)) 290 | 291 | return dataloder 292 | 293 | -------------------------------------------------------------------------------- /data/demo/image/000009_09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/data/demo/image/000009_09.png -------------------------------------------------------------------------------- /data/demo/image/000009_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/data/demo/image/000009_10.png -------------------------------------------------------------------------------- /data/demo/image/000009_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/data/demo/image/000009_11.png -------------------------------------------------------------------------------- /data/demo/pred/flow_000009_10_for_check.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/data/demo/pred/flow_000009_10_for_check.png -------------------------------------------------------------------------------- /model/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/corr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/__pycache__/corr.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/extractor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/__pycache__/extractor.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_splatflow.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/__pycache__/model_splatflow.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/softsplat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/__pycache__/softsplat.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/update.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/__pycache__/update.cpython-38.pyc -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from einops import rearrange 4 | 5 | class Attention(nn.Module): 6 | def __init__( 7 | self, 8 | *, 9 | dim, 10 | heads = 4, 11 | dim_head = 128, 12 | ): 13 | super().__init__() 14 | self.heads = heads 15 | self.scale = dim_head ** -0.5 16 | inner_dim = heads * dim_head 17 | 18 | self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) 19 | 20 | def forward(self, fmap): 21 | heads, b, c, h, w = self.heads, *fmap.shape 22 | 23 | q, k = self.to_qk(fmap).chunk(2, dim=1) 24 | 25 | q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) 26 | q = self.scale * q 27 | 28 | sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) 29 | 30 | sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') 31 | attn = sim.softmax(dim=-1) 32 | 33 | return attn 34 | 35 | class Aggregate(nn.Module): 36 | def __init__( 37 | self, 38 | dim, 39 | heads = 4, 40 | dim_head = 128, 41 | ): 42 | super().__init__() 43 | self.heads = heads 44 | self.scale = dim_head ** -0.5 45 | inner_dim = heads * dim_head 46 | 47 | self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) 48 | 49 | self.gamma = nn.Parameter(torch.zeros(1)) 50 | 51 | if dim != inner_dim: 52 | self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) 53 | else: 54 | self.project = None 55 | 56 | def forward(self, attn, fmap): 57 | heads, b, c, h, w = self.heads, *fmap.shape 58 | 59 | v = self.to_v(fmap) 60 | v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 63 | 64 | if self.project is not None: 65 | out = self.project(out) 66 | 67 | out = fmap + self.gamma * out 68 | 69 | return out 70 | -------------------------------------------------------------------------------- /model/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def bilinear_sampler(img, coords): 6 | 7 | H, W = img.shape[-2:] 8 | xgrid, ygrid = coords.split([1, 1], dim=-1) 9 | xgrid = 2 * xgrid / (W - 1) - 1 10 | ygrid = 2 * ygrid / (H - 1) - 1 11 | 12 | grid = torch.cat([xgrid, ygrid], dim=-1) 13 | img = F.grid_sample(img, grid, align_corners=True) 14 | 15 | return img 16 | 17 | 18 | class CorrBlock: 19 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 20 | self.num_levels = num_levels 21 | self.radius = radius 22 | self.corr_pyramid = [] 23 | 24 | # all pairs correlation 25 | corr = CorrBlock.corr(fmap1, fmap2) 26 | 27 | batch, h1, w1, dim, h2, w2 = corr.shape 28 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 29 | 30 | self.corr_pyramid.append(corr) 31 | for i in range(self.num_levels - 1): 32 | corr = F.avg_pool2d(corr, 2, stride=2) 33 | self.corr_pyramid.append(corr) 34 | 35 | def __call__(self, coords): 36 | r = self.radius 37 | coords = coords.permute(0, 2, 3, 1) 38 | batch, h1, w1, _ = coords.shape 39 | 40 | out_pyramid = [] 41 | for i in range(self.num_levels): 42 | corr = self.corr_pyramid[i] 43 | dx = torch.linspace(-r, r, 2 * r + 1) 44 | dy = torch.linspace(-r, r, 2 * r + 1) 45 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 46 | 47 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 48 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 49 | coords_lvl = centroid_lvl + delta_lvl 50 | 51 | corr = bilinear_sampler(corr, coords_lvl) 52 | 53 | corr = corr.view(batch, h1, w1, -1) 54 | out_pyramid.append(corr) 55 | 56 | out = torch.cat(out_pyramid, dim=-1) 57 | return out.permute(0, 3, 1, 2).contiguous().float() 58 | 59 | @staticmethod 60 | def corr(fmap1, fmap2): 61 | batch, dim, ht, wd = fmap1.shape 62 | fmap1 = fmap1.view(batch, dim, ht * wd) 63 | fmap2 = fmap2.view(batch, dim, ht * wd) 64 | 65 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 66 | corr = corr.view(batch, ht, wd, 1, ht, wd) 67 | return corr / torch.sqrt(torch.tensor(dim).float()) 68 | -------------------------------------------------------------------------------- /model/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ResidualBlock(nn.Module): 5 | def __init__(self, in_planes, planes, norm_fn='group', stride=1, rate=(1, 1)): 6 | super(ResidualBlock, self).__init__() 7 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=(3, 3), padding=(1, 1), stride=(stride, stride)) 8 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3, 3), padding=(1, 1)) 9 | self.relu = nn.ReLU(inplace=True) 10 | 11 | if norm_fn == 'batch': 12 | self.norm1 = nn.BatchNorm2d(planes) 13 | self.norm2 = nn.BatchNorm2d(planes) 14 | if not stride == 1 or in_planes != planes: 15 | self.norm3 = nn.BatchNorm2d(planes) 16 | 17 | elif norm_fn == 'instance': 18 | self.norm1 = nn.InstanceNorm2d(planes) 19 | self.norm2 = nn.InstanceNorm2d(planes) 20 | if not stride == 1 or in_planes != planes: 21 | self.norm3 = nn.InstanceNorm2d(planes) 22 | 23 | if stride == 1 and in_planes == planes: 24 | self.downsample = None 25 | else: 26 | self.downsample = nn.Sequential( 27 | nn.Conv2d(in_planes, planes, kernel_size=(1, 1), stride=(stride, stride)), 28 | self.norm3 29 | ) 30 | 31 | def forward(self, x): 32 | y = x 33 | y = self.relu(self.norm1(self.conv1(y))) 34 | y = self.relu(self.norm2(self.conv2(y))) 35 | 36 | if self.downsample is not None: 37 | x = self.downsample(x) 38 | 39 | return self.relu(x + y) 40 | 41 | 42 | class BasicEncoder(nn.Module): 43 | def __init__(self, output_dim=128, norm_fn='batch'): 44 | super(BasicEncoder, self).__init__() 45 | self.norm_fn = norm_fn 46 | 47 | if self.norm_fn == 'batch': 48 | self.norm1 = nn.BatchNorm2d(64) 49 | 50 | elif self.norm_fn == 'instance': 51 | self.norm1 = nn.InstanceNorm2d(64) 52 | 53 | 54 | self.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3)) 55 | self.relu1 = nn.ReLU(inplace=True) 56 | 57 | self.in_planes = 64 58 | self.layer1 = self._make_layer(64, stride=1) 59 | self.layer2 = self._make_layer(96, stride=2) 60 | self.layer3 = self._make_layer(128, stride=2) 61 | 62 | # output convolution 63 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=(1, 1)) 64 | 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 68 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 69 | if m.weight is not None: 70 | nn.init.constant_(m.weight, 1) 71 | if m.bias is not None: 72 | nn.init.constant_(m.bias, 0) 73 | 74 | def _make_layer(self, dim, stride=1): 75 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 76 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 77 | layers = (layer1, layer2) 78 | 79 | self.in_planes = dim 80 | return nn.Sequential(*layers) 81 | 82 | def forward(self, x): 83 | 84 | # if input is list, combine batch dimension 85 | is_list = isinstance(x, tuple) or isinstance(x, list) 86 | if is_list: 87 | batch_dim = x[0].shape[0] 88 | x = torch.cat(x, dim=0) 89 | 90 | x = self.conv1(x) 91 | x = self.norm1(x) 92 | x = self.relu1(x) 93 | 94 | x = self.layer1(x) 95 | x = self.layer2(x) 96 | x = self.layer3(x) 97 | 98 | x = self.conv2(x) 99 | 100 | if is_list: 101 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 102 | 103 | return x 104 | -------------------------------------------------------------------------------- /model/model_splatflow.py: -------------------------------------------------------------------------------- 1 | # @File: model_splatflow.py 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .extractor import BasicEncoder 10 | from .corr import CorrBlock 11 | from .attention import Attention 12 | from .softsplat import FunctionSoftsplat as forward_warping 13 | from .update import Update 14 | autocast = torch.cuda.amp.autocast 15 | fast_inference = False 16 | import torch.distributed as dist 17 | 18 | class SplatFlow(nn.Module): 19 | def __init__(self, config=None): 20 | super(SplatFlow, self).__init__() 21 | 22 | self.hdim = self.cdim = 128 23 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance') 24 | self.cnet = BasicEncoder(output_dim=self.hdim + self.cdim, norm_fn='batch') 25 | self.att = Attention(dim=self.cdim, heads=1, dim_head=self.cdim) 26 | 27 | if config != None and config.part_params_train: 28 | for p in self.parameters(): 29 | p.requires_grad = False 30 | 31 | self.update = Update(config, hidden_dim=self.hdim) 32 | 33 | def init_coord(self, fmap): 34 | f_shape = fmap.shape 35 | H, W = f_shape[-2:] 36 | y0, x0 = torch.meshgrid( 37 | torch.arange(H).to(fmap.device).float(), 38 | torch.arange(W).to(fmap.device).float()) 39 | coord = torch.stack([x0, y0], dim=0) # shape: (2, H, W) 40 | coord = coord.unsqueeze(0).repeat(f_shape[0], 1, 1, 1) 41 | return coord 42 | 43 | def initialize_flow(self, fmap): 44 | 45 | coords0 = self.init_coord(fmap) 46 | coords1 = self.init_coord(fmap) 47 | 48 | return coords0, coords1 49 | 50 | def cvx_upsample(self, data, mask): 51 | 52 | N, C, H, W = data.shape 53 | mask = mask.view(N, 1, 9, 8, 8, H, W) 54 | mask = torch.softmax(mask, dim=2) 55 | 56 | up_flow = F.unfold(data, [3, 3], padding=1) 57 | up_flow = up_flow.view(N, C, 9, 1, 1, H, W) 58 | 59 | up_flow = torch.sum(mask * up_flow, dim=2) 60 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 61 | return up_flow.reshape(N, C, 8 * H, 8 * W) 62 | 63 | def forward(self, image1, image2, iters=12, mf_t=None): 64 | 65 | image1 = 2 * (image1 / 255.0) - 1.0 66 | image2 = 2 * (image2 / 255.0) - 1.0 67 | image1 = image1.contiguous() 68 | image2 = image2.contiguous() 69 | 70 | with autocast(enabled=fast_inference): 71 | fmap1, fmap2 = self.fnet([image1, image2]) 72 | 73 | fmap1 = fmap1.float() 74 | fmap2 = fmap2.float() 75 | corr_fn = CorrBlock(fmap1, fmap2, radius=4) 76 | 77 | coords0, coords1 = self.initialize_flow(fmap1) 78 | 79 | with autocast(enabled=fast_inference): 80 | cnet = self.cnet(image1) 81 | net, inp = torch.split(cnet, [self.hdim, self.cdim], dim=1) 82 | net = torch.tanh(net) 83 | inp = torch.relu(inp) 84 | atte_s = self.att(inp) 85 | 86 | flow_predictions = [] 87 | 88 | for itr in range(iters): 89 | coords1 = coords1.detach() 90 | 91 | corr = corr_fn(coords1) 92 | 93 | flow = coords1 - coords0 94 | with autocast(enabled=fast_inference): 95 | net, up_mask, delta_flow, mf = self.update(net, inp, corr, flow, atte_s, mf_t) 96 | coords1 = coords1 + delta_flow 97 | 98 | if (fast_inference and (itr == iters - 1)) or (not fast_inference): 99 | 100 | flow_up = self.cvx_upsample(8 * (coords1 - coords0), up_mask) 101 | flow_predictions.append(flow_up) 102 | 103 | low = coords1 - coords0 104 | 105 | return flow_predictions, mf, low, fmap1, fmap2 106 | 107 | def Loss(self, flow_prs_01, gt_01, valid_01, flow_prs_12, gt_12, valid_12): 108 | 109 | MAX_FLOW = 400 110 | 111 | n_predictions = len(flow_prs_12) 112 | loss = 0 113 | 114 | valid_01 = ((valid_01 >= 0.5) & ((gt_01 ** 2).sum(dim=1).sqrt() < MAX_FLOW)).view(-1) >= 0.5 115 | valid_12 = ((valid_12 >= 0.5) & ((gt_12 ** 2).sum(dim=1).sqrt() < MAX_FLOW)).view(-1) >= 0.5 116 | 117 | for i in range(n_predictions): 118 | i_weight = 0.8 ** (n_predictions - i - 1) 119 | tmp_01 = ((flow_prs_01[i] - gt_01).abs().sum(dim=1)).view(-1)[valid_01] 120 | tmp_12 = ((flow_prs_12[i] - gt_12).abs().sum(dim=1)).view(-1)[valid_12] 121 | loss += i_weight * torch.cat([tmp_01, tmp_12]).mean() 122 | 123 | with torch.no_grad(): 124 | epe = torch.sum((flow_prs_12[-1] - gt_12) ** 2, dim=1).sqrt() 125 | epe = epe.view(-1)[valid_12.view(-1)] 126 | epe_sum = epe.sum() 127 | px1_sum = (epe < 1).float().sum() 128 | px3_sum = (epe < 3).float().sum() 129 | px5_sum = (epe < 5).float().sum() 130 | valid_12_sum = valid_12.sum() 131 | 132 | dist.all_reduce(epe_sum) 133 | dist.all_reduce(px1_sum) 134 | dist.all_reduce(px3_sum) 135 | dist.all_reduce(px5_sum) 136 | dist.all_reduce(valid_12_sum) 137 | 138 | epe = epe_sum / valid_12_sum 139 | px1 = px1_sum / valid_12_sum 140 | px3 = px3_sum / valid_12_sum 141 | px5 = px5_sum / valid_12_sum 142 | 143 | metric_list = [ 144 | ['epe', epe.item()], 145 | ['px1', px1.item()], 146 | ['px3', px3.item()], 147 | ['px5', px5.item()]] 148 | 149 | return loss, metric_list 150 | 151 | def infer(self, model, input_list, iters=12, gt_list=None, mf_01=None, low_01=None): 152 | 153 | img0, img1, img2 = input_list 154 | 155 | if img0 == None: 156 | flow_prs_12, mf_12, low_12, fmap1, fmap2 = model(img1, img2, iters=iters) 157 | return flow_prs_12 158 | 159 | if not (gt_list == None and mf_01 != None and low_01 != None): 160 | flow_prs_01, mf_01, low_01, fmap0, fmap1 = model(img0, img1, iters=iters) 161 | 162 | mf_t = forward_warping(mf_01, low_01) 163 | 164 | flow_prs_12, mf_12, low_12, fmap1, fmap2 = model(img1, img2, iters=iters, mf_t=mf_t) 165 | 166 | if gt_list != None: # training mode 167 | gt_01, valid_01, gt_12, valid_12 = gt_list 168 | loss, metric_list = self.Loss(flow_prs_01, gt_01, valid_01, flow_prs_12, gt_12, valid_12) 169 | return loss, metric_list 170 | 171 | return flow_prs_12 172 | 173 | def training_infer(self, model, step_data, device): 174 | 175 | img0, img1, img2, gt_01, valid_01, gt_12, valid_12 = [x.to(device) for x in step_data] 176 | 177 | loss, metric_list = model.module.infer( 178 | model, 179 | input_list=[img0, img1, img2], 180 | gt_list=[gt_01, valid_01, gt_12, valid_12]) 181 | 182 | return loss, metric_list 183 | 184 | -------------------------------------------------------------------------------- /model/softsplat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import cupy 6 | import re 7 | 8 | kernel_Softsplat_updateOutput = ''' 9 | extern "C" __global__ void kernel_Softsplat_updateOutput( 10 | const int n, 11 | const float* input, 12 | const float* flow, 13 | double* output 14 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 15 | const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); 16 | const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); 17 | const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); 18 | const int intX = ( intIndex ) % SIZE_3(output); 19 | 20 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 21 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 22 | 23 | int intNorthwestX = (int) (floor(fltOutputX)); 24 | int intNorthwestY = (int) (floor(fltOutputY)); 25 | int intNortheastX = intNorthwestX + 1; 26 | int intNortheastY = intNorthwestY; 27 | int intSouthwestX = intNorthwestX; 28 | int intSouthwestY = intNorthwestY + 1; 29 | int intSoutheastX = intNorthwestX + 1; 30 | int intSoutheastY = intNorthwestY + 1; 31 | 32 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); 33 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); 34 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); 35 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 36 | 37 | float value = VALUE_4(input, intN, intC, intY, intX); 38 | double valueNorthwest = value * fltNorthwest; 39 | double valueNortheast = value * fltNortheast; 40 | double valueSouthwest = value * fltSouthwest; 41 | double valueSoutheast = value * fltSoutheast; 42 | 43 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { 44 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], valueNorthwest); 45 | } 46 | 47 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { 48 | atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], valueNortheast); 49 | } 50 | 51 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { 52 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], valueSouthwest); 53 | } 54 | 55 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { 56 | atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], valueSoutheast); 57 | } 58 | }} 59 | ''' 60 | 61 | kernel_Softsplat_updateGradInput = ''' 62 | extern "C" __global__ void kernel_Softsplat_updateGradInput( 63 | const int n, 64 | const float* input, 65 | const float* flow, 66 | const float* gradOutput, 67 | float* gradInput, 68 | float* gradFlow 69 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 70 | const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); 71 | const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); 72 | const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); 73 | const int intX = ( intIndex ) % SIZE_3(gradInput); 74 | 75 | float fltGradInput = 0.0; 76 | 77 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 78 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 79 | 80 | int intNorthwestX = (int) (floor(fltOutputX)); 81 | int intNorthwestY = (int) (floor(fltOutputY)); 82 | int intNortheastX = intNorthwestX + 1; 83 | int intNortheastY = intNorthwestY; 84 | int intSouthwestX = intNorthwestX; 85 | int intSouthwestY = intNorthwestY + 1; 86 | int intSoutheastX = intNorthwestX + 1; 87 | int intSoutheastY = intNorthwestY + 1; 88 | 89 | float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); 90 | float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); 91 | float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); 92 | float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); 93 | 94 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 95 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; 96 | } 97 | 98 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 99 | fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; 100 | } 101 | 102 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 103 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; 104 | } 105 | 106 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 107 | fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; 108 | } 109 | 110 | gradInput[intIndex] = fltGradInput; 111 | } } 112 | ''' 113 | 114 | kernel_Softsplat_updateGradFlow = ''' 115 | extern "C" __global__ void kernel_Softsplat_updateGradFlow( 116 | const int n, 117 | const float* input, 118 | const float* flow, 119 | const float* gradOutput, 120 | float* gradInput, 121 | float* gradFlow 122 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 123 | float fltGradFlow = 0.0; 124 | 125 | const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); 126 | const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); 127 | const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); 128 | const int intX = ( intIndex ) % SIZE_3(gradFlow); 129 | 130 | float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); 131 | float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); 132 | 133 | int intNorthwestX = (int) (floor(fltOutputX)); 134 | int intNorthwestY = (int) (floor(fltOutputY)); 135 | int intNortheastX = intNorthwestX + 1; 136 | int intNortheastY = intNorthwestY; 137 | int intSouthwestX = intNorthwestX; 138 | int intSouthwestY = intNorthwestY + 1; 139 | int intSoutheastX = intNorthwestX + 1; 140 | int intSoutheastY = intNorthwestY + 1; 141 | 142 | float fltNorthwest = 0.0; 143 | float fltNortheast = 0.0; 144 | float fltSouthwest = 0.0; 145 | float fltSoutheast = 0.0; 146 | 147 | if (intC == 0) { 148 | fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY); 149 | fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY); 150 | fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); 151 | fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); 152 | 153 | } else if (intC == 1) { 154 | fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0)); 155 | fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); 156 | fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0)); 157 | fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); 158 | 159 | } 160 | 161 | for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { 162 | float fltInput = VALUE_4(input, intN, intChannel, intY, intX); 163 | 164 | if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { 165 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; 166 | } 167 | 168 | if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { 169 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; 170 | } 171 | 172 | if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { 173 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; 174 | } 175 | 176 | if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { 177 | fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; 178 | } 179 | } 180 | 181 | gradFlow[intIndex] = fltGradFlow; 182 | } } 183 | ''' 184 | 185 | def cupy_kernel(strFunction, objVariables): 186 | strKernel = globals()[strFunction] 187 | 188 | while True: 189 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 190 | 191 | if objMatch is None: 192 | break 193 | # end 194 | 195 | intArg = int(objMatch.group(2)) 196 | 197 | strTensor = objMatch.group(4) 198 | intSizes = objVariables[strTensor].size() 199 | 200 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 201 | # end 202 | 203 | while True: 204 | objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) 205 | 206 | if objMatch is None: 207 | break 208 | # end 209 | 210 | intArgs = int(objMatch.group(2)) 211 | strArgs = objMatch.group(4).split(',') 212 | 213 | strTensor = strArgs[0] 214 | intStrides = objVariables[strTensor].stride() 215 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 216 | 217 | strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') 218 | # end 219 | 220 | while True: 221 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 222 | 223 | if objMatch is None: 224 | break 225 | # end 226 | 227 | intArgs = int(objMatch.group(2)) 228 | strArgs = objMatch.group(4).split(',') 229 | 230 | strTensor = strArgs[0] 231 | intStrides = objVariables[strTensor].stride() 232 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 233 | 234 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 235 | # end 236 | 237 | return strKernel 238 | # end 239 | 240 | @cupy.memoize(for_each_device=True) 241 | def cupy_launch(strFunction, strKernel): 242 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 243 | # end 244 | 245 | class _FunctionSoftsplat(torch.autograd.Function): 246 | @staticmethod 247 | def forward(self, input, flow): 248 | intSamples = input.shape[0] 249 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 250 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 251 | 252 | assert(intFlowDepth == 2) 253 | assert(intInputHeight == intFlowHeight) 254 | assert(intInputWidth == intFlowWidth) 255 | 256 | input = input.contiguous(); assert(input.is_cuda == True) 257 | flow = flow.contiguous(); assert(flow.is_cuda == True) 258 | 259 | output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]).double() 260 | 261 | if input.is_cuda == True: 262 | n = output.nelement() 263 | cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { 264 | 'input': input, 265 | 'flow': flow, 266 | 'output': output 267 | }))( 268 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 269 | block=tuple([ 512, 1, 1 ]), 270 | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ] 271 | ) 272 | 273 | elif input.is_cuda == False: 274 | raise NotImplementedError() 275 | 276 | # end 277 | 278 | self.save_for_backward(input, flow) 279 | 280 | return output.float() 281 | # end 282 | 283 | @staticmethod 284 | def backward(self, gradOutput): 285 | input, flow = self.saved_tensors 286 | 287 | intSamples = input.shape[0] 288 | intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] 289 | intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] 290 | 291 | assert(intFlowDepth == 2) 292 | assert(intInputHeight == intFlowHeight) 293 | assert(intInputWidth == intFlowWidth) 294 | 295 | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) 296 | 297 | gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None 298 | gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None 299 | 300 | if input.is_cuda == True: 301 | if gradInput is not None: 302 | n = gradInput.nelement() 303 | cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { 304 | 'input': input, 305 | 'flow': flow, 306 | 'gradOutput': gradOutput, 307 | 'gradInput': gradInput, 308 | 'gradFlow': gradFlow 309 | }))( 310 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 311 | block=tuple([ 512, 1, 1 ]), 312 | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] 313 | ) 314 | # end 315 | 316 | if gradFlow is not None: 317 | n = gradFlow.nelement() 318 | cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { 319 | 'input': input, 320 | 'flow': flow, 321 | 'gradOutput': gradOutput, 322 | 'gradInput': gradInput, 323 | 'gradFlow': gradFlow 324 | }))( 325 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 326 | block=tuple([ 512, 1, 1 ]), 327 | args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] 328 | ) 329 | # end 330 | 331 | elif input.is_cuda == False: 332 | raise NotImplementedError() 333 | 334 | # end 335 | 336 | return gradInput, gradFlow 337 | # end 338 | # end 339 | 340 | def FunctionSoftsplat(tenInput, tenFlow, tenMetric=None, strType='average'): 341 | assert(tenMetric is None or tenMetric.shape[1] == 1) 342 | assert(strType in ['summation', 'average', 'linear', 'softmax']) 343 | 344 | if strType == 'average': 345 | tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) 346 | 347 | elif strType == 'linear': 348 | tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) 349 | 350 | elif strType == 'softmax': 351 | tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) 352 | 353 | # end 354 | 355 | tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) 356 | 357 | if strType != 'summation': 358 | tenNormalize = tenOutput[:, -1:, :, :] 359 | 360 | tenNormalize[tenNormalize == 0.0] = 1.0 361 | 362 | tenOutput = tenOutput[:, :-1, :, :] / tenNormalize 363 | # end 364 | 365 | return tenOutput 366 | # end 367 | 368 | class ModuleSoftsplat(torch.nn.Module): 369 | def __init__(self, strType): 370 | super().__init__() 371 | 372 | self.strType = strType 373 | # end 374 | 375 | def forward(self, tenInput, tenFlow, tenMetric): 376 | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) 377 | # end 378 | # end 379 | -------------------------------------------------------------------------------- /model/update.py: -------------------------------------------------------------------------------- 1 | # @File: update.py 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .attention import Aggregate 10 | 11 | class BasicMotionEncoder(nn.Module): 12 | def __init__(self): 13 | super(BasicMotionEncoder, self).__init__() 14 | corr_levels = 4 15 | corr_radius = 4 16 | cor_planes = corr_levels * (2*corr_radius + 1)**2 17 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 18 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 19 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 20 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 21 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 22 | 23 | def forward(self, flow, corr): 24 | cor = F.relu(self.convc1(corr)) 25 | cor = F.relu(self.convc2(cor)) 26 | flo = F.relu(self.convf1(flow)) 27 | flo = F.relu(self.convf2(flo)) 28 | 29 | cor_flo = torch.cat([cor, flo], dim=1) 30 | out = F.relu(self.conv(cor_flo)) 31 | return torch.cat([out, flow], dim=1) 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class FlowHead(nn.Module): 63 | def __init__(self, input_dim=128, hidden_dim=256): 64 | super(FlowHead, self).__init__() 65 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 66 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | def forward(self, x): 70 | return self.conv2(self.relu(self.conv1(x))) 71 | 72 | class Update(nn.Module): 73 | def __init__(self, config=None, hidden_dim=128): 74 | super().__init__() 75 | self.encoder = BasicMotionEncoder() 76 | 77 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim) 78 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 79 | self.mask = nn.Sequential( 80 | nn.Conv2d(128, 256, 3, padding=1), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 83 | 84 | self.aggregator = Aggregate(dim=128, dim_head=128, heads=1) 85 | 86 | if config != None and config.part_params_train: 87 | for p in self.parameters(): 88 | p.requires_grad = False 89 | 90 | self.gru_sp = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim + hidden_dim * 2) 91 | self.flow_head_sp = FlowHead(hidden_dim, hidden_dim=256) 92 | self.mask_sp = nn.Sequential( 93 | nn.Conv2d(128, 256, 3, padding=1), 94 | nn.ReLU(inplace=True), 95 | nn.Conv2d(256, 64 * 9, 1, padding=0)) 96 | 97 | def forward(self, net, inp, corr, flow, atte_s, mf_t=None): 98 | mf = self.encoder(flow, corr) 99 | mf_s = self.aggregator(atte_s, mf) 100 | 101 | if mf_t != None: 102 | inp_cat = torch.cat([inp, mf, mf_s, mf_t], dim=1) 103 | net = self.gru_sp(net, inp_cat) 104 | delta_flow = self.flow_head_sp(net) 105 | mask = .25 * self.mask_sp(net) 106 | else: 107 | inp_cat = torch.cat([inp, mf, mf_s], dim=1) 108 | net = self.gru(net, inp_cat) 109 | delta_flow = self.flow_head(net) 110 | mask = .25 * self.mask(net) 111 | return net, mask, delta_flow, mf 112 | -------------------------------------------------------------------------------- /model/util/__pycache__/augmentor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/util/__pycache__/augmentor.cpython-38.pyc -------------------------------------------------------------------------------- /model/util/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wwsource/SplatFlow/41c53ffa5a7b665bf23ef0fa1afc3fb69837b79d/model/util/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /model/util/augmentor.py: -------------------------------------------------------------------------------- 1 | # @File: augmentor.py 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import cv2 9 | cv2.setNumThreads(0) 10 | cv2.ocl.setUseOpenCL(False) 11 | from torchvision.transforms import ColorJitter 12 | 13 | 14 | class FlowAugmentor: 15 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 16 | 17 | # spatial augmentation params 18 | self.crop_size = crop_size 19 | self.min_scale = min_scale 20 | self.max_scale = max_scale 21 | self.spatial_aug_prob = 0.8 22 | self.stretch_prob = 0.8 23 | self.max_stretch = 0.2 24 | 25 | # flip augmentation params 26 | self.do_flip = do_flip 27 | self.h_flip_prob = 0.5 28 | self.v_flip_prob = 0.1 29 | 30 | # photometric augmentation params 31 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14) 32 | self.asymmetric_color_aug_prob = 0.2 33 | self.eraser_aug_prob = 0.5 34 | 35 | def color_transform(self, img1, img2, img3): 36 | """ Photometric augmentation """ 37 | 38 | # asymmetric 39 | if np.random.rand() < self.asymmetric_color_aug_prob: 40 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 41 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 42 | img3 = np.array(self.photo_aug(Image.fromarray(img3)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2, img3], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2, img3 = np.split(image_stack, 3, axis=0) 49 | 50 | return img1, img2, img3 51 | 52 | def eraser_transform(self, img2, img3, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img2.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img3.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img3[y0:y0 + dy, x0:x0 + dx, :] = mean_color 64 | 65 | return img2, img3 66 | 67 | def spatial_transform(self, img1, img2, img3, flow1, flow2): 68 | # randomly sample scale 69 | ht, wd = img2.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | img3 = cv2.resize(img3, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | 90 | flow1 = cv2.resize(flow1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 91 | flow1 = flow1 * [scale_x, scale_y] 92 | 93 | flow2 = cv2.resize(flow2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 94 | flow2 = flow2 * [scale_x, scale_y] 95 | 96 | if self.do_flip: 97 | if np.random.rand() < self.h_flip_prob: # h-flip 98 | img1 = img1[:, ::-1] 99 | img2 = img2[:, ::-1] 100 | img3 = img3[:, ::-1] 101 | flow1 = flow1[:, ::-1] * [-1.0, 1.0] 102 | flow2 = flow2[:, ::-1] * [-1.0, 1.0] 103 | 104 | if np.random.rand() < self.v_flip_prob: # v-flip 105 | img1 = img1[::-1, :] 106 | img2 = img2[::-1, :] 107 | img3 = img3[::-1, :] 108 | flow1 = flow1[::-1, :] * [1.0, -1.0] 109 | flow2 = flow2[::-1, :] * [1.0, -1.0] 110 | 111 | if img2.shape[0] - self.crop_size[0] > 0: 112 | y0 = np.random.randint(0, img2.shape[0] - self.crop_size[0]) 113 | else: 114 | y0 = 0 115 | if img2.shape[1] - self.crop_size[1] > 0: 116 | x0 = np.random.randint(0, img2.shape[1] - self.crop_size[1]) 117 | else: 118 | x0 = 0 119 | 120 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 121 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 122 | img3 = img3[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 123 | flow1 = flow1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 124 | flow2 = flow2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 125 | 126 | return img1, img2, img3, flow1, flow2 127 | 128 | def __call__(self, img1, img2, img3, flow1, flow2): 129 | img1, img2, img3 = self.color_transform(img1, img2, img3) 130 | # img2, img3 = self.eraser_transform(img2, img3) 131 | img1, img2, img3, flow1, flow2 = self.spatial_transform(img1, img2, img3, flow1, flow2) 132 | 133 | img1 = np.ascontiguousarray(img1) 134 | img2 = np.ascontiguousarray(img2) 135 | img3 = np.ascontiguousarray(img3) 136 | flow1 = np.ascontiguousarray(flow1) 137 | flow2 = np.ascontiguousarray(flow2) 138 | 139 | return img1, img2, img3, flow1, flow2 140 | 141 | 142 | class SparseFlowAugmentor: 143 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 144 | # spatial augmentation params 145 | self.crop_size = crop_size 146 | self.min_scale = min_scale 147 | self.max_scale = max_scale 148 | self.spatial_aug_prob = 0.8 149 | self.stretch_prob = 0.8 150 | self.max_stretch = 0.2 151 | 152 | # flip augmentation params 153 | self.do_flip = do_flip 154 | self.h_flip_prob = 0.5 155 | self.v_flip_prob = 0.1 156 | 157 | # photometric augmentation params 158 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14) 159 | self.asymmetric_color_aug_prob = 0.2 160 | self.eraser_aug_prob = 0.5 161 | 162 | def color_transform(self, img1, img2, img3): 163 | image_stack = np.concatenate([img1, img2, img3], axis=0) 164 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 165 | img1, img2, img3 = np.split(image_stack, 3, axis=0) 166 | return img1, img2, img3 167 | 168 | def eraser_transform(self, img2, img3): 169 | ht, wd = img2.shape[:2] 170 | if np.random.rand() < self.eraser_aug_prob: 171 | mean_color = np.mean(img3.reshape(-1, 3), axis=0) 172 | for _ in range(np.random.randint(1, 3)): 173 | x0 = np.random.randint(0, wd) 174 | y0 = np.random.randint(0, ht) 175 | dx = np.random.randint(50, 100) 176 | dy = np.random.randint(50, 100) 177 | img3[y0:y0 + dy, x0:x0 + dx, :] = mean_color 178 | 179 | return img2, img3 180 | 181 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 182 | ht, wd = flow.shape[:2] 183 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 184 | coords = np.stack(coords, axis=-1) 185 | 186 | coords = coords.reshape(-1, 2).astype(np.float32) 187 | flow = flow.reshape(-1, 2).astype(np.float32) 188 | valid = valid.reshape(-1).astype(np.float32) 189 | 190 | coords0 = coords[valid >= 1] 191 | flow0 = flow[valid >= 1] 192 | 193 | ht1 = int(round(ht * fy)) 194 | wd1 = int(round(wd * fx)) 195 | 196 | coords1 = coords0 * [fx, fy] 197 | flow1 = flow0 * [fx, fy] 198 | 199 | xx = np.round(coords1[:, 0]).astype(np.int32) 200 | yy = np.round(coords1[:, 1]).astype(np.int32) 201 | 202 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 203 | xx = xx[v] 204 | yy = yy[v] 205 | flow1 = flow1[v] 206 | 207 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 208 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 209 | 210 | flow_img[yy, xx] = flow1 211 | valid_img[yy, xx] = 1 212 | 213 | return flow_img, valid_img 214 | 215 | def spatial_transform(self, img1, img2, img3, flow1, valid1, flow2, valid2): 216 | # randomly sample scale 217 | 218 | ht, wd = img2.shape[:2] 219 | min_scale = np.maximum( 220 | (self.crop_size[0] + 1) / float(ht), 221 | (self.crop_size[1] + 1) / float(wd)) 222 | 223 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 224 | scale_x = np.clip(scale, min_scale, None) 225 | scale_y = np.clip(scale, min_scale, None) 226 | 227 | if np.random.rand() < self.spatial_aug_prob: 228 | # rescale the images 229 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 230 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 231 | img3 = cv2.resize(img3, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 232 | flow1, valid1 = self.resize_sparse_flow_map(flow1, valid1, fx=scale_x, fy=scale_y) 233 | flow2, valid2 = self.resize_sparse_flow_map(flow2, valid2, fx=scale_x, fy=scale_y) 234 | 235 | if self.do_flip: 236 | if np.random.rand() < 0.5: # h-flip 237 | img1 = img1[:, ::-1] 238 | img2 = img2[:, ::-1] 239 | img3 = img3[:, ::-1] 240 | flow1 = flow1[:, ::-1] * [-1.0, 1.0] 241 | valid1 = valid1[:, ::-1] 242 | flow2 = flow2[:, ::-1] * [-1.0, 1.0] 243 | valid2 = valid2[:, ::-1] 244 | 245 | margin_y = 20 246 | margin_x = 50 247 | 248 | y0 = np.random.randint(0, img2.shape[0] - self.crop_size[0] + margin_y) 249 | x0 = np.random.randint(-margin_x, img2.shape[1] - self.crop_size[1] + margin_x) 250 | 251 | y0 = np.clip(y0, 0, img2.shape[0] - self.crop_size[0]) 252 | x0 = np.clip(x0, 0, img2.shape[1] - self.crop_size[1]) 253 | 254 | img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 255 | img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 256 | img3 = img3[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 257 | flow1 = flow1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 258 | valid1 = valid1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 259 | flow2 = flow2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 260 | valid2 = valid2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]] 261 | return img1, img2, img3, flow1, valid1, flow2, valid2 262 | 263 | def __call__(self, img1, img2, img3, flow1, valid1, flow2, valid2): 264 | img1, img2, img3 = self.color_transform(img1, img2, img3) 265 | # img2, img3 = self.eraser_transform(img2, img3) 266 | img1, img2, img3, flow1, valid1, flow2, valid2 = self.spatial_transform(img1, img2, img3, flow1, valid1, flow2, valid2) 267 | 268 | img1 = np.ascontiguousarray(img1) 269 | img2 = np.ascontiguousarray(img2) 270 | img3 = np.ascontiguousarray(img3) 271 | flow1 = np.ascontiguousarray(flow1) 272 | valid1 = np.ascontiguousarray(valid1) 273 | flow2 = np.ascontiguousarray(flow2) 274 | valid2 = np.ascontiguousarray(valid2) 275 | 276 | return img1, img2, img3, flow1, valid1, flow2, valid2 277 | -------------------------------------------------------------------------------- /model/util/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | import torch.nn.functional as F 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readImageKITTI(name_path): 13 | return Image.open(name_path) 14 | 15 | def readFlowKITTI(filename): 16 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 17 | flow = flow[:,:,::-1].astype(np.float32) 18 | flow, valid = flow[:, :, :2], flow[:, :, 2] 19 | flow = (flow - 2**15) / 64.0 20 | return flow, valid 21 | 22 | def writeFlowKITTI(filename, uv): 23 | uv = 64.0 * uv + 2 ** 15 24 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 25 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 26 | cv2.imwrite(filename, uv[..., ::-1]) 27 | 28 | def readFlow(fn): 29 | """ Read .flo file in Middlebury format""" 30 | # Code adapted from: 31 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 32 | 33 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 34 | # print 'fn = %s'%(fn) 35 | with open(fn, 'rb') as f: 36 | magic = np.fromfile(f, np.float32, count=1) 37 | if 202021.25 != magic: 38 | print('Magic number incorrect. Invalid .flo file') 39 | return None 40 | else: 41 | w = np.fromfile(f, np.int32, count=1) 42 | h = np.fromfile(f, np.int32, count=1) 43 | # print 'Reading %d x %d flo file\n' % (w, h) 44 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 45 | # Reshape data into 3D array (columns, rows, bands) 46 | # The reshape here is for visualization, the original code is (w,h,2) 47 | return np.resize(data, (int(h), int(w), 2)) 48 | 49 | def readPFM(file): 50 | file = open(file, 'rb') 51 | 52 | color = None 53 | width = None 54 | height = None 55 | scale = None 56 | endian = None 57 | 58 | header = file.readline().rstrip() 59 | if header == b'PF': 60 | color = True 61 | elif header == b'Pf': 62 | color = False 63 | else: 64 | raise Exception('Not a PFM file.') 65 | 66 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 67 | if dim_match: 68 | width, height = map(int, dim_match.groups()) 69 | else: 70 | raise Exception('Malformed PFM header.') 71 | 72 | scale = float(file.readline().rstrip()) 73 | if scale < 0: # little-endian 74 | endian = '<' 75 | scale = -scale 76 | else: 77 | endian = '>' # big-endian 78 | 79 | data = np.fromfile(file, endian + 'f') 80 | shape = (height, width, 3) if color else (height, width) 81 | 82 | data = np.reshape(data, shape) 83 | data = np.flipud(data) 84 | return data 85 | 86 | 87 | def writeFlow(filename, uv, v=None): 88 | """ Write optical flow to file. 89 | 90 | If v is None, uv is assumed to contain both u and v channels, 91 | stacked in depth. 92 | Original code by Deqing Sun, adapted from Daniel Scharstein. 93 | """ 94 | nBands = 2 95 | 96 | if v is None: 97 | assert (uv.ndim == 3) 98 | assert (uv.shape[2] == 2) 99 | u = uv[:, :, 0] 100 | v = uv[:, :, 1] 101 | else: 102 | u = uv 103 | 104 | assert (u.shape == v.shape) 105 | height, width = u.shape 106 | f = open(filename, 'wb') 107 | # write the header 108 | f.write(TAG_CHAR) 109 | np.array(width).astype(np.int32).tofile(f) 110 | np.array(height).astype(np.int32).tofile(f) 111 | # arrange into matrix form 112 | tmp = np.zeros((height, width * nBands)) 113 | tmp[:, np.arange(width) * 2] = u 114 | tmp[:, np.arange(width) * 2 + 1] = v 115 | tmp.astype(np.float32).tofile(f) 116 | f.close() 117 | 118 | def read_gen(file_name, pil=False): 119 | ext = splitext(file_name)[-1] 120 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 121 | return Image.open(file_name) 122 | elif ext == '.bin' or ext == '.raw': 123 | return np.load(file_name) 124 | elif ext == '.flo': 125 | return readFlow(file_name).astype(np.float32) 126 | elif ext == '.pfm': 127 | flow = readPFM(file_name).astype(np.float32) 128 | if len(flow.shape) == 2: 129 | return flow 130 | else: 131 | return flow[:, :, :-1] 132 | return [] 133 | 134 | class InputPadder: 135 | 136 | def __init__(self, dims, base=8): 137 | self.ht, self.wd = dims[-2:] 138 | pad_ht = (((self.ht // base) + 1) * base - self.ht) % base 139 | pad_wd = (((self.wd // base) + 1) * base - self.wd) % base 140 | self._pad = [0, pad_wd, 0, pad_ht] 141 | 142 | def pad(self, *inputs): 143 | outputs = [] 144 | 145 | for x in inputs: 146 | 147 | bhw_mode = 0 148 | 149 | if len(x.shape) == 3: 150 | bhw_mode = 1 151 | x = x.unsqueeze(1) 152 | 153 | x = F.pad(x, self._pad, mode='replicate') 154 | if bhw_mode: 155 | x = x.squeeze(1) 156 | outputs.append(x) 157 | 158 | return outputs 159 | 160 | def unpad(self, *inputs): 161 | 162 | ht, wd = inputs[0].shape[-2:] 163 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 164 | 165 | return [x[..., c[0]:c[1], c[2]:c[3]] for x in inputs] 166 | 167 | -------------------------------------------------------------------------------- /run_demo.py: -------------------------------------------------------------------------------- 1 | # @File: run_demo.py 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | import os 7 | 8 | import torch 9 | from model.model_splatflow import SplatFlow 10 | from model.util.util import * 11 | 12 | print('SplatFlow demo start...') 13 | 14 | model = SplatFlow() 15 | model.load_state_dict(torch.load('exp/0-pretrain/splatflow-kitti-50000.pth'), strict=True) 16 | model.eval().cuda() 17 | print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 18 | 19 | img_names = ['data/demo/image/000009_%02d.png'%i for i in [9, 10, 11]] 20 | 21 | imgs = [torch.from_numpy(np.array(readImageKITTI(img_names[i])).astype(np.uint8)).permute(2, 0, 1).float()[None].cuda() for i in range(3)] 22 | img0, img1, img2 = imgs 23 | padder = InputPadder(img1.shape) 24 | img0, img1, img2 = padder.pad(img0, img1, img2) 25 | 26 | with torch.no_grad(): 27 | outputs = model.infer( 28 | model, 29 | input_list=[img0, img1, img2], 30 | iters=24) 31 | 32 | pr_flow2d = padder.unpad(outputs[0])[0][0].permute(1, 2, 0).cpu().numpy() 33 | 34 | output_path = 'exp/demo' 35 | if not os.path.exists(output_path): 36 | os.makedirs(output_path) 37 | writeFlowKITTI(f'{output_path}/flow_000009_10.png', pr_flow2d) 38 | 39 | print('Success!!!') 40 | -------------------------------------------------------------------------------- /run_test.py: -------------------------------------------------------------------------------- 1 | # @File: run_test.py 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | import argparse 7 | import torch.nn.functional as F 8 | 9 | from model.model_splatflow import SplatFlow 10 | from data.dataset import * 11 | 12 | def get_stamp(second): 13 | m, s = divmod(second, 60) 14 | h, m = divmod(m, 60) 15 | d, h = divmod(h, 24) 16 | return '{}/{}/{}'.format(int(d), int(h), int(m)) 17 | 18 | class InputPadder: 19 | 20 | def __init__(self, dims, mode='sintel', base=8): 21 | self.ht, self.wd = dims[-2:] 22 | pad_ht = (((self.ht // base) + 1) * base - self.ht) % base 23 | pad_wd = (((self.wd // base) + 1) * base - self.wd) % base 24 | if mode == 'sintel': 25 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] 26 | else: 27 | self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] 28 | 29 | def pad(self, *inputs): 30 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 31 | 32 | def unpad(self, x): 33 | ht, wd = x.shape[-2:] 34 | c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] 35 | return x[..., c[0]:c[1], c[2]:c[3]] 36 | 37 | @torch.no_grad() 38 | def validate_things(model): 39 | 40 | print('Start testing splatflow on Things...') 41 | 42 | for ptype in ['clean', 'final']: 43 | 44 | epe_list = [] 45 | 46 | val_dataset = Things(split='val', ptype=ptype) 47 | data_num = len(val_dataset) 48 | print(f'Dataset length {data_num}') 49 | 50 | for val_id in range(data_num): 51 | 52 | img1, img2, img3, gt1, valid1, gt2, valid2 = val_dataset[val_id] 53 | 54 | img1 = img1[None].cuda() 55 | img2 = img2[None].cuda() 56 | img3 = img3[None].cuda() 57 | padder = InputPadder(img2.shape) 58 | img1, img2, img3 = padder.pad(img1, img2, img3) 59 | 60 | flow_prs_23 = model.infer( 61 | model, 62 | input_list=[img1, img2, img3], 63 | iters=24) 64 | pr2 = padder.unpad(flow_prs_23[-1][0]).cpu() 65 | 66 | epe = torch.sum((pr2 - gt2) ** 2, dim=0).sqrt() 67 | epe = epe.view(-1) 68 | val2 = valid2.view(-1) >= 0.5 69 | epe_list.append(epe[val2].numpy()) 70 | 71 | if val_id % 50 == 0: 72 | print(f'{ptype}: {val_id}/{data_num}') 73 | 74 | epe_all = np.concatenate(epe_list) 75 | epe = np.mean(epe_all) 76 | 77 | print("Things (%s) EPE: %f" % (ptype, epe)) 78 | 79 | @torch.no_grad() 80 | def validate_kitti(model): 81 | 82 | out_list, epe_list = [], [] 83 | val_dataset = KITTI(split='train') 84 | data_num = len(val_dataset) 85 | print(f'Dataset length {data_num}') 86 | 87 | for val_id in range(data_num): 88 | img1, img2, img3, gt1, valid1, gt2, valid2 = val_dataset[val_id] 89 | img1 = img1[None].cuda() 90 | img2 = img2[None].cuda() 91 | img3 = img3[None].cuda() 92 | padder = InputPadder(img2.shape, mode='kitti') 93 | img1, img2, img3 = padder.pad(img1, img2, img3) 94 | 95 | flow_prs_23 = model.infer( 96 | model, 97 | input_list=[img1, img2, img3], 98 | iters=24) 99 | pr2 = padder.unpad(flow_prs_23[-1][0]).cpu() 100 | 101 | epe = torch.sum((pr2 - gt2) ** 2, dim=0).sqrt() 102 | mag = torch.sum(gt2 ** 2, dim=0).sqrt() 103 | epe = epe.view(-1) 104 | mag = mag.view(-1) 105 | val2 = valid2.view(-1) >= 0.5 106 | 107 | out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() 108 | epe_list.append(epe[val2].mean().item()) 109 | out_list.append(out[val2].cpu().numpy()) 110 | 111 | if val_id % 20 == 0: 112 | print(f'kitti: {val_id}/{data_num}') 113 | 114 | epe_list = np.array(epe_list) 115 | out_list = np.concatenate(out_list) 116 | 117 | epe = np.mean(epe_list) 118 | fl = 100 * np.mean(out_list) 119 | 120 | print("Validation KITTI: %f, %f" % (epe, fl)) 121 | 122 | 123 | if __name__ == '__main__': 124 | 125 | parser = argparse.ArgumentParser() 126 | 127 | parser.add_argument('--dataset', default='things') 128 | parser.add_argument('--pre_name_path', default='exp/train_things_full/model.pth') 129 | 130 | args = parser.parse_args() 131 | 132 | model = SplatFlow() 133 | pre_replace_list = [ 134 | ['update_block', 'update'], ['module.', ''], 135 | ['gru_tf', 'gru_sp'], ['flow_head_tf', 'flow_head_sp'], ['mask_tf', 'mask_sp'] 136 | ] 137 | checkpoint = torch.load(args.pre_name_path) 138 | for l in pre_replace_list: 139 | checkpoint = {k.replace(l[0], l[1]): v for k, v in checkpoint.items()} 140 | model.load_state_dict(checkpoint, strict=False) 141 | print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 142 | model.eval().cuda() 143 | 144 | if args.dataset == 'things': 145 | validate_things(model) 146 | 147 | if args.dataset == 'kitti': 148 | validate_kitti(model) 149 | 150 | 151 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | # @File: run_train.py 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | import argparse 7 | import torch.nn as nn 8 | import torch.distributed as dist 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | import time 11 | 12 | from model.model_splatflow import SplatFlow 13 | from data.dataset import * 14 | 15 | def get_stamp(second): 16 | m, s = divmod(second, 60) 17 | h, m = divmod(m, 60) 18 | d, h = divmod(h, 24) 19 | return '{}/{}/{}'.format(int(d), int(h), int(m)) 20 | 21 | if __name__ == '__main__': 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument('--exp_name', default='train_things_part') 26 | parser.add_argument('--stage', default='things') 27 | parser.add_argument('--pre_name_path', default='exp/0-pretrain/gma-things.pth') 28 | parser.add_argument('--part_params_train', action='store_true') 29 | parser.add_argument('--image_size', type=int, nargs='+', default=[400, 720]) 30 | parser.add_argument('--batch_size', type=int, default=8) 31 | parser.add_argument('--lr', type=float, default=.000125) 32 | parser.add_argument('--wdecay', type=float, default=.0001) 33 | parser.add_argument('--step_max', type=int, default=100000) 34 | parser.add_argument('--log_train', type=int, default=100) 35 | parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') 36 | 37 | args = parser.parse_args() 38 | 39 | args.rank = rank = args.local_rank 40 | args.is_master = is_master = True if rank in [0, -1] else False 41 | args.is_ddp = is_ddp = True if rank != -1 else False 42 | 43 | if is_ddp: 44 | torch.cuda.set_device(rank) 45 | device = torch.device('cuda', rank) 46 | dist.init_process_group(backend='nccl', init_method='env://') 47 | args.world_size = world_size = dist.get_world_size() 48 | else: 49 | exit() 50 | 51 | model = SplatFlow(args).to(device) 52 | pre_replace_list = [ 53 | ['update_block', 'update'], ['module.', ''], 54 | ['gru_tf', 'gru_sp'], ['flow_head_tf', 'flow_head_sp'], ['mask_tf', 'mask_sp'] 55 | ] 56 | checkpoint = torch.load(args.pre_name_path, map_location=device) 57 | for l in pre_replace_list: 58 | checkpoint = {k.replace(l[0], l[1]): v for k, v in checkpoint.items()} 59 | model.load_state_dict(checkpoint, strict=False) 60 | 61 | if is_master: print(f"Parameter Count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") 62 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device).train() 63 | model = DDP(model, device_ids=[rank], output_device=(rank)) 64 | 65 | train_loader = fetch_dataloader(args) 66 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) 67 | optimizer.zero_grad() 68 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=args.lr, total_steps=args.step_max + 100, 69 | pct_start=0.05, cycle_momentum=False, anneal_strategy='linear') 70 | 71 | t_1 = t_0 = time.time() 72 | step_i = 1 73 | epoch = 1 74 | 75 | while epoch: 76 | if is_ddp: 77 | train_loader.sampler.set_epoch(epoch) 78 | epoch = epoch + 1 79 | 80 | for step_data in train_loader: 81 | 82 | loss, metric_list = model.module.training_infer(model, step_data, device) 83 | 84 | loss.backward() 85 | 86 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 87 | optimizer.step() 88 | scheduler.step() 89 | optimizer.zero_grad() 90 | 91 | if step_i % args.log_train == 0: 92 | if args.is_master: 93 | t_now = time.time() 94 | t_have = t_now - t_0 95 | t_period = t_now - t_1 96 | t_1 = t_now 97 | t_left = t_period * (args.step_max - step_i) / args.log_train 98 | time_stamp = 'time: [' + get_stamp(t_have) + ',' + get_stamp(t_left) + ']' 99 | metric_log_list = [(mr[0] + ': %.3f') % mr[1] for mr in metric_list] 100 | metric_log = ' '.join(metric_log_list) 101 | print(f'{args.exp_name}\ttrain [{step_i}/{args.step_max}]\tloss: %.3f\t%s\tlr: %.6f\t' % (loss.item(), metric_log, scheduler.get_last_lr()[-1]) + time_stamp) 102 | 103 | dist.barrier() 104 | 105 | step_i += 1 106 | if step_i > args.step_max: 107 | if args.is_master: 108 | model_path = f'exp/{args.exp_name}' 109 | if not os.path.exists(model_path): 110 | os.mkdir(model_path) 111 | model_name_path = f'{model_path}/model.pth' 112 | torch.save(model.state_dict(), model_name_path) 113 | print('training finished') 114 | 115 | dist.barrier() 116 | 117 | exit() 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /script/demo.sh: -------------------------------------------------------------------------------- 1 | # @File: demo.sh 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | python run_demo.py 7 | 8 | -------------------------------------------------------------------------------- /script/test_kitti.sh: -------------------------------------------------------------------------------- 1 | # @File: test_kitti.sh 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | python run_test.py \ 7 | --dataset kitti \ 8 | --pre_name_path exp/train_things_full/model.pth \ 9 | -------------------------------------------------------------------------------- /script/test_things.sh: -------------------------------------------------------------------------------- 1 | # @File: test_things.sh 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | python run_test.py \ 7 | --dataset things \ 8 | --pre_name_path exp/train_things_full/model.pth \ -------------------------------------------------------------------------------- /script/train_kitti.sh: -------------------------------------------------------------------------------- 1 | # @File: train_kitti.sh 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | export CUDA_VISIBLE_DEVICES="0, 1, 2" 7 | python -m torch.distributed.launch --nproc_per_node 3 run_train.py \ 8 | --exp_name train_kitti \ 9 | --stage kitti \ 10 | --pre_name_path exp/train_sintel/model.pth \ 11 | --image_size 368 768 \ 12 | --batch_size 6 \ 13 | --lr 0.000125 \ 14 | --wdecay 0.00001 \ 15 | --step_max 50000 \ 16 | --log_train 100 \ -------------------------------------------------------------------------------- /script/train_sintel.sh: -------------------------------------------------------------------------------- 1 | # @File: train_sintel.sh 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | export CUDA_VISIBLE_DEVICES="0, 1, 2, 3, 4, 5" 7 | python -m torch.distributed.launch --nproc_per_node 6 run_train.py \ 8 | --exp_name train_sintel \ 9 | --stage sintel \ 10 | --pre_name_path exp/train_things_full/model.pth \ 11 | --image_size 368 768 \ 12 | --batch_size 6 \ 13 | --lr 0.000125 \ 14 | --wdecay 0.00001 \ 15 | --step_max 120000 \ 16 | --log_train 100 \ -------------------------------------------------------------------------------- /script/train_things.sh: -------------------------------------------------------------------------------- 1 | # @File: train_things.sh 2 | # @Project: SplatFlow 3 | # @Author : wangbo 4 | # @Time : 2024.07.03 5 | 6 | export CUDA_VISIBLE_DEVICES="0, 1, 2, 3, 4, 5, 6, 7" 7 | python -m torch.distributed.launch --nproc_per_node 8 run_train.py \ 8 | --exp_name train_things_part \ 9 | --part_params_train \ 10 | --stage things \ 11 | --pre_name_path exp/0-pretrain/gma-things.pth \ 12 | --image_size 400 720 \ 13 | --batch_size 8 \ 14 | --lr 0.000125 \ 15 | --wdecay 0.0001 \ 16 | --step_max 100000 \ 17 | --log_train 100 \ 18 | 19 | python -m torch.distributed.launch --nproc_per_node 8 run_train.py \ 20 | --exp_name train_things_full \ 21 | --stage things \ 22 | --pre_name_path exp/train_things_part/model.pth \ 23 | --image_size 400 720 \ 24 | --batch_size 8 \ 25 | --lr 0.000125 \ 26 | --wdecay 0.00001 \ 27 | --step_max 100000 \ 28 | --log_train 100 \ 29 | --------------------------------------------------------------------------------