├── LICENSE
├── README.md
├── augs.py
├── checkpoints
└── .gitignore
├── configs
├── config.yaml
├── data
│ ├── KITTI.yaml
│ └── NYU.yaml
├── hydra.yaml
├── loss
│ └── MSMSE.yaml
├── metric
│ ├── MetricALL.yaml
│ └── RMSE.yaml
├── net
│ └── PMP.yaml
├── optim
│ └── AdamW.yaml
└── sched
│ └── lr
│ └── NoiseOneCycleCosMo.yaml
├── criteria.py
├── datas
└── .gitignore
├── datasets
├── __init__.py
├── kitti.py
└── nyu.py
├── demo.py
├── demo.sh
├── environment.yml
├── exts
├── bp_cuda.cpp
├── bp_cuda.h
├── bp_cuda_kernel.cu
└── setup.py
├── models
├── BPNet.py
├── __init__.py
└── utils.py
├── optimizers.py
├── rpnloss.py
├── run.sh
├── schedulers.py
├── test.py
├── train_distill.py
├── utils.py
└── utils_infer.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Jie Tang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DMD³C: Distilling Monocular Foundation Model for Fine-grained Depth Completion
2 | **Official Code for the CVPR 2025 Paper**
3 | **"[CVPR 2025] Distilling Monocular Foundation Model for Fine-grained Depth Completion"**
4 |
5 | [📄 Paper on arXiv](https://arxiv.org/abs/2503.16970)
6 |
7 | ---
8 |
9 | ## 🆕 Update Log
10 |
11 | - **[2025.04.23]** We have released the **2rd stage training code**! 🎉
12 | - **[2025.04.11]** We have released the **inference code**! 🎉
13 |
14 | ## ✅ To Do
15 |
16 | - [ ] 📦 Easy-to-use **data generation pipeline**
17 | ---
18 |
19 |
20 |

21 |
22 |
23 | ---
24 |
25 | ## 🔍 Overview
26 |
27 | DMD³C introduces a novel framework for **fine-grained depth completion** by distilling knowledge from **monocular foundation models**. This approach significantly enhances depth estimation accuracy in sparse data, especially in regions without ground-truth supervision.
28 |
29 | ---
30 | 
31 |
32 |
33 | ---
34 |
35 |
36 |
37 | ## 🚀 Getting Started
38 |
39 | ### 1. Clone Base Repository
40 |
41 | ```bash
42 | git clone https://github.com/kakaxi314/BP-Net.git
43 | ```
44 |
45 | ### 2. Copy This Repo into the BP-Net Directory
46 |
47 | ```bash
48 | cp DMD3C/* BP-Net/
49 | cd BP-Net/DMD3C/
50 | ```
51 |
52 | ### 3. Prepare KITTI Raw Data
53 |
54 | Download any sequence from the **KITTI Raw dataset**, which includes:
55 |
56 | - Camera intrinsics
57 | - Velodyne point cloud
58 | - Image sequences
59 |
60 | Make sure the structure follows the **standard KITTI format**.
61 |
62 | ### 4. Modify the Sequence in `demo.py` for Inference
63 |
64 | Open `demo.py` and go to **line 338**, where you can modify the input sequence path according to your downloaded KITTI data.
65 |
66 | ```python
67 | # demo.py (Line 338)
68 | sequence = "/path/to/your/kitti/sequence"
69 | ```
70 |
71 | Download pre-trained weights:
72 |
73 | ```
74 | wget https://github.com/Sharpiless/DMD3C/releases/download/pretrain-checkpoints/dmd3c_distillation_depth_anything_v2.pth
75 | mv dmd3c_distillation_depth_anything_v2.pth checkpoints
76 | ```
77 |
78 | Run inference:
79 | ```bash
80 | bash demo.sh
81 | ```
82 |
83 | You will get results like this:
84 |
85 | 
86 |
87 | ### 5. Train on KITTI
88 |
89 | Runing monocular depth estimation for all KITTI-raw images. Data structure:
90 | ```
91 | ├── datas/kitti/raw/
92 | │ ├── 2011_09_26
93 | │ │ ├── 2011_09_26_drive_0001_sync
94 | │ │ │ ├── image_02
95 | │ │ │ │ ├── data/*.png
96 | │ │ │ │ ├── disp/*.png
97 | │ │ │ ├── image_03
98 | │ │ ├── 2011_09_26_drive_0002_sync.......
99 | ```
100 |
101 | Where disparity images are stored in gray-scale.
102 |
103 | Download pre-trained checkpoitns:
104 | ```
105 | wget https://github.com/Sharpiless/DMD3C/releases/download/pretrain-checkpoints/pretrained_mixed_singleview_256.pth
106 | mv pretrained_mixed_singleview_256.pth checkpoints
107 | ```
108 |
109 | Zero-shot preformance on KITTI valiation set:
110 |
111 | | Training Data | RMSE | MAE | iRMSE | REL |
112 | |----------------------|----------|----------|----------|----------|
113 | | Single-view Images | 1.4251 | 0.3722 | 0.0056 | 0.0235 |
114 |
115 |
116 | Run metric-finetuning on KITTI dataset:
117 | ```
118 | torchrun --nproc_per_node=4 --master_port 4321 train_distill.py \
119 | gpus=[0,1,2,3] num_workers=4 name=DMD3D_BP_KITTI \
120 | ++chpt=checkpoints/pretrained_mixed_singleview_256.pth \
121 | net=PMP data=KITTI \
122 | lr=5e-4 train_batch_size=2 test_batch_size=1 \
123 | sched/lr=NoiseOneCycleCosMo sched.lr.policy.max_momentum=0.90 \
124 | nepoch=30 test_epoch=25 ++net.sbn=true
125 | ```
126 |
--------------------------------------------------------------------------------
/augs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # @Filename: augs.py
4 | # @Project: BP-Net
5 | # @Author: jie
6 | # @Time: 2021/3/14 8:27 PM
7 |
8 |
9 | import numpy as np
10 |
11 | __all__ = [
12 | 'Compose',
13 | 'Norm',
14 | 'Jitter',
15 | 'Flip',
16 | ]
17 |
18 |
19 | class Compose(object):
20 | """
21 | Sequential operations on input images, (i.e. rgb, lidar and depth).
22 | """
23 |
24 | def __init__(self, transforms):
25 | self.transforms = transforms
26 |
27 | def __call__(self, rgb, disp, lidar, depth, K_cam):
28 | for t in self.transforms:
29 | rgb, disp, lidar, depth, K_cam = t(rgb, disp, lidar, depth, K_cam)
30 | return rgb, disp, lidar, depth, K_cam
31 |
32 |
33 | class Norm(object):
34 | """
35 | normalize rgb image.
36 | """
37 |
38 | def __init__(self, mean, std):
39 | self.mean = np.array(mean)
40 | self.std = np.array(std)
41 |
42 | def __call__(self, rgb, disp, lidar, depth, K_cam):
43 | rgb = (rgb - self.mean) / self.std
44 | disp = disp / 255
45 | return rgb, disp, lidar, depth, K_cam
46 |
47 | class Jitter(object):
48 | """
49 | borrow from https://github.com/kujason/avod/blob/master/avod/datasets/kitti/kitti_aug.py
50 | """
51 |
52 | def __call__(self, rgb, disp, lidar, depth, K_cam):
53 | pca = compute_pca(rgb)
54 | rgb = add_pca_jitter(rgb, pca)
55 | return rgb, disp, lidar, depth, K_cam
56 |
57 |
58 |
59 | class Flip(object):
60 | """
61 | random horizontal flip of images.
62 | """
63 |
64 | def __call__(self, rgb, disp, lidar, depth, CamK):
65 | width = rgb.shape[1]
66 | flip = bool(np.random.randint(2))
67 | if flip:
68 | rgb = rgb[:, ::-1, :]
69 | lidar = lidar[:, ::-1, :]
70 | depth = depth[:, ::-1, :]
71 | CamK[0, 2] = (width - 1) - CamK[0, 2]
72 | return rgb, disp, lidar, depth, CamK
73 |
74 |
75 | def compute_pca(image):
76 | """
77 | calculate PCA of image
78 | """
79 |
80 | reshaped_data = image.reshape(-1, 3)
81 | reshaped_data = (reshaped_data / 255.0).astype(np.float32)
82 | covariance = np.cov(reshaped_data.T)
83 | e_vals, e_vecs = np.linalg.eigh(covariance)
84 | pca = np.sqrt(e_vals) * e_vecs
85 | return pca
86 |
87 |
88 | def add_pca_jitter(img_data, pca):
89 | """
90 | add a multiple of principle components with Gaussian noise
91 | """
92 | new_img_data = np.copy(img_data).astype(np.float32) / 255.0
93 | magnitude = np.random.randn(3) * 0.1
94 | noise = (pca * magnitude).sum(axis=1)
95 |
96 | new_img_data = new_img_data + noise
97 | np.clip(new_img_data, 0.0, 1.0, out=new_img_data)
98 | new_img_data = (new_img_data * 255).astype(np.uint8)
99 |
100 | return new_img_data
--------------------------------------------------------------------------------
/checkpoints/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - hydra
3 | - optim: AdamW
4 | - sched/lr: NoiseOneCycleCosMo
5 | - data: KITTI
6 | - net: PMP
7 | - loss: MSMSE
8 | - metric: RMSE
9 | - _self_
10 | #
11 |
12 | nepoch: 30
13 | test_epoch: 25
14 | test_iter: 1000
15 | lr: 0.001
16 | gpus:
17 | - 0
18 | train_batch_size: 2
19 | test_batch_size: 2
20 | num_workers: 4
21 | manual_seed: 1
22 | vis_iter: 100
23 | start_epoch: 0
24 | gpu_id: ???
25 | name: ???
26 | device:
27 | _target_: torch.device
28 | device: cuda:${gpu_id}
--------------------------------------------------------------------------------
/configs/data/KITTI.yaml:
--------------------------------------------------------------------------------
1 | mean:
2 | - 90.9950
3 | - 96.2278
4 | - 94.3213
5 | std:
6 | - 79.2382
7 | - 80.5267
8 | - 82.1483
9 | train_sample_size: 85898
10 | test_sample_size: 1000
11 | height: 256
12 | width: 1216
13 |
14 | trainset:
15 | _target_: datasets.KITTI
16 | mode: 'train'
17 | RandCrop: true
18 | path: ${data.path}
19 | height: ${data.height}
20 | width: ${data.width}
21 | mean: ${data.mean}
22 | std: ${data.std}
23 |
24 | testset:
25 | _target_: datasets.KITTI
26 | mode: 'selval'
27 | RandCrop: false
28 | path: ${data.path}
29 | height: ${data.height}
30 | width: ${data.width}
31 | mean: ${data.mean}
32 | std: ${data.std}
33 |
34 | path: datas/kitti
35 | mul_factor: 1.0
--------------------------------------------------------------------------------
/configs/data/NYU.yaml:
--------------------------------------------------------------------------------
1 | mean:
2 | - 117.
3 | - 97.
4 | - 91.
5 | std:
6 | - 70.
7 | - 71.
8 | - 74.
9 | train_sample_size: 47584
10 | test_sample_size: 654
11 | height: 256
12 | width: 320
13 |
14 |
15 | trainset:
16 | _target_: datasets.NYU
17 | mode: 'train'
18 | path: ${data.path}
19 | num_sample: ${data.npoints}
20 | mul_factor: ${data.mul_factor}
21 | num_mask: ${data.num_mask}
22 | scale_kcam: true
23 |
24 |
25 | testset:
26 | _target_: datasets.NYU
27 | mode: 'val'
28 | path: ${data.path}
29 | num_sample: ${data.npoints}
30 | mul_factor: ${data.mul_factor}
31 | num_mask: ${data.num_mask}
32 | scale_kcam: false
33 |
34 |
35 | path: datas/nyu
36 | npoints: 500
37 | mul_factor: 10.0
38 | num_mask: 1
--------------------------------------------------------------------------------
/configs/hydra.yaml:
--------------------------------------------------------------------------------
1 | hydra:
2 | run:
3 | dir: outputs/${name}/${now:%Y-%m-%d_%H-%M-%S}
4 | output_subdir: .
5 | job:
6 | chdir: false
7 | # job_logging:
8 | # handlers:
9 | # file:
10 | # filename: ${name}.log
--------------------------------------------------------------------------------
/configs/loss/MSMSE.yaml:
--------------------------------------------------------------------------------
1 | _target_: criteria.MSMSE
--------------------------------------------------------------------------------
/configs/metric/MetricALL.yaml:
--------------------------------------------------------------------------------
1 | _target_: criteria.MetricALL
2 | mul_factor: ${data.mul_factor}
--------------------------------------------------------------------------------
/configs/metric/RMSE.yaml:
--------------------------------------------------------------------------------
1 | _target_: criteria.RMSE
2 | mul_factor: ${data.mul_factor}
--------------------------------------------------------------------------------
/configs/net/PMP.yaml:
--------------------------------------------------------------------------------
1 | # Propagation Field Baseline Relu6 Clip EMA
2 |
3 | name: PMP
4 |
5 | model:
6 | _target_: models.Pre_MF_Post
7 |
8 |
9 | ema:
10 | _target_: models.EMA
11 | decay: 0.9999
12 |
13 |
14 | clip:
15 | _target_: utils.clip_grad_norm_
16 | max_norm: 0.1
--------------------------------------------------------------------------------
/configs/optim/AdamW.yaml:
--------------------------------------------------------------------------------
1 | _target_: optimizers.AdamW
2 | lr: ${lr}
3 | weight_decay: 0.05
4 | betas:
5 | - 0.9
6 | - 0.999
--------------------------------------------------------------------------------
/configs/sched/lr/NoiseOneCycleCosMo.yaml:
--------------------------------------------------------------------------------
1 | #Noise One Cycle Cos With cycle_momentum
2 | policy:
3 | _target_: schedulers.NoiseLR
4 | lr_sched: OneCycleLR
5 | anneal_strategy: cos
6 | epochs: ${nepoch}
7 | div_factor: 40.0
8 | final_div_factor: 0.1
9 | last_epoch: -1
10 | max_lr: ${lr}
11 | cycle_momentum: true
12 | base_momentum: 0.85
13 | max_momentum: 0.95
14 | pct_start: 0.1
15 | noise_pct: 0.1
16 | steps_per_epoch: ${sched.lr.steps_per_epoch}
17 |
18 | iter: true
19 |
20 | steps_per_epoch:
21 | _target_: utils.FloorDiv
22 | _args_:
23 | - _target_: utils.CeilDiv
24 | _args_:
25 | - ${data.train_sample_size}
26 | - _target_: builtins.len
27 | _args_:
28 | - ${gpus}
29 | - ${train_batch_size}
--------------------------------------------------------------------------------
/criteria.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # @Filename: criteria.py
4 | # @Project: BP-Net
5 | # @Author: jie
6 | # @Time: 2021/3/14 7:51 PM
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | __all__ = [
12 | 'RMSE',
13 | 'MSMSE',
14 | 'MetricALL',
15 | ]
16 |
17 |
18 | class RMSE(nn.Module):
19 |
20 | def __init__(self, mul_factor=1.):
21 | super().__init__()
22 | self.mul_factor = mul_factor
23 | self.metric_name = [
24 | 'RMSE',
25 | ]
26 |
27 | def forward(self, outputs, target):
28 | outputs = outputs / self.mul_factor
29 | target = target / self.mul_factor
30 | val_pixels = (target > 1e-3).float()
31 | err = (target * val_pixels - outputs * val_pixels) ** 2
32 | loss = torch.sum(err.view(err.size(0), 1, -1), -1, keepdim=True)
33 | cnt = torch.sum(val_pixels.view(val_pixels.size(0), 1, -1), -1, keepdim=True)
34 | return torch.sqrt(loss / cnt).mean(),
35 |
36 |
37 | class MSMSE(nn.Module):
38 | """
39 | Multi-Scale MSE
40 | """
41 |
42 | def __init__(self, deltas=(2 ** (-5 * 2), 2 ** (-4 * 2), 2 ** (-3 * 2), 2 ** (-2 * 2), 2 ** (-1 * 2), 1)):
43 | super().__init__()
44 | self.deltas = deltas
45 |
46 | def mse(self, est, gt):
47 | valid = (gt > 1e-3).float()
48 | loss = est * valid - gt * valid
49 | return (loss ** 2).mean()
50 |
51 | def forward(self, outputs, target):
52 | loss = [delta * self.mse(ests, target) for ests, delta in zip(outputs, self.deltas)]
53 | return loss
54 |
55 |
56 | class MetricALL(nn.Module):
57 | def __init__(self, mul_factor):
58 | super().__init__()
59 | self.t_valid = 0.0001
60 | self.mul_factor = mul_factor
61 | self.metric_name = [
62 | 'RMSE', 'MAE', 'iRMSE', 'iMAE', 'REL', 'D^1', 'D^2', 'D^3', 'D102', 'D105', 'D110'
63 | ]
64 |
65 | def forward(self, pred, gt):
66 | with torch.no_grad():
67 | pred = pred.detach() / self.mul_factor
68 | gt = gt.detach() / self.mul_factor
69 | pred_inv = 1.0 / (pred + 1e-8)
70 | gt_inv = 1.0 / (gt + 1e-8)
71 |
72 | # For numerical stability
73 | mask = gt > self.t_valid
74 | # num_valid = mask.sum()
75 | B = mask.size(0)
76 | num_valid = torch.sum(mask.view(B, -1), -1, keepdim=True)
77 |
78 | # pred = pred[mask]
79 | # gt = gt[mask]
80 | pred = pred * mask
81 | gt = gt * mask
82 |
83 | # pred_inv = pred_inv[mask]
84 | # gt_inv = gt_inv[mask]
85 | pred_inv = pred_inv * mask
86 | gt_inv = gt_inv * mask
87 |
88 | # pred_inv[pred <= self.t_valid] = 0.0
89 | # gt_inv[gt <= self.t_valid] = 0.0
90 |
91 | # RMSE / MAE
92 | diff = pred - gt
93 | diff_abs = torch.abs(diff)
94 | diff_sqr = torch.pow(diff, 2)
95 |
96 | rmse = torch.sum(diff_sqr.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
97 | rmse = torch.sqrt(rmse)
98 |
99 | mae = torch.sum(diff_abs.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
100 |
101 | # iRMSE / iMAE
102 | diff_inv = pred_inv - gt_inv
103 | diff_inv_abs = torch.abs(diff_inv)
104 | diff_inv_sqr = torch.pow(diff_inv, 2)
105 |
106 | irmse = torch.sum(diff_inv_sqr.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
107 | irmse = torch.sqrt(irmse)
108 |
109 | imae = torch.sum(diff_inv_abs.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
110 |
111 | # Rel
112 | rel = diff_abs / (gt + 1e-8)
113 | rel = torch.sum(rel.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
114 |
115 | # delta
116 | r1 = gt / (pred + 1e-8)
117 | r2 = pred / (gt + 1e-8)
118 | ratio = torch.max(r1, r2)
119 |
120 | ratio = torch.max(ratio, 10000 * (1 - mask.float()))
121 |
122 | del_1 = (ratio < 1.25).type_as(ratio)
123 | del_2 = (ratio < 1.25 ** 2).type_as(ratio)
124 | del_3 = (ratio < 1.25 ** 3).type_as(ratio)
125 | del_102 = (ratio < 1.02).type_as(ratio)
126 | del_105 = (ratio < 1.05).type_as(ratio)
127 | del_110 = (ratio < 1.10).type_as(ratio)
128 |
129 | del_1 = torch.sum(del_1.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
130 | del_2 = torch.sum(del_2.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
131 | del_3 = torch.sum(del_3.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
132 | del_102 = torch.sum(del_102.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
133 | del_105 = torch.sum(del_105.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
134 | del_110 = torch.sum(del_110.view(B, -1), -1, keepdim=True) / (num_valid + 1e-8)
135 |
136 | result = [rmse, mae, irmse, imae, rel, del_1, del_2, del_3, del_102, del_105, del_110]
137 |
138 | return result
139 |
--------------------------------------------------------------------------------
/datas/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .nyu import *
2 | from .kitti import *
--------------------------------------------------------------------------------
/datasets/kitti.py:
--------------------------------------------------------------------------------
1 | import random
2 | import os
3 | import numpy as np
4 | import glob
5 | from PIL import Image
6 | import torch
7 | import augs
8 | from torchvision.utils import save_image
9 |
10 | __all__ = [
11 | 'KITTI',
12 | ]
13 |
14 | def read_calib_file(filepath):
15 | """Read in a calibration file and parse into a dictionary."""
16 | data = {}
17 |
18 | with open(filepath, 'r') as f:
19 | for line in f.readlines():
20 | key, value = line.split(':', 1)
21 | # The only non-float values in these files are dates, which
22 | # we don't care about anyway
23 | try:
24 | data[key] = np.array([float(x) for x in value.split()])
25 | except ValueError:
26 | pass
27 |
28 | return data
29 |
30 |
31 | class KITTI(torch.utils.data.Dataset):
32 | """
33 | kitti depth completion dataset: http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion
34 | """
35 |
36 | def __init__(self, path='datas/kitti', mode='train', height=256, width=1216, mean=(90.9950, 96.2278, 94.3213),
37 | std=(79.2382, 80.5267, 82.1483), RandCrop=False, tp_min=50, *args, **kwargs):
38 | self.base_dir = path
39 | self.height = height
40 | self.width = width
41 | self.mode = mode
42 | if mode == 'train':
43 | self.transform = augs.Compose([
44 | augs.Jitter(),
45 | augs.Flip(),
46 | augs.Norm(mean=mean, std=std),
47 | ])
48 | else:
49 | self.transform = augs.Compose([
50 | augs.Norm(mean=mean, std=std),
51 | ])
52 | self.RandCrop = RandCrop and mode == 'train'
53 | self.tp_min = tp_min
54 | if mode in ['train', 'val']:
55 | self.depth_path = os.path.join(self.base_dir, 'data_depth_annotated', mode)
56 | self.lidar_path = os.path.join(self.base_dir, 'data_depth_velodyne', mode)
57 | self.depths = list(sorted(glob.iglob(self.depth_path + "/**/*.png", recursive=True)))
58 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/**/*.png", recursive=True)))
59 | elif mode == 'selval':
60 | self.depth_path = os.path.join(self.base_dir, 'val_selection_cropped', 'groundtruth_depth')
61 | self.lidar_path = os.path.join(self.base_dir, 'val_selection_cropped', 'velodyne_raw')
62 | self.image_path = os.path.join(self.base_dir, 'val_selection_cropped', 'image')
63 | self.depths = list(sorted(glob.iglob(self.depth_path + "/*.png", recursive=True)))
64 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/*.png", recursive=True)))
65 | self.images = list(sorted(glob.iglob(self.image_path + "/*.png", recursive=True)))
66 | elif mode == 'test':
67 | self.lidar_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'velodyne_raw')
68 | self.image_path = os.path.join(self.base_dir, 'test_depth_completion_anonymous', 'image')
69 | self.lidars = list(sorted(glob.iglob(self.lidar_path + "/*.png", recursive=True)))
70 | self.images = list(sorted(glob.iglob(self.image_path + "/*.png", recursive=True)))
71 | self.depths = self.lidars
72 | else:
73 | raise ValueError("Unknown mode: {}".format(mode))
74 | assert (len(self.depths) == len(self.lidars))
75 | self.names = [os.path.split(path)[-1] for path in self.depths]
76 | x = np.arange(width)
77 | y = np.arange(height)
78 | xx, yy = np.meshgrid(x, y)
79 | xy = np.stack((xx, yy), axis=-1)
80 | self.xy = xy
81 |
82 | def __len__(self):
83 | return len(self.depths)
84 |
85 | def __getitem__(self, index):
86 | return self.get_item(index)
87 |
88 | def get_item(self, index):
89 | depth = self.pull_DEPTH(self.depths[index])
90 | depth = np.expand_dims(depth, axis=2)
91 | lidar = self.pull_DEPTH(self.lidars[index])
92 | lidar = np.expand_dims(lidar, axis=2)
93 | K_cam = self.pull_K_cam(index).astype(np.float32)
94 |
95 | file_names = self.depths[index].split('/')
96 | if self.mode in ['train', 'val']:
97 | rgb_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0], file_names[-5],
98 | file_names[-2], 'data', file_names[-1])
99 | disp_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0], file_names[-5],
100 | file_names[-2], 'disp', file_names[-1])
101 | elif self.mode in ['selval', 'test']:
102 | rgb_path = self.images[index]
103 | disp_path = self.images[index]
104 | else:
105 | raise ValueError("Unknown mode: {}".format(self.mode))
106 | rgb = self.pull_RGB(rgb_path)
107 | disp = self.pull_DISP(disp_path)
108 | rgb = rgb.astype(np.float32)
109 | disp = disp.astype(np.float32)[:, :, None]
110 | lidar = lidar.astype(np.float32)
111 | depth = depth.astype(np.float32)
112 | if self.transform:
113 | rgb, disp, lidar, depth, K_cam = self.transform(rgb, disp, lidar, depth, K_cam)
114 | rgb = rgb.transpose(2, 0, 1).astype(np.float32)
115 | disp = disp.transpose(2, 0, 1).astype(np.float32)
116 | lidar = lidar.transpose(2, 0, 1).astype(np.float32)
117 | depth = depth.transpose(2, 0, 1).astype(np.float32)
118 | tp = rgb.shape[1] - self.height
119 | lp = (rgb.shape[2] - self.width) // 2
120 | if self.RandCrop:
121 | tp = random.randint(self.tp_min, tp)
122 | lp = random.randint(0, rgb.shape[2] - self.width)
123 | rgb = rgb[:, tp:tp + self.height, lp:lp + self.width]
124 | disp = disp[:, tp:tp + self.height, lp:lp + self.width]
125 | lidar = lidar[:, tp:tp + self.height, lp:lp + self.width]
126 | depth = depth[:, tp:tp + self.height, lp:lp + self.width]
127 | K_cam[0, 2] -= lp
128 | K_cam[1, 2] -= tp
129 |
130 | return rgb, disp, lidar, K_cam, depth
131 |
132 | def pull_DISP(self, path):
133 | disp = np.array(Image.open(path).convert('L'), dtype=np.uint8)
134 | return disp
135 |
136 | def pull_RGB(self, path):
137 | img = np.array(Image.open(path).convert('RGB'), dtype=np.uint8)
138 | return img
139 |
140 | def pull_DEPTH(self, path):
141 | depth_png = np.array(Image.open(path), dtype=int)
142 | assert (np.max(depth_png) > 255)
143 | depth_image = (depth_png / 256.).astype(np.float32)
144 | return depth_image
145 |
146 | def pull_K_cam(self, index):
147 | file_names = self.depths[index].split('/')
148 | if self.mode in ['train', 'val', 'trainval']:
149 | calib_path = os.path.join(*file_names[:-7], 'raw', file_names[-5].split('_drive')[0],
150 | 'calib_cam_to_cam.txt')
151 | filedata = read_calib_file(calib_path)
152 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4))
153 | P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4))
154 | if file_names[-2] == 'image_02':
155 | K_cam = P_rect_20[0:3, 0:3]
156 | elif file_names[-2] == 'image_03':
157 | K_cam = P_rect_30[0:3, 0:3]
158 | else:
159 | raise ValueError("Unknown mode: {}".format(file_names[-2]))
160 |
161 | elif self.mode in ['selval', 'test']:
162 | fns = self.images[index].split('/')
163 | calib_path = os.path.join(*fns[:-2], 'intrinsics', fns[-1][:-3] + 'txt')
164 | with open(calib_path, 'r') as f:
165 | K_cam = f.read().split()
166 | K_cam = np.array(K_cam, dtype=np.float32).reshape(3, 3)
167 | else:
168 | raise ValueError("Unknown mode: {}".format(self.mode))
169 | return K_cam
170 |
--------------------------------------------------------------------------------
/datasets/nyu.py:
--------------------------------------------------------------------------------
1 | """
2 | Berrow from CompletionFormer https://github.com/youmi-zym/CompletionFormer
3 | ======================================================================
4 |
5 | NYU Depth V2 Dataset Helper
6 | """
7 |
8 | import os
9 | import warnings
10 |
11 | import numpy as np
12 | import json
13 | import h5py
14 |
15 | from PIL import Image
16 | import torch
17 | import torchvision.transforms as T
18 | import torchvision.transforms.functional as TF
19 | from torch.utils.data import Dataset
20 | import glob
21 |
22 | __all__ = [
23 | 'NYU',
24 | ]
25 |
26 |
27 | class BaseDataset(Dataset):
28 | def __init__(self, args, mode):
29 | self.args = args
30 | self.mode = mode
31 |
32 | def __len__(self):
33 | pass
34 |
35 | def __getitem__(self, idx):
36 | pass
37 |
38 | # A workaround for a pytorch bug
39 | # https://github.com/pytorch/vision/issues/2194
40 | class ToNumpy:
41 | def __call__(self, sample):
42 | return np.array(sample)
43 |
44 |
45 | class NYU(BaseDataset):
46 | def __init__(self, mode, path='../datas/nyudepthv2', num_sample=500, mul_factor=1., num_mask=8, scale_kcam=False,
47 | rand_scale=True, *args, **kwargs):
48 | super(NYU, self).__init__(None, mode)
49 |
50 | self.mode = mode
51 | self.num_sample = num_sample
52 | self.num_mask = num_mask
53 | self.mul_factor = mul_factor
54 | self.scale_kcam = scale_kcam
55 | self.rand_scale = rand_scale
56 |
57 | if mode != 'train' and mode != 'val':
58 | raise NotImplementedError
59 |
60 | height, width = (240, 320)
61 | crop_size = (228, 304)
62 |
63 | self.height = height
64 | self.width = width
65 | self.crop_size = crop_size
66 |
67 | self.Kcam = torch.from_numpy(np.array(
68 | [
69 | [5.1885790117450188e+02, 0, 3.2558244941119034e+02],
70 | [0, 5.1946961112127485e+02, 2.5373616633400465e+02],
71 | [0, 0, 1.],
72 | ], dtype=np.float32
73 | )
74 | )
75 |
76 | base_dir = path
77 |
78 | self.sample_list = list(sorted(glob.glob(os.path.join(base_dir, mode, "**/**.h5"))))
79 |
80 | def __len__(self):
81 | if self.mode == 'train':
82 | return len(self.sample_list)
83 | elif self.mode == 'val':
84 | return self.num_mask * len(self.sample_list)
85 | else:
86 | raise NotImplementedError
87 |
88 | def __getitem__(self, idx):
89 | if self.mode == 'val':
90 | seed = idx % self.num_mask
91 | idx = idx // self.num_mask
92 |
93 | path_file = self.sample_list[idx]
94 |
95 | f = h5py.File(path_file, 'r')
96 | rgb_h5 = f['rgb'][:].transpose(1, 2, 0)
97 | dep_h5 = f['depth'][:]
98 |
99 | rgb = Image.fromarray(rgb_h5, mode='RGB')
100 | dep = Image.fromarray(dep_h5.astype('float32'), mode='F')
101 |
102 | Kcam = self.Kcam.clone()
103 |
104 | if self.mode == 'train':
105 | if self.rand_scale:
106 | _scale = np.random.uniform(1.0, 1.5)
107 | else:
108 | _scale = 1.0
109 | scale = int(self.height * _scale)
110 | degree = np.random.uniform(-5.0, 5.0)
111 | flip = np.random.uniform(0.0, 1.0)
112 |
113 | if flip > 0.5:
114 | rgb = TF.hflip(rgb)
115 | dep = TF.hflip(dep)
116 | Kcam[0, 2] = rgb.width - 1 - Kcam[0, 2]
117 |
118 | rgb = TF.rotate(rgb, angle=degree)
119 | dep = TF.rotate(dep, angle=degree)
120 |
121 | t_rgb = T.Compose([
122 | T.Resize(scale),
123 | T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
124 | T.CenterCrop(self.crop_size),
125 | T.ToTensor(),
126 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
127 | ])
128 |
129 | t_dep = T.Compose([
130 | T.Resize(scale),
131 | T.CenterCrop(self.crop_size),
132 | self.ToNumpy(),
133 | T.ToTensor()
134 | ])
135 |
136 | rgb = t_rgb(rgb)
137 | dep = t_dep(dep)
138 |
139 | if self.scale_kcam:
140 | Kcam[:2] = Kcam[:2] * _scale
141 |
142 | else:
143 | t_rgb = T.Compose([
144 | T.Resize(self.height),
145 | T.CenterCrop(self.crop_size),
146 | T.ToTensor(),
147 | T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
148 | ])
149 |
150 | t_dep = T.Compose([
151 | T.Resize(self.height),
152 | T.CenterCrop(self.crop_size),
153 | self.ToNumpy(),
154 | T.ToTensor()
155 | ])
156 |
157 | rgb = t_rgb(rgb)
158 | dep = t_dep(dep)
159 |
160 | if self.mode == 'train':
161 | dep_sp = self.get_sparse_depth(dep, self.num_sample)
162 | elif self.mode == 'val':
163 | dep_sp = self.mask_sparse_depth(dep, self.num_sample, seed)
164 | else:
165 | raise NotImplementedError
166 |
167 | rgb = TF.pad(rgb, padding=[8, 14], padding_mode='edge')
168 | # rgb = TF.pad(rgb, padding=[8, 14], padding_mode='constant')
169 | dep_sp = TF.pad(dep_sp, padding=[8, 14], padding_mode='constant')
170 | dep = TF.pad(dep, padding=[8, 14], padding_mode='constant')
171 |
172 | Kcam[:2] = Kcam[:2] / 2.
173 | Kcam[0, 2] += 8 - 8
174 | Kcam[1, 2] += -6 + 14
175 |
176 | dep_sp *= self.mul_factor
177 | dep *= self.mul_factor
178 | return rgb, dep_sp, Kcam, dep
179 |
180 | def get_sparse_depth(self, dep, num_sample):
181 | channel, height, width = dep.shape
182 |
183 | assert channel == 1
184 |
185 | idx_nnz = torch.nonzero(dep.view(-1) > 0.0001, as_tuple=False)
186 |
187 | num_idx = len(idx_nnz)
188 | idx_sample = torch.randperm(num_idx)[:num_sample]
189 |
190 | idx_nnz = idx_nnz[idx_sample[:]]
191 |
192 | mask = torch.zeros((channel * height * width))
193 | mask[idx_nnz] = 1.0
194 | mask = mask.view((channel, height, width))
195 |
196 | dep_sp = dep * mask.type_as(dep)
197 |
198 | if num_idx == 0:
199 | dep_sp[:, 20:-20:10, 20:-20:10] = 3.
200 |
201 | return dep_sp
202 |
203 | def mask_sparse_depth(self, dep, num_sample, seed):
204 | channel, height, width = dep.shape
205 | dep = dep.numpy().reshape(-1)
206 | np.random.seed(seed)
207 | index = np.random.choice(height * width, num_sample, replace=False)
208 | dep_sp = np.zeros_like(dep)
209 | dep_sp[index] = dep[index]
210 | dep_sp = dep_sp.reshape(channel, height, width)
211 | dep_sp = torch.from_numpy(dep_sp)
212 | return dep_sp
213 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d
2 | import numpy as np
3 | import cv2
4 | import os
5 | import matplotlib.pyplot as plt
6 | import torch
7 | import hydra
8 | from utils_infer import Trainer
9 | from mpl_toolkits.mplot3d import Axes3D
10 | from tqdm import tqdm
11 | from PIL import Image
12 | import imutils
13 |
14 | def save_point_cloud_to_image(pcd, image_size=(1600, 1200)):
15 | vis = o3d.visualization.Visualizer()
16 | vis.create_window(visible=False, width=image_size[0], height=image_size[1]) # 设置窗口大小并不显示
17 |
18 | # 添加点云到可视化器中
19 | vis.add_geometry(pcd)
20 |
21 | # 获取 ViewControl 对象并设置自定义视角
22 | view_control = vis.get_view_control()
23 |
24 | # 设置视角参数
25 | parameters = {
26 | "boundingbox_max" : [ -1.5977719363706235, 11.519330868353832, 84.127326965332031 ],
27 | "boundingbox_min" : [ -56.623178268627761, -50.724836932361825, 4.6948032379150391 ],
28 | "field_of_view" : 60.0,
29 | "front" : [ 0.36788389015943362, 0.28372788418091033, -0.8855280521244856 ],
30 | "lookat" : [ -29.110475102499194, -19.602753032003996, 44.411065101623535 ],
31 | "up" : [ -0.88654778269461509, -0.18027422226025128, -0.42606834403382188 ],
32 | "zoom" : 0.5199999999999998
33 | }
34 |
35 | # 应用视角设置
36 | ctr = vis.get_view_control()
37 | ctr.set_lookat(parameters["lookat"])
38 | ctr.set_front(parameters["front"])
39 | ctr.set_up(parameters["up"])
40 | ctr.set_zoom(parameters["zoom"])
41 |
42 | # 渲染点云为图像
43 | vis.poll_events()
44 | vis.update_renderer()
45 |
46 | # 获取点云的2D图像
47 | depth_image = np.asarray(vis.capture_screen_float_buffer(do_render=True))
48 | vis.destroy_window()
49 |
50 | # 调整尺寸并格式化图像
51 | depth_image = (depth_image * 255).astype(np.uint8) # 转换为8位图像
52 | depth_image = cv2.cvtColor(depth_image, cv2.COLOR_RGB2BGR) # 转换为BGR格式以便与OpenCV兼容
53 |
54 | return depth_image
55 |
56 | def depth_to_point_cloud(depth_map, K_cam):
57 | h, w = depth_map.shape
58 | i, j = np.meshgrid(np.arange(w), np.arange(h))
59 |
60 | # 将像素坐标变换为相机坐标
61 | z = depth_map
62 | x = (j - K_cam[0, 2]) * z / K_cam[0, 0]
63 | y = (i - K_cam[1, 2]) * z / K_cam[1, 1]
64 | y = -y
65 |
66 | points_3d = np.stack((x, y, z), axis=-1).reshape(-1, 3)
67 |
68 | return points_3d
69 |
70 |
71 | def read_calib_file(filepath):
72 | """Read in a calibration file and parse into a dictionary."""
73 | data = {}
74 |
75 | with open(filepath, 'r') as f:
76 | for line in f.readlines():
77 | key, value = line.split(':', 1)
78 | # The only non-float values in these files are dates, which
79 | # we don't care about anyway
80 | try:
81 | data[key] = np.array([float(x) for x in value.split()])
82 | except ValueError:
83 | pass
84 |
85 | return data
86 |
87 |
88 | def pull_K_cam(calib_path):
89 | filedata = read_calib_file(calib_path)
90 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4))
91 | K_cam = P_rect_20[0:3, 0:3]
92 | return K_cam
93 |
94 |
95 | def depth_color(val, min_d=0, max_d=120):
96 | """
97 | print Color(HSV's H value) corresponding to distance(m)
98 | close distance = red , far distance = blue
99 | """
100 | np.clip(val, 0, max_d, out=val) # max distance is 120m but usually not usual
101 | return (((val - min_d) / (max_d - min_d)) * 120).astype(np.uint8)
102 |
103 |
104 | def in_h_range_points(points, m, n, fov):
105 | """ extract horizontal in-range points """
106 | return np.logical_and(np.arctan2(n, m) > (-fov[1] * np.pi / 180),
107 | np.arctan2(n, m) < (-fov[0] * np.pi / 180))
108 |
109 |
110 | def in_v_range_points(points, m, n, fov):
111 | """ extract vertical in-range points """
112 | return np.logical_and(np.arctan2(n, m) < (fov[1] * np.pi / 180),
113 | np.arctan2(n, m) > (fov[0] * np.pi / 180))
114 |
115 |
116 | def fov_setting(points, x, y, z, dist, h_fov, v_fov):
117 | """ filter points based on h,v FOV """
118 |
119 | if h_fov[1] == 180 and h_fov[0] == -180 and v_fov[1] == 2.0 and v_fov[0] == -24.9:
120 | return points
121 |
122 | if h_fov[1] == 180 and h_fov[0] == -180:
123 | return points[in_v_range_points(points, dist, z, v_fov)]
124 | elif v_fov[1] == 2.0 and v_fov[0] == -24.9:
125 | return points[in_h_range_points(points, x, y, h_fov)]
126 | else:
127 | h_points = in_h_range_points(points, x, y, h_fov)
128 | v_points = in_v_range_points(points, dist, z, v_fov)
129 | return points[np.logical_and(h_points, v_points)]
130 |
131 |
132 | def in_range_points(points, size):
133 | """ extract in-range points """
134 | return np.logical_and(points > 0, points < size)
135 |
136 |
137 | def velo_points_filter(points, v_fov, h_fov):
138 | """ extract points corresponding to FOV setting """
139 |
140 | # Projecting to 2D
141 | x = points[:, 0]
142 | y = points[:, 1]
143 | z = points[:, 2]
144 | dist = np.sqrt(x ** 2 + y ** 2 + z ** 2)
145 |
146 | if h_fov[0] < -90:
147 | h_fov = (-90,) + h_fov[1:]
148 | if h_fov[1] > 90:
149 | h_fov = h_fov[:1] + (90,)
150 |
151 | x_lim = fov_setting(x, x, y, z, dist, h_fov, v_fov)[:, None]
152 | y_lim = fov_setting(y, x, y, z, dist, h_fov, v_fov)[:, None]
153 | z_lim = fov_setting(z, x, y, z, dist, h_fov, v_fov)[:, None]
154 |
155 | # Stack arrays in sequence horizontally
156 | xyz_ = np.hstack((x_lim, y_lim, z_lim))
157 | xyz_ = xyz_.T
158 |
159 | # stack (1,n) arrays filled with the number 1
160 | one_mat = np.full((1, xyz_.shape[1]), 1)
161 | xyz_ = np.concatenate((xyz_, one_mat), axis=0)
162 |
163 | # need dist info for points color
164 | dist_lim = fov_setting(dist, x, y, z, dist, h_fov, v_fov)
165 | color = depth_color(dist_lim, 0, 70)
166 |
167 | return xyz_, color
168 |
169 |
170 | def calib_velo2cam(filepath):
171 | """
172 | get Rotation(R : 3x3), Translation(T : 3x1) matrix info
173 | using R,T matrix, we can convert velodyne coordinates to camera coordinates
174 | """
175 | with open(filepath, "r") as f:
176 | file = f.readlines()
177 |
178 | for line in file:
179 | (key, val) = line.split(':', 1)
180 | if key == 'R':
181 | R = np.fromstring(val, sep=' ')
182 | R = R.reshape(3, 3)
183 | if key == 'T':
184 | T = np.fromstring(val, sep=' ')
185 | T = T.reshape(3, 1)
186 | return R, T
187 |
188 |
189 | def calib_cam2cam(filepath, mode):
190 | with open(filepath, "r") as f:
191 | file = f.readlines()
192 |
193 | for line in file:
194 | (key, val) = line.split(':', 1)
195 | if key == ('P_rect_' + mode):
196 | P_ = np.fromstring(val, sep=' ')
197 | P_ = P_.reshape(3, 4)
198 | # erase 4th column ([0,0,0])
199 | P_ = P_[:3, :3]
200 | return P_
201 |
202 |
203 | def velo3d_2_camera2d_points(points, v_fov, h_fov, vc_path, cc_path, mode='02', image_shape=None):
204 | """
205 | Return velodyne 3D points corresponding to camera 2D image and sparse depth map
206 | """
207 |
208 | # R_vc = Rotation matrix ( velodyne -> camera )
209 | # T_vc = Translation matrix ( velodyne -> camera )
210 | R_vc, T_vc = calib_velo2cam(vc_path)
211 |
212 | # P_ = Projection matrix ( camera coordinates 3d points -> image plane 2d points )
213 | P_ = calib_cam2cam(cc_path, mode)
214 | xyz_v, c_ = velo_points_filter(points, v_fov, h_fov)
215 |
216 | RT_ = np.concatenate((R_vc, T_vc), axis=1)
217 |
218 | # Initialize sparse depth map
219 | if image_shape is None:
220 | raise ValueError(
221 | "Image shape must be provided to generate sparse depth map")
222 |
223 | # Create a depth map with NaN values
224 | depth_map = np.full(image_shape[:2], 0)
225 |
226 | # Convert velodyne coordinates(X_v, Y_v, Z_v) to camera coordinates(X_c, Y_c, Z_c)
227 | for i in range(xyz_v.shape[1]):
228 | xyz_v[:3, i] = np.matmul(RT_, xyz_v[:, i])
229 |
230 | xyz_c = np.delete(xyz_v, 3, axis=0)
231 |
232 | # Convert camera coordinates(X_c, Y_c, Z_c) to image(pixel) coordinates(x,y)
233 | for i in range(xyz_c.shape[1]):
234 | xyz_c[:, i] = np.matmul(P_, xyz_c[:, i])
235 |
236 | # Normalize by the third coordinate to get 2D pixel coordinates
237 | xy_i = xyz_c[:2, :] / xyz_c[2, :]
238 | depth_values = xyz_c[2, :] # Z-coordinate (depth) in camera space
239 |
240 | # Filter out points that are out of image bounds
241 | valid_mask = np.logical_and.reduce((
242 | xy_i[0, :] >= 0,
243 | xy_i[0, :] < image_shape[1], # x coordinate within image width
244 | xy_i[1, :] >= 0,
245 | xy_i[1, :] < image_shape[0], # y coordinate within image height
246 | ))
247 |
248 | valid_points = xy_i[:, valid_mask]
249 | valid_depths = depth_values[valid_mask]
250 |
251 | # Fill the depth map with depth values
252 | for i in range(valid_points.shape[1]):
253 | x = int(valid_points[0, i])
254 | y = int(valid_points[1, i])
255 | depth_map[y, x] = valid_depths[i]
256 |
257 | # Return both the projected points and the sparse depth map
258 | return valid_points, c_, depth_map
259 |
260 |
261 | def print_projection_cv2(points, color, image):
262 | """ project converted velodyne points into camera image """
263 |
264 | hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
265 |
266 | for i in range(points.shape[1]):
267 | cv2.circle(hsv_image, (np.int32(points[0][i]), np.int32(
268 | points[1][i])), 2, (int(color[i]), 255, 255), -1)
269 |
270 | return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
271 |
272 |
273 | def print_projection_plt(points, color, image):
274 | """ project converted velodyne points into camera image """
275 |
276 | hsv_image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
277 |
278 | for i in range(points.shape[1]):
279 | cv2.circle(hsv_image, (np.int32(points[0][i]), np.int32(
280 | points[1][i])), 2, (int(color[i]), 255, 255), -1)
281 |
282 | return cv2.cvtColor(hsv_image, cv2.COLOR_HSV2RGB)
283 |
284 |
285 | def load_from_bin(bin_path):
286 | obj = np.fromfile(bin_path, dtype=np.float32).reshape(-1, 4)
287 | # ignore reflectivity info
288 | return obj[:, :3]
289 |
290 |
291 | def set_axes_equal(ax):
292 | """使 3D 图的刻度长短一致"""
293 | x_limits = ax.get_xlim3d()
294 | y_limits = ax.get_ylim3d()
295 | z_limits = ax.get_zlim3d()
296 |
297 | # 找到所有坐标的中心和范围
298 | x_range = abs(x_limits[1] - x_limits[0])
299 | x_middle = np.mean(x_limits)
300 | y_range = abs(y_limits[1] - y_limits[0])
301 | y_middle = np.mean(y_limits)
302 | z_range = abs(z_limits[1] - z_limits[0])
303 | z_middle = np.mean(z_limits)
304 |
305 | # 计算出最大的范围
306 | plot_radius = 0.5 * max([x_range, y_range, z_range])
307 |
308 | # 设置每个坐标轴的范围,使其相等
309 | ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
310 | ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
311 | ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
312 |
313 |
314 | def save_point_cloud_to_ply_open3d(point_cloud, image, filename):
315 | # 创建 open3d 点云对象
316 | pcd = o3d.geometry.PointCloud()
317 |
318 | # 设置点云的坐标
319 | pcd.points = o3d.utility.Vector3dVector(point_cloud)
320 |
321 | # 将图像展开为 (N, 3) 形式的颜色数据
322 | colors = image.reshape(-1, 3) / 255.0 # 归一化颜色到 [0, 1]
323 |
324 | # 设置点云的颜色
325 | pcd.colors = o3d.utility.Vector3dVector(colors)
326 |
327 | # 保存点云为 .ply 文件
328 | o3d.io.write_point_cloud(filename, pcd)
329 | return pcd
330 |
331 |
332 | @hydra.main(config_path='configs', config_name='config', version_base='1.2')
333 | def main(cfg):
334 | with Trainer(cfg) as run:
335 | net = run.net_ema.module.cuda()
336 | net.eval()
337 | # 读取左目图像
338 | # base = "datas/kitti/raw/2011_09_26/2011_09_26_drive_0002_sync"
339 | base = "datas/kitti/raw/2011_09_26/2011_09_26_drive_0048_sync"
340 |
341 | image_type = 'color' # 'grayscale' or 'color' image
342 |
343 | mode = '00' if image_type == 'grayscale' else '02'
344 |
345 | v2c_filepath = 'datas/kitti/raw/2011_09_26/calib_velo_to_cam.txt'
346 | c2c_filepath = 'datas/kitti/raw/2011_09_26/calib_cam_to_cam.txt'
347 | image_mean = np.array([90.9950, 96.2278, 94.3213])
348 | image_std = np.array([79.2382, 80.5267, 82.1483])
349 | image_height = 352
350 | image_width = 1216
351 |
352 | for i in tqdm(range(1000)):
353 |
354 | image = np.array(Image.open(os.path.join(
355 | base, 'image_' + mode + f'/data/0000000{i:03d}.png')).convert('RGB'), dtype=np.uint8)
356 |
357 | if image is None:
358 | break
359 | width, height = image.shape[1], image.shape[0]
360 |
361 | # bin file -> numpy array
362 | velo_points = load_from_bin(os.path.join(
363 | base, f'velodyne_points/data/0000000{i:03d}.bin'))
364 |
365 | image_type = 'color' # 'grayscale' or 'color' image
366 | # image_00 = 'grayscale image' , image_02 = 'color image'
367 | mode = '00' if image_type == 'grayscale' else '02'
368 |
369 | ans, c_, lidar = velo3d_2_camera2d_points(velo_points, v_fov=(-24.9, 2.0), h_fov=(-45, 45),
370 | vc_path=v2c_filepath, cc_path=c2c_filepath, mode=mode,
371 | image_shape=image.shape)
372 |
373 | image_vis = print_projection_plt(points=ans, color=c_, image=image.copy())
374 |
375 | # depth completion
376 | K_cam = torch.from_numpy(pull_K_cam(
377 | c2c_filepath).astype(np.float32)).cuda()
378 |
379 | tp = image.shape[0] - image_height
380 | lp = (image.shape[1] - image_width) // 2
381 | image = image[tp:tp + image_height, lp:lp + image_width]
382 | lidar = lidar[tp:tp + image_height, lp:lp + image_width, None]
383 | image_vis = image_vis[tp:tp + image_height, lp:lp + image_width]
384 | K_cam[0, 2] -= lp
385 | K_cam[1, 2] -= tp
386 |
387 | image = (image - image_mean) / image_std
388 |
389 | image_tensor = image.transpose(2, 0, 1).astype(np.float32)[None]
390 | lidar_tensor = lidar.transpose(2, 0, 1).astype(np.float32)[None]
391 |
392 | image_tensor = torch.from_numpy(image_tensor)
393 | lidar_tensor = torch.from_numpy(lidar_tensor)
394 |
395 | output = net(image_tensor.cuda(), None,
396 | lidar_tensor.cuda(), K_cam[None].cuda())
397 | if isinstance(output, (list, tuple)):
398 | output = output[-1]
399 |
400 | output = output.squeeze().detach().cpu().numpy()
401 | image = image * image_std + image_mean
402 |
403 | output_max, output_min = output.max(), output.min()
404 | output_norm = (output - output_min) / (output_max - output_min) * 255
405 | output_norm = output_norm.astype('uint8')
406 | output_color = cv2.applyColorMap(output_norm, cv2.COLORMAP_JET)
407 | cv2.imwrite(f'outputs/0000000{i:03d}_depth.png', output_color)
408 |
409 | cv2.imwrite(f'outputs/0000000{i:03d}_image.png', image.astype(np.uint8)[:, :, ::-1])
410 | cv2.imwrite(f'outputs/0000000{i:03d}_lidar.png', lidar.astype(np.uint8) * 3)
411 | cv2.imwrite(f'outputs/0000000{i:03d}_image_vis.png', image_vis)
412 |
413 |
414 | if __name__ == '__main__':
415 | main()
416 |
--------------------------------------------------------------------------------
/demo.sh:
--------------------------------------------------------------------------------
1 | python demo.py gpus=[2] name=BP_KITTI ++chpt=checkpoints/dmd3c_kitti.pth \
2 | net=PMP num_workers=1 \
3 | data=KITTI data.testset.mode=test data.testset.height=352 \
4 | test_batch_size=1 metric=RMSE ++save=true
5 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: bp
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=2_gnu
10 | - abseil-cpp=20220623.0=h8cdb687_6
11 | - absl-py=2.1.0=pyhd8ed1ab_0
12 | - alsa-lib=1.2.8=h166bdaf_0
13 | - antlr-python-runtime=4.9.3=pyhd8ed1ab_1
14 | - aom=3.5.0=h27087fc_0
15 | - attr=2.5.1=h166bdaf_1
16 | - binutils=2.40=hdd6e379_0
17 | - binutils_impl_linux-64=2.40=hf600244_0
18 | - binutils_linux-64=2.40=hdade7a5_3
19 | - blas=1.0=mkl
20 | - brotli-python=1.0.9=py39h6a678d5_7
21 | - bzip2=1.0.8=hd590300_5
22 | - c-ares=1.28.1=hd590300_0
23 | - c-compiler=1.6.0=hd590300_0
24 | - ca-certificates=2024.3.11=h06a4308_0
25 | - cached-property=1.5.2=hd8ed1ab_1
26 | - cached_property=1.5.2=pyha770c72_1
27 | - cairo=1.16.0=ha61ee94_1014
28 | - certifi=2024.2.2=pyhd8ed1ab_0
29 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
30 | - colorama=0.4.6=pyhd8ed1ab_0
31 | - compilers=1.6.0=ha770c72_0
32 | - cuda-cudart=12.1.105=0
33 | - cuda-cupti=12.1.105=0
34 | - cuda-libraries=12.1.0=0
35 | - cuda-nvrtc=12.1.105=0
36 | - cuda-nvtx=12.1.105=0
37 | - cuda-opencl=12.4.127=0
38 | - cuda-runtime=12.1.0=0
39 | - cxx-compiler=1.6.0=h00ab1b0_0
40 | - dbus=1.13.6=h5008d03_3
41 | - einops=0.7.0=pyhd8ed1ab_1
42 | - expat=2.6.2=h59595ed_0
43 | - ffmpeg=5.1.2=gpl_h8dda1f0_106
44 | - fftw=3.3.10=nompi_hc118613_108
45 | - filelock=3.13.1=py39h06a4308_0
46 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0
47 | - font-ttf-inconsolata=3.000=h77eed37_0
48 | - font-ttf-source-code-pro=2.038=h77eed37_0
49 | - font-ttf-ubuntu=0.83=h77eed37_1
50 | - fontconfig=2.14.2=h14ed4e7_0
51 | - fonts-conda-ecosystem=1=0
52 | - fonts-conda-forge=1=0
53 | - fortran-compiler=1.6.0=heb67821_0
54 | - freeglut=3.2.2=h9c3ff4c_1
55 | - freetype=2.12.1=h4a9f257_0
56 | - fsspec=2024.3.1=pyhca7485f_0
57 | - gcc=12.3.0=h95e488c_3
58 | - gcc_impl_linux-64=12.3.0=he2b93b0_5
59 | - gcc_linux-64=12.3.0=h6477408_3
60 | - gettext=0.21.1=h27087fc_0
61 | - gfortran=12.3.0=h7389182_3
62 | - gfortran_impl_linux-64=12.3.0=hfcedea8_5
63 | - gfortran_linux-64=12.3.0=h617cb40_3
64 | - glib=2.80.0=hf2295e7_1
65 | - glib-tools=2.80.0=hde27a5a_1
66 | - gmp=6.2.1=h295c915_3
67 | - gmpy2=2.1.2=py39heeb90bb_0
68 | - gnutls=3.7.9=hb077bed_0
69 | - graphite2=1.3.13=h59595ed_1003
70 | - grpc-cpp=1.51.1=h27aab58_1
71 | - grpcio=1.51.1=py39he859823_1
72 | - gst-plugins-base=1.22.0=h4243ec0_2
73 | - gstreamer=1.22.0=h25f0c4b_2
74 | - gstreamer-orc=0.4.38=hd590300_0
75 | - gtest=1.14.0=hdb19cb5_0
76 | - gxx=12.3.0=h95e488c_3
77 | - gxx_impl_linux-64=12.3.0=he2b93b0_5
78 | - gxx_linux-64=12.3.0=h4a1b8e8_3
79 | - h5py=3.9.0=nompi_py39h4dfffb9_100
80 | - harfbuzz=6.0.0=h8e241bc_0
81 | - hdf5=1.14.0=nompi_hb72d44e_103
82 | - huggingface_hub=0.22.2=pyhd8ed1ab_0
83 | - hydra-core=1.3.2=pyhd8ed1ab_0
84 | - icu=70.1=h27087fc_0
85 | - idna=3.4=py39h06a4308_0
86 | - importlib-metadata=7.1.0=pyha770c72_0
87 | - importlib_resources=6.4.0=pyhd8ed1ab_0
88 | - intel-openmp=2023.1.0=hdb19cb5_46306
89 | - jack=1.9.22=h11f4161_0
90 | - jasper=2.0.33=h0ff4b12_1
91 | - jinja2=3.1.3=py39h06a4308_0
92 | - jpeg=9e=h5eee18b_1
93 | - kernel-headers_linux-64=2.6.32=he073ed8_17
94 | - keyutils=1.6.1=h166bdaf_0
95 | - krb5=1.20.1=h81ceb04_0
96 | - lame=3.100=h7b6447c_0
97 | - lcms2=2.12=h3be6417_0
98 | - ld_impl_linux-64=2.40=h41732ed_0
99 | - lerc=3.0=h295c915_0
100 | - libabseil=20220623.0=cxx17_h05df665_6
101 | - libaec=1.1.3=h59595ed_0
102 | - libblas=3.9.0=1_h86c2bf4_netlib
103 | - libcap=2.67=he9d0100_0
104 | - libcblas=3.9.0=5_h92ddd45_netlib
105 | - libclang=15.0.7=default_h127d8a8_5
106 | - libclang13=15.0.7=default_h5d6823c_5
107 | - libcublas=12.1.0.26=0
108 | - libcufft=11.0.2.4=0
109 | - libcufile=1.9.1.3=0
110 | - libcups=2.3.3=h36d4200_3
111 | - libcurand=10.3.5.147=0
112 | - libcurl=8.1.2=h409715c_0
113 | - libcusolver=11.4.4.55=0
114 | - libcusparse=12.0.2.55=0
115 | - libdb=6.2.32=h9c3ff4c_0
116 | - libdeflate=1.8=h7f8727e_5
117 | - libdrm=2.4.120=hd590300_0
118 | - libedit=3.1.20191231=he28a2e2_2
119 | - libev=4.33=hd590300_2
120 | - libevent=2.1.10=h28343ad_4
121 | - libexpat=2.6.2=h59595ed_0
122 | - libffi=3.4.2=h7f98852_5
123 | - libflac=1.4.3=h59595ed_0
124 | - libgcc-devel_linux-64=12.3.0=h8bca6fd_105
125 | - libgcc-ng=13.2.0=h807b86a_5
126 | - libgcrypt=1.10.3=hd590300_0
127 | - libgfortran-ng=13.2.0=h69a702a_5
128 | - libgfortran5=13.2.0=ha4646dd_5
129 | - libglib=2.80.0=hf2295e7_1
130 | - libglu=9.0.0=he1b5a44_1001
131 | - libgomp=13.2.0=h807b86a_5
132 | - libgpg-error=1.48=h71f35ed_0
133 | - libgrpc=1.51.1=h4fad500_1
134 | - libiconv=1.17=hd590300_2
135 | - libidn2=2.3.4=h5eee18b_0
136 | - libjpeg-turbo=2.0.0=h9bf148f_0
137 | - liblapack=3.9.0=5_h92ddd45_netlib
138 | - liblapacke=3.9.0=5_h92ddd45_netlib
139 | - libllvm15=15.0.7=hadd5161_1
140 | - libnghttp2=1.58.0=h47da74e_0
141 | - libnpp=12.0.2.50=0
142 | - libnsl=2.0.1=hd590300_0
143 | - libnvjitlink=12.1.105=0
144 | - libnvjpeg=12.1.1.14=0
145 | - libogg=1.3.4=h7f98852_1
146 | - libopencv=4.7.0=py39h2ca4621_1
147 | - libopus=1.3.1=h7f98852_1
148 | - libpciaccess=0.18=hd590300_0
149 | - libpng=1.6.39=h5eee18b_0
150 | - libpq=15.3=hbcd7760_1
151 | - libprotobuf=3.21.12=hfc55251_2
152 | - libsanitizer=12.3.0=h0f45ef3_5
153 | - libsndfile=1.2.2=hc60ed4a_1
154 | - libsqlite=3.45.2=h2797004_0
155 | - libssh2=1.11.0=h0841786_0
156 | - libstdcxx-devel_linux-64=12.3.0=h8bca6fd_105
157 | - libstdcxx-ng=13.2.0=h7e041cc_5
158 | - libsystemd0=253=h8c4010b_1
159 | - libtasn1=4.19.0=h5eee18b_0
160 | - libtiff=4.5.0=h6a678d5_1
161 | - libtool=2.4.7=h27087fc_0
162 | - libudev1=253=h0b41bf4_1
163 | - libunistring=0.9.10=h27cfd23_0
164 | - libuuid=2.38.1=h0b41bf4_0
165 | - libva=2.18.0=h0b41bf4_0
166 | - libvorbis=1.3.7=h9c3ff4c_0
167 | - libvpx=1.11.0=h9c3ff4c_3
168 | - libwebp-base=1.3.2=h5eee18b_0
169 | - libxcb=1.13=h7f98852_1004
170 | - libxcrypt=4.4.36=hd590300_1
171 | - libxkbcommon=1.5.0=h79f4944_1
172 | - libxml2=2.10.3=hca2bb57_4
173 | - libzlib=1.2.13=hd590300_5
174 | - llvm-openmp=14.0.6=h9e868ea_0
175 | - lz4-c=1.9.4=h6a678d5_0
176 | - markdown=3.6=pyhd8ed1ab_0
177 | - markupsafe=2.1.3=py39h5eee18b_0
178 | - mkl=2023.1.0=h213fc3f_46344
179 | - mkl-service=2.4.0=py39h5eee18b_1
180 | - mkl_fft=1.3.8=py39h5eee18b_0
181 | - mkl_random=1.2.4=py39hdb19cb5_0
182 | - mpc=1.1.0=h10f8cd9_1
183 | - mpfr=4.0.2=hb69a4c5_1
184 | - mpg123=1.32.6=h59595ed_0
185 | - mpmath=1.3.0=py39h06a4308_0
186 | - mysql-common=8.0.33=hf1915f5_2
187 | - mysql-libs=8.0.33=hca2cd23_2
188 | - ncurses=6.4.20240210=h59595ed_0
189 | - nettle=3.9.1=h7ab15ed_0
190 | - networkx=3.1=py39h06a4308_0
191 | - nspr=4.35=h27087fc_0
192 | - nss=3.98=h1d7d5a4_0
193 | - numpy=1.26.4=py39h5f9d8c6_0
194 | - numpy-base=1.26.4=py39hb5e798b_0
195 | - omegaconf=2.3.0=pyhd8ed1ab_0
196 | - opencv=4.7.0=py39hf3d152e_1
197 | - openh264=2.3.1=hcb278e6_2
198 | - openjpeg=2.4.0=h3ad879b_0
199 | - openssl=3.1.5=hd590300_0
200 | - p11-kit=0.24.1=hc5aa10d_0
201 | - packaging=24.0=pyhd8ed1ab_0
202 | - pcre2=10.43=hcad00b1_0
203 | - pillow=10.2.0=py39h5eee18b_0
204 | - pip=24.0=pyhd8ed1ab_0
205 | - pixman=0.43.2=h59595ed_0
206 | - protobuf=4.21.12=py39h227be39_0
207 | - pthread-stubs=0.4=h36c2ea0_1001
208 | - pulseaudio=16.1=hcb278e6_3
209 | - pulseaudio-client=16.1=h5195f5e_3
210 | - pulseaudio-daemon=16.1=ha8d29e2_3
211 | - py-opencv=4.7.0=py39hcca971b_1
212 | - pysocks=1.7.1=py39h06a4308_0
213 | - python=3.9.18=h0755675_0_cpython
214 | - python_abi=3.9=4_cp39
215 | - pytorch=2.2.2=py3.9_cuda12.1_cudnn8.9.2_0
216 | - pytorch-cuda=12.1=ha16c6d3_5
217 | - pytorch-mutex=1.0=cuda
218 | - pyyaml=6.0.1=py39h5eee18b_0
219 | - qt-main=5.15.8=h5d23da1_6
220 | - re2=2023.02.01=hcb278e6_0
221 | - readline=8.2=h8228510_1
222 | - requests=2.31.0=py39h06a4308_1
223 | - safetensors=0.4.2=py39h9fdd4d6_0
224 | - setuptools=69.2.0=pyhd8ed1ab_0
225 | - six=1.16.0=pyh6c4a22f_0
226 | - svt-av1=1.4.1=hcb278e6_0
227 | - sympy=1.12=pypyh9d50eac_103
228 | - sysroot_linux-64=2.12=he073ed8_17
229 | - tbb=2021.8.0=hdb19cb5_0
230 | - tensorboard=2.16.2=pyhd8ed1ab_0
231 | - tensorboard-data-server=0.7.0=py39hd4f0224_1
232 | - timm=0.9.16=pyhd8ed1ab_0
233 | - tk=8.6.13=noxft_h4845f30_101
234 | - torchaudio=2.2.2=py39_cu121
235 | - torchtriton=2.2.0=py39
236 | - torchvision=0.17.2=py39_cu121
237 | - tqdm=4.66.2=pyhd8ed1ab_0
238 | - typing-extensions=4.9.0=py39h06a4308_1
239 | - typing_extensions=4.9.0=py39h06a4308_1
240 | - tzdata=2024a=h0c530f3_0
241 | - urllib3=2.1.0=py39h06a4308_1
242 | - werkzeug=3.0.2=pyhd8ed1ab_0
243 | - wheel=0.43.0=pyhd8ed1ab_1
244 | - x264=1!164.3095=h166bdaf_2
245 | - x265=3.5=h924138e_3
246 | - xcb-util=0.4.0=h516909a_0
247 | - xcb-util-image=0.4.0=h166bdaf_0
248 | - xcb-util-keysyms=0.4.0=h516909a_0
249 | - xcb-util-renderutil=0.3.9=h166bdaf_0
250 | - xcb-util-wm=0.4.1=h516909a_0
251 | - xkeyboard-config=2.38=h0b41bf4_0
252 | - xorg-fixesproto=5.0=h7f98852_1002
253 | - xorg-inputproto=2.3.2=h7f98852_1002
254 | - xorg-kbproto=1.0.7=h7f98852_1002
255 | - xorg-libice=1.1.1=hd590300_0
256 | - xorg-libsm=1.2.4=h7391055_0
257 | - xorg-libx11=1.8.4=h0b41bf4_0
258 | - xorg-libxau=1.0.11=hd590300_0
259 | - xorg-libxdmcp=1.1.3=h7f98852_0
260 | - xorg-libxext=1.3.4=h0b41bf4_2
261 | - xorg-libxfixes=5.0.3=h7f98852_1004
262 | - xorg-libxi=1.7.10=h7f98852_0
263 | - xorg-libxrender=0.9.10=h7f98852_1003
264 | - xorg-renderproto=0.11.1=h7f98852_1002
265 | - xorg-xextproto=7.3.0=h0b41bf4_1003
266 | - xorg-xproto=7.0.31=h7f98852_1007
267 | - xz=5.4.6=h5eee18b_0
268 | - yaml=0.2.5=h7b6447c_0
269 | - zipp=3.17.0=pyhd8ed1ab_0
270 | - zlib=1.2.13=hd590300_5
271 | - zstd=1.5.2=ha4553b6_0
272 |
273 |
--------------------------------------------------------------------------------
/exts/bp_cuda.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jie on 09/02/19.
3 | //
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include "bp_cuda.h"
12 |
13 |
14 | void dist(
15 | at::Tensor Pc,
16 | at::Tensor IPCnum,
17 | at::Tensor args,
18 | int H,
19 | int W
20 | ) {
21 | int B, Cc, N, M, num;
22 | B = Pc.size(0);
23 | Cc = Pc.size(1);
24 | M = Pc.size(2);
25 | num = args.size(1);
26 | N = args.size(2);
27 | Dist_Cuda(Pc, IPCnum, args, B, Cc, N, M, num, H, W);
28 | }
29 |
30 |
31 | at::Tensor Conv2dLocal_F(
32 | at::Tensor a,
33 | at::Tensor b
34 | ) {
35 | int N1, N2, Ci, Co, K, B;
36 | B = a.size(0);
37 | Ci = a.size(1);
38 | N1 = a.size(2);
39 | N2 = a.size(3);
40 | Co = Ci;
41 | K = sqrt(b.size(1) / Co);
42 | auto c = at::zeros_like(a);
43 | Conv2d_LF_Cuda(a, b, c, N1, N2, Ci, Co, B, K);
44 | return c;
45 | }
46 |
47 |
48 | std::tuple Conv2dLocal_B(
49 | at::Tensor a,
50 | at::Tensor b,
51 | at::Tensor gc
52 | ) {
53 | int N1, N2, Ci, Co, K, B;
54 | B = a.size(0);
55 | Ci = a.size(1);
56 | N1 = a.size(2);
57 | N2 = a.size(3);
58 | Co = Ci;
59 | K = sqrt(b.size(1) / Co);
60 | auto ga = at::zeros_like(a);
61 | auto gb = at::zeros_like(b);
62 | Conv2d_LB_Cuda(a, b, ga, gb, gc, N1, N2, Ci, Co, B, K);
63 | return std::make_tuple(ga, gb);
64 | }
65 |
66 |
67 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m
68 | ) {
69 | m.def("Dist", &dist, "calculate distance on 2D image");
70 | m.def("Conv2dLocal_F", &Conv2dLocal_F, "Conv2dLocal Forward (CUDA)");
71 | m.def("Conv2dLocal_B", &Conv2dLocal_B, "Conv2dLocal Backward (CUDA)");
72 | }
73 |
74 |
--------------------------------------------------------------------------------
/exts/bp_cuda.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by jie on 12/8/22.
3 | //
4 |
5 | #ifndef BP_CUDA_H
6 | #define BP_CUDA_H
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 |
17 | void dist(
18 | at::Tensor Pc,
19 | at::Tensor IPCnum,
20 | at::Tensor args,
21 | int H,
22 | int W
23 | );
24 |
25 | at::Tensor Conv2dLocal_F(
26 | at::Tensor a,
27 | at::Tensor b
28 | );
29 |
30 | std::tuple Conv2dLocal_B(
31 | at::Tensor a,
32 | at::Tensor b,
33 | at::Tensor gc
34 | );
35 |
36 | void Dist_Cuda(at::Tensor Pc, at::Tensor IPCnum, at::Tensor args,
37 | size_t B, size_t Cc, size_t N, size_t M, size_t num, size_t H, size_t W);
38 |
39 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B,
40 | size_t K);
41 |
42 | void
43 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2, size_t Ci,
44 | size_t Co, size_t B, size_t K);
45 |
46 |
47 | #endif
48 |
--------------------------------------------------------------------------------
/exts/bp_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include "bp_cuda.h"
7 | #include
8 | #include
9 |
10 | namespace {
11 |
12 | template
13 | __global__ void
14 | dist_kernel4(const scalar_t *__restrict__ Pc, scalar_t *__restrict__ IPCnum,
15 | long *__restrict__ args, int B, int Cc, int N, int M, int numfw, int H, int W) {
16 | int num = 4;
17 | int h = threadIdx.x + blockIdx.x * blockDim.x;
18 | int w = threadIdx.y + blockIdx.y * blockDim.y;
19 | int b = threadIdx.z + blockIdx.z * blockDim.z;
20 | if ((h < H) && (w < W) && (b < B)) {
21 | scalar_t best[4];
22 | long argbest[4];
23 | #pragma unroll
24 | for (int i = 0; i < num; i++) {
25 | best[i] = 1e20;
26 | argbest[i] = 0;
27 | }
28 | for (int m = 0; m < M; m++) {
29 | scalar_t res1, res2, dist;
30 | res1 = Pc[b * Cc * M + 0 * M + m] - w;
31 | res2 = Pc[b * Cc * M + 1 * M + m] - h;
32 | dist = res1 * res1 + res2 * res2;
33 | #pragma unroll
34 | for (int i = 0; i < num; i++) {
35 | if (best[i] >= dist) {
36 | #pragma unroll
37 | for (int j = num - 1; j > i; j--) {
38 | best[j] = best[j - 1];
39 | argbest[j] = argbest[j - 1];
40 | }
41 | best[i] = dist;
42 | argbest[i] = m;
43 | break;
44 | }
45 | }
46 | }
47 | #pragma unroll
48 | for (int i = 0; i < num; i++) {
49 | if (best[i] >= 1e20) {
50 | argbest[i] = argbest[i - 1];
51 | }
52 | args[b * num * N + i * N + h * W + w] = argbest[i];
53 | IPCnum[b * Cc * num * N + 0 * num * N + i * N + h * W + w] =
54 | Pc[b * Cc * M + 0 * M + argbest[i]] - w;
55 | IPCnum[b * Cc * num * N + 1 * num * N + i * N + h * W + w] =
56 | Pc[b * Cc * M + 1 * M + argbest[i]] - h;
57 | }
58 | }
59 | }
60 |
61 | template
62 | __global__ void
63 | conv2d_kernel_lf(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ z, size_t N1,
64 | size_t N2, size_t Ci, size_t Co, size_t B,
65 | size_t K) {
66 | int col_index = threadIdx.x + blockIdx.x * blockDim.x;
67 | int row_index = threadIdx.y + blockIdx.y * blockDim.y;
68 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z;
69 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) {
70 | for (int b = 0; b < B; b++) {
71 | scalar_t result = 0;
72 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) {
73 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) {
74 |
75 | if ((row_index + i < 0) || (row_index + i >= N1) || (col_index + j < 0) ||
76 | (col_index + j >= N2)) {
77 | continue;
78 | }
79 |
80 | result += x[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index + i) * N2 + col_index + j] *
81 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K +
82 | (i + (K - 1) / 2) * K * N1 * N2 +
83 | (j + (K - 1) / 2) * N1 * N2 + row_index * N2 + col_index];
84 | }
85 | }
86 | z[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result;
87 | }
88 | }
89 | }
90 |
91 |
92 | template
93 | __global__ void conv2d_kernel_lb(scalar_t *__restrict__ x, scalar_t *__restrict__ y, scalar_t *__restrict__ gx,
94 | scalar_t *__restrict__ gy, scalar_t *__restrict__ gz, size_t N1, size_t N2,
95 | size_t Ci, size_t Co, size_t B,
96 | size_t K) {
97 | int col_index = threadIdx.x + blockIdx.x * blockDim.x;
98 | int row_index = threadIdx.y + blockIdx.y * blockDim.y;
99 | int cha_index = threadIdx.z + blockIdx.z * blockDim.z;
100 | if ((row_index < N1) && (col_index < N2) && (cha_index < Co)) {
101 | for (int b = 0; b < B; b++) {
102 | scalar_t result = 0;
103 | for (int i = -int((K - 1) / 2.); i < (K + 1) / 2.; i++) {
104 | for (int j = -int((K - 1) / 2.); j < (K + 1) / 2.; j++) {
105 |
106 | if ((row_index - i < 0) || (row_index - i >= N1) || (col_index - j < 0) ||
107 | (col_index - j >= N2)) {
108 | continue;
109 | }
110 | result += gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j
111 | ] *
112 | y[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K +
113 | (i + (K - 1) / 2) * K * N1 * N2 +
114 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j];
115 | gy[b * N1 * N2 * Ci * K * K + cha_index * N1 * N2 * K * K +
116 | (i + (K - 1) / 2) * K * N1 * N2 +
117 | (j + (K - 1) / 2) * N1 * N2 + (row_index - i) * N2 + col_index - j] =
118 | gz[b * N1 * N2 * Ci + cha_index * N1 * N2 + (row_index - i) * N2 + col_index - j
119 | ] * x[b * N1 * N2 * Ci + cha_index * N1 * N2 + row_index * N2 + col_index];
120 |
121 | }
122 | }
123 | gx[b * N1 * N2 * Co + cha_index * N1 * N2 + row_index * N2 + col_index] = result;
124 | }
125 | }
126 | }
127 |
128 |
129 | }
130 |
131 |
132 | void Dist_Cuda(at::Tensor Pc, at::Tensor IPCnum, at::Tensor args,
133 | size_t B, size_t Cc, size_t N, size_t M, size_t num, size_t H, size_t W) {
134 | dim3 blockSize(32, 32, 1);
135 | dim3 gridSize((H + blockSize.x - 1) / blockSize.x, (W + blockSize.x - 1) / blockSize.x, B);
136 | switch (num) {
137 | case 4:
138 | AT_DISPATCH_FLOATING_TYPES(Pc.type(), "DistF1_Cuda", ([&] {
139 | dist_kernel4 << < gridSize, blockSize >> > (
140 | Pc.data_ptr(), IPCnum.data_ptr(), args.data_ptr(),
141 | B, Cc, N, M, num, H, W);
142 | }));
143 | break;
144 | default:
145 | exit(-1);
146 | }
147 | }
148 |
149 | void Conv2d_LF_Cuda(at::Tensor x, at::Tensor y, at::Tensor z, size_t N1, size_t N2, size_t Ci, size_t Co, size_t B,
150 | size_t K) {
151 | dim3 blockSize(32, 32, 1);
152 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y,
153 | (Co + blockSize.z - 1) / blockSize.z);
154 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LF", ([&] {
155 | conv2d_kernel_lf << < gridSize, blockSize >> > (
156 | x.data(), y.data(), z.data(),
157 | N1, N2, Ci, Co, B, K);
158 | }));
159 | }
160 |
161 |
162 | void
163 | Conv2d_LB_Cuda(at::Tensor x, at::Tensor y, at::Tensor gx, at::Tensor gy, at::Tensor gz, size_t N1, size_t N2,
164 | size_t Ci,
165 | size_t Co, size_t B, size_t K) {
166 | dim3 blockSize(32, 32, 1);
167 | dim3 gridSize((N2 + blockSize.x - 1) / blockSize.x, (N1 + blockSize.y - 1) / blockSize.y,
168 | (Co + blockSize.z - 1) / blockSize.z);
169 | AT_DISPATCH_FLOATING_TYPES(x.type(), "Conv2d_LB", ([&] {
170 | conv2d_kernel_lb << < gridSize, blockSize >> > (
171 | x.data(), y.data(),
172 | gx.data(), gy.data(), gz.data(),
173 | N1, N2, Ci, Co, B, K);
174 | }));
175 | }
176 |
--------------------------------------------------------------------------------
/exts/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='BpOps',
6 | ext_modules=[
7 | CUDAExtension('BpOps',
8 | [
9 | 'bp_cuda.cpp',
10 | 'bp_cuda_kernel.cu',
11 | ],
12 | extra_compile_args={'cxx': ['-g'],
13 | 'nvcc': ['-O3']}
14 | ),
15 | ],
16 | cmdclass={
17 | 'build_ext': BuildExtension
18 | })
19 |
--------------------------------------------------------------------------------
/models/BPNet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : BPNet.py
3 | # @Project: BP-Net
4 | # @Author : jie
5 | # @Time : 4/8/23 12:43 PM
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import numpy as np
10 | import functools
11 | from .utils import Conv1x1, Basic2d, BasicBlock, weights_init, inplace_relu, GenKernel, PMP
12 |
13 | __all__ = [
14 | 'Pre_MF_Post',
15 | ]
16 |
17 |
18 | class Net(nn.Module):
19 | """
20 | network
21 | """
22 |
23 | def __init__(self, block=BasicBlock, bc=16, img_layers=[2, 2, 2, 2, 2, 2],
24 | drop_path=0.1, norm_layer=nn.BatchNorm2d, padding_mode='zeros', drift=1e6):
25 | super().__init__()
26 | self.drift = drift
27 | self._norm_layer = norm_layer
28 | self._padding_mode = padding_mode
29 | self._img_dpc = 0
30 | self._img_dprs = np.linspace(0, drop_path, sum(img_layers))
31 |
32 | self.inplanes = bc * 2
33 | self.conv_img = nn.Sequential(
34 | Basic2d(3, bc * 2, norm_layer=norm_layer, kernel_size=3, padding=1),
35 | self._make_layer(block, bc * 2, 2, stride=1)
36 | )
37 |
38 | self.layer1_img = self._make_layer(block, bc * 4, img_layers[1], stride=2)
39 |
40 | self.inplanes = bc * 4
41 | self.layer2_img = self._make_layer(block, bc * 8, img_layers[2], stride=2)
42 |
43 | self.inplanes = bc * 8
44 | self.layer3_img = self._make_layer(block, bc * 16, img_layers[3], stride=2)
45 |
46 | self.inplanes = bc * 16
47 | self.layer4_img = self._make_layer(block, bc * 16, img_layers[4], stride=2)
48 |
49 | self.inplanes = bc * 16
50 | self.layer5_img = self._make_layer(block, bc * 16, img_layers[5], stride=2)
51 |
52 | self.pred0 = PMP(level=0, in_ch=bc * 4, out_ch=bc * 2, drop_path=drop_path, pool=False)
53 | self.pred1 = PMP(level=1, in_ch=bc * 8, out_ch=bc * 4, drop_path=drop_path)
54 | self.pred2 = PMP(level=2, in_ch=bc * 16, out_ch=bc * 8, drop_path=drop_path)
55 | self.pred3 = PMP(level=3, in_ch=bc * 16, out_ch=bc * 16, drop_path=drop_path)
56 | self.pred4 = PMP(level=4, in_ch=bc * 16, out_ch=bc * 16, drop_path=drop_path)
57 | self.pred5 = PMP(level=5, in_ch=bc * 16, out_ch=bc * 16, drop_path=drop_path, up=False)
58 |
59 | def forward(self, I, DISP, S, K):
60 | """
61 | I: Bx3xHxW
62 | S: Bx1xHxW
63 | K: Bx3x3
64 | """
65 | output = []
66 | # torch.Size([2, 3, 256, 1216])
67 | XI0 = self.conv_img(I) # torch.Size([2, 32, 256, 1216])
68 | XI1 = self.layer1_img(XI0) # torch.Size([2, 64, 128, 608])
69 | XI2 = self.layer2_img(XI1) # torch.Size([2, 128, 64, 304])
70 | XI3 = self.layer3_img(XI2) # torch.Size([2, 256, 32, 152])
71 | XI4 = self.layer4_img(XI3) # torch.Size([2, 256, 16, 76])
72 | XI5 = self.layer5_img(XI4) # torch.Size([2, 256, 8, 38])
73 |
74 | # import IPython
75 | # IPython.embed()
76 | # exit()
77 |
78 | # S: torch.Size([2, 1, 256, 1216])
79 | # XI5: torch.Size([2, 256, 8, 38])
80 |
81 | fout, dout = self.pred5(fout=None, dout=None, XI=XI5, S=S, K=K)
82 | output.append(F.interpolate(dout, scale_factor=2 ** 5, mode='bilinear', align_corners=True))
83 |
84 | fout, dout = self.pred4(fout=fout, dout=dout, XI=XI4, S=S, K=K)
85 | output.append(F.interpolate(dout, scale_factor=2 ** 4, mode='bilinear', align_corners=True))
86 |
87 | fout, dout = self.pred3(fout=fout, dout=dout, XI=XI3, S=S, K=K)
88 | output.append(F.interpolate(dout, scale_factor=2 ** 3, mode='bilinear', align_corners=True))
89 |
90 | fout, dout = self.pred2(fout=fout, dout=dout, XI=XI2, S=S, K=K)
91 | output.append(F.interpolate(dout, scale_factor=2 ** 2, mode='bilinear', align_corners=True))
92 |
93 | fout, dout = self.pred1(fout=fout, dout=dout, XI=XI1, S=S, K=K)
94 | output.append(F.interpolate(dout, scale_factor=2 ** 1, mode='bilinear', align_corners=True))
95 |
96 | fout, dout = self.pred0(fout=fout, dout=dout, XI=XI0, S=S, K=K)
97 | output.append(dout)
98 | return output
99 |
100 | def _make_layer(self, block, planes, blocks, stride=1):
101 | norm_layer = self._norm_layer
102 | padding_mode = self._padding_mode
103 | downsample = None
104 | if norm_layer is None:
105 | bias = True
106 | norm_layer = nn.Identity
107 | else:
108 | bias = False
109 | if stride != 1 or self.inplanes != planes * block.expansion:
110 | downsample = nn.Sequential(
111 | Conv1x1(self.inplanes, planes * block.expansion, stride, bias=bias),
112 | norm_layer(planes * block.expansion),
113 | )
114 |
115 | layers = []
116 | layers.append(
117 | block(self.inplanes, planes, stride, downsample, norm_layer=norm_layer, padding_mode=padding_mode,
118 | drop_path=self._img_dprs[self._img_dpc]))
119 | self._img_dpc += 1
120 | self.inplanes = planes * block.expansion
121 | for _ in range(1, blocks):
122 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer, padding_mode=padding_mode,
123 | drop_path=self._img_dprs[self._img_dpc]))
124 | self._img_dpc += 1
125 |
126 | return nn.Sequential(*layers)
127 |
128 |
129 | def Pre_MF_Post():
130 | """
131 | Pre.+MF.+Post.
132 | """
133 | net = Net()
134 | net.apply(functools.partial(weights_init, mode='trunc'))
135 | for m in net.modules():
136 | if isinstance(m, GenKernel) and m.conv[1].conv.bn.weight is not None:
137 | nn.init.constant_(m.conv[1].conv.bn.weight, 0)
138 | net.apply(inplace_relu)
139 | return net
140 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .BPNet import *
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env
2 | # -*- coding: utf-8 -*-
3 | # @Filename : utils
4 | # @Date : 2022-05-10
5 | # @Project: BP-Net
6 | # @AUTHOR : jie
7 |
8 | from copy import deepcopy
9 | import torch
10 | import torch.nn as nn
11 | from torch.autograd import Function
12 | import torch.nn.functional as F
13 | import math
14 | import numpy as np
15 | from collections import OrderedDict
16 | import BpOps
17 | import torch.distributed as dist
18 | from einops.layers.torch import Rearrange
19 | from timm.models.layers import DropPath
20 | from torch.cuda.amp import custom_fwd, custom_bwd
21 | import functools
22 |
23 | __all__ = [
24 | 'EMA',
25 | 'Conv1x1',
26 | 'Conv3x3',
27 | 'Basic2d',
28 | 'weights_init',
29 | 'inplace_relu',
30 | 'PMP',
31 | ]
32 |
33 |
34 | class BpDist(Function):
35 |
36 | @staticmethod
37 | @custom_fwd(cast_inputs=torch.float32)
38 | def forward(ctx, xy, idx, Valid, num, H, W):
39 | """
40 | """
41 | assert xy.is_contiguous()
42 | assert Valid.is_contiguous()
43 | _, Cc, M = xy.shape
44 | B = Valid.shape[0]
45 | N = H * W
46 | args = torch.zeros((B, num, N), dtype=torch.long, device=xy.device)
47 | IPCnum = torch.zeros((B, Cc, num, N), dtype=xy.dtype, device=xy.device)
48 | for b in range(B):
49 | Pc = torch.masked_select(xy, Valid[b:b + 1].view(1, 1, N)).reshape(1, 2, -1)
50 | BpOps.Dist(Pc, IPCnum[b:b + 1], args[b:b + 1], H, W)
51 | idx_valid = torch.masked_select(idx, Valid[b:b + 1].view(1, 1, N))
52 | args[b:b + 1] = torch.index_select(idx_valid, 0, args[b:b + 1].reshape(-1)).reshape(1, num, N)
53 | return IPCnum, args
54 |
55 | @staticmethod
56 | @custom_bwd
57 | def backward(ctx, ga=None, gb=None):
58 | return None, None, None, None
59 |
60 |
61 | class BpConvLocal(Function):
62 | @staticmethod
63 | def forward(ctx, input, weight):
64 | assert input.is_contiguous()
65 | assert weight.is_contiguous()
66 | ctx.save_for_backward(input, weight)
67 | output = BpOps.Conv2dLocal_F(input, weight)
68 | return output
69 |
70 | @staticmethod
71 | def backward(ctx, grad_output):
72 | input, weight = ctx.saved_tensors
73 | grad_output = grad_output.contiguous()
74 | grad_input, grad_weight = BpOps.Conv2dLocal_B(input, weight, grad_output)
75 | return grad_input, grad_weight
76 |
77 |
78 | bpdist = BpDist.apply
79 | bpconvlocal = BpConvLocal.apply
80 |
81 |
82 | class EMA(nn.Module):
83 | """ Model Exponential Moving Average V2 borrow from timm https://timm.fast.ai/
84 |
85 | Keep a moving average of everything in the model state_dict (parameters and buffers).
86 | V2 of this module is simpler, it does not match params/buffers based on name but simply
87 | iterates in order. It works with torchscript (JIT of full model).
88 |
89 | This is intended to allow functionality like
90 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
91 |
92 | A smoothed version of the weights is necessary for some training schemes to perform well.
93 | E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
94 | RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
95 | smoothing of weights to match results. Pay attention to the decay constant you are using
96 | relative to your update count per epoch.
97 |
98 | To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
99 | disable validation of the EMA weights. Validation will have to be done manually in a separate
100 | process, or after the training stops converging.
101 |
102 | This class is sensitive where it is initialized in the sequence of model init,
103 | GPU assignment and distributed training wrappers.
104 | """
105 |
106 | def __init__(self, model, decay=0.9999, ddp=False):
107 | super().__init__()
108 | # make a copy of the model for accumulating moving average of weights
109 | self.module = deepcopy(model)
110 | self.module.eval()
111 | if ddp:
112 | self.broadcast()
113 | self.decay = decay
114 |
115 | def broadcast(self):
116 | for ema_v in self.module.state_dict().values():
117 | dist.broadcast(ema_v, src=0, async_op=False)
118 |
119 | def _update(self, model, update_fn):
120 | with torch.no_grad():
121 | for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
122 | ema_v.copy_(update_fn(ema_v, model_v))
123 |
124 | def update(self, model):
125 | self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
126 |
127 | def set(self, model):
128 | self._update(model, update_fn=lambda e, m: m)
129 |
130 |
131 | def weights_init(m, mode='trunc'):
132 | from torch.nn.init import _calculate_fan_in_and_fan_out
133 | classname = m.__class__.__name__
134 | if classname.find('Conv2d') != -1:
135 | if hasattr(m, 'weight'):
136 | if mode == 'trunc':
137 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight.data)
138 | std = math.sqrt(2.0 / float(fan_in + fan_out))
139 | torch.nn.init.trunc_normal_(m.weight.data, mean=0, std=std)
140 | elif mode == 'xavier':
141 | torch.nn.init.xavier_normal_(m.weight.data)
142 | else:
143 | raise ValueError(f'unknown mode = {mode}')
144 | if hasattr(m, 'bias') and m.bias is not None:
145 | torch.nn.init.constant_(m.bias.data, 0.0)
146 | if classname.find('Conv1d') != -1:
147 | if hasattr(m, 'weight'):
148 | if mode == 'trunc':
149 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight.data)
150 | std = math.sqrt(2.0 / float(fan_in + fan_out))
151 | torch.nn.init.trunc_normal_(m.weight.data, mean=0, std=std)
152 | elif mode == 'xavier':
153 | torch.nn.init.xavier_normal_(m.weight.data)
154 | else:
155 | raise ValueError(f'unknown mode = {mode}')
156 | if hasattr(m, 'bias') and m.bias is not None:
157 | torch.nn.init.constant_(m.bias.data, 0.0)
158 | elif classname.find('Linear') != -1:
159 | if mode == 'trunc':
160 | fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight.data)
161 | std = math.sqrt(2.0 / float(fan_in + fan_out))
162 | torch.nn.init.trunc_normal_(m.weight.data, mean=0, std=std)
163 | elif mode == 'xavier':
164 | torch.nn.init.xavier_normal_(m.weight.data)
165 | else:
166 | raise ValueError(f'unknown mode = {mode}')
167 | if m.bias is not None:
168 | torch.nn.init.constant_(m.bias.data, 0.0)
169 |
170 |
171 | def Conv1x1(in_planes, out_planes, stride=1, bias=False, groups=1, dilation=1, padding_mode='zeros'):
172 | """1x1 convolution"""
173 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
174 |
175 |
176 | def Conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, padding_mode='zeros', bias=False):
177 | """3x3 convolution with padding"""
178 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
179 | padding=dilation, padding_mode=padding_mode, groups=groups, bias=bias, dilation=dilation)
180 |
181 |
182 | class Basic2d(nn.Module):
183 | def __init__(self, in_channels, out_channels, norm_layer=None, kernel_size=3, padding=1, padding_mode='zeros',
184 | act=nn.ReLU, stride=1):
185 | super().__init__()
186 | if norm_layer:
187 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
188 | stride=stride, padding=padding, bias=False, padding_mode=padding_mode)
189 | else:
190 | conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
191 | stride=stride, padding=padding, bias=True, padding_mode=padding_mode)
192 | self.conv = nn.Sequential(OrderedDict([('conv', conv)]))
193 | if norm_layer:
194 | self.conv.add_module('bn', norm_layer(out_channels))
195 | self.conv.add_module('relu', act())
196 |
197 | def forward(self, x):
198 | out = self.conv(x)
199 | return out
200 |
201 |
202 |
203 | def inplace_relu(m):
204 | classname = m.__class__.__name__
205 | if classname.find('ReLU') != -1:
206 | m.inplace = True
207 |
208 |
209 | class Basic2dTrans(nn.Module):
210 | def __init__(self, in_channels, out_channels, norm_layer=None, act=nn.ReLU):
211 | super().__init__()
212 | if norm_layer is None:
213 | bias = True
214 | norm_layer = nn.Identity
215 | else:
216 | bias = False
217 | self.conv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4,
218 | stride=2, padding=1, bias=bias)
219 | self.bn = norm_layer(out_channels)
220 | self.relu = act()
221 |
222 | def forward(self, x):
223 | out = self.conv(x.contiguous())
224 | out = self.bn(out)
225 | out = self.relu(out)
226 | return out
227 |
228 |
229 | class UpCC(nn.Module):
230 | def __init__(self, in_channels, mid_channels, out_channels, norm_layer=None, kernel_size=3, padding=1,
231 | padding_mode='zeros', act=nn.ReLU):
232 | super().__init__()
233 | self.upf = Basic2dTrans(in_channels, out_channels, norm_layer=norm_layer, act=act)
234 | self.conv = Basic2d(mid_channels + out_channels, out_channels,
235 | norm_layer=norm_layer, kernel_size=kernel_size,
236 | padding=padding, padding_mode=padding_mode, act=act)
237 |
238 | def forward(self, x, y):
239 | """
240 | """
241 | out = self.upf(x)
242 | fout = torch.cat([out, y], dim=1)
243 | fout = self.conv(fout)
244 | return fout
245 |
246 | class GenKernel(nn.Module):
247 | def __init__(self, in_channels, pk, norm_layer=nn.BatchNorm2d, act=nn.ReLU, eps=1e-6):
248 | super().__init__()
249 | self.eps = eps
250 | self.conv = nn.Sequential(
251 | Basic2d(in_channels, in_channels, norm_layer=norm_layer, act=act),
252 | Basic2d(in_channels, pk * pk - 1, norm_layer=norm_layer, act=nn.Identity),
253 | )
254 |
255 | def forward(self, fout):
256 | weight = self.conv(fout)
257 | weight_sum = torch.sum(weight.abs(), dim=1, keepdim=True)
258 | weight = torch.div(weight, weight_sum + self.eps)
259 | weight_mid = 1 - torch.sum(weight, dim=1, keepdim=True)
260 | weight_pre, weight_post = torch.split(weight, [weight.shape[1] // 2, weight.shape[1] // 2], dim=1)
261 | weight = torch.cat([weight_pre, weight_mid, weight_post], dim=1).contiguous()
262 | return weight
263 |
264 |
265 | class CSPN(nn.Module):
266 | """
267 | implementation of CSPN++
268 | """
269 |
270 | def __init__(self, in_channels, pt, norm_layer=nn.BatchNorm2d, act=nn.ReLU, eps=1e-6):
271 | super().__init__()
272 | self.pt = pt
273 | self.weight3x3 = GenKernel(in_channels, 3, norm_layer=norm_layer, act=act, eps=eps)
274 | self.weight5x5 = GenKernel(in_channels, 5, norm_layer=norm_layer, act=act, eps=eps)
275 | self.weight7x7 = GenKernel(in_channels, 7, norm_layer=norm_layer, act=act, eps=eps)
276 | self.convmask = nn.Sequential(
277 | Basic2d(in_channels, in_channels, norm_layer=norm_layer, act=act),
278 | Basic2d(in_channels, 3, norm_layer=None, act=nn.Sigmoid),
279 | )
280 | self.convck = nn.Sequential(
281 | Basic2d(in_channels, in_channels, norm_layer=norm_layer, act=act),
282 | Basic2d(in_channels, 3, norm_layer=None, act=functools.partial(nn.Softmax, dim=1)),
283 | )
284 | self.convct = nn.Sequential(
285 | Basic2d(in_channels + 3, in_channels, norm_layer=norm_layer, act=act),
286 | Basic2d(in_channels, 3, norm_layer=None, act=functools.partial(nn.Softmax, dim=1)),
287 | )
288 |
289 | @custom_fwd(cast_inputs=torch.float32)
290 | def forward(self, fout, hn, h0):
291 | weight3x3 = self.weight3x3(fout)
292 | weight5x5 = self.weight5x5(fout)
293 | weight7x7 = self.weight7x7(fout)
294 | mask3x3, mask5x5, mask7x7 = torch.split(self.convmask(fout) * (h0 > 1e-3).float(), 1, dim=1)
295 | conf3x3, conf5x5, conf7x7 = torch.split(self.convck(fout), 1, dim=1)
296 | hn3x3 = hn5x5 = hn7x7 = hn
297 | hns = [hn, ]
298 | for i in range(self.pt):
299 | hn3x3 = (1. - mask3x3) * bpconvlocal(hn3x3, weight3x3) + mask3x3 * h0
300 | hn5x5 = (1. - mask5x5) * bpconvlocal(hn5x5, weight5x5) + mask5x5 * h0
301 | hn7x7 = (1. - mask7x7) * bpconvlocal(hn7x7, weight7x7) + mask7x7 * h0
302 | if i == self.pt // 2 - 1:
303 | hns.append(conf3x3 * hn3x3 + conf5x5 * hn5x5 + conf7x7 * hn7x7)
304 | hns.append(conf3x3 * hn3x3 + conf5x5 * hn5x5 + conf7x7 * hn7x7)
305 | hns = torch.cat(hns, dim=1)
306 | wt = self.convct(torch.cat([fout, hns], dim=1))
307 | hn = torch.sum(wt * hns, dim=1, keepdim=True)
308 | return hn
309 |
310 |
311 | class Coef(nn.Module):
312 | """
313 | """
314 | def __init__(self, in_channels, out_channels=3, kernel_size=1, padding=0):
315 | super().__init__()
316 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
317 | stride=1, padding=padding, bias=True)
318 |
319 | def forward(self, x):
320 | feat = self.conv(x)
321 | XF, XB, XW = torch.split(feat, [1, 1, 1], dim=1)
322 | return XF, XB, XW
323 |
324 |
325 | class Dist(nn.Module):
326 | """
327 | """
328 |
329 | def __init__(self, num):
330 | super().__init__()
331 | """
332 | """
333 | self.num = num
334 |
335 | def forward(self, S, xx, yy):
336 | """
337 | """
338 | num = self.num
339 | B, _, height, width = S.shape
340 | N = height * width
341 | S = S.reshape(B, 1, N)
342 | Valid = (S > 1e-3)
343 | xy = torch.stack((xx, yy), axis=0).reshape(1, 2, -1).float()
344 | idx = torch.arange(N, device=S.device).reshape(1, 1, N)
345 | Ofnum, args = bpdist(xy, idx, Valid, num, height, width)
346 | return Ofnum, args
347 |
348 |
349 | class BasicBlock(nn.Module):
350 | expansion = 1
351 |
352 | def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=None, padding_mode='zeros', act=nn.ReLU,
353 | last=True, drop_path=0.0):
354 | super().__init__()
355 | bias = False
356 | if norm_layer is None:
357 | bias = True
358 | norm_layer = nn.Identity
359 | self.conv1 = Conv3x3(inplanes, planes, stride, padding_mode=padding_mode, bias=bias)
360 | self.bn1 = norm_layer(planes)
361 | self.relu1 = act()
362 | self.conv2 = Conv3x3(planes, planes, padding_mode=padding_mode, bias=bias)
363 | self.bn2 = norm_layer(planes)
364 | self.downsample = downsample
365 | self.stride = stride
366 | self.last = last
367 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
368 | if last:
369 | self.relu2 = act()
370 |
371 | def forward(self, x):
372 | identity = x
373 | out = self.conv1(x)
374 | out = self.bn1(out)
375 | out = self.relu1(out)
376 | out = self.conv2(out)
377 | out = self.bn2(out)
378 | if self.downsample is not None:
379 | identity = self.downsample(x)
380 | out = self.drop_path(out) + identity
381 | if self.last:
382 | out = self.relu2(out)
383 | return out
384 |
385 |
386 | class UBNet(nn.Module):
387 | def __init__(self, inplanes, dplanes=1, norm_layer=nn.BatchNorm2d, padding_mode='zeros', act=nn.ReLU,
388 | blocknum=2, uplayer=UpCC, depth=1, block=BasicBlock, drop_path=0.):
389 | super().__init__()
390 | self._norm_layer = norm_layer
391 | self._act = act
392 | self._padding_mode = padding_mode
393 | bc = inplanes // 2
394 | self.inplanes = bc * 2
395 | self._img_dpc = 0
396 | self._img_dprs = np.linspace(0, drop_path, blocknum * (1 + depth))
397 | encoder_list = [nn.Sequential(
398 | Basic2d(inplanes + dplanes, bc * 2, norm_layer=norm_layer, kernel_size=3, padding=1,
399 | padding_mode=padding_mode, act=act),
400 | self._make_layer(block, bc * 2, blocknum, stride=1)
401 | ), ]
402 | decoder_list = []
403 | in_channels = bc * 2
404 | for i in range(depth):
405 | self.inplanes = in_channels
406 | out_channels = min(in_channels * 2, 256)
407 | encoder_list.append(self._make_layer(block, out_channels, blocknum, stride=2))
408 | decoder_list.append(uplayer(out_channels, in_channels, in_channels, norm_layer=norm_layer))
409 | in_channels = min(in_channels * 2, 256)
410 | self.encoder = nn.ModuleList(encoder_list)
411 | self.decoder = nn.ModuleList(decoder_list)
412 |
413 | def forward(self, x, d):
414 | feat = []
415 | if d is not None:
416 | x = torch.cat([x, d], dim=1)
417 | for layer in self.encoder:
418 | x = layer(x)
419 | feat.append(x)
420 | out = feat[-1]
421 | for idx in range(len(feat) - 2, -1, -1):
422 | out = self.decoder[idx](out, feat[idx])
423 | return out
424 |
425 | def _make_layer(self, block, planes, blocks, stride=1):
426 | norm_layer = self._norm_layer
427 | act = self._act
428 | padding_mode = self._padding_mode
429 | downsample = None
430 | if norm_layer is None:
431 | bias = True
432 | norm_layer = nn.Identity
433 | else:
434 | bias = False
435 | if stride != 1 or self.inplanes != planes * block.expansion:
436 | downsample = nn.Sequential(
437 | Conv1x1(self.inplanes, planes * block.expansion, stride, bias=bias),
438 | norm_layer(planes * block.expansion),
439 | )
440 | layers = []
441 | layers.append(block(self.inplanes, planes, stride, downsample, norm_layer, act=act, padding_mode=padding_mode,
442 | drop_path=self._img_dprs[self._img_dpc]))
443 | self._img_dpc += 1
444 | self.inplanes = planes * block.expansion
445 | for _ in range(1, blocks):
446 | layers.append(block(self.inplanes, planes, norm_layer=norm_layer, act=act, padding_mode=padding_mode,
447 | drop_path=self._img_dprs[self._img_dpc]))
448 | self._img_dpc += 1
449 |
450 | return nn.Sequential(*layers)
451 |
452 |
453 | class Permute(nn.Module):
454 | def __init__(self, in_channels, out_channels=1, stride=2, norm_layer=nn.BatchNorm2d, act=nn.ReLU):
455 | super().__init__()
456 | self.stride = stride
457 | self.out_channels = out_channels
458 | self.conv = nn.Sequential(
459 | Basic2d(in_channels=in_channels, out_channels=in_channels, norm_layer=norm_layer, act=act, kernel_size=1,
460 | padding=0),
461 | Basic2d(in_channels=in_channels, out_channels=in_channels, norm_layer=norm_layer, act=act, kernel_size=1,
462 | padding=0),
463 | Conv1x1(in_channels, out_channels * stride ** 2, bias=True),
464 | Rearrange('b (c h2 w2) h w -> b c (h h2) (w w2)', c=out_channels, h2=stride, w2=stride),
465 | )
466 |
467 | def forward(self, x):
468 | """
469 | """
470 | fout = self.conv(x)
471 | return fout
472 |
473 |
474 | class WPool(nn.Module):
475 | def __init__(self, in_ch, level, drift=1e6):
476 | super().__init__()
477 | self.level = level
478 | self.drift = drift
479 | self.permute = Permute(in_ch, stride=2 ** level)
480 |
481 | def forward(self, S, fout):
482 | W = self.permute(fout)
483 | size = int(2 ** self.level)
484 | M = (S > 1e-3).float()
485 | with torch.no_grad():
486 | maxW = F.max_pool2d((W + self.drift) * M, size, stride=[size, size]) - self.drift
487 | maxW = F.upsample_nearest(maxW, scale_factor=size) * M
488 | expW = torch.exp(W * M - maxW) * M
489 | avgS = F.avg_pool2d(S * expW, kernel_size=size, stride=size)
490 | avgexpW = F.avg_pool2d(expW, kernel_size=size, stride=size)
491 | Sp = avgS / (avgexpW + 1e-6)
492 | return Sp
493 |
494 |
495 | class UpCat(nn.Module):
496 | def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm2d, kernel_size=3, padding=1,
497 | padding_mode='zeros', act=nn.ReLU):
498 | super().__init__()
499 | self.upf = Basic2dTrans(in_channels + 1, out_channels, norm_layer=norm_layer, act=act)
500 | self.conv = Basic2d(out_channels * 2, out_channels,
501 | norm_layer=norm_layer, kernel_size=kernel_size,
502 | padding=padding, padding_mode=padding_mode, act=act)
503 |
504 | def forward(self, y, x, d):
505 | """
506 | x is
507 | """
508 | fout = self.upf(torch.cat([x, d], dim=1))
509 | fout = self.conv(torch.cat([fout, y], dim=1))
510 | return fout
511 |
512 |
513 | class Ident(nn.Module):
514 | def __init__(self, *args, **kwargs) -> None:
515 | super().__init__()
516 |
517 | def forward(self, *args):
518 | return args[0]
519 |
520 |
521 | class Prop(nn.Module):
522 | """
523 | """
524 |
525 | def __init__(self, Cfi, Cfp=3, Cfo=2, act=nn.GELU, norm_layer=nn.BatchNorm2d):
526 | super().__init__()
527 | """
528 | """
529 | self.dist = lambda x: (x * x).sum(1)
530 | Ct = Cfo + Cfi + Cfi + Cfp
531 | self.convXF = nn.Sequential(
532 | Basic2d(in_channels=Ct, out_channels=Cfi, norm_layer=norm_layer, act=act, kernel_size=1,
533 | padding=0),
534 | Basic2d(in_channels=Cfi, out_channels=Cfi, norm_layer=norm_layer, act=act, kernel_size=1,
535 | padding=0),
536 | )
537 | self.convXL = nn.Sequential(
538 | Basic2d(in_channels=Cfi, out_channels=Cfi, norm_layer=norm_layer, act=act, kernel_size=1,
539 | padding=0),
540 | Basic2d(in_channels=Cfi, out_channels=Cfi, norm_layer=norm_layer, act=nn.Identity, kernel_size=1,
541 | padding=0),
542 | )
543 | self.act = act()
544 | self.coef = Coef(Cfi, 3)
545 |
546 | def forward(self, If, Pf, Ofnum, args):
547 | """
548 | """
549 | num = args.shape[-2]
550 | B, Cfi, H, W = If.shape
551 | N = H * W
552 | B, Cfp, M = Pf.shape
553 | If = If.view(B, Cfi, 1, N)
554 | Pf = Pf.view(B, Cfp, 1, M)
555 | Ifnum = If.expand(B, Cfi, num, N) ## Ifnum is BxCfixnumxN
556 | IPfnum = torch.gather(
557 | input=If.expand(B, Cfi, num, N),
558 | dim=-1,
559 | index=args.view(B, 1, num, N).expand(B, Cfi, num, N)) ## IPfnum is BxCfixnumxN
560 | Pfnum = torch.gather(
561 | input=Pf.expand(B, Cfp, num, M),
562 | dim=-1,
563 | index=args.view(B, 1, num, N).expand(B, Cfp, num, N)) ## Pfnum is BxCfpxnumxN
564 | X = torch.cat([Ifnum, IPfnum, Pfnum, Ofnum], dim=1)
565 | XF = self.convXF(X)
566 | XF = self.act(XF + self.convXL(XF))
567 | Alpha, Beta, Omega = self.coef(XF)
568 | Omega = torch.softmax(Omega, dim=2)
569 | dout = torch.sum(((Alpha + 1) * Pfnum[:, -1:] + Beta) * Omega, dim=2, keepdim=True)
570 | return dout.view(B, 1, H, W)
571 |
572 |
573 | class PMP(nn.Module):
574 | """
575 | Pre+MF+Post
576 | """
577 |
578 | def __init__(self, level, in_ch, out_ch, drop_path, up=True, pool=True):
579 | super().__init__()
580 | self.level = level
581 | if up:
582 | self.upcat = UpCat(in_ch, out_ch)
583 | else:
584 | self.upcat = Ident()
585 | if pool:
586 | self.wpool = WPool(out_ch, level=level)
587 | else:
588 | self.wpool = Ident()
589 | self.dist = Dist(num=4)
590 | self.prop = Prop(out_ch)
591 | self.fuse = UBNet(out_ch, dplanes=3, blocknum=2, depth=5 - level, drop_path=drop_path)
592 | self.conv = Conv3x3(out_ch, 1, bias=True)
593 | self.cspn = CSPN(out_ch, pt=2 * (6 - level))
594 |
595 | def pinv(self, S, K, xx, yy):
596 | fx, fy, cx, cy = K[:, 0:1, 0:1], K[:, 1:2, 1:2], K[:, 0:1, 2:3], K[:, 1:2, 2:3]
597 | S = S.view(S.shape[0], 1, -1)
598 | xx = xx.reshape(1, 1, -1)
599 | yy = yy.reshape(1, 1, -1)
600 | Px = S * (xx - cx) / fx
601 | Py = S * (yy - cy) / fy
602 | Pz = S
603 | Pxyz = torch.cat([Px, Py, Pz], dim=1).contiguous()
604 | return Pxyz
605 |
606 | def forward(self, fout, dout, XI, S, K):
607 | fout = self.upcat(XI, fout, dout)
608 | Sp = self.wpool(S, fout)
609 | Kp = K.clone()
610 | Kp[:, :2] = Kp[:, :2] / 2 ** self.level
611 | B, _, height, width = Sp.shape
612 | xx, yy = torch.meshgrid(torch.arange(width, device=Sp.device), torch.arange(height, device=Sp.device),
613 | indexing='xy')
614 | ###############################################################
615 | # Pre
616 | Pxyz = self.pinv(Sp, Kp, xx, yy)
617 | Ofnum, args = self.dist(Sp, xx, yy)
618 | dout = self.prop(fout, Pxyz, Ofnum, args)
619 | ###############################################################
620 | # MF
621 | Pxyz = self.pinv(dout, Kp, xx, yy).view(dout.shape[0], 3, dout.shape[2], dout.shape[3])
622 | fout = self.fuse(fout, Pxyz)
623 | res = self.conv(fout)
624 | dout = dout + res
625 | ###############################################################
626 | # Post
627 | dout = self.cspn(fout, dout, Sp)
628 | return fout, dout
629 |
--------------------------------------------------------------------------------
/optimizers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : optimizers.py
3 | # @Project: BP-Net
4 | # @Author : jie
5 | # @Time : 5/11/22 3:47 PM
6 |
7 | from torch.optim import SGD, AdamW, Adam
8 |
9 | __all__ = [
10 | 'Adam',
11 | 'AdamW',
12 | 'SGD',
13 | ]
14 |
--------------------------------------------------------------------------------
/rpnloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class RandomProposalNormalizationLoss(nn.Module):
5 | def __init__(self, deltas=(2 ** (-5 * 2), 2 ** (-4 * 2), 2 ** (-3 * 2), 2 ** (-2 * 2), 2 ** (-1 * 2), 1), num_proposals=32, min_crop_ratio=0.125, max_crop_ratio=0.5):
6 | super(RandomProposalNormalizationLoss, self).__init__()
7 | self.num_proposals = num_proposals
8 | self.min_crop_ratio = min_crop_ratio
9 | self.max_crop_ratio = max_crop_ratio
10 | self.deltas = deltas
11 |
12 | def forward(self, outputs, target):
13 | """
14 | :param predicted_depth: Predicted depth map (B, 1, H, W)
15 | :param ground_truth_depth: Ground truth depth map (B, 1, H, W)
16 | :return: RPNL loss value
17 | """
18 | loss = [delta * self.compute_loss(ests, target) for ests, delta in zip(outputs, self.deltas)]
19 | return loss
20 |
21 | def compute_loss(self, predicted_depth, ground_truth_depth):
22 | B, _, H, W = predicted_depth.shape
23 | loss = 0.0
24 |
25 | for _ in range(self.num_proposals):
26 | # Randomly select crop size
27 | crop_ratio = torch.rand(1).item() * (self.max_crop_ratio - self.min_crop_ratio) + self.min_crop_ratio
28 | crop_h, crop_w = int(H * crop_ratio), int(W * crop_ratio)
29 |
30 | # Randomly select top-left corner of the crop
31 | top = torch.randint(0, H - crop_h + 1, (1,)).item()
32 | left = torch.randint(0, W - crop_w + 1, (1,)).item()
33 |
34 | # Extract patches
35 | pred_patch = predicted_depth[:, :, top:top + crop_h, left:left + crop_w]
36 | gt_patch = ground_truth_depth[:, :, top:top + crop_h, left:left + crop_w]
37 |
38 | # Flatten patches for normalization
39 | pred_patch = pred_patch.reshape(B, -1)
40 | gt_patch = gt_patch.reshape(B, -1)
41 |
42 | # Normalize patches using median absolute deviation normalization
43 | pred_median = pred_patch.median(dim=1, keepdim=True)[0]
44 | gt_median = gt_patch.median(dim=1, keepdim=True)[0]
45 |
46 | pred_mad = torch.median(torch.abs(pred_patch - pred_median), dim=1, keepdim=True)[0]
47 | gt_mad = torch.median(torch.abs(gt_patch - gt_median), dim=1, keepdim=True)[0]
48 |
49 | pred_patch_norm = (pred_patch - pred_median) / (pred_mad + 1e-6)
50 | gt_patch_norm = (gt_patch - gt_median) / (gt_mad + 1e-6)
51 |
52 | # Calculate the L1 difference between normalized patches
53 | patch_loss = torch.mean(torch.abs(pred_patch_norm - gt_patch_norm))
54 | loss += patch_loss
55 |
56 | loss /= self.num_proposals
57 | return loss
58 |
59 | # Example usage:
60 | if __name__ == "__main__":
61 | # Create random predicted and ground truth depth maps
62 | predicted_depth = torch.rand(2, 1, 256, 256) # (batch_size, channels, height, width)
63 | ground_truth_depth = torch.rand(2, 1, 256, 256)
64 |
65 | # Initialize the loss function
66 | rpn_loss = RandomProposalNormalizationLoss()
67 |
68 | # Compute the loss
69 | loss_value = rpn_loss(predicted_depth, ground_truth_depth)
70 | print(f"RPN Loss: {loss_value.item()}")
71 |
72 | class RegularizationLoss(nn.Module):
73 | """
74 | Enforce losses on pixels without any gts.
75 | """
76 | def __init__(self, loss_weight=0.1):
77 | super(RegularizationLoss, self).__init__()
78 | self.loss_weight = loss_weight
79 | self.eps = 1e-6
80 |
81 | def forward(self, prediction, gt):
82 | mask = gt > 1e-3
83 | pred_wo_gt = prediction[~mask]
84 | loss = 1/ (torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + self.eps))
85 | return loss * self.loss_weight
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | # torchrun --nproc_per_node=4 --master_port 4321 train_distill.py \
2 | # gpus=[0,1,2,3] num_workers=4 name=BP_KITTI \
3 | # net=PMP data=KITTI \
4 | # lr=1e-3 train_batch_size=2 test_batch_size=2 \
5 | # sched/lr=NoiseOneCycleCosMo sched.lr.policy.max_momentum=0.90 \
6 | # nepoch=30 test_epoch=25 ++net.sbn=true
7 |
8 | torchrun --nproc_per_node=1 --master_port 4321 train_distill.py \
9 | gpus=[0] num_workers=1 name=BP_KITTI \
10 | net=PMP data=KITTI \
11 | lr=1e-3 train_batch_size=2 test_batch_size=2 \
12 | sched/lr=NoiseOneCycleCosMo sched.lr.policy.max_momentum=0.90 \
13 | nepoch=30 test_epoch=25 ++net.sbn=true
--------------------------------------------------------------------------------
/schedulers.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : schedulers.py
3 | # @Project: BP-Net
4 | # @Author : jie
5 | # @Time : 5/11/22 3:50 PM
6 | import random
7 | import sys
8 | from torch.optim.lr_scheduler import StepLR, MultiStepLR, OneCycleLR, LambdaLR, LinearLR, ExponentialLR
9 | import numpy as np
10 | import torch
11 |
12 |
13 | def NoiseLR(**kwargs):
14 | lr_sched = getattr(sys.modules[__name__], kwargs.pop('lr_sched', 'OneCycleLR'))
15 |
16 | class sched(lr_sched):
17 | def __init__(self, **kwargs):
18 | self.noise_pct = kwargs.pop('noise_pct', 0.1)
19 | self.noise_seed = kwargs.pop('noise_seed', 0)
20 | super().__init__(**kwargs)
21 |
22 | def get_lr(self):
23 | """
24 | lrn: Learning Rate with Noise
25 | """
26 | g = torch.Generator()
27 | g.manual_seed(self.noise_seed + self.last_epoch)
28 | noise = 2 * torch.rand(1, generator=g).item() - 1
29 | lrs = super().get_lr()
30 | lrn = []
31 | for lr in lrs:
32 | lrn.append(lr * (1 + self.noise_pct * noise))
33 | return lrn
34 |
35 | return sched(**kwargs)
36 |
37 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : train_amp.py
3 | # @Project: BP-Net
4 | # @Author : jie
5 | # @Time : 10/27/21 3:58 PM
6 |
7 | import torch
8 | from tqdm import tqdm
9 | import hydra
10 | from PIL import Image
11 | import os
12 | from omegaconf import OmegaConf
13 | from utils import *
14 |
15 |
16 | def test(run, mode='selval', save=False):
17 | dataloader = run.testloader
18 | net = run.net_ema.module
19 | net.eval()
20 | tops = [AverageMeter() for i in range(len(run.metric.metric_name))]
21 | if save:
22 | dir_path = f'results/{run.cfg.name}/{mode}'
23 | os.makedirs(dir_path, exist_ok=True)
24 | with torch.no_grad():
25 | for idx, datas in enumerate(
26 | tqdm(dataloader, desc="test ", dynamic_ncols=True, leave=False, disable=run.rank)):
27 | datas = run.init_cuda(*datas)
28 | output = net(*datas[:-1])
29 | if isinstance(output, (list, tuple)):
30 | output = output[-1]
31 | precs = run.metric(output, datas[-1])
32 | for prec, top in zip(precs, tops):
33 | top.update(prec.mean().detach().cpu().item())
34 | if save:
35 | for i in range(output.shape[0]):
36 | index = idx * output.shape[0] + i
37 | file_path = os.path.join(dir_path, f'{index:010d}.png')
38 | img = (output[i, 0] * 256.0).detach().cpu().numpy().astype('uint16')
39 | Img = Image.fromarray(img)
40 | Img.save(file_path)
41 | logs = ""
42 | for name, top in zip(run.metric.metric_name, tops):
43 | logs += f" {name}:{top.avg:.7f} "
44 | run.ddp_log(logs, always=True)
45 |
46 |
47 | @hydra.main(config_path='configs', config_name='config', version_base='1.2')
48 | def main(cfg):
49 | with Trainer(cfg) as run:
50 | test(run, mode=cfg.data.testset.mode, save=OmegaConf.select(cfg, 'save', default=False))
51 |
52 |
53 |
54 | if __name__ == '__main__':
55 | main()
56 |
--------------------------------------------------------------------------------
/train_distill.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : train_amp.py
3 | # @Project: BP-Net
4 | # @Author : jie
5 | # @Time : 10/27/21 3:58 PM
6 |
7 | import torch
8 | from tqdm import tqdm
9 | import hydra
10 | import torch.distributed as dist
11 | from utils import *
12 | from rpnloss import RandomProposalNormalizationLoss
13 |
14 | def train(run):
15 |
16 | compute_disp_loss = RandomProposalNormalizationLoss()
17 |
18 | for datas in tqdm(run.trainloader, desc="train", dynamic_ncols=True, leave=False, disable=run.rank):
19 | if run.epoch >= run.cfg.test_epoch:
20 | if run.iter % run.cfg.test_iter == 0:
21 | test(run, iter=True)
22 | datas = run.init_cuda(*datas)
23 | run.net.train()
24 | run.optimizer.zero_grad(set_to_none=True)
25 | output = run.net(*datas[:-1])
26 | loss = run.criterion(output, datas[-1])
27 |
28 | disp = 1 / (datas[1] + 0.5)
29 | # disp_loss = compute_disp_loss(output, disp, torch.ones_like(disp).float())
30 | disp_loss = compute_disp_loss(output, disp)
31 |
32 | loss += disp_loss
33 |
34 | sum(loss).backward()
35 | if run.clip:
36 | grad_norm = run.clip(run.net.parameters())
37 | run.optimizer.step()
38 | run.net_ema.update(run.net)
39 | if run.lr_iter:
40 | run.lr_scheduler.step()
41 | if run.iter % run.cfg.vis_iter == 0:
42 | run.writer.add_scalar("Lr", run.optimizer.param_groups[0]['lr'], run.iter)
43 | run.writer.add_scalars("Loss", {f"{idx}": l.item() for idx, l in enumerate(loss)}, run.iter)
44 | if run.clip and (grad_norm is not None):
45 | run.writer.add_scalar("GradNorm", grad_norm.item(), run.iter)
46 | run.iter += 1
47 | if not run.lr_iter:
48 | run.lr_scheduler.step()
49 | run.writer.flush()
50 |
51 |
52 | def test(run, iter=False):
53 | top1 = AverageMeter()
54 | net = run.net_ema.module
55 | best_metric_name = "best_metric_ema"
56 | legand = 'net_ema'
57 | net.eval()
58 | with torch.no_grad():
59 | for datas in tqdm(run.testloader, desc="test ", dynamic_ncols=True, leave=False, disable=run.rank):
60 | datas = run.init_cuda(*datas)
61 | output = net(*datas[:-1])
62 | if isinstance(output, (list, tuple)):
63 | output = output[-1]
64 | prec1 = run.metric(output, datas[-1])
65 | if isinstance(prec1, (list, tuple)):
66 | prec1 = prec1[0]
67 | if run.ddp:
68 | dist.reduce(prec1, 0, dist.ReduceOp.AVG)
69 | top1.update(prec1.item())
70 | if iter:
71 | run.writer.add_scalars("RMSE_Iter", {legand: top1.avg}, run.iter)
72 | else:
73 | run.writer.add_scalars("RMSE", {legand: top1.avg}, run.epoch)
74 |
75 | if top1.avg < getattr(run, best_metric_name):
76 | setattr(run, best_metric_name, top1.avg)
77 | run.save_state()
78 | run.ddp_cout(f'Epoch: {run.epoch} {best_metric_name}: {top1.avg:.7f}\n')
79 | else:
80 | best_metric_value = getattr(run, best_metric_name)
81 | run.ddp_cout(f'Epoch: {run.epoch} current: {top1.avg:.7f}, best: {best_metric_value:.7f}. \n')
82 |
83 |
84 | @hydra.main(config_path='configs', config_name='config', version_base='1.2')
85 | def main(cfg):
86 | with Trainer(cfg) as run:
87 | for epoch in tqdm(range(run.cfg.start_epoch, run.cfg.nepoch), desc="epoch", dynamic_ncols=True):
88 | run.epoch = epoch
89 | if run.train_sampler:
90 | run.train_sampler.set_epoch(epoch)
91 | train(run)
92 | torch.cuda.synchronize()
93 | test(run)
94 | torch.cuda.synchronize()
95 |
96 | if __name__ == '__main__':
97 | main()
98 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env
2 | # -*- coding: utf-8 -*-
3 | # @Filename : utils
4 | # @Date : 2022-05-06
5 | # @Project: BP-Net
6 | # @AUTHOR : jie
7 |
8 | import os
9 | import torch
10 | import random
11 | import numpy as np
12 | from torch.utils.tensorboard import SummaryWriter
13 | import time
14 | from tqdm import tqdm
15 | from torch.nn.parallel import DistributedDataParallel as DDP
16 | import logging
17 | from hydra.utils import instantiate, get_class
18 | from omegaconf import OmegaConf
19 | import cv2
20 | import augs
21 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_
22 | from collections import OrderedDict
23 | import math
24 |
25 | __all__ = [
26 | 'AverageMeter',
27 | 'Trainer',
28 | ]
29 |
30 |
31 | class AverageMeter(object):
32 | def __init__(self):
33 | self.reset()
34 |
35 | def reset(self):
36 | self.val = 0
37 | self.avg = 0
38 | self.sum = 0
39 | self.count = 0
40 |
41 | def update(self, val, n=1):
42 | self.val = val
43 | self.sum += val * n
44 | self.count += n
45 | self.avg = self.sum / self.count
46 |
47 |
48 | class Trainer(object):
49 | def __init__(self, cfg):
50 | self.cfg = cfg
51 | self.rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0
52 | self.cfg.gpu_id = self.cfg.gpus[self.rank]
53 | self.init_gpu()
54 | self.ddp = len(self.cfg.gpus) > 1
55 | self.iter = 0
56 | self.epoch = 0
57 | self.best_metric_ema = 100
58 | #####################################################################################
59 | self.log = self.init_log()
60 | self.init_device()
61 | self.init_seed()
62 | self.writer = self.init_viz()
63 | self.trainloader, self.testloader = self.init_dataset()
64 | net = self.init_net()
65 | criterion = self.init_loss()
66 | metric = self.init_metric()
67 | self.net, self.criterion, self.metric = self.init_cuda(net, criterion, metric)
68 | self.net_ema = self.init_ema()
69 | if self.ddp:
70 | self.net = DDP(self.net)
71 | self.optimizer = self.init_optim()
72 | self.lr_scheduler = self.init_sched_lr()
73 | self.lr_iter = OmegaConf.select(self.cfg.sched.lr, 'iter', default=False)
74 | self.clip = self.init_clip()
75 |
76 | def init_log(self):
77 | return Blank() if self.rank else logging.getLogger(f'{self.cfg.name}')
78 |
79 | def init_device(self):
80 | torch.cuda.set_device(f'cuda:{self.cfg.gpu_id}')
81 | if self.ddp:
82 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
83 | self.ddp_log(f'device is {self.cfg.gpu_id}', always=True)
84 |
85 | def init_seed(self):
86 | manual_seed = self.cfg.manual_seed
87 | self.ddp_log(f"Random Seed: {manual_seed:04d}")
88 | torch.initial_seed()
89 | random.seed(manual_seed)
90 | np.random.seed(manual_seed)
91 | torch.manual_seed(manual_seed)
92 | torch.cuda.manual_seed_all(manual_seed)
93 |
94 | def init_viz(self):
95 | if self.rank:
96 | return Blank()
97 | else:
98 | writer_name = os.path.join('runs', f'{self.cfg.name}')
99 | return SummaryWriter(writer_name)
100 |
101 | def init_dataset(self):
102 | trainset = instantiate(self.cfg.data.trainset)
103 | testset = instantiate(self.cfg.data.testset)
104 | if self.ddp:
105 | train_sampler = torch.utils.data.distributed.DistributedSampler(
106 | trainset,
107 | num_replicas=len(self.cfg.gpus),
108 | rank=self.rank,
109 | shuffle=True,
110 | )
111 | test_sampler = torch.utils.data.distributed.DistributedSampler(
112 | testset,
113 | num_replicas=len(self.cfg.gpus),
114 | rank=self.rank,
115 | shuffle=False,
116 | )
117 | else:
118 | train_sampler = None
119 | test_sampler = None
120 | self.train_sampler = train_sampler
121 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.cfg.train_batch_size,
122 | num_workers=self.cfg.num_workers, shuffle=(train_sampler is None),
123 | sampler=train_sampler,
124 | drop_last=True, pin_memory=True)
125 | testloader = torch.utils.data.DataLoader(testset, batch_size=self.cfg.test_batch_size,
126 | num_workers=self.cfg.num_workers, shuffle=False,
127 | sampler=test_sampler,
128 | drop_last=True, pin_memory=True)
129 | self.ddp_log(f'num_train = {len(trainloader)}, num_test = {len(testloader)}')
130 | return trainloader, testloader
131 |
132 |
133 | def init_net(self):
134 | model = instantiate(self.cfg.net.model)
135 | if 'chpt' in self.cfg:
136 | self.ddp_log(f'resume CHECKPOINTS')
137 | save_path = os.path.join('checkpoints', self.cfg.chpt)
138 | cp = torch.load(self.cfg.chpt, map_location=torch.device('cpu'))
139 | model.load_state_dict(cp['net'], strict=True)
140 | self.best_metric_ema = cp['best_metric_ema']
141 | del cp
142 | if self.ddp and OmegaConf.select(self.cfg.net, 'sbn', default=False):
143 | self.ddp_log('sbn')
144 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
145 | else:
146 | """
147 | SBN is not compatible with torch.compile
148 | """
149 | if OmegaConf.select(self.cfg.net, 'compile', default=False):
150 | self.ddp_log('compile')
151 | model = torch.compile(model)
152 | return model
153 |
154 | def init_ema(self):
155 | return instantiate(self.cfg.net.ema, model=self.net, ddp=self.ddp)
156 |
157 | def init_loss(self):
158 | return instantiate(self.cfg.loss)
159 |
160 | def init_metric(self):
161 | return instantiate(self.cfg.metric)
162 |
163 | def init_cuda(self, *modules):
164 | modules = [module.to(f'cuda:{self.cfg.gpu_id}') for module in modules]
165 | return modules
166 |
167 | def init_gpu(self):
168 | with torch.cuda.device(f'cuda:{self.cfg.gpu_id}'):
169 | torch.cuda.empty_cache()
170 | torch.backends.cudnn.benchmark = True
171 |
172 | def init_optim(self):
173 | optim = instantiate(self.cfg.optim, _partial_=True)
174 | return optim(params=config_param(self.net))
175 |
176 | def init_clip(self):
177 | if 'clip' in self.cfg.net:
178 | return instantiate(self.cfg.net.clip, _partial_=True)
179 | else:
180 | return None
181 |
182 | def init_sched_lr(self):
183 | sched = instantiate(self.cfg.sched.lr.policy, _partial_=True)
184 | return sched(optimizer=self.optimizer)
185 |
186 |
187 | def ddp_log(self, content, always=False):
188 | # self.log.info(f'{content}')
189 | if (not self.rank) or always:
190 | self.log.info(f'{content}')
191 |
192 | def ddp_cout(self, content, always=False):
193 | # tqdm.write(f'{content}')
194 | if (not self.rank) or always:
195 | tqdm.write(f'{content}')
196 |
197 | def save_state(self):
198 | if self.rank:
199 | return
200 | save_path = os.path.join('checkpoints', self.cfg.name)
201 | os.makedirs(save_path, exist_ok=True)
202 | model = self.net_ema.module
203 | if hasattr(model, 'module'):
204 | model = model.module
205 | model_state_dict = model.state_dict()
206 | state_dict = {
207 | 'net': model_state_dict,
208 | 'epoch': self.epoch,
209 | 'best_metric_ema': self.best_metric_ema,
210 | }
211 | torch.save(state_dict, os.path.join(save_path, 'result_ema.pth'))
212 |
213 | def __enter__(self):
214 | return self
215 |
216 | def __exit__(self, *args, **kwargs):
217 | self.writer.close()
218 | self.ddp_log(f'best_metric_ema={self.best_metric_ema:.4f}')
219 |
220 |
221 | def config_param(model):
222 | param_groups = []
223 | other_params = []
224 | for name, param in model.named_parameters():
225 | if len(param.shape) == 1:
226 | g = {'params': [param], 'weight_decay': 0.0}
227 | param_groups.append(g)
228 | else:
229 | other_params.append(param)
230 | param_groups.append({'params': other_params})
231 | return param_groups
232 |
233 |
234 | def set_requires_grad(model, requires_grad=True):
235 | for p in model.parameters():
236 | p.requires_grad = requires_grad
237 |
238 |
239 |
240 | class Blank(object):
241 | def __getattr__(self, name):
242 | def wrapper(*args, **kwargs):
243 | return None
244 | return wrapper
245 |
246 | FloorDiv = lambda a, b: a // b
247 |
248 | CeilDiv = lambda a, b: math.ceil(a / b)
249 |
250 | Div = lambda a, b: a / b
251 |
252 | Mul = lambda a, b: a * b
253 |
254 |
255 |
--------------------------------------------------------------------------------
/utils_infer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env
2 | # -*- coding: utf-8 -*-
3 | # @Filename : utils
4 | # @Date : 2022-05-06
5 | # @Project: BP-Net
6 | # @AUTHOR : jie
7 |
8 | import os
9 | import torch
10 | import random
11 | import numpy as np
12 | from torch.utils.tensorboard import SummaryWriter
13 | import time
14 | from tqdm import tqdm
15 | from torch.nn.parallel import DistributedDataParallel as DDP
16 | import logging
17 | from hydra.utils import instantiate, get_class
18 | from omegaconf import OmegaConf
19 | import cv2
20 | import augs
21 | from torch.nn.utils import clip_grad_norm_, clip_grad_value_
22 | from collections import OrderedDict
23 | import math
24 |
25 | __all__ = [
26 | 'AverageMeter',
27 | 'Trainer',
28 | ]
29 |
30 |
31 | class AverageMeter(object):
32 | def __init__(self):
33 | self.reset()
34 |
35 | def reset(self):
36 | self.val = 0
37 | self.avg = 0
38 | self.sum = 0
39 | self.count = 0
40 |
41 | def update(self, val, n=1):
42 | self.val = val
43 | self.sum += val * n
44 | self.count += n
45 | self.avg = self.sum / self.count
46 |
47 |
48 | class Trainer(object):
49 | def __init__(self, cfg):
50 | self.cfg = cfg
51 | self.rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0
52 | self.cfg.gpu_id = self.cfg.gpus[self.rank]
53 | self.init_gpu()
54 | self.ddp = len(self.cfg.gpus) > 1
55 | self.iter = 0
56 | self.epoch = 0
57 | self.best_metric_ema = 100
58 | #####################################################################################
59 | self.log = self.init_log()
60 | self.init_device()
61 | self.init_seed()
62 | self.writer = self.init_viz()
63 | net = self.init_net()
64 | criterion = self.init_loss()
65 | metric = self.init_metric()
66 | self.net, self.criterion, self.metric = self.init_cuda(net, criterion, metric)
67 | self.net_ema = self.init_ema()
68 | if self.ddp:
69 | self.net = DDP(self.net)
70 | self.optimizer = self.init_optim()
71 | self.lr_scheduler = self.init_sched_lr()
72 | self.lr_iter = OmegaConf.select(self.cfg.sched.lr, 'iter', default=False)
73 | self.clip = self.init_clip()
74 |
75 | def init_log(self):
76 | return Blank() if self.rank else logging.getLogger(f'{self.cfg.name}')
77 |
78 | def init_device(self):
79 | torch.cuda.set_device(f'cuda:{self.cfg.gpu_id}')
80 | if self.ddp:
81 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
82 | self.ddp_log(f'device is {self.cfg.gpu_id}', always=True)
83 |
84 | def init_seed(self):
85 | manual_seed = self.cfg.manual_seed
86 | self.ddp_log(f"Random Seed: {manual_seed:04d}")
87 | torch.initial_seed()
88 | random.seed(manual_seed)
89 | np.random.seed(manual_seed)
90 | torch.manual_seed(manual_seed)
91 | torch.cuda.manual_seed_all(manual_seed)
92 |
93 | def init_viz(self):
94 | if self.rank:
95 | return Blank()
96 | else:
97 | writer_name = os.path.join('runs', f'{self.cfg.name}')
98 | return SummaryWriter(writer_name)
99 |
100 | def init_net(self):
101 | model = instantiate(self.cfg.net.model)
102 | if 'chpt' in self.cfg:
103 | self.ddp_log(f'resume CHECKPOINTS')
104 | save_path = os.path.join('checkpoints', self.cfg.chpt)
105 | cp = torch.load(os.path.join(save_path, 'result_ema.pth'), map_location=torch.device('cpu'))
106 | model.load_state_dict(cp['net'], strict=True)
107 | self.best_metric_ema = cp['best_metric_ema']
108 | print("Epoch:", cp["epoch"])
109 | del cp
110 | if self.ddp and OmegaConf.select(self.cfg.net, 'sbn', default=False):
111 | self.ddp_log('sbn')
112 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
113 | else:
114 | """
115 | SBN is not compatible with torch.compile
116 | """
117 | if OmegaConf.select(self.cfg.net, 'compile', default=False):
118 | self.ddp_log('compile')
119 | model = torch.compile(model)
120 | return model
121 |
122 | def init_ema(self):
123 | return instantiate(self.cfg.net.ema, model=self.net, ddp=self.ddp)
124 |
125 | def init_loss(self):
126 | return instantiate(self.cfg.loss)
127 |
128 | def init_metric(self):
129 | return instantiate(self.cfg.metric)
130 |
131 | def init_cuda(self, *modules):
132 | modules = [module.to(f'cuda:{self.cfg.gpu_id}') for module in modules]
133 | return modules
134 |
135 | def init_gpu(self):
136 | with torch.cuda.device(f'cuda:{self.cfg.gpu_id}'):
137 | torch.cuda.empty_cache()
138 | torch.backends.cudnn.benchmark = True
139 |
140 | def init_optim(self):
141 | optim = instantiate(self.cfg.optim, _partial_=True)
142 | return optim(params=config_param(self.net))
143 |
144 | def init_clip(self):
145 | if 'clip' in self.cfg.net:
146 | return instantiate(self.cfg.net.clip, _partial_=True)
147 | else:
148 | return None
149 |
150 | def init_sched_lr(self):
151 | sched = instantiate(self.cfg.sched.lr.policy, _partial_=True)
152 | return sched(optimizer=self.optimizer)
153 |
154 |
155 | def ddp_log(self, content, always=False):
156 | # self.log.info(f'{content}')
157 | if (not self.rank) or always:
158 | self.log.info(f'{content}')
159 |
160 | def ddp_cout(self, content, always=False):
161 | # tqdm.write(f'{content}')
162 | if (not self.rank) or always:
163 | tqdm.write(f'{content}')
164 |
165 | def save_state(self):
166 | if self.rank:
167 | return
168 | save_path = os.path.join('checkpoints', self.cfg.name)
169 | os.makedirs(save_path, exist_ok=True)
170 | model = self.net_ema.module
171 | if hasattr(model, 'module'):
172 | model = model.module
173 | model_state_dict = model.state_dict()
174 | state_dict = {
175 | 'net': model_state_dict,
176 | 'epoch': self.epoch,
177 | 'best_metric_ema': self.best_metric_ema,
178 | }
179 | torch.save(state_dict, os.path.join(save_path, 'result_ema.pth'))
180 |
181 | def __enter__(self):
182 | return self
183 |
184 | def __exit__(self, *args, **kwargs):
185 | self.writer.close()
186 | self.ddp_log(f'best_metric_ema={self.best_metric_ema:.4f}')
187 |
188 |
189 | def config_param(model):
190 | param_groups = []
191 | other_params = []
192 | for name, param in model.named_parameters():
193 | if len(param.shape) == 1:
194 | g = {'params': [param], 'weight_decay': 0.0}
195 | param_groups.append(g)
196 | else:
197 | other_params.append(param)
198 | param_groups.append({'params': other_params})
199 | return param_groups
200 |
201 |
202 | def set_requires_grad(model, requires_grad=True):
203 | for p in model.parameters():
204 | p.requires_grad = requires_grad
205 |
206 |
207 |
208 | class Blank(object):
209 | def __getattr__(self, name):
210 | def wrapper(*args, **kwargs):
211 | return None
212 | return wrapper
213 |
214 | FloorDiv = lambda a, b: a // b
215 |
216 | CeilDiv = lambda a, b: math.ceil(a / b)
217 |
218 | Div = lambda a, b: a / b
219 |
220 | Mul = lambda a, b: a * b
221 |
222 |
223 |
--------------------------------------------------------------------------------