├── LICENSE ├── README.md ├── augs.py ├── checkpoints └── .gitignore ├── configs ├── config.yaml ├── data │ ├── KITTI.yaml │ └── NYU.yaml ├── hydra.yaml ├── loss │ └── MSMSE.yaml ├── metric │ ├── MetricALL.yaml │ └── RMSE.yaml ├── net │ └── PMP.yaml ├── optim │ └── AdamW.yaml └── sched │ └── lr │ └── NoiseOneCycleCosMo.yaml ├── criteria.py ├── datas └── .gitignore ├── datasets ├── __init__.py ├── kitti.py └── nyu.py ├── demo.py ├── demo.sh ├── environment.yml ├── exts ├── bp_cuda.cpp ├── bp_cuda.h ├── bp_cuda_kernel.cu └── setup.py ├── models ├── BPNet.py ├── __init__.py └── utils.py ├── optimizers.py ├── rpnloss.py ├── run.sh ├── schedulers.py ├── test.py ├── train_distill.py ├── utils.py └── utils_infer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jie Tang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DMD³C: Distilling Monocular Foundation Model for Fine-grained Depth Completion 2 | **Official Code for the CVPR 2025 Paper** 3 | **"[CVPR 2025] Distilling Monocular Foundation Model for Fine-grained Depth Completion"** 4 | 5 | [📄 Paper on arXiv](https://arxiv.org/abs/2503.16970) 6 | 7 | --- 8 | 9 | ## 🆕 Update Log 10 | 11 | - **[2025.04.23]** We have released the **2rd stage training code**! 🎉 12 | - **[2025.04.11]** We have released the **inference code**! 🎉 13 | 14 | ## ✅ To Do 15 | 16 | - [ ] 📦 Easy-to-use **data generation pipeline** 17 | --- 18 | 19 |
20 | DMD3C Results 21 |
22 | 23 | --- 24 | 25 | ## 🔍 Overview 26 | 27 | DMD³C introduces a novel framework for **fine-grained depth completion** by distilling knowledge from **monocular foundation models**. This approach significantly enhances depth estimation accuracy in sparse data, especially in regions without ground-truth supervision. 28 | 29 | --- 30 | ![image](https://github.com/user-attachments/assets/f24eef8e-5dc2-483a-bb70-67671ff5e4e9) 31 | 32 | 33 | --- 34 | 35 | 36 | 37 | ## 🚀 Getting Started 38 | 39 | ### 1. Clone Base Repository 40 | 41 | ```bash 42 | git clone https://github.com/kakaxi314/BP-Net.git 43 | ``` 44 | 45 | ### 2. Copy This Repo into the BP-Net Directory 46 | 47 | ```bash 48 | cp DMD3C/* BP-Net/ 49 | cd BP-Net/DMD3C/ 50 | ``` 51 | 52 | ### 3. Prepare KITTI Raw Data 53 | 54 | Download any sequence from the **KITTI Raw dataset**, which includes: 55 | 56 | - Camera intrinsics 57 | - Velodyne point cloud 58 | - Image sequences 59 | 60 | Make sure the structure follows the **standard KITTI format**. 61 | 62 | ### 4. Modify the Sequence in `demo.py` for Inference 63 | 64 | Open `demo.py` and go to **line 338**, where you can modify the input sequence path according to your downloaded KITTI data. 65 | 66 | ```python 67 | # demo.py (Line 338) 68 | sequence = "/path/to/your/kitti/sequence" 69 | ``` 70 | 71 | Download pre-trained weights: 72 | 73 | ``` 74 | wget https://github.com/Sharpiless/DMD3C/releases/download/pretrain-checkpoints/dmd3c_distillation_depth_anything_v2.pth 75 | mv dmd3c_distillation_depth_anything_v2.pth checkpoints 76 | ``` 77 | 78 | Run inference: 79 | ```bash 80 | bash demo.sh 81 | ``` 82 | 83 | You will get results like this: 84 | 85 | ![supp-video 00_00_00-00_00_30](https://github.com/user-attachments/assets/a1412bca-c368-4d19-a081-79eeabaa2901) 86 | 87 | ### 5. Train on KITTI 88 | 89 | Runing monocular depth estimation for all KITTI-raw images. Data structure: 90 | ``` 91 | ├── datas/kitti/raw/ 92 | │ ├── 2011_09_26 93 | │ │ ├── 2011_09_26_drive_0001_sync 94 | │ │ │ ├── image_02 95 | │ │ │ │ ├── data/*.png 96 | │ │ │ │ ├── disp/*.png 97 | │ │ │ ├── image_03 98 | │ │ ├── 2011_09_26_drive_0002_sync....... 99 | ``` 100 | 101 | Where disparity images are stored in gray-scale. 102 | 103 | Download pre-trained checkpoitns: 104 | ``` 105 | wget https://github.com/Sharpiless/DMD3C/releases/download/pretrain-checkpoints/pretrained_mixed_singleview_256.pth 106 | mv pretrained_mixed_singleview_256.pth checkpoints 107 | ``` 108 | 109 | Zero-shot preformance on KITTI valiation set: 110 | 111 | | Training Data | RMSE | MAE | iRMSE | REL | 112 | |----------------------|----------|----------|----------|----------| 113 | | Single-view Images | 1.4251 | 0.3722 | 0.0056 | 0.0235 | 114 | 115 | 116 | Run metric-finetuning on KITTI dataset: 117 | ``` 118 | torchrun --nproc_per_node=4 --master_port 4321 train_distill.py \ 119 | gpus=[0,1,2,3] num_workers=4 name=DMD3D_BP_KITTI \ 120 | ++chpt=checkpoints/pretrained_mixed_singleview_256.pth \ 121 | net=PMP data=KITTI \ 122 | lr=5e-4 train_batch_size=2 test_batch_size=1 \ 123 | sched/lr=NoiseOneCycleCosMo sched.lr.policy.max_momentum=0.90 \ 124 | nepoch=30 test_epoch=25 ++net.sbn=true 125 | ``` 126 | -------------------------------------------------------------------------------- /augs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: augs.py 4 | # @Project: BP-Net 5 | # @Author: jie 6 | # @Time: 2021/3/14 8:27 PM 7 | 8 | 9 | import numpy as np 10 | 11 | __all__ = [ 12 | 'Compose', 13 | 'Norm', 14 | 'Jitter', 15 | 'Flip', 16 | ] 17 | 18 | 19 | class Compose(object): 20 | """ 21 | Sequential operations on input images, (i.e. rgb, lidar and depth). 22 | """ 23 | 24 | def __init__(self, transforms): 25 | self.transforms = transforms 26 | 27 | def __call__(self, rgb, disp, lidar, depth, K_cam): 28 | for t in self.transforms: 29 | rgb, disp, lidar, depth, K_cam = t(rgb, disp, lidar, depth, K_cam) 30 | return rgb, disp, lidar, depth, K_cam 31 | 32 | 33 | class Norm(object): 34 | """ 35 | normalize rgb image. 36 | """ 37 | 38 | def __init__(self, mean, std): 39 | self.mean = np.array(mean) 40 | self.std = np.array(std) 41 | 42 | def __call__(self, rgb, disp, lidar, depth, K_cam): 43 | rgb = (rgb - self.mean) / self.std 44 | disp = disp / 255 45 | return rgb, disp, lidar, depth, K_cam 46 | 47 | class Jitter(object): 48 | """ 49 | borrow from https://github.com/kujason/avod/blob/master/avod/datasets/kitti/kitti_aug.py 50 | """ 51 | 52 | def __call__(self, rgb, disp, lidar, depth, K_cam): 53 | pca = compute_pca(rgb) 54 | rgb = add_pca_jitter(rgb, pca) 55 | return rgb, disp, lidar, depth, K_cam 56 | 57 | 58 | 59 | class Flip(object): 60 | """ 61 | random horizontal flip of images. 62 | """ 63 | 64 | def __call__(self, rgb, disp, lidar, depth, CamK): 65 | width = rgb.shape[1] 66 | flip = bool(np.random.randint(2)) 67 | if flip: 68 | rgb = rgb[:, ::-1, :] 69 | lidar = lidar[:, ::-1, :] 70 | depth = depth[:, ::-1, :] 71 | CamK[0, 2] = (width - 1) - CamK[0, 2] 72 | return rgb, disp, lidar, depth, CamK 73 | 74 | 75 | def compute_pca(image): 76 | """ 77 | calculate PCA of image 78 | """ 79 | 80 | reshaped_data = image.reshape(-1, 3) 81 | reshaped_data = (reshaped_data / 255.0).astype(np.float32) 82 | covariance = np.cov(reshaped_data.T) 83 | e_vals, e_vecs = np.linalg.eigh(covariance) 84 | pca = np.sqrt(e_vals) * e_vecs 85 | return pca 86 | 87 | 88 | def add_pca_jitter(img_data, pca): 89 | """ 90 | add a multiple of principle components with Gaussian noise 91 | """ 92 | new_img_data = np.copy(img_data).astype(np.float32) / 255.0 93 | magnitude = np.random.randn(3) * 0.1 94 | noise = (pca * magnitude).sum(axis=1) 95 | 96 | new_img_data = new_img_data + noise 97 | np.clip(new_img_data, 0.0, 1.0, out=new_img_data) 98 | new_img_data = (new_img_data * 255).astype(np.uint8) 99 | 100 | return new_img_data -------------------------------------------------------------------------------- /checkpoints/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra 3 | - optim: AdamW 4 | - sched/lr: NoiseOneCycleCosMo 5 | - data: KITTI 6 | - net: PMP 7 | - loss: MSMSE 8 | - metric: RMSE 9 | - _self_ 10 | # 11 | 12 | nepoch: 30 13 | test_epoch: 25 14 | test_iter: 1000 15 | lr: 0.001 16 | gpus: 17 | - 0 18 | train_batch_size: 2 19 | test_batch_size: 2 20 | num_workers: 4 21 | manual_seed: 1 22 | vis_iter: 100 23 | start_epoch: 0 24 | gpu_id: ??? 25 | name: ??? 26 | device: 27 | _target_: torch.device 28 | device: cuda:${gpu_id} -------------------------------------------------------------------------------- /configs/data/KITTI.yaml: -------------------------------------------------------------------------------- 1 | mean: 2 | - 90.9950 3 | - 96.2278 4 | - 94.3213 5 | std: 6 | - 79.2382 7 | - 80.5267 8 | - 82.1483 9 | train_sample_size: 85898 10 | test_sample_size: 1000 11 | height: 256 12 | width: 1216 13 | 14 | trainset: 15 | _target_: datasets.KITTI 16 | mode: 'train' 17 | RandCrop: true 18 | path: ${data.path} 19 | height: ${data.height} 20 | width: ${data.width} 21 | mean: ${data.mean} 22 | std: ${data.std} 23 | 24 | testset: 25 | _target_: datasets.KITTI 26 | mode: 'selval' 27 | RandCrop: false 28 | path: ${data.path} 29 | height: ${data.height} 30 | width: ${data.width} 31 | mean: ${data.mean} 32 | std: ${data.std} 33 | 34 | path: datas/kitti 35 | mul_factor: 1.0 -------------------------------------------------------------------------------- /configs/data/NYU.yaml: -------------------------------------------------------------------------------- 1 | mean: 2 | - 117. 3 | - 97. 4 | - 91. 5 | std: 6 | - 70. 7 | - 71. 8 | - 74. 9 | train_sample_size: 47584 10 | test_sample_size: 654 11 | height: 256 12 | width: 320 13 | 14 | 15 | trainset: 16 | _target_: datasets.NYU 17 | mode: 'train' 18 | path: ${data.path} 19 | num_sample: ${data.npoints} 20 | mul_factor: ${data.mul_factor} 21 | num_mask: ${data.num_mask} 22 | scale_kcam: true 23 | 24 | 25 | testset: 26 | _target_: datasets.NYU 27 | mode: 'val' 28 | path: ${data.path} 29 | num_sample: ${data.npoints} 30 | mul_factor: ${data.mul_factor} 31 | num_mask: ${data.num_mask} 32 | scale_kcam: false 33 | 34 | 35 | path: datas/nyu 36 | npoints: 500 37 | mul_factor: 10.0 38 | num_mask: 1 -------------------------------------------------------------------------------- /configs/hydra.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: outputs/${name}/${now:%Y-%m-%d_%H-%M-%S} 4 | output_subdir: . 5 | job: 6 | chdir: false 7 | # job_logging: 8 | # handlers: 9 | # file: 10 | # filename: ${name}.log -------------------------------------------------------------------------------- /configs/loss/MSMSE.yaml: -------------------------------------------------------------------------------- 1 | _target_: criteria.MSMSE -------------------------------------------------------------------------------- /configs/metric/MetricALL.yaml: -------------------------------------------------------------------------------- 1 | _target_: criteria.MetricALL 2 | mul_factor: ${data.mul_factor} -------------------------------------------------------------------------------- /configs/metric/RMSE.yaml: -------------------------------------------------------------------------------- 1 | _target_: criteria.RMSE 2 | mul_factor: ${data.mul_factor} -------------------------------------------------------------------------------- /configs/net/PMP.yaml: -------------------------------------------------------------------------------- 1 | # Propagation Field Baseline Relu6 Clip EMA 2 | 3 | name: PMP 4 | 5 | model: 6 | _target_: models.Pre_MF_Post 7 | 8 | 9 | ema: 10 | _target_: models.EMA 11 | decay: 0.9999 12 | 13 | 14 | clip: 15 | _target_: utils.clip_grad_norm_ 16 | max_norm: 0.1 -------------------------------------------------------------------------------- /configs/optim/AdamW.yaml: -------------------------------------------------------------------------------- 1 | _target_: optimizers.AdamW 2 | lr: ${lr} 3 | weight_decay: 0.05 4 | betas: 5 | - 0.9 6 | - 0.999 -------------------------------------------------------------------------------- /configs/sched/lr/NoiseOneCycleCosMo.yaml: -------------------------------------------------------------------------------- 1 | #Noise One Cycle Cos With cycle_momentum 2 | policy: 3 | _target_: schedulers.NoiseLR 4 | lr_sched: OneCycleLR 5 | anneal_strategy: cos 6 | epochs: ${nepoch} 7 | div_factor: 40.0 8 | final_div_factor: 0.1 9 | last_epoch: -1 10 | max_lr: ${lr} 11 | cycle_momentum: true 12 | base_momentum: 0.85 13 | max_momentum: 0.95 14 | pct_start: 0.1 15 | noise_pct: 0.1 16 | steps_per_epoch: ${sched.lr.steps_per_epoch} 17 | 18 | iter: true 19 | 20 | steps_per_epoch: 21 | _target_: utils.FloorDiv 22 | _args_: 23 | - _target_: utils.CeilDiv 24 | _args_: 25 | - ${data.train_sample_size} 26 | - _target_: builtins.len 27 | _args_: 28 | - ${gpus} 29 | - ${train_batch_size} -------------------------------------------------------------------------------- /criteria.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @Filename: criteria.py 4 | # @Project: BP-Net 5 | # @Author: jie 6 | # @Time: 2021/3/14 7:51 PM 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | __all__ = [ 12 | 'RMSE', 13 | 'MSMSE', 14 | 'MetricALL', 15 | ] 16 | 17 | 18 | class RMSE(nn.Module): 19 | 20 | def __init__(self, mul_factor=1.): 21 | super().__init__() 22 | self.mul_factor = mul_factor 23 | self.metric_name = [ 24 | 'RMSE', 25 | ] 26 | 27 | def forward(self, outputs, target): 28 | outputs = outputs / self.mul_factor 29 | target = target / self.mul_factor 30 | val_pixels = (target > 1e-3).float() 31 | err = (target * val_pixels - outputs * val_pixels) ** 2 32 | loss = torch.sum(err.view(err.size(0), 1, -1), -1, keepdim=True) 33 | cnt = torch.sum(val_pixels.view(val_pixels.size(0), 1, -1), -1, keepdim=True) 34 | return torch.sqrt(loss / cnt).mean(), 35 | 36 | 37 | class MSMSE(nn.Module): 38 | """ 39 | Multi-Scale MSE 40 | """ 41 | 42 | def __init__(self, deltas=(2 ** (-5 * 2), 2 ** (-4 * 2), 2 ** (-3 * 2), 2 ** (-2 * 2), 2 ** (-1 * 2), 1)): 43 | super().__init__() 44 | self.deltas = deltas 45 | 46 | def mse(self, est, gt): 47 | valid = (gt > 1e-3).float() 48 | loss = est * valid - gt * valid 49 | return (loss ** 2).mean() 50 | 51 | def forward(self, outputs, target): 52 | loss = [delta * self.mse(ests, target) for ests, delta in zip(outputs, self.deltas)] 53 | return loss 54 | 55 | 56 | class MetricALL(nn.Module): 57 | def __init__(self, mul_factor): 58 | super().__init__() 59 | self.t_valid = 0.0001 60 | self.mul_factor = mul_factor 61 | self.metric_name = [ 62 | 'RMSE', 'MAE', 'iRMSE', 'iMAE', 'REL', 'D^1', 'D^2', 'D^3', 'D102', 'D105', 'D110' 63 | ] 64 | 65 | def forward(self, pred, gt): 66 | with torch.no_grad(): 67 | pred = pred.detach() / self.mul_factor 68 | gt = gt.detach() / self.mul_factor 69 | pred_inv = 1.0 / (pred + 1e-8) 70 | gt_inv = 1.0 / (gt + 1e-8) 71 | 72 | # For numerical stability 73 | mask = gt > self.t_valid 74 | # num_valid = mask.sum() 75 | B = mask.size(0) 76 | num_valid = torch.sum(mask.view(B, -1), -1, keepdim=True) 77 | 78 | # pred = pred[mask] 79 | # gt = gt[mask] 80 | pred = pred * mask 81 | gt = gt * mask 82 | 83 | # pred_inv = pred_inv[mask] 84 | # gt_inv = gt_inv[mask] 85 | pred_inv = pred_inv * mask 86 | gt_inv = gt_inv * mask 87 | 88 | # pred_inv[pred <= self.t_valid] = 0.0 89 | # gt_inv[gt <= self.t_valid] = 0.0 90 | 91 | # RMSE / MAE 92 | diff = pred - gt 93 | diff_abs = torch.abs(diff) 94 | diff_sqr = torch.pow(diff, 2) 95 | 96 | rmse = torch.sum(diff_sqr.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 97 | rmse = torch.sqrt(rmse) 98 | 99 | mae = torch.sum(diff_abs.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 100 | 101 | # iRMSE / iMAE 102 | diff_inv = pred_inv - gt_inv 103 | diff_inv_abs = torch.abs(diff_inv) 104 | diff_inv_sqr = torch.pow(diff_inv, 2) 105 | 106 | irmse = torch.sum(diff_inv_sqr.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 107 | irmse = torch.sqrt(irmse) 108 | 109 | imae = torch.sum(diff_inv_abs.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 110 | 111 | # Rel 112 | rel = diff_abs / (gt + 1e-8) 113 | rel = torch.sum(rel.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 114 | 115 | # delta 116 | r1 = gt / (pred + 1e-8) 117 | r2 = pred / (gt + 1e-8) 118 | ratio = torch.max(r1, r2) 119 | 120 | ratio = torch.max(ratio, 10000 * (1 - mask.float())) 121 | 122 | del_1 = (ratio < 1.25).type_as(ratio) 123 | del_2 = (ratio < 1.25 ** 2).type_as(ratio) 124 | del_3 = (ratio < 1.25 ** 3).type_as(ratio) 125 | del_102 = (ratio < 1.02).type_as(ratio) 126 | del_105 = (ratio < 1.05).type_as(ratio) 127 | del_110 = (ratio < 1.10).type_as(ratio) 128 | 129 | del_1 = torch.sum(del_1.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 130 | del_2 = torch.sum(del_2.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 131 | del_3 = torch.sum(del_3.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 132 | del_102 = torch.sum(del_102.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 133 | del_105 = torch.sum(del_105.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 134 | del_110 = torch.sum(del_110.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8) 135 | 136 | result = [rmse, mae, irmse, imae, rel, del_1, del_2, del_3, del_102, del_105, del_110] 137 | 138 | return result 139 | -------------------------------------------------------------------------------- /datas/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nyu import * 2 | from .kitti import * -------------------------------------------------------------------------------- /datasets/kitti.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import numpy as np 4 | import glob 5 | from PIL import Image 6 | import torch 7 | import augs 8 | from torchvision.utils import save_image 9 | 10 | __all__ = [ 11 | 'KITTI', 12 | ] 13 | 14 | def read_calib_file(filepath): 15 | """Read in a calibration file and parse into a dictionary.""" 16 | data = {} 17 | 18 | with open(filepath, 'r') as f: 19 | for line in f.readlines(): 20 | key, value = line.split(':', 1) 21 | # The only non-float values in these files are dates, which 22 | # we don't care about anyway 23 | try: 24 | data[key] = np.array([float(x) for x in value.split()]) 25 | except ValueError: 26 | pass 27 | 28 | return data 29 | 30 | 31 | class KITTI(torch.utils.data.Dataset): 32 | """ 33 | kitti depth completion dataset: http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion 34 | """ 35 | 36 | def __init__(self, path='datas/kitti', mode='train', height=256, width=1216, mean=(90.9950, 96.2278, 94.3213), 37 | std=(79.2382, 80.5267, 82.1483), RandCrop=False, tp_min=50, *args, **kwargs): 38 | self.base_dir = path 39 | self.height = height 40 | self.width = width 41 | self.mode = mode 42 | if mode == 'train': 43 | self.transform = augs.Compose([ 44 | augs.Jitter(), 45 | augs.Flip(), 46 | augs.Norm(mean=mean, std=std), 47 | ]) 48 | else: 49 | self.transform = augs.Compose([ 50 | augs.Norm(mean=mean, std=std), 51 | ]) 52 | self.RandCrop = RandCrop and mode == 'train' 53 | self.tp_min = tp_min 54 | if mode in ['train', 'val']: 55 | self.depth_path = os.path.join(self.base_dir, 'data_depth_annotated', mode) 56 | self.lidar_path = os.path.join(self.base_dir, 'data_depth_velodyne', mode) 57 | self.depths = list(sorted(glob.iglob(self.depth_path + "/**/*.png", recursive=True))) 58 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/**/*.png", recursive=True))) 59 | elif mode == 'selval': 60 | self.depth_path = os.path.join(self.base_dir, 'val_selection_cropped', 'groundtruth_depth') 61 | self.lidar_path = os.path.join(self.base_dir, 'val_selection_cropped', 'velodyne_raw') 62 | self.image_path = os.path.join(self.base_dir, 'val_selection_cropped', 'image') 63 | self.depths = list(sorted(glob.iglob(self.depth_path + "/*.png", recursive=True))) 64 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/*.png", recursive=True))) 65 | self.images = list(sorted(glob.iglob(self.image_path + "/*.png", recursive=True))) 66 | elif mode == 'test': 67 | self.lidar_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'velodyne_raw') 68 | self.image_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'image') 69 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/*.png", recursive=True))) 70 | self.images = list(sorted(glob.iglob(self.image_path + "/*.png", recursive=True))) 71 | self.depths = self.lidars 72 | else: 73 | raise ValueError("Unknown mode: {}".format(mode)) 74 | assert (len(self.depths) == len(self.lidars)) 75 | self.names = [os.path.split(path)[-1] for path in self.depths] 76 | x = np.arange(width) 77 | y = np.arange(height) 78 | xx, yy = np.meshgrid(x, y) 79 | xy = np.stack((xx, yy), axis=-1) 80 | self.xy = xy 81 | 82 | def __len__(self): 83 | return len(self.depths) 84 | 85 | def __getitem__(self, index): 86 | return self.get_item(index) 87 | 88 | def get_item(self, index): 89 | depth = self.pull_DEPTH(self.depths[index]) 90 | depth = np.expand_dims(depth, axis=2) 91 | lidar = self.pull_DEPTH(self.lidars[index]) 92 | lidar = np.expand_dims(lidar, axis=2) 93 | K_cam = self.pull_K_cam(index).astype(np.float32) 94 | 95 | file_names = self.depths[index].split('/') 96 | if self.mode in ['train', 'val']: 97 | rgb_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0], file_names[-5], 98 | file_names[-2], 'data', file_names[-1]) 99 | disp_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0], file_names[-5], 100 | file_names[-2], 'disp', file_names[-1]) 101 | elif self.mode in ['selval', 'test']: 102 | rgb_path = self.images[index] 103 | disp_path = self.images[index] 104 | else: 105 | raise ValueError("Unknown mode: {}".format(self.mode)) 106 | rgb = self.pull_RGB(rgb_path) 107 | disp = self.pull_DISP(disp_path) 108 | rgb = rgb.astype(np.float32) 109 | disp = disp.astype(np.float32)[:, :, None] 110 | lidar = lidar.astype(np.float32) 111 | depth = depth.astype(np.float32) 112 | if self.transform: 113 | rgb, disp, lidar, depth, K_cam = self.transform(rgb, disp, lidar, depth, K_cam) 114 | rgb = rgb.transpose(2, 0, 1).astype(np.float32) 115 | disp = disp.transpose(2, 0, 1).astype(np.float32) 116 | lidar = lidar.transpose(2, 0, 1).astype(np.float32) 117 | depth = depth.transpose(2, 0, 1).astype(np.float32) 118 | tp = rgb.shape[1] - self.height 119 | lp = (rgb.shape[2] - self.width) // 2 120 | if self.RandCrop: 121 | tp = random.randint(self.tp_min, tp) 122 | lp = random.randint(0, rgb.shape[2] - self.width) 123 | rgb = rgb[:, tp:tp + self.height, lp:lp + self.width] 124 | disp = disp[:, tp:tp + self.height, lp:lp + self.width] 125 | lidar = lidar[:, tp:tp + self.height, lp:lp + self.width] 126 | depth = depth[:, tp:tp + self.height, lp:lp + self.width] 127 | K_cam[0, 2] -= lp 128 | K_cam[1, 2] -= tp 129 | 130 | return rgb, disp, lidar, K_cam, depth 131 | 132 | def pull_DISP(self, path): 133 | disp = np.array(Image.open(path).convert('L'), dtype=np.uint8) 134 | return disp 135 | 136 | def pull_RGB(self, path): 137 | img = np.array(Image.open(path).convert('RGB'), dtype=np.uint8) 138 | return img 139 | 140 | def pull_DEPTH(self, path): 141 | depth_png = np.array(Image.open(path), dtype=int) 142 | assert (np.max(depth_png) > 255) 143 | depth_image = (depth_png / 256.).astype(np.float32) 144 | return depth_image 145 | 146 | def pull_K_cam(self, index): 147 | file_names = self.depths[index].split('/') 148 | if self.mode in ['train', 'val', 'trainval']: 149 | calib_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0], 150 | 'calib_cam_to_cam.txt') 151 | filedata = read_calib_file(calib_path) 152 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4)) 153 | P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4)) 154 | if file_names[-2] == 'image_02': 155 | K_cam = P_rect_20[0:3, 0:3] 156 | elif file_names[-2] == 'image_03': 157 | K_cam = P_rect_30[0:3, 0:3] 158 | else: 159 | raise ValueError("Unknown mode: {}".format(file_names[-2])) 160 | 161 | elif self.mode in ['selval', 'test']: 162 | fns = self.images[index].split('/') 163 | calib_path = os.path.join(*fns[:-2], 'intrinsics', fns[-1][:-3] + 'txt') 164 | with open(calib_path, 'r') as f: 165 | K_cam = f.read().split() 166 | K_cam = np.array(K_cam, dtype=np.float32).reshape(3, 3) 167 | else: 168 | raise ValueError("Unknown mode: {}".format(self.mode)) 169 | return K_cam 170 | -------------------------------------------------------------------------------- /datasets/nyu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Berrow from CompletionFormer https://github.com/youmi-zym/CompletionFormer 3 | ====================================================================== 4 | 5 | NYU Depth V2 Dataset Helper 6 | """ 7 | 8 | import os 9 | import warnings 10 | 11 | import numpy as np 12 | import json 13 | import h5py 14 | 15 | from PIL import Image 16 | import torch 17 | import torchvision.transforms as T 18 | import torchvision.transforms.functional as TF 19 | from torch.utils.data import Dataset 20 | import glob 21 | 22 | __all__ = [ 23 | 'NYU', 24 | ] 25 | 26 | 27 | class BaseDataset(Dataset): 28 | def __init__(self, args, mode): 29 | self.args = args 30 | self.mode = mode 31 | 32 | def __len__(self): 33 | pass 34 | 35 | def __getitem__(self, idx): 36 | pass 37 | 38 | # A workaround for a pytorch bug 39 | # https://github.com/pytorch/vision/issues/2194 40 | class ToNumpy: 41 | def __call__(self, sample): 42 | return np.array(sample) 43 | 44 | 45 | class NYU(BaseDataset): 46 | def __init__(self, mode, path='../datas/nyudepthv2', num_sample=500, mul_factor=1., num_mask=8, scale_kcam=False, 47 | rand_scale=True, *args, **kwargs): 48 | super(NYU, self).__init__(None, mode) 49 | 50 | self.mode = mode 51 | self.num_sample = num_sample 52 | self.num_mask = num_mask 53 | self.mul_factor = mul_factor 54 | self.scale_kcam = scale_kcam 55 | self.rand_scale = rand_scale 56 | 57 | if mode != 'train' and mode != 'val': 58 | raise NotImplementedError 59 | 60 | height, width = (240, 320) 61 | crop_size = (228, 304) 62 | 63 | self.height = height 64 | self.width = width 65 | self.crop_size = crop_size 66 | 67 | self.Kcam = torch.from_numpy(np.array( 68 | [ 69 | [5.1885790117450188e+02, 0, 3.2558244941119034e+02], 70 | [0, 5.1946961112127485e+02, 2.5373616633400465e+02], 71 | [0, 0, 1.], 72 | ], dtype=np.float32 73 | ) 74 | ) 75 | 76 | base_dir = path 77 | 78 | self.sample_list = list(sorted(glob.glob(os.path.join(base_dir, mode, "**/**.h5")))) 79 | 80 | def __len__(self): 81 | if self.mode == 'train': 82 | return len(self.sample_list) 83 | elif self.mode == 'val': 84 | return self.num_mask * len(self.sample_list) 85 | else: 86 | raise NotImplementedError 87 | 88 | def __getitem__(self, idx): 89 | if self.mode == 'val': 90 | seed = idx % self.num_mask 91 | idx = idx // self.num_mask 92 | 93 | path_file = self.sample_list[idx] 94 | 95 | f = h5py.File(path_file, 'r') 96 | rgb_h5 = f['rgb'][:].transpose(1, 2, 0) 97 | dep_h5 = f['depth'][:] 98 | 99 | rgb = Image.fromarray(rgb_h5, mode='RGB') 100 | dep = Image.fromarray(dep_h5.astype('float32'), mode='F') 101 | 102 | Kcam = self.Kcam.clone() 103 | 104 | if self.mode == 'train': 105 | if self.rand_scale: 106 | _scale = np.random.uniform(1.0, 1.5) 107 | else: 108 | _scale = 1.0 109 | scale = int(self.height * _scale) 110 | degree = np.random.uniform(-5.0, 5.0) 111 | flip = np.random.uniform(0.0, 1.0) 112 | 113 | if flip > 0.5: 114 | rgb = TF.hflip(rgb) 115 | dep = TF.hflip(dep) 116 | Kcam[0, 2] = rgb.width - 1 - Kcam[0, 2] 117 | 118 | rgb = TF.rotate(rgb, angle=degree) 119 | dep = TF.rotate(dep, angle=degree) 120 | 121 | t_rgb = T.Compose([ 122 | T.Resize(scale), 123 | T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 124 | T.CenterCrop(self.crop_size), 125 | T.ToTensor(), 126 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 127 | ]) 128 | 129 | t_dep = T.Compose([ 130 | T.Resize(scale), 131 | T.CenterCrop(self.crop_size), 132 | self.ToNumpy(), 133 | T.ToTensor() 134 | ]) 135 | 136 | rgb = t_rgb(rgb) 137 | dep = t_dep(dep) 138 | 139 | if self.scale_kcam: 140 | Kcam[:2] = Kcam[:2] * _scale 141 | 142 | else: 143 | t_rgb = T.Compose([ 144 | T.Resize(self.height), 145 | T.CenterCrop(self.crop_size), 146 | T.ToTensor(), 147 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 148 | ]) 149 | 150 | t_dep = T.Compose([ 151 | T.Resize(self.height), 152 | T.CenterCrop(self.crop_size), 153 | self.ToNumpy(), 154 | T.ToTensor() 155 | ]) 156 | 157 | rgb = t_rgb(rgb) 158 | dep = t_dep(dep) 159 | 160 | if self.mode == 'train': 161 | dep_sp = self.get_sparse_depth(dep, self.num_sample) 162 | elif self.mode == 'val': 163 | dep_sp = self.mask_sparse_depth(dep, self.num_sample, seed) 164 | else: 165 | raise NotImplementedError 166 | 167 | rgb = TF.pad(rgb, padding=[8, 14], padding_mode='edge') 168 | # rgb = TF.pad(rgb, padding=[8, 14], padding_mode='constant') 169 | dep_sp = TF.pad(dep_sp, padding=[8, 14], padding_mode='constant') 170 | dep = TF.pad(dep, padding=[8, 14], padding_mode='constant') 171 | 172 | Kcam[:2] = Kcam[:2] / 2. 173 | Kcam[0, 2] += 8 - 8 174 | Kcam[1, 2] += -6 + 14 175 | 176 | dep_sp *= self.mul_factor 177 | dep *= self.mul_factor 178 | return rgb, dep_sp, Kcam, dep 179 | 180 | def get_sparse_depth(self, dep, num_sample): 181 | channel, height, width = dep.shape 182 | 183 | assert channel == 1 184 | 185 | idx_nnz = torch.nonzero(dep.view(-1) > 0.0001, as_tuple=False) 186 | 187 | num_idx = len(idx_nnz) 188 | idx_sample = torch.randperm(num_idx)[:num_sample] 189 | 190 | idx_nnz = idx_nnz[idx_sample[:]] 191 | 192 | mask = torch.zeros((channel * height * width)) 193 | mask[idx_nnz] = 1.0 194 | mask = mask.view((channel, height, width)) 195 | 196 | dep_sp = dep * mask.type_as(dep) 197 | 198 | if num_idx == 0: 199 | dep_sp[:, 20:-20:10, 20:-20:10] = 3. 200 | 201 | return dep_sp 202 | 203 | def mask_sparse_depth(self, dep, num_sample, seed): 204 | channel, height, width = dep.shape 205 | dep = dep.numpy().reshape(-1) 206 | np.random.seed(seed) 207 | index = np.random.choice(height * width, num_sample, replace=False) 208 | dep_sp = np.zeros_like(dep) 209 | dep_sp[index] = dep[index] 210 | dep_sp = dep_sp.reshape(channel, height, width) 211 | dep_sp = torch.from_numpy(dep_sp) 212 | return dep_sp 213 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | import cv2 4 | import os 5 | import matplotlib.pyplot as plt 6 | import torch 7 | import hydra 8 | from utils_infer import Trainer 9 | from mpl_toolkits.mplot3d import Axes3D 10 | from tqdm import tqdm 11 | from PIL import Image 12 | import imutils 13 | 14 | def save_point_cloud_to_image(pcd, image_size=(1600, 1200)): 15 | vis = o3d.visualization.Visualizer() 16 | vis.create_window(visible=False, width=image_size[0], height=image_size[1]) # 设置窗口大小并不显示 17 | 18 | # 添加点云到可视化器中 19 | vis.add_geometry(pcd) 20 | 21 | # 获取 ViewControl 对象并设置自定义视角 22 | view_control = vis.get_view_control() 23 | 24 | # 设置视角参数 25 | parameters = { 26 | "boundingbox_max" : [ -1.5977719363706235, 11.519330868353832, 84.127326965332031 ], 27 | "boundingbox_min" : [ -56.623178268627761, -50.724836932361825, 4.6948032379150391 ], 28 | "field_of_view" : 60.0, 29 | "front" : [ 0.36788389015943362, 0.28372788418091033, -0.8855280521244856 ], 30 | "lookat" : [ -29.110475102499194, -19.602753032003996, 44.411065101623535 ], 31 | "up" : [ -0.88654778269461509, -0.18027422226025128, -0.42606834403382188 ], 32 | "zoom" : 0.5199999999999998 33 | } 34 | 35 | # 应用视角设置 36 | ctr = vis.get_view_control() 37 | ctr.set_lookat(parameters["lookat"]) 38 | ctr.set_front(parameters["front"]) 39 | ctr.set_up(parameters["up"]) 40 | ctr.set_zoom(parameters["zoom"]) 41 | 42 | # 渲染点云为图像 43 | vis.poll_events() 44 | vis.update_renderer() 45 | 46 | # 获取点云的2D图像 47 | depth_image = np.asarray(vis.capture_screen_float_buffer(do_render=True)) 48 | vis.destroy_window() 49 | 50 | # 调整尺寸并格式化图像 51 | depth_image = (depth_image * 255).astype(np.uint8) # 转换为8位图像 52 | depth_image = cv2.cvtColor(depth_image, cv2.COLOR_RGB2BGR) # 转换为BGR格式以便与OpenCV兼容 53 | 54 | return depth_image 55 | 56 | def depth_to_point_cloud(depth_map, K_cam): 57 | h, w = depth_map.shape 58 | i, j = np.meshgrid(np.arange(w), np.arange(h)) 59 | 60 | # 将像素坐标变换为相机坐标 61 | z = depth_map 62 | x = (j - K_cam[0, 2]) * z / K_cam[0, 0] 63 | y = (i - K_cam[1, 2]) * z / K_cam[1, 1] 64 | y = -y 65 | 66 | points_3d = np.stack((x, y, z), axis=-1).reshape(-1, 3) 67 | 68 | return points_3d 69 | 70 | 71 | def read_calib_file(filepath): 72 | """Read in a calibration file and parse into a dictionary.""" 73 | data = {} 74 | 75 | with open(filepath, 'r') as f: 76 | for line in f.readlines(): 77 | key, value = line.split(':', 1) 78 | # The only non-float values in these files are dates, which 79 | # we don't care about anyway 80 | try: 81 | data[key] = np.array([float(x) for x in value.split()]) 82 | except ValueError: 83 | pass 84 | 85 | return data 86 | 87 | 88 | def pull_K_cam(calib_path): 89 | filedata = read_calib_file(calib_path) 90 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4)) 91 | K_cam = P_rect_20[0:3, 0:3] 92 | return K_cam 93 | 94 | 95 | def depth_color(val, min_d=0, max_d=120): 96 | """ 97 | print Color(HSV's H value) corresponding to distance(m) 98 | close distance = red , far distance = blue 99 | """ 100 | np.clip(val, 0, max_d, out=val) # max distance is 120m but usually not usual 101 | return (((val - min_d) / (max_d - min_d)) * 120).astype(np.uint8) 102 | 103 | 104 | def in_h_range_points(points, m, n, fov): 105 | """ extract horizontal in-range points """ 106 | return np.logical_and(np.arctan2(n, m) > (-fov[1] * np.pi / 180), 107 | np.arctan2(n, m) < (-fov[0] * np.pi / 180)) 108 | 109 | 110 | def in_v_range_points(points, m, n, fov): 111 | """ extract vertical in-range points """ 112 | return np.logical_and(np.arctan2(n, m) < (fov[1] * np.pi / 180), 113 | np.arctan2(n, m) > (fov[0] * np.pi / 180)) 114 | 115 | 116 | def fov_setting(points, x, y, z, dist, h_fov, v_fov): 117 | """ filter points based on h,v FOV """ 118 | 119 | if h_fov[1] == 180 and h_fov[0] == -180 and v_fov[1] == 2.0 and v_fov[0] == -24.9: 120 | return points 121 | 122 | if h_fov[1] == 180 and h_fov[0] == -180: 123 | return points[in_v_range_points(points, dist, z, v_fov)] 124 | elif v_fov[1] == 2.0 and v_fov[0] == -24.9: 125 | return points[in_h_range_points(points, x, y, h_fov)] 126 | else: 127 | h_points = in_h_range_points(points, x, y, h_fov) 128 | v_points = in_v_range_points(points, dist, z, v_fov) 129 | return points[np.logical_and(h_points, v_points)] 130 | 131 | 132 | def in_range_points(points, size): 133 | """ extract in-range points """ 134 | return np.logical_and(points > 0, points < size) 135 | 136 | 137 | def velo_points_filter(points, v_fov, h_fov): 138 | """ extract points corresponding to FOV setting """ 139 | 140 | # Projecting to 2D 141 | x = points[:, 0] 142 | y = points[:, 1] 143 | z = points[:, 2] 144 | dist = np.sqrt(x ** 2 + y ** 2 + z ** 2) 145 | 146 | if h_fov[0] < -90: 147 | h_fov = (-90,) + h_fov[1:] 148 | if h_fov[1] > 90: 149 | h_fov = h_fov[:1] + (90,) 150 | 151 | x_lim = fov_setting(x, x, y, z, dist, h_fov, v_fov)[:, None] 152 | y_lim = fov_setting(y, x, y, z, dist, h_fov, v_fov)[:, None] 153 | z_lim = fov_setting(z, x, y, z, dist, h_fov, v_fov)[:, None] 154 | 155 | # Stack arrays in sequence horizontally 156 | xyz_ = np.hstack((x_lim, y_lim, z_lim)) 157 | xyz_ = xyz_.T 158 | 159 | # stack (1,n) arrays filled with the number 1 160 | one_mat = np.full((1, xyz_.shape[1]), 1) 161 | xyz_ = np.concatenate((xyz_, one_mat), axis=0) 162 | 163 | # need dist info for points color 164 | dist_lim = fov_setting(dist, x, y, z, dist, h_fov, v_fov) 165 | color = depth_color(dist_lim, 0, 70) 166 | 167 | return xyz_, color 168 | 169 | 170 | def calib_velo2cam(filepath): 171 | """ 172 | get Rotation(R : 3x3), Translation(T : 3x1) matrix info 173 | using R,T matrix, we can convert velodyne coordinates to camera coordinates 174 | """ 175 | with open(filepath, "r") as f: 176 | file = f.readlines() 177 | 178 | for line in file: 179 | (key, val) = line.split(':', 1) 180 | if key == 'R': 181 | R = np.fromstring(val, sep=' ') 182 | R = R.reshape(3, 3) 183 | if key == 'T': 184 | T = np.fromstring(val, sep=' ') 185 | T = T.reshape(3, 1) 186 | return R, T 187 | 188 | 189 | def calib_cam2cam(filepath, mode): 190 | with open(filepath, "r") as f: 191 | file = f.readlines() 192 | 193 | for line in file: 194 | (key, val) = line.split(':', 1) 195 | if key == ('P_rect_' + mode): 196 | P_ = np.fromstring(val, sep=' ') 197 | P_ = P_.reshape(3, 4) 198 | # erase 4th column ([0,0,0]) 199 | P_ = P_[:3, :3] 200 | return P_ 201 | 202 | 203 | def velo3d_2_camera2d_points(points, v_fov, h_fov, vc_path, cc_path, mode='02', image_shape=None): 204 | """ 205 | Return velodyne 3D points corresponding to camera 2D image and sparse depth map 206 | """ 207 | 208 | # R_vc = Rotation matrix ( velodyne -> camera ) 209 | # T_vc = Translation matrix ( velodyne -> camera ) 210 | R_vc, T_vc = calib_velo2cam(vc_path) 211 | 212 | # P_ = Projection matrix ( camera coordinates 3d points -> image plane 2d points ) 213 | P_ = calib_cam2cam(cc_path, mode) 214 | xyz_v, c_ = velo_points_filter(points, v_fov, h_fov) 215 | 216 | RT_ = np.concatenate((R_vc, T_vc), axis=1) 217 | 218 | # Initialize sparse depth map 219 | if image_shape is None: 220 | raise ValueError( 221 | "Image shape must be provided to generate sparse depth map") 222 | 223 | # Create a depth map with NaN values 224 | depth_map = np.full(image_shape[:2], 0) 225 | 226 | # Convert velodyne coordinates(X_v, Y_v, Z_v) to camera coordinates(X_c, Y_c, Z_c) 227 | for i in range(xyz_v.shape[1]): 228 | xyz_v[:3, i] = np.matmul(RT_, xyz_v[:, i]) 229 | 230 | xyz_c = np.delete(xyz_v, 3, axis=0) 231 | 232 | # Convert camera coordinates(X_c, Y_c, Z_c) to image(pixel) coordinates(x,y) 233 | for i in range(xyz_c.shape[1]): 234 | xyz_c[:, i] = np.matmul(P_, xyz_c[:, i]) 235 | 236 | # Normalize by the third coordinate to get 2D pixel coordinates 237 | xy_i = xyz_c[:2, :] / xyz_c[2, :] 238 | depth_values = xyz_c[2, :] # Z-coordinate (depth) in camera space 239 | 240 | # Filter out points that are out of image bounds 241 | valid_mask = np.logical_and.reduce(( 242 | xy_i[0, :] >= 0, 243 | xy_i[0, :] < image_shape[1], # x coordinate within image width 244 | xy_i[1, :] >= 0, 245 | xy_i[1, :] < image_shape[0], # y coordinate within image height 246 | )) 247 | 248 | valid_points = xy_i[:, valid_mask] 249 | valid_depths = depth_values[valid_mask] 250 | 251 | # Fill the depth map with depth values 252 | for i in range(valid_points.shape[1]): 253 | x = int(valid_points[0, i]) 254 | y = int(valid_points[1, i]) 255 | depth_map[y, x] = valid_depths[i] 256 | 257 | # Return both the projected points and the sparse depth map 258 | return valid_points, c_, depth_map 259 | 260 | 261 | def print_projection_cv2(points, color, image): 262 | """ project converted velodyne points into camera image """ 263 | 264 | hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 265 | 266 | for i in range(points.shape[1]): 267 | cv2.circle(hsv_image, (np.int32(points[0][i]), np.int32( 268 | points[1][i])), 2, (int(color[i]), 255, 255), -1) 269 | 270 | return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR) 271 | 272 | 273 | def print_projection_plt(points, color, image): 274 | """ project converted velodyne points into camera image """ 275 | 276 | hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 277 | 278 | for i in range(points.shape[1]): 279 | cv2.circle(hsv_image, (np.int32(points[0][i]), np.int32( 280 | points[1][i])), 2, (int(color[i]), 255, 255), -1) 281 | 282 | return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2RGB) 283 | 284 | 285 | def load_from_bin(bin_path): 286 | obj = np.fromfile(bin_path, dtype=np.float32).reshape(-1, 4) 287 | # ignore reflectivity info 288 | return obj[:, :3] 289 | 290 | 291 | def set_axes_equal(ax): 292 | """使 3D 图的刻度长短一致""" 293 | x_limits = ax.get_xlim3d() 294 | y_limits = ax.get_ylim3d() 295 | z_limits = ax.get_zlim3d() 296 | 297 | # 找到所有坐标的中心和范围 298 | x_range = abs(x_limits[1] - x_limits[0]) 299 | x_middle = np.mean(x_limits) 300 | y_range = abs(y_limits[1] - y_limits[0]) 301 | y_middle = np.mean(y_limits) 302 | z_range = abs(z_limits[1] - z_limits[0]) 303 | z_middle = np.mean(z_limits) 304 | 305 | # 计算出最大的范围 306 | plot_radius = 0.5 * max([x_range, y_range, z_range]) 307 | 308 | # 设置每个坐标轴的范围,使其相等 309 | ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius]) 310 | ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius]) 311 | ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius]) 312 | 313 | 314 | def save_point_cloud_to_ply_open3d(point_cloud, image, filename): 315 | # 创建 open3d 点云对象 316 | pcd = o3d.geometry.PointCloud() 317 | 318 | # 设置点云的坐标 319 | pcd.points = o3d.utility.Vector3dVector(point_cloud) 320 | 321 | # 将图像展开为 (N, 3) 形式的颜色数据 322 | colors = image.reshape(-1, 3) / 255.0 # 归一化颜色到 [0, 1] 323 | 324 | # 设置点云的颜色 325 | pcd.colors = o3d.utility.Vector3dVector(colors) 326 | 327 | # 保存点云为 .ply 文件 328 | o3d.io.write_point_cloud(filename, pcd) 329 | return pcd 330 | 331 | 332 | @hydra.main(config_path='configs', config_name='config', version_base='1.2') 333 | def main(cfg): 334 | with Trainer(cfg) as run: 335 | net = run.net_ema.module.cuda() 336 | net.eval() 337 | # 读取左目图像 338 | # base = "datas/kitti/raw/2011_09_26/2011_09_26_drive_0002_sync" 339 | base = "datas/kitti/raw/2011_09_26/2011_09_26_drive_0048_sync" 340 | 341 | image_type = 'color' # 'grayscale' or 'color' image 342 | 343 | mode = '00' if image_type == 'grayscale' else '02' 344 | 345 | v2c_filepath = 'datas/kitti/raw/2011_09_26/calib_velo_to_cam.txt' 346 | c2c_filepath = 'datas/kitti/raw/2011_09_26/calib_cam_to_cam.txt' 347 | image_mean = np.array([90.9950, 96.2278, 94.3213]) 348 | image_std = np.array([79.2382, 80.5267, 82.1483]) 349 | image_height = 352 350 | image_width = 1216 351 | 352 | for i in tqdm(range(1000)): 353 | 354 | image = np.array(Image.open(os.path.join( 355 | base, 'image_' + mode + f'/data/0000000{i:03d}.png')).convert('RGB'), dtype=np.uint8) 356 | 357 | if image is None: 358 | break 359 | width, height = image.shape[1], image.shape[0] 360 | 361 | # bin file -> numpy array 362 | velo_points = load_from_bin(os.path.join( 363 | base, f'velodyne_points/data/0000000{i:03d}.bin')) 364 | 365 | image_type = 'color' # 'grayscale' or 'color' image 366 | # image_00 = 'grayscale image' , image_02 = 'color image' 367 | mode = '00' if image_type == 'grayscale' else '02' 368 | 369 | ans, c_, lidar = velo3d_2_camera2d_points(velo_points, v_fov=(-24.9, 2.0), h_fov=(-45, 45), 370 | vc_path=v2c_filepath, cc_path=c2c_filepath, mode=mode, 371 | image_shape=image.shape) 372 | 373 | image_vis = print_projection_plt(points=ans, color=c_, image=image.copy()) 374 | 375 | # depth completion 376 | K_cam = torch.from_numpy(pull_K_cam( 377 | c2c_filepath).astype(np.float32)).cuda() 378 | 379 | tp = image.shape[0] - image_height 380 | lp = (image.shape[1] - image_width) // 2 381 | image = image[tp:tp + image_height, lp:lp + image_width] 382 | lidar = lidar[tp:tp + image_height, lp:lp + image_width, None] 383 | image_vis = image_vis[tp:tp + image_height, lp:lp + image_width] 384 | K_cam[0, 2] -= lp 385 | K_cam[1, 2] -= tp 386 | 387 | image = (image - image_mean) / image_std 388 | 389 | image_tensor = image.transpose(2, 0, 1).astype(np.float32)[None] 390 | lidar_tensor = lidar.transpose(2, 0, 1).astype(np.float32)[None] 391 | 392 | image_tensor = torch.from_numpy(image_tensor) 393 | lidar_tensor = torch.from_numpy(lidar_tensor) 394 | 395 | output = net(image_tensor.cuda(), None, 396 | lidar_tensor.cuda(), K_cam[None].cuda()) 397 | if isinstance(output, (list, tuple)): 398 | output = output[-1] 399 | 400 | output = output.squeeze().detach().cpu().numpy() 401 | image = image * image_std + image_mean 402 | 403 | output_max, output_min = output.max(), output.min() 404 | output_norm = (output - output_min) / (output_max - output_min) * 255 405 | output_norm = output_norm.astype('uint8') 406 | output_color = cv2.applyColorMap(output_norm, cv2.COLORMAP_JET) 407 | cv2.imwrite(f'outputs/0000000{i:03d}_depth.png', output_color) 408 | 409 | cv2.imwrite(f'outputs/0000000{i:03d}_image.png', image.astype(np.uint8)[:, :, ::-1]) 410 | cv2.imwrite(f'outputs/0000000{i:03d}_lidar.png', lidar.astype(np.uint8) * 3) 411 | cv2.imwrite(f'outputs/0000000{i:03d}_image_vis.png', image_vis) 412 | 413 | 414 | if __name__ == '__main__': 415 | main() 416 | -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | python demo.py gpus=[2] name=BP_KITTI ++chpt=checkpoints/dmd3c_kitti.pth \ 2 | net=PMP num_workers=1 \ 3 | data=KITTI data.testset.mode=test data.testset.height=352 \ 4 | test_batch_size=1 metric=RMSE ++save=true 5 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: bp 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - abseil-cpp=20220623.0=h8cdb687_6 11 | - absl-py=2.1.0=pyhd8ed1ab_0 12 | - alsa-lib=1.2.8=h166bdaf_0 13 | - antlr-python-runtime=4.9.3=pyhd8ed1ab_1 14 | - aom=3.5.0=h27087fc_0 15 | - attr=2.5.1=h166bdaf_1 16 | - binutils=2.40=hdd6e379_0 17 | - binutils_impl_linux-64=2.40=hf600244_0 18 | - binutils_linux-64=2.40=hdade7a5_3 19 | - blas=1.0=mkl 20 | - brotli-python=1.0.9=py39h6a678d5_7 21 | - bzip2=1.0.8=hd590300_5 22 | - c-ares=1.28.1=hd590300_0 23 | - c-compiler=1.6.0=hd590300_0 24 | - ca-certificates=2024.3.11=h06a4308_0 25 | - cached-property=1.5.2=hd8ed1ab_1 26 | - cached_property=1.5.2=pyha770c72_1 27 | - cairo=1.16.0=ha61ee94_1014 28 | - certifi=2024.2.2=pyhd8ed1ab_0 29 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 30 | - colorama=0.4.6=pyhd8ed1ab_0 31 | - compilers=1.6.0=ha770c72_0 32 | - cuda-cudart=12.1.105=0 33 | - cuda-cupti=12.1.105=0 34 | - cuda-libraries=12.1.0=0 35 | - cuda-nvrtc=12.1.105=0 36 | - cuda-nvtx=12.1.105=0 37 | - cuda-opencl=12.4.127=0 38 | - cuda-runtime=12.1.0=0 39 | - cxx-compiler=1.6.0=h00ab1b0_0 40 | - dbus=1.13.6=h5008d03_3 41 | - einops=0.7.0=pyhd8ed1ab_1 42 | - expat=2.6.2=h59595ed_0 43 | - ffmpeg=5.1.2=gpl_h8dda1f0_106 44 | - fftw=3.3.10=nompi_hc118613_108 45 | - filelock=3.13.1=py39h06a4308_0 46 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 47 | - font-ttf-inconsolata=3.000=h77eed37_0 48 | - font-ttf-source-code-pro=2.038=h77eed37_0 49 | - font-ttf-ubuntu=0.83=h77eed37_1 50 | - fontconfig=2.14.2=h14ed4e7_0 51 | - fonts-conda-ecosystem=1=0 52 | - fonts-conda-forge=1=0 53 | - fortran-compiler=1.6.0=heb67821_0 54 | - freeglut=3.2.2=h9c3ff4c_1 55 | - freetype=2.12.1=h4a9f257_0 56 | - fsspec=2024.3.1=pyhca7485f_0 57 | - gcc=12.3.0=h95e488c_3 58 | - gcc_impl_linux-64=12.3.0=he2b93b0_5 59 | - gcc_linux-64=12.3.0=h6477408_3 60 | - gettext=0.21.1=h27087fc_0 61 | - gfortran=12.3.0=h7389182_3 62 | - gfortran_impl_linux-64=12.3.0=hfcedea8_5 63 | - gfortran_linux-64=12.3.0=h617cb40_3 64 | - glib=2.80.0=hf2295e7_1 65 | - glib-tools=2.80.0=hde27a5a_1 66 | - gmp=6.2.1=h295c915_3 67 | - gmpy2=2.1.2=py39heeb90bb_0 68 | - gnutls=3.7.9=hb077bed_0 69 | - graphite2=1.3.13=h59595ed_1003 70 | - grpc-cpp=1.51.1=h27aab58_1 71 | - grpcio=1.51.1=py39he859823_1 72 | - gst-plugins-base=1.22.0=h4243ec0_2 73 | - gstreamer=1.22.0=h25f0c4b_2 74 | - gstreamer-orc=0.4.38=hd590300_0 75 | - gtest=1.14.0=hdb19cb5_0 76 | - gxx=12.3.0=h95e488c_3 77 | - gxx_impl_linux-64=12.3.0=he2b93b0_5 78 | - gxx_linux-64=12.3.0=h4a1b8e8_3 79 | - h5py=3.9.0=nompi_py39h4dfffb9_100 80 | - harfbuzz=6.0.0=h8e241bc_0 81 | - hdf5=1.14.0=nompi_hb72d44e_103 82 | - huggingface_hub=0.22.2=pyhd8ed1ab_0 83 | - hydra-core=1.3.2=pyhd8ed1ab_0 84 | - icu=70.1=h27087fc_0 85 | - idna=3.4=py39h06a4308_0 86 | - importlib-metadata=7.1.0=pyha770c72_0 87 | - importlib_resources=6.4.0=pyhd8ed1ab_0 88 | - intel-openmp=2023.1.0=hdb19cb5_46306 89 | - jack=1.9.22=h11f4161_0 90 | - jasper=2.0.33=h0ff4b12_1 91 | - jinja2=3.1.3=py39h06a4308_0 92 | - jpeg=9e=h5eee18b_1 93 | - kernel-headers_linux-64=2.6.32=he073ed8_17 94 | - keyutils=1.6.1=h166bdaf_0 95 | - krb5=1.20.1=h81ceb04_0 96 | - lame=3.100=h7b6447c_0 97 | - lcms2=2.12=h3be6417_0 98 | - ld_impl_linux-64=2.40=h41732ed_0 99 | - lerc=3.0=h295c915_0 100 | - libabseil=20220623.0=cxx17_h05df665_6 101 | - libaec=1.1.3=h59595ed_0 102 | - libblas=3.9.0=1_h86c2bf4_netlib 103 | - libcap=2.67=he9d0100_0 104 | - libcblas=3.9.0=5_h92ddd45_netlib 105 | - libclang=15.0.7=default_h127d8a8_5 106 | - libclang13=15.0.7=default_h5d6823c_5 107 | - libcublas=12.1.0.26=0 108 | - libcufft=11.0.2.4=0 109 | - libcufile=1.9.1.3=0 110 | - libcups=2.3.3=h36d4200_3 111 | - libcurand=10.3.5.147=0 112 | - libcurl=8.1.2=h409715c_0 113 | - libcusolver=11.4.4.55=0 114 | - libcusparse=12.0.2.55=0 115 | - libdb=6.2.32=h9c3ff4c_0 116 | - libdeflate=1.8=h7f8727e_5 117 | - libdrm=2.4.120=hd590300_0 118 | - libedit=3.1.20191231=he28a2e2_2 119 | - libev=4.33=hd590300_2 120 | - libevent=2.1.10=h28343ad_4 121 | - libexpat=2.6.2=h59595ed_0 122 | - libffi=3.4.2=h7f98852_5 123 | - libflac=1.4.3=h59595ed_0 124 | - libgcc-devel_linux-64=12.3.0=h8bca6fd_105 125 | - libgcc-ng=13.2.0=h807b86a_5 126 | - libgcrypt=1.10.3=hd590300_0 127 | - libgfortran-ng=13.2.0=h69a702a_5 128 | - libgfortran5=13.2.0=ha4646dd_5 129 | - libglib=2.80.0=hf2295e7_1 130 | - libglu=9.0.0=he1b5a44_1001 131 | - libgomp=13.2.0=h807b86a_5 132 | - libgpg-error=1.48=h71f35ed_0 133 | - libgrpc=1.51.1=h4fad500_1 134 | - libiconv=1.17=hd590300_2 135 | - libidn2=2.3.4=h5eee18b_0 136 | - libjpeg-turbo=2.0.0=h9bf148f_0 137 | - liblapack=3.9.0=5_h92ddd45_netlib 138 | - liblapacke=3.9.0=5_h92ddd45_netlib 139 | - libllvm15=15.0.7=hadd5161_1 140 | - libnghttp2=1.58.0=h47da74e_0 141 | - libnpp=12.0.2.50=0 142 | - libnsl=2.0.1=hd590300_0 143 | - libnvjitlink=12.1.105=0 144 | - libnvjpeg=12.1.1.14=0 145 | - libogg=1.3.4=h7f98852_1 146 | - libopencv=4.7.0=py39h2ca4621_1 147 | - libopus=1.3.1=h7f98852_1 148 | - libpciaccess=0.18=hd590300_0 149 | - libpng=1.6.39=h5eee18b_0 150 | - libpq=15.3=hbcd7760_1 151 | - libprotobuf=3.21.12=hfc55251_2 152 | - libsanitizer=12.3.0=h0f45ef3_5 153 | - libsndfile=1.2.2=hc60ed4a_1 154 | - libsqlite=3.45.2=h2797004_0 155 | - libssh2=1.11.0=h0841786_0 156 | - libstdcxx-devel_linux-64=12.3.0=h8bca6fd_105 157 | - libstdcxx-ng=13.2.0=h7e041cc_5 158 | - libsystemd0=253=h8c4010b_1 159 | - libtasn1=4.19.0=h5eee18b_0 160 | - libtiff=4.5.0=h6a678d5_1 161 | - libtool=2.4.7=h27087fc_0 162 | - libudev1=253=h0b41bf4_1 163 | - libunistring=0.9.10=h27cfd23_0 164 | - libuuid=2.38.1=h0b41bf4_0 165 | - libva=2.18.0=h0b41bf4_0 166 | - libvorbis=1.3.7=h9c3ff4c_0 167 | - libvpx=1.11.0=h9c3ff4c_3 168 | - libwebp-base=1.3.2=h5eee18b_0 169 | - libxcb=1.13=h7f98852_1004 170 | - libxcrypt=4.4.36=hd590300_1 171 | - libxkbcommon=1.5.0=h79f4944_1 172 | - libxml2=2.10.3=hca2bb57_4 173 | - libzlib=1.2.13=hd590300_5 174 | - llvm-openmp=14.0.6=h9e868ea_0 175 | - lz4-c=1.9.4=h6a678d5_0 176 | - markdown=3.6=pyhd8ed1ab_0 177 | - markupsafe=2.1.3=py39h5eee18b_0 178 | - mkl=2023.1.0=h213fc3f_46344 179 | - mkl-service=2.4.0=py39h5eee18b_1 180 | - mkl_fft=1.3.8=py39h5eee18b_0 181 | - mkl_random=1.2.4=py39hdb19cb5_0 182 | - mpc=1.1.0=h10f8cd9_1 183 | - mpfr=4.0.2=hb69a4c5_1 184 | - mpg123=1.32.6=h59595ed_0 185 | - mpmath=1.3.0=py39h06a4308_0 186 | - mysql-common=8.0.33=hf1915f5_2 187 | - mysql-libs=8.0.33=hca2cd23_2 188 | - ncurses=6.4.20240210=h59595ed_0 189 | - nettle=3.9.1=h7ab15ed_0 190 | - networkx=3.1=py39h06a4308_0 191 | - nspr=4.35=h27087fc_0 192 | - nss=3.98=h1d7d5a4_0 193 | - numpy=1.26.4=py39h5f9d8c6_0 194 | - numpy-base=1.26.4=py39hb5e798b_0 195 | - omegaconf=2.3.0=pyhd8ed1ab_0 196 | - opencv=4.7.0=py39hf3d152e_1 197 | - openh264=2.3.1=hcb278e6_2 198 | - openjpeg=2.4.0=h3ad879b_0 199 | - openssl=3.1.5=hd590300_0 200 | - p11-kit=0.24.1=hc5aa10d_0 201 | - packaging=24.0=pyhd8ed1ab_0 202 | - pcre2=10.43=hcad00b1_0 203 | - pillow=10.2.0=py39h5eee18b_0 204 | - pip=24.0=pyhd8ed1ab_0 205 | - pixman=0.43.2=h59595ed_0 206 | - protobuf=4.21.12=py39h227be39_0 207 | - pthread-stubs=0.4=h36c2ea0_1001 208 | - pulseaudio=16.1=hcb278e6_3 209 | - pulseaudio-client=16.1=h5195f5e_3 210 | - pulseaudio-daemon=16.1=ha8d29e2_3 211 | - py-opencv=4.7.0=py39hcca971b_1 212 | - pysocks=1.7.1=py39h06a4308_0 213 | - python=3.9.18=h0755675_0_cpython 214 | - python_abi=3.9=4_cp39 215 | - pytorch=2.2.2=py3.9_cuda12.1_cudnn8.9.2_0 216 | - pytorch-cuda=12.1=ha16c6d3_5 217 | - pytorch-mutex=1.0=cuda 218 | - pyyaml=6.0.1=py39h5eee18b_0 219 | - qt-main=5.15.8=h5d23da1_6 220 | - re2=2023.02.01=hcb278e6_0 221 | - readline=8.2=h8228510_1 222 | - requests=2.31.0=py39h06a4308_1 223 | - safetensors=0.4.2=py39h9fdd4d6_0 224 | - setuptools=69.2.0=pyhd8ed1ab_0 225 | - six=1.16.0=pyh6c4a22f_0 226 | - svt-av1=1.4.1=hcb278e6_0 227 | - sympy=1.12=pypyh9d50eac_103 228 | - sysroot_linux-64=2.12=he073ed8_17 229 | - tbb=2021.8.0=hdb19cb5_0 230 | - tensorboard=2.16.2=pyhd8ed1ab_0 231 | - tensorboard-data-server=0.7.0=py39hd4f0224_1 232 | - timm=0.9.16=pyhd8ed1ab_0 233 | - tk=8.6.13=noxft_h4845f30_101 234 | - torchaudio=2.2.2=py39_cu121 235 | - torchtriton=2.2.0=py39 236 | - torchvision=0.17.2=py39_cu121 237 | - tqdm=4.66.2=pyhd8ed1ab_0 238 | - typing-extensions=4.9.0=py39h06a4308_1 239 | - typing_extensions=4.9.0=py39h06a4308_1 240 | - tzdata=2024a=h0c530f3_0 241 | - urllib3=2.1.0=py39h06a4308_1 242 | - werkzeug=3.0.2=pyhd8ed1ab_0 243 | - wheel=0.43.0=pyhd8ed1ab_1 244 | - x264=1!164.3095=h166bdaf_2 245 | - x265=3.5=h924138e_3 246 | - xcb-util=0.4.0=h516909a_0 247 | - xcb-util-image=0.4.0=h166bdaf_0 248 | - xcb-util-keysyms=0.4.0=h516909a_0 249 | - xcb-util-renderutil=0.3.9=h166bdaf_0 250 | - xcb-util-wm=0.4.1=h516909a_0 251 | - xkeyboard-config=2.38=h0b41bf4_0 252 | - xorg-fixesproto=5.0=h7f98852_1002 253 | - xorg-inputproto=2.3.2=h7f98852_1002 254 | - xorg-kbproto=1.0.7=h7f98852_1002 255 | - xorg-libice=1.1.1=hd590300_0 256 | - xorg-libsm=1.2.4=h7391055_0 257 | - xorg-libx11=1.8.4=h0b41bf4_0 258 | - xorg-libxau=1.0.11=hd590300_0 259 | - xorg-libxdmcp=1.1.3=h7f98852_0 260 | - xorg-libxext=1.3.4=h0b41bf4_2 261 | - xorg-libxfixes=5.0.3=h7f98852_1004 262 | - xorg-libxi=1.7.10=h7f98852_0 263 | - xorg-libxrender=0.9.10=h7f98852_1003 264 | - xorg-renderproto=0.11.1=h7f98852_1002 265 | - xorg-xextproto=7.3.0=h0b41bf4_1003 266 | - xorg-xproto=7.0.31=h7f98852_1007 267 | - xz=5.4.6=h5eee18b_0 268 | - yaml=0.2.5=h7b6447c_0 269 | - zipp=3.17.0=pyhd8ed1ab_0 270 | - zlib=1.2.13=hd590300_5 271 | - zstd=1.5.2=ha4553b6_0 272 | 273 | -------------------------------------------------------------------------------- /exts/bp_cuda.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jie on 09/02/19. 3 | // 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "bp_cuda.h" 12 | 13 | 14 | void dist( 15 | at::Tensor Pc, 16 | at::Tensor IPCnum, 17 | at::Tensor args, 18 | int H, 19 | int W 20 | ) { 21 | int B, Cc, N, M, num; 22 | B = Pc.size(0); 23 | Cc = Pc.size(1); 24 | M = Pc.size(2); 25 | num = args.size(1); 26 | N = args.size(2); 27 | Dist_Cuda(Pc, IPCnum, args, B, Cc, N, M, num, H, W); 28 | } 29 | 30 | 31 | at::Tensor Conv2dLocal_F( 32 | at::Tensor a, 33 | at::Tensor b 34 | ) { 35 | int N1, N2, Ci, Co, K, B; 36 | B = a.size(0); 37 | Ci = a.size(1); 38 | N1 = a.size(2); 39 | N2 = a.size(3); 40 | Co = Ci; 41 | K = sqrt(b.size(1) / Co); 42 | auto c = at::zeros_like(a); 43 | Conv2d_LF_Cuda(a, b, c, N1, N2, Ci, Co, B, K); 44 | return c; 45 | } 46 | 47 | 48 | std::tuple Conv2dLocal_B( 49 | at::Tensor a, 50 | at::Tensor b, 51 | at::Tensor gc 52 | ) { 53 | int N1, N2, Ci, Co, K, B; 54 | B = a.size(0); 55 | Ci = a.size(1); 56 | N1 = a.size(2); 57 | N2 = a.size(3); 58 | Co = Ci; 59 | K = sqrt(b.size(1) / Co); 60 | auto ga = at::zeros_like(a); 61 | auto gb = at::zeros_like(b); 62 | Conv2d_LB_Cuda(a, b, ga, gb, gc, N1, N2, Ci, Co, B, K); 63 | return std::make_tuple(ga, gb); 64 | } 65 | 66 | 67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m 68 | ) { 69 | m.def("Dist", &dist, "calculate distance on 2D image"); 70 | m.def("Conv2dLocal_F", &Conv2dLocal_F, "Conv2dLocal Forward (CUDA)"); 71 | m.def("Conv2dLocal_B", &Conv2dLocal_B, "Conv2dLocal Backward (CUDA)"); 72 | } 73 | 74 | -------------------------------------------------------------------------------- /exts/bp_cuda.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by jie on 12/8/22. 3 | // 4 | 5 | #ifndef BP_CUDA_H 6 | #define BP_CUDA_H 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | void dist( 18 | at::Tensor Pc, 19 | at::Tensor IPCnum, 20 | at::Tensor args, 21 | int H, 22 | int W 23 | ); 24 | 25 | at::Tensor Conv2dLocal_F( 26 | at::Tensor a, 27 | at::Tensor b 28 | ); 29 | 30 | std::tuple Conv2dLocal_B( 31 | at::Tensor a, 32 | at::Tensor b, 33 | at::Tensor gc 34 | ); 35 | 36 | void Dist_Cuda(at::Tensor Pc, at::Tensor IPCnum, at::Tensor args, 37 | size_t B, size_t Cc, size_t N, size_t M, size_t num, size_t H, size_t W); 38 | 39 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B, 40 | size_t K); 41 | 42 | void 43 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci, 44 | size_t Co, size_t B, size_t K); 45 | 46 | 47 | #endif 48 | -------------------------------------------------------------------------------- /exts/bp_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "bp_cuda.h" 7 | #include 8 | #include 9 | 10 | namespace { 11 | 12 | template 13 | __global__ void 14 | dist_kernel4(const scalar_t *__restrict__ Pc, scalar_t *__restrict__ IPCnum, 15 | long *__restrict__ args, int B, int Cc, int N, int M, int numfw, int H, int W) { 16 | int num = 4; 17 | int h = threadIdx.x + blockIdx.x * blockDim.x; 18 | int w = threadIdx.y + blockIdx.y * blockDim.y; 19 | int b = threadIdx.z + blockIdx.z * blockDim.z; 20 | if ((h < H) && (w < W) && (b < B)) { 21 | scalar_t best[4]; 22 | long argbest[4]; 23 | #pragma unroll 24 | for (int i = 0; i < num; i++) { 25 | best[i] = 1e20; 26 | argbest[i] = 0; 27 | } 28 | for (int m = 0; m < M; m++) { 29 | scalar_t res1, res2, dist; 30 | res1 = Pc[b * Cc * M + 0 * M + m] - w; 31 | res2 = Pc[b * Cc * M + 1 * M + m] - h; 32 | dist = res1 * res1 + res2 * res2; 33 | #pragma unroll 34 | for (int i = 0; i < num; i++) { 35 | if (best[i] >= dist) { 36 | #pragma unroll 37 | for (int j = num - 1; j > i; j--) { 38 | best[j] = best[j - 1]; 39 | argbest[j] = argbest[j - 1]; 40 | } 41 | best[i] = dist; 42 | argbest[i] = m; 43 | break; 44 | } 45 | } 46 | } 47 | #pragma unroll 48 | for (int i = 0; i < num; i++) { 49 | if (best[i] >= 1e20) { 50 | argbest[i] = argbest[i - 1]; 51 | } 52 | args[b * num * N + i * N + h * W + w] = argbest[i]; 53 | IPCnum[b * Cc * num * N + 0 * num * N + i * N + h * W + w] = 54 | Pc[b * Cc * M + 0 * M + argbest[i]] - w; 55 | IPCnum[b * Cc * num * N + 1 * num * N + i * N + h * W + w] = 56 | Pc[b * Cc * M + 1 * M + argbest[i]] - h; 57 | } 58 | } 59 | } 60 | 61 | template 62 | __global__ void 63 | conv2d_kernel_lf(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ z, size_t N1, 64 | size_t N2, size_t Ci, size_t Co, size_t B, 65 | size_t K) { 66 | int col_index = threadIdx.x + blockIdx.x * blockDim.x; 67 | int row_index = threadIdx.y + blockIdx.y * blockDim.y; 68 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z; 69 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) { 70 | for (int b = 0; b < B; b++) { 71 | scalar_t result = 0; 72 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) { 73 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) { 74 | 75 | if ((row_index + i < 0) || (row_index + i >= N1) || (col_index + j < 0) || 76 | (col_index + j >= N2)) { 77 | continue; 78 | } 79 | 80 | result += x[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index + i) * N2 + col_index + j] * 81 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + 82 | (i + (K - 1) / 2) * K * N1 * N2 + 83 | (j + (K - 1) / 2) * N1 * N2 + row_index * N2 + col_index]; 84 | } 85 | } 86 | z[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result; 87 | } 88 | } 89 | } 90 | 91 | 92 | template 93 | __global__ void conv2d_kernel_lb(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ gx, 94 | scalar_t *__restrict__ gy, scalar_t *__restrict__ gz, size_t N1, size_t N2, 95 | size_t Ci, size_t Co, size_t B, 96 | size_t K) { 97 | int col_index = threadIdx.x + blockIdx.x * blockDim.x; 98 | int row_index = threadIdx.y + blockIdx.y * blockDim.y; 99 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z; 100 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) { 101 | for (int b = 0; b < B; b++) { 102 | scalar_t result = 0; 103 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) { 104 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) { 105 | 106 | if ((row_index - i < 0) || (row_index - i >= N1) || (col_index - j < 0) || 107 | (col_index - j >= N2)) { 108 | continue; 109 | } 110 | result += gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j 111 | ] * 112 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + 113 | (i + (K - 1) / 2) * K * N1 * N2 + 114 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j]; 115 | gy[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K + 116 | (i + (K - 1) / 2) * K * N1 * N2 + 117 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j] = 118 | gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j 119 | ] * x[b * N1 * N2 * Ci + cha_index * N1 * N2 + row_index * N2 + col_index]; 120 | 121 | } 122 | } 123 | gx[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result; 124 | } 125 | } 126 | } 127 | 128 | 129 | } 130 | 131 | 132 | void Dist_Cuda(at::Tensor Pc, at::Tensor IPCnum, at::Tensor args, 133 | size_t B, size_t Cc, size_t N, size_t M, size_t num, size_t H, size_t W) { 134 | dim3 blockSize(32, 32, 1); 135 | dim3 gridSize((H + blockSize.x - 1) / blockSize.x, (W + blockSize.x - 1) / blockSize.x, B); 136 | switch (num) { 137 | case 4: 138 | AT_DISPATCH_FLOATING_TYPES(Pc.type(), "DistF1_Cuda", ([&] { 139 | dist_kernel4 << < gridSize, blockSize >> > ( 140 | Pc.data_ptr(), IPCnum.data_ptr(), args.data_ptr(), 141 | B, Cc, N, M, num, H, W); 142 | })); 143 | break; 144 | default: 145 | exit(-1); 146 | } 147 | } 148 | 149 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B, 150 | size_t K) { 151 | dim3 blockSize(32, 32, 1); 152 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y, 153 | (Co + blockSize.z - 1) / blockSize.z); 154 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LF", ([&] { 155 | conv2d_kernel_lf << < gridSize, blockSize >> > ( 156 | x.data(), y.data(), z.data(), 157 | N1, N2, Ci, Co, B, K); 158 | })); 159 | } 160 | 161 | 162 | void 163 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, 164 | size_t Ci, 165 | size_t Co, size_t B, size_t K) { 166 | dim3 blockSize(32, 32, 1); 167 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y, 168 | (Co + blockSize.z - 1) / blockSize.z); 169 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LB", ([&] { 170 | conv2d_kernel_lb << < gridSize, blockSize >> > ( 171 | x.data(), y.data(), 172 | gx.data(), gy.data(), gz.data(), 173 | N1, N2, Ci, Co, B, K); 174 | })); 175 | } 176 | -------------------------------------------------------------------------------- /exts/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='BpOps', 6 | ext_modules=[ 7 | CUDAExtension('BpOps', 8 | [ 9 | 'bp_cuda.cpp', 10 | 'bp_cuda_kernel.cu', 11 | ], 12 | extra_compile_args={'cxx': ['-g'], 13 | 'nvcc': ['-O3']} 14 | ), 15 | ], 16 | cmdclass={ 17 | 'build_ext': BuildExtension 18 | }) 19 | -------------------------------------------------------------------------------- /models/BPNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : BPNet.py 3 | # @Project: BP-Net 4 | # @Author : jie 5 | # @Time : 4/8/23 12:43 PM 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import functools 11 | from .utils import Conv1x1, Basic2d, BasicBlock, weights_init, inplace_relu, GenKernel, PMP 12 | 13 | __all__ = [ 14 | 'Pre_MF_Post', 15 | ] 16 | 17 | 18 | class Net(nn.Module): 19 | """ 20 | network 21 | """ 22 | 23 | def __init__(self, block=BasicBlock, bc=16, img_layers=[2, 2, 2, 2, 2, 2], 24 | drop_path=0.1, norm_layer=nn.BatchNorm2d, padding_mode='zeros', drift=1e6): 25 | super().__init__() 26 | self.drift = drift 27 | self._norm_layer = norm_layer 28 | self._padding_mode = padding_mode 29 | self._img_dpc = 0 30 | self._img_dprs = np.linspace(0, drop_path, sum(img_layers)) 31 | 32 | self.inplanes = bc * 2 33 | self.conv_img = nn.Sequential( 34 | Basic2d(3, bc * 2, norm_layer=norm_layer, kernel_size=3, padding=1), 35 | self._make_layer(block, bc * 2, 2, stride=1) 36 | ) 37 | 38 | self.layer1_img = self._make_layer(block, bc * 4, img_layers[1], stride=2) 39 | 40 | self.inplanes = bc * 4 41 | self.layer2_img = self._make_layer(block, bc * 8, img_layers[2], stride=2) 42 | 43 | self.inplanes = bc * 8 44 | self.layer3_img = self._make_layer(block, bc * 16, img_layers[3], stride=2) 45 | 46 | self.inplanes = bc * 16 47 | self.layer4_img = self._make_layer(block, bc * 16, img_layers[4], stride=2) 48 | 49 | self.inplanes = bc * 16 50 | self.layer5_img = self._make_layer(block, bc * 16, img_layers[5], stride=2) 51 | 52 | self.pred0 = PMP(level=0, in_ch=bc * 4, out_ch=bc * 2, drop_path=drop_path, pool=False) 53 | self.pred1 = PMP(level=1, in_ch=bc * 8, out_ch=bc * 4, drop_path=drop_path) 54 | self.pred2 = PMP(level=2, in_ch=bc * 16, out_ch=bc * 8, drop_path=drop_path) 55 | self.pred3 = PMP(level=3, in_ch=bc * 16, out_ch=bc * 16, drop_path=drop_path) 56 | self.pred4 = PMP(level=4, in_ch=bc * 16, out_ch=bc * 16, drop_path=drop_path) 57 | self.pred5 = PMP(level=5, in_ch=bc * 16, out_ch=bc * 16, drop_path=drop_path, up=False) 58 | 59 | def forward(self, I, DISP, S, K): 60 | """ 61 | I: Bx3xHxW 62 | S: Bx1xHxW 63 | K: Bx3x3 64 | """ 65 | output = [] 66 | # torch.Size([2, 3, 256, 1216]) 67 | XI0 = self.conv_img(I) # torch.Size([2, 32, 256, 1216]) 68 | XI1 = self.layer1_img(XI0) # torch.Size([2, 64, 128, 608]) 69 | XI2 = self.layer2_img(XI1) # torch.Size([2, 128, 64, 304]) 70 | XI3 = self.layer3_img(XI2) # torch.Size([2, 256, 32, 152]) 71 | XI4 = self.layer4_img(XI3) # torch.Size([2, 256, 16, 76]) 72 | XI5 = self.layer5_img(XI4) # torch.Size([2, 256, 8, 38]) 73 | 74 | # import IPython 75 | # IPython.embed() 76 | # exit() 77 | 78 | # S: torch.Size([2, 1, 256, 1216]) 79 | # XI5: torch.Size([2, 256, 8, 38]) 80 | 81 | fout, dout = self.pred5(fout=None, dout=None, XI=XI5, S=S, K=K) 82 | output.append(F.interpolate(dout, scale_factor=2 ** 5, mode='bilinear', align_corners=True)) 83 | 84 | fout, dout = self.pred4(fout=fout, dout=dout, XI=XI4, S=S, K=K) 85 | output.append(F.interpolate(dout, scale_factor=2 ** 4, mode='bilinear', align_corners=True)) 86 | 87 | fout, dout = self.pred3(fout=fout, dout=dout, XI=XI3, S=S, K=K) 88 | output.append(F.interpolate(dout, scale_factor=2 ** 3, mode='bilinear', align_corners=True)) 89 | 90 | fout, dout = self.pred2(fout=fout, dout=dout, XI=XI2, S=S, K=K) 91 | output.append(F.interpolate(dout, scale_factor=2 ** 2, mode='bilinear', align_corners=True)) 92 | 93 | fout, dout = self.pred1(fout=fout, dout=dout, XI=XI1, S=S, K=K) 94 | output.append(F.interpolate(dout, scale_factor=2 ** 1, mode='bilinear', align_corners=True)) 95 | 96 | fout, dout = self.pred0(fout=fout, dout=dout, XI=XI0, S=S, K=K) 97 | output.append(dout) 98 | return output 99 | 100 | def _make_layer(self, block, planes, blocks, stride=1): 101 | norm_layer = self._norm_layer 102 | padding_mode = self._padding_mode 103 | downsample = None 104 | if norm_layer is None: 105 | bias = True 106 | norm_layer = nn.Identity 107 | else: 108 | bias = False 109 | if stride != 1 or self.inplanes != planes * block.expansion: 110 | downsample = nn.Sequential( 111 | Conv1x1(self.inplanes, planes * block.expansion, stride, bias=bias), 112 | norm_layer(planes * block.expansion), 113 | ) 114 | 115 | layers = [] 116 | layers.append( 117 | block(self.inplanes, planes, stride, downsample, norm_layer=norm_layer, padding_mode=padding_mode, 118 | drop_path=self._img_dprs[self._img_dpc])) 119 | self._img_dpc += 1 120 | self.inplanes = planes * block.expansion 121 | for _ in range(1, blocks): 122 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer, padding_mode=padding_mode, 123 | drop_path=self._img_dprs[self._img_dpc])) 124 | self._img_dpc += 1 125 | 126 | return nn.Sequential(*layers) 127 | 128 | 129 | def Pre_MF_Post(): 130 | """ 131 | Pre.+MF.+Post. 132 | """ 133 | net = Net() 134 | net.apply(functools.partial(weights_init, mode='trunc')) 135 | for m in net.modules(): 136 | if isinstance(m, GenKernel) and m.conv[1].conv.bn.weight is not None: 137 | nn.init.constant_(m.conv[1].conv.bn.weight, 0) 138 | net.apply(inplace_relu) 139 | return net 140 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .BPNet import * -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env 2 | # -*- coding: utf-8 -*- 3 | # @Filename : utils 4 | # @Date : 2022-05-10 5 | # @Project: BP-Net 6 | # @AUTHOR : jie 7 | 8 | from copy import deepcopy 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Function 12 | import torch.nn.functional as F 13 | import math 14 | import numpy as np 15 | from collections import OrderedDict 16 | import BpOps 17 | import torch.distributed as dist 18 | from einops.layers.torch import Rearrange 19 | from timm.models.layers import DropPath 20 | from torch.cuda.amp import custom_fwd, custom_bwd 21 | import functools 22 | 23 | __all__ = [ 24 | 'EMA', 25 | 'Conv1x1', 26 | 'Conv3x3', 27 | 'Basic2d', 28 | 'weights_init', 29 | 'inplace_relu', 30 | 'PMP', 31 | ] 32 | 33 | 34 | class BpDist(Function): 35 | 36 | @staticmethod 37 | @custom_fwd(cast_inputs=torch.float32) 38 | def forward(ctx, xy, idx, Valid, num, H, W): 39 | """ 40 | """ 41 | assert xy.is_contiguous() 42 | assert Valid.is_contiguous() 43 | _, Cc, M = xy.shape 44 | B = Valid.shape[0] 45 | N = H * W 46 | args = torch.zeros((B, num, N), dtype=torch.long, device=xy.device) 47 | IPCnum = torch.zeros((B, Cc, num, N), dtype=xy.dtype, device=xy.device) 48 | for b in range(B): 49 | Pc = torch.masked_select(xy, Valid[b:b + 1].view(1, 1, N)).reshape(1, 2, -1) 50 | BpOps.Dist(Pc, IPCnum[b:b + 1], args[b:b + 1], H, W) 51 | idx_valid = torch.masked_select(idx, Valid[b:b + 1].view(1, 1, N)) 52 | args[b:b + 1] = torch.index_select(idx_valid, 0, args[b:b + 1].reshape(-1)).reshape(1, num, N) 53 | return IPCnum, args 54 | 55 | @staticmethod 56 | @custom_bwd 57 | def backward(ctx, ga=None, gb=None): 58 | return None, None, None, None 59 | 60 | 61 | class BpConvLocal(Function): 62 | @staticmethod 63 | def forward(ctx, input, weight): 64 | assert input.is_contiguous() 65 | assert weight.is_contiguous() 66 | ctx.save_for_backward(input, weight) 67 | output = BpOps.Conv2dLocal_F(input, weight) 68 | return output 69 | 70 | @staticmethod 71 | def backward(ctx, grad_output): 72 | input, weight = ctx.saved_tensors 73 | grad_output = grad_output.contiguous() 74 | grad_input, grad_weight = BpOps.Conv2dLocal_B(input, weight, grad_output) 75 | return grad_input, grad_weight 76 | 77 | 78 | bpdist = BpDist.apply 79 | bpconvlocal = BpConvLocal.apply 80 | 81 | 82 | class EMA(nn.Module): 83 | """ Model Exponential Moving Average V2 borrow from timm https://timm.fast.ai/ 84 | 85 | Keep a moving average of everything in the model state_dict (parameters and buffers). 86 | V2 of this module is simpler, it does not match params/buffers based on name but simply 87 | iterates in order. It works with torchscript (JIT of full model). 88 | 89 | This is intended to allow functionality like 90 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 91 | 92 | A smoothed version of the weights is necessary for some training schemes to perform well. 93 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use 94 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA 95 | smoothing of weights to match results. Pay attention to the decay constant you are using 96 | relative to your update count per epoch. 97 | 98 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but 99 | disable validation of the EMA weights. Validation will have to be done manually in a separate 100 | process, or after the training stops converging. 101 | 102 | This class is sensitive where it is initialized in the sequence of model init, 103 | GPU assignment and distributed training wrappers. 104 | """ 105 | 106 | def __init__(self, model, decay=0.9999, ddp=False): 107 | super().__init__() 108 | # make a copy of the model for accumulating moving average of weights 109 | self.module = deepcopy(model) 110 | self.module.eval() 111 | if ddp: 112 | self.broadcast() 113 | self.decay = decay 114 | 115 | def broadcast(self): 116 | for ema_v in self.module.state_dict().values(): 117 | dist.broadcast(ema_v, src=0, async_op=False) 118 | 119 | def _update(self, model, update_fn): 120 | with torch.no_grad(): 121 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): 122 | ema_v.copy_(update_fn(ema_v, model_v)) 123 | 124 | def update(self, model): 125 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) 126 | 127 | def set(self, model): 128 | self._update(model, update_fn=lambda e, m: m) 129 | 130 | 131 | def weights_init(m, mode='trunc'): 132 | from torch.nn.init import _calculate_fan_in_and_fan_out 133 | classname = m.__class__.__name__ 134 | if classname.find('Conv2d') != -1: 135 | if hasattr(m, 'weight'): 136 | if mode == 'trunc': 137 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight.data) 138 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 139 | torch.nn.init.trunc_normal_(m.weight.data, mean=0, std=std) 140 | elif mode == 'xavier': 141 | torch.nn.init.xavier_normal_(m.weight.data) 142 | else: 143 | raise ValueError(f'unknown mode = {mode}') 144 | if hasattr(m, 'bias') and m.bias is not None: 145 | torch.nn.init.constant_(m.bias.data, 0.0) 146 | if classname.find('Conv1d') != -1: 147 | if hasattr(m, 'weight'): 148 | if mode == 'trunc': 149 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight.data) 150 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 151 | torch.nn.init.trunc_normal_(m.weight.data, mean=0, std=std) 152 | elif mode == 'xavier': 153 | torch.nn.init.xavier_normal_(m.weight.data) 154 | else: 155 | raise ValueError(f'unknown mode = {mode}') 156 | if hasattr(m, 'bias') and m.bias is not None: 157 | torch.nn.init.constant_(m.bias.data, 0.0) 158 | elif classname.find('Linear') != -1: 159 | if mode == 'trunc': 160 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight.data) 161 | std = math.sqrt(2.0 / float(fan_in + fan_out)) 162 | torch.nn.init.trunc_normal_(m.weight.data, mean=0, std=std) 163 | elif mode == 'xavier': 164 | torch.nn.init.xavier_normal_(m.weight.data) 165 | else: 166 | raise ValueError(f'unknown mode = {mode}') 167 | if m.bias is not None: 168 | torch.nn.init.constant_(m.bias.data, 0.0) 169 | 170 | 171 | def Conv1x1(in_planes, out_planes, stride=1, bias=False, groups=1, dilation=1, padding_mode='zeros'): 172 | """1x1 convolution""" 173 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) 174 | 175 | 176 | def Conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, padding_mode='zeros', bias=False): 177 | """3x3 convolution with padding""" 178 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 179 | padding=dilation, padding_mode=padding_mode, groups=groups, bias=bias, dilation=dilation) 180 | 181 | 182 | class Basic2d(nn.Module): 183 | def __init__(self, in_channels, out_channels, norm_layer=None, kernel_size=3, padding=1, padding_mode='zeros', 184 | act=nn.ReLU, stride=1): 185 | super().__init__() 186 | if norm_layer: 187 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 188 | stride=stride, padding=padding, bias=False, padding_mode=padding_mode) 189 | else: 190 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 191 | stride=stride, padding=padding, bias=True, padding_mode=padding_mode) 192 | self.conv = nn.Sequential(OrderedDict([('conv', conv)])) 193 | if norm_layer: 194 | self.conv.add_module('bn', norm_layer(out_channels)) 195 | self.conv.add_module('relu', act()) 196 | 197 | def forward(self, x): 198 | out = self.conv(x) 199 | return out 200 | 201 | 202 | 203 | def inplace_relu(m): 204 | classname = m.__class__.__name__ 205 | if classname.find('ReLU') != -1: 206 | m.inplace = True 207 | 208 | 209 | class Basic2dTrans(nn.Module): 210 | def __init__(self, in_channels, out_channels, norm_layer=None, act=nn.ReLU): 211 | super().__init__() 212 | if norm_layer is None: 213 | bias = True 214 | norm_layer = nn.Identity 215 | else: 216 | bias = False 217 | self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, 218 | stride=2, padding=1, bias=bias) 219 | self.bn = norm_layer(out_channels) 220 | self.relu = act() 221 | 222 | def forward(self, x): 223 | out = self.conv(x.contiguous()) 224 | out = self.bn(out) 225 | out = self.relu(out) 226 | return out 227 | 228 | 229 | class UpCC(nn.Module): 230 | def __init__(self, in_channels, mid_channels, out_channels, norm_layer=None, kernel_size=3, padding=1, 231 | padding_mode='zeros', act=nn.ReLU): 232 | super().__init__() 233 | self.upf = Basic2dTrans(in_channels, out_channels, norm_layer=norm_layer, act=act) 234 | self.conv = Basic2d(mid_channels + out_channels, out_channels, 235 | norm_layer=norm_layer, kernel_size=kernel_size, 236 | padding=padding, padding_mode=padding_mode, act=act) 237 | 238 | def forward(self, x, y): 239 | """ 240 | """ 241 | out = self.upf(x) 242 | fout = torch.cat([out, y], dim=1) 243 | fout = self.conv(fout) 244 | return fout 245 | 246 | class GenKernel(nn.Module): 247 | def __init__(self, in_channels, pk, norm_layer=nn.BatchNorm2d, act=nn.ReLU, eps=1e-6): 248 | super().__init__() 249 | self.eps = eps 250 | self.conv = nn.Sequential( 251 | Basic2d(in_channels, in_channels, norm_layer=norm_layer, act=act), 252 | Basic2d(in_channels, pk * pk - 1, norm_layer=norm_layer, act=nn.Identity), 253 | ) 254 | 255 | def forward(self, fout): 256 | weight = self.conv(fout) 257 | weight_sum = torch.sum(weight.abs(), dim=1, keepdim=True) 258 | weight = torch.div(weight, weight_sum + self.eps) 259 | weight_mid = 1 - torch.sum(weight, dim=1, keepdim=True) 260 | weight_pre, weight_post = torch.split(weight, [weight.shape[1] // 2, weight.shape[1] // 2], dim=1) 261 | weight = torch.cat([weight_pre, weight_mid, weight_post], dim=1).contiguous() 262 | return weight 263 | 264 | 265 | class CSPN(nn.Module): 266 | """ 267 | implementation of CSPN++ 268 | """ 269 | 270 | def __init__(self, in_channels, pt, norm_layer=nn.BatchNorm2d, act=nn.ReLU, eps=1e-6): 271 | super().__init__() 272 | self.pt = pt 273 | self.weight3x3 = GenKernel(in_channels, 3, norm_layer=norm_layer, act=act, eps=eps) 274 | self.weight5x5 = GenKernel(in_channels, 5, norm_layer=norm_layer, act=act, eps=eps) 275 | self.weight7x7 = GenKernel(in_channels, 7, norm_layer=norm_layer, act=act, eps=eps) 276 | self.convmask = nn.Sequential( 277 | Basic2d(in_channels, in_channels, norm_layer=norm_layer, act=act), 278 | Basic2d(in_channels, 3, norm_layer=None, act=nn.Sigmoid), 279 | ) 280 | self.convck = nn.Sequential( 281 | Basic2d(in_channels, in_channels, norm_layer=norm_layer, act=act), 282 | Basic2d(in_channels, 3, norm_layer=None, act=functools.partial(nn.Softmax, dim=1)), 283 | ) 284 | self.convct = nn.Sequential( 285 | Basic2d(in_channels + 3, in_channels, norm_layer=norm_layer, act=act), 286 | Basic2d(in_channels, 3, norm_layer=None, act=functools.partial(nn.Softmax, dim=1)), 287 | ) 288 | 289 | @custom_fwd(cast_inputs=torch.float32) 290 | def forward(self, fout, hn, h0): 291 | weight3x3 = self.weight3x3(fout) 292 | weight5x5 = self.weight5x5(fout) 293 | weight7x7 = self.weight7x7(fout) 294 | mask3x3, mask5x5, mask7x7 = torch.split(self.convmask(fout) * (h0 > 1e-3).float(), 1, dim=1) 295 | conf3x3, conf5x5, conf7x7 = torch.split(self.convck(fout), 1, dim=1) 296 | hn3x3 = hn5x5 = hn7x7 = hn 297 | hns = [hn, ] 298 | for i in range(self.pt): 299 | hn3x3 = (1. - mask3x3) * bpconvlocal(hn3x3, weight3x3) + mask3x3 * h0 300 | hn5x5 = (1. - mask5x5) * bpconvlocal(hn5x5, weight5x5) + mask5x5 * h0 301 | hn7x7 = (1. - mask7x7) * bpconvlocal(hn7x7, weight7x7) + mask7x7 * h0 302 | if i == self.pt // 2 - 1: 303 | hns.append(conf3x3 * hn3x3 + conf5x5 * hn5x5 + conf7x7 * hn7x7) 304 | hns.append(conf3x3 * hn3x3 + conf5x5 * hn5x5 + conf7x7 * hn7x7) 305 | hns = torch.cat(hns, dim=1) 306 | wt = self.convct(torch.cat([fout, hns], dim=1)) 307 | hn = torch.sum(wt * hns, dim=1, keepdim=True) 308 | return hn 309 | 310 | 311 | class Coef(nn.Module): 312 | """ 313 | """ 314 | def __init__(self, in_channels, out_channels=3, kernel_size=1, padding=0): 315 | super().__init__() 316 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 317 | stride=1, padding=padding, bias=True) 318 | 319 | def forward(self, x): 320 | feat = self.conv(x) 321 | XF, XB, XW = torch.split(feat, [1, 1, 1], dim=1) 322 | return XF, XB, XW 323 | 324 | 325 | class Dist(nn.Module): 326 | """ 327 | """ 328 | 329 | def __init__(self, num): 330 | super().__init__() 331 | """ 332 | """ 333 | self.num = num 334 | 335 | def forward(self, S, xx, yy): 336 | """ 337 | """ 338 | num = self.num 339 | B, _, height, width = S.shape 340 | N = height * width 341 | S = S.reshape(B, 1, N) 342 | Valid = (S > 1e-3) 343 | xy = torch.stack((xx, yy), axis=0).reshape(1, 2, -1).float() 344 | idx = torch.arange(N, device=S.device).reshape(1, 1, N) 345 | Ofnum, args = bpdist(xy, idx, Valid, num, height, width) 346 | return Ofnum, args 347 | 348 | 349 | class BasicBlock(nn.Module): 350 | expansion = 1 351 | 352 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, padding_mode='zeros', act=nn.ReLU, 353 | last=True, drop_path=0.0): 354 | super().__init__() 355 | bias = False 356 | if norm_layer is None: 357 | bias = True 358 | norm_layer = nn.Identity 359 | self.conv1 = Conv3x3(inplanes, planes, stride, padding_mode=padding_mode, bias=bias) 360 | self.bn1 = norm_layer(planes) 361 | self.relu1 = act() 362 | self.conv2 = Conv3x3(planes, planes, padding_mode=padding_mode, bias=bias) 363 | self.bn2 = norm_layer(planes) 364 | self.downsample = downsample 365 | self.stride = stride 366 | self.last = last 367 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 368 | if last: 369 | self.relu2 = act() 370 | 371 | def forward(self, x): 372 | identity = x 373 | out = self.conv1(x) 374 | out = self.bn1(out) 375 | out = self.relu1(out) 376 | out = self.conv2(out) 377 | out = self.bn2(out) 378 | if self.downsample is not None: 379 | identity = self.downsample(x) 380 | out = self.drop_path(out) + identity 381 | if self.last: 382 | out = self.relu2(out) 383 | return out 384 | 385 | 386 | class UBNet(nn.Module): 387 | def __init__(self, inplanes, dplanes=1, norm_layer=nn.BatchNorm2d, padding_mode='zeros', act=nn.ReLU, 388 | blocknum=2, uplayer=UpCC, depth=1, block=BasicBlock, drop_path=0.): 389 | super().__init__() 390 | self._norm_layer = norm_layer 391 | self._act = act 392 | self._padding_mode = padding_mode 393 | bc = inplanes // 2 394 | self.inplanes = bc * 2 395 | self._img_dpc = 0 396 | self._img_dprs = np.linspace(0, drop_path, blocknum * (1 + depth)) 397 | encoder_list = [nn.Sequential( 398 | Basic2d(inplanes + dplanes, bc * 2, norm_layer=norm_layer, kernel_size=3, padding=1, 399 | padding_mode=padding_mode, act=act), 400 | self._make_layer(block, bc * 2, blocknum, stride=1) 401 | ), ] 402 | decoder_list = [] 403 | in_channels = bc * 2 404 | for i in range(depth): 405 | self.inplanes = in_channels 406 | out_channels = min(in_channels * 2, 256) 407 | encoder_list.append(self._make_layer(block, out_channels, blocknum, stride=2)) 408 | decoder_list.append(uplayer(out_channels, in_channels, in_channels, norm_layer=norm_layer)) 409 | in_channels = min(in_channels * 2, 256) 410 | self.encoder = nn.ModuleList(encoder_list) 411 | self.decoder = nn.ModuleList(decoder_list) 412 | 413 | def forward(self, x, d): 414 | feat = [] 415 | if d is not None: 416 | x = torch.cat([x, d], dim=1) 417 | for layer in self.encoder: 418 | x = layer(x) 419 | feat.append(x) 420 | out = feat[-1] 421 | for idx in range(len(feat) - 2, -1, -1): 422 | out = self.decoder[idx](out, feat[idx]) 423 | return out 424 | 425 | def _make_layer(self, block, planes, blocks, stride=1): 426 | norm_layer = self._norm_layer 427 | act = self._act 428 | padding_mode = self._padding_mode 429 | downsample = None 430 | if norm_layer is None: 431 | bias = True 432 | norm_layer = nn.Identity 433 | else: 434 | bias = False 435 | if stride != 1 or self.inplanes != planes * block.expansion: 436 | downsample = nn.Sequential( 437 | Conv1x1(self.inplanes, planes * block.expansion, stride, bias=bias), 438 | norm_layer(planes * block.expansion), 439 | ) 440 | layers = [] 441 | layers.append(block(self.inplanes, planes, stride, downsample, norm_layer, act=act, padding_mode=padding_mode, 442 | drop_path=self._img_dprs[self._img_dpc])) 443 | self._img_dpc += 1 444 | self.inplanes = planes * block.expansion 445 | for _ in range(1, blocks): 446 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer, act=act, padding_mode=padding_mode, 447 | drop_path=self._img_dprs[self._img_dpc])) 448 | self._img_dpc += 1 449 | 450 | return nn.Sequential(*layers) 451 | 452 | 453 | class Permute(nn.Module): 454 | def __init__(self, in_channels, out_channels=1, stride=2, norm_layer=nn.BatchNorm2d, act=nn.ReLU): 455 | super().__init__() 456 | self.stride = stride 457 | self.out_channels = out_channels 458 | self.conv = nn.Sequential( 459 | Basic2d(in_channels=in_channels, out_channels=in_channels, norm_layer=norm_layer, act=act, kernel_size=1, 460 | padding=0), 461 | Basic2d(in_channels=in_channels, out_channels=in_channels, norm_layer=norm_layer, act=act, kernel_size=1, 462 | padding=0), 463 | Conv1x1(in_channels, out_channels * stride ** 2, bias=True), 464 | Rearrange('b (c h2 w2) h w -> b c (h h2) (w w2)', c=out_channels, h2=stride, w2=stride), 465 | ) 466 | 467 | def forward(self, x): 468 | """ 469 | """ 470 | fout = self.conv(x) 471 | return fout 472 | 473 | 474 | class WPool(nn.Module): 475 | def __init__(self, in_ch, level, drift=1e6): 476 | super().__init__() 477 | self.level = level 478 | self.drift = drift 479 | self.permute = Permute(in_ch, stride=2 ** level) 480 | 481 | def forward(self, S, fout): 482 | W = self.permute(fout) 483 | size = int(2 ** self.level) 484 | M = (S > 1e-3).float() 485 | with torch.no_grad(): 486 | maxW = F.max_pool2d((W + self.drift) * M, size, stride=[size, size]) - self.drift 487 | maxW = F.upsample_nearest(maxW, scale_factor=size) * M 488 | expW = torch.exp(W * M - maxW) * M 489 | avgS = F.avg_pool2d(S * expW, kernel_size=size, stride=size) 490 | avgexpW = F.avg_pool2d(expW, kernel_size=size, stride=size) 491 | Sp = avgS / (avgexpW + 1e-6) 492 | return Sp 493 | 494 | 495 | class UpCat(nn.Module): 496 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, kernel_size=3, padding=1, 497 | padding_mode='zeros', act=nn.ReLU): 498 | super().__init__() 499 | self.upf = Basic2dTrans(in_channels + 1, out_channels, norm_layer=norm_layer, act=act) 500 | self.conv = Basic2d(out_channels * 2, out_channels, 501 | norm_layer=norm_layer, kernel_size=kernel_size, 502 | padding=padding, padding_mode=padding_mode, act=act) 503 | 504 | def forward(self, y, x, d): 505 | """ 506 | x is 507 | """ 508 | fout = self.upf(torch.cat([x, d], dim=1)) 509 | fout = self.conv(torch.cat([fout, y], dim=1)) 510 | return fout 511 | 512 | 513 | class Ident(nn.Module): 514 | def __init__(self, *args, **kwargs) -> None: 515 | super().__init__() 516 | 517 | def forward(self, *args): 518 | return args[0] 519 | 520 | 521 | class Prop(nn.Module): 522 | """ 523 | """ 524 | 525 | def __init__(self, Cfi, Cfp=3, Cfo=2, act=nn.GELU, norm_layer=nn.BatchNorm2d): 526 | super().__init__() 527 | """ 528 | """ 529 | self.dist = lambda x: (x * x).sum(1) 530 | Ct = Cfo + Cfi + Cfi + Cfp 531 | self.convXF = nn.Sequential( 532 | Basic2d(in_channels=Ct, out_channels=Cfi, norm_layer=norm_layer, act=act, kernel_size=1, 533 | padding=0), 534 | Basic2d(in_channels=Cfi, out_channels=Cfi, norm_layer=norm_layer, act=act, kernel_size=1, 535 | padding=0), 536 | ) 537 | self.convXL = nn.Sequential( 538 | Basic2d(in_channels=Cfi, out_channels=Cfi, norm_layer=norm_layer, act=act, kernel_size=1, 539 | padding=0), 540 | Basic2d(in_channels=Cfi, out_channels=Cfi, norm_layer=norm_layer, act=nn.Identity, kernel_size=1, 541 | padding=0), 542 | ) 543 | self.act = act() 544 | self.coef = Coef(Cfi, 3) 545 | 546 | def forward(self, If, Pf, Ofnum, args): 547 | """ 548 | """ 549 | num = args.shape[-2] 550 | B, Cfi, H, W = If.shape 551 | N = H * W 552 | B, Cfp, M = Pf.shape 553 | If = If.view(B, Cfi, 1, N) 554 | Pf = Pf.view(B, Cfp, 1, M) 555 | Ifnum = If.expand(B, Cfi, num, N) ## Ifnum is BxCfixnumxN 556 | IPfnum = torch.gather( 557 | input=If.expand(B, Cfi, num, N), 558 | dim=-1, 559 | index=args.view(B, 1, num, N).expand(B, Cfi, num, N)) ## IPfnum is BxCfixnumxN 560 | Pfnum = torch.gather( 561 | input=Pf.expand(B, Cfp, num, M), 562 | dim=-1, 563 | index=args.view(B, 1, num, N).expand(B, Cfp, num, N)) ## Pfnum is BxCfpxnumxN 564 | X = torch.cat([Ifnum, IPfnum, Pfnum, Ofnum], dim=1) 565 | XF = self.convXF(X) 566 | XF = self.act(XF + self.convXL(XF)) 567 | Alpha, Beta, Omega = self.coef(XF) 568 | Omega = torch.softmax(Omega, dim=2) 569 | dout = torch.sum(((Alpha + 1) * Pfnum[:, -1:] + Beta) * Omega, dim=2, keepdim=True) 570 | return dout.view(B, 1, H, W) 571 | 572 | 573 | class PMP(nn.Module): 574 | """ 575 | Pre+MF+Post 576 | """ 577 | 578 | def __init__(self, level, in_ch, out_ch, drop_path, up=True, pool=True): 579 | super().__init__() 580 | self.level = level 581 | if up: 582 | self.upcat = UpCat(in_ch, out_ch) 583 | else: 584 | self.upcat = Ident() 585 | if pool: 586 | self.wpool = WPool(out_ch, level=level) 587 | else: 588 | self.wpool = Ident() 589 | self.dist = Dist(num=4) 590 | self.prop = Prop(out_ch) 591 | self.fuse = UBNet(out_ch, dplanes=3, blocknum=2, depth=5 - level, drop_path=drop_path) 592 | self.conv = Conv3x3(out_ch, 1, bias=True) 593 | self.cspn = CSPN(out_ch, pt=2 * (6 - level)) 594 | 595 | def pinv(self, S, K, xx, yy): 596 | fx, fy, cx, cy = K[:, 0:1, 0:1], K[:, 1:2, 1:2], K[:, 0:1, 2:3], K[:, 1:2, 2:3] 597 | S = S.view(S.shape[0], 1, -1) 598 | xx = xx.reshape(1, 1, -1) 599 | yy = yy.reshape(1, 1, -1) 600 | Px = S * (xx - cx) / fx 601 | Py = S * (yy - cy) / fy 602 | Pz = S 603 | Pxyz = torch.cat([Px, Py, Pz], dim=1).contiguous() 604 | return Pxyz 605 | 606 | def forward(self, fout, dout, XI, S, K): 607 | fout = self.upcat(XI, fout, dout) 608 | Sp = self.wpool(S, fout) 609 | Kp = K.clone() 610 | Kp[:, :2] = Kp[:, :2] / 2 ** self.level 611 | B, _, height, width = Sp.shape 612 | xx, yy = torch.meshgrid(torch.arange(width, device=Sp.device), torch.arange(height, device=Sp.device), 613 | indexing='xy') 614 | ############################################################### 615 | # Pre 616 | Pxyz = self.pinv(Sp, Kp, xx, yy) 617 | Ofnum, args = self.dist(Sp, xx, yy) 618 | dout = self.prop(fout, Pxyz, Ofnum, args) 619 | ############################################################### 620 | # MF 621 | Pxyz = self.pinv(dout, Kp, xx, yy).view(dout.shape[0], 3, dout.shape[2], dout.shape[3]) 622 | fout = self.fuse(fout, Pxyz) 623 | res = self.conv(fout) 624 | dout = dout + res 625 | ############################################################### 626 | # Post 627 | dout = self.cspn(fout, dout, Sp) 628 | return fout, dout 629 | -------------------------------------------------------------------------------- /optimizers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : optimizers.py 3 | # @Project: BP-Net 4 | # @Author : jie 5 | # @Time : 5/11/22 3:47 PM 6 | 7 | from torch.optim import SGD, AdamW, Adam 8 | 9 | __all__ = [ 10 | 'Adam', 11 | 'AdamW', 12 | 'SGD', 13 | ] 14 | -------------------------------------------------------------------------------- /rpnloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class RandomProposalNormalizationLoss(nn.Module): 5 | def __init__(self, deltas=(2 ** (-5 * 2), 2 ** (-4 * 2), 2 ** (-3 * 2), 2 ** (-2 * 2), 2 ** (-1 * 2), 1), num_proposals=32, min_crop_ratio=0.125, max_crop_ratio=0.5): 6 | super(RandomProposalNormalizationLoss, self).__init__() 7 | self.num_proposals = num_proposals 8 | self.min_crop_ratio = min_crop_ratio 9 | self.max_crop_ratio = max_crop_ratio 10 | self.deltas = deltas 11 | 12 | def forward(self, outputs, target): 13 | """ 14 | :param predicted_depth: Predicted depth map (B, 1, H, W) 15 | :param ground_truth_depth: Ground truth depth map (B, 1, H, W) 16 | :return: RPNL loss value 17 | """ 18 | loss = [delta * self.compute_loss(ests, target) for ests, delta in zip(outputs, self.deltas)] 19 | return loss 20 | 21 | def compute_loss(self, predicted_depth, ground_truth_depth): 22 | B, _, H, W = predicted_depth.shape 23 | loss = 0.0 24 | 25 | for _ in range(self.num_proposals): 26 | # Randomly select crop size 27 | crop_ratio = torch.rand(1).item() * (self.max_crop_ratio - self.min_crop_ratio) + self.min_crop_ratio 28 | crop_h, crop_w = int(H * crop_ratio), int(W * crop_ratio) 29 | 30 | # Randomly select top-left corner of the crop 31 | top = torch.randint(0, H - crop_h + 1, (1,)).item() 32 | left = torch.randint(0, W - crop_w + 1, (1,)).item() 33 | 34 | # Extract patches 35 | pred_patch = predicted_depth[:, :, top:top + crop_h, left:left + crop_w] 36 | gt_patch = ground_truth_depth[:, :, top:top + crop_h, left:left + crop_w] 37 | 38 | # Flatten patches for normalization 39 | pred_patch = pred_patch.reshape(B, -1) 40 | gt_patch = gt_patch.reshape(B, -1) 41 | 42 | # Normalize patches using median absolute deviation normalization 43 | pred_median = pred_patch.median(dim=1, keepdim=True)[0] 44 | gt_median = gt_patch.median(dim=1, keepdim=True)[0] 45 | 46 | pred_mad = torch.median(torch.abs(pred_patch - pred_median), dim=1, keepdim=True)[0] 47 | gt_mad = torch.median(torch.abs(gt_patch - gt_median), dim=1, keepdim=True)[0] 48 | 49 | pred_patch_norm = (pred_patch - pred_median) / (pred_mad + 1e-6) 50 | gt_patch_norm = (gt_patch - gt_median) / (gt_mad + 1e-6) 51 | 52 | # Calculate the L1 difference between normalized patches 53 | patch_loss = torch.mean(torch.abs(pred_patch_norm - gt_patch_norm)) 54 | loss += patch_loss 55 | 56 | loss /= self.num_proposals 57 | return loss 58 | 59 | # Example usage: 60 | if __name__ == "__main__": 61 | # Create random predicted and ground truth depth maps 62 | predicted_depth = torch.rand(2, 1, 256, 256) # (batch_size, channels, height, width) 63 | ground_truth_depth = torch.rand(2, 1, 256, 256) 64 | 65 | # Initialize the loss function 66 | rpn_loss = RandomProposalNormalizationLoss() 67 | 68 | # Compute the loss 69 | loss_value = rpn_loss(predicted_depth, ground_truth_depth) 70 | print(f"RPN Loss: {loss_value.item()}") 71 | 72 | class RegularizationLoss(nn.Module): 73 | """ 74 | Enforce losses on pixels without any gts. 75 | """ 76 | def __init__(self, loss_weight=0.1): 77 | super(RegularizationLoss, self).__init__() 78 | self.loss_weight = loss_weight 79 | self.eps = 1e-6 80 | 81 | def forward(self, prediction, gt): 82 | mask = gt > 1e-3 83 | pred_wo_gt = prediction[~mask] 84 | loss = 1/ (torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + self.eps)) 85 | return loss * self.loss_weight -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # torchrun --nproc_per_node=4 --master_port 4321 train_distill.py \ 2 | # gpus=[0,1,2,3] num_workers=4 name=BP_KITTI \ 3 | # net=PMP data=KITTI \ 4 | # lr=1e-3 train_batch_size=2 test_batch_size=2 \ 5 | # sched/lr=NoiseOneCycleCosMo sched.lr.policy.max_momentum=0.90 \ 6 | # nepoch=30 test_epoch=25 ++net.sbn=true 7 | 8 | torchrun --nproc_per_node=1 --master_port 4321 train_distill.py \ 9 | gpus=[0] num_workers=1 name=BP_KITTI \ 10 | net=PMP data=KITTI \ 11 | lr=1e-3 train_batch_size=2 test_batch_size=2 \ 12 | sched/lr=NoiseOneCycleCosMo sched.lr.policy.max_momentum=0.90 \ 13 | nepoch=30 test_epoch=25 ++net.sbn=true -------------------------------------------------------------------------------- /schedulers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : schedulers.py 3 | # @Project: BP-Net 4 | # @Author : jie 5 | # @Time : 5/11/22 3:50 PM 6 | import random 7 | import sys 8 | from torch.optim.lr_scheduler import StepLR, MultiStepLR, OneCycleLR, LambdaLR, LinearLR, ExponentialLR 9 | import numpy as np 10 | import torch 11 | 12 | 13 | def NoiseLR(**kwargs): 14 | lr_sched = getattr(sys.modules[__name__], kwargs.pop('lr_sched', 'OneCycleLR')) 15 | 16 | class sched(lr_sched): 17 | def __init__(self, **kwargs): 18 | self.noise_pct = kwargs.pop('noise_pct', 0.1) 19 | self.noise_seed = kwargs.pop('noise_seed', 0) 20 | super().__init__(**kwargs) 21 | 22 | def get_lr(self): 23 | """ 24 | lrn: Learning Rate with Noise 25 | """ 26 | g = torch.Generator() 27 | g.manual_seed(self.noise_seed + self.last_epoch) 28 | noise = 2 * torch.rand(1, generator=g).item() - 1 29 | lrs = super().get_lr() 30 | lrn = [] 31 | for lr in lrs: 32 | lrn.append(lr * (1 + self.noise_pct * noise)) 33 | return lrn 34 | 35 | return sched(**kwargs) 36 | 37 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : train_amp.py 3 | # @Project: BP-Net 4 | # @Author : jie 5 | # @Time : 10/27/21 3:58 PM 6 | 7 | import torch 8 | from tqdm import tqdm 9 | import hydra 10 | from PIL import Image 11 | import os 12 | from omegaconf import OmegaConf 13 | from utils import * 14 | 15 | 16 | def test(run, mode='selval', save=False): 17 | dataloader = run.testloader 18 | net = run.net_ema.module 19 | net.eval() 20 | tops = [AverageMeter() for i in range(len(run.metric.metric_name))] 21 | if save: 22 | dir_path = f'results/{run.cfg.name}/{mode}' 23 | os.makedirs(dir_path, exist_ok=True) 24 | with torch.no_grad(): 25 | for idx, datas in enumerate( 26 | tqdm(dataloader, desc="test ", dynamic_ncols=True, leave=False, disable=run.rank)): 27 | datas = run.init_cuda(*datas) 28 | output = net(*datas[:-1]) 29 | if isinstance(output, (list, tuple)): 30 | output = output[-1] 31 | precs = run.metric(output, datas[-1]) 32 | for prec, top in zip(precs, tops): 33 | top.update(prec.mean().detach().cpu().item()) 34 | if save: 35 | for i in range(output.shape[0]): 36 | index = idx * output.shape[0] + i 37 | file_path = os.path.join(dir_path, f'{index:010d}.png') 38 | img = (output[i, 0] * 256.0).detach().cpu().numpy().astype('uint16') 39 | Img = Image.fromarray(img) 40 | Img.save(file_path) 41 | logs = "" 42 | for name, top in zip(run.metric.metric_name, tops): 43 | logs += f" {name}:{top.avg:.7f} " 44 | run.ddp_log(logs, always=True) 45 | 46 | 47 | @hydra.main(config_path='configs', config_name='config', version_base='1.2') 48 | def main(cfg): 49 | with Trainer(cfg) as run: 50 | test(run, mode=cfg.data.testset.mode, save=OmegaConf.select(cfg, 'save', default=False)) 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /train_distill.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : train_amp.py 3 | # @Project: BP-Net 4 | # @Author : jie 5 | # @Time : 10/27/21 3:58 PM 6 | 7 | import torch 8 | from tqdm import tqdm 9 | import hydra 10 | import torch.distributed as dist 11 | from utils import * 12 | from rpnloss import RandomProposalNormalizationLoss 13 | 14 | def train(run): 15 | 16 | compute_disp_loss = RandomProposalNormalizationLoss() 17 | 18 | for datas in tqdm(run.trainloader, desc="train", dynamic_ncols=True, leave=False, disable=run.rank): 19 | if run.epoch >= run.cfg.test_epoch: 20 | if run.iter % run.cfg.test_iter == 0: 21 | test(run, iter=True) 22 | datas = run.init_cuda(*datas) 23 | run.net.train() 24 | run.optimizer.zero_grad(set_to_none=True) 25 | output = run.net(*datas[:-1]) 26 | loss = run.criterion(output, datas[-1]) 27 | 28 | disp = 1 / (datas[1] + 0.5) 29 | # disp_loss = compute_disp_loss(output, disp, torch.ones_like(disp).float()) 30 | disp_loss = compute_disp_loss(output, disp) 31 | 32 | loss += disp_loss 33 | 34 | sum(loss).backward() 35 | if run.clip: 36 | grad_norm = run.clip(run.net.parameters()) 37 | run.optimizer.step() 38 | run.net_ema.update(run.net) 39 | if run.lr_iter: 40 | run.lr_scheduler.step() 41 | if run.iter % run.cfg.vis_iter == 0: 42 | run.writer.add_scalar("Lr", run.optimizer.param_groups[0]['lr'], run.iter) 43 | run.writer.add_scalars("Loss", {f"{idx}": l.item() for idx, l in enumerate(loss)}, run.iter) 44 | if run.clip and (grad_norm is not None): 45 | run.writer.add_scalar("GradNorm", grad_norm.item(), run.iter) 46 | run.iter += 1 47 | if not run.lr_iter: 48 | run.lr_scheduler.step() 49 | run.writer.flush() 50 | 51 | 52 | def test(run, iter=False): 53 | top1 = AverageMeter() 54 | net = run.net_ema.module 55 | best_metric_name = "best_metric_ema" 56 | legand = 'net_ema' 57 | net.eval() 58 | with torch.no_grad(): 59 | for datas in tqdm(run.testloader, desc="test ", dynamic_ncols=True, leave=False, disable=run.rank): 60 | datas = run.init_cuda(*datas) 61 | output = net(*datas[:-1]) 62 | if isinstance(output, (list, tuple)): 63 | output = output[-1] 64 | prec1 = run.metric(output, datas[-1]) 65 | if isinstance(prec1, (list, tuple)): 66 | prec1 = prec1[0] 67 | if run.ddp: 68 | dist.reduce(prec1, 0, dist.ReduceOp.AVG) 69 | top1.update(prec1.item()) 70 | if iter: 71 | run.writer.add_scalars("RMSE_Iter", {legand: top1.avg}, run.iter) 72 | else: 73 | run.writer.add_scalars("RMSE", {legand: top1.avg}, run.epoch) 74 | 75 | if top1.avg < getattr(run, best_metric_name): 76 | setattr(run, best_metric_name, top1.avg) 77 | run.save_state() 78 | run.ddp_cout(f'Epoch: {run.epoch} {best_metric_name}: {top1.avg:.7f}\n') 79 | else: 80 | best_metric_value = getattr(run, best_metric_name) 81 | run.ddp_cout(f'Epoch: {run.epoch} current: {top1.avg:.7f}, best: {best_metric_value:.7f}. \n') 82 | 83 | 84 | @hydra.main(config_path='configs', config_name='config', version_base='1.2') 85 | def main(cfg): 86 | with Trainer(cfg) as run: 87 | for epoch in tqdm(range(run.cfg.start_epoch, run.cfg.nepoch), desc="epoch", dynamic_ncols=True): 88 | run.epoch = epoch 89 | if run.train_sampler: 90 | run.train_sampler.set_epoch(epoch) 91 | train(run) 92 | torch.cuda.synchronize() 93 | test(run) 94 | torch.cuda.synchronize() 95 | 96 | if __name__ == '__main__': 97 | main() 98 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env 2 | # -*- coding: utf-8 -*- 3 | # @Filename : utils 4 | # @Date : 2022-05-06 5 | # @Project: BP-Net 6 | # @AUTHOR : jie 7 | 8 | import os 9 | import torch 10 | import random 11 | import numpy as np 12 | from torch.utils.tensorboard import SummaryWriter 13 | import time 14 | from tqdm import tqdm 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | import logging 17 | from hydra.utils import instantiate, get_class 18 | from omegaconf import OmegaConf 19 | import cv2 20 | import augs 21 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_ 22 | from collections import OrderedDict 23 | import math 24 | 25 | __all__ = [ 26 | 'AverageMeter', 27 | 'Trainer', 28 | ] 29 | 30 | 31 | class AverageMeter(object): 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | 48 | class Trainer(object): 49 | def __init__(self, cfg): 50 | self.cfg = cfg 51 | self.rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0 52 | self.cfg.gpu_id = self.cfg.gpus[self.rank] 53 | self.init_gpu() 54 | self.ddp = len(self.cfg.gpus) > 1 55 | self.iter = 0 56 | self.epoch = 0 57 | self.best_metric_ema = 100 58 | ##################################################################################### 59 | self.log = self.init_log() 60 | self.init_device() 61 | self.init_seed() 62 | self.writer = self.init_viz() 63 | self.trainloader, self.testloader = self.init_dataset() 64 | net = self.init_net() 65 | criterion = self.init_loss() 66 | metric = self.init_metric() 67 | self.net, self.criterion, self.metric = self.init_cuda(net, criterion, metric) 68 | self.net_ema = self.init_ema() 69 | if self.ddp: 70 | self.net = DDP(self.net) 71 | self.optimizer = self.init_optim() 72 | self.lr_scheduler = self.init_sched_lr() 73 | self.lr_iter = OmegaConf.select(self.cfg.sched.lr, 'iter', default=False) 74 | self.clip = self.init_clip() 75 | 76 | def init_log(self): 77 | return Blank() if self.rank else logging.getLogger(f'{self.cfg.name}') 78 | 79 | def init_device(self): 80 | torch.cuda.set_device(f'cuda:{self.cfg.gpu_id}') 81 | if self.ddp: 82 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 83 | self.ddp_log(f'device is {self.cfg.gpu_id}', always=True) 84 | 85 | def init_seed(self): 86 | manual_seed = self.cfg.manual_seed 87 | self.ddp_log(f"Random Seed: {manual_seed:04d}") 88 | torch.initial_seed() 89 | random.seed(manual_seed) 90 | np.random.seed(manual_seed) 91 | torch.manual_seed(manual_seed) 92 | torch.cuda.manual_seed_all(manual_seed) 93 | 94 | def init_viz(self): 95 | if self.rank: 96 | return Blank() 97 | else: 98 | writer_name = os.path.join('runs', f'{self.cfg.name}') 99 | return SummaryWriter(writer_name) 100 | 101 | def init_dataset(self): 102 | trainset = instantiate(self.cfg.data.trainset) 103 | testset = instantiate(self.cfg.data.testset) 104 | if self.ddp: 105 | train_sampler = torch.utils.data.distributed.DistributedSampler( 106 | trainset, 107 | num_replicas=len(self.cfg.gpus), 108 | rank=self.rank, 109 | shuffle=True, 110 | ) 111 | test_sampler = torch.utils.data.distributed.DistributedSampler( 112 | testset, 113 | num_replicas=len(self.cfg.gpus), 114 | rank=self.rank, 115 | shuffle=False, 116 | ) 117 | else: 118 | train_sampler = None 119 | test_sampler = None 120 | self.train_sampler = train_sampler 121 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.cfg.train_batch_size, 122 | num_workers=self.cfg.num_workers, shuffle=(train_sampler is None), 123 | sampler=train_sampler, 124 | drop_last=True, pin_memory=True) 125 | testloader = torch.utils.data.DataLoader(testset, batch_size=self.cfg.test_batch_size, 126 | num_workers=self.cfg.num_workers, shuffle=False, 127 | sampler=test_sampler, 128 | drop_last=True, pin_memory=True) 129 | self.ddp_log(f'num_train = {len(trainloader)}, num_test = {len(testloader)}') 130 | return trainloader, testloader 131 | 132 | 133 | def init_net(self): 134 | model = instantiate(self.cfg.net.model) 135 | if 'chpt' in self.cfg: 136 | self.ddp_log(f'resume CHECKPOINTS') 137 | save_path = os.path.join('checkpoints', self.cfg.chpt) 138 | cp = torch.load(self.cfg.chpt, map_location=torch.device('cpu')) 139 | model.load_state_dict(cp['net'], strict=True) 140 | self.best_metric_ema = cp['best_metric_ema'] 141 | del cp 142 | if self.ddp and OmegaConf.select(self.cfg.net, 'sbn', default=False): 143 | self.ddp_log('sbn') 144 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 145 | else: 146 | """ 147 | SBN is not compatible with torch.compile 148 | """ 149 | if OmegaConf.select(self.cfg.net, 'compile', default=False): 150 | self.ddp_log('compile') 151 | model = torch.compile(model) 152 | return model 153 | 154 | def init_ema(self): 155 | return instantiate(self.cfg.net.ema, model=self.net, ddp=self.ddp) 156 | 157 | def init_loss(self): 158 | return instantiate(self.cfg.loss) 159 | 160 | def init_metric(self): 161 | return instantiate(self.cfg.metric) 162 | 163 | def init_cuda(self, *modules): 164 | modules = [module.to(f'cuda:{self.cfg.gpu_id}') for module in modules] 165 | return modules 166 | 167 | def init_gpu(self): 168 | with torch.cuda.device(f'cuda:{self.cfg.gpu_id}'): 169 | torch.cuda.empty_cache() 170 | torch.backends.cudnn.benchmark = True 171 | 172 | def init_optim(self): 173 | optim = instantiate(self.cfg.optim, _partial_=True) 174 | return optim(params=config_param(self.net)) 175 | 176 | def init_clip(self): 177 | if 'clip' in self.cfg.net: 178 | return instantiate(self.cfg.net.clip, _partial_=True) 179 | else: 180 | return None 181 | 182 | def init_sched_lr(self): 183 | sched = instantiate(self.cfg.sched.lr.policy, _partial_=True) 184 | return sched(optimizer=self.optimizer) 185 | 186 | 187 | def ddp_log(self, content, always=False): 188 | # self.log.info(f'{content}') 189 | if (not self.rank) or always: 190 | self.log.info(f'{content}') 191 | 192 | def ddp_cout(self, content, always=False): 193 | # tqdm.write(f'{content}') 194 | if (not self.rank) or always: 195 | tqdm.write(f'{content}') 196 | 197 | def save_state(self): 198 | if self.rank: 199 | return 200 | save_path = os.path.join('checkpoints', self.cfg.name) 201 | os.makedirs(save_path, exist_ok=True) 202 | model = self.net_ema.module 203 | if hasattr(model, 'module'): 204 | model = model.module 205 | model_state_dict = model.state_dict() 206 | state_dict = { 207 | 'net': model_state_dict, 208 | 'epoch': self.epoch, 209 | 'best_metric_ema': self.best_metric_ema, 210 | } 211 | torch.save(state_dict, os.path.join(save_path, 'result_ema.pth')) 212 | 213 | def __enter__(self): 214 | return self 215 | 216 | def __exit__(self, *args, **kwargs): 217 | self.writer.close() 218 | self.ddp_log(f'best_metric_ema={self.best_metric_ema:.4f}') 219 | 220 | 221 | def config_param(model): 222 | param_groups = [] 223 | other_params = [] 224 | for name, param in model.named_parameters(): 225 | if len(param.shape) == 1: 226 | g = {'params': [param], 'weight_decay': 0.0} 227 | param_groups.append(g) 228 | else: 229 | other_params.append(param) 230 | param_groups.append({'params': other_params}) 231 | return param_groups 232 | 233 | 234 | def set_requires_grad(model, requires_grad=True): 235 | for p in model.parameters(): 236 | p.requires_grad = requires_grad 237 | 238 | 239 | 240 | class Blank(object): 241 | def __getattr__(self, name): 242 | def wrapper(*args, **kwargs): 243 | return None 244 | return wrapper 245 | 246 | FloorDiv = lambda a, b: a // b 247 | 248 | CeilDiv = lambda a, b: math.ceil(a / b) 249 | 250 | Div = lambda a, b: a / b 251 | 252 | Mul = lambda a, b: a * b 253 | 254 | 255 | -------------------------------------------------------------------------------- /utils_infer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env 2 | # -*- coding: utf-8 -*- 3 | # @Filename : utils 4 | # @Date : 2022-05-06 5 | # @Project: BP-Net 6 | # @AUTHOR : jie 7 | 8 | import os 9 | import torch 10 | import random 11 | import numpy as np 12 | from torch.utils.tensorboard import SummaryWriter 13 | import time 14 | from tqdm import tqdm 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | import logging 17 | from hydra.utils import instantiate, get_class 18 | from omegaconf import OmegaConf 19 | import cv2 20 | import augs 21 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_ 22 | from collections import OrderedDict 23 | import math 24 | 25 | __all__ = [ 26 | 'AverageMeter', 27 | 'Trainer', 28 | ] 29 | 30 | 31 | class AverageMeter(object): 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | 48 | class Trainer(object): 49 | def __init__(self, cfg): 50 | self.cfg = cfg 51 | self.rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0 52 | self.cfg.gpu_id = self.cfg.gpus[self.rank] 53 | self.init_gpu() 54 | self.ddp = len(self.cfg.gpus) > 1 55 | self.iter = 0 56 | self.epoch = 0 57 | self.best_metric_ema = 100 58 | ##################################################################################### 59 | self.log = self.init_log() 60 | self.init_device() 61 | self.init_seed() 62 | self.writer = self.init_viz() 63 | net = self.init_net() 64 | criterion = self.init_loss() 65 | metric = self.init_metric() 66 | self.net, self.criterion, self.metric = self.init_cuda(net, criterion, metric) 67 | self.net_ema = self.init_ema() 68 | if self.ddp: 69 | self.net = DDP(self.net) 70 | self.optimizer = self.init_optim() 71 | self.lr_scheduler = self.init_sched_lr() 72 | self.lr_iter = OmegaConf.select(self.cfg.sched.lr, 'iter', default=False) 73 | self.clip = self.init_clip() 74 | 75 | def init_log(self): 76 | return Blank() if self.rank else logging.getLogger(f'{self.cfg.name}') 77 | 78 | def init_device(self): 79 | torch.cuda.set_device(f'cuda:{self.cfg.gpu_id}') 80 | if self.ddp: 81 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 82 | self.ddp_log(f'device is {self.cfg.gpu_id}', always=True) 83 | 84 | def init_seed(self): 85 | manual_seed = self.cfg.manual_seed 86 | self.ddp_log(f"Random Seed: {manual_seed:04d}") 87 | torch.initial_seed() 88 | random.seed(manual_seed) 89 | np.random.seed(manual_seed) 90 | torch.manual_seed(manual_seed) 91 | torch.cuda.manual_seed_all(manual_seed) 92 | 93 | def init_viz(self): 94 | if self.rank: 95 | return Blank() 96 | else: 97 | writer_name = os.path.join('runs', f'{self.cfg.name}') 98 | return SummaryWriter(writer_name) 99 | 100 | def init_net(self): 101 | model = instantiate(self.cfg.net.model) 102 | if 'chpt' in self.cfg: 103 | self.ddp_log(f'resume CHECKPOINTS') 104 | save_path = os.path.join('checkpoints', self.cfg.chpt) 105 | cp = torch.load(os.path.join(save_path, 'result_ema.pth'), map_location=torch.device('cpu')) 106 | model.load_state_dict(cp['net'], strict=True) 107 | self.best_metric_ema = cp['best_metric_ema'] 108 | print("Epoch:", cp["epoch"]) 109 | del cp 110 | if self.ddp and OmegaConf.select(self.cfg.net, 'sbn', default=False): 111 | self.ddp_log('sbn') 112 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 113 | else: 114 | """ 115 | SBN is not compatible with torch.compile 116 | """ 117 | if OmegaConf.select(self.cfg.net, 'compile', default=False): 118 | self.ddp_log('compile') 119 | model = torch.compile(model) 120 | return model 121 | 122 | def init_ema(self): 123 | return instantiate(self.cfg.net.ema, model=self.net, ddp=self.ddp) 124 | 125 | def init_loss(self): 126 | return instantiate(self.cfg.loss) 127 | 128 | def init_metric(self): 129 | return instantiate(self.cfg.metric) 130 | 131 | def init_cuda(self, *modules): 132 | modules = [module.to(f'cuda:{self.cfg.gpu_id}') for module in modules] 133 | return modules 134 | 135 | def init_gpu(self): 136 | with torch.cuda.device(f'cuda:{self.cfg.gpu_id}'): 137 | torch.cuda.empty_cache() 138 | torch.backends.cudnn.benchmark = True 139 | 140 | def init_optim(self): 141 | optim = instantiate(self.cfg.optim, _partial_=True) 142 | return optim(params=config_param(self.net)) 143 | 144 | def init_clip(self): 145 | if 'clip' in self.cfg.net: 146 | return instantiate(self.cfg.net.clip, _partial_=True) 147 | else: 148 | return None 149 | 150 | def init_sched_lr(self): 151 | sched = instantiate(self.cfg.sched.lr.policy, _partial_=True) 152 | return sched(optimizer=self.optimizer) 153 | 154 | 155 | def ddp_log(self, content, always=False): 156 | # self.log.info(f'{content}') 157 | if (not self.rank) or always: 158 | self.log.info(f'{content}') 159 | 160 | def ddp_cout(self, content, always=False): 161 | # tqdm.write(f'{content}') 162 | if (not self.rank) or always: 163 | tqdm.write(f'{content}') 164 | 165 | def save_state(self): 166 | if self.rank: 167 | return 168 | save_path = os.path.join('checkpoints', self.cfg.name) 169 | os.makedirs(save_path, exist_ok=True) 170 | model = self.net_ema.module 171 | if hasattr(model, 'module'): 172 | model = model.module 173 | model_state_dict = model.state_dict() 174 | state_dict = { 175 | 'net': model_state_dict, 176 | 'epoch': self.epoch, 177 | 'best_metric_ema': self.best_metric_ema, 178 | } 179 | torch.save(state_dict, os.path.join(save_path, 'result_ema.pth')) 180 | 181 | def __enter__(self): 182 | return self 183 | 184 | def __exit__(self, *args, **kwargs): 185 | self.writer.close() 186 | self.ddp_log(f'best_metric_ema={self.best_metric_ema:.4f}') 187 | 188 | 189 | def config_param(model): 190 | param_groups = [] 191 | other_params = [] 192 | for name, param in model.named_parameters(): 193 | if len(param.shape) == 1: 194 | g = {'params': [param], 'weight_decay': 0.0} 195 | param_groups.append(g) 196 | else: 197 | other_params.append(param) 198 | param_groups.append({'params': other_params}) 199 | return param_groups 200 | 201 | 202 | def set_requires_grad(model, requires_grad=True): 203 | for p in model.parameters(): 204 | p.requires_grad = requires_grad 205 | 206 | 207 | 208 | class Blank(object): 209 | def __getattr__(self, name): 210 | def wrapper(*args, **kwargs): 211 | return None 212 | return wrapper 213 | 214 | FloorDiv = lambda a, b: a // b 215 | 216 | CeilDiv = lambda a, b: math.ceil(a / b) 217 | 218 | Div = lambda a, b: a / b 219 | 220 | Mul = lambda a, b: a * b 221 | 222 | 223 | --------------------------------------------------------------------------------