├── .gitignore
├── LICENSE
├── README.md
├── figs
├── framework.png
├── heterogeneous.png
└── homogeneous.png
├── image2image_translation
├── cfg.py
├── data
│ ├── train_domain.txt
│ └── val_domain.txt
├── eval.py
├── main.py
├── models
│ ├── model_cfg.py
│ ├── model_pruning.py
│ └── modules.py
└── utils
│ ├── __init__.py
│ ├── dataset.py
│ ├── fid_kid.py
│ ├── inception.py
│ └── utils.py
├── object-detection-3d
├── models
│ ├── __init__.py
│ ├── ap_helper.py
│ ├── detector.py
│ ├── dump_helper.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── drop.py
│ │ ├── helper.py
│ │ └── weight_init.py
│ ├── loss_helper.py
│ ├── losses.py
│ ├── matcher.py
│ ├── modules.py
│ ├── multi_head_attention.py
│ ├── pointnet.py
│ ├── transformer.py
│ └── vit.py
├── pointnet2
│ ├── _ext_src
│ │ ├── include
│ │ │ ├── ball_query.h
│ │ │ ├── cuda_utils.h
│ │ │ ├── group_points.h
│ │ │ ├── interpolate.h
│ │ │ ├── sampling.h
│ │ │ └── utils.h
│ │ └── src
│ │ │ ├── ball_query.cpp
│ │ │ ├── ball_query_gpu.cu
│ │ │ ├── bindings.cpp
│ │ │ ├── group_points.cpp
│ │ │ ├── group_points_gpu.cu
│ │ │ ├── interpolate.cpp
│ │ │ ├── interpolate_gpu.cu
│ │ │ ├── sampling.cpp
│ │ │ └── sampling_gpu.cu
│ ├── pointnet2_modules.py
│ ├── pointnet2_test.py
│ ├── pointnet2_utils.py
│ ├── pytorch_utils.py
│ └── setup.py
├── sunrgbd
│ ├── model_util_sunrgbd.py
│ ├── sunrgbd_data.py
│ ├── sunrgbd_detection_dataset.py
│ └── sunrgbd_utils.py
├── train_dist.py
└── utils
│ ├── __init__.py
│ ├── box_ops.py
│ ├── box_util.py
│ ├── eval_det.py
│ ├── logger.py
│ ├── lr_scheduler.py
│ ├── metric_util.py
│ ├── misc.py
│ ├── nms.py
│ ├── nn_distance.py
│ ├── pc_util.py
│ └── visual.py
└── semantic_segmentation
├── config.py
├── data
└── nyudv2
│ ├── train.txt
│ └── val.txt
├── main.py
├── models
├── __init__.py
├── mix_transformer.py
├── modules.py
└── segformer.py
└── utils
├── __init__.py
├── cmap.npy
├── datasets.py
├── helpers.py
├── load_state.txt
├── meter.py
├── optimizer.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | __pycache__
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yikai Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multimodal Token Fusion for Vision Transformers
2 |
3 | By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang.
4 |
5 | [**[Paper]**](https://arxiv.org/pdf/2204.08721.pdf)
6 |
7 | This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022.
8 |
9 |
10 |

11 |
12 |
13 | Homogeneous predictions,
14 |
15 |

16 |
17 |
18 | Heterogeneous predictions,
19 |
20 |

21 |
22 |
23 |
24 | ## Datasets
25 |
26 | For semantic segmentation task on NYUDv2 ([official dataset](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)), we provide a link to download the dataset [here](https://drive.google.com/drive/folders/1mXmOXVsd5l9-gYHk92Wpn6AcKAbE0m3X?usp=sharing). The provided dataset is originally preprocessed in this [repository](https://github.com/DrSleep/light-weight-refinenet), and we add depth data in it.
27 |
28 | For image-to-image translation task, we use the sample dataset of [Taskonomy](http://taskonomy.stanford.edu/), where a link to download the sample dataset is [here](https://github.com/alexsax/taskonomy-sample-model-1.git).
29 |
30 | Please modify the data paths in the codes, where we add comments 'Modify data path'.
31 |
32 |
33 | ## Dependencies
34 | ```
35 | python==3.6
36 | pytorch==1.7.1
37 | torchvision==0.8.2
38 | numpy==1.19.2
39 | ```
40 |
41 |
42 | ## Semantic Segmentation
43 |
44 |
45 | First,
46 | ```
47 | cd semantic_segmentation
48 | ```
49 |
50 | Download the [segformer](https://github.com/NVlabs/SegFormer) pretrained model (pretrained on ImageNet) from [weights](https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia), e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'.
51 |
52 | Training script for segmentation with RGB and Depth input,
53 | ```
54 | python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2
55 | ```
56 |
57 | Evaluation script,
58 | ```
59 | python main.py --gpu 0 --resume path_to_pth --evaluate # optionally use --save-img to visualize results
60 | ```
61 |
62 | Checkpoint models, training logs, mask ratios and the **single-scale** performance on NYUDv2 are provided as follows:
63 |
64 | | Method | Backbone | Pixel Acc. (%) | Mean Acc. (%) | Mean IoU (%) | Download |
65 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|
66 | |[CEN](https://github.com/yikaiw/CEN)| ResNet101 | 76.2 | 62.8 | 51.1 | [Google Drive](https://drive.google.com/drive/folders/1wim_cBG-HW0bdipwA1UbnGeDwjldPIwV?usp=sharing)|
67 | |[CEN](https://github.com/yikaiw/CEN)| ResNet152 | 77.0 | 64.4 | 51.6 | [Google Drive](https://drive.google.com/drive/folders/1DGF6vHLDgBgLrdUNJOLYdoXCuEKbIuRs?usp=sharing)|
68 | |Ours| SegFormer-B3 | 78.7 | 67.5 | 54.8 | [Google Drive](https://drive.google.com/drive/folders/14fi8aABFYqGF7LYKHkiJazHA58OBW1AW?usp=sharing)|
69 |
70 |
71 | Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion
72 |
73 | ## Image-to-Image Translation
74 |
75 | First,
76 | ```
77 | cd image2image_translation
78 | ```
79 | Training script, from Shade and Texture to RGB,
80 | ```
81 | python main.py --gpu 0 -c exp_name
82 | ```
83 | This script will auto-evaluate on the validation dataset every 5 training epochs.
84 |
85 | Predicted images will be automatically saved during training, in the following folder structure:
86 |
87 | ```
88 | code_root/ckpt/exp_name/results
89 | ├── input0 # 1st modality input
90 | ├── input1 # 2nd modality input
91 | ├── fake0 # 1st branch output
92 | ├── fake1 # 2nd branch output
93 | ├── fake2 # ensemble output
94 | ├── best # current best output
95 | │ ├── fake0
96 | │ ├── fake1
97 | │ └── fake2
98 | └── real # ground truth output
99 | ```
100 |
101 | Checkpoint models:
102 |
103 | | Method | Task | FID | KID | Download |
104 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|
105 | | [CEN](https://github.com/yikaiw/CEN) |Texture+Shade->RGB | 62.6 | 1.65 | - |
106 | | Ours | Texture+Shade->RGB | 45.5 | 1.00 | [Google Drive](https://drive.google.com/drive/folders/1vkcDv5bHKXZKxCg4dC7R56ts6nLLt6lh?usp=sharing)|
107 |
108 | ## 3D Object Detection (under construction)
109 |
110 | Data preparation, environments, and training scripts follow [Group-Free](https://github.com/zeliu98/Group-Free-3D) and [ImVoteNet](https://github.com/facebookresearch/imvotenet).
111 |
112 | E.g.,
113 | ```
114 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name
115 | ```
116 |
117 | ## Citation
118 |
119 | If you find our work useful for your research, please consider citing the following paper.
120 | ```
121 | @inproceedings{wang2022tokenfusion,
122 | title={Multimodal Token Fusion for Vision Transformers},
123 | author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe},
124 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
125 | year={2022}
126 | }
127 | ```
128 |
129 |
130 |
--------------------------------------------------------------------------------
/figs/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/figs/framework.png
--------------------------------------------------------------------------------
/figs/heterogeneous.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/figs/heterogeneous.png
--------------------------------------------------------------------------------
/figs/homogeneous.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/figs/homogeneous.png
--------------------------------------------------------------------------------
/image2image_translation/cfg.py:
--------------------------------------------------------------------------------
1 | num_parallel = None
2 | use_exchange = None
3 | mask_threshold = None
4 | lrnorm_threshold = None
5 | logger = None
--------------------------------------------------------------------------------
/image2image_translation/eval.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import numpy as np
3 | import torch
4 | from tqdm import tqdm
5 | from torchvision import transforms
6 | from models.model_cfg import gen_b2, dis_b2
7 | import cfg
8 | from utils import *
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--batch-size', type=int, default=1,
12 | help='train batch size')
13 | parser.add_argument('--ngf', type=int, default=64)
14 | parser.add_argument('--ndf', type=int, default=64)
15 | parser.add_argument('--input-size', type=int, default=256,
16 | help='input size')
17 | parser.add_argument('--resize-scale', type=int, default=286,
18 | help='resize scale (0 is false)')
19 | parser.add_argument('--crop-size', type=int, default=256,
20 | help='crop size (0 is false)')
21 | parser.add_argument('--fliplr', type=bool, default=True,
22 | help='random fliplr True of False')
23 | parser.add_argument('--num-epochs', type=int, default=300,
24 | help='number of train epochs')
25 | parser.add_argument('--val-every', type=int, default=5,
26 | help='how often to validate current architecture')
27 | parser.add_argument('--lrG', type=float, default=0.0002,
28 | help='learning rate for generator, default=0.0002')
29 | parser.add_argument('--lrD', type=float, default=0.0002,
30 | help='learning rate for discriminator, default=0.0002')
31 | parser.add_argument('--gama', type=float, default=100,
32 | help='gama for L1 loss')
33 | parser.add_argument('--beta1', type=float, default=0.5,
34 | help='beta1 for Adam optimizer')
35 | parser.add_argument('--beta2', type=float, default=0.999,
36 | help='beta2 for Adam optimizer')
37 | parser.add_argument('--print-loss', action='store_true', default=False,
38 | help='whether print losses during training')
39 | parser.add_argument('--gpu', type=int, nargs='+', default=[0],
40 | help='select gpu.')
41 | parser.add_argument('-c', '--ckpt', default='model', type=str, metavar='PATH',
42 | help='path to save checkpoint (default: model)')
43 | parser.add_argument('-i', '--img-types', default=[2, 7, 0], type=int, nargs='+',
44 | help='image types, last image is target, others are inputs')
45 | parser.add_argument('--exchange', type=int, default=1,
46 | help='whether use feature exchange')
47 | parser.add_argument('-l', '--lamda', type=float, default=1e-3,
48 | help='lamda for L1 norm on BN scales.')
49 | parser.add_argument('-t', '--insnorm-threshold', type=float, default=1e-2,
50 | help='threshold for slimming BNs')
51 | parser.add_argument('--enc', default=[0], type=int, nargs='+')
52 | parser.add_argument('--dec', default=[0], type=int, nargs='+')
53 | params = parser.parse_args()
54 |
55 | # Directories for loading data and saving results
56 | data_dir = '/mnt/beegfs/ssd_pool/docker/user/hadoop-automl/yikai/upload/taskonomy-sample-model-1' # 'Modify data path'
57 | # data_dir = '/data/wyk/datasets/taskonomy-sample-model-1'
58 | # data_dir = '/home1/wyk/data/taskonomy-sample-model-1'
59 | model_dir = os.path.join('ckpt', params.ckpt)
60 | save_dir = os.path.join(model_dir, 'results')
61 | save_dir_best = os.path.join(save_dir, 'best')
62 | os.makedirs(save_dir_best, exist_ok=True)
63 | os.makedirs(os.path.join(model_dir, 'insnorm_params'), exist_ok=True)
64 | os.system('cp -r *py models utils data %s' % model_dir)
65 | cfg.logger = open(os.path.join(model_dir, 'log.txt'), 'w+')
66 | print_log(params)
67 |
68 | train_file = './data/train_domain.txt'
69 | val_file = './data/val_domain.txt'
70 |
71 | domain_dicts = {0: 'rgb', 1: 'normal', 2: 'reshading', 3: 'depth_euclidean', 4: 'depth_zbuffer',
72 | 5: 'principal_curvature', 6: 'edge_occlusion', 7: 'edge_texture',
73 | 8: 'segment_unsup2d', 9: 'segment_unsup25d'}
74 | params.img_types = [domain_dicts[img_type] for img_type in params.img_types]
75 | print_log('\n' + ', '.join(params.img_types[:-1]) + ' -> ' + params.img_types[-1])
76 | num_parallel = len(params.img_types) - 1
77 |
78 | cfg.num_parallel = num_parallel
79 | cfg.use_exchange = params.exchange == 1
80 | cfg.insnorm_threshold = params.insnorm_threshold
81 | cfg.enc, cfg.dec = params.enc, params.dec
82 |
83 | # Data pre-processing
84 | transform = transforms.Compose([transforms.Resize(params.input_size),
85 | transforms.ToTensor(),
86 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
87 |
88 | # Train data
89 | train_data = DatasetFromFolder(data_dir, train_file, params.img_types, transform=transform,
90 | resize_scale=params.resize_scale, crop_size=params.crop_size,
91 | fliplr=params.fliplr)
92 | train_data_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=params.batch_size,
93 | shuffle=True, drop_last=False)
94 |
95 | # Test data
96 | test_data = DatasetFromFolder(data_dir, val_file, params.img_types, transform=transform)
97 | test_data_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=params.batch_size,
98 | shuffle=False, drop_last=False)
99 | # test_input, test_target = test_data_loader.__iter__().__next__()
100 |
101 | # Models
102 | torch.cuda.set_device(params.gpu[0])
103 | # G = Generator(3, params.ngf, 3)
104 | G = gen_b2()
105 | # D = Discriminator(6, params.ndf, 1)
106 | D = dis_b2(img_size=256, patch_size=4)
107 | G.cuda()
108 | G = torch.nn.DataParallel(G, params.gpu)
109 | D.cuda()
110 | D = torch.nn.DataParallel(D, params.gpu)
111 | state_dict = torch.load('./checkpoint-gen-399.pkl')
112 | G.load_state_dict(state_dict, strict=True)
113 |
114 | BCE_loss = torch.nn.BCELoss().cuda()
115 | L2_loss = torch.nn.MSELoss().cuda()
116 | L1_loss = torch.nn.L1Loss().cuda()
117 |
118 |
119 | def evaluate(G, epoch, training):
120 | num_parallel_ = 1 if num_parallel == 1 else num_parallel + 1
121 | l1_losses = init_lists(num_parallel_)
122 | l2_losses = init_lists(num_parallel_)
123 | fids = init_lists(num_parallel_)
124 | kids = init_lists(num_parallel_)
125 | for i, (test_inputs, test_target) in tqdm(enumerate(test_data_loader), miniters=25, total=len(test_data_loader)):
126 | # for i, (test_inputs, test_target) in enumerate(test_data_loader):
127 | # Show result for test image
128 | test_inputs_cuda = [test_input.cuda() for test_input in test_inputs]
129 | gen_images, alpha_soft, _ = G(test_inputs_cuda)
130 | test_target_cuda = test_target.cuda()
131 | for l, gen_image in enumerate(gen_images):
132 | if l < num_parallel or num_parallel > 1:
133 | l1_losses[l].append(L1_loss(gen_image, test_target_cuda).item())
134 | l2_losses[l].append(L2_loss(gen_image, test_target_cuda).item())
135 | gen_image = gen_image.cpu().data
136 | save_dir_ = os.path.join(save_dir, 'fake%d' % l)
137 | plot_test_result_single(gen_image, i, save_dir=save_dir_)
138 | if l < num_parallel:
139 | save_dir_ = os.path.join(save_dir, 'input%d' % l)
140 | if not os.path.exists(os.path.join(save_dir_, '%03d.png' % i)):
141 | plot_test_result_single(test_inputs[l], i, save_dir=save_dir_)
142 | save_dir_ = os.path.join(save_dir, 'real')
143 | if not os.path.exists(os.path.join(save_dir_, '%03d.png' % i)):
144 | plot_test_result_single(test_target, i, save_dir=save_dir_)
145 | # break
146 |
147 | for l in range(num_parallel_):
148 | paths = [os.path.join(save_dir, 'fake%d' % l), os.path.join(save_dir, 'real')]
149 | fid, kid = calculate_given_paths(paths, batch_size=50, cuda=True, dims=2048)
150 | fids[l], kids[l] = fid, kid
151 |
152 | l1_avg_losses = [torch.mean(torch.FloatTensor(l1_losses_)) for l1_losses_ in l1_losses]
153 | l2_avg_losses = [torch.mean(torch.FloatTensor(l2_losses_)) for l2_losses_ in l2_losses]
154 | return l1_avg_losses, l2_avg_losses, fids, kids
155 |
156 |
157 | l1_avg_losses, l2_avg_losses, fids, kids = evaluate(G, epoch=0, training=True)
158 | for l in range(len(l1_avg_losses)):
159 | l1_avg_loss, rl2_avg_loss = l1_avg_losses[l], l2_avg_losses[l]** 0.5
160 | fid, kid = fids[l], kids[l]
161 | if l < num_parallel:
162 | img_type_str = '(%s)' % params.img_types[l][:10]
163 | else:
164 | img_type_str = '(ens)'
165 | print_log('Epoch %3d %-15s l1_avg_loss: %.5f rl2_avg_loss: %.5f fid: %.3f kid: %.3f' % \
166 | (0, img_type_str, l1_avg_loss, rl2_avg_loss, fid, kid))
167 |
168 | cfg.logger.close()
169 |
--------------------------------------------------------------------------------
/image2image_translation/models/model_cfg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .model_pruning import MixVisionTransformerGen, MixVisionTransformerDis
3 | from .modules import LayerNormParallel
4 | from functools import partial
5 |
6 |
7 | class gen_b0(MixVisionTransformerGen):
8 | def __init__(self, **kwargs):
9 | super(gen_b0, self).__init__(
10 | patch_size=4, embed_dims=[32, 64, 128, 256, 512], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
11 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
12 |
13 |
14 | class gen_b1(MixVisionTransformerGen):
15 | def __init__(self, **kwargs):
16 | super(gen_b1, self).__init__(
17 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
18 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
19 |
20 |
21 | class gen_b2(MixVisionTransformerGen):
22 | def __init__(self, **kwargs):
23 | super(gen_b2, self).__init__(
24 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
25 | qkv_bias=True, depths=[3, 4, 4, 6, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
26 |
27 |
28 | class gen_b3(MixVisionTransformerGen):
29 | def __init__(self, **kwargs):
30 | super(gen_b3, self).__init__(
31 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
32 | qkv_bias=True, depths=[3, 4, 4, 18, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
33 |
34 |
35 | class gen_b4(MixVisionTransformerGen):
36 | def __init__(self, **kwargs):
37 | super(gen_b4, self).__init__(
38 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
39 | qkv_bias=True, depths=[3, 8, 8, 27, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
40 |
41 |
42 | class gen_b5(MixVisionTransformerGen):
43 | def __init__(self, **kwargs):
44 | super(gen_b5, self).__init__(
45 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
46 | qkv_bias=True, depths=[3, 6, 6, 40, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
47 |
48 |
49 | class dis_b0(MixVisionTransformerDis):
50 | def __init__(self, **kwargs):
51 | super(dis_b0, self).__init__(
52 | patch_size=4, embed_dims=[32, 64, 128, 256, 512], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
53 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
54 |
55 |
56 | class dis_b1(MixVisionTransformerDis):
57 | def __init__(self, **kwargs):
58 | super(dis_b1, self).__init__(
59 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
60 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
61 |
62 |
63 | class dis_b2(MixVisionTransformerDis):
64 | def __init__(self, **kwargs):
65 | super(dis_b2, self).__init__(
66 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
67 | qkv_bias=True, depths=[3, 4, 4, 6, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
68 |
69 |
70 | class dis_b3(MixVisionTransformerDis):
71 | def __init__(self, **kwargs):
72 | super(dis_b3, self).__init__(
73 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
74 | qkv_bias=True, depths=[3, 4, 4, 18, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
75 |
76 |
77 | class dis_b4(MixVisionTransformerDis):
78 | def __init__(self, **kwargs):
79 | super(dis_b4, self).__init__(
80 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
81 | qkv_bias=True, depths=[3, 8, 8, 27, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
82 |
83 |
84 | class dis_b5(MixVisionTransformerDis):
85 | def __init__(self, **kwargs):
86 | super(dis_b5, self).__init__(
87 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4],
88 | qkv_bias=True, depths=[3, 6, 6, 40, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1)
89 |
--------------------------------------------------------------------------------
/image2image_translation/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import cfg
3 | import torch
4 |
5 |
6 | class TokenExchange(nn.Module):
7 | def __init__(self):
8 | super(TokenExchange, self).__init__()
9 |
10 | def forward(self, x, mask, mask_threshold):
11 | # x: [B, N, C], mask: [B, N, 2]
12 | x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1])
13 | x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold]
14 | x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold]
15 | x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold]
16 | x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold]
17 | return [x0, x1]
18 |
19 |
20 | class ChannelExchange(nn.Module):
21 | def __init__(self):
22 | super(ChannelExchange, self).__init__()
23 |
24 | def forward(self, x, lrnorm, lrnorm_threshold):
25 | lrnorm0, lrnorm1 = lrnorm[0].weight.abs(), lrnorm[1].weight.abs()
26 | x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1])
27 | x0[:, lrnorm0 >= lrnorm_threshold] = x[0][:, lrnorm0 >= lrnorm_threshold]
28 | x0[:, lrnorm0 < lrnorm_threshold] = x[1][:, lrnorm0 < lrnorm_threshold]
29 | x1[:, lrnorm1 >= lrnorm_threshold] = x[1][:, lrnorm1 >= lrnorm_threshold]
30 | x1[:, lrnorm1 < lrnorm_threshold] = x[0][:, lrnorm1 < lrnorm_threshold]
31 | return [x0, x1]
32 |
33 |
34 | class ModuleParallel(nn.Module):
35 | def __init__(self, module):
36 | super(ModuleParallel, self).__init__()
37 | self.module = module
38 |
39 | def forward(self, x_parallel):
40 | return [self.module(x) for x in x_parallel]
41 |
42 |
43 | class LayerNormParallel(nn.Module):
44 | def __init__(self, num_features):
45 | super(LayerNormParallel, self).__init__()
46 | for i in range(cfg.num_parallel):
47 | setattr(self, 'lrnorm_' + str(i), nn.LayerNorm(num_features, eps=1e-6))
48 |
49 | def forward(self, x_parallel):
50 | return [getattr(self, 'lrnorm_' + str(i))(x) for i, x in enumerate(x_parallel)]
51 |
--------------------------------------------------------------------------------
/image2image_translation/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import *
2 | from .utils import *
3 | from .fid_kid import calculate_given_paths
4 |
--------------------------------------------------------------------------------
/image2image_translation/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # Custom dataset
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import numpy as np
5 | import os
6 | import random
7 |
8 | max_v = {'edge_texture': 11355, 'edge_occlusion': 11584, 'depth_euclidean': 11.1, 'depth_zbuffer': 11.1}
9 | min_v = {'edge_texture': 0, 'edge_occlusion': 0, 'depth_euclidean': 0, 'depth_zbuffer': 0}
10 |
11 |
12 | def line_to_path_fn(x, data_dir):
13 | path = x.decode('utf-8').strip('\n')
14 | return os.path.join(data_dir, path)
15 |
16 |
17 | class DatasetFromFolder(data.Dataset):
18 | def __init__(self, data_dir, data_file, img_types, transform=None,
19 | resize_scale=None, crop_size=None, fliplr=False, is_cls=False):
20 | super(DatasetFromFolder, self).__init__()
21 | with open(data_file, 'rb') as f:
22 | data_list = f.readlines()
23 | self.data_list = [line_to_path_fn(line, data_dir) for line in data_list]
24 | self.img_types = img_types
25 | self.transform = transform
26 | self.resize_scale = resize_scale
27 | self.crop_size = crop_size
28 | self.fliplr = fliplr
29 | self.is_cls = is_cls
30 |
31 | def __getitem__(self, index):
32 | # Load Image
33 | domain_path = self.data_list[index]
34 | if self.is_cls:
35 | img_types = self.img_types[:-1]
36 | cls_target = self.img_types[-1]
37 | else:
38 | img_types = self.img_types
39 | img_paths = [domain_path.replace('{domain}', img_type) for img_type in img_types]
40 | imgs = [Image.open(img_path) for img_path in img_paths]
41 |
42 | for l in range(len(imgs)):
43 | img = np.array(imgs[l])
44 | img_type = img_types[l]
45 | update = False
46 | if len(img.shape) == 2:
47 | img = img[:,:, np.newaxis]
48 | img = np.concatenate([img] * 3, 2)
49 | update = True
50 | if 'depth' in img_type:
51 | img = np.log(1 + img)
52 | update = True
53 | if img_type in max_v:
54 | img = (img - min_v[img_type]) * 255.0 / (max_v[img_type] - min_v[img_type])
55 | update = True
56 | if update:
57 | imgs[l] = Image.fromarray(img.astype('uint8'))
58 |
59 | if self.resize_scale:
60 | imgs = [img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR) \
61 | for img in imgs]
62 | if self.crop_size:
63 | x = random.randint(0, self.resize_scale - self.crop_size + 1)
64 | y = random.randint(0, self.resize_scale - self.crop_size + 1)
65 | imgs = [img.crop((x, y, x + self.crop_size, y + self.crop_size)) for img in imgs]
66 | if self.fliplr:
67 | if random.random() < 0.5:
68 | imgs = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs]
69 | if self.transform is not None:
70 | imgs = [self.transform(img) for img in imgs]
71 |
72 | if self.is_cls:
73 | inputs = imgs
74 | target = np.load(domain_path.replace('{domain}', cls_target).\
75 | replace('png', 'npy').replace('scene.npy', 'places.npy'))
76 | target = np.argmax(target)
77 | else:
78 | inputs, target = imgs[:-1], imgs[-1]
79 |
80 | return inputs, target
81 |
82 | def __len__(self):
83 | return len(self.data_list)
84 |
--------------------------------------------------------------------------------
/image2image_translation/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | plt.switch_backend('agg')
5 | import os
6 | import imageio
7 | import cfg
8 |
9 |
10 | def print_log(message):
11 | print(message, flush=True)
12 | if cfg.logger:
13 | cfg.logger.write(str(message) + '\n')
14 |
15 |
16 | # Plot losses
17 | def plot_loss(d_losses, g_losses, num_epochs, save=True, save_dir='results/', show=False):
18 | fig, ax = plt.subplots()
19 | ax_ = ax.twinx()
20 | ax.set_xlim(0, num_epochs)
21 | # ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses)) * 1.1)
22 | plt.xlabel('# of Epochs')
23 | ax.set_ylabel('Generator loss values')
24 | ax_.set_ylabel('Discriminator loss values')
25 | ax.plot(g_losses, label='Generator')
26 | ax_.plot(d_losses, label='Discriminator')
27 | plt.legend()
28 |
29 | # save figure
30 | if save:
31 | if not os.path.exists(save_dir):
32 | os.mkdir(save_dir)
33 | save_fn = save_dir + 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png'
34 | plt.savefig(save_fn)
35 |
36 | if show:
37 | plt.show()
38 | else:
39 | plt.close()
40 |
41 |
42 | def plot_test_result_single(image, image_idx, save_dir='results/', fig_size=(5, 5)):
43 | # fig_size = (target.size(2) / 100, target.size(3) / 100)
44 | # fig, ax = plt.subplots(figsize=fig_size)
45 | fig, ax = plt.subplots()
46 | img = image
47 |
48 | ax.axis('off')
49 | ax.set_adjustable('box')
50 | # Scale to 0-255
51 | img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min()))\
52 | .numpy().transpose(1, 2, 0).astype(np.uint8)
53 | ax.imshow(img, cmap=None, aspect='equal')
54 | plt.subplots_adjust(wspace=0, hspace=0)
55 |
56 | save_path = save_dir
57 | os.makedirs(save_path, exist_ok=True)
58 | save_path = os.path.join(save_path, '%03d.png' % image_idx)
59 | fig.subplots_adjust(bottom=0)
60 | fig.subplots_adjust(top=1)
61 | fig.subplots_adjust(right=1)
62 | fig.subplots_adjust(left=0)
63 |
64 | foo_fig = plt.gcf()
65 | foo_fig.set_size_inches(5, 5)
66 | foo_fig.savefig(save_path, dpi=200, bbox_inches='tight')
67 | plt.savefig(save_path)
68 | plt.close()
69 |
70 |
71 | def plot_test_result(input, target, gen_image, image_idx, img_title, epoch, training=True,
72 | save=True, save_dir='results/', show=False, fig_size=(5, 5)):
73 | if input is not None:
74 | fig_size = (target.size(2) * 3 / 100, target.size(3) / 100)
75 | imgs = [input, gen_image, target]
76 | fig, axes = plt.subplots(1, 3, figsize=fig_size)
77 | else:
78 | fig_size = (target.size(2) * 2 / 100, target.size(3) / 100)
79 | imgs = [gen_image, target]
80 | fig, axes = plt.subplots(1, 2, figsize=fig_size)
81 |
82 | for ax, img in zip(axes.flatten(), imgs):
83 | ax.axis('off')
84 | ax.set_adjustable('box')
85 | # Scale to 0-255
86 | img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min()))\
87 | .numpy().transpose(1, 2, 0).astype(np.uint8)
88 | ax.imshow(img, cmap=None, aspect='equal')
89 | plt.subplots_adjust(wspace=0, hspace=0)
90 |
91 | # save figure
92 | if save:
93 | # save_path = os.path.join(save_dir, str(image_idx))
94 | save_path = save_dir
95 | os.makedirs(save_path, exist_ok=True)
96 | if training:
97 | save_path = os.path.join(save_path, '%03d_%s.png' % (image_idx, img_title))
98 | else:
99 | save_path = os.path.join(save_path, 'Test_%03d_%s.png' % (image_idx, img_title))
100 | fig.subplots_adjust(bottom=0)
101 | fig.subplots_adjust(top=1)
102 | fig.subplots_adjust(right=1)
103 | fig.subplots_adjust(left=0)
104 | plt.savefig(save_path)
105 |
106 | if show:
107 | plt.show()
108 | else:
109 | plt.close()
110 |
111 |
112 | def maybe_download(model_name, model_url, model_dir=None, map_location=None):
113 | import os, sys
114 | from six.moves import urllib
115 | if model_dir is None:
116 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
117 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
118 | if not os.path.exists(model_dir):
119 | os.makedirs(model_dir)
120 | filename = '{}.pth.tar'.format(model_name)
121 | cached_file = os.path.join(model_dir, filename)
122 | if not os.path.exists(cached_file):
123 | url = model_url
124 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
125 | urllib.request.urlretrieve(url, cached_file)
126 | if '152' in cached_file:
127 | cached_file = '/home/anbang/.cache/torch/checkpoints/resnet152-b121ed2d.pth'
128 | return torch.load(cached_file, map_location=map_location)
129 |
130 |
131 | # Make gif
132 | def make_gif(dataset, num_epochs, save_dir='results/'):
133 | gen_image_plots = []
134 | for epoch in range(num_epochs):
135 | # plot for generating gif
136 | save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png'
137 | gen_image_plots.append(imageio.imread(save_fn))
138 |
139 | imageio.mimsave(save_dir + dataset + '_pix2pix_epochs_{:d}'.format(num_epochs) \
140 | + '.gif', gen_image_plots, fps=5)
141 |
142 |
143 | def init_lists(length):
144 | lists = []
145 | for l in range(length):
146 | lists.append([])
147 | return lists
148 |
149 |
150 | class AverageMeter(object):
151 | """Computes and stores the average and current value
152 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
153 | """
154 | def __init__(self):
155 | self.reset()
156 |
157 | def reset(self):
158 | self.val = 0
159 | self.avg = 0
160 | self.sum = 0
161 | self.count = 0
162 |
163 | def update(self, val, n=1):
164 | self.val = val
165 | self.sum += val * n
166 | self.count += n
167 | self.avg = self.sum / self.count
168 |
169 |
170 | def accuracy(output, target, topk=(1,)):
171 | """Computes the precision@k for the specified values of k"""
172 | with torch.no_grad():
173 | maxk = max(topk)
174 | batch_size = target.size(0)
175 | _, pred = output.topk(maxk, 1, True, True)
176 | pred = pred.t()
177 | correct = pred.eq(target.view(1, -1).expand_as(pred))
178 | res = []
179 | for k in topk:
180 | correct_k = correct[:k].view(-1).float().sum(0)
181 | res.append(correct_k.mul_(100.0 / batch_size))
182 | return res
183 |
--------------------------------------------------------------------------------
/object-detection-3d/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .detector import GroupFreeDetector
2 | from .loss_helper import get_loss
3 | from .ap_helper import APCalculator, parse_predictions, parse_groundtruths
4 | from .dump_helper import dump_results
5 |
--------------------------------------------------------------------------------
/object-detection-3d/models/dump_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import torch
8 | import os
9 | import sys
10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11 | ROOT_DIR = os.path.dirname(BASE_DIR)
12 | sys.path.append(os.path.join(ROOT_DIR, 'utils'))
13 | import pc_util
14 |
15 | DUMP_CONF_THRESH = 0.5 # Dump boxes with obj prob larger than that.
16 |
17 | def softmax(x):
18 | ''' Numpy function for softmax'''
19 | shape = x.shape
20 | probs = np.exp(x - np.max(x, axis=len(shape)-1, keepdims=True))
21 | probs /= np.sum(probs, axis=len(shape)-1, keepdims=True)
22 | return probs
23 |
24 | def dump_results(end_points, dump_dir, config, inference_switch=False):
25 | ''' Dump results.
26 |
27 | Args:
28 | end_points: dict
29 | {..., pred_mask}
30 | pred_mask is a binary mask array of size (batch_size, num_proposal) computed by running NMS and empty box removal
31 | Returns:
32 | None
33 | '''
34 | if not os.path.exists(dump_dir):
35 | os.system('mkdir %s'%(dump_dir))
36 |
37 | # INPUT
38 | print(end_points.keys())
39 | point_clouds = end_points['point_clouds'].cpu().numpy()
40 | batch_size = point_clouds.shape[0]
41 |
42 | # NETWORK OUTPUTS
43 | seed_xyz = end_points['seed_xyz'].detach().cpu().numpy() # (B,num_seed,3)
44 | if 'vote_xyz' in end_points:
45 | aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy()
46 | vote_xyz = end_points['vote_xyz'].detach().cpu().numpy() # (B,num_seed,3)
47 | aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy()
48 | objectness_scores = end_points['objectness_scores'].detach().cpu().numpy() # (B,K,2)
49 | pred_center = end_points['center'].detach().cpu().numpy() # (B,K,3)
50 | pred_heading_class = torch.argmax(end_points['heading_scores'], -1) # B,num_proposal
51 | pred_heading_residual = torch.gather(end_points['heading_residuals'], 2, pred_heading_class.unsqueeze(-1)) # B,num_proposal,1
52 | pred_heading_class = pred_heading_class.detach().cpu().numpy() # B,num_proposal
53 | pred_heading_residual = pred_heading_residual.squeeze(2).detach().cpu().numpy() # B,num_proposal
54 | pred_size_class = torch.argmax(end_points['size_scores'], -1) # B,num_proposal
55 | pred_size_residual = torch.gather(end_points['size_residuals'], 2, pred_size_class.unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,3)) # B,num_proposal,1,3
56 | pred_size_residual = pred_size_residual.squeeze(2).detach().cpu().numpy() # B,num_proposal,3
57 |
58 | # OTHERS
59 | pred_mask = end_points['pred_mask'] # B,num_proposal
60 | idx_beg = 0
61 |
62 | for i in range(batch_size):
63 | pc = point_clouds[i,:,:]
64 | objectness_prob = softmax(objectness_scores[i,:,:])[:,1] # (K,)
65 |
66 | # Dump various point clouds
67 | pc_util.write_ply(pc, os.path.join(dump_dir, '%06d_pc.ply'%(idx_beg+i)))
68 | pc_util.write_ply(seed_xyz[i,:,:], os.path.join(dump_dir, '%06d_seed_pc.ply'%(idx_beg+i)))
69 | if 'vote_xyz' in end_points:
70 | pc_util.write_ply(end_points['vote_xyz'][i,:,:], os.path.join(dump_dir, '%06d_vgen_pc.ply'%(idx_beg+i)))
71 | pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i)))
72 | pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i)))
73 | pc_util.write_ply(pred_center[i,:,0:3], os.path.join(dump_dir, '%06d_proposal_pc.ply'%(idx_beg+i)))
74 | if np.sum(objectness_prob>DUMP_CONF_THRESH)>0:
75 | pc_util.write_ply(pred_center[i,objectness_prob>DUMP_CONF_THRESH,0:3], os.path.join(dump_dir, '%06d_confident_proposal_pc.ply'%(idx_beg+i)))
76 |
77 | # Dump predicted bounding boxes
78 | if np.sum(objectness_prob>DUMP_CONF_THRESH)>0:
79 | num_proposal = pred_center.shape[1]
80 | obbs = []
81 | for j in range(num_proposal):
82 | obb = config.param2obb(pred_center[i,j,0:3], pred_heading_class[i,j], pred_heading_residual[i,j],
83 | pred_size_class[i,j], pred_size_residual[i,j])
84 | obbs.append(obb)
85 | if len(obbs)>0:
86 | obbs = np.vstack(tuple(obbs)) # (num_proposal, 7)
87 | pc_util.write_oriented_bbox(obbs[objectness_prob>DUMP_CONF_THRESH,:], os.path.join(dump_dir, '%06d_pred_confident_bbox.ply'%(idx_beg+i)))
88 | pc_util.write_oriented_bbox(obbs[np.logical_and(objectness_prob>DUMP_CONF_THRESH, pred_mask[i,:]==1),:], os.path.join(dump_dir, '%06d_pred_confident_nms_bbox.ply'%(idx_beg+i)))
89 | pc_util.write_oriented_bbox(obbs[pred_mask[i,:]==1,:], os.path.join(dump_dir, '%06d_pred_nms_bbox.ply'%(idx_beg+i)))
90 | pc_util.write_oriented_bbox(obbs, os.path.join(dump_dir, '%06d_pred_bbox.ply'%(idx_beg+i)))
91 |
92 | # Return if it is at inference time. No dumping of groundtruths
93 | if inference_switch:
94 | return
95 |
96 | # LABELS
97 | gt_center = end_points['center_label'].cpu().numpy() # (B,MAX_NUM_OBJ,3)
98 | gt_mask = end_points['box_label_mask'].cpu().numpy() # B,K2
99 | gt_heading_class = end_points['heading_class_label'].cpu().numpy() # B,K2
100 | gt_heading_residual = end_points['heading_residual_label'].cpu().numpy() # B,K2
101 | gt_size_class = end_points['size_class_label'].cpu().numpy() # B,K2
102 | gt_size_residual = end_points['size_residual_label'].cpu().numpy() # B,K2,3
103 | objectness_label = end_points['objectness_label'].detach().cpu().numpy() # (B,K,)
104 | objectness_mask = end_points['objectness_mask'].detach().cpu().numpy() # (B,K,)
105 |
106 | for i in range(batch_size):
107 | if np.sum(objectness_label[i,:])>0:
108 | pc_util.write_ply(pred_center[i,objectness_label[i,:]>0,0:3], os.path.join(dump_dir, '%06d_gt_positive_proposal_pc.ply'%(idx_beg+i)))
109 | if np.sum(objectness_mask[i,:])>0:
110 | pc_util.write_ply(pred_center[i,objectness_mask[i,:]>0,0:3], os.path.join(dump_dir, '%06d_gt_mask_proposal_pc.ply'%(idx_beg+i)))
111 | pc_util.write_ply(gt_center[i,:,0:3], os.path.join(dump_dir, '%06d_gt_centroid_pc.ply'%(idx_beg+i)))
112 | pc_util.write_ply_color(pred_center[i,:,0:3], objectness_label[i,:], os.path.join(dump_dir, '%06d_proposal_pc_objectness_label.obj'%(idx_beg+i)))
113 |
114 | # Dump GT bounding boxes
115 | obbs = []
116 | for j in range(gt_center.shape[1]):
117 | if gt_mask[i,j] == 0: continue
118 | obb = config.param2obb(gt_center[i,j,0:3], gt_heading_class[i,j], gt_heading_residual[i,j],
119 | gt_size_class[i,j], gt_size_residual[i,j])
120 | obbs.append(obb)
121 | if len(obbs)>0:
122 | obbs = np.vstack(tuple(obbs)) # (num_gt_objects, 7)
123 | pc_util.write_oriented_bbox(obbs, os.path.join(dump_dir, '%06d_gt_bbox.ply'%(idx_beg+i)))
124 |
125 | # OPTIONALL, also dump prediction and gt details
126 | if 'batch_pred_map_cls' in end_points:
127 | for ii in range(batch_size):
128 | fout = open(os.path.join(dump_dir, '%06d_pred_map_cls.txt'%(ii)), 'w')
129 | for t in end_points['batch_pred_map_cls'][ii]:
130 | fout.write(str(t[0])+' ')
131 | fout.write(",".join([str(x) for x in list(t[1].flatten())]))
132 | fout.write(' '+str(t[2]))
133 | fout.write('\n')
134 | fout.close()
135 | if 'batch_gt_map_cls' in end_points:
136 | for ii in range(batch_size):
137 | fout = open(os.path.join(dump_dir, '%06d_gt_map_cls.txt'%(ii)), 'w')
138 | for t in end_points['batch_gt_map_cls'][ii]:
139 | fout.write(str(t[0])+' ')
140 | fout.write(",".join([str(x) for x in list(t[1].flatten())]))
141 | fout.write('\n')
142 | fout.close()
143 |
144 |
--------------------------------------------------------------------------------
/object-detection-3d/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .weight_init import trunc_normal_
2 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
3 | from .helper import to_ntuple, to_2tuple, to_3tuple, to_4tuple
4 |
--------------------------------------------------------------------------------
/object-detection-3d/models/layers/drop.py:
--------------------------------------------------------------------------------
1 | """ DropBlock, DropPath
2 |
3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4 |
5 | Papers:
6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7 |
8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9 |
10 | Code:
11 | DropBlock impl inspired by two Tensorflow impl that I liked:
12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman
16 | """
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | def drop_block_2d(
23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
26 |
27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
28 | runs with success, but needs further validation and possibly optimization for lower runtime impact.
29 | """
30 | B, C, H, W = x.shape
31 | total_size = W * H
32 | clipped_block_size = min(block_size, min(W, H))
33 | # seed_drop_rate, the gamma parameter
34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
35 | (W - block_size + 1) * (H - block_size + 1))
36 |
37 | # Forces the block to be inside the feature map.
38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
42 |
43 | if batchwise:
44 | # one mask for whole batch, quite a bit faster
45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
46 | else:
47 | uniform_noise = torch.rand_like(x)
48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
49 | block_mask = -F.max_pool2d(
50 | -block_mask,
51 | kernel_size=clipped_block_size, # block_size,
52 | stride=1,
53 | padding=clipped_block_size // 2)
54 |
55 | if with_noise:
56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
57 | if inplace:
58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
59 | else:
60 | x = x * block_mask + normal_noise * (1 - block_mask)
61 | else:
62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
63 | if inplace:
64 | x.mul_(block_mask * normalize_scale)
65 | else:
66 | x = x * block_mask * normalize_scale
67 | return x
68 |
69 |
70 | def drop_block_fast_2d(
71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
74 |
75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
76 | block mask at edges.
77 | """
78 | B, C, H, W = x.shape
79 | total_size = W * H
80 | clipped_block_size = min(block_size, min(W, H))
81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
82 | (W - block_size + 1) * (H - block_size + 1))
83 |
84 | if batchwise:
85 | # one mask for whole batch, quite a bit faster
86 | block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
87 | else:
88 | # mask per batch element
89 | block_mask = torch.rand_like(x) < gamma
90 | block_mask = F.max_pool2d(
91 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
92 |
93 | if with_noise:
94 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
95 | if inplace:
96 | x.mul_(1. - block_mask).add_(normal_noise * block_mask)
97 | else:
98 | x = x * (1. - block_mask) + normal_noise * block_mask
99 | else:
100 | block_mask = 1 - block_mask
101 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
102 | if inplace:
103 | x.mul_(block_mask * normalize_scale)
104 | else:
105 | x = x * block_mask * normalize_scale
106 | return x
107 |
108 |
109 | class DropBlock2d(nn.Module):
110 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
111 | """
112 | def __init__(self,
113 | drop_prob=0.1,
114 | block_size=7,
115 | gamma_scale=1.0,
116 | with_noise=False,
117 | inplace=False,
118 | batchwise=False,
119 | fast=True):
120 | super(DropBlock2d, self).__init__()
121 | self.drop_prob = drop_prob
122 | self.gamma_scale = gamma_scale
123 | self.block_size = block_size
124 | self.with_noise = with_noise
125 | self.inplace = inplace
126 | self.batchwise = batchwise
127 | self.fast = fast # FIXME finish comparisons of fast vs not
128 |
129 | def forward(self, x):
130 | if not self.training or not self.drop_prob:
131 | return x
132 | if self.fast:
133 | return drop_block_fast_2d(
134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
135 | else:
136 | return drop_block_2d(
137 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
138 |
139 |
140 | def drop_path(x, drop_prob: float = 0., training: bool = False):
141 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
142 |
143 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
144 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
145 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
146 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
147 | 'survival rate' as the argument.
148 |
149 | """
150 | if drop_prob == 0. or not training:
151 | return x
152 | keep_prob = 1 - drop_prob
153 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
154 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
155 | random_tensor.floor_() # binarize
156 | output = x.div(keep_prob) * random_tensor
157 | return output
158 |
159 |
160 | class DropPath(nn.Module):
161 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
162 | """
163 | def __init__(self, drop_prob=None):
164 | super(DropPath, self).__init__()
165 | self.drop_prob = drop_prob
166 |
167 | def forward(self, x):
168 | return drop_path(x, self.drop_prob, self.training)
169 |
--------------------------------------------------------------------------------
/object-detection-3d/models/layers/helper.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | from torch._six import container_abcs
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, container_abcs.Iterable):
13 | return x
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/object-detection-3d/models/layers/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 |
5 |
6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
9 | def norm_cdf(x):
10 | # Computes standard normal cumulative distribution function
11 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
12 |
13 | if (mean < a - 2 * std) or (mean > b + 2 * std):
14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15 | "The distribution of values may be incorrect.",
16 | stacklevel=2)
17 |
18 | with torch.no_grad():
19 | # Values are generated by using a truncated uniform distribution and
20 | # then using the inverse CDF for the normal distribution.
21 | # Get upper and lower cdf values
22 | l = norm_cdf((a - mean) / std)
23 | u = norm_cdf((b - mean) / std)
24 |
25 | # Uniformly fill tensor with values from [l, u], then translate to
26 | # [2l-1, 2u-1].
27 | tensor.uniform_(2 * l - 1, 2 * u - 1)
28 |
29 | # Use inverse cdf transform for normal distribution to get truncated
30 | # standard normal
31 | tensor.erfinv_()
32 |
33 | # Transform to proper mean, std
34 | tensor.mul_(std * math.sqrt(2.))
35 | tensor.add_(mean)
36 |
37 | # Clamp to ensure it's in the proper range
38 | tensor.clamp_(min=a, max=b)
39 | return tensor
40 |
41 |
42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
43 | # type: (Tensor, float, float, float, float) -> Tensor
44 | r"""Fills the input Tensor with values drawn from a truncated
45 | normal distribution. The values are effectively drawn from the
46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
47 | with values outside :math:`[a, b]` redrawn until they are within
48 | the bounds. The method used for generating the random values works
49 | best when :math:`a \leq \text{mean} \leq b`.
50 | Args:
51 | tensor: an n-dimensional `torch.Tensor`
52 | mean: the mean of the normal distribution
53 | std: the standard deviation of the normal distribution
54 | a: the minimum cutoff value
55 | b: the maximum cutoff value
56 | Examples:
57 | >>> w = torch.empty(3, 5)
58 | >>> nn.init.trunc_normal_(w)
59 | """
60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
61 |
--------------------------------------------------------------------------------
/object-detection-3d/models/matcher.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Modules to compute the matching cost and solve the corresponding LSAP.
4 | """
5 | import torch
6 | from scipy.optimize import linear_sum_assignment
7 | from torch import nn
8 |
9 | from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
10 |
11 |
12 | class HungarianMatcher(nn.Module):
13 | """This class computes an assignment between the targets and the predictions of the network
14 |
15 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
16 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
17 | while the others are un-matched (and thus treated as non-objects).
18 | """
19 |
20 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
21 | """Creates the matcher
22 |
23 | Params:
24 | cost_class: This is the relative weight of the classification error in the matching cost
25 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
26 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
27 | """
28 | super().__init__()
29 | self.cost_class = cost_class
30 | self.cost_bbox = cost_bbox
31 | self.cost_giou = cost_giou
32 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
33 |
34 | @torch.no_grad()
35 | def forward(self, outputs, targets):
36 | """ Performs the matching
37 |
38 | Params:
39 | outputs: This is a dict that contains at least these entries:
40 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
41 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
42 |
43 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
44 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
45 | objects in the target) containing the class labels
46 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
47 |
48 | Returns:
49 | A list of size batch_size, containing tuples of (index_i, index_j) where:
50 | - index_i is the indices of the selected predictions (in order)
51 | - index_j is the indices of the corresponding selected targets (in order)
52 | For each batch element, it holds:
53 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
54 | """
55 | bs, num_queries = outputs["pred_logits2d"].shape[:2]
56 |
57 | # We flatten to compute the cost matrices in a batch
58 | out_prob = outputs["pred_logits2d"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
59 | out_bbox = outputs["pred_boxes2d"].flatten(0, 1) # [batch_size * num_queries, 4]
60 |
61 | # Also concat the target labels and boxes
62 | tgt_ids = torch.cat([v["labels"] for v in targets])
63 | tgt_bbox = torch.cat([v["boxes"] for v in targets])
64 |
65 | # Compute the classification cost. Contrary to the loss, we don't use the NLL,
66 | # but approximate it in 1 - proba[target class].
67 | # The 1 is a constant that doesn't change the matching, it can be ommitted.
68 | cost_class = -out_prob[:, tgt_ids]
69 |
70 | # Compute the L1 cost between boxes
71 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
72 |
73 | # Compute the giou cost betwen boxes
74 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
75 |
76 | # Final cost matrix
77 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
78 | C = C.view(bs, num_queries, -1).cpu()
79 |
80 | sizes = [len(v["boxes"]) for v in targets]
81 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
82 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
83 |
84 |
85 | def build_matcher(weight_dict):
86 | return HungarianMatcher(cost_class=weight_dict['loss_ce'], cost_bbox=weight_dict['loss_bbox'],
87 | cost_giou=weight_dict['loss_giou'])
88 |
--------------------------------------------------------------------------------
/object-detection-3d/models/pointnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import torch
7 | import torch.nn as nn
8 | import sys
9 | import os
10 |
11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12 | ROOT_DIR = os.path.dirname(BASE_DIR)
13 | sys.path.append(ROOT_DIR)
14 | sys.path.append(os.path.join(ROOT_DIR, 'utils'))
15 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2'))
16 | sys.path.append(os.path.join(ROOT_DIR, 'ops', 'pt_custom_ops'))
17 |
18 | from pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule
19 |
20 |
21 | class Pointnet2Backbone(nn.Module):
22 | r"""
23 | Backbone network for point cloud feature learning.
24 | Based on Pointnet++ single-scale grouping network.
25 |
26 | Parameters
27 | ----------
28 | input_feature_dim: int
29 | Number of input channels in the feature descriptor for each point.
30 | e.g. 3 for RGB.
31 | """
32 |
33 | def __init__(self, input_feature_dim=0, width=1, depth=2):
34 | super().__init__()
35 | self.depth = depth
36 | self.width = width
37 |
38 | self.sa1 = PointnetSAModuleVotes(
39 | npoint=2048,
40 | radius=0.2,
41 | nsample=64,
42 | mlp=[input_feature_dim] + [64 * width for i in range(depth)] + [128 * width],
43 | use_xyz=True,
44 | normalize_xyz=True
45 | )
46 |
47 | self.sa2 = PointnetSAModuleVotes(
48 | npoint=1024,
49 | radius=0.4,
50 | nsample=32,
51 | mlp=[128 * width] + [128 * width for i in range(depth)] + [256 * width],
52 | use_xyz=True,
53 | normalize_xyz=True
54 | )
55 |
56 | self.sa3 = PointnetSAModuleVotes(
57 | npoint=512,
58 | radius=0.8,
59 | nsample=16,
60 | mlp=[256 * width] + [128 * width for i in range(depth)] + [256 * width],
61 | use_xyz=True,
62 | normalize_xyz=True
63 | )
64 |
65 | self.sa4 = PointnetSAModuleVotes(
66 | npoint=256,
67 | radius=1.2,
68 | nsample=16,
69 | mlp=[256 * width] + [128 * width for i in range(depth)] + [256 * width],
70 | use_xyz=True,
71 | normalize_xyz=True
72 | )
73 |
74 | self.fp1 = PointnetFPModule(mlp=[256 * width + 256 * width, 256 * width, 256 * width])
75 | self.fp2 = PointnetFPModule(mlp=[256 * width + 256 * width, 256 * width, 288])
76 |
77 | def _break_up_pc(self, pc):
78 | xyz = pc[..., 0:3].contiguous()
79 | features = (
80 | pc[..., 3:].transpose(1, 2).contiguous()
81 | if pc.size(-1) > 3 else None
82 | )
83 |
84 | return xyz, features
85 |
86 | def forward(self, pointcloud: torch.cuda.FloatTensor, end_points=None):
87 | r"""
88 | Forward pass of the network
89 |
90 | Parameters
91 | ----------
92 | pointcloud: Variable(torch.cuda.FloatTensor)
93 | (B, N, 3 + input_feature_dim) tensor
94 | Point cloud to run predicts on
95 | Each point in the point-cloud MUST
96 | be formated as (x, y, z, features...)
97 |
98 | Returns
99 | ----------
100 | end_points: {XXX_xyz, XXX_features, XXX_inds}
101 | XXX_xyz: float32 Tensor of shape (B,K,3)
102 | XXX_features: float32 Tensor of shape (B,K,D)
103 | XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1]
104 | """
105 | if not end_points: end_points = {}
106 | batch_size = pointcloud.shape[0]
107 |
108 | xyz, features = self._break_up_pc(pointcloud)
109 |
110 | # --------- 4 SET ABSTRACTION LAYERS ---------
111 | xyz, features, fps_inds = self.sa1(xyz, features)
112 | end_points['sa1_inds'] = fps_inds
113 | end_points['sa1_xyz'] = xyz
114 | end_points['sa1_features'] = features
115 |
116 | xyz, features, fps_inds = self.sa2(xyz, features) # this fps_inds is just 0,1,...,1023
117 | end_points['sa2_inds'] = fps_inds
118 | end_points['sa2_xyz'] = xyz
119 | end_points['sa2_features'] = features
120 |
121 | xyz, features, fps_inds = self.sa3(xyz, features) # this fps_inds is just 0,1,...,511
122 | end_points['sa3_xyz'] = xyz
123 | end_points['sa3_features'] = features
124 |
125 | xyz, features, fps_inds = self.sa4(xyz, features) # this fps_inds is just 0,1,...,255
126 | end_points['sa4_xyz'] = xyz
127 | end_points['sa4_features'] = features
128 |
129 | # --------- 2 FEATURE UPSAMPLING LAYERS --------
130 | features = self.fp1(end_points['sa3_xyz'], end_points['sa4_xyz'], end_points['sa3_features'],
131 | end_points['sa4_features'])
132 | features = self.fp2(end_points['sa2_xyz'], end_points['sa3_xyz'], end_points['sa2_features'], features)
133 | end_points['fp2_features'] = features
134 | end_points['fp2_xyz'] = end_points['sa2_xyz']
135 | num_seed = end_points['fp2_xyz'].shape[1]
136 | end_points['fp2_inds'] = end_points['sa1_inds'][:, 0:num_seed] # indices among the entire input point clouds
137 |
138 | return end_points
139 |
140 |
141 | if __name__ == '__main__':
142 | backbone_net = Pointnet2Backbone(input_feature_dim=3).cuda()
143 | print(backbone_net)
144 | backbone_net.eval()
145 | out = backbone_net(torch.rand(16, 20000, 6).cuda())
146 | for key in sorted(out.keys()):
147 | print(key, '\t', out[key].shape)
148 |
--------------------------------------------------------------------------------
/object-detection-3d/models/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 | from typing import Optional
7 | from multi_head_attention import MultiheadAttention
8 |
9 |
10 | class TransformerDecoderLayer(nn.Module):
11 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
12 | self_posembed=None, cross_posembed=None):
13 | super().__init__()
14 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
15 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
16 | # Implementation of Feedforward model
17 | self.linear1 = nn.Linear(d_model, dim_feedforward)
18 | self.dropout = nn.Dropout(dropout)
19 | self.linear2 = nn.Linear(dim_feedforward, d_model)
20 |
21 | self.norm1 = nn.LayerNorm(d_model)
22 | self.norm2 = nn.LayerNorm(d_model)
23 | self.norm3 = nn.LayerNorm(d_model)
24 | self.dropout1 = nn.Dropout(dropout)
25 | self.dropout2 = nn.Dropout(dropout)
26 | self.dropout3 = nn.Dropout(dropout)
27 |
28 | self.activation = _get_activation_fn(activation)
29 |
30 | self.self_posembed = self_posembed
31 | self.cross_posembed = cross_posembed
32 |
33 | def with_pos_embed(self, tensor, pos_embed: Optional[Tensor]):
34 | return tensor if pos_embed is None else tensor + pos_embed
35 |
36 | def forward(self, query, key, query_pos, key_pos):
37 | """
38 | :param query: B C Pq
39 | :param key: B C Pk
40 | :param query_pos: B Pq 3/6
41 | :param key_pos: B Pk 3/6
42 | :param value_pos: [B Pq 3/6]
43 |
44 | :return:
45 | """
46 | # NxCxP to PxNxC
47 | if self.self_posembed is not None:
48 | query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1)
49 | else:
50 | query_pos_embed = None
51 | if self.cross_posembed is not None:
52 | key_pos_embed = self.cross_posembed(key_pos).permute(2, 0, 1)
53 | else:
54 | key_pos_embed = None
55 |
56 | query = query.permute(2, 0, 1)
57 | key = key.permute(2, 0, 1)
58 |
59 | q = k = v = self.with_pos_embed(query, query_pos_embed)
60 | query2 = self.self_attn(q, k, value=v)[0]
61 | query = query + self.dropout1(query2)
62 | query = self.norm1(query)
63 |
64 | query2 = self.multihead_attn(query=self.with_pos_embed(query, query_pos_embed),
65 | key=self.with_pos_embed(key, key_pos_embed),
66 | value=self.with_pos_embed(key, key_pos_embed))[0]
67 | query = query + self.dropout2(query2)
68 | query = self.norm2(query)
69 |
70 | query2 = self.linear2(self.dropout(self.activation(self.linear1(query))))
71 | query = query + self.dropout3(query2)
72 | query = self.norm3(query)
73 |
74 | # NxCxP to PxNxC
75 | query = query.permute(1, 2, 0)
76 | return query
77 |
78 |
79 | def _get_activation_fn(activation):
80 | """Return an activation function given a string"""
81 | if activation == "relu":
82 | return F.relu
83 | if activation == "gelu":
84 | return F.gelu
85 | if activation == "glu":
86 | return F.glu
87 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
88 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/include/ball_query.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
10 | const int nsample);
11 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #ifndef _CUDA_UTILS_H
7 | #define _CUDA_UTILS_H
8 |
9 | #include
10 | #include
11 | #include
12 |
13 | #include
14 | #include
15 |
16 | #include
17 |
18 | #define TOTAL_THREADS 512
19 |
20 | inline int opt_n_threads(int work_size) {
21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
22 |
23 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
24 | }
25 |
26 | inline dim3 opt_block_config(int x, int y) {
27 | const int x_threads = opt_n_threads(x);
28 | const int y_threads =
29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
30 | dim3 block_config(x_threads, y_threads, 1);
31 |
32 | return block_config;
33 | }
34 |
35 | #define CUDA_CHECK_ERRORS() \
36 | do { \
37 | cudaError_t err = cudaGetLastError(); \
38 | if (cudaSuccess != err) { \
39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
41 | __FILE__); \
42 | exit(-1); \
43 | } \
44 | } while (0)
45 |
46 | #endif
47 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/include/group_points.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor group_points(at::Tensor points, at::Tensor idx);
10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
11 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/include/interpolate.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 |
8 | #include
9 | #include
10 |
11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows);
12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
13 | at::Tensor weight);
14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
15 | at::Tensor weight, const int m);
16 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/include/sampling.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 |
9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx);
10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n);
11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples);
12 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/include/utils.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #pragma once
7 | #include
8 | #include
9 |
10 | #define CHECK_CUDA(x) \
11 | do { \
12 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_CONTIGUOUS(x) \
16 | do { \
17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
18 | } while (0)
19 |
20 | #define CHECK_IS_INT(x) \
21 | do { \
22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \
23 | #x " must be an int tensor"); \
24 | } while (0)
25 |
26 | #define CHECK_IS_FLOAT(x) \
27 | do { \
28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \
29 | #x " must be a float tensor"); \
30 | } while (0)
31 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/ball_query.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "ball_query.h"
7 | #include "utils.h"
8 |
9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
10 | int nsample, const float *new_xyz,
11 | const float *xyz, int *idx);
12 |
13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius,
14 | const int nsample) {
15 | CHECK_CONTIGUOUS(new_xyz);
16 | CHECK_CONTIGUOUS(xyz);
17 | CHECK_IS_FLOAT(new_xyz);
18 | CHECK_IS_FLOAT(xyz);
19 |
20 | if (new_xyz.device().is_cuda()) {
21 | CHECK_CUDA(xyz);
22 | }
23 |
24 | at::Tensor idx =
25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int));
27 |
28 | if (new_xyz.device().is_cuda()) {
29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1),
30 | radius, nsample, new_xyz.data_ptr(),
31 | xyz.data_ptr(), idx.data_ptr());
32 | } else {
33 | TORCH_CHECK(false, "CPU not supported");
34 | }
35 |
36 | return idx;
37 | }
38 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/ball_query_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | #include "cuda_utils.h"
11 |
12 | // input: new_xyz(b, m, 3) xyz(b, n, 3)
13 | // output: idx(b, m, nsample)
14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
15 | int nsample,
16 | const float *__restrict__ new_xyz,
17 | const float *__restrict__ xyz,
18 | int *__restrict__ idx) {
19 | int batch_index = blockIdx.x;
20 | xyz += batch_index * n * 3;
21 | new_xyz += batch_index * m * 3;
22 | idx += m * nsample * batch_index;
23 |
24 | int index = threadIdx.x;
25 | int stride = blockDim.x;
26 |
27 | float radius2 = radius * radius;
28 | for (int j = index; j < m; j += stride) {
29 | float new_x = new_xyz[j * 3 + 0];
30 | float new_y = new_xyz[j * 3 + 1];
31 | float new_z = new_xyz[j * 3 + 2];
32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
33 | float x = xyz[k * 3 + 0];
34 | float y = xyz[k * 3 + 1];
35 | float z = xyz[k * 3 + 2];
36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
37 | (new_z - z) * (new_z - z);
38 | if (d2 < radius2) {
39 | if (cnt == 0) {
40 | for (int l = 0; l < nsample; ++l) {
41 | idx[j * nsample + l] = k;
42 | }
43 | }
44 | idx[j * nsample + cnt] = k;
45 | ++cnt;
46 | }
47 | }
48 | }
49 | }
50 |
51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
52 | int nsample, const float *new_xyz,
53 | const float *xyz, int *idx) {
54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
55 | query_ball_point_kernel<<>>(
56 | b, n, m, radius, nsample, new_xyz, xyz, idx);
57 |
58 | CUDA_CHECK_ERRORS();
59 | }
60 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "ball_query.h"
7 | #include "group_points.h"
8 | #include "interpolate.h"
9 | #include "sampling.h"
10 |
11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12 | m.def("gather_points", &gather_points);
13 | m.def("gather_points_grad", &gather_points_grad);
14 | m.def("furthest_point_sampling", &furthest_point_sampling);
15 |
16 | m.def("three_nn", &three_nn);
17 | m.def("three_interpolate", &three_interpolate);
18 | m.def("three_interpolate_grad", &three_interpolate_grad);
19 |
20 | m.def("ball_query", &ball_query);
21 |
22 | m.def("group_points", &group_points);
23 | m.def("group_points_grad", &group_points_grad);
24 | }
25 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/group_points.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "group_points.h"
7 | #include "utils.h"
8 |
9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
10 | const float *points, const int *idx,
11 | float *out);
12 |
13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
14 | int nsample, const float *grad_out,
15 | const int *idx, float *grad_points);
16 |
17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) {
18 | CHECK_CONTIGUOUS(points);
19 | CHECK_CONTIGUOUS(idx);
20 | CHECK_IS_FLOAT(points);
21 | CHECK_IS_INT(idx);
22 |
23 | if (points.device().is_cuda()) {
24 | CHECK_CUDA(idx);
25 | }
26 |
27 | at::Tensor output =
28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)},
29 | at::device(points.device()).dtype(at::ScalarType::Float));
30 |
31 | if (points.device().is_cuda()) {
32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
33 | idx.size(1), idx.size(2), points.data_ptr(),
34 | idx.data_ptr(), output.data_ptr());
35 | } else {
36 | TORCH_CHECK(false, "CPU not supported");
37 | }
38 |
39 | return output;
40 | }
41 |
42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) {
43 | CHECK_CONTIGUOUS(grad_out);
44 | CHECK_CONTIGUOUS(idx);
45 | CHECK_IS_FLOAT(grad_out);
46 | CHECK_IS_INT(idx);
47 |
48 | if (grad_out.device().is_cuda()) {
49 | CHECK_CUDA(idx);
50 | }
51 |
52 | at::Tensor output =
53 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
54 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
55 |
56 | if (grad_out.device().is_cuda()) {
57 | group_points_grad_kernel_wrapper(
58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2),
59 | grad_out.data_ptr(), idx.data_ptr(), output.data_ptr());
60 | } else {
61 | TORCH_CHECK(false, "CPU not supported");
62 | }
63 |
64 | return output;
65 | }
66 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/group_points_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 |
9 | #include "cuda_utils.h"
10 |
11 | // input: points(b, c, n) idx(b, npoints, nsample)
12 | // output: out(b, c, npoints, nsample)
13 | __global__ void group_points_kernel(int b, int c, int n, int npoints,
14 | int nsample,
15 | const float *__restrict__ points,
16 | const int *__restrict__ idx,
17 | float *__restrict__ out) {
18 | int batch_index = blockIdx.x;
19 | points += batch_index * n * c;
20 | idx += batch_index * npoints * nsample;
21 | out += batch_index * npoints * nsample * c;
22 |
23 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
24 | const int stride = blockDim.y * blockDim.x;
25 | for (int i = index; i < c * npoints; i += stride) {
26 | const int l = i / npoints;
27 | const int j = i % npoints;
28 | for (int k = 0; k < nsample; ++k) {
29 | int ii = idx[j * nsample + k];
30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii];
31 | }
32 | }
33 | }
34 |
35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample,
36 | const float *points, const int *idx,
37 | float *out) {
38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
39 |
40 | group_points_kernel<<>>(
41 | b, c, n, npoints, nsample, points, idx, out);
42 |
43 | CUDA_CHECK_ERRORS();
44 | }
45 |
46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample)
47 | // output: grad_points(b, c, n)
48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints,
49 | int nsample,
50 | const float *__restrict__ grad_out,
51 | const int *__restrict__ idx,
52 | float *__restrict__ grad_points) {
53 | int batch_index = blockIdx.x;
54 | grad_out += batch_index * npoints * nsample * c;
55 | idx += batch_index * npoints * nsample;
56 | grad_points += batch_index * n * c;
57 |
58 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
59 | const int stride = blockDim.y * blockDim.x;
60 | for (int i = index; i < c * npoints; i += stride) {
61 | const int l = i / npoints;
62 | const int j = i % npoints;
63 | for (int k = 0; k < nsample; ++k) {
64 | int ii = idx[j * nsample + k];
65 | atomicAdd(grad_points + l * n + ii,
66 | grad_out[(l * npoints + j) * nsample + k]);
67 | }
68 | }
69 | }
70 |
71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
72 | int nsample, const float *grad_out,
73 | const int *idx, float *grad_points) {
74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
75 |
76 | group_points_grad_kernel<<>>(
77 | b, c, n, npoints, nsample, grad_out, idx, grad_points);
78 |
79 | CUDA_CHECK_ERRORS();
80 | }
81 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/interpolate.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "interpolate.h"
7 | #include "utils.h"
8 |
9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
10 | const float *known, float *dist2, int *idx);
11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
12 | const float *points, const int *idx,
13 | const float *weight, float *out);
14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
15 | const float *grad_out,
16 | const int *idx, const float *weight,
17 | float *grad_points);
18 |
19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) {
20 | CHECK_CONTIGUOUS(unknowns);
21 | CHECK_CONTIGUOUS(knows);
22 | CHECK_IS_FLOAT(unknowns);
23 | CHECK_IS_FLOAT(knows);
24 |
25 | if (unknowns.device().is_cuda()) {
26 | CHECK_CUDA(knows);
27 | }
28 |
29 | at::Tensor idx =
30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
31 | at::device(unknowns.device()).dtype(at::ScalarType::Int));
32 | at::Tensor dist2 =
33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3},
34 | at::device(unknowns.device()).dtype(at::ScalarType::Float));
35 |
36 | if (unknowns.device().is_cuda()) {
37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1),
38 | unknowns.data_ptr(), knows.data_ptr(),
39 | dist2.data_ptr(), idx.data_ptr());
40 | } else {
41 | TORCH_CHECK(false, "CPU not supported");
42 | }
43 |
44 | return {dist2, idx};
45 | }
46 |
47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx,
48 | at::Tensor weight) {
49 | CHECK_CONTIGUOUS(points);
50 | CHECK_CONTIGUOUS(idx);
51 | CHECK_CONTIGUOUS(weight);
52 | CHECK_IS_FLOAT(points);
53 | CHECK_IS_INT(idx);
54 | CHECK_IS_FLOAT(weight);
55 |
56 | if (points.device().is_cuda()) {
57 | CHECK_CUDA(idx);
58 | CHECK_CUDA(weight);
59 | }
60 |
61 | at::Tensor output =
62 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
63 | at::device(points.device()).dtype(at::ScalarType::Float));
64 |
65 | if (points.device().is_cuda()) {
66 | three_interpolate_kernel_wrapper(
67 | points.size(0), points.size(1), points.size(2), idx.size(1),
68 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(),
69 | output.data_ptr());
70 | } else {
71 | TORCH_CHECK(false, "CPU not supported");
72 | }
73 |
74 | return output;
75 | }
76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx,
77 | at::Tensor weight, const int m) {
78 | CHECK_CONTIGUOUS(grad_out);
79 | CHECK_CONTIGUOUS(idx);
80 | CHECK_CONTIGUOUS(weight);
81 | CHECK_IS_FLOAT(grad_out);
82 | CHECK_IS_INT(idx);
83 | CHECK_IS_FLOAT(weight);
84 |
85 | if (grad_out.device().is_cuda()) {
86 | CHECK_CUDA(idx);
87 | CHECK_CUDA(weight);
88 | }
89 |
90 | at::Tensor output =
91 | torch::zeros({grad_out.size(0), grad_out.size(1), m},
92 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
93 |
94 | if (grad_out.device().is_cuda()) {
95 | three_interpolate_grad_kernel_wrapper(
96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m,
97 | grad_out.data_ptr(), idx.data_ptr(), weight.data_ptr(),
98 | output.data_ptr());
99 | } else {
100 | TORCH_CHECK(false, "CPU not supported");
101 | }
102 |
103 | return output;
104 | }
105 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/interpolate_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | #include "cuda_utils.h"
11 |
12 | // input: unknown(b, n, 3) known(b, m, 3)
13 | // output: dist2(b, n, 3), idx(b, n, 3)
14 | __global__ void three_nn_kernel(int b, int n, int m,
15 | const float *__restrict__ unknown,
16 | const float *__restrict__ known,
17 | float *__restrict__ dist2,
18 | int *__restrict__ idx) {
19 | int batch_index = blockIdx.x;
20 | unknown += batch_index * n * 3;
21 | known += batch_index * m * 3;
22 | dist2 += batch_index * n * 3;
23 | idx += batch_index * n * 3;
24 |
25 | int index = threadIdx.x;
26 | int stride = blockDim.x;
27 | for (int j = index; j < n; j += stride) {
28 | float ux = unknown[j * 3 + 0];
29 | float uy = unknown[j * 3 + 1];
30 | float uz = unknown[j * 3 + 2];
31 |
32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40;
33 | int besti1 = 0, besti2 = 0, besti3 = 0;
34 | for (int k = 0; k < m; ++k) {
35 | float x = known[k * 3 + 0];
36 | float y = known[k * 3 + 1];
37 | float z = known[k * 3 + 2];
38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
39 | if (d < best1) {
40 | best3 = best2;
41 | besti3 = besti2;
42 | best2 = best1;
43 | besti2 = besti1;
44 | best1 = d;
45 | besti1 = k;
46 | } else if (d < best2) {
47 | best3 = best2;
48 | besti3 = besti2;
49 | best2 = d;
50 | besti2 = k;
51 | } else if (d < best3) {
52 | best3 = d;
53 | besti3 = k;
54 | }
55 | }
56 | dist2[j * 3 + 0] = best1;
57 | dist2[j * 3 + 1] = best2;
58 | dist2[j * 3 + 2] = best3;
59 |
60 | idx[j * 3 + 0] = besti1;
61 | idx[j * 3 + 1] = besti2;
62 | idx[j * 3 + 2] = besti3;
63 | }
64 | }
65 |
66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown,
67 | const float *known, float *dist2, int *idx) {
68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
69 | three_nn_kernel<<>>(b, n, m, unknown, known,
70 | dist2, idx);
71 |
72 | CUDA_CHECK_ERRORS();
73 | }
74 |
75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
76 | // output: out(b, c, n)
77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n,
78 | const float *__restrict__ points,
79 | const int *__restrict__ idx,
80 | const float *__restrict__ weight,
81 | float *__restrict__ out) {
82 | int batch_index = blockIdx.x;
83 | points += batch_index * m * c;
84 |
85 | idx += batch_index * n * 3;
86 | weight += batch_index * n * 3;
87 |
88 | out += batch_index * n * c;
89 |
90 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
91 | const int stride = blockDim.y * blockDim.x;
92 | for (int i = index; i < c * n; i += stride) {
93 | const int l = i / n;
94 | const int j = i % n;
95 | float w1 = weight[j * 3 + 0];
96 | float w2 = weight[j * 3 + 1];
97 | float w3 = weight[j * 3 + 2];
98 |
99 | int i1 = idx[j * 3 + 0];
100 | int i2 = idx[j * 3 + 1];
101 | int i3 = idx[j * 3 + 2];
102 |
103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
104 | points[l * m + i3] * w3;
105 | }
106 | }
107 |
108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
109 | const float *points, const int *idx,
110 | const float *weight, float *out) {
111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
112 | three_interpolate_kernel<<>>(
113 | b, c, m, n, points, idx, weight, out);
114 |
115 | CUDA_CHECK_ERRORS();
116 | }
117 |
118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3)
119 | // output: grad_points(b, c, m)
120 |
121 | __global__ void three_interpolate_grad_kernel(
122 | int b, int c, int n, int m, const float *__restrict__ grad_out,
123 | const int *__restrict__ idx, const float *__restrict__ weight,
124 | float *__restrict__ grad_points) {
125 | int batch_index = blockIdx.x;
126 | grad_out += batch_index * n * c;
127 | idx += batch_index * n * 3;
128 | weight += batch_index * n * 3;
129 | grad_points += batch_index * m * c;
130 |
131 | const int index = threadIdx.y * blockDim.x + threadIdx.x;
132 | const int stride = blockDim.y * blockDim.x;
133 | for (int i = index; i < c * n; i += stride) {
134 | const int l = i / n;
135 | const int j = i % n;
136 | float w1 = weight[j * 3 + 0];
137 | float w2 = weight[j * 3 + 1];
138 | float w3 = weight[j * 3 + 2];
139 |
140 | int i1 = idx[j * 3 + 0];
141 | int i2 = idx[j * 3 + 1];
142 | int i3 = idx[j * 3 + 2];
143 |
144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1);
145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2);
146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3);
147 | }
148 | }
149 |
150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m,
151 | const float *grad_out,
152 | const int *idx, const float *weight,
153 | float *grad_points) {
154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
155 | three_interpolate_grad_kernel<<>>(
156 | b, c, n, m, grad_out, idx, weight, grad_points);
157 |
158 | CUDA_CHECK_ERRORS();
159 | }
160 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/sampling.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include "sampling.h"
7 | #include "utils.h"
8 |
9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
10 | const float *points, const int *idx,
11 | float *out);
12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
13 | const float *grad_out, const int *idx,
14 | float *grad_points);
15 |
16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
17 | const float *dataset, float *temp,
18 | int *idxs);
19 |
20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) {
21 | CHECK_CONTIGUOUS(points);
22 | CHECK_CONTIGUOUS(idx);
23 | CHECK_IS_FLOAT(points);
24 | CHECK_IS_INT(idx);
25 |
26 | if (points.device().is_cuda()) {
27 | CHECK_CUDA(idx);
28 | }
29 |
30 | at::Tensor output =
31 | torch::zeros({points.size(0), points.size(1), idx.size(1)},
32 | at::device(points.device()).dtype(at::ScalarType::Float));
33 |
34 | if (points.device().is_cuda()) {
35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2),
36 | idx.size(1), points.data_ptr(),
37 | idx.data_ptr(), output.data_ptr());
38 | } else {
39 | TORCH_CHECK(false, "CPU not supported");
40 | }
41 |
42 | return output;
43 | }
44 |
45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx,
46 | const int n) {
47 | CHECK_CONTIGUOUS(grad_out);
48 | CHECK_CONTIGUOUS(idx);
49 | CHECK_IS_FLOAT(grad_out);
50 | CHECK_IS_INT(idx);
51 |
52 | if (grad_out.device().is_cuda()) {
53 | CHECK_CUDA(idx);
54 | }
55 |
56 | at::Tensor output =
57 | torch::zeros({grad_out.size(0), grad_out.size(1), n},
58 | at::device(grad_out.device()).dtype(at::ScalarType::Float));
59 |
60 | if (grad_out.device().is_cuda()) {
61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n,
62 | idx.size(1), grad_out.data_ptr(),
63 | idx.data_ptr(), output.data_ptr());
64 | } else {
65 | TORCH_CHECK(false, "CPU not supported");
66 | }
67 |
68 | return output;
69 | }
70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) {
71 | CHECK_CONTIGUOUS(points);
72 | CHECK_IS_FLOAT(points);
73 |
74 | at::Tensor output =
75 | torch::zeros({points.size(0), nsamples},
76 | at::device(points.device()).dtype(at::ScalarType::Int));
77 |
78 | at::Tensor tmp =
79 | torch::full({points.size(0), points.size(1)}, 1e10,
80 | at::device(points.device()).dtype(at::ScalarType::Float));
81 |
82 | if (points.device().is_cuda()) {
83 | furthest_point_sampling_kernel_wrapper(
84 | points.size(0), points.size(1), nsamples, points.data_ptr(),
85 | tmp.data_ptr(), output.data_ptr());
86 | } else {
87 | TORCH_CHECK(false, "CPU not supported");
88 | }
89 |
90 | return output;
91 | }
92 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/_ext_src/src/sampling_gpu.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates.
2 | //
3 | // This source code is licensed under the MIT license found in the
4 | // LICENSE file in the root directory of this source tree.
5 |
6 | #include
7 | #include
8 |
9 | #include "cuda_utils.h"
10 |
11 | // input: points(b, c, n) idx(b, m)
12 | // output: out(b, c, m)
13 | __global__ void gather_points_kernel(int b, int c, int n, int m,
14 | const float *__restrict__ points,
15 | const int *__restrict__ idx,
16 | float *__restrict__ out) {
17 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
18 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
19 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
20 | int a = idx[i * m + j];
21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
22 | }
23 | }
24 | }
25 | }
26 |
27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
28 | const float *points, const int *idx,
29 | float *out) {
30 | gather_points_kernel<<>>(b, c, n, npoints,
32 | points, idx, out);
33 |
34 | CUDA_CHECK_ERRORS();
35 | }
36 |
37 | // input: grad_out(b, c, m) idx(b, m)
38 | // output: grad_points(b, c, n)
39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m,
40 | const float *__restrict__ grad_out,
41 | const int *__restrict__ idx,
42 | float *__restrict__ grad_points) {
43 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
44 | for (int l = blockIdx.y; l < c; l += gridDim.y) {
45 | for (int j = threadIdx.x; j < m; j += blockDim.x) {
46 | int a = idx[i * m + j];
47 | atomicAdd(grad_points + (i * c + l) * n + a,
48 | grad_out[(i * c + l) * m + j]);
49 | }
50 | }
51 | }
52 | }
53 |
54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
55 | const float *grad_out, const int *idx,
56 | float *grad_points) {
57 | gather_points_grad_kernel<<>>(
59 | b, c, n, npoints, grad_out, idx, grad_points);
60 |
61 | CUDA_CHECK_ERRORS();
62 | }
63 |
64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
65 | int idx1, int idx2) {
66 | const float v1 = dists[idx1], v2 = dists[idx2];
67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
68 | dists[idx1] = max(v1, v2);
69 | dists_i[idx1] = v2 > v1 ? i2 : i1;
70 | }
71 |
72 | // Input dataset: (b, n, 3), tmp: (b, n)
73 | // Ouput idxs (b, m)
74 | template
75 | __global__ void furthest_point_sampling_kernel(
76 | int b, int n, int m, const float *__restrict__ dataset,
77 | float *__restrict__ temp, int *__restrict__ idxs) {
78 | if (m <= 0) return;
79 | __shared__ float dists[block_size];
80 | __shared__ int dists_i[block_size];
81 |
82 | int batch_index = blockIdx.x;
83 | dataset += batch_index * n * 3;
84 | temp += batch_index * n;
85 | idxs += batch_index * m;
86 |
87 | int tid = threadIdx.x;
88 | const int stride = block_size;
89 |
90 | int old = 0;
91 | if (threadIdx.x == 0) idxs[0] = old;
92 |
93 | __syncthreads();
94 | for (int j = 1; j < m; j++) {
95 | int besti = 0;
96 | float best = -1;
97 | float x1 = dataset[old * 3 + 0];
98 | float y1 = dataset[old * 3 + 1];
99 | float z1 = dataset[old * 3 + 2];
100 | for (int k = tid; k < n; k += stride) {
101 | float x2, y2, z2;
102 | x2 = dataset[k * 3 + 0];
103 | y2 = dataset[k * 3 + 1];
104 | z2 = dataset[k * 3 + 2];
105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
106 | if (mag <= 1e-3) continue;
107 |
108 | float d =
109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
110 |
111 | float d2 = min(d, temp[k]);
112 | temp[k] = d2;
113 | besti = d2 > best ? k : besti;
114 | best = d2 > best ? d2 : best;
115 | }
116 | dists[tid] = best;
117 | dists_i[tid] = besti;
118 | __syncthreads();
119 |
120 | if (block_size >= 512) {
121 | if (tid < 256) {
122 | __update(dists, dists_i, tid, tid + 256);
123 | }
124 | __syncthreads();
125 | }
126 | if (block_size >= 256) {
127 | if (tid < 128) {
128 | __update(dists, dists_i, tid, tid + 128);
129 | }
130 | __syncthreads();
131 | }
132 | if (block_size >= 128) {
133 | if (tid < 64) {
134 | __update(dists, dists_i, tid, tid + 64);
135 | }
136 | __syncthreads();
137 | }
138 | if (block_size >= 64) {
139 | if (tid < 32) {
140 | __update(dists, dists_i, tid, tid + 32);
141 | }
142 | __syncthreads();
143 | }
144 | if (block_size >= 32) {
145 | if (tid < 16) {
146 | __update(dists, dists_i, tid, tid + 16);
147 | }
148 | __syncthreads();
149 | }
150 | if (block_size >= 16) {
151 | if (tid < 8) {
152 | __update(dists, dists_i, tid, tid + 8);
153 | }
154 | __syncthreads();
155 | }
156 | if (block_size >= 8) {
157 | if (tid < 4) {
158 | __update(dists, dists_i, tid, tid + 4);
159 | }
160 | __syncthreads();
161 | }
162 | if (block_size >= 4) {
163 | if (tid < 2) {
164 | __update(dists, dists_i, tid, tid + 2);
165 | }
166 | __syncthreads();
167 | }
168 | if (block_size >= 2) {
169 | if (tid < 1) {
170 | __update(dists, dists_i, tid, tid + 1);
171 | }
172 | __syncthreads();
173 | }
174 |
175 | old = dists_i[0];
176 | if (tid == 0) idxs[j] = old;
177 | }
178 | }
179 |
180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m,
181 | const float *dataset, float *temp,
182 | int *idxs) {
183 | unsigned int n_threads = opt_n_threads(n);
184 |
185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
186 |
187 | switch (n_threads) {
188 | case 512:
189 | furthest_point_sampling_kernel<512>
190 | <<>>(b, n, m, dataset, temp, idxs);
191 | break;
192 | case 256:
193 | furthest_point_sampling_kernel<256>
194 | <<>>(b, n, m, dataset, temp, idxs);
195 | break;
196 | case 128:
197 | furthest_point_sampling_kernel<128>
198 | <<>>(b, n, m, dataset, temp, idxs);
199 | break;
200 | case 64:
201 | furthest_point_sampling_kernel<64>
202 | <<>>(b, n, m, dataset, temp, idxs);
203 | break;
204 | case 32:
205 | furthest_point_sampling_kernel<32>
206 | <<>>(b, n, m, dataset, temp, idxs);
207 | break;
208 | case 16:
209 | furthest_point_sampling_kernel<16>
210 | <<>>(b, n, m, dataset, temp, idxs);
211 | break;
212 | case 8:
213 | furthest_point_sampling_kernel<8>
214 | <<>>(b, n, m, dataset, temp, idxs);
215 | break;
216 | case 4:
217 | furthest_point_sampling_kernel<4>
218 | <<>>(b, n, m, dataset, temp, idxs);
219 | break;
220 | case 2:
221 | furthest_point_sampling_kernel<2>
222 | <<>>(b, n, m, dataset, temp, idxs);
223 | break;
224 | case 1:
225 | furthest_point_sampling_kernel<1>
226 | <<>>(b, n, m, dataset, temp, idxs);
227 | break;
228 | default:
229 | furthest_point_sampling_kernel<512>
230 | <<>>(b, n, m, dataset, temp, idxs);
231 | }
232 |
233 | CUDA_CHECK_ERRORS();
234 | }
235 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/pointnet2_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | ''' Testing customized ops. '''
7 |
8 | import torch
9 | from torch.autograd import gradcheck
10 | import numpy as np
11 |
12 | import os
13 | import sys
14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15 | sys.path.append(BASE_DIR)
16 | import pointnet2_utils
17 |
18 | def test_interpolation_grad():
19 | batch_size = 1
20 | feat_dim = 2
21 | m = 4
22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda()
23 |
24 | def interpolate_func(inputs):
25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda()
26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda()
27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight)
28 | return interpolated_feats
29 |
30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1))
31 |
32 | if __name__=='__main__':
33 | test_interpolation_grad()
34 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/pytorch_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch '''
7 | import torch
8 | import torch.nn as nn
9 | from typing import List, Tuple
10 |
11 | class SharedMLP(nn.Sequential):
12 |
13 | def __init__(
14 | self,
15 | args: List[int],
16 | *,
17 | bn: bool = False,
18 | activation=nn.ReLU(inplace=True),
19 | preact: bool = False,
20 | first: bool = False,
21 | name: str = ""
22 | ):
23 | super().__init__()
24 |
25 | for i in range(len(args) - 1):
26 | self.add_module(
27 | name + 'layer{}'.format(i),
28 | Conv2d(
29 | args[i],
30 | args[i + 1],
31 | bn=(not first or not preact or (i != 0)) and bn,
32 | activation=activation
33 | if (not first or not preact or (i != 0)) else None,
34 | preact=preact
35 | )
36 | )
37 |
38 |
39 | class _BNBase(nn.Sequential):
40 |
41 | def __init__(self, in_size, batch_norm=None, name=""):
42 | super().__init__()
43 | self.add_module(name + "bn", batch_norm(in_size))
44 |
45 | nn.init.constant_(self[0].weight, 1.0)
46 | nn.init.constant_(self[0].bias, 0)
47 |
48 |
49 | class BatchNorm1d(_BNBase):
50 |
51 | def __init__(self, in_size: int, *, name: str = ""):
52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
53 |
54 |
55 | class BatchNorm2d(_BNBase):
56 |
57 | def __init__(self, in_size: int, name: str = ""):
58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
59 |
60 |
61 | class BatchNorm3d(_BNBase):
62 |
63 | def __init__(self, in_size: int, name: str = ""):
64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name)
65 |
66 |
67 | class _ConvBase(nn.Sequential):
68 |
69 | def __init__(
70 | self,
71 | in_size,
72 | out_size,
73 | kernel_size,
74 | stride,
75 | padding,
76 | activation,
77 | bn,
78 | init,
79 | conv=None,
80 | batch_norm=None,
81 | bias=True,
82 | preact=False,
83 | name=""
84 | ):
85 | super().__init__()
86 |
87 | bias = bias and (not bn)
88 | conv_unit = conv(
89 | in_size,
90 | out_size,
91 | kernel_size=kernel_size,
92 | stride=stride,
93 | padding=padding,
94 | bias=bias
95 | )
96 | init(conv_unit.weight)
97 | if bias:
98 | nn.init.constant_(conv_unit.bias, 0)
99 |
100 | if bn:
101 | if not preact:
102 | bn_unit = batch_norm(out_size)
103 | else:
104 | bn_unit = batch_norm(in_size)
105 |
106 | if preact:
107 | if bn:
108 | self.add_module(name + 'bn', bn_unit)
109 |
110 | if activation is not None:
111 | self.add_module(name + 'activation', activation)
112 |
113 | self.add_module(name + 'conv', conv_unit)
114 |
115 | if not preact:
116 | if bn:
117 | self.add_module(name + 'bn', bn_unit)
118 |
119 | if activation is not None:
120 | self.add_module(name + 'activation', activation)
121 |
122 |
123 | class Conv1d(_ConvBase):
124 |
125 | def __init__(
126 | self,
127 | in_size: int,
128 | out_size: int,
129 | *,
130 | kernel_size: int = 1,
131 | stride: int = 1,
132 | padding: int = 0,
133 | activation=nn.ReLU(inplace=True),
134 | bn: bool = False,
135 | init=nn.init.kaiming_normal_,
136 | bias: bool = True,
137 | preact: bool = False,
138 | name: str = ""
139 | ):
140 | super().__init__(
141 | in_size,
142 | out_size,
143 | kernel_size,
144 | stride,
145 | padding,
146 | activation,
147 | bn,
148 | init,
149 | conv=nn.Conv1d,
150 | batch_norm=BatchNorm1d,
151 | bias=bias,
152 | preact=preact,
153 | name=name
154 | )
155 |
156 |
157 | class Conv2d(_ConvBase):
158 |
159 | def __init__(
160 | self,
161 | in_size: int,
162 | out_size: int,
163 | *,
164 | kernel_size: Tuple[int, int] = (1, 1),
165 | stride: Tuple[int, int] = (1, 1),
166 | padding: Tuple[int, int] = (0, 0),
167 | activation=nn.ReLU(inplace=True),
168 | bn: bool = False,
169 | init=nn.init.kaiming_normal_,
170 | bias: bool = True,
171 | preact: bool = False,
172 | name: str = ""
173 | ):
174 | super().__init__(
175 | in_size,
176 | out_size,
177 | kernel_size,
178 | stride,
179 | padding,
180 | activation,
181 | bn,
182 | init,
183 | conv=nn.Conv2d,
184 | batch_norm=BatchNorm2d,
185 | bias=bias,
186 | preact=preact,
187 | name=name
188 | )
189 |
190 |
191 | class Conv3d(_ConvBase):
192 |
193 | def __init__(
194 | self,
195 | in_size: int,
196 | out_size: int,
197 | *,
198 | kernel_size: Tuple[int, int, int] = (1, 1, 1),
199 | stride: Tuple[int, int, int] = (1, 1, 1),
200 | padding: Tuple[int, int, int] = (0, 0, 0),
201 | activation=nn.ReLU(inplace=True),
202 | bn: bool = False,
203 | init=nn.init.kaiming_normal_,
204 | bias: bool = True,
205 | preact: bool = False,
206 | name: str = ""
207 | ):
208 | super().__init__(
209 | in_size,
210 | out_size,
211 | kernel_size,
212 | stride,
213 | padding,
214 | activation,
215 | bn,
216 | init,
217 | conv=nn.Conv3d,
218 | batch_norm=BatchNorm3d,
219 | bias=bias,
220 | preact=preact,
221 | name=name
222 | )
223 |
224 |
225 | class FC(nn.Sequential):
226 |
227 | def __init__(
228 | self,
229 | in_size: int,
230 | out_size: int,
231 | *,
232 | activation=nn.ReLU(inplace=True),
233 | bn: bool = False,
234 | init=None,
235 | preact: bool = False,
236 | name: str = ""
237 | ):
238 | super().__init__()
239 |
240 | fc = nn.Linear(in_size, out_size, bias=not bn)
241 | if init is not None:
242 | init(fc.weight)
243 | if not bn:
244 | nn.init.constant_(fc.bias, 0)
245 |
246 | if preact:
247 | if bn:
248 | self.add_module(name + 'bn', BatchNorm1d(in_size))
249 |
250 | if activation is not None:
251 | self.add_module(name + 'activation', activation)
252 |
253 | self.add_module(name + 'fc', fc)
254 |
255 | if not preact:
256 | if bn:
257 | self.add_module(name + 'bn', BatchNorm1d(out_size))
258 |
259 | if activation is not None:
260 | self.add_module(name + 'activation', activation)
261 |
262 | def set_bn_momentum_default(bn_momentum):
263 |
264 | def fn(m):
265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
266 | m.momentum = bn_momentum
267 |
268 | return fn
269 |
270 |
271 | class BNMomentumScheduler(object):
272 |
273 | def __init__(
274 | self, model, bn_lambda, last_epoch=-1,
275 | setter=set_bn_momentum_default
276 | ):
277 | if not isinstance(model, nn.Module):
278 | raise RuntimeError(
279 | "Class '{}' is not a PyTorch nn Module".format(
280 | type(model).__name__
281 | )
282 | )
283 |
284 | self.model = model
285 | self.setter = setter
286 | self.lmbd = bn_lambda
287 |
288 | self.step(last_epoch + 1)
289 | self.last_epoch = last_epoch
290 |
291 | def step(self, epoch=None):
292 | if epoch is None:
293 | epoch = self.last_epoch + 1
294 |
295 | self.last_epoch = epoch
296 | self.model.apply(self.setter(self.lmbd(epoch)))
297 |
298 |
299 |
--------------------------------------------------------------------------------
/object-detection-3d/pointnet2/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from setuptools import setup
7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
8 | import glob
9 | import os.path as osp
10 |
11 | _ext_src_root = "_ext_src"
12 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob(
13 | "{}/src/*.cu".format(_ext_src_root)
14 | )
15 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root))
16 | this_dir = osp.dirname(osp.abspath(__file__))
17 |
18 | setup(
19 | name='pointnet2',
20 | ext_modules=[
21 | CUDAExtension(
22 | name='pointnet2._ext',
23 | sources=_ext_sources,
24 | extra_compile_args={
25 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))],
26 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))],
27 | },
28 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")],
29 | )
30 | ],
31 | cmdclass={
32 | 'build_ext': BuildExtension
33 | }
34 | )
35 |
--------------------------------------------------------------------------------
/object-detection-3d/sunrgbd/model_util_sunrgbd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | import sys
8 | import os
9 |
10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11 | sys.path.append(BASE_DIR)
12 | ROOT_DIR = os.path.dirname(BASE_DIR)
13 | sys.path.append(os.path.join(ROOT_DIR, 'utils'))
14 |
15 |
16 | class SunrgbdDatasetConfig(object):
17 | def __init__(self):
18 | self.num_class = 10
19 | self.num_heading_bin = 12
20 | self.num_size_cluster = 10
21 |
22 | self.type2class = {'bed': 0, 'table': 1, 'sofa': 2, 'chair': 3, 'toilet': 4, 'desk': 5, 'dresser': 6,
23 | 'night_stand': 7, 'bookshelf': 8, 'bathtub': 9}
24 | self.class2type = {self.type2class[t]: t for t in self.type2class}
25 | self.type2onehotclass = {'bed': 0, 'table': 1, 'sofa': 2, 'chair': 3, 'toilet': 4, 'desk': 5, 'dresser': 6,
26 | 'night_stand': 7, 'bookshelf': 8, 'bathtub': 9}
27 | self.type_mean_size = {'bathtub': np.array([0.765840, 1.398258, 0.472728]),
28 | 'bed': np.array([2.114256, 1.620300, 0.927272]),
29 | 'bookshelf': np.array([0.404671, 1.071108, 1.688889]),
30 | 'chair': np.array([0.591958, 0.552978, 0.827272]),
31 | 'desk': np.array([0.695190, 1.346299, 0.736364]),
32 | 'dresser': np.array([0.528526, 1.002642, 1.172878]),
33 | 'night_stand': np.array([0.500618, 0.632163, 0.683424]),
34 | 'sofa': np.array([0.923508, 1.867419, 0.845495]),
35 | 'table': np.array([0.791118, 1.279516, 0.718182]),
36 | 'toilet': np.array([0.699104, 0.454178, 0.756250])}
37 |
38 | self.mean_size_arr = np.zeros((self.num_size_cluster, 3))
39 | for i in range(self.num_size_cluster):
40 | self.mean_size_arr[i, :] = self.type_mean_size[self.class2type[i]]
41 |
42 | def size2class(self, size, type_name):
43 | ''' Convert 3D box size (l,w,h) to size class and size residual '''
44 | size_class = self.type2class[type_name]
45 | size_residual = size - self.type_mean_size[type_name]
46 | return size_class, size_residual
47 |
48 | def class2size(self, pred_cls, residual):
49 | ''' Inverse function to size2class '''
50 | mean_size = self.type_mean_size[self.class2type[pred_cls]]
51 | return mean_size + residual
52 |
53 | def angle2class(self, angle):
54 | ''' Convert continuous angle to discrete class
55 | [optinal] also small regression number from
56 | class center angle to current angle.
57 |
58 | angle is from 0-2pi (or -pi~pi), class center at 0, 1*(2pi/N), 2*(2pi/N) ... (N-1)*(2pi/N)
59 | return is class of int32 of 0,1,...,N-1 and a number such that
60 | class*(2pi/N) + number = angle
61 | '''
62 | num_class = self.num_heading_bin
63 | angle = angle % (2 * np.pi)
64 | assert (angle >= 0 and angle <= 2 * np.pi)
65 | angle_per_class = 2 * np.pi / float(num_class)
66 | shifted_angle = (angle + angle_per_class / 2) % (2 * np.pi)
67 | class_id = int(shifted_angle / angle_per_class)
68 | residual_angle = shifted_angle - (class_id * angle_per_class + angle_per_class / 2)
69 | return class_id, residual_angle
70 |
71 | def class2angle(self, pred_cls, residual, to_label_format=True):
72 | ''' Inverse function to angle2class '''
73 | num_class = self.num_heading_bin
74 | angle_per_class = 2 * np.pi / float(num_class)
75 | angle_center = pred_cls * angle_per_class
76 | angle = angle_center + residual
77 | if to_label_format and angle > np.pi:
78 | angle = angle - 2 * np.pi
79 | return angle
80 |
81 | def param2obb(self, center, heading_class, heading_residual, size_class, size_residual):
82 | heading_angle = self.class2angle(heading_class, heading_residual)
83 | box_size = self.class2size(int(size_class), size_residual)
84 | obb = np.zeros((7,))
85 | obb[0:3] = center
86 | obb[3:6] = box_size
87 | obb[6] = heading_angle * -1
88 | return obb
89 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_scheduler import get_scheduler
2 | from .logger import setup_logger
--------------------------------------------------------------------------------
/object-detection-3d/utils/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 |
23 | # modified from torchvision to also return the union
24 | def box_iou(boxes1, boxes2):
25 | area1 = box_area(boxes1)
26 | area2 = box_area(boxes2)
27 |
28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30 |
31 | wh = (rb - lt).clamp(min=0) # [N,M,2]
32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33 |
34 | union = area1[:, None] + area2 - inter
35 |
36 | iou = inter / union
37 | return iou, union
38 |
39 |
40 | def generalized_box_iou(boxes1, boxes2):
41 | """
42 | Generalized IoU from https://giou.stanford.edu/
43 |
44 | The boxes should be in [x0, y0, x1, y1] format
45 |
46 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
47 | and M = len(boxes2)
48 | """
49 | # degenerate boxes gives inf / nan results
50 | # so do an early check
51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
53 | iou, union = box_iou(boxes1, boxes2)
54 |
55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
57 |
58 | wh = (rb - lt).clamp(min=0) # [N,M,2]
59 | area = wh[:, :, 0] * wh[:, :, 1]
60 |
61 | return iou - (area - union) / area
62 |
63 |
64 | def masks_to_boxes(masks):
65 | """Compute the bounding boxes around the provided masks
66 |
67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
68 |
69 | Returns a [N, 4] tensors, with the boxes in xyxy format
70 | """
71 | if masks.numel() == 0:
72 | return torch.zeros((0, 4), device=masks.device)
73 |
74 | h, w = masks.shape[-2:]
75 |
76 | y = torch.arange(0, h, dtype=torch.float)
77 | x = torch.arange(0, w, dtype=torch.float)
78 | y, x = torch.meshgrid(y, x)
79 |
80 | x_mask = (masks * x.unsqueeze(0))
81 | x_max = x_mask.flatten(1).max(-1)[0]
82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
83 |
84 | y_mask = (masks * y.unsqueeze(0))
85 | y_max = y_mask.flatten(1).max(-1)[0]
86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
87 |
88 | return torch.stack([x_min, y_min, x_max, y_max], 1)
89 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/box_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """ Helper functions for calculating 2D and 3D bounding box IoU.
7 |
8 | Collected and written by Charles R. Qi
9 | Last modified: Jul 2019
10 | """
11 | from __future__ import print_function
12 |
13 | import numpy as np
14 | from scipy.spatial import ConvexHull
15 |
16 |
17 | def polygon_clip(subjectPolygon, clipPolygon):
18 | """ Clip a polygon with another polygon.
19 |
20 | Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python
21 |
22 | Args:
23 | subjectPolygon: a list of (x,y) 2d points, any polygon.
24 | clipPolygon: a list of (x,y) 2d points, has to be *convex*
25 | Note:
26 | **points have to be counter-clockwise ordered**
27 |
28 | Return:
29 | a list of (x,y) vertex point for the intersection polygon.
30 | """
31 |
32 | def inside(p):
33 | return (cp2[0] - cp1[0]) * (p[1] - cp1[1]) > (cp2[1] - cp1[1]) * (p[0] - cp1[0])
34 |
35 | def computeIntersection():
36 | dc = [cp1[0] - cp2[0], cp1[1] - cp2[1]]
37 | dp = [s[0] - e[0], s[1] - e[1]]
38 | n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0]
39 | n2 = s[0] * e[1] - s[1] * e[0]
40 | n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0])
41 | return [(n1 * dp[0] - n2 * dc[0]) * n3, (n1 * dp[1] - n2 * dc[1]) * n3]
42 |
43 | outputList = subjectPolygon
44 | cp1 = clipPolygon[-1]
45 |
46 | for clipVertex in clipPolygon:
47 | cp2 = clipVertex
48 | inputList = outputList
49 | outputList = []
50 | s = inputList[-1]
51 |
52 | for subjectVertex in inputList:
53 | e = subjectVertex
54 | if inside(e):
55 | if not inside(s):
56 | outputList.append(computeIntersection())
57 | outputList.append(e)
58 | elif inside(s):
59 | outputList.append(computeIntersection())
60 | s = e
61 | cp1 = cp2
62 | if len(outputList) == 0:
63 | return None
64 | return (outputList)
65 |
66 |
67 | def poly_area(x, y):
68 | """ Ref: http://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates """
69 | return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
70 |
71 |
72 | def convex_hull_intersection(p1, p2):
73 | """ Compute area of two convex hull's intersection area.
74 | p1,p2 are a list of (x,y) tuples of hull vertices.
75 | return a list of (x,y) for the intersection and its volume
76 | """
77 | inter_p = polygon_clip(p1, p2)
78 | if inter_p is not None:
79 | hull_inter = ConvexHull(inter_p)
80 | return inter_p, hull_inter.volume
81 | else:
82 | return None, 0.0
83 |
84 |
85 | def box3d_vol(corners):
86 | ''' corners: (8,3) no assumption on axis direction '''
87 | a = np.sqrt(np.sum((corners[0, :] - corners[1, :]) ** 2))
88 | b = np.sqrt(np.sum((corners[1, :] - corners[2, :]) ** 2))
89 | c = np.sqrt(np.sum((corners[0, :] - corners[4, :]) ** 2))
90 | return a * b * c
91 |
92 |
93 | def is_clockwise(p):
94 | x = p[:, 0]
95 | y = p[:, 1]
96 | return np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)) > 0
97 |
98 |
99 | def box3d_iou(corners1, corners2):
100 | ''' Compute 3D bounding box IoU.
101 |
102 | Input:
103 | corners1: numpy array (8,3), assume up direction is negative Y
104 | corners2: numpy array (8,3), assume up direction is negative Y
105 | Output:
106 | iou: 3D bounding box IoU
107 | iou_2d: bird's eye view 2D bounding box IoU
108 |
109 | todo (rqi): add more description on corner points' orders.
110 | '''
111 | # corner points are in counter clockwise order
112 | rect1 = [(corners1[i, 0], corners1[i, 2]) for i in range(3, -1, -1)]
113 | rect2 = [(corners2[i, 0], corners2[i, 2]) for i in range(3, -1, -1)]
114 | area1 = poly_area(np.array(rect1)[:, 0], np.array(rect1)[:, 1])
115 | area2 = poly_area(np.array(rect2)[:, 0], np.array(rect2)[:, 1])
116 | inter, inter_area = convex_hull_intersection(rect1, rect2)
117 | iou_2d = inter_area / (area1 + area2 - inter_area)
118 | ymax = min(corners1[0, 1], corners2[0, 1])
119 | ymin = max(corners1[4, 1], corners2[4, 1])
120 | inter_vol = inter_area * max(0.0, ymax - ymin)
121 | vol1 = box3d_vol(corners1)
122 | vol2 = box3d_vol(corners2)
123 | iou = inter_vol / (vol1 + vol2 - inter_vol)
124 | return iou, iou_2d
125 |
126 |
127 | def get_iou(bb1, bb2):
128 | """
129 | Calculate the Intersection over Union (IoU) of two 2D bounding boxes.
130 |
131 | Parameters
132 | ----------
133 | bb1 : dict
134 | Keys: {'x1', 'x2', 'y1', 'y2'}
135 | The (x1, y1) position is at the top left corner,
136 | the (x2, y2) position is at the bottom right corner
137 | bb2 : dict
138 | Keys: {'x1', 'x2', 'y1', 'y2'}
139 | The (x, y) position is at the top left corner,
140 | the (x2, y2) position is at the bottom right corner
141 |
142 | Returns
143 | -------
144 | float
145 | in [0, 1]
146 | """
147 | assert bb1['x1'] < bb1['x2']
148 | assert bb1['y1'] < bb1['y2']
149 | assert bb2['x1'] < bb2['x2']
150 | assert bb2['y1'] < bb2['y2']
151 |
152 | # determine the coordinates of the intersection rectangle
153 | x_left = max(bb1['x1'], bb2['x1'])
154 | y_top = max(bb1['y1'], bb2['y1'])
155 | x_right = min(bb1['x2'], bb2['x2'])
156 | y_bottom = min(bb1['y2'], bb2['y2'])
157 |
158 | if x_right < x_left or y_bottom < y_top:
159 | return 0.0
160 |
161 | # The intersection of two axis-aligned bounding boxes is always an
162 | # axis-aligned bounding box
163 | intersection_area = (x_right - x_left) * (y_bottom - y_top)
164 |
165 | # compute the area of both AABBs
166 | bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
167 | bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])
168 |
169 | # compute the intersection over union by taking the intersection
170 | # area and dividing it by the sum of prediction + ground-truth
171 | # areas - the interesection area
172 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
173 | assert iou >= 0.0
174 | assert iou <= 1.0
175 | return iou
176 |
177 |
178 | def box2d_iou(box1, box2):
179 | ''' Compute 2D bounding box IoU.
180 |
181 | Input:
182 | box1: tuple of (xmin,ymin,xmax,ymax)
183 | box2: tuple of (xmin,ymin,xmax,ymax)
184 | Output:
185 | iou: 2D IoU scalar
186 | '''
187 | return get_iou({'x1': box1[0], 'y1': box1[1], 'x2': box1[2], 'y2': box1[3]}, \
188 | {'x1': box2[0], 'y1': box2[1], 'x2': box2[2], 'y2': box2[3]})
189 |
190 |
191 | # -----------------------------------------------------------
192 | # Convert from box parameters to
193 | # -----------------------------------------------------------
194 | def roty(t):
195 | """Rotation about the y-axis."""
196 | c = np.cos(t)
197 | s = np.sin(t)
198 | return np.array([[c, 0, s],
199 | [0, 1, 0],
200 | [-s, 0, c]])
201 |
202 |
203 | def roty_batch(t):
204 | """Rotation about the y-axis.
205 | t: (x1,x2,...xn)
206 | return: (x1,x2,...,xn,3,3)
207 | """
208 | input_shape = t.shape
209 | output = np.zeros(tuple(list(input_shape) + [3, 3]))
210 | c = np.cos(t)
211 | s = np.sin(t)
212 | output[..., 0, 0] = c
213 | output[..., 0, 2] = s
214 | output[..., 1, 1] = 1
215 | output[..., 2, 0] = -s
216 | output[..., 2, 2] = c
217 | return output
218 |
219 |
220 | def get_3d_box(box_size, heading_angle, center):
221 | ''' box_size is array(l,w,h), heading_angle is radius clockwise from pos x axis, center is xyz of box center
222 | output (8,3) array for 3D box cornders
223 | Similar to utils/compute_orientation_3d
224 | '''
225 | R = roty(heading_angle)
226 | l, w, h = box_size
227 | x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2];
228 | y_corners = [h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2];
229 | z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2];
230 | corners_3d = np.dot(R, np.vstack([x_corners, y_corners, z_corners]))
231 | corners_3d[0, :] = corners_3d[0, :] + center[0];
232 | corners_3d[1, :] = corners_3d[1, :] + center[1];
233 | corners_3d[2, :] = corners_3d[2, :] + center[2];
234 | corners_3d = np.transpose(corners_3d)
235 | return corners_3d
236 |
237 |
238 | def get_3d_box_batch(box_size, heading_angle, center):
239 | ''' box_size: [x1,x2,...,xn,3]
240 | heading_angle: [x1,x2,...,xn]
241 | center: [x1,x2,...,xn,3]
242 | Return:
243 | [x1,x3,...,xn,8,3]
244 | '''
245 | input_shape = heading_angle.shape
246 | R = roty_batch(heading_angle)
247 | l = np.expand_dims(box_size[..., 0], -1) # [x1,...,xn,1]
248 | w = np.expand_dims(box_size[..., 1], -1)
249 | h = np.expand_dims(box_size[..., 2], -1)
250 | corners_3d = np.zeros(tuple(list(input_shape) + [8, 3]))
251 | corners_3d[..., :, 0] = np.concatenate((l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2), -1)
252 | corners_3d[..., :, 1] = np.concatenate((h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2), -1)
253 | corners_3d[..., :, 2] = np.concatenate((w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2), -1)
254 | tlist = [i for i in range(len(input_shape))]
255 | tlist += [len(input_shape) + 1, len(input_shape)]
256 | corners_3d = np.matmul(corners_3d, np.transpose(R, tuple(tlist)))
257 | corners_3d += np.expand_dims(center, -2)
258 | return corners_3d
259 |
260 |
261 | if __name__ == '__main__':
262 |
263 | # Function for polygon ploting
264 | import matplotlib
265 | from matplotlib.patches import Polygon
266 | from matplotlib.collections import PatchCollection
267 | import matplotlib.pyplot as plt
268 |
269 |
270 | def plot_polys(plist, scale=500.0):
271 | fig, ax = plt.subplots()
272 | patches = []
273 | for p in plist:
274 | poly = Polygon(np.array(p) / scale, True)
275 | patches.append(poly)
276 |
277 |
278 | pc = PatchCollection(patches, cmap=matplotlib.cm.jet, alpha=0.5)
279 | colors = 100 * np.random.rand(len(patches))
280 | pc.set_array(np.array(colors))
281 | ax.add_collection(pc)
282 | plt.show()
283 |
284 | # Demo on ConvexHull
285 | points = np.random.rand(30, 2) # 30 random points in 2-D
286 | hull = ConvexHull(points)
287 | # **In 2D "volume" is is area, "area" is perimeter
288 | print(('Hull area: ', hull.volume))
289 | for simplex in hull.simplices:
290 | print(simplex)
291 |
292 | # Demo on convex hull overlaps
293 | sub_poly = [(0, 0), (300, 0), (300, 300), (0, 300)]
294 | clip_poly = [(150, 150), (300, 300), (150, 450), (0, 300)]
295 | inter_poly = polygon_clip(sub_poly, clip_poly)
296 | print(poly_area(np.array(inter_poly)[:, 0], np.array(inter_poly)[:, 1]))
297 |
298 | # Test convex hull interaction function
299 | rect1 = [(50, 0), (50, 300), (300, 300), (300, 0)]
300 | rect2 = [(150, 150), (300, 300), (150, 450), (0, 300)]
301 | plot_polys([rect1, rect2])
302 | inter, area = convex_hull_intersection(rect1, rect2)
303 | print((inter, area))
304 | if inter is not None:
305 | print(poly_area(np.array(inter)[:, 0], np.array(inter)[:, 1]))
306 |
307 | print('------------------')
308 | rect1 = [(0.30026005199835404, 8.9408694211408424), \
309 | (-1.1571105364358421, 9.4686676477075533), \
310 | (0.1777082043006144, 13.154404877812102), \
311 | (1.6350787927348105, 12.626606651245391)]
312 | rect1 = [rect1[0], rect1[3], rect1[2], rect1[1]]
313 | rect2 = [(0.23908745901608636, 8.8551095691132886), \
314 | (-1.2771419487733995, 9.4269062966181956), \
315 | (0.13138836963152717, 13.161896351296868), \
316 | (1.647617777421013, 12.590099623791961)]
317 | rect2 = [rect2[0], rect2[3], rect2[2], rect2[1]]
318 | plot_polys([rect1, rect2])
319 | inter, area = convex_hull_intersection(rect1, rect2)
320 | print((inter, area))
321 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/eval_det.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """ Generic Code for Object Detection Evaluation
7 |
8 | Input:
9 | For each class:
10 | For each image:
11 | Predictions: box, score
12 | Groundtruths: box
13 |
14 | Output:
15 | For each class:
16 | precision-recal and average precision
17 |
18 | Author: Charles R. Qi
19 |
20 | Ref: https://raw.githubusercontent.com/rbgirshick/py-faster-rcnn/master/lib/datasets/voc_eval.py
21 | """
22 | import numpy as np
23 |
24 |
25 | def voc_ap(rec, prec, use_07_metric=False):
26 | """ ap = voc_ap(rec, prec, [use_07_metric])
27 | Compute VOC AP given precision and recall.
28 | If use_07_metric is true, uses the
29 | VOC 07 11 point method (default:False).
30 | """
31 | if use_07_metric:
32 | # 11 point metric
33 | ap = 0.
34 | for t in np.arange(0., 1.1, 0.1):
35 | if np.sum(rec >= t) == 0:
36 | p = 0
37 | else:
38 | p = np.max(prec[rec >= t])
39 | ap = ap + p / 11.
40 | else:
41 | # correct AP calculation
42 | # first append sentinel values at the end
43 | mrec = np.concatenate(([0.], rec, [1.]))
44 | mpre = np.concatenate(([0.], prec, [0.]))
45 |
46 | # compute the precision envelope
47 | for i in range(mpre.size - 1, 0, -1):
48 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
49 |
50 | # to calculate area under PR curve, look for points
51 | # where X axis (recall) changes value
52 | i = np.where(mrec[1:] != mrec[:-1])[0]
53 |
54 | # and sum (\Delta recall) * prec
55 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
56 | return ap
57 |
58 |
59 | import os
60 | import sys
61 |
62 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
63 | from metric_util import calc_iou # axis-aligned 3D box IoU
64 |
65 |
66 | def get_iou(bb1, bb2):
67 | """ Compute IoU of two bounding boxes.
68 | ** Define your bod IoU function HERE **
69 | """
70 | # pass
71 | iou3d = calc_iou(bb1, bb2)
72 | return iou3d
73 |
74 |
75 | from box_util import box3d_iou
76 |
77 |
78 | def get_iou_obb(bb1, bb2):
79 | iou3d, iou2d = box3d_iou(bb1, bb2)
80 | return iou3d
81 |
82 |
83 | def get_iou_main(get_iou_func, args):
84 | return get_iou_func(*args)
85 |
86 |
87 | def eval_det_cls(pred, gt, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou):
88 | """ Generic functions to compute precision/recall for object detection
89 | for a single class.
90 | Input:
91 | pred: map of {img_id: [(bbox, score)]} where bbox is numpy array
92 | gt: map of {img_id: [bbox]}
93 | ovthresh: scalar, iou threshold
94 | use_07_metric: bool, if True use VOC07 11 point method
95 | Output:
96 | rec: numpy array of length nd
97 | prec: numpy array of length nd
98 | ap: scalar, average precision
99 | """
100 |
101 | # construct gt objects
102 | class_recs = {} # {img_id: {'bbox': bbox list, 'det': matched list}}
103 | npos = 0
104 | for img_id in gt.keys():
105 | bbox = np.array(gt[img_id])
106 | det = [False] * len(bbox)
107 | npos += len(bbox)
108 | class_recs[img_id] = {'bbox': bbox, 'det': det}
109 | # pad empty list to all other imgids
110 | for img_id in pred.keys():
111 | if img_id not in gt:
112 | class_recs[img_id] = {'bbox': np.array([]), 'det': []}
113 |
114 | # construct dets
115 | image_ids = []
116 | confidence = []
117 | BB = []
118 | for img_id in pred.keys():
119 | for box, score in pred[img_id]:
120 | image_ids.append(img_id)
121 | confidence.append(score)
122 | BB.append(box)
123 | confidence = np.array(confidence)
124 | BB = np.array(BB) # (nd,4 or 8,3 or 6)
125 |
126 | # sort by confidence
127 | sorted_ind = np.argsort(-confidence)
128 | sorted_scores = np.sort(-confidence)
129 | BB = BB[sorted_ind, ...]
130 | image_ids = [image_ids[x] for x in sorted_ind]
131 |
132 | # go down dets and mark TPs and FPs
133 | nd = len(image_ids)
134 | tp = np.zeros(nd)
135 | fp = np.zeros(nd)
136 | for d in range(nd):
137 | # if d%100==0: print(d)
138 | R = class_recs[image_ids[d]]
139 | bb = BB[d, ...].astype(float)
140 | ovmax = -np.inf
141 | BBGT = R['bbox'].astype(float)
142 |
143 | if BBGT.size > 0:
144 | # compute overlaps
145 | for j in range(BBGT.shape[0]):
146 | iou = get_iou_main(get_iou_func, (bb, BBGT[j, ...]))
147 | if iou > ovmax:
148 | ovmax = iou
149 | jmax = j
150 |
151 | # print d, ovmax
152 | if ovmax > ovthresh:
153 | if not R['det'][jmax]:
154 | tp[d] = 1.
155 | R['det'][jmax] = 1
156 | else:
157 | fp[d] = 1.
158 | else:
159 | fp[d] = 1.
160 |
161 | # compute precision recall
162 | fp = np.cumsum(fp)
163 | tp = np.cumsum(tp)
164 | rec = tp / float(npos)
165 | # print('NPOS: ', npos)
166 | # avoid divide by zero in case the first detection matches a difficult
167 | # ground truth
168 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
169 | ap = voc_ap(rec, prec, use_07_metric)
170 |
171 | return rec, prec, ap
172 |
173 |
174 | def eval_det_cls_wrapper(arguments):
175 | pred, gt, ovthresh, use_07_metric, get_iou_func = arguments
176 | rec, prec, ap = eval_det_cls(pred, gt, ovthresh, use_07_metric, get_iou_func)
177 | return (rec, prec, ap)
178 |
179 |
180 | def eval_det(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou):
181 | """ Generic functions to compute precision/recall for object detection
182 | for multiple classes.
183 | Input:
184 | pred_all: map of {img_id: [(classname, bbox, score)]}
185 | gt_all: map of {img_id: [(classname, bbox)]}
186 | ovthresh: scalar, iou threshold
187 | use_07_metric: bool, if true use VOC07 11 point method
188 | Output:
189 | rec: {classname: rec}
190 | prec: {classname: prec_all}
191 | ap: {classname: scalar}
192 | """
193 | pred = {} # map {classname: pred}
194 | gt = {} # map {classname: gt}
195 | for img_id in pred_all.keys():
196 | for classname, bbox, score in pred_all[img_id]:
197 | if classname not in pred: pred[classname] = {}
198 | if img_id not in pred[classname]:
199 | pred[classname][img_id] = []
200 | if classname not in gt: gt[classname] = {}
201 | if img_id not in gt[classname]:
202 | gt[classname][img_id] = []
203 | pred[classname][img_id].append((bbox, score))
204 | for img_id in gt_all.keys():
205 | for classname, bbox in gt_all[img_id]:
206 | if classname not in gt: gt[classname] = {}
207 | if img_id not in gt[classname]:
208 | gt[classname][img_id] = []
209 | gt[classname][img_id].append(bbox)
210 |
211 | rec = {}
212 | prec = {}
213 | ap = {}
214 | for classname in gt.keys():
215 | # print('Computing AP for class: ', classname)
216 | rec[classname], prec[classname], ap[classname] = eval_det_cls(pred[classname], gt[classname], ovthresh,
217 | use_07_metric, get_iou_func)
218 | # print(classname, ap[classname])
219 |
220 | return rec, prec, ap
221 |
222 |
223 | from multiprocessing import Pool
224 |
225 |
226 | def eval_det_multiprocessing(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou):
227 | """ Generic functions to compute precision/recall for object detection
228 | for multiple classes.
229 | Input:
230 | pred_all: map of {img_id: [(classname, bbox, score)]}
231 | gt_all: map of {img_id: [(classname, bbox)]}
232 | ovthresh: scalar, iou threshold
233 | use_07_metric: bool, if true use VOC07 11 point method
234 | Output:
235 | rec: {classname: rec}
236 | prec: {classname: prec_all}
237 | ap: {classname: scalar}
238 | """
239 | pred = {} # map {classname: pred}
240 | gt = {} # map {classname: gt}
241 | for img_id in pred_all.keys():
242 | for classname, bbox, score in pred_all[img_id]:
243 | if classname not in pred: pred[classname] = {}
244 | if img_id not in pred[classname]:
245 | pred[classname][img_id] = []
246 | if classname not in gt: gt[classname] = {}
247 | if img_id not in gt[classname]:
248 | gt[classname][img_id] = []
249 | pred[classname][img_id].append((bbox, score))
250 | for img_id in gt_all.keys():
251 | for classname, bbox in gt_all[img_id]:
252 | if classname not in gt: gt[classname] = {}
253 | if img_id not in gt[classname]:
254 | gt[classname][img_id] = []
255 | gt[classname][img_id].append(bbox)
256 | # print(pred)
257 | rec = {}
258 | prec = {}
259 | ap = {}
260 | p = Pool(processes=10)
261 | ret_values = p.map(eval_det_cls_wrapper,
262 | [(pred[classname], gt[classname], ovthresh, use_07_metric, get_iou_func) for classname in
263 | gt.keys() if classname in pred])
264 | p.close()
265 | for i, classname in enumerate(gt.keys()):
266 | if classname in pred:
267 | rec[classname], prec[classname], ap[classname] = ret_values[i]
268 | else:
269 | rec[classname] = 0
270 | prec[classname] = 0
271 | ap[classname] = 0
272 | # print(classname, ap[classname])
273 |
274 | return rec, prec, ap
275 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import functools
3 | import logging
4 | import os
5 | import sys
6 | from termcolor import colored
7 |
8 |
9 | class _ColorfulFormatter(logging.Formatter):
10 | def __init__(self, *args, **kwargs):
11 | self._root_name = kwargs.pop("root_name") + "."
12 | self._abbrev_name = kwargs.pop("abbrev_name", "")
13 | if len(self._abbrev_name):
14 | self._abbrev_name = self._abbrev_name + "."
15 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
16 |
17 | def formatMessage(self, record):
18 | record.name = record.name.replace(self._root_name, self._abbrev_name)
19 | log = super(_ColorfulFormatter, self).formatMessage(record)
20 | if record.levelno == logging.WARNING:
21 | prefix = colored("WARNING", "red", attrs=["blink"])
22 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
23 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
24 | else:
25 | return log
26 | return prefix + " " + log
27 |
28 |
29 | # so that calling setup_logger multiple times won't add many handlers
30 | @functools.lru_cache()
31 | def setup_logger(
32 | output=None, distributed_rank=0, *, color=True, name="log", abbrev_name=None
33 | ):
34 | """
35 | Initialize the detectron2 logger and set its verbosity level to "INFO".
36 |
37 | Args:
38 | output (str): a file name or a directory to save log. If None, will not save log file.
39 | If ends with ".txt" or ".log", assumed to be a file name.
40 | Otherwise, logs will be saved to `output/log.txt`.
41 | name (str): the root module name of this logger
42 |
43 | Returns:
44 | logging.Logger: a logger
45 | """
46 | logger = logging.getLogger(name)
47 | logger.setLevel(logging.DEBUG)
48 | logger.propagate = False
49 |
50 | if abbrev_name is None:
51 | abbrev_name = name
52 |
53 | plain_formatter = logging.Formatter(
54 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
55 | )
56 | # stdout logging: master only
57 | if distributed_rank == 0:
58 | ch = logging.StreamHandler(stream=sys.stdout)
59 | ch.setLevel(logging.DEBUG)
60 | if color:
61 | formatter = _ColorfulFormatter(
62 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
63 | datefmt="%m/%d %H:%M:%S",
64 | root_name=name,
65 | abbrev_name=str(abbrev_name),
66 | )
67 | else:
68 | formatter = plain_formatter
69 | ch.setFormatter(formatter)
70 | logger.addHandler(ch)
71 |
72 | # file logging: all workers
73 | if output is not None:
74 | if output.endswith(".txt") or output.endswith(".log"):
75 | filename = output
76 | else:
77 | filename = os.path.join(output, "log.txt")
78 | if distributed_rank > 0:
79 | filename = filename + f".rank{distributed_rank}"
80 | os.makedirs(os.path.dirname(filename), exist_ok=True)
81 |
82 | fh = logging.StreamHandler(_cached_log_stream(filename))
83 | fh.setLevel(logging.DEBUG)
84 | fh.setFormatter(plain_formatter)
85 | logger.addHandler(fh)
86 |
87 | return logger
88 |
89 |
90 | # cache the opened file object, so that different calls to `setup_logger`
91 | # with the same file name can safely write to the same file.
92 | @functools.lru_cache(maxsize=None)
93 | def _cached_log_stream(filename):
94 | return open(filename, "a")
95 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # noinspection PyProtectedMember
2 | from torch.optim.lr_scheduler import _LRScheduler, MultiStepLR, CosineAnnealingLR
3 |
4 |
5 | # noinspection PyAttributeOutsideInit
6 | class GradualWarmupScheduler(_LRScheduler):
7 | """ Gradually warm-up(increasing) learning rate in optimizer.
8 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: init learning rate = base lr / multiplier
12 | warmup_epoch: target learning rate is reached at warmup_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler, last_epoch=-1):
17 | self.multiplier = multiplier
18 | if self.multiplier <= 1.:
19 | raise ValueError('multiplier should be greater than 1.')
20 | self.warmup_epoch = warmup_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super().__init__(optimizer, last_epoch=last_epoch)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.warmup_epoch:
27 | return self.after_scheduler.get_lr()
28 | else:
29 | return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.)
30 | for base_lr in self.base_lrs]
31 |
32 | def step(self, epoch=None):
33 | if epoch is None:
34 | epoch = self.last_epoch + 1
35 | self.last_epoch = epoch
36 | if epoch > self.warmup_epoch:
37 | self.after_scheduler.step(epoch - self.warmup_epoch)
38 | else:
39 | super(GradualWarmupScheduler, self).step(epoch)
40 |
41 | def state_dict(self):
42 | """Returns the state of the scheduler as a :class:`dict`.
43 |
44 | It contains an entry for every variable in self.__dict__ which
45 | is not the optimizer.
46 | """
47 |
48 | state = {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'after_scheduler'}
49 | state['after_scheduler'] = self.after_scheduler.state_dict()
50 | return state
51 |
52 | def load_state_dict(self, state_dict):
53 | """Loads the schedulers state.
54 |
55 | Arguments:
56 | state_dict (dict): scheduler state. Should be an object returned
57 | from a call to :meth:`state_dict`.
58 | """
59 |
60 | after_scheduler_state = state_dict.pop('after_scheduler')
61 | self.__dict__.update(state_dict)
62 | self.after_scheduler.load_state_dict(after_scheduler_state)
63 |
64 |
65 | def get_scheduler(optimizer, n_iter_per_epoch, args):
66 | if "cosine" in args.lr_scheduler:
67 | scheduler = CosineAnnealingLR(
68 | optimizer=optimizer,
69 | eta_min=0.000001,
70 | T_max=(args.max_epoch - args.warmup_epoch) * n_iter_per_epoch)
71 | elif "step" in args.lr_scheduler:
72 | if isinstance(args.lr_decay_epochs, int):
73 | args.lr_decay_epochs = [args.lr_decay_epochs]
74 | scheduler = MultiStepLR(
75 | optimizer=optimizer,
76 | gamma=args.lr_decay_rate,
77 | milestones=[(m - args.warmup_epoch) * n_iter_per_epoch for m in args.lr_decay_epochs])
78 | else:
79 | raise NotImplementedError(f"scheduler {args.lr_scheduler} not supported")
80 |
81 | if args.warmup_epoch > 0:
82 | scheduler = GradualWarmupScheduler(
83 | optimizer,
84 | multiplier=args.warmup_multiplier,
85 | after_scheduler=scheduler,
86 | warmup_epoch=args.warmup_epoch * n_iter_per_epoch)
87 | return scheduler
88 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/metric_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """ Utility functions for metric evaluation.
7 |
8 | Author: Or Litany and Charles R. Qi
9 | """
10 |
11 | import os
12 | import sys
13 | import torch
14 |
15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
16 | sys.path.append(BASE_DIR)
17 |
18 | import numpy as np
19 |
20 | # Mesh IO
21 | import trimesh
22 |
23 |
24 | # ----------------------------------------
25 | # Precision and Recall
26 | # ----------------------------------------
27 |
28 | def multi_scene_precision_recall(labels, pred, iou_thresh, conf_thresh, label_mask, pred_mask=None):
29 | '''
30 | Args:
31 | labels: (B, N, 6)
32 | pred: (B, M, 6)
33 | iou_thresh: scalar
34 | conf_thresh: scalar
35 | label_mask: (B, N,) with values in 0 or 1 to indicate which GT boxes to consider.
36 | pred_mask: (B, M,) with values in 0 or 1 to indicate which PRED boxes to consider.
37 | Returns:
38 | TP,FP,FN,Precision,Recall
39 | '''
40 | # Make sure the masks are not Torch tensor, otherwise the mask==1 returns uint8 array instead
41 | # of True/False array as in numpy
42 | assert (not torch.is_tensor(label_mask))
43 | assert (not torch.is_tensor(pred_mask))
44 | TP, FP, FN = 0, 0, 0
45 | if label_mask is None: label_mask = np.ones((labels.shape[0], labels.shape[1]))
46 | if pred_mask is None: pred_mask = np.ones((pred.shape[0], pred.shape[1]))
47 | for batch_idx in range(labels.shape[0]):
48 | TP_i, FP_i, FN_i = single_scene_precision_recall(labels[batch_idx, label_mask[batch_idx, :] == 1, :],
49 | pred[batch_idx, pred_mask[batch_idx, :] == 1, :],
50 | iou_thresh, conf_thresh)
51 | TP += TP_i
52 | FP += FP_i
53 | FN += FN_i
54 |
55 | return TP, FP, FN, precision_recall(TP, FP, FN)
56 |
57 |
58 | def single_scene_precision_recall(labels, pred, iou_thresh, conf_thresh):
59 | """Compute P and R for predicted bounding boxes. Ignores classes!
60 | Args:
61 | labels: (N x bbox) ground-truth bounding boxes (6 dims)
62 | pred: (M x (bbox + conf)) predicted bboxes with confidence and maybe classification
63 | Returns:
64 | TP, FP, FN
65 | """
66 |
67 | # for each pred box with high conf (C), compute IoU with all gt boxes.
68 | # TP = number of times IoU > th ; FP = C - TP
69 | # FN - number of scene objects without good match
70 |
71 | gt_bboxes = labels[:, :6]
72 |
73 | num_scene_bboxes = gt_bboxes.shape[0]
74 | conf = pred[:, 6]
75 |
76 | conf_pred_bbox = pred[np.where(conf > conf_thresh)[0], :6]
77 | num_conf_pred_bboxes = conf_pred_bbox.shape[0]
78 |
79 | # init an array to keep iou between generated and scene bboxes
80 | iou_arr = np.zeros([num_conf_pred_bboxes, num_scene_bboxes])
81 | for g_idx in range(num_conf_pred_bboxes):
82 | for s_idx in range(num_scene_bboxes):
83 | iou_arr[g_idx, s_idx] = calc_iou(conf_pred_bbox[g_idx, :], gt_bboxes[s_idx, :])
84 |
85 | good_match_arr = (iou_arr >= iou_thresh)
86 |
87 | TP = good_match_arr.any(axis=1).sum()
88 | FP = num_conf_pred_bboxes - TP
89 | FN = num_scene_bboxes - good_match_arr.any(axis=0).sum()
90 |
91 | return TP, FP, FN
92 |
93 |
94 | def precision_recall(TP, FP, FN):
95 | Prec = 1.0 * TP / (TP + FP) if TP + FP > 0 else 0
96 | Rec = 1.0 * TP / (TP + FN)
97 | return Prec, Rec
98 |
99 |
100 | def calc_iou(box_a, box_b):
101 | """Computes IoU of two axis aligned bboxes.
102 | Args:
103 | box_a, box_b: 6D of center and lengths
104 | Returns:
105 | iou
106 | """
107 |
108 | max_a = box_a[0:3] + box_a[3:6] / 2
109 | max_b = box_b[0:3] + box_b[3:6] / 2
110 | min_max = np.array([max_a, max_b]).min(0)
111 |
112 | min_a = box_a[0:3] - box_a[3:6] / 2
113 | min_b = box_b[0:3] - box_b[3:6] / 2
114 | max_min = np.array([min_a, min_b]).max(0)
115 | if not ((min_max > max_min).all()):
116 | return 0.0
117 |
118 | intersection = (min_max - max_min).prod()
119 | vol_a = box_a[3:6].prod()
120 | vol_b = box_b[3:6].prod()
121 | union = vol_a + vol_b - intersection
122 | return 1.0 * intersection / union
123 |
124 |
125 | if __name__ == '__main__':
126 | print('running some tests')
127 |
128 | ############
129 | ## Test IoU
130 | ############
131 | box_a = np.array([0, 0, 0, 1, 1, 1])
132 | box_b = np.array([0, 0, 0, 2, 2, 2])
133 | expected_iou = 1.0 / 8
134 | pred_iou = calc_iou(box_a, box_b)
135 | assert expected_iou == pred_iou, 'function returned wrong IoU'
136 |
137 | box_a = np.array([0, 0, 0, 1, 1, 1])
138 | box_b = np.array([10, 10, 10, 2, 2, 2])
139 | expected_iou = 0.0
140 | pred_iou = calc_iou(box_a, box_b)
141 | assert expected_iou == pred_iou, 'function returned wrong IoU'
142 |
143 | print('IoU test -- PASSED')
144 |
145 | #########################
146 | ## Test Precition Recall
147 | #########################
148 | gt_boxes = np.array([[0, 0, 0, 1, 1, 1], [3, 0, 1, 1, 10, 1]])
149 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0], [3, 0, 1, 1, 10, 1, 0.9]])
150 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5)
151 | assert TP == 2 and FP == 0 and FN == 0
152 | assert precision_recall(TP, FP, FN) == (1, 1)
153 |
154 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0]])
155 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5)
156 | assert TP == 1 and FP == 0 and FN == 1
157 | assert precision_recall(TP, FP, FN) == (1, 0.5)
158 |
159 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0], [-1, -1, 0, 0.1, 0.1, 1, 1.0]])
160 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5)
161 | assert TP == 1 and FP == 1 and FN == 1
162 | assert precision_recall(TP, FP, FN) == (0.5, 0.5)
163 |
164 | # wrong box has low confidence
165 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0], [-1, -1, 0, 0.1, 0.1, 1, 0.1]])
166 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5)
167 | assert TP == 1 and FP == 0 and FN == 1
168 | assert precision_recall(TP, FP, FN) == (1, 0.5)
169 |
170 | print('Precition Recall test -- PASSED')
171 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/nms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | import numpy as np
7 | from pc_util import bbox_corner_dist_measure
8 |
9 | # boxes are axis aigned 2D boxes of shape (n,5) in FLOAT numbers with (x1,y1,x2,y2,score)
10 | ''' Ref: https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
11 | Ref: https://github.com/vickyboy47/nms-python/blob/master/nms.py
12 | '''
13 |
14 |
15 | def nms_2d(boxes, overlap_threshold):
16 | x1 = boxes[:, 0]
17 | y1 = boxes[:, 1]
18 | x2 = boxes[:, 2]
19 | y2 = boxes[:, 3]
20 | score = boxes[:, 4]
21 | area = (x2 - x1) * (y2 - y1)
22 |
23 | I = np.argsort(score)
24 | pick = []
25 | while (I.size != 0):
26 | last = I.size
27 | i = I[-1]
28 | pick.append(i)
29 | suppress = [last - 1]
30 | for pos in range(last - 1):
31 | j = I[pos]
32 | xx1 = max(x1[i], x1[j])
33 | yy1 = max(y1[i], y1[j])
34 | xx2 = min(x2[i], x2[j])
35 | yy2 = min(y2[i], y2[j])
36 | w = xx2 - xx1
37 | h = yy2 - yy1
38 | if (w > 0 and h > 0):
39 | o = w * h / area[j]
40 | print('Overlap is', o)
41 | if (o > overlap_threshold):
42 | suppress.append(pos)
43 | I = np.delete(I, suppress)
44 | return pick
45 |
46 |
47 | def nms_2d_faster(boxes, overlap_threshold, old_type=False):
48 | x1 = boxes[:, 0]
49 | y1 = boxes[:, 1]
50 | x2 = boxes[:, 2]
51 | y2 = boxes[:, 3]
52 | score = boxes[:, 4]
53 | area = (x2 - x1) * (y2 - y1)
54 |
55 | I = np.argsort(score)
56 | pick = []
57 | while (I.size != 0):
58 | last = I.size
59 | i = I[-1]
60 | pick.append(i)
61 |
62 | xx1 = np.maximum(x1[i], x1[I[:last - 1]])
63 | yy1 = np.maximum(y1[i], y1[I[:last - 1]])
64 | xx2 = np.minimum(x2[i], x2[I[:last - 1]])
65 | yy2 = np.minimum(y2[i], y2[I[:last - 1]])
66 |
67 | w = np.maximum(0, xx2 - xx1)
68 | h = np.maximum(0, yy2 - yy1)
69 |
70 | if old_type:
71 | o = (w * h) / area[I[:last - 1]]
72 | else:
73 | inter = w * h
74 | o = inter / (area[i] + area[I[:last - 1]] - inter)
75 |
76 | I = np.delete(I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])))
77 |
78 | return pick
79 |
80 |
81 | def nms_3d_faster(boxes, overlap_threshold, old_type=False):
82 | x1 = boxes[:, 0]
83 | y1 = boxes[:, 1]
84 | z1 = boxes[:, 2]
85 | x2 = boxes[:, 3]
86 | y2 = boxes[:, 4]
87 | z2 = boxes[:, 5]
88 | score = boxes[:, 6]
89 | area = (x2 - x1) * (y2 - y1) * (z2 - z1)
90 |
91 | I = np.argsort(score)
92 | pick = []
93 | while (I.size != 0):
94 | last = I.size
95 | i = I[-1]
96 | pick.append(i)
97 |
98 | xx1 = np.maximum(x1[i], x1[I[:last - 1]])
99 | yy1 = np.maximum(y1[i], y1[I[:last - 1]])
100 | zz1 = np.maximum(z1[i], z1[I[:last - 1]])
101 | xx2 = np.minimum(x2[i], x2[I[:last - 1]])
102 | yy2 = np.minimum(y2[i], y2[I[:last - 1]])
103 | zz2 = np.minimum(z2[i], z2[I[:last - 1]])
104 |
105 | l = np.maximum(0, xx2 - xx1)
106 | w = np.maximum(0, yy2 - yy1)
107 | h = np.maximum(0, zz2 - zz1)
108 |
109 | if old_type:
110 | o = (l * w * h) / area[I[:last - 1]]
111 | else:
112 | inter = l * w * h
113 | o = inter / (area[i] + area[I[:last - 1]] - inter)
114 |
115 | I = np.delete(I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])))
116 |
117 | return pick
118 |
119 |
120 | def nms_3d_faster_samecls(boxes, overlap_threshold, old_type=False):
121 | x1 = boxes[:, 0]
122 | y1 = boxes[:, 1]
123 | z1 = boxes[:, 2]
124 | x2 = boxes[:, 3]
125 | y2 = boxes[:, 4]
126 | z2 = boxes[:, 5]
127 | score = boxes[:, 6]
128 | cls = boxes[:, 7]
129 | area = (x2 - x1) * (y2 - y1) * (z2 - z1)
130 |
131 | I = np.argsort(score)
132 | pick = []
133 | while (I.size != 0):
134 | last = I.size
135 | i = I[-1]
136 | pick.append(i)
137 |
138 | xx1 = np.maximum(x1[i], x1[I[:last - 1]])
139 | yy1 = np.maximum(y1[i], y1[I[:last - 1]])
140 | zz1 = np.maximum(z1[i], z1[I[:last - 1]])
141 | xx2 = np.minimum(x2[i], x2[I[:last - 1]])
142 | yy2 = np.minimum(y2[i], y2[I[:last - 1]])
143 | zz2 = np.minimum(z2[i], z2[I[:last - 1]])
144 | cls1 = cls[i]
145 | cls2 = cls[I[:last - 1]]
146 |
147 | l = np.maximum(0, xx2 - xx1)
148 | w = np.maximum(0, yy2 - yy1)
149 | h = np.maximum(0, zz2 - zz1)
150 |
151 | if old_type:
152 | o = (l * w * h) / area[I[:last - 1]]
153 | else:
154 | inter = l * w * h
155 | o = inter / (area[i] + area[I[:last - 1]] - inter)
156 | o = o * (cls1 == cls2)
157 |
158 | I = np.delete(I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0])))
159 |
160 | return pick
161 |
162 |
163 | def nms_crnr_dist(boxes, conf, overlap_threshold):
164 | I = np.argsort(conf)
165 | pick = []
166 | while (I.size != 0):
167 | last = I.size
168 | i = I[-1]
169 | pick.append(i)
170 |
171 | scores = []
172 | for ind in I[:-1]:
173 | scores.append(bbox_corner_dist_measure(boxes[i, :], boxes[ind, :]))
174 |
175 | I = np.delete(I, np.concatenate(([last - 1], np.where(np.array(scores) > overlap_threshold)[0])))
176 |
177 | return pick
178 |
179 |
180 | if __name__ == '__main__':
181 | a = np.random.random((100, 5))
182 | print(nms_2d(a, 0.9))
183 | print(nms_2d_faster(a, 0.9))
184 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/nn_distance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """ Chamfer distance in Pytorch.
7 | Author: Charles R. Qi
8 | """
9 |
10 | import torch
11 | import torch.nn as nn
12 | import numpy as np
13 |
14 |
15 | def huber_loss(error, delta=1.0):
16 | """
17 | Args:
18 | error: Torch tensor (d1,d2,...,dk)
19 | Returns:
20 | loss: Torch tensor (d1,d2,...,dk)
21 |
22 | x = error = pred - gt or dist(pred,gt)
23 | 0.5 * |x|^2 if |x|<=d
24 | 0.5 * d^2 + d * (|x|-d) if |x|>d
25 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
26 | """
27 | abs_error = torch.abs(error)
28 | # quadratic = torch.min(abs_error, torch.FloatTensor([delta]))
29 | quadratic = torch.clamp(abs_error, max=delta)
30 | linear = (abs_error - quadratic)
31 | loss = 0.5 * quadratic ** 2 + delta * linear
32 | return loss
33 |
34 |
35 | def nn_distance(pc1, pc2, l1smooth=False, delta=1.0, l1=False):
36 | """
37 | Input:
38 | pc1: (B,N,C) torch tensor
39 | pc2: (B,M,C) torch tensor
40 | l1smooth: bool, whether to use l1smooth loss
41 | delta: scalar, the delta used in l1smooth loss
42 | Output:
43 | dist1: (B,N) torch float32 tensor
44 | idx1: (B,N) torch int64 tensor
45 | dist2: (B,M) torch float32 tensor
46 | idx2: (B,M) torch int64 tensor
47 | """
48 | N = pc1.shape[1]
49 | M = pc2.shape[1]
50 | pc1_expand_tile = pc1.unsqueeze(2).repeat(1, 1, M, 1)
51 | pc2_expand_tile = pc2.unsqueeze(1).repeat(1, N, 1, 1)
52 | pc_diff = pc1_expand_tile - pc2_expand_tile
53 |
54 | if l1smooth:
55 | pc_dist = torch.sum(huber_loss(pc_diff, delta), dim=-1) # (B,N,M)
56 | elif l1:
57 | pc_dist = torch.sum(torch.abs(pc_diff), dim=-1) # (B,N,M)
58 | else:
59 | pc_dist = torch.sum(pc_diff ** 2, dim=-1) # (B,N,M)
60 | dist1, idx1 = torch.min(pc_dist, dim=2) # (B,N)
61 | dist2, idx2 = torch.min(pc_dist, dim=1) # (B,M)
62 | return dist1, idx1, dist2, idx2
63 |
64 |
65 | def demo_nn_distance():
66 | np.random.seed(0)
67 | pc1arr = np.random.random((1, 5, 3))
68 | pc2arr = np.random.random((1, 6, 3))
69 | pc1 = torch.from_numpy(pc1arr.astype(np.float32))
70 | pc2 = torch.from_numpy(pc2arr.astype(np.float32))
71 | dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2)
72 | print(dist1)
73 | print(idx1)
74 | dist = np.zeros((5, 6))
75 | for i in range(5):
76 | for j in range(6):
77 | dist[i, j] = np.sum((pc1arr[0, i, :] - pc2arr[0, j, :]) ** 2)
78 | print(dist)
79 | print('-' * 30)
80 | print('L1smooth dists:')
81 | dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2, True)
82 | print(dist1)
83 | print(idx1)
84 | dist = np.zeros((5, 6))
85 | for i in range(5):
86 | for j in range(6):
87 | error = np.abs(pc1arr[0, i, :] - pc2arr[0, j, :])
88 | quad = np.minimum(error, 1.0)
89 | linear = error - quad
90 | loss = 0.5 * quad ** 2 + 1.0 * linear
91 | dist[i, j] = np.sum(loss)
92 | print(dist)
93 |
94 |
95 | if __name__ == '__main__':
96 | demo_nn_distance()
97 |
--------------------------------------------------------------------------------
/object-detection-3d/utils/visual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import torch.nn.functional as F
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | from pathlib import Path
8 | import cv2
9 | from PIL import Image
10 | from torchvision import transforms
11 | import torchvision
12 | from torch.utils.data import DataLoader
13 | import skimage
14 | import colorsys
15 | import random
16 | from skimage import io
17 | import argparse
18 | import datasets.transforms as T
19 | import copy
20 | import glob
21 | import re
22 | torch.set_grad_enabled(False)
23 |
24 | CLASSES = ['bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', 'night_stand', 'bookshelf', 'bathtub']
25 |
26 | def make_coco_transforms(image_set, args):
27 |
28 | normalize = T.Compose([
29 | T.ToTensor(),
30 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
31 | ])
32 |
33 | scales = [800]
34 |
35 | if image_set == 'val' or image_set == 'train':
36 | return T.Compose([
37 | T.RandomResize([scales[-1]], max_size=scales[-1] * 1333 // 800),
38 | normalize,
39 | ])
40 |
41 | raise ValueError(f'unknown {image_set}')
42 |
43 | def plot_gt(im, labels, bboxes_scaled, output_dir):
44 | tl = 3
45 | tf = max(tl-1, 1)
46 | tempimg = copy.deepcopy(im)
47 | color = [255,0,0]
48 | for label, (xmin, ymin, xmax, ymax) in zip(labels.tolist(), bboxes_scaled.tolist()):
49 | c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
50 | cv2.rectangle(tempimg, c1, c2, color, tl, cv2.LINE_AA)
51 | text = f'{CLASSES[label]}'
52 | t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0]
53 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
54 | cv2.rectangle(tempimg, c1, c2, color, -1, cv2.LINE_AA) # filled
55 | cv2.putText(tempimg, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
56 | fname = os.path.join(output_dir,'gt_img.png')
57 | cv2.imwrite(fname, tempimg)
58 | print(f"{fname} saved.")
59 |
60 | def box_cxcywh_to_xyxy(x):
61 | x_c, y_c, w, h = x.unbind(1)
62 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
63 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
64 | return torch.stack(b, dim=1)
65 |
66 | def rescale_bboxes(out_bbox, size):
67 | img_w, img_h = size
68 | b = box_cxcywh_to_xyxy(out_bbox)
69 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
70 | return b
71 | def draw_bbox_in_img(fname, bbox_scaled, score, color=[0,255,0]):
72 | tl = 3
73 | tf = max(tl-1,1) # font thickness
74 | # color = [0,255,0]
75 | im = cv2.imread(fname)
76 | for p, (xmin, ymin, xmax, ymax) in zip(score, bbox_scaled.tolist()):
77 | c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
78 | cv2.rectangle(im, c1, c2, color, tl, cv2.LINE_AA)
79 | cl = p.argmax()
80 | text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
81 | t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0]
82 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
83 | cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
84 | cv2.putText(im, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
85 | cv2.imwrite(fname, im)
86 |
87 | def plot_results(cv2_img, prob, boxes, output_dir):
88 | tl = 3 # thickness line
89 | tf = max(tl-1,1) # font thickness
90 | tempimg = copy.deepcopy(cv2_img)
91 | color = [0,0,255]
92 | for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
93 | c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax))
94 | cv2.rectangle(tempimg, c1, c2, color, tl, cv2.LINE_AA)
95 | cl = p.argmax()
96 | text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
97 | t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0]
98 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
99 | cv2.rectangle(tempimg, c1, c2, color, -1, cv2.LINE_AA) # filled
100 | cv2.putText(tempimg, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
101 | fname = os.path.join(output_dir,'pred_img.png')
102 | cv2.imwrite(fname, tempimg)
103 | print(f"{fname} saved.")
104 |
105 | def increment_path(path, exist_ok=False, sep='', mkdir=False):
106 | # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
107 | path = Path(path) # os-agnostic
108 | if path.exists() and not exist_ok:
109 | suffix = path.suffix
110 | path = path.with_suffix('')
111 | dirs = glob.glob(f"{path}{sep}*") # similar paths
112 | matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
113 | i = [int(m.groups()[0]) for m in matches if m] # indices
114 | n = max(i) + 1 if i else 2 # increment number
115 | path = Path(f"{path}{sep}{n}{suffix}") # update path
116 | dir = path if path.suffix == '' else path.parent # directory
117 | if not dir.exists() and mkdir:
118 | dir.mkdir(parents=True, exist_ok=True) # make directory
119 | return path
120 |
121 | def save_pred_fig(output_dir, output_dic, keep):
122 | # im = Image.open(os.path.join(output_dir, "img.png"))
123 | im = cv2.imread(os.path.join(output_dir, "img.png"))
124 | h, w = im.shape[:2]
125 | bboxes_scaled = rescale_bboxes(output_dic['pred_boxes2d'][0, keep].cpu(), (w,h))
126 | prob = output_dic['pred_logits2d'].softmax(-1)[0, :, :-1]
127 | scores = prob[keep]
128 | plot_results(im, scores, bboxes_scaled, output_dir)
129 |
130 | def save_gt_fig(output_dir, gt_anno):
131 | im = cv2.imread(os.path.join(output_dir, "img.png"))
132 | h, w = im.shape[:2]
133 | bboxes_scaled = rescale_bboxes(gt_anno['boxes'], (w,h))
134 | labels = gt_anno['labels']
135 | plot_gt(im, labels, bboxes_scaled, output_dir)
136 |
137 | def get_one_query_meanattn(vis_attn,h_featmap,w_featmap):
138 | mean_attentions = vis_attn.mean(0).reshape(h_featmap, w_featmap)
139 | mean_attentions = nn.functional.interpolate(mean_attentions.unsqueeze(0).unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy()
140 | return mean_attentions
141 |
142 | def get_one_query_attn(vis_attn, h_featmap, w_featmap, nh):
143 | attentions = vis_attn.reshape(nh, h_featmap, w_featmap)
144 | # attentions = vis_attn.sum(0).reshape(h_featmap, w_featmap)
145 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy()
146 | return attentions
147 |
--------------------------------------------------------------------------------
/semantic_segmentation/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # DATASET PARAMETERS
4 | DATASET = 'nyudv2'
5 | TRAIN_DIR = './data/nyudv2' # 'Modify data path'
6 | VAL_DIR = TRAIN_DIR
7 | TRAIN_LIST = './data/nyudv2/train.txt'
8 | VAL_LIST = './data/nyudv2/val.txt'
9 |
10 |
11 | SHORTER_SIDE = 350
12 | CROP_SIZE = 500
13 | RESIZE_SIZE = None
14 |
15 | NORMALISE_PARAMS = [1./255, # Image SCALE
16 | np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)), # Image MEAN
17 | np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)), # Image STD
18 | 1./5000] # Depth SCALE
19 | BATCH_SIZE = 6
20 | NUM_WORKERS = 16
21 | NUM_CLASSES = 40
22 | LOW_SCALE = 0.5
23 | HIGH_SCALE = 2.0
24 | IGNORE_LABEL = 255
25 |
26 | # ENCODER PARAMETERS
27 | ENC = '101' # ResNet101
28 | ENC_PRETRAINED = True # pre-trained on ImageNet or randomly initialised
29 |
30 | # GENERAL
31 | FREEZE_BN = True
32 | NUM_SEGM_EPOCHS = [100] * 3 # [150] * 3 if using ResNet152 as backbone
33 | PRINT_EVERY = 10
34 | RANDOM_SEED = 42
35 | VAL_EVERY = 5 # how often to record validation scores
36 |
37 | # OPTIMISERS' PARAMETERS
38 | LR_ENC = [5e-4, 2.5e-4, 1e-4] # TO FREEZE, PUT 0
39 | LR_DEC = [3e-3, 1.5e-3, 7e-4]
40 | MOM_ENC = 0.9 # TO FREEZE, PUT 0
41 | MOM_DEC = 0.9
42 | WD_ENC = 1e-5 # TO FREEZE, PUT 0
43 | WD_DEC = 1e-5
44 | LAMDA = 1e-4 # slightly better
45 | BN_threshold = 2e-2 # slightly better
46 | OPTIM_DEC = 'sgd'
47 |
--------------------------------------------------------------------------------
/semantic_segmentation/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .mix_transformer import *
2 | from .segformer import WeTr
--------------------------------------------------------------------------------
/semantic_segmentation/models/modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 | num_parallel = 2
5 |
6 |
7 | class TokenExchange(nn.Module):
8 | def __init__(self):
9 | super(TokenExchange, self).__init__()
10 |
11 | def forward(self, x, mask, mask_threshold):
12 | # x: [B, N, C], mask: [B, N, 1]
13 | x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1])
14 | x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold]
15 | x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold]
16 | x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold]
17 | x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold]
18 | return [x0, x1]
19 |
20 |
21 | class ModuleParallel(nn.Module):
22 | def __init__(self, module):
23 | super(ModuleParallel, self).__init__()
24 | self.module = module
25 |
26 | def forward(self, x_parallel):
27 | return [self.module(x) for x in x_parallel]
28 |
29 |
30 | class LayerNormParallel(nn.Module):
31 | def __init__(self, num_features):
32 | super(LayerNormParallel, self).__init__()
33 | for i in range(num_parallel):
34 | setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6))
35 |
36 | def forward(self, x_parallel):
37 | return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)]
38 |
--------------------------------------------------------------------------------
/semantic_segmentation/models/segformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from . import mix_transformer
5 | from mmcv.cnn import ConvModule
6 | from .modules import num_parallel
7 |
8 |
9 | class MLP(nn.Module):
10 | """
11 | Linear Embedding
12 | """
13 | def __init__(self, input_dim=2048, embed_dim=768):
14 | super().__init__()
15 | self.proj = nn.Linear(input_dim, embed_dim)
16 |
17 | def forward(self, x):
18 | x = x.flatten(2).transpose(1, 2)
19 | x = self.proj(x)
20 | return x
21 |
22 |
23 | class SegFormerHead(nn.Module):
24 | """
25 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
26 | """
27 | def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs):
28 | super(SegFormerHead, self).__init__()
29 | self.in_channels = in_channels
30 | self.num_classes = num_classes
31 | assert len(feature_strides) == len(self.in_channels)
32 | assert min(feature_strides) == feature_strides[0]
33 | self.feature_strides = feature_strides
34 |
35 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
36 |
37 | #decoder_params = kwargs['decoder_params']
38 | #embedding_dim = decoder_params['embed_dim']
39 |
40 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
41 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
42 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
43 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
44 | self.dropout = nn.Dropout2d(0.1)
45 |
46 | self.linear_fuse = ConvModule(
47 | in_channels=embedding_dim*4,
48 | out_channels=embedding_dim,
49 | kernel_size=1,
50 | norm_cfg=dict(type='BN', requires_grad=True)
51 | )
52 |
53 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
54 |
55 | def forward(self, x):
56 | c1, c2, c3, c4 = x
57 |
58 | ############## MLP decoder on C1-C4 ###########
59 | n, _, h, w = c4.shape
60 |
61 | _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
62 | _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
63 |
64 | _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
65 | _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
66 |
67 | _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
68 | _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
69 |
70 | _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
71 |
72 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
73 |
74 | x = self.dropout(_c)
75 | x = self.linear_pred(x)
76 |
77 | return x
78 |
79 |
80 | class WeTr(nn.Module):
81 | def __init__(self, backbone, num_classes=20, embedding_dim=256, pretrained=True):
82 | super().__init__()
83 | self.num_classes = num_classes
84 | self.embedding_dim = embedding_dim
85 | self.feature_strides = [4, 8, 16, 32]
86 | self.num_parallel = num_parallel
87 | #self.in_channels = [32, 64, 160, 256]
88 | #self.in_channels = [64, 128, 320, 512]
89 |
90 | self.encoder = getattr(mix_transformer, backbone)()
91 | self.in_channels = self.encoder.embed_dims
92 | ## initilize encoder
93 | if pretrained:
94 | state_dict = torch.load('pretrained/' + backbone + '.pth')
95 | state_dict.pop('head.weight')
96 | state_dict.pop('head.bias')
97 | state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel)
98 | self.encoder.load_state_dict(state_dict, strict=True)
99 |
100 | self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels,
101 | embedding_dim=self.embedding_dim, num_classes=self.num_classes)
102 |
103 | self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True))
104 | self.register_parameter('alpha', self.alpha)
105 |
106 | def get_param_groups(self):
107 | param_groups = [[], [], []]
108 | for name, param in list(self.encoder.named_parameters()):
109 | if "norm" in name:
110 | param_groups[1].append(param)
111 | else:
112 | param_groups[0].append(param)
113 | for param in list(self.decoder.parameters()):
114 | param_groups[2].append(param)
115 | return param_groups
116 |
117 | def forward(self, x):
118 | x, masks = self.encoder(x)
119 | x = [self.decoder(x[0]), self.decoder(x[1])]
120 | ens = 0
121 | alpha_soft = F.softmax(self.alpha)
122 | for l in range(self.num_parallel):
123 | ens += alpha_soft[l] * x[l].detach()
124 | x.append(ens)
125 | return x, masks
126 |
127 |
128 | def expand_state_dict(model_dict, state_dict, num_parallel):
129 | model_dict_keys = model_dict.keys()
130 | state_dict_keys = state_dict.keys()
131 | for model_dict_key in model_dict_keys:
132 | model_dict_key_re = model_dict_key.replace('module.', '')
133 | if model_dict_key_re in state_dict_keys:
134 | model_dict[model_dict_key] = state_dict[model_dict_key_re]
135 | for i in range(num_parallel):
136 | ln = '.ln_%d' % i
137 | replace = True if ln in model_dict_key_re else False
138 | model_dict_key_re = model_dict_key_re.replace(ln, '')
139 | if replace and model_dict_key_re in state_dict_keys:
140 | model_dict[model_dict_key] = state_dict[model_dict_key_re]
141 | return model_dict
142 |
143 |
144 | if __name__=="__main__":
145 | # import torch.distributed as dist
146 | # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1)
147 | pretrained_weights = torch.load('pretrained/mit_b1.pth')
148 | wetr = WeTr('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda()
149 | wetr.get_param_groupsv()
150 | dummy_input = torch.rand(2,3,512,512).cuda()
151 | wetr(dummy_input)
152 |
--------------------------------------------------------------------------------
/semantic_segmentation/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .helpers import *
2 | from .meter import *
3 |
--------------------------------------------------------------------------------
/semantic_segmentation/utils/cmap.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/semantic_segmentation/utils/cmap.npy
--------------------------------------------------------------------------------
/semantic_segmentation/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | def line_to_paths_fn_nyudv2(x, input_names):
9 | return x.decode('utf-8').strip('\n').split('\t')
10 |
11 | line_to_paths_fn = {'nyudv2': line_to_paths_fn_nyudv2}
12 |
13 |
14 | class SegDataset(Dataset):
15 | """Multi-Modality Segmentation dataset.
16 |
17 | Works with any datasets that contain image
18 | and any number of 2D-annotations.
19 |
20 | Args:
21 | data_file (string): Path to the data file with annotations.
22 | data_dir (string): Directory with all the images.
23 | line_to_paths_fn (callable): function to convert a line of data_file
24 | into paths (img_relpath, msk_relpath, ...).
25 | masks_names (list of strings): keys for each annotation mask
26 | (e.g., 'segm', 'depth').
27 | transform_trn (callable, optional): Optional transform
28 | to be applied on a sample during the training stage.
29 | transform_val (callable, optional): Optional transform
30 | to be applied on a sample during the validation stage.
31 | stage (str): initial stage of dataset - either 'train' or 'val'.
32 |
33 | """
34 | def __init__(self, dataset, data_file, data_dir, input_names, input_mask_idxs,
35 | transform_trn=None, transform_val=None, stage='train', ignore_label=None):
36 | with open(data_file, 'rb') as f:
37 | datalist = f.readlines()
38 | self.datalist = [line_to_paths_fn[dataset](l, input_names) for l in datalist]
39 | self.root_dir = data_dir
40 | self.transform_trn = transform_trn
41 | self.transform_val = transform_val
42 | self.stage = stage
43 | self.input_names = input_names
44 | self.input_mask_idxs = input_mask_idxs
45 | self.ignore_label = ignore_label
46 |
47 | def set_stage(self, stage):
48 | """Define which set of transformation to use.
49 |
50 | Args:
51 | stage (str): either 'train' or 'val'
52 |
53 | """
54 | self.stage = stage
55 |
56 | def __len__(self):
57 | return len(self.datalist)
58 |
59 | def __getitem__(self, idx):
60 | idxs = self.input_mask_idxs
61 | names = [os.path.join(self.root_dir, rpath) for rpath in self.datalist[idx]]
62 | sample = {}
63 | for i, key in enumerate(self.input_names):
64 | sample[key] = self.read_image(names[idxs[i]], key)
65 | try:
66 | mask = np.array(Image.open(names[idxs[-1]]))
67 | except FileNotFoundError: # for sunrgbd
68 | path = names[idxs[-1]]
69 | num_idx = int(path[-10:-4]) + 5050
70 | path = path[:-10] + '%06d' % num_idx + path[-4:]
71 | mask = np.array(Image.open(path))
72 | assert len(mask.shape) == 2, 'Masks must be encoded without colourmap'
73 | sample['inputs'] = self.input_names
74 | sample['mask'] = mask
75 | if self.stage == 'train':
76 | if self.transform_trn:
77 | sample = self.transform_trn(sample)
78 | elif self.stage == 'val':
79 | if self.transform_val:
80 | sample = self.transform_val(sample)
81 | del sample['inputs']
82 | return sample
83 |
84 | @staticmethod
85 | def read_image_(x, key):
86 | img = cv2.imread(x)
87 | if key == 'depth':
88 | img = cv2.applyColorMap(cv2.convertScaleAbs(255 - img, alpha=1), cv2.COLORMAP_JET)
89 | return img
90 |
91 | @staticmethod
92 | def read_image(x, key):
93 | """Simple image reader
94 |
95 | Args:
96 | x (str): path to image.
97 |
98 | Returns image as `np.array`.
99 |
100 | """
101 | img_arr = np.array(Image.open(x))
102 | if len(img_arr.shape) == 2: # grayscale
103 | img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0)
104 | return img_arr
105 |
--------------------------------------------------------------------------------
/semantic_segmentation/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import matplotlib as mpl
4 | import matplotlib.cm as cm
5 | import PIL.Image as pil
6 | import cv2
7 | import os
8 |
9 | IMG_SCALE = 1./255
10 | IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
11 | IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
12 | logger = None
13 |
14 |
15 | def print_log(message):
16 | print(message, flush=True)
17 | if logger:
18 | logger.write(str(message) + '\n')
19 |
20 |
21 | def maybe_download(model_name, model_url, model_dir=None, map_location=None):
22 | import os, sys
23 | from six.moves import urllib
24 | if model_dir is None:
25 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch'))
26 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models'))
27 | if not os.path.exists(model_dir):
28 | os.makedirs(model_dir)
29 | filename = '{}.pth.tar'.format(model_name)
30 | cached_file = os.path.join(model_dir, filename)
31 | if not os.path.exists(cached_file):
32 | url = model_url
33 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
34 | urllib.request.urlretrieve(url, cached_file)
35 | return torch.load(cached_file, map_location=map_location)
36 |
37 |
38 | def prepare_img(img):
39 | return (img * IMG_SCALE - IMG_MEAN) / IMG_STD
40 |
41 |
42 | def make_validation_img(img_, depth_, lab, pre):
43 | cmap = np.load('./utils/cmap.npy')
44 |
45 | img = np.array([i * IMG_STD.reshape((3, 1, 1)) + IMG_MEAN.reshape((3, 1, 1)) for i in img_])
46 | img *= 255
47 | img = img.astype(np.uint8)
48 | img = np.concatenate(img, axis=1)
49 |
50 | depth_ = depth_[0].transpose(1, 2, 0) / max(depth_.max(), 10)
51 | vmax = np.percentile(depth_, 95)
52 | normalizer = mpl.colors.Normalize(vmin=depth_.min(), vmax=vmax)
53 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
54 | depth = (mapper.to_rgba(depth_)[:,:,:3] * 255).astype(np.uint8)
55 | lab = np.concatenate(lab)
56 | lab = np.array([cmap[i.astype(np.uint8) + 1] for i in lab])
57 |
58 | pre = np.concatenate(pre)
59 | pre = np.array([cmap[i.astype(np.uint8) + 1] for i in pre])
60 | img = img.transpose(1, 2, 0)
61 |
62 | return np.concatenate([img, depth, lab, pre], 1)
63 |
--------------------------------------------------------------------------------
/semantic_segmentation/utils/load_state.txt:
--------------------------------------------------------------------------------
1 | RuntimeError: Error(s) in loading state_dict for mit_b1:
2 | Missing key(s) in state_dict: "patch_embed1.proj.module.weight", "patch_embed1.proj.module.bias", "patch_embed1.norm.ln_0.weight", "patch_embed1.norm.ln_0.bias", "patch_embed1.norm.ln_1.weight", "patch_embed1.norm.ln_1.bias", "patch_embed2.proj.module.weight", "patch_embed2.proj.module.bias", "patch_embed2.norm.ln_0.weight", "patch_embed2.norm.ln_0.bias", "patch_embed2.norm.ln_1.weight", "patch_embed2.norm.ln_1.bias", "patch_embed3.proj.module.weight", "patch_embed3.proj.module.bias", "patch_embed3.norm.ln_0.weight", "patch_embed3.norm.ln_0.bias", "patch_embed3.norm.ln_1.weight", "patch_embed3.norm.ln_1.bias", "patch_embed4.proj.module.weight", "patch_embed4.proj.module.bias", "patch_embed4.norm.ln_0.weight", "patch_embed4.norm.ln_0.bias", "patch_embed4.norm.ln_1.weight", "patch_embed4.norm.ln_1.bias", "block1.0.norm1.ln_0.weight", "block1.0.norm1.ln_0.bias", "block1.0.norm1.ln_1.weight", "block1.0.norm1.ln_1.bias", "block1.0.attn.q.module.weight", "block1.0.attn.q.module.bias", "block1.0.attn.kv.module.weight", "block1.0.attn.kv.module.bias", "block1.0.attn.proj.module.weight", "block1.0.attn.proj.module.bias", "block1.0.attn.sr.module.weight", "block1.0.attn.sr.module.bias", "block1.0.attn.norm.ln_0.weight", "block1.0.attn.norm.ln_0.bias", "block1.0.attn.norm.ln_1.weight", "block1.0.attn.norm.ln_1.bias", "block1.0.norm2.ln_0.weight", "block1.0.norm2.ln_0.bias", "block1.0.norm2.ln_1.weight", "block1.0.norm2.ln_1.bias", "block1.0.mlp.module.fc1.weight", "block1.0.mlp.module.fc1.bias", "block1.0.mlp.module.dwconv.dwconv.weight", "block1.0.mlp.module.dwconv.dwconv.bias", "block1.0.mlp.module.fc2.weight", "block1.0.mlp.module.fc2.bias", "block1.1.norm1.ln_0.weight", "block1.1.norm1.ln_0.bias", "block1.1.norm1.ln_1.weight", "block1.1.norm1.ln_1.bias", "block1.1.attn.q.module.weight", "block1.1.attn.q.module.bias", "block1.1.attn.kv.module.weight", "block1.1.attn.kv.module.bias", "block1.1.attn.proj.module.weight", "block1.1.attn.proj.module.bias", "block1.1.attn.sr.module.weight", "block1.1.attn.sr.module.bias", "block1.1.attn.norm.ln_0.weight", "block1.1.attn.norm.ln_0.bias", "block1.1.attn.norm.ln_1.weight", "block1.1.attn.norm.ln_1.bias", "block1.1.norm2.ln_0.weight", "block1.1.norm2.ln_0.bias", "block1.1.norm2.ln_1.weight", "block1.1.norm2.ln_1.bias", "block1.1.mlp.module.fc1.weight", "block1.1.mlp.module.fc1.bias", "block1.1.mlp.module.dwconv.dwconv.weight", "block1.1.mlp.module.dwconv.dwconv.bias", "block1.1.mlp.module.fc2.weight", "block1.1.mlp.module.fc2.bias", "norm1.ln_0.weight", "norm1.ln_0.bias", "norm1.ln_1.weight", "norm1.ln_1.bias", "block2.0.norm1.ln_0.weight", "block2.0.norm1.ln_0.bias", "block2.0.norm1.ln_1.weight", "block2.0.norm1.ln_1.bias", "block2.0.attn.q.module.weight", "block2.0.attn.q.module.bias", "block2.0.attn.kv.module.weight", "block2.0.attn.kv.module.bias", "block2.0.attn.proj.module.weight", "block2.0.attn.proj.module.bias", "block2.0.attn.sr.module.weight", "block2.0.attn.sr.module.bias", "block2.0.attn.norm.ln_0.weight", "block2.0.attn.norm.ln_0.bias", "block2.0.attn.norm.ln_1.weight", "block2.0.attn.norm.ln_1.bias", "block2.0.norm2.ln_0.weight", "block2.0.norm2.ln_0.bias", "block2.0.norm2.ln_1.weight", "block2.0.norm2.ln_1.bias", "block2.0.mlp.module.fc1.weight", "block2.0.mlp.module.fc1.bias", "block2.0.mlp.module.dwconv.dwconv.weight", "block2.0.mlp.module.dwconv.dwconv.bias", "block2.0.mlp.module.fc2.weight", "block2.0.mlp.module.fc2.bias", "block2.1.norm1.ln_0.weight", "block2.1.norm1.ln_0.bias", "block2.1.norm1.ln_1.weight", "block2.1.norm1.ln_1.bias", "block2.1.attn.q.module.weight", "block2.1.attn.q.module.bias", "block2.1.attn.kv.module.weight", "block2.1.attn.kv.module.bias", "block2.1.attn.proj.module.weight", "block2.1.attn.proj.module.bias", "block2.1.attn.sr.module.weight", "block2.1.attn.sr.module.bias", "block2.1.attn.norm.ln_0.weight", "block2.1.attn.norm.ln_0.bias", "block2.1.attn.norm.ln_1.weight", "block2.1.attn.norm.ln_1.bias", "block2.1.norm2.ln_0.weight", "block2.1.norm2.ln_0.bias", "block2.1.norm2.ln_1.weight", "block2.1.norm2.ln_1.bias", "block2.1.mlp.module.fc1.weight", "block2.1.mlp.module.fc1.bias", "block2.1.mlp.module.dwconv.dwconv.weight", "block2.1.mlp.module.dwconv.dwconv.bias", "block2.1.mlp.module.fc2.weight", "block2.1.mlp.module.fc2.bias", "norm2.ln_0.weight", "norm2.ln_0.bias", "norm2.ln_1.weight", "norm2.ln_1.bias", "block3.0.norm1.ln_0.weight", "block3.0.norm1.ln_0.bias", "block3.0.norm1.ln_1.weight", "block3.0.norm1.ln_1.bias", "block3.0.attn.q.module.weight", "block3.0.attn.q.module.bias", "block3.0.attn.kv.module.weight", "block3.0.attn.kv.module.bias", "block3.0.attn.proj.module.weight", "block3.0.attn.proj.module.bias", "block3.0.attn.sr.module.weight", "block3.0.attn.sr.module.bias", "block3.0.attn.norm.ln_0.weight", "block3.0.attn.norm.ln_0.bias", "block3.0.attn.norm.ln_1.weight", "block3.0.attn.norm.ln_1.bias", "block3.0.norm2.ln_0.weight", "block3.0.norm2.ln_0.bias", "block3.0.norm2.ln_1.weight", "block3.0.norm2.ln_1.bias", "block3.0.mlp.module.fc1.weight", "block3.0.mlp.module.fc1.bias", "block3.0.mlp.module.dwconv.dwconv.weight", "block3.0.mlp.module.dwconv.dwconv.bias", "block3.0.mlp.module.fc2.weight", "block3.0.mlp.module.fc2.bias", "block3.1.norm1.ln_0.weight", "block3.1.norm1.ln_0.bias", "block3.1.norm1.ln_1.weight", "block3.1.norm1.ln_1.bias", "block3.1.attn.q.module.weight", "block3.1.attn.q.module.bias", "block3.1.attn.kv.module.weight", "block3.1.attn.kv.module.bias", "block3.1.attn.proj.module.weight", "block3.1.attn.proj.module.bias", "block3.1.attn.sr.module.weight", "block3.1.attn.sr.module.bias", "block3.1.attn.norm.ln_0.weight", "block3.1.attn.norm.ln_0.bias", "block3.1.attn.norm.ln_1.weight", "block3.1.attn.norm.ln_1.bias", "block3.1.norm2.ln_0.weight", "block3.1.norm2.ln_0.bias", "block3.1.norm2.ln_1.weight", "block3.1.norm2.ln_1.bias", "block3.1.mlp.module.fc1.weight", "block3.1.mlp.module.fc1.bias", "block3.1.mlp.module.dwconv.dwconv.weight", "block3.1.mlp.module.dwconv.dwconv.bias", "block3.1.mlp.module.fc2.weight", "block3.1.mlp.module.fc2.bias", "norm3.ln_0.weight", "norm3.ln_0.bias", "norm3.ln_1.weight", "norm3.ln_1.bias", "block4.0.norm1.ln_0.weight", "block4.0.norm1.ln_0.bias", "block4.0.norm1.ln_1.weight", "block4.0.norm1.ln_1.bias", "block4.0.attn.q.module.weight", "block4.0.attn.q.module.bias", "block4.0.attn.kv.module.weight", "block4.0.attn.kv.module.bias", "block4.0.attn.proj.module.weight", "block4.0.attn.proj.module.bias", "block4.0.norm2.ln_0.weight", "block4.0.norm2.ln_0.bias", "block4.0.norm2.ln_1.weight", "block4.0.norm2.ln_1.bias", "block4.0.mlp.module.fc1.weight", "block4.0.mlp.module.fc1.bias", "block4.0.mlp.module.dwconv.dwconv.weight", "block4.0.mlp.module.dwconv.dwconv.bias", "block4.0.mlp.module.fc2.weight", "block4.0.mlp.module.fc2.bias", "block4.1.norm1.ln_0.weight", "block4.1.norm1.ln_0.bias", "block4.1.norm1.ln_1.weight", "block4.1.norm1.ln_1.bias", "block4.1.attn.q.module.weight", "block4.1.attn.q.module.bias", "block4.1.attn.kv.module.weight", "block4.1.attn.kv.module.bias", "block4.1.attn.proj.module.weight", "block4.1.attn.proj.module.bias", "block4.1.norm2.ln_0.weight", "block4.1.norm2.ln_0.bias", "block4.1.norm2.ln_1.weight", "block4.1.norm2.ln_1.bias", "block4.1.mlp.module.fc1.weight", "block4.1.mlp.module.fc1.bias", "block4.1.mlp.module.dwconv.dwconv.weight", "block4.1.mlp.module.dwconv.dwconv.bias", "block4.1.mlp.module.fc2.weight", "block4.1.mlp.module.fc2.bias", "norm4.ln_0.weight", "norm4.ln_0.bias", "norm4.ln_1.weight", "norm4.ln_1.bias".
3 | Unexpected key(s) in state_dict: "patch_embed1.proj.weight", "patch_embed1.proj.bias", "patch_embed1.norm.weight", "patch_embed1.norm.bias", "patch_embed2.proj.weight", "patch_embed2.proj.bias", "patch_embed2.norm.weight", "patch_embed2.norm.bias", "patch_embed3.proj.weight", "patch_embed3.proj.bias", "patch_embed3.norm.weight", "patch_embed3.norm.bias", "patch_embed4.proj.weight", "patch_embed4.proj.bias", "patch_embed4.norm.weight", "patch_embed4.norm.bias", "block1.0.norm1.weight", "block1.0.norm1.bias", "block1.0.attn.q.weight", "block1.0.attn.q.bias", "block1.0.attn.kv.weight", "block1.0.attn.kv.bias", "block1.0.attn.proj.weight", "block1.0.attn.proj.bias", "block1.0.attn.sr.weight", "block1.0.attn.sr.bias", "block1.0.attn.norm.weight", "block1.0.attn.norm.bias", "block1.0.norm2.weight", "block1.0.norm2.bias", "block1.0.mlp.fc1.weight", "block1.0.mlp.fc1.bias", "block1.0.mlp.dwconv.dwconv.weight", "block1.0.mlp.dwconv.dwconv.bias", "block1.0.mlp.fc2.weight", "block1.0.mlp.fc2.bias", "block1.1.norm1.weight", "block1.1.norm1.bias", "block1.1.attn.q.weight", "block1.1.attn.q.bias", "block1.1.attn.kv.weight", "block1.1.attn.kv.bias", "block1.1.attn.proj.weight", "block1.1.attn.proj.bias", "block1.1.attn.sr.weight", "block1.1.attn.sr.bias", "block1.1.attn.norm.weight", "block1.1.attn.norm.bias", "block1.1.norm2.weight", "block1.1.norm2.bias", "block1.1.mlp.fc1.weight", "block1.1.mlp.fc1.bias", "block1.1.mlp.dwconv.dwconv.weight", "block1.1.mlp.dwconv.dwconv.bias", "block1.1.mlp.fc2.weight", "block1.1.mlp.fc2.bias", "norm1.weight", "norm1.bias", "block2.0.norm1.weight", "block2.0.norm1.bias", "block2.0.attn.q.weight", "block2.0.attn.q.bias", "block2.0.attn.kv.weight", "block2.0.attn.kv.bias", "block2.0.attn.proj.weight", "block2.0.attn.proj.bias", "block2.0.attn.sr.weight", "block2.0.attn.sr.bias", "block2.0.attn.norm.weight", "block2.0.attn.norm.bias", "block2.0.norm2.weight", "block2.0.norm2.bias", "block2.0.mlp.fc1.weight", "block2.0.mlp.fc1.bias", "block2.0.mlp.dwconv.dwconv.weight", "block2.0.mlp.dwconv.dwconv.bias", "block2.0.mlp.fc2.weight", "block2.0.mlp.fc2.bias", "block2.1.norm1.weight", "block2.1.norm1.bias", "block2.1.attn.q.weight", "block2.1.attn.q.bias", "block2.1.attn.kv.weight", "block2.1.attn.kv.bias", "block2.1.attn.proj.weight", "block2.1.attn.proj.bias", "block2.1.attn.sr.weight", "block2.1.attn.sr.bias", "block2.1.attn.norm.weight", "block2.1.attn.norm.bias", "block2.1.norm2.weight", "block2.1.norm2.bias", "block2.1.mlp.fc1.weight", "block2.1.mlp.fc1.bias", "block2.1.mlp.dwconv.dwconv.weight", "block2.1.mlp.dwconv.dwconv.bias", "block2.1.mlp.fc2.weight", "block2.1.mlp.fc2.bias", "norm2.weight", "norm2.bias", "block3.0.norm1.weight", "block3.0.norm1.bias", "block3.0.attn.q.weight", "block3.0.attn.q.bias", "block3.0.attn.kv.weight", "block3.0.attn.kv.bias", "block3.0.attn.proj.weight", "block3.0.attn.proj.bias", "block3.0.attn.sr.weight", "block3.0.attn.sr.bias", "block3.0.attn.norm.weight", "block3.0.attn.norm.bias", "block3.0.norm2.weight", "block3.0.norm2.bias", "block3.0.mlp.fc1.weight", "block3.0.mlp.fc1.bias", "block3.0.mlp.dwconv.dwconv.weight", "block3.0.mlp.dwconv.dwconv.bias", "block3.0.mlp.fc2.weight", "block3.0.mlp.fc2.bias", "block3.1.norm1.weight", "block3.1.norm1.bias", "block3.1.attn.q.weight", "block3.1.attn.q.bias", "block3.1.attn.kv.weight", "block3.1.attn.kv.bias", "block3.1.attn.proj.weight", "block3.1.attn.proj.bias", "block3.1.attn.sr.weight", "block3.1.attn.sr.bias", "block3.1.attn.norm.weight", "block3.1.attn.norm.bias", "block3.1.norm2.weight", "block3.1.norm2.bias", "block3.1.mlp.fc1.weight", "block3.1.mlp.fc1.bias", "block3.1.mlp.dwconv.dwconv.weight", "block3.1.mlp.dwconv.dwconv.bias", "block3.1.mlp.fc2.weight", "block3.1.mlp.fc2.bias", "norm3.weight", "norm3.bias", "block4.0.norm1.weight", "block4.0.norm1.bias", "block4.0.attn.q.weight", "block4.0.attn.q.bias", "block4.0.attn.kv.weight", "block4.0.attn.kv.bias", "block4.0.attn.proj.weight", "block4.0.attn.proj.bias", "block4.0.norm2.weight", "block4.0.norm2.bias", "block4.0.mlp.fc1.weight", "block4.0.mlp.fc1.bias", "block4.0.mlp.dwconv.dwconv.weight", "block4.0.mlp.dwconv.dwconv.bias", "block4.0.mlp.fc2.weight", "block4.0.mlp.fc2.bias", "block4.1.norm1.weight", "block4.1.norm1.bias", "block4.1.attn.q.weight", "block4.1.attn.q.bias", "block4.1.attn.kv.weight", "block4.1.attn.kv.bias", "block4.1.attn.proj.weight", "block4.1.attn.proj.bias", "block4.1.norm2.weight", "block4.1.norm2.bias", "block4.1.mlp.fc1.weight", "block4.1.mlp.fc1.bias", "block4.1.mlp.dwconv.dwconv.weight", "block4.1.mlp.dwconv.dwconv.bias", "block4.1.mlp.fc2.weight", "block4.1.mlp.fc2.bias", "norm4.weight", "norm4.bias".
4 |
--------------------------------------------------------------------------------
/semantic_segmentation/utils/meter.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def confusion_matrix(x, y, n, ignore_label=None, mask=None):
8 | if mask is None:
9 | mask = np.ones_like(x) == 1
10 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool))
11 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n ** 2).reshape(n, n)
12 |
13 |
14 | def getScores(conf_matrix):
15 | if conf_matrix.sum() == 0:
16 | return 0, 0, 0
17 | with np.errstate(divide='ignore',invalid='ignore'):
18 | overall = np.diag(conf_matrix).sum() / np.float(conf_matrix.sum())
19 | perclass = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float)
20 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) \
21 | - np.diag(conf_matrix)).astype(np.float)
22 | return overall * 100., np.nanmean(perclass) * 100., np.nanmean(IU) * 100.
23 |
24 |
25 | def compute_params(model):
26 | """Compute number of parameters"""
27 | n_total_params = 0
28 | for name, m in model.named_parameters():
29 | n_elem = m.numel()
30 | n_total_params += n_elem
31 | return n_total_params
32 |
33 |
34 | # Adopted from https://raw.githubusercontent.com/pytorch/examples/master/imagenet/main.py
35 | class AverageMeter(object):
36 | """Computes and stores the average and current value"""
37 | def __init__(self):
38 | self.reset()
39 |
40 | def reset(self):
41 | self.val = 0
42 | self.avg = 0
43 | self.sum = 0
44 | self.count = 0
45 |
46 | def update(self, val, n=1):
47 | self.val = val
48 | self.sum += val * n
49 | self.count += n
50 | self.avg = self.sum / self.count
51 |
52 |
53 | class Saver():
54 | """Saver class for managing parameters"""
55 | def __init__(self, args, ckpt_dir, best_val=0, condition=lambda x, y: x > y):
56 | """
57 | Args:
58 | args (dict): dictionary with arguments.
59 | ckpt_dir (str): path to directory in which to store the checkpoint.
60 | best_val (float): initial best value.
61 | condition (function): how to decide whether to save the new checkpoint
62 | by comparing best value and new value (x,y).
63 |
64 | """
65 | if not os.path.exists(ckpt_dir):
66 | os.makedirs(ckpt_dir)
67 | with open('{}/args.json'.format(ckpt_dir), 'w') as f:
68 | json.dump({k: v for k, v in args.items() if isinstance(v, (int, float, str))}, f,
69 | sort_keys = True, indent = 4, ensure_ascii = False)
70 | self.ckpt_dir = ckpt_dir
71 | self.best_val = best_val
72 | self.condition = condition
73 | self._counter = 0
74 |
75 | def _do_save(self, new_val):
76 | """Check whether need to save"""
77 | return self.condition(new_val, self.best_val)
78 |
79 | def save(self, new_val, dict_to_save):
80 | """Save new checkpoint"""
81 | self._counter += 1
82 | if self._do_save(new_val):
83 | # print(' New best value {:.4f}, was {:.4f}'.format(new_val, self.best_val), flush=True)
84 | self.best_val = new_val
85 | dict_to_save['best_val'] = new_val
86 | torch.save(dict_to_save, '{}/model-best.pth.tar'.format(self.ckpt_dir))
87 | else:
88 | dict_to_save['best_val'] = new_val
89 | torch.save(dict_to_save, '{}/checkpoint.pth.tar'.format(self.ckpt_dir))
90 |
91 |
--------------------------------------------------------------------------------
/semantic_segmentation/utils/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class PolyWarmupAdamW(torch.optim.AdamW):
4 |
5 | def __init__(self, params, lr, weight_decay, betas, warmup_iter=None, max_iter=None, warmup_ratio=None, power=None):
6 | super().__init__(params, lr=lr, betas=betas,weight_decay=weight_decay, eps=1e-8)
7 |
8 | self.global_step = 0
9 | self.warmup_iter = warmup_iter
10 | self.warmup_ratio = warmup_ratio
11 | self.max_iter = max_iter
12 | self.power = power
13 |
14 | self.__init_lr = [group['lr'] for group in self.param_groups]
15 |
16 | def step(self, closure=None):
17 | ## adjust lr
18 | if self.global_step < self.warmup_iter:
19 |
20 | lr_mult = 1 - (1 - self.global_step / self.warmup_iter) * (1 - self.warmup_ratio)
21 | for i in range(len(self.param_groups)):
22 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult
23 |
24 | elif self.global_step < self.max_iter:
25 |
26 | lr_mult = (1 - self.global_step / self.max_iter) ** self.power
27 | for i in range(len(self.param_groups)):
28 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult
29 |
30 | # step
31 | super().step(closure)
32 |
33 | self.global_step += 1
--------------------------------------------------------------------------------
/semantic_segmentation/utils/transforms.py:
--------------------------------------------------------------------------------
1 | """RefineNet-LightWeight
2 |
3 | RefineNet-LigthWeight PyTorch for non-commercial purposes
4 |
5 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au)
6 | All rights reserved.
7 |
8 | Redistribution and use in source and binary forms, with or without
9 | modification, are permitted provided that the following conditions are met:
10 |
11 | * Redistributions of source code must retain the above copyright notice, this
12 | list of conditions and the following disclaimer.
13 |
14 | * Redistributions in binary form must reproduce the above copyright notice,
15 | this list of conditions and the following disclaimer in the documentation
16 | and/or other materials provided with the distribution.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | """
29 |
30 |
31 | import cv2
32 | import numpy as np
33 | import torch
34 |
35 | # Usual dtypes for common modalities
36 | KEYS_TO_DTYPES = {
37 | 'rgb': torch.float,
38 | 'depth': torch.float,
39 | 'normals': torch.float,
40 | 'mask': torch.long,
41 | }
42 |
43 |
44 | class Pad(object):
45 | """Pad image and mask to the desired size.
46 |
47 | Args:
48 | size (int) : minimum length/width.
49 | img_val (array) : image padding value.
50 | msk_val (int) : mask padding value.
51 |
52 | """
53 | def __init__(self, size, img_val, msk_val):
54 | assert isinstance(size, int)
55 | self.size = size
56 | self.img_val = img_val
57 | self.msk_val = msk_val
58 |
59 | def __call__(self, sample):
60 | image = sample['rgb']
61 | h, w = image.shape[:2]
62 | h_pad = int(np.clip(((self.size - h) + 1) // 2, 0, 1e6))
63 | w_pad = int(np.clip(((self.size - w) + 1) // 2, 0, 1e6))
64 | pad = ((h_pad, h_pad), (w_pad, w_pad))
65 | for key in sample['inputs']:
66 | sample[key] = self.transform_input(sample[key], pad)
67 | sample['mask'] = np.pad(sample['mask'], pad, mode='constant', constant_values=self.msk_val)
68 | return sample
69 |
70 | def transform_input(self, input, pad):
71 | input = np.stack([
72 | np.pad(input[:, :, c], pad, mode='constant',
73 | constant_values=self.img_val[c]) for c in range(3)
74 | ], axis=2)
75 | return input
76 |
77 |
78 | class RandomCrop(object):
79 | """Crop randomly the image in a sample.
80 |
81 | Args:
82 | crop_size (int): Desired output size.
83 |
84 | """
85 | def __init__(self, crop_size):
86 | assert isinstance(crop_size, int)
87 | self.crop_size = crop_size
88 | if self.crop_size % 2 != 0:
89 | self.crop_size -= 1
90 |
91 | def __call__(self, sample):
92 | image = sample['rgb']
93 | h, w = image.shape[:2]
94 | new_h = min(h, self.crop_size)
95 | new_w = min(w, self.crop_size)
96 | top = np.random.randint(0, h - new_h + 1)
97 | left = np.random.randint(0, w - new_w + 1)
98 | for key in sample['inputs']:
99 | sample[key] = self.transform_input(sample[key], top, new_h, left, new_w)
100 | sample['mask'] = sample['mask'][top : top + new_h, left : left + new_w]
101 | return sample
102 |
103 | def transform_input(self, input, top, new_h, left, new_w):
104 | input = input[top : top + new_h, left : left + new_w]
105 | return input
106 |
107 |
108 | class ResizeAndScale(object):
109 | """Resize shorter/longer side to a given value and randomly scale.
110 |
111 | Args:
112 | side (int) : shorter / longer side value.
113 | low_scale (float) : lower scaling bound.
114 | high_scale (float) : upper scaling bound.
115 | shorter (bool) : whether to resize shorter / longer side.
116 |
117 | """
118 | def __init__(self, side, low_scale, high_scale, shorter=True):
119 | assert isinstance(side, int)
120 | assert isinstance(low_scale, float)
121 | assert isinstance(high_scale, float)
122 | self.side = side
123 | self.low_scale = low_scale
124 | self.high_scale = high_scale
125 | self.shorter = shorter
126 |
127 | def __call__(self, sample):
128 | image = sample['rgb']
129 | scale = np.random.uniform(self.low_scale, self.high_scale)
130 | if self.shorter:
131 | min_side = min(image.shape[:2])
132 | if min_side * scale < self.side:
133 | scale = (self.side * 1. / min_side)
134 | else:
135 | max_side = max(image.shape[:2])
136 | if max_side * scale > self.side:
137 | scale = (self.side * 1. / max_side)
138 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST}
139 | for key in sample['inputs']:
140 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
141 | sample[key] = self.transform_input(sample[key], scale, inter)
142 | sample['mask'] = cv2.resize(sample['mask'], None, fx=scale, fy=scale,
143 | interpolation=cv2.INTER_NEAREST)
144 | return sample
145 |
146 | def transform_input(self, input, scale, inter):
147 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
148 | return input
149 |
150 |
151 | class CropAlignToMask(object):
152 | """Crop inputs to the size of the mask."""
153 | def __call__(self, sample):
154 | mask_h, mask_w = sample['mask'].shape[:2]
155 | for key in sample['inputs']:
156 | sample[key] = self.transform_input(sample[key], mask_h, mask_w)
157 | return sample
158 |
159 | def transform_input(self, input, mask_h, mask_w):
160 | input_h, input_w = input.shape[:2]
161 | if (input_h, input_w) == (mask_h, mask_w):
162 | return input
163 | h, w = (input_h - mask_h) // 2, (input_w - mask_w) // 2
164 | del_h, del_w = (input_h - mask_h) % 2, (input_w - mask_w) % 2
165 | input = input[h: input_h - h - del_h, w: input_w - w - del_w]
166 | assert input.shape[:2] == (mask_h, mask_w)
167 | return input
168 |
169 |
170 | class ResizeAlignToMask(object):
171 | """Resize inputs to the size of the mask."""
172 | def __call__(self, sample):
173 | mask_h, mask_w = sample['mask'].shape[:2]
174 | assert mask_h == mask_w
175 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST}
176 | for key in sample['inputs']:
177 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
178 | sample[key] = self.transform_input(sample[key], mask_h, inter)
179 | return sample
180 |
181 | def transform_input(self, input, mask_h, inter):
182 | input_h, input_w = input.shape[:2]
183 | assert input_h == input_w
184 | scale = mask_h / input_h
185 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
186 | return input
187 |
188 |
189 | class ResizeInputs(object):
190 | def __init__(self, size):
191 | self.size = size
192 |
193 | def __call__(self, sample):
194 | # sample['rgb'] = sample['rgb'].numpy()
195 | if self.size is None:
196 | return sample
197 | size = sample['rgb'].shape[0]
198 | scale = self.size / size
199 | # print(sample['rgb'].shape, type(sample['rgb']))
200 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST}
201 | for key in sample['inputs']:
202 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
203 | sample[key] = self.transform_input(sample[key], scale, inter)
204 | return sample
205 |
206 | def transform_input(self, input, scale, inter):
207 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
208 | return input
209 |
210 |
211 | class ResizeInputsScale(object):
212 | def __init__(self, scale):
213 | self.scale = scale
214 |
215 | def __call__(self, sample):
216 | if self.scale is None:
217 | return sample
218 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST}
219 | for key in sample['inputs']:
220 | inter = inters[key] if key in inters else cv2.INTER_CUBIC
221 | sample[key] = self.transform_input(sample[key], self.scale, inter)
222 | return sample
223 |
224 | def transform_input(self, input, scale, inter):
225 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter)
226 | return input
227 |
228 |
229 | class RandomMirror(object):
230 | """Randomly flip the image and the mask"""
231 | def __call__(self, sample):
232 | do_mirror = np.random.randint(2)
233 | if do_mirror:
234 | for key in sample['inputs']:
235 | sample[key] = cv2.flip(sample[key], 1)
236 | sample['mask'] = cv2.flip(sample['mask'], 1)
237 | return sample
238 |
239 |
240 | class Normalise(object):
241 | """Normalise a tensor image with mean and standard deviation.
242 | Given mean: (R, G, B) and std: (R, G, B),
243 | will normalise each channel of the torch.*Tensor, i.e.
244 | channel = (scale * channel - mean) / std
245 |
246 | Args:
247 | scale (float): Scaling constant.
248 | mean (sequence): Sequence of means for R,G,B channels respecitvely.
249 | std (sequence): Sequence of standard deviations for R,G,B channels
250 | respecitvely.
251 | depth_scale (float): Depth divisor for depth annotations.
252 |
253 | """
254 | def __init__(self, scale, mean, std, depth_scale=1.):
255 | self.scale = scale
256 | self.mean = mean
257 | self.std = std
258 | self.depth_scale = depth_scale
259 |
260 | def __call__(self, sample):
261 | for key in sample['inputs']:
262 | if key == 'depth':
263 | continue
264 | sample[key] = (self.scale * sample[key] - self.mean) / self.std
265 | if 'depth' in sample:
266 | # sample['depth'] = self.scale * sample['depth']
267 | # sample['depth'] = (self.scale * sample['depth'] - self.mean) / self.std
268 | if self.depth_scale > 0:
269 | sample['depth'] = self.depth_scale * sample['depth']
270 | elif self.depth_scale == -1: # taskonomy
271 | # sample['depth'] = np.log(1 + sample['depth']) / np.log(2.** 16.0)
272 | sample['depth'] = np.log(1 + sample['depth'])
273 | elif self.depth_scale == -2: # sunrgbd
274 | depth = sample['depth']
275 | sample['depth'] = (depth - depth.min()) * 255.0 / (depth.max() - depth.min())
276 | return sample
277 |
278 |
279 | class ToTensor(object):
280 | """Convert ndarrays in sample to Tensors."""
281 | def __call__(self, sample):
282 | # swap color axis because
283 | # numpy image: H x W x C
284 | # torch image: C X H X W
285 | for key in sample['inputs']:
286 | sample[key] = torch.from_numpy(
287 | sample[key].transpose((2, 0, 1))
288 | ).to(KEYS_TO_DTYPES[key] if key in KEYS_TO_DTYPES else KEYS_TO_DTYPES['rgb'])
289 | sample['mask'] = torch.from_numpy(sample['mask']).to(KEYS_TO_DTYPES['mask'])
290 | return sample
291 |
292 |
293 | def make_list(x):
294 | """Returns the given input as a list."""
295 | if isinstance(x, list):
296 | return x
297 | elif isinstance(x, tuple):
298 | return list(x)
299 | else:
300 | return [x]
301 |
302 |
--------------------------------------------------------------------------------