├── .gitignore ├── LICENSE ├── README.md ├── figs ├── framework.png ├── heterogeneous.png └── homogeneous.png ├── image2image_translation ├── cfg.py ├── data │ ├── train_domain.txt │ └── val_domain.txt ├── eval.py ├── main.py ├── models │ ├── model_cfg.py │ ├── model_pruning.py │ └── modules.py └── utils │ ├── __init__.py │ ├── dataset.py │ ├── fid_kid.py │ ├── inception.py │ └── utils.py ├── object-detection-3d ├── models │ ├── __init__.py │ ├── ap_helper.py │ ├── detector.py │ ├── dump_helper.py │ ├── layers │ │ ├── __init__.py │ │ ├── drop.py │ │ ├── helper.py │ │ └── weight_init.py │ ├── loss_helper.py │ ├── losses.py │ ├── matcher.py │ ├── modules.py │ ├── multi_head_attention.py │ ├── pointnet.py │ ├── transformer.py │ └── vit.py ├── pointnet2 │ ├── _ext_src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── pointnet2_modules.py │ ├── pointnet2_test.py │ ├── pointnet2_utils.py │ ├── pytorch_utils.py │ └── setup.py ├── sunrgbd │ ├── model_util_sunrgbd.py │ ├── sunrgbd_data.py │ ├── sunrgbd_detection_dataset.py │ └── sunrgbd_utils.py ├── train_dist.py └── utils │ ├── __init__.py │ ├── box_ops.py │ ├── box_util.py │ ├── eval_det.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── metric_util.py │ ├── misc.py │ ├── nms.py │ ├── nn_distance.py │ ├── pc_util.py │ └── visual.py └── semantic_segmentation ├── config.py ├── data └── nyudv2 │ ├── train.txt │ └── val.txt ├── main.py ├── models ├── __init__.py ├── mix_transformer.py ├── modules.py └── segformer.py └── utils ├── __init__.py ├── cmap.npy ├── datasets.py ├── helpers.py ├── load_state.txt ├── meter.py ├── optimizer.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yikai Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Token Fusion for Vision Transformers 2 | 3 | By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang. 4 | 5 | [**[Paper]**](https://arxiv.org/pdf/2204.08721.pdf) 6 | 7 | This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022. 8 | 9 |
10 | 11 |
12 | 13 | Homogeneous predictions, 14 |
15 | 16 |
17 | 18 | Heterogeneous predictions, 19 |
20 | 21 |
22 | 23 | 24 | ## Datasets 25 | 26 | For semantic segmentation task on NYUDv2 ([official dataset](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)), we provide a link to download the dataset [here](https://drive.google.com/drive/folders/1mXmOXVsd5l9-gYHk92Wpn6AcKAbE0m3X?usp=sharing). The provided dataset is originally preprocessed in this [repository](https://github.com/DrSleep/light-weight-refinenet), and we add depth data in it. 27 | 28 | For image-to-image translation task, we use the sample dataset of [Taskonomy](http://taskonomy.stanford.edu/), where a link to download the sample dataset is [here](https://github.com/alexsax/taskonomy-sample-model-1.git). 29 | 30 | Please modify the data paths in the codes, where we add comments 'Modify data path'. 31 | 32 | 33 | ## Dependencies 34 | ``` 35 | python==3.6 36 | pytorch==1.7.1 37 | torchvision==0.8.2 38 | numpy==1.19.2 39 | ``` 40 | 41 | 42 | ## Semantic Segmentation 43 | 44 | 45 | First, 46 | ``` 47 | cd semantic_segmentation 48 | ``` 49 | 50 | Download the [segformer](https://github.com/NVlabs/SegFormer) pretrained model (pretrained on ImageNet) from [weights](https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia), e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'. 51 | 52 | Training script for segmentation with RGB and Depth input, 53 | ``` 54 | python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2 55 | ``` 56 | 57 | Evaluation script, 58 | ``` 59 | python main.py --gpu 0 --resume path_to_pth --evaluate # optionally use --save-img to visualize results 60 | ``` 61 | 62 | Checkpoint models, training logs, mask ratios and the **single-scale** performance on NYUDv2 are provided as follows: 63 | 64 | | Method | Backbone | Pixel Acc. (%) | Mean Acc. (%) | Mean IoU (%) | Download | 65 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:| 66 | |[CEN](https://github.com/yikaiw/CEN)| ResNet101 | 76.2 | 62.8 | 51.1 | [Google Drive](https://drive.google.com/drive/folders/1wim_cBG-HW0bdipwA1UbnGeDwjldPIwV?usp=sharing)| 67 | |[CEN](https://github.com/yikaiw/CEN)| ResNet152 | 77.0 | 64.4 | 51.6 | [Google Drive](https://drive.google.com/drive/folders/1DGF6vHLDgBgLrdUNJOLYdoXCuEKbIuRs?usp=sharing)| 68 | |Ours| SegFormer-B3 | 78.7 | 67.5 | 54.8 | [Google Drive](https://drive.google.com/drive/folders/14fi8aABFYqGF7LYKHkiJazHA58OBW1AW?usp=sharing)| 69 | 70 | 71 | Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion 72 | 73 | ## Image-to-Image Translation 74 | 75 | First, 76 | ``` 77 | cd image2image_translation 78 | ``` 79 | Training script, from Shade and Texture to RGB, 80 | ``` 81 | python main.py --gpu 0 -c exp_name 82 | ``` 83 | This script will auto-evaluate on the validation dataset every 5 training epochs. 84 | 85 | Predicted images will be automatically saved during training, in the following folder structure: 86 | 87 | ``` 88 | code_root/ckpt/exp_name/results 89 | ├── input0 # 1st modality input 90 | ├── input1 # 2nd modality input 91 | ├── fake0 # 1st branch output 92 | ├── fake1 # 2nd branch output 93 | ├── fake2 # ensemble output 94 | ├── best # current best output 95 | │ ├── fake0 96 | │ ├── fake1 97 | │ └── fake2 98 | └── real # ground truth output 99 | ``` 100 | 101 | Checkpoint models: 102 | 103 | | Method | Task | FID | KID | Download | 104 | |:-----------:|:-----------:|:-----------:|:-----------:|:-----------:| 105 | | [CEN](https://github.com/yikaiw/CEN) |Texture+Shade->RGB | 62.6 | 1.65 | - | 106 | | Ours | Texture+Shade->RGB | 45.5 | 1.00 | [Google Drive](https://drive.google.com/drive/folders/1vkcDv5bHKXZKxCg4dC7R56ts6nLLt6lh?usp=sharing)| 107 | 108 | ## 3D Object Detection (under construction) 109 | 110 | Data preparation, environments, and training scripts follow [Group-Free](https://github.com/zeliu98/Group-Free-3D) and [ImVoteNet](https://github.com/facebookresearch/imvotenet). 111 | 112 | E.g., 113 | ``` 114 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name 115 | ``` 116 | 117 | ## Citation 118 | 119 | If you find our work useful for your research, please consider citing the following paper. 120 | ``` 121 | @inproceedings{wang2022tokenfusion, 122 | title={Multimodal Token Fusion for Vision Transformers}, 123 | author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe}, 124 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 125 | year={2022} 126 | } 127 | ``` 128 | 129 | 130 | -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/figs/framework.png -------------------------------------------------------------------------------- /figs/heterogeneous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/figs/heterogeneous.png -------------------------------------------------------------------------------- /figs/homogeneous.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/figs/homogeneous.png -------------------------------------------------------------------------------- /image2image_translation/cfg.py: -------------------------------------------------------------------------------- 1 | num_parallel = None 2 | use_exchange = None 3 | mask_threshold = None 4 | lrnorm_threshold = None 5 | logger = None -------------------------------------------------------------------------------- /image2image_translation/eval.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | from torchvision import transforms 6 | from models.model_cfg import gen_b2, dis_b2 7 | import cfg 8 | from utils import * 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--batch-size', type=int, default=1, 12 | help='train batch size') 13 | parser.add_argument('--ngf', type=int, default=64) 14 | parser.add_argument('--ndf', type=int, default=64) 15 | parser.add_argument('--input-size', type=int, default=256, 16 | help='input size') 17 | parser.add_argument('--resize-scale', type=int, default=286, 18 | help='resize scale (0 is false)') 19 | parser.add_argument('--crop-size', type=int, default=256, 20 | help='crop size (0 is false)') 21 | parser.add_argument('--fliplr', type=bool, default=True, 22 | help='random fliplr True of False') 23 | parser.add_argument('--num-epochs', type=int, default=300, 24 | help='number of train epochs') 25 | parser.add_argument('--val-every', type=int, default=5, 26 | help='how often to validate current architecture') 27 | parser.add_argument('--lrG', type=float, default=0.0002, 28 | help='learning rate for generator, default=0.0002') 29 | parser.add_argument('--lrD', type=float, default=0.0002, 30 | help='learning rate for discriminator, default=0.0002') 31 | parser.add_argument('--gama', type=float, default=100, 32 | help='gama for L1 loss') 33 | parser.add_argument('--beta1', type=float, default=0.5, 34 | help='beta1 for Adam optimizer') 35 | parser.add_argument('--beta2', type=float, default=0.999, 36 | help='beta2 for Adam optimizer') 37 | parser.add_argument('--print-loss', action='store_true', default=False, 38 | help='whether print losses during training') 39 | parser.add_argument('--gpu', type=int, nargs='+', default=[0], 40 | help='select gpu.') 41 | parser.add_argument('-c', '--ckpt', default='model', type=str, metavar='PATH', 42 | help='path to save checkpoint (default: model)') 43 | parser.add_argument('-i', '--img-types', default=[2, 7, 0], type=int, nargs='+', 44 | help='image types, last image is target, others are inputs') 45 | parser.add_argument('--exchange', type=int, default=1, 46 | help='whether use feature exchange') 47 | parser.add_argument('-l', '--lamda', type=float, default=1e-3, 48 | help='lamda for L1 norm on BN scales.') 49 | parser.add_argument('-t', '--insnorm-threshold', type=float, default=1e-2, 50 | help='threshold for slimming BNs') 51 | parser.add_argument('--enc', default=[0], type=int, nargs='+') 52 | parser.add_argument('--dec', default=[0], type=int, nargs='+') 53 | params = parser.parse_args() 54 | 55 | # Directories for loading data and saving results 56 | data_dir = '/mnt/beegfs/ssd_pool/docker/user/hadoop-automl/yikai/upload/taskonomy-sample-model-1' # 'Modify data path' 57 | # data_dir = '/data/wyk/datasets/taskonomy-sample-model-1' 58 | # data_dir = '/home1/wyk/data/taskonomy-sample-model-1' 59 | model_dir = os.path.join('ckpt', params.ckpt) 60 | save_dir = os.path.join(model_dir, 'results') 61 | save_dir_best = os.path.join(save_dir, 'best') 62 | os.makedirs(save_dir_best, exist_ok=True) 63 | os.makedirs(os.path.join(model_dir, 'insnorm_params'), exist_ok=True) 64 | os.system('cp -r *py models utils data %s' % model_dir) 65 | cfg.logger = open(os.path.join(model_dir, 'log.txt'), 'w+') 66 | print_log(params) 67 | 68 | train_file = './data/train_domain.txt' 69 | val_file = './data/val_domain.txt' 70 | 71 | domain_dicts = {0: 'rgb', 1: 'normal', 2: 'reshading', 3: 'depth_euclidean', 4: 'depth_zbuffer', 72 | 5: 'principal_curvature', 6: 'edge_occlusion', 7: 'edge_texture', 73 | 8: 'segment_unsup2d', 9: 'segment_unsup25d'} 74 | params.img_types = [domain_dicts[img_type] for img_type in params.img_types] 75 | print_log('\n' + ', '.join(params.img_types[:-1]) + ' -> ' + params.img_types[-1]) 76 | num_parallel = len(params.img_types) - 1 77 | 78 | cfg.num_parallel = num_parallel 79 | cfg.use_exchange = params.exchange == 1 80 | cfg.insnorm_threshold = params.insnorm_threshold 81 | cfg.enc, cfg.dec = params.enc, params.dec 82 | 83 | # Data pre-processing 84 | transform = transforms.Compose([transforms.Resize(params.input_size), 85 | transforms.ToTensor(), 86 | transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) 87 | 88 | # Train data 89 | train_data = DatasetFromFolder(data_dir, train_file, params.img_types, transform=transform, 90 | resize_scale=params.resize_scale, crop_size=params.crop_size, 91 | fliplr=params.fliplr) 92 | train_data_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=params.batch_size, 93 | shuffle=True, drop_last=False) 94 | 95 | # Test data 96 | test_data = DatasetFromFolder(data_dir, val_file, params.img_types, transform=transform) 97 | test_data_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=params.batch_size, 98 | shuffle=False, drop_last=False) 99 | # test_input, test_target = test_data_loader.__iter__().__next__() 100 | 101 | # Models 102 | torch.cuda.set_device(params.gpu[0]) 103 | # G = Generator(3, params.ngf, 3) 104 | G = gen_b2() 105 | # D = Discriminator(6, params.ndf, 1) 106 | D = dis_b2(img_size=256, patch_size=4) 107 | G.cuda() 108 | G = torch.nn.DataParallel(G, params.gpu) 109 | D.cuda() 110 | D = torch.nn.DataParallel(D, params.gpu) 111 | state_dict = torch.load('./checkpoint-gen-399.pkl') 112 | G.load_state_dict(state_dict, strict=True) 113 | 114 | BCE_loss = torch.nn.BCELoss().cuda() 115 | L2_loss = torch.nn.MSELoss().cuda() 116 | L1_loss = torch.nn.L1Loss().cuda() 117 | 118 | 119 | def evaluate(G, epoch, training): 120 | num_parallel_ = 1 if num_parallel == 1 else num_parallel + 1 121 | l1_losses = init_lists(num_parallel_) 122 | l2_losses = init_lists(num_parallel_) 123 | fids = init_lists(num_parallel_) 124 | kids = init_lists(num_parallel_) 125 | for i, (test_inputs, test_target) in tqdm(enumerate(test_data_loader), miniters=25, total=len(test_data_loader)): 126 | # for i, (test_inputs, test_target) in enumerate(test_data_loader): 127 | # Show result for test image 128 | test_inputs_cuda = [test_input.cuda() for test_input in test_inputs] 129 | gen_images, alpha_soft, _ = G(test_inputs_cuda) 130 | test_target_cuda = test_target.cuda() 131 | for l, gen_image in enumerate(gen_images): 132 | if l < num_parallel or num_parallel > 1: 133 | l1_losses[l].append(L1_loss(gen_image, test_target_cuda).item()) 134 | l2_losses[l].append(L2_loss(gen_image, test_target_cuda).item()) 135 | gen_image = gen_image.cpu().data 136 | save_dir_ = os.path.join(save_dir, 'fake%d' % l) 137 | plot_test_result_single(gen_image, i, save_dir=save_dir_) 138 | if l < num_parallel: 139 | save_dir_ = os.path.join(save_dir, 'input%d' % l) 140 | if not os.path.exists(os.path.join(save_dir_, '%03d.png' % i)): 141 | plot_test_result_single(test_inputs[l], i, save_dir=save_dir_) 142 | save_dir_ = os.path.join(save_dir, 'real') 143 | if not os.path.exists(os.path.join(save_dir_, '%03d.png' % i)): 144 | plot_test_result_single(test_target, i, save_dir=save_dir_) 145 | # break 146 | 147 | for l in range(num_parallel_): 148 | paths = [os.path.join(save_dir, 'fake%d' % l), os.path.join(save_dir, 'real')] 149 | fid, kid = calculate_given_paths(paths, batch_size=50, cuda=True, dims=2048) 150 | fids[l], kids[l] = fid, kid 151 | 152 | l1_avg_losses = [torch.mean(torch.FloatTensor(l1_losses_)) for l1_losses_ in l1_losses] 153 | l2_avg_losses = [torch.mean(torch.FloatTensor(l2_losses_)) for l2_losses_ in l2_losses] 154 | return l1_avg_losses, l2_avg_losses, fids, kids 155 | 156 | 157 | l1_avg_losses, l2_avg_losses, fids, kids = evaluate(G, epoch=0, training=True) 158 | for l in range(len(l1_avg_losses)): 159 | l1_avg_loss, rl2_avg_loss = l1_avg_losses[l], l2_avg_losses[l]** 0.5 160 | fid, kid = fids[l], kids[l] 161 | if l < num_parallel: 162 | img_type_str = '(%s)' % params.img_types[l][:10] 163 | else: 164 | img_type_str = '(ens)' 165 | print_log('Epoch %3d %-15s l1_avg_loss: %.5f rl2_avg_loss: %.5f fid: %.3f kid: %.3f' % \ 166 | (0, img_type_str, l1_avg_loss, rl2_avg_loss, fid, kid)) 167 | 168 | cfg.logger.close() 169 | -------------------------------------------------------------------------------- /image2image_translation/models/model_cfg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .model_pruning import MixVisionTransformerGen, MixVisionTransformerDis 3 | from .modules import LayerNormParallel 4 | from functools import partial 5 | 6 | 7 | class gen_b0(MixVisionTransformerGen): 8 | def __init__(self, **kwargs): 9 | super(gen_b0, self).__init__( 10 | patch_size=4, embed_dims=[32, 64, 128, 256, 512], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 11 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 12 | 13 | 14 | class gen_b1(MixVisionTransformerGen): 15 | def __init__(self, **kwargs): 16 | super(gen_b1, self).__init__( 17 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 18 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 19 | 20 | 21 | class gen_b2(MixVisionTransformerGen): 22 | def __init__(self, **kwargs): 23 | super(gen_b2, self).__init__( 24 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 25 | qkv_bias=True, depths=[3, 4, 4, 6, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 26 | 27 | 28 | class gen_b3(MixVisionTransformerGen): 29 | def __init__(self, **kwargs): 30 | super(gen_b3, self).__init__( 31 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 32 | qkv_bias=True, depths=[3, 4, 4, 18, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 33 | 34 | 35 | class gen_b4(MixVisionTransformerGen): 36 | def __init__(self, **kwargs): 37 | super(gen_b4, self).__init__( 38 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 39 | qkv_bias=True, depths=[3, 8, 8, 27, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 40 | 41 | 42 | class gen_b5(MixVisionTransformerGen): 43 | def __init__(self, **kwargs): 44 | super(gen_b5, self).__init__( 45 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 46 | qkv_bias=True, depths=[3, 6, 6, 40, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 47 | 48 | 49 | class dis_b0(MixVisionTransformerDis): 50 | def __init__(self, **kwargs): 51 | super(dis_b0, self).__init__( 52 | patch_size=4, embed_dims=[32, 64, 128, 256, 512], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 53 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 54 | 55 | 56 | class dis_b1(MixVisionTransformerDis): 57 | def __init__(self, **kwargs): 58 | super(dis_b1, self).__init__( 59 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 60 | qkv_bias=True, depths=[2, 2, 2, 2, 2], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 61 | 62 | 63 | class dis_b2(MixVisionTransformerDis): 64 | def __init__(self, **kwargs): 65 | super(dis_b2, self).__init__( 66 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 67 | qkv_bias=True, depths=[3, 4, 4, 6, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 68 | 69 | 70 | class dis_b3(MixVisionTransformerDis): 71 | def __init__(self, **kwargs): 72 | super(dis_b3, self).__init__( 73 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 74 | qkv_bias=True, depths=[3, 4, 4, 18, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 75 | 76 | 77 | class dis_b4(MixVisionTransformerDis): 78 | def __init__(self, **kwargs): 79 | super(dis_b4, self).__init__( 80 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 81 | qkv_bias=True, depths=[3, 8, 8, 27, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 82 | 83 | 84 | class dis_b5(MixVisionTransformerDis): 85 | def __init__(self, **kwargs): 86 | super(dis_b5, self).__init__( 87 | patch_size=4, embed_dims=[64, 128, 256, 512, 1024], num_heads=[1, 2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4, 4], 88 | qkv_bias=True, depths=[3, 6, 6, 40, 3], sr_ratios=[16, 8, 4, 2, 1], drop_rate=0.0, drop_path_rate=0.1) 89 | -------------------------------------------------------------------------------- /image2image_translation/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import cfg 3 | import torch 4 | 5 | 6 | class TokenExchange(nn.Module): 7 | def __init__(self): 8 | super(TokenExchange, self).__init__() 9 | 10 | def forward(self, x, mask, mask_threshold): 11 | # x: [B, N, C], mask: [B, N, 2] 12 | x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) 13 | x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold] 14 | x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold] 15 | x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold] 16 | x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold] 17 | return [x0, x1] 18 | 19 | 20 | class ChannelExchange(nn.Module): 21 | def __init__(self): 22 | super(ChannelExchange, self).__init__() 23 | 24 | def forward(self, x, lrnorm, lrnorm_threshold): 25 | lrnorm0, lrnorm1 = lrnorm[0].weight.abs(), lrnorm[1].weight.abs() 26 | x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) 27 | x0[:, lrnorm0 >= lrnorm_threshold] = x[0][:, lrnorm0 >= lrnorm_threshold] 28 | x0[:, lrnorm0 < lrnorm_threshold] = x[1][:, lrnorm0 < lrnorm_threshold] 29 | x1[:, lrnorm1 >= lrnorm_threshold] = x[1][:, lrnorm1 >= lrnorm_threshold] 30 | x1[:, lrnorm1 < lrnorm_threshold] = x[0][:, lrnorm1 < lrnorm_threshold] 31 | return [x0, x1] 32 | 33 | 34 | class ModuleParallel(nn.Module): 35 | def __init__(self, module): 36 | super(ModuleParallel, self).__init__() 37 | self.module = module 38 | 39 | def forward(self, x_parallel): 40 | return [self.module(x) for x in x_parallel] 41 | 42 | 43 | class LayerNormParallel(nn.Module): 44 | def __init__(self, num_features): 45 | super(LayerNormParallel, self).__init__() 46 | for i in range(cfg.num_parallel): 47 | setattr(self, 'lrnorm_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) 48 | 49 | def forward(self, x_parallel): 50 | return [getattr(self, 'lrnorm_' + str(i))(x) for i, x in enumerate(x_parallel)] 51 | -------------------------------------------------------------------------------- /image2image_translation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .utils import * 3 | from .fid_kid import calculate_given_paths 4 | -------------------------------------------------------------------------------- /image2image_translation/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Custom dataset 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import numpy as np 5 | import os 6 | import random 7 | 8 | max_v = {'edge_texture': 11355, 'edge_occlusion': 11584, 'depth_euclidean': 11.1, 'depth_zbuffer': 11.1} 9 | min_v = {'edge_texture': 0, 'edge_occlusion': 0, 'depth_euclidean': 0, 'depth_zbuffer': 0} 10 | 11 | 12 | def line_to_path_fn(x, data_dir): 13 | path = x.decode('utf-8').strip('\n') 14 | return os.path.join(data_dir, path) 15 | 16 | 17 | class DatasetFromFolder(data.Dataset): 18 | def __init__(self, data_dir, data_file, img_types, transform=None, 19 | resize_scale=None, crop_size=None, fliplr=False, is_cls=False): 20 | super(DatasetFromFolder, self).__init__() 21 | with open(data_file, 'rb') as f: 22 | data_list = f.readlines() 23 | self.data_list = [line_to_path_fn(line, data_dir) for line in data_list] 24 | self.img_types = img_types 25 | self.transform = transform 26 | self.resize_scale = resize_scale 27 | self.crop_size = crop_size 28 | self.fliplr = fliplr 29 | self.is_cls = is_cls 30 | 31 | def __getitem__(self, index): 32 | # Load Image 33 | domain_path = self.data_list[index] 34 | if self.is_cls: 35 | img_types = self.img_types[:-1] 36 | cls_target = self.img_types[-1] 37 | else: 38 | img_types = self.img_types 39 | img_paths = [domain_path.replace('{domain}', img_type) for img_type in img_types] 40 | imgs = [Image.open(img_path) for img_path in img_paths] 41 | 42 | for l in range(len(imgs)): 43 | img = np.array(imgs[l]) 44 | img_type = img_types[l] 45 | update = False 46 | if len(img.shape) == 2: 47 | img = img[:,:, np.newaxis] 48 | img = np.concatenate([img] * 3, 2) 49 | update = True 50 | if 'depth' in img_type: 51 | img = np.log(1 + img) 52 | update = True 53 | if img_type in max_v: 54 | img = (img - min_v[img_type]) * 255.0 / (max_v[img_type] - min_v[img_type]) 55 | update = True 56 | if update: 57 | imgs[l] = Image.fromarray(img.astype('uint8')) 58 | 59 | if self.resize_scale: 60 | imgs = [img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR) \ 61 | for img in imgs] 62 | if self.crop_size: 63 | x = random.randint(0, self.resize_scale - self.crop_size + 1) 64 | y = random.randint(0, self.resize_scale - self.crop_size + 1) 65 | imgs = [img.crop((x, y, x + self.crop_size, y + self.crop_size)) for img in imgs] 66 | if self.fliplr: 67 | if random.random() < 0.5: 68 | imgs = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in imgs] 69 | if self.transform is not None: 70 | imgs = [self.transform(img) for img in imgs] 71 | 72 | if self.is_cls: 73 | inputs = imgs 74 | target = np.load(domain_path.replace('{domain}', cls_target).\ 75 | replace('png', 'npy').replace('scene.npy', 'places.npy')) 76 | target = np.argmax(target) 77 | else: 78 | inputs, target = imgs[:-1], imgs[-1] 79 | 80 | return inputs, target 81 | 82 | def __len__(self): 83 | return len(self.data_list) 84 | -------------------------------------------------------------------------------- /image2image_translation/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | plt.switch_backend('agg') 5 | import os 6 | import imageio 7 | import cfg 8 | 9 | 10 | def print_log(message): 11 | print(message, flush=True) 12 | if cfg.logger: 13 | cfg.logger.write(str(message) + '\n') 14 | 15 | 16 | # Plot losses 17 | def plot_loss(d_losses, g_losses, num_epochs, save=True, save_dir='results/', show=False): 18 | fig, ax = plt.subplots() 19 | ax_ = ax.twinx() 20 | ax.set_xlim(0, num_epochs) 21 | # ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses)) * 1.1) 22 | plt.xlabel('# of Epochs') 23 | ax.set_ylabel('Generator loss values') 24 | ax_.set_ylabel('Discriminator loss values') 25 | ax.plot(g_losses, label='Generator') 26 | ax_.plot(d_losses, label='Discriminator') 27 | plt.legend() 28 | 29 | # save figure 30 | if save: 31 | if not os.path.exists(save_dir): 32 | os.mkdir(save_dir) 33 | save_fn = save_dir + 'Loss_values_epoch_{:d}'.format(num_epochs) + '.png' 34 | plt.savefig(save_fn) 35 | 36 | if show: 37 | plt.show() 38 | else: 39 | plt.close() 40 | 41 | 42 | def plot_test_result_single(image, image_idx, save_dir='results/', fig_size=(5, 5)): 43 | # fig_size = (target.size(2) / 100, target.size(3) / 100) 44 | # fig, ax = plt.subplots(figsize=fig_size) 45 | fig, ax = plt.subplots() 46 | img = image 47 | 48 | ax.axis('off') 49 | ax.set_adjustable('box') 50 | # Scale to 0-255 51 | img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min()))\ 52 | .numpy().transpose(1, 2, 0).astype(np.uint8) 53 | ax.imshow(img, cmap=None, aspect='equal') 54 | plt.subplots_adjust(wspace=0, hspace=0) 55 | 56 | save_path = save_dir 57 | os.makedirs(save_path, exist_ok=True) 58 | save_path = os.path.join(save_path, '%03d.png' % image_idx) 59 | fig.subplots_adjust(bottom=0) 60 | fig.subplots_adjust(top=1) 61 | fig.subplots_adjust(right=1) 62 | fig.subplots_adjust(left=0) 63 | 64 | foo_fig = plt.gcf() 65 | foo_fig.set_size_inches(5, 5) 66 | foo_fig.savefig(save_path, dpi=200, bbox_inches='tight') 67 | plt.savefig(save_path) 68 | plt.close() 69 | 70 | 71 | def plot_test_result(input, target, gen_image, image_idx, img_title, epoch, training=True, 72 | save=True, save_dir='results/', show=False, fig_size=(5, 5)): 73 | if input is not None: 74 | fig_size = (target.size(2) * 3 / 100, target.size(3) / 100) 75 | imgs = [input, gen_image, target] 76 | fig, axes = plt.subplots(1, 3, figsize=fig_size) 77 | else: 78 | fig_size = (target.size(2) * 2 / 100, target.size(3) / 100) 79 | imgs = [gen_image, target] 80 | fig, axes = plt.subplots(1, 2, figsize=fig_size) 81 | 82 | for ax, img in zip(axes.flatten(), imgs): 83 | ax.axis('off') 84 | ax.set_adjustable('box') 85 | # Scale to 0-255 86 | img = (((img[0] - img[0].min()) * 255) / (img[0].max() - img[0].min()))\ 87 | .numpy().transpose(1, 2, 0).astype(np.uint8) 88 | ax.imshow(img, cmap=None, aspect='equal') 89 | plt.subplots_adjust(wspace=0, hspace=0) 90 | 91 | # save figure 92 | if save: 93 | # save_path = os.path.join(save_dir, str(image_idx)) 94 | save_path = save_dir 95 | os.makedirs(save_path, exist_ok=True) 96 | if training: 97 | save_path = os.path.join(save_path, '%03d_%s.png' % (image_idx, img_title)) 98 | else: 99 | save_path = os.path.join(save_path, 'Test_%03d_%s.png' % (image_idx, img_title)) 100 | fig.subplots_adjust(bottom=0) 101 | fig.subplots_adjust(top=1) 102 | fig.subplots_adjust(right=1) 103 | fig.subplots_adjust(left=0) 104 | plt.savefig(save_path) 105 | 106 | if show: 107 | plt.show() 108 | else: 109 | plt.close() 110 | 111 | 112 | def maybe_download(model_name, model_url, model_dir=None, map_location=None): 113 | import os, sys 114 | from six.moves import urllib 115 | if model_dir is None: 116 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 117 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 118 | if not os.path.exists(model_dir): 119 | os.makedirs(model_dir) 120 | filename = '{}.pth.tar'.format(model_name) 121 | cached_file = os.path.join(model_dir, filename) 122 | if not os.path.exists(cached_file): 123 | url = model_url 124 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 125 | urllib.request.urlretrieve(url, cached_file) 126 | if '152' in cached_file: 127 | cached_file = '/home/anbang/.cache/torch/checkpoints/resnet152-b121ed2d.pth' 128 | return torch.load(cached_file, map_location=map_location) 129 | 130 | 131 | # Make gif 132 | def make_gif(dataset, num_epochs, save_dir='results/'): 133 | gen_image_plots = [] 134 | for epoch in range(num_epochs): 135 | # plot for generating gif 136 | save_fn = save_dir + 'Result_epoch_{:d}'.format(epoch + 1) + '.png' 137 | gen_image_plots.append(imageio.imread(save_fn)) 138 | 139 | imageio.mimsave(save_dir + dataset + '_pix2pix_epochs_{:d}'.format(num_epochs) \ 140 | + '.gif', gen_image_plots, fps=5) 141 | 142 | 143 | def init_lists(length): 144 | lists = [] 145 | for l in range(length): 146 | lists.append([]) 147 | return lists 148 | 149 | 150 | class AverageMeter(object): 151 | """Computes and stores the average and current value 152 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 153 | """ 154 | def __init__(self): 155 | self.reset() 156 | 157 | def reset(self): 158 | self.val = 0 159 | self.avg = 0 160 | self.sum = 0 161 | self.count = 0 162 | 163 | def update(self, val, n=1): 164 | self.val = val 165 | self.sum += val * n 166 | self.count += n 167 | self.avg = self.sum / self.count 168 | 169 | 170 | def accuracy(output, target, topk=(1,)): 171 | """Computes the precision@k for the specified values of k""" 172 | with torch.no_grad(): 173 | maxk = max(topk) 174 | batch_size = target.size(0) 175 | _, pred = output.topk(maxk, 1, True, True) 176 | pred = pred.t() 177 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 178 | res = [] 179 | for k in topk: 180 | correct_k = correct[:k].view(-1).float().sum(0) 181 | res.append(correct_k.mul_(100.0 / batch_size)) 182 | return res 183 | -------------------------------------------------------------------------------- /object-detection-3d/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .detector import GroupFreeDetector 2 | from .loss_helper import get_loss 3 | from .ap_helper import APCalculator, parse_predictions, parse_groundtruths 4 | from .dump_helper import dump_results 5 | -------------------------------------------------------------------------------- /object-detection-3d/models/dump_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | import os 9 | import sys 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | ROOT_DIR = os.path.dirname(BASE_DIR) 12 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 13 | import pc_util 14 | 15 | DUMP_CONF_THRESH = 0.5 # Dump boxes with obj prob larger than that. 16 | 17 | def softmax(x): 18 | ''' Numpy function for softmax''' 19 | shape = x.shape 20 | probs = np.exp(x - np.max(x, axis=len(shape)-1, keepdims=True)) 21 | probs /= np.sum(probs, axis=len(shape)-1, keepdims=True) 22 | return probs 23 | 24 | def dump_results(end_points, dump_dir, config, inference_switch=False): 25 | ''' Dump results. 26 | 27 | Args: 28 | end_points: dict 29 | {..., pred_mask} 30 | pred_mask is a binary mask array of size (batch_size, num_proposal) computed by running NMS and empty box removal 31 | Returns: 32 | None 33 | ''' 34 | if not os.path.exists(dump_dir): 35 | os.system('mkdir %s'%(dump_dir)) 36 | 37 | # INPUT 38 | print(end_points.keys()) 39 | point_clouds = end_points['point_clouds'].cpu().numpy() 40 | batch_size = point_clouds.shape[0] 41 | 42 | # NETWORK OUTPUTS 43 | seed_xyz = end_points['seed_xyz'].detach().cpu().numpy() # (B,num_seed,3) 44 | if 'vote_xyz' in end_points: 45 | aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy() 46 | vote_xyz = end_points['vote_xyz'].detach().cpu().numpy() # (B,num_seed,3) 47 | aggregated_vote_xyz = end_points['aggregated_vote_xyz'].detach().cpu().numpy() 48 | objectness_scores = end_points['objectness_scores'].detach().cpu().numpy() # (B,K,2) 49 | pred_center = end_points['center'].detach().cpu().numpy() # (B,K,3) 50 | pred_heading_class = torch.argmax(end_points['heading_scores'], -1) # B,num_proposal 51 | pred_heading_residual = torch.gather(end_points['heading_residuals'], 2, pred_heading_class.unsqueeze(-1)) # B,num_proposal,1 52 | pred_heading_class = pred_heading_class.detach().cpu().numpy() # B,num_proposal 53 | pred_heading_residual = pred_heading_residual.squeeze(2).detach().cpu().numpy() # B,num_proposal 54 | pred_size_class = torch.argmax(end_points['size_scores'], -1) # B,num_proposal 55 | pred_size_residual = torch.gather(end_points['size_residuals'], 2, pred_size_class.unsqueeze(-1).unsqueeze(-1).repeat(1,1,1,3)) # B,num_proposal,1,3 56 | pred_size_residual = pred_size_residual.squeeze(2).detach().cpu().numpy() # B,num_proposal,3 57 | 58 | # OTHERS 59 | pred_mask = end_points['pred_mask'] # B,num_proposal 60 | idx_beg = 0 61 | 62 | for i in range(batch_size): 63 | pc = point_clouds[i,:,:] 64 | objectness_prob = softmax(objectness_scores[i,:,:])[:,1] # (K,) 65 | 66 | # Dump various point clouds 67 | pc_util.write_ply(pc, os.path.join(dump_dir, '%06d_pc.ply'%(idx_beg+i))) 68 | pc_util.write_ply(seed_xyz[i,:,:], os.path.join(dump_dir, '%06d_seed_pc.ply'%(idx_beg+i))) 69 | if 'vote_xyz' in end_points: 70 | pc_util.write_ply(end_points['vote_xyz'][i,:,:], os.path.join(dump_dir, '%06d_vgen_pc.ply'%(idx_beg+i))) 71 | pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i))) 72 | pc_util.write_ply(aggregated_vote_xyz[i,:,:], os.path.join(dump_dir, '%06d_aggregated_vote_pc.ply'%(idx_beg+i))) 73 | pc_util.write_ply(pred_center[i,:,0:3], os.path.join(dump_dir, '%06d_proposal_pc.ply'%(idx_beg+i))) 74 | if np.sum(objectness_prob>DUMP_CONF_THRESH)>0: 75 | pc_util.write_ply(pred_center[i,objectness_prob>DUMP_CONF_THRESH,0:3], os.path.join(dump_dir, '%06d_confident_proposal_pc.ply'%(idx_beg+i))) 76 | 77 | # Dump predicted bounding boxes 78 | if np.sum(objectness_prob>DUMP_CONF_THRESH)>0: 79 | num_proposal = pred_center.shape[1] 80 | obbs = [] 81 | for j in range(num_proposal): 82 | obb = config.param2obb(pred_center[i,j,0:3], pred_heading_class[i,j], pred_heading_residual[i,j], 83 | pred_size_class[i,j], pred_size_residual[i,j]) 84 | obbs.append(obb) 85 | if len(obbs)>0: 86 | obbs = np.vstack(tuple(obbs)) # (num_proposal, 7) 87 | pc_util.write_oriented_bbox(obbs[objectness_prob>DUMP_CONF_THRESH,:], os.path.join(dump_dir, '%06d_pred_confident_bbox.ply'%(idx_beg+i))) 88 | pc_util.write_oriented_bbox(obbs[np.logical_and(objectness_prob>DUMP_CONF_THRESH, pred_mask[i,:]==1),:], os.path.join(dump_dir, '%06d_pred_confident_nms_bbox.ply'%(idx_beg+i))) 89 | pc_util.write_oriented_bbox(obbs[pred_mask[i,:]==1,:], os.path.join(dump_dir, '%06d_pred_nms_bbox.ply'%(idx_beg+i))) 90 | pc_util.write_oriented_bbox(obbs, os.path.join(dump_dir, '%06d_pred_bbox.ply'%(idx_beg+i))) 91 | 92 | # Return if it is at inference time. No dumping of groundtruths 93 | if inference_switch: 94 | return 95 | 96 | # LABELS 97 | gt_center = end_points['center_label'].cpu().numpy() # (B,MAX_NUM_OBJ,3) 98 | gt_mask = end_points['box_label_mask'].cpu().numpy() # B,K2 99 | gt_heading_class = end_points['heading_class_label'].cpu().numpy() # B,K2 100 | gt_heading_residual = end_points['heading_residual_label'].cpu().numpy() # B,K2 101 | gt_size_class = end_points['size_class_label'].cpu().numpy() # B,K2 102 | gt_size_residual = end_points['size_residual_label'].cpu().numpy() # B,K2,3 103 | objectness_label = end_points['objectness_label'].detach().cpu().numpy() # (B,K,) 104 | objectness_mask = end_points['objectness_mask'].detach().cpu().numpy() # (B,K,) 105 | 106 | for i in range(batch_size): 107 | if np.sum(objectness_label[i,:])>0: 108 | pc_util.write_ply(pred_center[i,objectness_label[i,:]>0,0:3], os.path.join(dump_dir, '%06d_gt_positive_proposal_pc.ply'%(idx_beg+i))) 109 | if np.sum(objectness_mask[i,:])>0: 110 | pc_util.write_ply(pred_center[i,objectness_mask[i,:]>0,0:3], os.path.join(dump_dir, '%06d_gt_mask_proposal_pc.ply'%(idx_beg+i))) 111 | pc_util.write_ply(gt_center[i,:,0:3], os.path.join(dump_dir, '%06d_gt_centroid_pc.ply'%(idx_beg+i))) 112 | pc_util.write_ply_color(pred_center[i,:,0:3], objectness_label[i,:], os.path.join(dump_dir, '%06d_proposal_pc_objectness_label.obj'%(idx_beg+i))) 113 | 114 | # Dump GT bounding boxes 115 | obbs = [] 116 | for j in range(gt_center.shape[1]): 117 | if gt_mask[i,j] == 0: continue 118 | obb = config.param2obb(gt_center[i,j,0:3], gt_heading_class[i,j], gt_heading_residual[i,j], 119 | gt_size_class[i,j], gt_size_residual[i,j]) 120 | obbs.append(obb) 121 | if len(obbs)>0: 122 | obbs = np.vstack(tuple(obbs)) # (num_gt_objects, 7) 123 | pc_util.write_oriented_bbox(obbs, os.path.join(dump_dir, '%06d_gt_bbox.ply'%(idx_beg+i))) 124 | 125 | # OPTIONALL, also dump prediction and gt details 126 | if 'batch_pred_map_cls' in end_points: 127 | for ii in range(batch_size): 128 | fout = open(os.path.join(dump_dir, '%06d_pred_map_cls.txt'%(ii)), 'w') 129 | for t in end_points['batch_pred_map_cls'][ii]: 130 | fout.write(str(t[0])+' ') 131 | fout.write(",".join([str(x) for x in list(t[1].flatten())])) 132 | fout.write(' '+str(t[2])) 133 | fout.write('\n') 134 | fout.close() 135 | if 'batch_gt_map_cls' in end_points: 136 | for ii in range(batch_size): 137 | fout = open(os.path.join(dump_dir, '%06d_gt_map_cls.txt'%(ii)), 'w') 138 | for t in end_points['batch_gt_map_cls'][ii]: 139 | fout.write(str(t[0])+' ') 140 | fout.write(",".join([str(x) for x in list(t[1].flatten())])) 141 | fout.write('\n') 142 | fout.close() 143 | 144 | -------------------------------------------------------------------------------- /object-detection-3d/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .weight_init import trunc_normal_ 2 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 3 | from .helper import to_ntuple, to_2tuple, to_3tuple, to_4tuple 4 | -------------------------------------------------------------------------------- /object-detection-3d/models/layers/drop.py: -------------------------------------------------------------------------------- 1 | """ DropBlock, DropPath 2 | 3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. 4 | 5 | Papers: 6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) 7 | 8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) 9 | 10 | Code: 11 | DropBlock impl inspired by two Tensorflow impl that I liked: 12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | def drop_block_2d( 23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, 24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 26 | 27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 28 | runs with success, but needs further validation and possibly optimization for lower runtime impact. 29 | """ 30 | B, C, H, W = x.shape 31 | total_size = W * H 32 | clipped_block_size = min(block_size, min(W, H)) 33 | # seed_drop_rate, the gamma parameter 34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 35 | (W - block_size + 1) * (H - block_size + 1)) 36 | 37 | # Forces the block to be inside the feature map. 38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) 39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ 40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) 41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) 42 | 43 | if batchwise: 44 | # one mask for whole batch, quite a bit faster 45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) 46 | else: 47 | uniform_noise = torch.rand_like(x) 48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) 49 | block_mask = -F.max_pool2d( 50 | -block_mask, 51 | kernel_size=clipped_block_size, # block_size, 52 | stride=1, 53 | padding=clipped_block_size // 2) 54 | 55 | if with_noise: 56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 57 | if inplace: 58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) 59 | else: 60 | x = x * block_mask + normal_noise * (1 - block_mask) 61 | else: 62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) 63 | if inplace: 64 | x.mul_(block_mask * normalize_scale) 65 | else: 66 | x = x * block_mask * normalize_scale 67 | return x 68 | 69 | 70 | def drop_block_fast_2d( 71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, 72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 74 | 75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid 76 | block mask at edges. 77 | """ 78 | B, C, H, W = x.shape 79 | total_size = W * H 80 | clipped_block_size = min(block_size, min(W, H)) 81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 82 | (W - block_size + 1) * (H - block_size + 1)) 83 | 84 | if batchwise: 85 | # one mask for whole batch, quite a bit faster 86 | block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma 87 | else: 88 | # mask per batch element 89 | block_mask = torch.rand_like(x) < gamma 90 | block_mask = F.max_pool2d( 91 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) 92 | 93 | if with_noise: 94 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 95 | if inplace: 96 | x.mul_(1. - block_mask).add_(normal_noise * block_mask) 97 | else: 98 | x = x * (1. - block_mask) + normal_noise * block_mask 99 | else: 100 | block_mask = 1 - block_mask 101 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) 102 | if inplace: 103 | x.mul_(block_mask * normalize_scale) 104 | else: 105 | x = x * block_mask * normalize_scale 106 | return x 107 | 108 | 109 | class DropBlock2d(nn.Module): 110 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 111 | """ 112 | def __init__(self, 113 | drop_prob=0.1, 114 | block_size=7, 115 | gamma_scale=1.0, 116 | with_noise=False, 117 | inplace=False, 118 | batchwise=False, 119 | fast=True): 120 | super(DropBlock2d, self).__init__() 121 | self.drop_prob = drop_prob 122 | self.gamma_scale = gamma_scale 123 | self.block_size = block_size 124 | self.with_noise = with_noise 125 | self.inplace = inplace 126 | self.batchwise = batchwise 127 | self.fast = fast # FIXME finish comparisons of fast vs not 128 | 129 | def forward(self, x): 130 | if not self.training or not self.drop_prob: 131 | return x 132 | if self.fast: 133 | return drop_block_fast_2d( 134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 135 | else: 136 | return drop_block_2d( 137 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 138 | 139 | 140 | def drop_path(x, drop_prob: float = 0., training: bool = False): 141 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 142 | 143 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 144 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 145 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 146 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 147 | 'survival rate' as the argument. 148 | 149 | """ 150 | if drop_prob == 0. or not training: 151 | return x 152 | keep_prob = 1 - drop_prob 153 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 154 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 155 | random_tensor.floor_() # binarize 156 | output = x.div(keep_prob) * random_tensor 157 | return output 158 | 159 | 160 | class DropPath(nn.Module): 161 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 162 | """ 163 | def __init__(self, drop_prob=None): 164 | super(DropPath, self).__init__() 165 | self.drop_prob = drop_prob 166 | 167 | def forward(self, x): 168 | return drop_path(x, self.drop_prob, self.training) 169 | -------------------------------------------------------------------------------- /object-detection-3d/models/layers/helper.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | from torch._six import container_abcs 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, container_abcs.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /object-detection-3d/models/layers/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 61 | -------------------------------------------------------------------------------- /object-detection-3d/models/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Modules to compute the matching cost and solve the corresponding LSAP. 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | 9 | from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 10 | 11 | 12 | class HungarianMatcher(nn.Module): 13 | """This class computes an assignment between the targets and the predictions of the network 14 | 15 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 16 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 17 | while the others are un-matched (and thus treated as non-objects). 18 | """ 19 | 20 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 21 | """Creates the matcher 22 | 23 | Params: 24 | cost_class: This is the relative weight of the classification error in the matching cost 25 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 26 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 27 | """ 28 | super().__init__() 29 | self.cost_class = cost_class 30 | self.cost_bbox = cost_bbox 31 | self.cost_giou = cost_giou 32 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 33 | 34 | @torch.no_grad() 35 | def forward(self, outputs, targets): 36 | """ Performs the matching 37 | 38 | Params: 39 | outputs: This is a dict that contains at least these entries: 40 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 41 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 42 | 43 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 44 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 45 | objects in the target) containing the class labels 46 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 47 | 48 | Returns: 49 | A list of size batch_size, containing tuples of (index_i, index_j) where: 50 | - index_i is the indices of the selected predictions (in order) 51 | - index_j is the indices of the corresponding selected targets (in order) 52 | For each batch element, it holds: 53 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 54 | """ 55 | bs, num_queries = outputs["pred_logits2d"].shape[:2] 56 | 57 | # We flatten to compute the cost matrices in a batch 58 | out_prob = outputs["pred_logits2d"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 59 | out_bbox = outputs["pred_boxes2d"].flatten(0, 1) # [batch_size * num_queries, 4] 60 | 61 | # Also concat the target labels and boxes 62 | tgt_ids = torch.cat([v["labels"] for v in targets]) 63 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 64 | 65 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 66 | # but approximate it in 1 - proba[target class]. 67 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 68 | cost_class = -out_prob[:, tgt_ids] 69 | 70 | # Compute the L1 cost between boxes 71 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 72 | 73 | # Compute the giou cost betwen boxes 74 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 75 | 76 | # Final cost matrix 77 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 78 | C = C.view(bs, num_queries, -1).cpu() 79 | 80 | sizes = [len(v["boxes"]) for v in targets] 81 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 82 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 83 | 84 | 85 | def build_matcher(weight_dict): 86 | return HungarianMatcher(cost_class=weight_dict['loss_ce'], cost_bbox=weight_dict['loss_bbox'], 87 | cost_giou=weight_dict['loss_giou']) 88 | -------------------------------------------------------------------------------- /object-detection-3d/models/pointnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | import sys 9 | import os 10 | 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | ROOT_DIR = os.path.dirname(BASE_DIR) 13 | sys.path.append(ROOT_DIR) 14 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 15 | sys.path.append(os.path.join(ROOT_DIR, 'pointnet2')) 16 | sys.path.append(os.path.join(ROOT_DIR, 'ops', 'pt_custom_ops')) 17 | 18 | from pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule 19 | 20 | 21 | class Pointnet2Backbone(nn.Module): 22 | r""" 23 | Backbone network for point cloud feature learning. 24 | Based on Pointnet++ single-scale grouping network. 25 | 26 | Parameters 27 | ---------- 28 | input_feature_dim: int 29 | Number of input channels in the feature descriptor for each point. 30 | e.g. 3 for RGB. 31 | """ 32 | 33 | def __init__(self, input_feature_dim=0, width=1, depth=2): 34 | super().__init__() 35 | self.depth = depth 36 | self.width = width 37 | 38 | self.sa1 = PointnetSAModuleVotes( 39 | npoint=2048, 40 | radius=0.2, 41 | nsample=64, 42 | mlp=[input_feature_dim] + [64 * width for i in range(depth)] + [128 * width], 43 | use_xyz=True, 44 | normalize_xyz=True 45 | ) 46 | 47 | self.sa2 = PointnetSAModuleVotes( 48 | npoint=1024, 49 | radius=0.4, 50 | nsample=32, 51 | mlp=[128 * width] + [128 * width for i in range(depth)] + [256 * width], 52 | use_xyz=True, 53 | normalize_xyz=True 54 | ) 55 | 56 | self.sa3 = PointnetSAModuleVotes( 57 | npoint=512, 58 | radius=0.8, 59 | nsample=16, 60 | mlp=[256 * width] + [128 * width for i in range(depth)] + [256 * width], 61 | use_xyz=True, 62 | normalize_xyz=True 63 | ) 64 | 65 | self.sa4 = PointnetSAModuleVotes( 66 | npoint=256, 67 | radius=1.2, 68 | nsample=16, 69 | mlp=[256 * width] + [128 * width for i in range(depth)] + [256 * width], 70 | use_xyz=True, 71 | normalize_xyz=True 72 | ) 73 | 74 | self.fp1 = PointnetFPModule(mlp=[256 * width + 256 * width, 256 * width, 256 * width]) 75 | self.fp2 = PointnetFPModule(mlp=[256 * width + 256 * width, 256 * width, 288]) 76 | 77 | def _break_up_pc(self, pc): 78 | xyz = pc[..., 0:3].contiguous() 79 | features = ( 80 | pc[..., 3:].transpose(1, 2).contiguous() 81 | if pc.size(-1) > 3 else None 82 | ) 83 | 84 | return xyz, features 85 | 86 | def forward(self, pointcloud: torch.cuda.FloatTensor, end_points=None): 87 | r""" 88 | Forward pass of the network 89 | 90 | Parameters 91 | ---------- 92 | pointcloud: Variable(torch.cuda.FloatTensor) 93 | (B, N, 3 + input_feature_dim) tensor 94 | Point cloud to run predicts on 95 | Each point in the point-cloud MUST 96 | be formated as (x, y, z, features...) 97 | 98 | Returns 99 | ---------- 100 | end_points: {XXX_xyz, XXX_features, XXX_inds} 101 | XXX_xyz: float32 Tensor of shape (B,K,3) 102 | XXX_features: float32 Tensor of shape (B,K,D) 103 | XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1] 104 | """ 105 | if not end_points: end_points = {} 106 | batch_size = pointcloud.shape[0] 107 | 108 | xyz, features = self._break_up_pc(pointcloud) 109 | 110 | # --------- 4 SET ABSTRACTION LAYERS --------- 111 | xyz, features, fps_inds = self.sa1(xyz, features) 112 | end_points['sa1_inds'] = fps_inds 113 | end_points['sa1_xyz'] = xyz 114 | end_points['sa1_features'] = features 115 | 116 | xyz, features, fps_inds = self.sa2(xyz, features) # this fps_inds is just 0,1,...,1023 117 | end_points['sa2_inds'] = fps_inds 118 | end_points['sa2_xyz'] = xyz 119 | end_points['sa2_features'] = features 120 | 121 | xyz, features, fps_inds = self.sa3(xyz, features) # this fps_inds is just 0,1,...,511 122 | end_points['sa3_xyz'] = xyz 123 | end_points['sa3_features'] = features 124 | 125 | xyz, features, fps_inds = self.sa4(xyz, features) # this fps_inds is just 0,1,...,255 126 | end_points['sa4_xyz'] = xyz 127 | end_points['sa4_features'] = features 128 | 129 | # --------- 2 FEATURE UPSAMPLING LAYERS -------- 130 | features = self.fp1(end_points['sa3_xyz'], end_points['sa4_xyz'], end_points['sa3_features'], 131 | end_points['sa4_features']) 132 | features = self.fp2(end_points['sa2_xyz'], end_points['sa3_xyz'], end_points['sa2_features'], features) 133 | end_points['fp2_features'] = features 134 | end_points['fp2_xyz'] = end_points['sa2_xyz'] 135 | num_seed = end_points['fp2_xyz'].shape[1] 136 | end_points['fp2_inds'] = end_points['sa1_inds'][:, 0:num_seed] # indices among the entire input point clouds 137 | 138 | return end_points 139 | 140 | 141 | if __name__ == '__main__': 142 | backbone_net = Pointnet2Backbone(input_feature_dim=3).cuda() 143 | print(backbone_net) 144 | backbone_net.eval() 145 | out = backbone_net(torch.rand(16, 20000, 6).cuda()) 146 | for key in sorted(out.keys()): 147 | print(key, '\t', out[key].shape) 148 | -------------------------------------------------------------------------------- /object-detection-3d/models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from typing import Optional 7 | from multi_head_attention import MultiheadAttention 8 | 9 | 10 | class TransformerDecoderLayer(nn.Module): 11 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", 12 | self_posembed=None, cross_posembed=None): 13 | super().__init__() 14 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 15 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 16 | # Implementation of Feedforward model 17 | self.linear1 = nn.Linear(d_model, dim_feedforward) 18 | self.dropout = nn.Dropout(dropout) 19 | self.linear2 = nn.Linear(dim_feedforward, d_model) 20 | 21 | self.norm1 = nn.LayerNorm(d_model) 22 | self.norm2 = nn.LayerNorm(d_model) 23 | self.norm3 = nn.LayerNorm(d_model) 24 | self.dropout1 = nn.Dropout(dropout) 25 | self.dropout2 = nn.Dropout(dropout) 26 | self.dropout3 = nn.Dropout(dropout) 27 | 28 | self.activation = _get_activation_fn(activation) 29 | 30 | self.self_posembed = self_posembed 31 | self.cross_posembed = cross_posembed 32 | 33 | def with_pos_embed(self, tensor, pos_embed: Optional[Tensor]): 34 | return tensor if pos_embed is None else tensor + pos_embed 35 | 36 | def forward(self, query, key, query_pos, key_pos): 37 | """ 38 | :param query: B C Pq 39 | :param key: B C Pk 40 | :param query_pos: B Pq 3/6 41 | :param key_pos: B Pk 3/6 42 | :param value_pos: [B Pq 3/6] 43 | 44 | :return: 45 | """ 46 | # NxCxP to PxNxC 47 | if self.self_posembed is not None: 48 | query_pos_embed = self.self_posembed(query_pos).permute(2, 0, 1) 49 | else: 50 | query_pos_embed = None 51 | if self.cross_posembed is not None: 52 | key_pos_embed = self.cross_posembed(key_pos).permute(2, 0, 1) 53 | else: 54 | key_pos_embed = None 55 | 56 | query = query.permute(2, 0, 1) 57 | key = key.permute(2, 0, 1) 58 | 59 | q = k = v = self.with_pos_embed(query, query_pos_embed) 60 | query2 = self.self_attn(q, k, value=v)[0] 61 | query = query + self.dropout1(query2) 62 | query = self.norm1(query) 63 | 64 | query2 = self.multihead_attn(query=self.with_pos_embed(query, query_pos_embed), 65 | key=self.with_pos_embed(key, key_pos_embed), 66 | value=self.with_pos_embed(key, key_pos_embed))[0] 67 | query = query + self.dropout2(query2) 68 | query = self.norm2(query) 69 | 70 | query2 = self.linear2(self.dropout(self.activation(self.linear1(query)))) 71 | query = query + self.dropout3(query2) 72 | query = self.norm3(query) 73 | 74 | # NxCxP to PxNxC 75 | query = query.permute(1, 2, 0) 76 | return query 77 | 78 | 79 | def _get_activation_fn(activation): 80 | """Return an activation function given a string""" 81 | if activation == "relu": 82 | return F.relu 83 | if activation == "gelu": 84 | return F.gelu 85 | if activation == "glu": 86 | return F.glu 87 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 88 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 10 | const int nsample); 11 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #ifndef _CUDA_UTILS_H 7 | #define _CUDA_UTILS_H 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | #define TOTAL_THREADS 512 19 | 20 | inline int opt_n_threads(int work_size) { 21 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 22 | 23 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 24 | } 25 | 26 | inline dim3 opt_block_config(int x, int y) { 27 | const int x_threads = opt_n_threads(x); 28 | const int y_threads = 29 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 30 | dim3 block_config(x_threads, y_threads, 1); 31 | 32 | return block_config; 33 | } 34 | 35 | #define CUDA_CHECK_ERRORS() \ 36 | do { \ 37 | cudaError_t err = cudaGetLastError(); \ 38 | if (cudaSuccess != err) { \ 39 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 40 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 41 | __FILE__); \ 42 | exit(-1); \ 43 | } \ 44 | } while (0) 45 | 46 | #endif 47 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/include/group_points.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 12 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 13 | at::Tensor weight); 14 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 15 | at::Tensor weight, const int m); 16 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/include/sampling.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | 9 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 10 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 11 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 12 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/include/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #pragma once 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_INT(x) \ 21 | do { \ 22 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 23 | #x " must be an int tensor"); \ 24 | } while (0) 25 | 26 | #define CHECK_IS_FLOAT(x) \ 27 | do { \ 28 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 29 | #x " must be a float tensor"); \ 30 | } while (0) 31 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "utils.h" 8 | 9 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 10 | int nsample, const float *new_xyz, 11 | const float *xyz, int *idx); 12 | 13 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 14 | const int nsample) { 15 | CHECK_CONTIGUOUS(new_xyz); 16 | CHECK_CONTIGUOUS(xyz); 17 | CHECK_IS_FLOAT(new_xyz); 18 | CHECK_IS_FLOAT(xyz); 19 | 20 | if (new_xyz.device().is_cuda()) { 21 | CHECK_CUDA(xyz); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 26 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 27 | 28 | if (new_xyz.device().is_cuda()) { 29 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 30 | radius, nsample, new_xyz.data_ptr(), 31 | xyz.data_ptr(), idx.data_ptr()); 32 | } else { 33 | TORCH_CHECK(false, "CPU not supported"); 34 | } 35 | 36 | return idx; 37 | } 38 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 13 | // output: idx(b, m, nsample) 14 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 15 | int nsample, 16 | const float *__restrict__ new_xyz, 17 | const float *__restrict__ xyz, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | xyz += batch_index * n * 3; 21 | new_xyz += batch_index * m * 3; 22 | idx += m * nsample * batch_index; 23 | 24 | int index = threadIdx.x; 25 | int stride = blockDim.x; 26 | 27 | float radius2 = radius * radius; 28 | for (int j = index; j < m; j += stride) { 29 | float new_x = new_xyz[j * 3 + 0]; 30 | float new_y = new_xyz[j * 3 + 1]; 31 | float new_z = new_xyz[j * 3 + 2]; 32 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 33 | float x = xyz[k * 3 + 0]; 34 | float y = xyz[k * 3 + 1]; 35 | float z = xyz[k * 3 + 2]; 36 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 37 | (new_z - z) * (new_z - z); 38 | if (d2 < radius2) { 39 | if (cnt == 0) { 40 | for (int l = 0; l < nsample; ++l) { 41 | idx[j * nsample + l] = k; 42 | } 43 | } 44 | idx[j * nsample + cnt] = k; 45 | ++cnt; 46 | } 47 | } 48 | } 49 | } 50 | 51 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 52 | int nsample, const float *new_xyz, 53 | const float *xyz, int *idx) { 54 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 55 | query_ball_point_kernel<<>>( 56 | b, n, m, radius, nsample, new_xyz, xyz, idx); 57 | 58 | CUDA_CHECK_ERRORS(); 59 | } 60 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "ball_query.h" 7 | #include "group_points.h" 8 | #include "interpolate.h" 9 | #include "sampling.h" 10 | 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 12 | m.def("gather_points", &gather_points); 13 | m.def("gather_points_grad", &gather_points_grad); 14 | m.def("furthest_point_sampling", &furthest_point_sampling); 15 | 16 | m.def("three_nn", &three_nn); 17 | m.def("three_interpolate", &three_interpolate); 18 | m.def("three_interpolate_grad", &three_interpolate_grad); 19 | 20 | m.def("ball_query", &ball_query); 21 | 22 | m.def("group_points", &group_points); 23 | m.def("group_points_grad", &group_points_grad); 24 | } 25 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "group_points.h" 7 | #include "utils.h" 8 | 9 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 10 | const float *points, const int *idx, 11 | float *out); 12 | 13 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 14 | int nsample, const float *grad_out, 15 | const int *idx, float *grad_points); 16 | 17 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 18 | CHECK_CONTIGUOUS(points); 19 | CHECK_CONTIGUOUS(idx); 20 | CHECK_IS_FLOAT(points); 21 | CHECK_IS_INT(idx); 22 | 23 | if (points.device().is_cuda()) { 24 | CHECK_CUDA(idx); 25 | } 26 | 27 | at::Tensor output = 28 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 29 | at::device(points.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (points.device().is_cuda()) { 32 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 33 | idx.size(1), idx.size(2), points.data_ptr(), 34 | idx.data_ptr(), output.data_ptr()); 35 | } else { 36 | TORCH_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return output; 40 | } 41 | 42 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 43 | CHECK_CONTIGUOUS(grad_out); 44 | CHECK_CONTIGUOUS(idx); 45 | CHECK_IS_FLOAT(grad_out); 46 | CHECK_IS_INT(idx); 47 | 48 | if (grad_out.device().is_cuda()) { 49 | CHECK_CUDA(idx); 50 | } 51 | 52 | at::Tensor output = 53 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 54 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 55 | 56 | if (grad_out.device().is_cuda()) { 57 | group_points_grad_kernel_wrapper( 58 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 59 | grad_out.data_ptr(), idx.data_ptr(), output.data_ptr()); 60 | } else { 61 | TORCH_CHECK(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, npoints, nsample) 12 | // output: out(b, c, npoints, nsample) 13 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 14 | int nsample, 15 | const float *__restrict__ points, 16 | const int *__restrict__ idx, 17 | float *__restrict__ out) { 18 | int batch_index = blockIdx.x; 19 | points += batch_index * n * c; 20 | idx += batch_index * npoints * nsample; 21 | out += batch_index * npoints * nsample * c; 22 | 23 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 24 | const int stride = blockDim.y * blockDim.x; 25 | for (int i = index; i < c * npoints; i += stride) { 26 | const int l = i / npoints; 27 | const int j = i % npoints; 28 | for (int k = 0; k < nsample; ++k) { 29 | int ii = idx[j * nsample + k]; 30 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 31 | } 32 | } 33 | } 34 | 35 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 36 | const float *points, const int *idx, 37 | float *out) { 38 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 39 | 40 | group_points_kernel<<>>( 41 | b, c, n, npoints, nsample, points, idx, out); 42 | 43 | CUDA_CHECK_ERRORS(); 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points) { 74 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 75 | 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | CUDA_CHECK_ERRORS(); 80 | } 81 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "interpolate.h" 7 | #include "utils.h" 8 | 9 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 10 | const float *known, float *dist2, int *idx); 11 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 12 | const float *points, const int *idx, 13 | const float *weight, float *out); 14 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 15 | const float *grad_out, 16 | const int *idx, const float *weight, 17 | float *grad_points); 18 | 19 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 20 | CHECK_CONTIGUOUS(unknowns); 21 | CHECK_CONTIGUOUS(knows); 22 | CHECK_IS_FLOAT(unknowns); 23 | CHECK_IS_FLOAT(knows); 24 | 25 | if (unknowns.device().is_cuda()) { 26 | CHECK_CUDA(knows); 27 | } 28 | 29 | at::Tensor idx = 30 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 31 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 32 | at::Tensor dist2 = 33 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 34 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 35 | 36 | if (unknowns.device().is_cuda()) { 37 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 38 | unknowns.data_ptr(), knows.data_ptr(), 39 | dist2.data_ptr(), idx.data_ptr()); 40 | } else { 41 | TORCH_CHECK(false, "CPU not supported"); 42 | } 43 | 44 | return {dist2, idx}; 45 | } 46 | 47 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 48 | at::Tensor weight) { 49 | CHECK_CONTIGUOUS(points); 50 | CHECK_CONTIGUOUS(idx); 51 | CHECK_CONTIGUOUS(weight); 52 | CHECK_IS_FLOAT(points); 53 | CHECK_IS_INT(idx); 54 | CHECK_IS_FLOAT(weight); 55 | 56 | if (points.device().is_cuda()) { 57 | CHECK_CUDA(idx); 58 | CHECK_CUDA(weight); 59 | } 60 | 61 | at::Tensor output = 62 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 63 | at::device(points.device()).dtype(at::ScalarType::Float)); 64 | 65 | if (points.device().is_cuda()) { 66 | three_interpolate_kernel_wrapper( 67 | points.size(0), points.size(1), points.size(2), idx.size(1), 68 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 69 | output.data_ptr()); 70 | } else { 71 | TORCH_CHECK(false, "CPU not supported"); 72 | } 73 | 74 | return output; 75 | } 76 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 77 | at::Tensor weight, const int m) { 78 | CHECK_CONTIGUOUS(grad_out); 79 | CHECK_CONTIGUOUS(idx); 80 | CHECK_CONTIGUOUS(weight); 81 | CHECK_IS_FLOAT(grad_out); 82 | CHECK_IS_INT(idx); 83 | CHECK_IS_FLOAT(weight); 84 | 85 | if (grad_out.device().is_cuda()) { 86 | CHECK_CUDA(idx); 87 | CHECK_CUDA(weight); 88 | } 89 | 90 | at::Tensor output = 91 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 92 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 93 | 94 | if (grad_out.device().is_cuda()) { 95 | three_interpolate_grad_kernel_wrapper( 96 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 97 | grad_out.data_ptr(), idx.data_ptr(), weight.data_ptr(), 98 | output.data_ptr()); 99 | } else { 100 | TORCH_CHECK(false, "CPU not supported"); 101 | } 102 | 103 | return output; 104 | } 105 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cuda_utils.h" 11 | 12 | // input: unknown(b, n, 3) known(b, m, 3) 13 | // output: dist2(b, n, 3), idx(b, n, 3) 14 | __global__ void three_nn_kernel(int b, int n, int m, 15 | const float *__restrict__ unknown, 16 | const float *__restrict__ known, 17 | float *__restrict__ dist2, 18 | int *__restrict__ idx) { 19 | int batch_index = blockIdx.x; 20 | unknown += batch_index * n * 3; 21 | known += batch_index * m * 3; 22 | dist2 += batch_index * n * 3; 23 | idx += batch_index * n * 3; 24 | 25 | int index = threadIdx.x; 26 | int stride = blockDim.x; 27 | for (int j = index; j < n; j += stride) { 28 | float ux = unknown[j * 3 + 0]; 29 | float uy = unknown[j * 3 + 1]; 30 | float uz = unknown[j * 3 + 2]; 31 | 32 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 33 | int besti1 = 0, besti2 = 0, besti3 = 0; 34 | for (int k = 0; k < m; ++k) { 35 | float x = known[k * 3 + 0]; 36 | float y = known[k * 3 + 1]; 37 | float z = known[k * 3 + 2]; 38 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 39 | if (d < best1) { 40 | best3 = best2; 41 | besti3 = besti2; 42 | best2 = best1; 43 | besti2 = besti1; 44 | best1 = d; 45 | besti1 = k; 46 | } else if (d < best2) { 47 | best3 = best2; 48 | besti3 = besti2; 49 | best2 = d; 50 | besti2 = k; 51 | } else if (d < best3) { 52 | best3 = d; 53 | besti3 = k; 54 | } 55 | } 56 | dist2[j * 3 + 0] = best1; 57 | dist2[j * 3 + 1] = best2; 58 | dist2[j * 3 + 2] = best3; 59 | 60 | idx[j * 3 + 0] = besti1; 61 | idx[j * 3 + 1] = besti2; 62 | idx[j * 3 + 2] = besti3; 63 | } 64 | } 65 | 66 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 67 | const float *known, float *dist2, int *idx) { 68 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 69 | three_nn_kernel<<>>(b, n, m, unknown, known, 70 | dist2, idx); 71 | 72 | CUDA_CHECK_ERRORS(); 73 | } 74 | 75 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 76 | // output: out(b, c, n) 77 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 78 | const float *__restrict__ points, 79 | const int *__restrict__ idx, 80 | const float *__restrict__ weight, 81 | float *__restrict__ out) { 82 | int batch_index = blockIdx.x; 83 | points += batch_index * m * c; 84 | 85 | idx += batch_index * n * 3; 86 | weight += batch_index * n * 3; 87 | 88 | out += batch_index * n * c; 89 | 90 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 91 | const int stride = blockDim.y * blockDim.x; 92 | for (int i = index; i < c * n; i += stride) { 93 | const int l = i / n; 94 | const int j = i % n; 95 | float w1 = weight[j * 3 + 0]; 96 | float w2 = weight[j * 3 + 1]; 97 | float w3 = weight[j * 3 + 2]; 98 | 99 | int i1 = idx[j * 3 + 0]; 100 | int i2 = idx[j * 3 + 1]; 101 | int i3 = idx[j * 3 + 2]; 102 | 103 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 104 | points[l * m + i3] * w3; 105 | } 106 | } 107 | 108 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 109 | const float *points, const int *idx, 110 | const float *weight, float *out) { 111 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 112 | three_interpolate_kernel<<>>( 113 | b, c, m, n, points, idx, weight, out); 114 | 115 | CUDA_CHECK_ERRORS(); 116 | } 117 | 118 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 119 | // output: grad_points(b, c, m) 120 | 121 | __global__ void three_interpolate_grad_kernel( 122 | int b, int c, int n, int m, const float *__restrict__ grad_out, 123 | const int *__restrict__ idx, const float *__restrict__ weight, 124 | float *__restrict__ grad_points) { 125 | int batch_index = blockIdx.x; 126 | grad_out += batch_index * n * c; 127 | idx += batch_index * n * 3; 128 | weight += batch_index * n * 3; 129 | grad_points += batch_index * m * c; 130 | 131 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 132 | const int stride = blockDim.y * blockDim.x; 133 | for (int i = index; i < c * n; i += stride) { 134 | const int l = i / n; 135 | const int j = i % n; 136 | float w1 = weight[j * 3 + 0]; 137 | float w2 = weight[j * 3 + 1]; 138 | float w3 = weight[j * 3 + 2]; 139 | 140 | int i1 = idx[j * 3 + 0]; 141 | int i2 = idx[j * 3 + 1]; 142 | int i3 = idx[j * 3 + 2]; 143 | 144 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 145 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 146 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 147 | } 148 | } 149 | 150 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 151 | const float *grad_out, 152 | const int *idx, const float *weight, 153 | float *grad_points) { 154 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 155 | three_interpolate_grad_kernel<<>>( 156 | b, c, n, m, grad_out, idx, weight, grad_points); 157 | 158 | CUDA_CHECK_ERRORS(); 159 | } 160 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include "sampling.h" 7 | #include "utils.h" 8 | 9 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 10 | const float *points, const int *idx, 11 | float *out); 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs); 19 | 20 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 21 | CHECK_CONTIGUOUS(points); 22 | CHECK_CONTIGUOUS(idx); 23 | CHECK_IS_FLOAT(points); 24 | CHECK_IS_INT(idx); 25 | 26 | if (points.device().is_cuda()) { 27 | CHECK_CUDA(idx); 28 | } 29 | 30 | at::Tensor output = 31 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 32 | at::device(points.device()).dtype(at::ScalarType::Float)); 33 | 34 | if (points.device().is_cuda()) { 35 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 36 | idx.size(1), points.data_ptr(), 37 | idx.data_ptr(), output.data_ptr()); 38 | } else { 39 | TORCH_CHECK(false, "CPU not supported"); 40 | } 41 | 42 | return output; 43 | } 44 | 45 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 46 | const int n) { 47 | CHECK_CONTIGUOUS(grad_out); 48 | CHECK_CONTIGUOUS(idx); 49 | CHECK_IS_FLOAT(grad_out); 50 | CHECK_IS_INT(idx); 51 | 52 | if (grad_out.device().is_cuda()) { 53 | CHECK_CUDA(idx); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 58 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (grad_out.device().is_cuda()) { 61 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 62 | idx.size(1), grad_out.data_ptr(), 63 | idx.data_ptr(), output.data_ptr()); 64 | } else { 65 | TORCH_CHECK(false, "CPU not supported"); 66 | } 67 | 68 | return output; 69 | } 70 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 71 | CHECK_CONTIGUOUS(points); 72 | CHECK_IS_FLOAT(points); 73 | 74 | at::Tensor output = 75 | torch::zeros({points.size(0), nsamples}, 76 | at::device(points.device()).dtype(at::ScalarType::Int)); 77 | 78 | at::Tensor tmp = 79 | torch::full({points.size(0), points.size(1)}, 1e10, 80 | at::device(points.device()).dtype(at::ScalarType::Float)); 81 | 82 | if (points.device().is_cuda()) { 83 | furthest_point_sampling_kernel_wrapper( 84 | points.size(0), points.size(1), nsamples, points.data_ptr(), 85 | tmp.data_ptr(), output.data_ptr()); 86 | } else { 87 | TORCH_CHECK(false, "CPU not supported"); 88 | } 89 | 90 | return output; 91 | } 92 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/_ext_src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. 2 | // 3 | // This source code is licensed under the MIT license found in the 4 | // LICENSE file in the root directory of this source tree. 5 | 6 | #include 7 | #include 8 | 9 | #include "cuda_utils.h" 10 | 11 | // input: points(b, c, n) idx(b, m) 12 | // output: out(b, c, m) 13 | __global__ void gather_points_kernel(int b, int c, int n, int m, 14 | const float *__restrict__ points, 15 | const int *__restrict__ idx, 16 | float *__restrict__ out) { 17 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 18 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 19 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 20 | int a = idx[i * m + j]; 21 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 22 | } 23 | } 24 | } 25 | } 26 | 27 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 28 | const float *points, const int *idx, 29 | float *out) { 30 | gather_points_kernel<<>>(b, c, n, npoints, 32 | points, idx, out); 33 | 34 | CUDA_CHECK_ERRORS(); 35 | } 36 | 37 | // input: grad_out(b, c, m) idx(b, m) 38 | // output: grad_points(b, c, n) 39 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | float *__restrict__ grad_points) { 43 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 44 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 45 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 46 | int a = idx[i * m + j]; 47 | atomicAdd(grad_points + (i * c + l) * n + a, 48 | grad_out[(i * c + l) * m + j]); 49 | } 50 | } 51 | } 52 | } 53 | 54 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 55 | const float *grad_out, const int *idx, 56 | float *grad_points) { 57 | gather_points_grad_kernel<<>>( 59 | b, c, n, npoints, grad_out, idx, grad_points); 60 | 61 | CUDA_CHECK_ERRORS(); 62 | } 63 | 64 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 65 | int idx1, int idx2) { 66 | const float v1 = dists[idx1], v2 = dists[idx2]; 67 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 68 | dists[idx1] = max(v1, v2); 69 | dists_i[idx1] = v2 > v1 ? i2 : i1; 70 | } 71 | 72 | // Input dataset: (b, n, 3), tmp: (b, n) 73 | // Ouput idxs (b, m) 74 | template 75 | __global__ void furthest_point_sampling_kernel( 76 | int b, int n, int m, const float *__restrict__ dataset, 77 | float *__restrict__ temp, int *__restrict__ idxs) { 78 | if (m <= 0) return; 79 | __shared__ float dists[block_size]; 80 | __shared__ int dists_i[block_size]; 81 | 82 | int batch_index = blockIdx.x; 83 | dataset += batch_index * n * 3; 84 | temp += batch_index * n; 85 | idxs += batch_index * m; 86 | 87 | int tid = threadIdx.x; 88 | const int stride = block_size; 89 | 90 | int old = 0; 91 | if (threadIdx.x == 0) idxs[0] = old; 92 | 93 | __syncthreads(); 94 | for (int j = 1; j < m; j++) { 95 | int besti = 0; 96 | float best = -1; 97 | float x1 = dataset[old * 3 + 0]; 98 | float y1 = dataset[old * 3 + 1]; 99 | float z1 = dataset[old * 3 + 2]; 100 | for (int k = tid; k < n; k += stride) { 101 | float x2, y2, z2; 102 | x2 = dataset[k * 3 + 0]; 103 | y2 = dataset[k * 3 + 1]; 104 | z2 = dataset[k * 3 + 2]; 105 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 106 | if (mag <= 1e-3) continue; 107 | 108 | float d = 109 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 110 | 111 | float d2 = min(d, temp[k]); 112 | temp[k] = d2; 113 | besti = d2 > best ? k : besti; 114 | best = d2 > best ? d2 : best; 115 | } 116 | dists[tid] = best; 117 | dists_i[tid] = besti; 118 | __syncthreads(); 119 | 120 | if (block_size >= 512) { 121 | if (tid < 256) { 122 | __update(dists, dists_i, tid, tid + 256); 123 | } 124 | __syncthreads(); 125 | } 126 | if (block_size >= 256) { 127 | if (tid < 128) { 128 | __update(dists, dists_i, tid, tid + 128); 129 | } 130 | __syncthreads(); 131 | } 132 | if (block_size >= 128) { 133 | if (tid < 64) { 134 | __update(dists, dists_i, tid, tid + 64); 135 | } 136 | __syncthreads(); 137 | } 138 | if (block_size >= 64) { 139 | if (tid < 32) { 140 | __update(dists, dists_i, tid, tid + 32); 141 | } 142 | __syncthreads(); 143 | } 144 | if (block_size >= 32) { 145 | if (tid < 16) { 146 | __update(dists, dists_i, tid, tid + 16); 147 | } 148 | __syncthreads(); 149 | } 150 | if (block_size >= 16) { 151 | if (tid < 8) { 152 | __update(dists, dists_i, tid, tid + 8); 153 | } 154 | __syncthreads(); 155 | } 156 | if (block_size >= 8) { 157 | if (tid < 4) { 158 | __update(dists, dists_i, tid, tid + 4); 159 | } 160 | __syncthreads(); 161 | } 162 | if (block_size >= 4) { 163 | if (tid < 2) { 164 | __update(dists, dists_i, tid, tid + 2); 165 | } 166 | __syncthreads(); 167 | } 168 | if (block_size >= 2) { 169 | if (tid < 1) { 170 | __update(dists, dists_i, tid, tid + 1); 171 | } 172 | __syncthreads(); 173 | } 174 | 175 | old = dists_i[0]; 176 | if (tid == 0) idxs[j] = old; 177 | } 178 | } 179 | 180 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 181 | const float *dataset, float *temp, 182 | int *idxs) { 183 | unsigned int n_threads = opt_n_threads(n); 184 | 185 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 186 | 187 | switch (n_threads) { 188 | case 512: 189 | furthest_point_sampling_kernel<512> 190 | <<>>(b, n, m, dataset, temp, idxs); 191 | break; 192 | case 256: 193 | furthest_point_sampling_kernel<256> 194 | <<>>(b, n, m, dataset, temp, idxs); 195 | break; 196 | case 128: 197 | furthest_point_sampling_kernel<128> 198 | <<>>(b, n, m, dataset, temp, idxs); 199 | break; 200 | case 64: 201 | furthest_point_sampling_kernel<64> 202 | <<>>(b, n, m, dataset, temp, idxs); 203 | break; 204 | case 32: 205 | furthest_point_sampling_kernel<32> 206 | <<>>(b, n, m, dataset, temp, idxs); 207 | break; 208 | case 16: 209 | furthest_point_sampling_kernel<16> 210 | <<>>(b, n, m, dataset, temp, idxs); 211 | break; 212 | case 8: 213 | furthest_point_sampling_kernel<8> 214 | <<>>(b, n, m, dataset, temp, idxs); 215 | break; 216 | case 4: 217 | furthest_point_sampling_kernel<4> 218 | <<>>(b, n, m, dataset, temp, idxs); 219 | break; 220 | case 2: 221 | furthest_point_sampling_kernel<2> 222 | <<>>(b, n, m, dataset, temp, idxs); 223 | break; 224 | case 1: 225 | furthest_point_sampling_kernel<1> 226 | <<>>(b, n, m, dataset, temp, idxs); 227 | break; 228 | default: 229 | furthest_point_sampling_kernel<512> 230 | <<>>(b, n, m, dataset, temp, idxs); 231 | } 232 | 233 | CUDA_CHECK_ERRORS(); 234 | } 235 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/pointnet2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Testing customized ops. ''' 7 | 8 | import torch 9 | from torch.autograd import gradcheck 10 | import numpy as np 11 | 12 | import os 13 | import sys 14 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(BASE_DIR) 16 | import pointnet2_utils 17 | 18 | def test_interpolation_grad(): 19 | batch_size = 1 20 | feat_dim = 2 21 | m = 4 22 | feats = torch.randn(batch_size, feat_dim, m, requires_grad=True).float().cuda() 23 | 24 | def interpolate_func(inputs): 25 | idx = torch.from_numpy(np.array([[[0,1,2],[1,2,3]]])).int().cuda() 26 | weight = torch.from_numpy(np.array([[[1,1,1],[2,2,2]]])).float().cuda() 27 | interpolated_feats = pointnet2_utils.three_interpolate(inputs, idx, weight) 28 | return interpolated_feats 29 | 30 | assert (gradcheck(interpolate_func, feats, atol=1e-1, rtol=1e-1)) 31 | 32 | if __name__=='__main__': 33 | test_interpolation_grad() 34 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ''' Modified based on Ref: https://github.com/erikwijmans/Pointnet2_PyTorch ''' 7 | import torch 8 | import torch.nn as nn 9 | from typing import List, Tuple 10 | 11 | class SharedMLP(nn.Sequential): 12 | 13 | def __init__( 14 | self, 15 | args: List[int], 16 | *, 17 | bn: bool = False, 18 | activation=nn.ReLU(inplace=True), 19 | preact: bool = False, 20 | first: bool = False, 21 | name: str = "" 22 | ): 23 | super().__init__() 24 | 25 | for i in range(len(args) - 1): 26 | self.add_module( 27 | name + 'layer{}'.format(i), 28 | Conv2d( 29 | args[i], 30 | args[i + 1], 31 | bn=(not first or not preact or (i != 0)) and bn, 32 | activation=activation 33 | if (not first or not preact or (i != 0)) else None, 34 | preact=preact 35 | ) 36 | ) 37 | 38 | 39 | class _BNBase(nn.Sequential): 40 | 41 | def __init__(self, in_size, batch_norm=None, name=""): 42 | super().__init__() 43 | self.add_module(name + "bn", batch_norm(in_size)) 44 | 45 | nn.init.constant_(self[0].weight, 1.0) 46 | nn.init.constant_(self[0].bias, 0) 47 | 48 | 49 | class BatchNorm1d(_BNBase): 50 | 51 | def __init__(self, in_size: int, *, name: str = ""): 52 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 53 | 54 | 55 | class BatchNorm2d(_BNBase): 56 | 57 | def __init__(self, in_size: int, name: str = ""): 58 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 59 | 60 | 61 | class BatchNorm3d(_BNBase): 62 | 63 | def __init__(self, in_size: int, name: str = ""): 64 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 65 | 66 | 67 | class _ConvBase(nn.Sequential): 68 | 69 | def __init__( 70 | self, 71 | in_size, 72 | out_size, 73 | kernel_size, 74 | stride, 75 | padding, 76 | activation, 77 | bn, 78 | init, 79 | conv=None, 80 | batch_norm=None, 81 | bias=True, 82 | preact=False, 83 | name="" 84 | ): 85 | super().__init__() 86 | 87 | bias = bias and (not bn) 88 | conv_unit = conv( 89 | in_size, 90 | out_size, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | padding=padding, 94 | bias=bias 95 | ) 96 | init(conv_unit.weight) 97 | if bias: 98 | nn.init.constant_(conv_unit.bias, 0) 99 | 100 | if bn: 101 | if not preact: 102 | bn_unit = batch_norm(out_size) 103 | else: 104 | bn_unit = batch_norm(in_size) 105 | 106 | if preact: 107 | if bn: 108 | self.add_module(name + 'bn', bn_unit) 109 | 110 | if activation is not None: 111 | self.add_module(name + 'activation', activation) 112 | 113 | self.add_module(name + 'conv', conv_unit) 114 | 115 | if not preact: 116 | if bn: 117 | self.add_module(name + 'bn', bn_unit) 118 | 119 | if activation is not None: 120 | self.add_module(name + 'activation', activation) 121 | 122 | 123 | class Conv1d(_ConvBase): 124 | 125 | def __init__( 126 | self, 127 | in_size: int, 128 | out_size: int, 129 | *, 130 | kernel_size: int = 1, 131 | stride: int = 1, 132 | padding: int = 0, 133 | activation=nn.ReLU(inplace=True), 134 | bn: bool = False, 135 | init=nn.init.kaiming_normal_, 136 | bias: bool = True, 137 | preact: bool = False, 138 | name: str = "" 139 | ): 140 | super().__init__( 141 | in_size, 142 | out_size, 143 | kernel_size, 144 | stride, 145 | padding, 146 | activation, 147 | bn, 148 | init, 149 | conv=nn.Conv1d, 150 | batch_norm=BatchNorm1d, 151 | bias=bias, 152 | preact=preact, 153 | name=name 154 | ) 155 | 156 | 157 | class Conv2d(_ConvBase): 158 | 159 | def __init__( 160 | self, 161 | in_size: int, 162 | out_size: int, 163 | *, 164 | kernel_size: Tuple[int, int] = (1, 1), 165 | stride: Tuple[int, int] = (1, 1), 166 | padding: Tuple[int, int] = (0, 0), 167 | activation=nn.ReLU(inplace=True), 168 | bn: bool = False, 169 | init=nn.init.kaiming_normal_, 170 | bias: bool = True, 171 | preact: bool = False, 172 | name: str = "" 173 | ): 174 | super().__init__( 175 | in_size, 176 | out_size, 177 | kernel_size, 178 | stride, 179 | padding, 180 | activation, 181 | bn, 182 | init, 183 | conv=nn.Conv2d, 184 | batch_norm=BatchNorm2d, 185 | bias=bias, 186 | preact=preact, 187 | name=name 188 | ) 189 | 190 | 191 | class Conv3d(_ConvBase): 192 | 193 | def __init__( 194 | self, 195 | in_size: int, 196 | out_size: int, 197 | *, 198 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 199 | stride: Tuple[int, int, int] = (1, 1, 1), 200 | padding: Tuple[int, int, int] = (0, 0, 0), 201 | activation=nn.ReLU(inplace=True), 202 | bn: bool = False, 203 | init=nn.init.kaiming_normal_, 204 | bias: bool = True, 205 | preact: bool = False, 206 | name: str = "" 207 | ): 208 | super().__init__( 209 | in_size, 210 | out_size, 211 | kernel_size, 212 | stride, 213 | padding, 214 | activation, 215 | bn, 216 | init, 217 | conv=nn.Conv3d, 218 | batch_norm=BatchNorm3d, 219 | bias=bias, 220 | preact=preact, 221 | name=name 222 | ) 223 | 224 | 225 | class FC(nn.Sequential): 226 | 227 | def __init__( 228 | self, 229 | in_size: int, 230 | out_size: int, 231 | *, 232 | activation=nn.ReLU(inplace=True), 233 | bn: bool = False, 234 | init=None, 235 | preact: bool = False, 236 | name: str = "" 237 | ): 238 | super().__init__() 239 | 240 | fc = nn.Linear(in_size, out_size, bias=not bn) 241 | if init is not None: 242 | init(fc.weight) 243 | if not bn: 244 | nn.init.constant_(fc.bias, 0) 245 | 246 | if preact: 247 | if bn: 248 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 249 | 250 | if activation is not None: 251 | self.add_module(name + 'activation', activation) 252 | 253 | self.add_module(name + 'fc', fc) 254 | 255 | if not preact: 256 | if bn: 257 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 258 | 259 | if activation is not None: 260 | self.add_module(name + 'activation', activation) 261 | 262 | def set_bn_momentum_default(bn_momentum): 263 | 264 | def fn(m): 265 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 266 | m.momentum = bn_momentum 267 | 268 | return fn 269 | 270 | 271 | class BNMomentumScheduler(object): 272 | 273 | def __init__( 274 | self, model, bn_lambda, last_epoch=-1, 275 | setter=set_bn_momentum_default 276 | ): 277 | if not isinstance(model, nn.Module): 278 | raise RuntimeError( 279 | "Class '{}' is not a PyTorch nn Module".format( 280 | type(model).__name__ 281 | ) 282 | ) 283 | 284 | self.model = model 285 | self.setter = setter 286 | self.lmbd = bn_lambda 287 | 288 | self.step(last_epoch + 1) 289 | self.last_epoch = last_epoch 290 | 291 | def step(self, epoch=None): 292 | if epoch is None: 293 | epoch = self.last_epoch + 1 294 | 295 | self.last_epoch = epoch 296 | self.model.apply(self.setter(self.lmbd(epoch))) 297 | 298 | 299 | -------------------------------------------------------------------------------- /object-detection-3d/pointnet2/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 8 | import glob 9 | import os.path as osp 10 | 11 | _ext_src_root = "_ext_src" 12 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 13 | "{}/src/*.cu".format(_ext_src_root) 14 | ) 15 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 16 | this_dir = osp.dirname(osp.abspath(__file__)) 17 | 18 | setup( 19 | name='pointnet2', 20 | ext_modules=[ 21 | CUDAExtension( 22 | name='pointnet2._ext', 23 | sources=_ext_sources, 24 | extra_compile_args={ 25 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 26 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 27 | }, 28 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 29 | ) 30 | ], 31 | cmdclass={ 32 | 'build_ext': BuildExtension 33 | } 34 | ) 35 | -------------------------------------------------------------------------------- /object-detection-3d/sunrgbd/model_util_sunrgbd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import sys 8 | import os 9 | 10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 11 | sys.path.append(BASE_DIR) 12 | ROOT_DIR = os.path.dirname(BASE_DIR) 13 | sys.path.append(os.path.join(ROOT_DIR, 'utils')) 14 | 15 | 16 | class SunrgbdDatasetConfig(object): 17 | def __init__(self): 18 | self.num_class = 10 19 | self.num_heading_bin = 12 20 | self.num_size_cluster = 10 21 | 22 | self.type2class = {'bed': 0, 'table': 1, 'sofa': 2, 'chair': 3, 'toilet': 4, 'desk': 5, 'dresser': 6, 23 | 'night_stand': 7, 'bookshelf': 8, 'bathtub': 9} 24 | self.class2type = {self.type2class[t]: t for t in self.type2class} 25 | self.type2onehotclass = {'bed': 0, 'table': 1, 'sofa': 2, 'chair': 3, 'toilet': 4, 'desk': 5, 'dresser': 6, 26 | 'night_stand': 7, 'bookshelf': 8, 'bathtub': 9} 27 | self.type_mean_size = {'bathtub': np.array([0.765840, 1.398258, 0.472728]), 28 | 'bed': np.array([2.114256, 1.620300, 0.927272]), 29 | 'bookshelf': np.array([0.404671, 1.071108, 1.688889]), 30 | 'chair': np.array([0.591958, 0.552978, 0.827272]), 31 | 'desk': np.array([0.695190, 1.346299, 0.736364]), 32 | 'dresser': np.array([0.528526, 1.002642, 1.172878]), 33 | 'night_stand': np.array([0.500618, 0.632163, 0.683424]), 34 | 'sofa': np.array([0.923508, 1.867419, 0.845495]), 35 | 'table': np.array([0.791118, 1.279516, 0.718182]), 36 | 'toilet': np.array([0.699104, 0.454178, 0.756250])} 37 | 38 | self.mean_size_arr = np.zeros((self.num_size_cluster, 3)) 39 | for i in range(self.num_size_cluster): 40 | self.mean_size_arr[i, :] = self.type_mean_size[self.class2type[i]] 41 | 42 | def size2class(self, size, type_name): 43 | ''' Convert 3D box size (l,w,h) to size class and size residual ''' 44 | size_class = self.type2class[type_name] 45 | size_residual = size - self.type_mean_size[type_name] 46 | return size_class, size_residual 47 | 48 | def class2size(self, pred_cls, residual): 49 | ''' Inverse function to size2class ''' 50 | mean_size = self.type_mean_size[self.class2type[pred_cls]] 51 | return mean_size + residual 52 | 53 | def angle2class(self, angle): 54 | ''' Convert continuous angle to discrete class 55 | [optinal] also small regression number from 56 | class center angle to current angle. 57 | 58 | angle is from 0-2pi (or -pi~pi), class center at 0, 1*(2pi/N), 2*(2pi/N) ... (N-1)*(2pi/N) 59 | return is class of int32 of 0,1,...,N-1 and a number such that 60 | class*(2pi/N) + number = angle 61 | ''' 62 | num_class = self.num_heading_bin 63 | angle = angle % (2 * np.pi) 64 | assert (angle >= 0 and angle <= 2 * np.pi) 65 | angle_per_class = 2 * np.pi / float(num_class) 66 | shifted_angle = (angle + angle_per_class / 2) % (2 * np.pi) 67 | class_id = int(shifted_angle / angle_per_class) 68 | residual_angle = shifted_angle - (class_id * angle_per_class + angle_per_class / 2) 69 | return class_id, residual_angle 70 | 71 | def class2angle(self, pred_cls, residual, to_label_format=True): 72 | ''' Inverse function to angle2class ''' 73 | num_class = self.num_heading_bin 74 | angle_per_class = 2 * np.pi / float(num_class) 75 | angle_center = pred_cls * angle_per_class 76 | angle = angle_center + residual 77 | if to_label_format and angle > np.pi: 78 | angle = angle - 2 * np.pi 79 | return angle 80 | 81 | def param2obb(self, center, heading_class, heading_residual, size_class, size_residual): 82 | heading_angle = self.class2angle(heading_class, heading_residual) 83 | box_size = self.class2size(int(size_class), size_residual) 84 | obb = np.zeros((7,)) 85 | obb[0:3] = center 86 | obb[3:6] = box_size 87 | obb[6] = heading_angle * -1 88 | return obb 89 | -------------------------------------------------------------------------------- /object-detection-3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import get_scheduler 2 | from .logger import setup_logger -------------------------------------------------------------------------------- /object-detection-3d/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /object-detection-3d/utils/box_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ Helper functions for calculating 2D and 3D bounding box IoU. 7 | 8 | Collected and written by Charles R. Qi 9 | Last modified: Jul 2019 10 | """ 11 | from __future__ import print_function 12 | 13 | import numpy as np 14 | from scipy.spatial import ConvexHull 15 | 16 | 17 | def polygon_clip(subjectPolygon, clipPolygon): 18 | """ Clip a polygon with another polygon. 19 | 20 | Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python 21 | 22 | Args: 23 | subjectPolygon: a list of (x,y) 2d points, any polygon. 24 | clipPolygon: a list of (x,y) 2d points, has to be *convex* 25 | Note: 26 | **points have to be counter-clockwise ordered** 27 | 28 | Return: 29 | a list of (x,y) vertex point for the intersection polygon. 30 | """ 31 | 32 | def inside(p): 33 | return (cp2[0] - cp1[0]) * (p[1] - cp1[1]) > (cp2[1] - cp1[1]) * (p[0] - cp1[0]) 34 | 35 | def computeIntersection(): 36 | dc = [cp1[0] - cp2[0], cp1[1] - cp2[1]] 37 | dp = [s[0] - e[0], s[1] - e[1]] 38 | n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0] 39 | n2 = s[0] * e[1] - s[1] * e[0] 40 | n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0]) 41 | return [(n1 * dp[0] - n2 * dc[0]) * n3, (n1 * dp[1] - n2 * dc[1]) * n3] 42 | 43 | outputList = subjectPolygon 44 | cp1 = clipPolygon[-1] 45 | 46 | for clipVertex in clipPolygon: 47 | cp2 = clipVertex 48 | inputList = outputList 49 | outputList = [] 50 | s = inputList[-1] 51 | 52 | for subjectVertex in inputList: 53 | e = subjectVertex 54 | if inside(e): 55 | if not inside(s): 56 | outputList.append(computeIntersection()) 57 | outputList.append(e) 58 | elif inside(s): 59 | outputList.append(computeIntersection()) 60 | s = e 61 | cp1 = cp2 62 | if len(outputList) == 0: 63 | return None 64 | return (outputList) 65 | 66 | 67 | def poly_area(x, y): 68 | """ Ref: http://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates """ 69 | return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) 70 | 71 | 72 | def convex_hull_intersection(p1, p2): 73 | """ Compute area of two convex hull's intersection area. 74 | p1,p2 are a list of (x,y) tuples of hull vertices. 75 | return a list of (x,y) for the intersection and its volume 76 | """ 77 | inter_p = polygon_clip(p1, p2) 78 | if inter_p is not None: 79 | hull_inter = ConvexHull(inter_p) 80 | return inter_p, hull_inter.volume 81 | else: 82 | return None, 0.0 83 | 84 | 85 | def box3d_vol(corners): 86 | ''' corners: (8,3) no assumption on axis direction ''' 87 | a = np.sqrt(np.sum((corners[0, :] - corners[1, :]) ** 2)) 88 | b = np.sqrt(np.sum((corners[1, :] - corners[2, :]) ** 2)) 89 | c = np.sqrt(np.sum((corners[0, :] - corners[4, :]) ** 2)) 90 | return a * b * c 91 | 92 | 93 | def is_clockwise(p): 94 | x = p[:, 0] 95 | y = p[:, 1] 96 | return np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)) > 0 97 | 98 | 99 | def box3d_iou(corners1, corners2): 100 | ''' Compute 3D bounding box IoU. 101 | 102 | Input: 103 | corners1: numpy array (8,3), assume up direction is negative Y 104 | corners2: numpy array (8,3), assume up direction is negative Y 105 | Output: 106 | iou: 3D bounding box IoU 107 | iou_2d: bird's eye view 2D bounding box IoU 108 | 109 | todo (rqi): add more description on corner points' orders. 110 | ''' 111 | # corner points are in counter clockwise order 112 | rect1 = [(corners1[i, 0], corners1[i, 2]) for i in range(3, -1, -1)] 113 | rect2 = [(corners2[i, 0], corners2[i, 2]) for i in range(3, -1, -1)] 114 | area1 = poly_area(np.array(rect1)[:, 0], np.array(rect1)[:, 1]) 115 | area2 = poly_area(np.array(rect2)[:, 0], np.array(rect2)[:, 1]) 116 | inter, inter_area = convex_hull_intersection(rect1, rect2) 117 | iou_2d = inter_area / (area1 + area2 - inter_area) 118 | ymax = min(corners1[0, 1], corners2[0, 1]) 119 | ymin = max(corners1[4, 1], corners2[4, 1]) 120 | inter_vol = inter_area * max(0.0, ymax - ymin) 121 | vol1 = box3d_vol(corners1) 122 | vol2 = box3d_vol(corners2) 123 | iou = inter_vol / (vol1 + vol2 - inter_vol) 124 | return iou, iou_2d 125 | 126 | 127 | def get_iou(bb1, bb2): 128 | """ 129 | Calculate the Intersection over Union (IoU) of two 2D bounding boxes. 130 | 131 | Parameters 132 | ---------- 133 | bb1 : dict 134 | Keys: {'x1', 'x2', 'y1', 'y2'} 135 | The (x1, y1) position is at the top left corner, 136 | the (x2, y2) position is at the bottom right corner 137 | bb2 : dict 138 | Keys: {'x1', 'x2', 'y1', 'y2'} 139 | The (x, y) position is at the top left corner, 140 | the (x2, y2) position is at the bottom right corner 141 | 142 | Returns 143 | ------- 144 | float 145 | in [0, 1] 146 | """ 147 | assert bb1['x1'] < bb1['x2'] 148 | assert bb1['y1'] < bb1['y2'] 149 | assert bb2['x1'] < bb2['x2'] 150 | assert bb2['y1'] < bb2['y2'] 151 | 152 | # determine the coordinates of the intersection rectangle 153 | x_left = max(bb1['x1'], bb2['x1']) 154 | y_top = max(bb1['y1'], bb2['y1']) 155 | x_right = min(bb1['x2'], bb2['x2']) 156 | y_bottom = min(bb1['y2'], bb2['y2']) 157 | 158 | if x_right < x_left or y_bottom < y_top: 159 | return 0.0 160 | 161 | # The intersection of two axis-aligned bounding boxes is always an 162 | # axis-aligned bounding box 163 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 164 | 165 | # compute the area of both AABBs 166 | bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1']) 167 | bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1']) 168 | 169 | # compute the intersection over union by taking the intersection 170 | # area and dividing it by the sum of prediction + ground-truth 171 | # areas - the interesection area 172 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) 173 | assert iou >= 0.0 174 | assert iou <= 1.0 175 | return iou 176 | 177 | 178 | def box2d_iou(box1, box2): 179 | ''' Compute 2D bounding box IoU. 180 | 181 | Input: 182 | box1: tuple of (xmin,ymin,xmax,ymax) 183 | box2: tuple of (xmin,ymin,xmax,ymax) 184 | Output: 185 | iou: 2D IoU scalar 186 | ''' 187 | return get_iou({'x1': box1[0], 'y1': box1[1], 'x2': box1[2], 'y2': box1[3]}, \ 188 | {'x1': box2[0], 'y1': box2[1], 'x2': box2[2], 'y2': box2[3]}) 189 | 190 | 191 | # ----------------------------------------------------------- 192 | # Convert from box parameters to 193 | # ----------------------------------------------------------- 194 | def roty(t): 195 | """Rotation about the y-axis.""" 196 | c = np.cos(t) 197 | s = np.sin(t) 198 | return np.array([[c, 0, s], 199 | [0, 1, 0], 200 | [-s, 0, c]]) 201 | 202 | 203 | def roty_batch(t): 204 | """Rotation about the y-axis. 205 | t: (x1,x2,...xn) 206 | return: (x1,x2,...,xn,3,3) 207 | """ 208 | input_shape = t.shape 209 | output = np.zeros(tuple(list(input_shape) + [3, 3])) 210 | c = np.cos(t) 211 | s = np.sin(t) 212 | output[..., 0, 0] = c 213 | output[..., 0, 2] = s 214 | output[..., 1, 1] = 1 215 | output[..., 2, 0] = -s 216 | output[..., 2, 2] = c 217 | return output 218 | 219 | 220 | def get_3d_box(box_size, heading_angle, center): 221 | ''' box_size is array(l,w,h), heading_angle is radius clockwise from pos x axis, center is xyz of box center 222 | output (8,3) array for 3D box cornders 223 | Similar to utils/compute_orientation_3d 224 | ''' 225 | R = roty(heading_angle) 226 | l, w, h = box_size 227 | x_corners = [l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2]; 228 | y_corners = [h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2]; 229 | z_corners = [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2]; 230 | corners_3d = np.dot(R, np.vstack([x_corners, y_corners, z_corners])) 231 | corners_3d[0, :] = corners_3d[0, :] + center[0]; 232 | corners_3d[1, :] = corners_3d[1, :] + center[1]; 233 | corners_3d[2, :] = corners_3d[2, :] + center[2]; 234 | corners_3d = np.transpose(corners_3d) 235 | return corners_3d 236 | 237 | 238 | def get_3d_box_batch(box_size, heading_angle, center): 239 | ''' box_size: [x1,x2,...,xn,3] 240 | heading_angle: [x1,x2,...,xn] 241 | center: [x1,x2,...,xn,3] 242 | Return: 243 | [x1,x3,...,xn,8,3] 244 | ''' 245 | input_shape = heading_angle.shape 246 | R = roty_batch(heading_angle) 247 | l = np.expand_dims(box_size[..., 0], -1) # [x1,...,xn,1] 248 | w = np.expand_dims(box_size[..., 1], -1) 249 | h = np.expand_dims(box_size[..., 2], -1) 250 | corners_3d = np.zeros(tuple(list(input_shape) + [8, 3])) 251 | corners_3d[..., :, 0] = np.concatenate((l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2), -1) 252 | corners_3d[..., :, 1] = np.concatenate((h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2), -1) 253 | corners_3d[..., :, 2] = np.concatenate((w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2), -1) 254 | tlist = [i for i in range(len(input_shape))] 255 | tlist += [len(input_shape) + 1, len(input_shape)] 256 | corners_3d = np.matmul(corners_3d, np.transpose(R, tuple(tlist))) 257 | corners_3d += np.expand_dims(center, -2) 258 | return corners_3d 259 | 260 | 261 | if __name__ == '__main__': 262 | 263 | # Function for polygon ploting 264 | import matplotlib 265 | from matplotlib.patches import Polygon 266 | from matplotlib.collections import PatchCollection 267 | import matplotlib.pyplot as plt 268 | 269 | 270 | def plot_polys(plist, scale=500.0): 271 | fig, ax = plt.subplots() 272 | patches = [] 273 | for p in plist: 274 | poly = Polygon(np.array(p) / scale, True) 275 | patches.append(poly) 276 | 277 | 278 | pc = PatchCollection(patches, cmap=matplotlib.cm.jet, alpha=0.5) 279 | colors = 100 * np.random.rand(len(patches)) 280 | pc.set_array(np.array(colors)) 281 | ax.add_collection(pc) 282 | plt.show() 283 | 284 | # Demo on ConvexHull 285 | points = np.random.rand(30, 2) # 30 random points in 2-D 286 | hull = ConvexHull(points) 287 | # **In 2D "volume" is is area, "area" is perimeter 288 | print(('Hull area: ', hull.volume)) 289 | for simplex in hull.simplices: 290 | print(simplex) 291 | 292 | # Demo on convex hull overlaps 293 | sub_poly = [(0, 0), (300, 0), (300, 300), (0, 300)] 294 | clip_poly = [(150, 150), (300, 300), (150, 450), (0, 300)] 295 | inter_poly = polygon_clip(sub_poly, clip_poly) 296 | print(poly_area(np.array(inter_poly)[:, 0], np.array(inter_poly)[:, 1])) 297 | 298 | # Test convex hull interaction function 299 | rect1 = [(50, 0), (50, 300), (300, 300), (300, 0)] 300 | rect2 = [(150, 150), (300, 300), (150, 450), (0, 300)] 301 | plot_polys([rect1, rect2]) 302 | inter, area = convex_hull_intersection(rect1, rect2) 303 | print((inter, area)) 304 | if inter is not None: 305 | print(poly_area(np.array(inter)[:, 0], np.array(inter)[:, 1])) 306 | 307 | print('------------------') 308 | rect1 = [(0.30026005199835404, 8.9408694211408424), \ 309 | (-1.1571105364358421, 9.4686676477075533), \ 310 | (0.1777082043006144, 13.154404877812102), \ 311 | (1.6350787927348105, 12.626606651245391)] 312 | rect1 = [rect1[0], rect1[3], rect1[2], rect1[1]] 313 | rect2 = [(0.23908745901608636, 8.8551095691132886), \ 314 | (-1.2771419487733995, 9.4269062966181956), \ 315 | (0.13138836963152717, 13.161896351296868), \ 316 | (1.647617777421013, 12.590099623791961)] 317 | rect2 = [rect2[0], rect2[3], rect2[2], rect2[1]] 318 | plot_polys([rect1, rect2]) 319 | inter, area = convex_hull_intersection(rect1, rect2) 320 | print((inter, area)) 321 | -------------------------------------------------------------------------------- /object-detection-3d/utils/eval_det.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ Generic Code for Object Detection Evaluation 7 | 8 | Input: 9 | For each class: 10 | For each image: 11 | Predictions: box, score 12 | Groundtruths: box 13 | 14 | Output: 15 | For each class: 16 | precision-recal and average precision 17 | 18 | Author: Charles R. Qi 19 | 20 | Ref: https://raw.githubusercontent.com/rbgirshick/py-faster-rcnn/master/lib/datasets/voc_eval.py 21 | """ 22 | import numpy as np 23 | 24 | 25 | def voc_ap(rec, prec, use_07_metric=False): 26 | """ ap = voc_ap(rec, prec, [use_07_metric]) 27 | Compute VOC AP given precision and recall. 28 | If use_07_metric is true, uses the 29 | VOC 07 11 point method (default:False). 30 | """ 31 | if use_07_metric: 32 | # 11 point metric 33 | ap = 0. 34 | for t in np.arange(0., 1.1, 0.1): 35 | if np.sum(rec >= t) == 0: 36 | p = 0 37 | else: 38 | p = np.max(prec[rec >= t]) 39 | ap = ap + p / 11. 40 | else: 41 | # correct AP calculation 42 | # first append sentinel values at the end 43 | mrec = np.concatenate(([0.], rec, [1.])) 44 | mpre = np.concatenate(([0.], prec, [0.])) 45 | 46 | # compute the precision envelope 47 | for i in range(mpre.size - 1, 0, -1): 48 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 49 | 50 | # to calculate area under PR curve, look for points 51 | # where X axis (recall) changes value 52 | i = np.where(mrec[1:] != mrec[:-1])[0] 53 | 54 | # and sum (\Delta recall) * prec 55 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 56 | return ap 57 | 58 | 59 | import os 60 | import sys 61 | 62 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 63 | from metric_util import calc_iou # axis-aligned 3D box IoU 64 | 65 | 66 | def get_iou(bb1, bb2): 67 | """ Compute IoU of two bounding boxes. 68 | ** Define your bod IoU function HERE ** 69 | """ 70 | # pass 71 | iou3d = calc_iou(bb1, bb2) 72 | return iou3d 73 | 74 | 75 | from box_util import box3d_iou 76 | 77 | 78 | def get_iou_obb(bb1, bb2): 79 | iou3d, iou2d = box3d_iou(bb1, bb2) 80 | return iou3d 81 | 82 | 83 | def get_iou_main(get_iou_func, args): 84 | return get_iou_func(*args) 85 | 86 | 87 | def eval_det_cls(pred, gt, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou): 88 | """ Generic functions to compute precision/recall for object detection 89 | for a single class. 90 | Input: 91 | pred: map of {img_id: [(bbox, score)]} where bbox is numpy array 92 | gt: map of {img_id: [bbox]} 93 | ovthresh: scalar, iou threshold 94 | use_07_metric: bool, if True use VOC07 11 point method 95 | Output: 96 | rec: numpy array of length nd 97 | prec: numpy array of length nd 98 | ap: scalar, average precision 99 | """ 100 | 101 | # construct gt objects 102 | class_recs = {} # {img_id: {'bbox': bbox list, 'det': matched list}} 103 | npos = 0 104 | for img_id in gt.keys(): 105 | bbox = np.array(gt[img_id]) 106 | det = [False] * len(bbox) 107 | npos += len(bbox) 108 | class_recs[img_id] = {'bbox': bbox, 'det': det} 109 | # pad empty list to all other imgids 110 | for img_id in pred.keys(): 111 | if img_id not in gt: 112 | class_recs[img_id] = {'bbox': np.array([]), 'det': []} 113 | 114 | # construct dets 115 | image_ids = [] 116 | confidence = [] 117 | BB = [] 118 | for img_id in pred.keys(): 119 | for box, score in pred[img_id]: 120 | image_ids.append(img_id) 121 | confidence.append(score) 122 | BB.append(box) 123 | confidence = np.array(confidence) 124 | BB = np.array(BB) # (nd,4 or 8,3 or 6) 125 | 126 | # sort by confidence 127 | sorted_ind = np.argsort(-confidence) 128 | sorted_scores = np.sort(-confidence) 129 | BB = BB[sorted_ind, ...] 130 | image_ids = [image_ids[x] for x in sorted_ind] 131 | 132 | # go down dets and mark TPs and FPs 133 | nd = len(image_ids) 134 | tp = np.zeros(nd) 135 | fp = np.zeros(nd) 136 | for d in range(nd): 137 | # if d%100==0: print(d) 138 | R = class_recs[image_ids[d]] 139 | bb = BB[d, ...].astype(float) 140 | ovmax = -np.inf 141 | BBGT = R['bbox'].astype(float) 142 | 143 | if BBGT.size > 0: 144 | # compute overlaps 145 | for j in range(BBGT.shape[0]): 146 | iou = get_iou_main(get_iou_func, (bb, BBGT[j, ...])) 147 | if iou > ovmax: 148 | ovmax = iou 149 | jmax = j 150 | 151 | # print d, ovmax 152 | if ovmax > ovthresh: 153 | if not R['det'][jmax]: 154 | tp[d] = 1. 155 | R['det'][jmax] = 1 156 | else: 157 | fp[d] = 1. 158 | else: 159 | fp[d] = 1. 160 | 161 | # compute precision recall 162 | fp = np.cumsum(fp) 163 | tp = np.cumsum(tp) 164 | rec = tp / float(npos) 165 | # print('NPOS: ', npos) 166 | # avoid divide by zero in case the first detection matches a difficult 167 | # ground truth 168 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 169 | ap = voc_ap(rec, prec, use_07_metric) 170 | 171 | return rec, prec, ap 172 | 173 | 174 | def eval_det_cls_wrapper(arguments): 175 | pred, gt, ovthresh, use_07_metric, get_iou_func = arguments 176 | rec, prec, ap = eval_det_cls(pred, gt, ovthresh, use_07_metric, get_iou_func) 177 | return (rec, prec, ap) 178 | 179 | 180 | def eval_det(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou): 181 | """ Generic functions to compute precision/recall for object detection 182 | for multiple classes. 183 | Input: 184 | pred_all: map of {img_id: [(classname, bbox, score)]} 185 | gt_all: map of {img_id: [(classname, bbox)]} 186 | ovthresh: scalar, iou threshold 187 | use_07_metric: bool, if true use VOC07 11 point method 188 | Output: 189 | rec: {classname: rec} 190 | prec: {classname: prec_all} 191 | ap: {classname: scalar} 192 | """ 193 | pred = {} # map {classname: pred} 194 | gt = {} # map {classname: gt} 195 | for img_id in pred_all.keys(): 196 | for classname, bbox, score in pred_all[img_id]: 197 | if classname not in pred: pred[classname] = {} 198 | if img_id not in pred[classname]: 199 | pred[classname][img_id] = [] 200 | if classname not in gt: gt[classname] = {} 201 | if img_id not in gt[classname]: 202 | gt[classname][img_id] = [] 203 | pred[classname][img_id].append((bbox, score)) 204 | for img_id in gt_all.keys(): 205 | for classname, bbox in gt_all[img_id]: 206 | if classname not in gt: gt[classname] = {} 207 | if img_id not in gt[classname]: 208 | gt[classname][img_id] = [] 209 | gt[classname][img_id].append(bbox) 210 | 211 | rec = {} 212 | prec = {} 213 | ap = {} 214 | for classname in gt.keys(): 215 | # print('Computing AP for class: ', classname) 216 | rec[classname], prec[classname], ap[classname] = eval_det_cls(pred[classname], gt[classname], ovthresh, 217 | use_07_metric, get_iou_func) 218 | # print(classname, ap[classname]) 219 | 220 | return rec, prec, ap 221 | 222 | 223 | from multiprocessing import Pool 224 | 225 | 226 | def eval_det_multiprocessing(pred_all, gt_all, ovthresh=0.25, use_07_metric=False, get_iou_func=get_iou): 227 | """ Generic functions to compute precision/recall for object detection 228 | for multiple classes. 229 | Input: 230 | pred_all: map of {img_id: [(classname, bbox, score)]} 231 | gt_all: map of {img_id: [(classname, bbox)]} 232 | ovthresh: scalar, iou threshold 233 | use_07_metric: bool, if true use VOC07 11 point method 234 | Output: 235 | rec: {classname: rec} 236 | prec: {classname: prec_all} 237 | ap: {classname: scalar} 238 | """ 239 | pred = {} # map {classname: pred} 240 | gt = {} # map {classname: gt} 241 | for img_id in pred_all.keys(): 242 | for classname, bbox, score in pred_all[img_id]: 243 | if classname not in pred: pred[classname] = {} 244 | if img_id not in pred[classname]: 245 | pred[classname][img_id] = [] 246 | if classname not in gt: gt[classname] = {} 247 | if img_id not in gt[classname]: 248 | gt[classname][img_id] = [] 249 | pred[classname][img_id].append((bbox, score)) 250 | for img_id in gt_all.keys(): 251 | for classname, bbox in gt_all[img_id]: 252 | if classname not in gt: gt[classname] = {} 253 | if img_id not in gt[classname]: 254 | gt[classname][img_id] = [] 255 | gt[classname][img_id].append(bbox) 256 | # print(pred) 257 | rec = {} 258 | prec = {} 259 | ap = {} 260 | p = Pool(processes=10) 261 | ret_values = p.map(eval_det_cls_wrapper, 262 | [(pred[classname], gt[classname], ovthresh, use_07_metric, get_iou_func) for classname in 263 | gt.keys() if classname in pred]) 264 | p.close() 265 | for i, classname in enumerate(gt.keys()): 266 | if classname in pred: 267 | rec[classname], prec[classname], ap[classname] = ret_values[i] 268 | else: 269 | rec[classname] = 0 270 | prec[classname] = 0 271 | ap[classname] = 0 272 | # print(classname, ap[classname]) 273 | 274 | return rec, prec, ap 275 | -------------------------------------------------------------------------------- /object-detection-3d/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | from termcolor import colored 7 | 8 | 9 | class _ColorfulFormatter(logging.Formatter): 10 | def __init__(self, *args, **kwargs): 11 | self._root_name = kwargs.pop("root_name") + "." 12 | self._abbrev_name = kwargs.pop("abbrev_name", "") 13 | if len(self._abbrev_name): 14 | self._abbrev_name = self._abbrev_name + "." 15 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 16 | 17 | def formatMessage(self, record): 18 | record.name = record.name.replace(self._root_name, self._abbrev_name) 19 | log = super(_ColorfulFormatter, self).formatMessage(record) 20 | if record.levelno == logging.WARNING: 21 | prefix = colored("WARNING", "red", attrs=["blink"]) 22 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 23 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 24 | else: 25 | return log 26 | return prefix + " " + log 27 | 28 | 29 | # so that calling setup_logger multiple times won't add many handlers 30 | @functools.lru_cache() 31 | def setup_logger( 32 | output=None, distributed_rank=0, *, color=True, name="log", abbrev_name=None 33 | ): 34 | """ 35 | Initialize the detectron2 logger and set its verbosity level to "INFO". 36 | 37 | Args: 38 | output (str): a file name or a directory to save log. If None, will not save log file. 39 | If ends with ".txt" or ".log", assumed to be a file name. 40 | Otherwise, logs will be saved to `output/log.txt`. 41 | name (str): the root module name of this logger 42 | 43 | Returns: 44 | logging.Logger: a logger 45 | """ 46 | logger = logging.getLogger(name) 47 | logger.setLevel(logging.DEBUG) 48 | logger.propagate = False 49 | 50 | if abbrev_name is None: 51 | abbrev_name = name 52 | 53 | plain_formatter = logging.Formatter( 54 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 55 | ) 56 | # stdout logging: master only 57 | if distributed_rank == 0: 58 | ch = logging.StreamHandler(stream=sys.stdout) 59 | ch.setLevel(logging.DEBUG) 60 | if color: 61 | formatter = _ColorfulFormatter( 62 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 63 | datefmt="%m/%d %H:%M:%S", 64 | root_name=name, 65 | abbrev_name=str(abbrev_name), 66 | ) 67 | else: 68 | formatter = plain_formatter 69 | ch.setFormatter(formatter) 70 | logger.addHandler(ch) 71 | 72 | # file logging: all workers 73 | if output is not None: 74 | if output.endswith(".txt") or output.endswith(".log"): 75 | filename = output 76 | else: 77 | filename = os.path.join(output, "log.txt") 78 | if distributed_rank > 0: 79 | filename = filename + f".rank{distributed_rank}" 80 | os.makedirs(os.path.dirname(filename), exist_ok=True) 81 | 82 | fh = logging.StreamHandler(_cached_log_stream(filename)) 83 | fh.setLevel(logging.DEBUG) 84 | fh.setFormatter(plain_formatter) 85 | logger.addHandler(fh) 86 | 87 | return logger 88 | 89 | 90 | # cache the opened file object, so that different calls to `setup_logger` 91 | # with the same file name can safely write to the same file. 92 | @functools.lru_cache(maxsize=None) 93 | def _cached_log_stream(filename): 94 | return open(filename, "a") 95 | -------------------------------------------------------------------------------- /object-detection-3d/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # noinspection PyProtectedMember 2 | from torch.optim.lr_scheduler import _LRScheduler, MultiStepLR, CosineAnnealingLR 3 | 4 | 5 | # noinspection PyAttributeOutsideInit 6 | class GradualWarmupScheduler(_LRScheduler): 7 | """ Gradually warm-up(increasing) learning rate in optimizer. 8 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: init learning rate = base lr / multiplier 12 | warmup_epoch: target learning rate is reached at warmup_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, warmup_epoch, after_scheduler, last_epoch=-1): 17 | self.multiplier = multiplier 18 | if self.multiplier <= 1.: 19 | raise ValueError('multiplier should be greater than 1.') 20 | self.warmup_epoch = warmup_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super().__init__(optimizer, last_epoch=last_epoch) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.warmup_epoch: 27 | return self.after_scheduler.get_lr() 28 | else: 29 | return [base_lr / self.multiplier * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epoch + 1.) 30 | for base_lr in self.base_lrs] 31 | 32 | def step(self, epoch=None): 33 | if epoch is None: 34 | epoch = self.last_epoch + 1 35 | self.last_epoch = epoch 36 | if epoch > self.warmup_epoch: 37 | self.after_scheduler.step(epoch - self.warmup_epoch) 38 | else: 39 | super(GradualWarmupScheduler, self).step(epoch) 40 | 41 | def state_dict(self): 42 | """Returns the state of the scheduler as a :class:`dict`. 43 | 44 | It contains an entry for every variable in self.__dict__ which 45 | is not the optimizer. 46 | """ 47 | 48 | state = {key: value for key, value in self.__dict__.items() if key != 'optimizer' and key != 'after_scheduler'} 49 | state['after_scheduler'] = self.after_scheduler.state_dict() 50 | return state 51 | 52 | def load_state_dict(self, state_dict): 53 | """Loads the schedulers state. 54 | 55 | Arguments: 56 | state_dict (dict): scheduler state. Should be an object returned 57 | from a call to :meth:`state_dict`. 58 | """ 59 | 60 | after_scheduler_state = state_dict.pop('after_scheduler') 61 | self.__dict__.update(state_dict) 62 | self.after_scheduler.load_state_dict(after_scheduler_state) 63 | 64 | 65 | def get_scheduler(optimizer, n_iter_per_epoch, args): 66 | if "cosine" in args.lr_scheduler: 67 | scheduler = CosineAnnealingLR( 68 | optimizer=optimizer, 69 | eta_min=0.000001, 70 | T_max=(args.max_epoch - args.warmup_epoch) * n_iter_per_epoch) 71 | elif "step" in args.lr_scheduler: 72 | if isinstance(args.lr_decay_epochs, int): 73 | args.lr_decay_epochs = [args.lr_decay_epochs] 74 | scheduler = MultiStepLR( 75 | optimizer=optimizer, 76 | gamma=args.lr_decay_rate, 77 | milestones=[(m - args.warmup_epoch) * n_iter_per_epoch for m in args.lr_decay_epochs]) 78 | else: 79 | raise NotImplementedError(f"scheduler {args.lr_scheduler} not supported") 80 | 81 | if args.warmup_epoch > 0: 82 | scheduler = GradualWarmupScheduler( 83 | optimizer, 84 | multiplier=args.warmup_multiplier, 85 | after_scheduler=scheduler, 86 | warmup_epoch=args.warmup_epoch * n_iter_per_epoch) 87 | return scheduler 88 | -------------------------------------------------------------------------------- /object-detection-3d/utils/metric_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ Utility functions for metric evaluation. 7 | 8 | Author: Or Litany and Charles R. Qi 9 | """ 10 | 11 | import os 12 | import sys 13 | import torch 14 | 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | sys.path.append(BASE_DIR) 17 | 18 | import numpy as np 19 | 20 | # Mesh IO 21 | import trimesh 22 | 23 | 24 | # ---------------------------------------- 25 | # Precision and Recall 26 | # ---------------------------------------- 27 | 28 | def multi_scene_precision_recall(labels, pred, iou_thresh, conf_thresh, label_mask, pred_mask=None): 29 | ''' 30 | Args: 31 | labels: (B, N, 6) 32 | pred: (B, M, 6) 33 | iou_thresh: scalar 34 | conf_thresh: scalar 35 | label_mask: (B, N,) with values in 0 or 1 to indicate which GT boxes to consider. 36 | pred_mask: (B, M,) with values in 0 or 1 to indicate which PRED boxes to consider. 37 | Returns: 38 | TP,FP,FN,Precision,Recall 39 | ''' 40 | # Make sure the masks are not Torch tensor, otherwise the mask==1 returns uint8 array instead 41 | # of True/False array as in numpy 42 | assert (not torch.is_tensor(label_mask)) 43 | assert (not torch.is_tensor(pred_mask)) 44 | TP, FP, FN = 0, 0, 0 45 | if label_mask is None: label_mask = np.ones((labels.shape[0], labels.shape[1])) 46 | if pred_mask is None: pred_mask = np.ones((pred.shape[0], pred.shape[1])) 47 | for batch_idx in range(labels.shape[0]): 48 | TP_i, FP_i, FN_i = single_scene_precision_recall(labels[batch_idx, label_mask[batch_idx, :] == 1, :], 49 | pred[batch_idx, pred_mask[batch_idx, :] == 1, :], 50 | iou_thresh, conf_thresh) 51 | TP += TP_i 52 | FP += FP_i 53 | FN += FN_i 54 | 55 | return TP, FP, FN, precision_recall(TP, FP, FN) 56 | 57 | 58 | def single_scene_precision_recall(labels, pred, iou_thresh, conf_thresh): 59 | """Compute P and R for predicted bounding boxes. Ignores classes! 60 | Args: 61 | labels: (N x bbox) ground-truth bounding boxes (6 dims) 62 | pred: (M x (bbox + conf)) predicted bboxes with confidence and maybe classification 63 | Returns: 64 | TP, FP, FN 65 | """ 66 | 67 | # for each pred box with high conf (C), compute IoU with all gt boxes. 68 | # TP = number of times IoU > th ; FP = C - TP 69 | # FN - number of scene objects without good match 70 | 71 | gt_bboxes = labels[:, :6] 72 | 73 | num_scene_bboxes = gt_bboxes.shape[0] 74 | conf = pred[:, 6] 75 | 76 | conf_pred_bbox = pred[np.where(conf > conf_thresh)[0], :6] 77 | num_conf_pred_bboxes = conf_pred_bbox.shape[0] 78 | 79 | # init an array to keep iou between generated and scene bboxes 80 | iou_arr = np.zeros([num_conf_pred_bboxes, num_scene_bboxes]) 81 | for g_idx in range(num_conf_pred_bboxes): 82 | for s_idx in range(num_scene_bboxes): 83 | iou_arr[g_idx, s_idx] = calc_iou(conf_pred_bbox[g_idx, :], gt_bboxes[s_idx, :]) 84 | 85 | good_match_arr = (iou_arr >= iou_thresh) 86 | 87 | TP = good_match_arr.any(axis=1).sum() 88 | FP = num_conf_pred_bboxes - TP 89 | FN = num_scene_bboxes - good_match_arr.any(axis=0).sum() 90 | 91 | return TP, FP, FN 92 | 93 | 94 | def precision_recall(TP, FP, FN): 95 | Prec = 1.0 * TP / (TP + FP) if TP + FP > 0 else 0 96 | Rec = 1.0 * TP / (TP + FN) 97 | return Prec, Rec 98 | 99 | 100 | def calc_iou(box_a, box_b): 101 | """Computes IoU of two axis aligned bboxes. 102 | Args: 103 | box_a, box_b: 6D of center and lengths 104 | Returns: 105 | iou 106 | """ 107 | 108 | max_a = box_a[0:3] + box_a[3:6] / 2 109 | max_b = box_b[0:3] + box_b[3:6] / 2 110 | min_max = np.array([max_a, max_b]).min(0) 111 | 112 | min_a = box_a[0:3] - box_a[3:6] / 2 113 | min_b = box_b[0:3] - box_b[3:6] / 2 114 | max_min = np.array([min_a, min_b]).max(0) 115 | if not ((min_max > max_min).all()): 116 | return 0.0 117 | 118 | intersection = (min_max - max_min).prod() 119 | vol_a = box_a[3:6].prod() 120 | vol_b = box_b[3:6].prod() 121 | union = vol_a + vol_b - intersection 122 | return 1.0 * intersection / union 123 | 124 | 125 | if __name__ == '__main__': 126 | print('running some tests') 127 | 128 | ############ 129 | ## Test IoU 130 | ############ 131 | box_a = np.array([0, 0, 0, 1, 1, 1]) 132 | box_b = np.array([0, 0, 0, 2, 2, 2]) 133 | expected_iou = 1.0 / 8 134 | pred_iou = calc_iou(box_a, box_b) 135 | assert expected_iou == pred_iou, 'function returned wrong IoU' 136 | 137 | box_a = np.array([0, 0, 0, 1, 1, 1]) 138 | box_b = np.array([10, 10, 10, 2, 2, 2]) 139 | expected_iou = 0.0 140 | pred_iou = calc_iou(box_a, box_b) 141 | assert expected_iou == pred_iou, 'function returned wrong IoU' 142 | 143 | print('IoU test -- PASSED') 144 | 145 | ######################### 146 | ## Test Precition Recall 147 | ######################### 148 | gt_boxes = np.array([[0, 0, 0, 1, 1, 1], [3, 0, 1, 1, 10, 1]]) 149 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0], [3, 0, 1, 1, 10, 1, 0.9]]) 150 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 151 | assert TP == 2 and FP == 0 and FN == 0 152 | assert precision_recall(TP, FP, FN) == (1, 1) 153 | 154 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0]]) 155 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 156 | assert TP == 1 and FP == 0 and FN == 1 157 | assert precision_recall(TP, FP, FN) == (1, 0.5) 158 | 159 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0], [-1, -1, 0, 0.1, 0.1, 1, 1.0]]) 160 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 161 | assert TP == 1 and FP == 1 and FN == 1 162 | assert precision_recall(TP, FP, FN) == (0.5, 0.5) 163 | 164 | # wrong box has low confidence 165 | detected_boxes = np.array([[0, 0, 0, 1, 1, 1, 1.0], [-1, -1, 0, 0.1, 0.1, 1, 0.1]]) 166 | TP, FP, FN = single_scene_precision_recall(gt_boxes, detected_boxes, 0.5, 0.5) 167 | assert TP == 1 and FP == 0 and FN == 1 168 | assert precision_recall(TP, FP, FN) == (1, 0.5) 169 | 170 | print('Precition Recall test -- PASSED') 171 | -------------------------------------------------------------------------------- /object-detection-3d/utils/nms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | from pc_util import bbox_corner_dist_measure 8 | 9 | # boxes are axis aigned 2D boxes of shape (n,5) in FLOAT numbers with (x1,y1,x2,y2,score) 10 | ''' Ref: https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/ 11 | Ref: https://github.com/vickyboy47/nms-python/blob/master/nms.py 12 | ''' 13 | 14 | 15 | def nms_2d(boxes, overlap_threshold): 16 | x1 = boxes[:, 0] 17 | y1 = boxes[:, 1] 18 | x2 = boxes[:, 2] 19 | y2 = boxes[:, 3] 20 | score = boxes[:, 4] 21 | area = (x2 - x1) * (y2 - y1) 22 | 23 | I = np.argsort(score) 24 | pick = [] 25 | while (I.size != 0): 26 | last = I.size 27 | i = I[-1] 28 | pick.append(i) 29 | suppress = [last - 1] 30 | for pos in range(last - 1): 31 | j = I[pos] 32 | xx1 = max(x1[i], x1[j]) 33 | yy1 = max(y1[i], y1[j]) 34 | xx2 = min(x2[i], x2[j]) 35 | yy2 = min(y2[i], y2[j]) 36 | w = xx2 - xx1 37 | h = yy2 - yy1 38 | if (w > 0 and h > 0): 39 | o = w * h / area[j] 40 | print('Overlap is', o) 41 | if (o > overlap_threshold): 42 | suppress.append(pos) 43 | I = np.delete(I, suppress) 44 | return pick 45 | 46 | 47 | def nms_2d_faster(boxes, overlap_threshold, old_type=False): 48 | x1 = boxes[:, 0] 49 | y1 = boxes[:, 1] 50 | x2 = boxes[:, 2] 51 | y2 = boxes[:, 3] 52 | score = boxes[:, 4] 53 | area = (x2 - x1) * (y2 - y1) 54 | 55 | I = np.argsort(score) 56 | pick = [] 57 | while (I.size != 0): 58 | last = I.size 59 | i = I[-1] 60 | pick.append(i) 61 | 62 | xx1 = np.maximum(x1[i], x1[I[:last - 1]]) 63 | yy1 = np.maximum(y1[i], y1[I[:last - 1]]) 64 | xx2 = np.minimum(x2[i], x2[I[:last - 1]]) 65 | yy2 = np.minimum(y2[i], y2[I[:last - 1]]) 66 | 67 | w = np.maximum(0, xx2 - xx1) 68 | h = np.maximum(0, yy2 - yy1) 69 | 70 | if old_type: 71 | o = (w * h) / area[I[:last - 1]] 72 | else: 73 | inter = w * h 74 | o = inter / (area[i] + area[I[:last - 1]] - inter) 75 | 76 | I = np.delete(I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0]))) 77 | 78 | return pick 79 | 80 | 81 | def nms_3d_faster(boxes, overlap_threshold, old_type=False): 82 | x1 = boxes[:, 0] 83 | y1 = boxes[:, 1] 84 | z1 = boxes[:, 2] 85 | x2 = boxes[:, 3] 86 | y2 = boxes[:, 4] 87 | z2 = boxes[:, 5] 88 | score = boxes[:, 6] 89 | area = (x2 - x1) * (y2 - y1) * (z2 - z1) 90 | 91 | I = np.argsort(score) 92 | pick = [] 93 | while (I.size != 0): 94 | last = I.size 95 | i = I[-1] 96 | pick.append(i) 97 | 98 | xx1 = np.maximum(x1[i], x1[I[:last - 1]]) 99 | yy1 = np.maximum(y1[i], y1[I[:last - 1]]) 100 | zz1 = np.maximum(z1[i], z1[I[:last - 1]]) 101 | xx2 = np.minimum(x2[i], x2[I[:last - 1]]) 102 | yy2 = np.minimum(y2[i], y2[I[:last - 1]]) 103 | zz2 = np.minimum(z2[i], z2[I[:last - 1]]) 104 | 105 | l = np.maximum(0, xx2 - xx1) 106 | w = np.maximum(0, yy2 - yy1) 107 | h = np.maximum(0, zz2 - zz1) 108 | 109 | if old_type: 110 | o = (l * w * h) / area[I[:last - 1]] 111 | else: 112 | inter = l * w * h 113 | o = inter / (area[i] + area[I[:last - 1]] - inter) 114 | 115 | I = np.delete(I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0]))) 116 | 117 | return pick 118 | 119 | 120 | def nms_3d_faster_samecls(boxes, overlap_threshold, old_type=False): 121 | x1 = boxes[:, 0] 122 | y1 = boxes[:, 1] 123 | z1 = boxes[:, 2] 124 | x2 = boxes[:, 3] 125 | y2 = boxes[:, 4] 126 | z2 = boxes[:, 5] 127 | score = boxes[:, 6] 128 | cls = boxes[:, 7] 129 | area = (x2 - x1) * (y2 - y1) * (z2 - z1) 130 | 131 | I = np.argsort(score) 132 | pick = [] 133 | while (I.size != 0): 134 | last = I.size 135 | i = I[-1] 136 | pick.append(i) 137 | 138 | xx1 = np.maximum(x1[i], x1[I[:last - 1]]) 139 | yy1 = np.maximum(y1[i], y1[I[:last - 1]]) 140 | zz1 = np.maximum(z1[i], z1[I[:last - 1]]) 141 | xx2 = np.minimum(x2[i], x2[I[:last - 1]]) 142 | yy2 = np.minimum(y2[i], y2[I[:last - 1]]) 143 | zz2 = np.minimum(z2[i], z2[I[:last - 1]]) 144 | cls1 = cls[i] 145 | cls2 = cls[I[:last - 1]] 146 | 147 | l = np.maximum(0, xx2 - xx1) 148 | w = np.maximum(0, yy2 - yy1) 149 | h = np.maximum(0, zz2 - zz1) 150 | 151 | if old_type: 152 | o = (l * w * h) / area[I[:last - 1]] 153 | else: 154 | inter = l * w * h 155 | o = inter / (area[i] + area[I[:last - 1]] - inter) 156 | o = o * (cls1 == cls2) 157 | 158 | I = np.delete(I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0]))) 159 | 160 | return pick 161 | 162 | 163 | def nms_crnr_dist(boxes, conf, overlap_threshold): 164 | I = np.argsort(conf) 165 | pick = [] 166 | while (I.size != 0): 167 | last = I.size 168 | i = I[-1] 169 | pick.append(i) 170 | 171 | scores = [] 172 | for ind in I[:-1]: 173 | scores.append(bbox_corner_dist_measure(boxes[i, :], boxes[ind, :])) 174 | 175 | I = np.delete(I, np.concatenate(([last - 1], np.where(np.array(scores) > overlap_threshold)[0]))) 176 | 177 | return pick 178 | 179 | 180 | if __name__ == '__main__': 181 | a = np.random.random((100, 5)) 182 | print(nms_2d(a, 0.9)) 183 | print(nms_2d_faster(a, 0.9)) 184 | -------------------------------------------------------------------------------- /object-detection-3d/utils/nn_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ Chamfer distance in Pytorch. 7 | Author: Charles R. Qi 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import numpy as np 13 | 14 | 15 | def huber_loss(error, delta=1.0): 16 | """ 17 | Args: 18 | error: Torch tensor (d1,d2,...,dk) 19 | Returns: 20 | loss: Torch tensor (d1,d2,...,dk) 21 | 22 | x = error = pred - gt or dist(pred,gt) 23 | 0.5 * |x|^2 if |x|<=d 24 | 0.5 * d^2 + d * (|x|-d) if |x|>d 25 | Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py 26 | """ 27 | abs_error = torch.abs(error) 28 | # quadratic = torch.min(abs_error, torch.FloatTensor([delta])) 29 | quadratic = torch.clamp(abs_error, max=delta) 30 | linear = (abs_error - quadratic) 31 | loss = 0.5 * quadratic ** 2 + delta * linear 32 | return loss 33 | 34 | 35 | def nn_distance(pc1, pc2, l1smooth=False, delta=1.0, l1=False): 36 | """ 37 | Input: 38 | pc1: (B,N,C) torch tensor 39 | pc2: (B,M,C) torch tensor 40 | l1smooth: bool, whether to use l1smooth loss 41 | delta: scalar, the delta used in l1smooth loss 42 | Output: 43 | dist1: (B,N) torch float32 tensor 44 | idx1: (B,N) torch int64 tensor 45 | dist2: (B,M) torch float32 tensor 46 | idx2: (B,M) torch int64 tensor 47 | """ 48 | N = pc1.shape[1] 49 | M = pc2.shape[1] 50 | pc1_expand_tile = pc1.unsqueeze(2).repeat(1, 1, M, 1) 51 | pc2_expand_tile = pc2.unsqueeze(1).repeat(1, N, 1, 1) 52 | pc_diff = pc1_expand_tile - pc2_expand_tile 53 | 54 | if l1smooth: 55 | pc_dist = torch.sum(huber_loss(pc_diff, delta), dim=-1) # (B,N,M) 56 | elif l1: 57 | pc_dist = torch.sum(torch.abs(pc_diff), dim=-1) # (B,N,M) 58 | else: 59 | pc_dist = torch.sum(pc_diff ** 2, dim=-1) # (B,N,M) 60 | dist1, idx1 = torch.min(pc_dist, dim=2) # (B,N) 61 | dist2, idx2 = torch.min(pc_dist, dim=1) # (B,M) 62 | return dist1, idx1, dist2, idx2 63 | 64 | 65 | def demo_nn_distance(): 66 | np.random.seed(0) 67 | pc1arr = np.random.random((1, 5, 3)) 68 | pc2arr = np.random.random((1, 6, 3)) 69 | pc1 = torch.from_numpy(pc1arr.astype(np.float32)) 70 | pc2 = torch.from_numpy(pc2arr.astype(np.float32)) 71 | dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2) 72 | print(dist1) 73 | print(idx1) 74 | dist = np.zeros((5, 6)) 75 | for i in range(5): 76 | for j in range(6): 77 | dist[i, j] = np.sum((pc1arr[0, i, :] - pc2arr[0, j, :]) ** 2) 78 | print(dist) 79 | print('-' * 30) 80 | print('L1smooth dists:') 81 | dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2, True) 82 | print(dist1) 83 | print(idx1) 84 | dist = np.zeros((5, 6)) 85 | for i in range(5): 86 | for j in range(6): 87 | error = np.abs(pc1arr[0, i, :] - pc2arr[0, j, :]) 88 | quad = np.minimum(error, 1.0) 89 | linear = error - quad 90 | loss = 0.5 * quad ** 2 + 1.0 * linear 91 | dist[i, j] = np.sum(loss) 92 | print(dist) 93 | 94 | 95 | if __name__ == '__main__': 96 | demo_nn_distance() 97 | -------------------------------------------------------------------------------- /object-detection-3d/utils/visual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from pathlib import Path 8 | import cv2 9 | from PIL import Image 10 | from torchvision import transforms 11 | import torchvision 12 | from torch.utils.data import DataLoader 13 | import skimage 14 | import colorsys 15 | import random 16 | from skimage import io 17 | import argparse 18 | import datasets.transforms as T 19 | import copy 20 | import glob 21 | import re 22 | torch.set_grad_enabled(False) 23 | 24 | CLASSES = ['bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', 'night_stand', 'bookshelf', 'bathtub'] 25 | 26 | def make_coco_transforms(image_set, args): 27 | 28 | normalize = T.Compose([ 29 | T.ToTensor(), 30 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 31 | ]) 32 | 33 | scales = [800] 34 | 35 | if image_set == 'val' or image_set == 'train': 36 | return T.Compose([ 37 | T.RandomResize([scales[-1]], max_size=scales[-1] * 1333 // 800), 38 | normalize, 39 | ]) 40 | 41 | raise ValueError(f'unknown {image_set}') 42 | 43 | def plot_gt(im, labels, bboxes_scaled, output_dir): 44 | tl = 3 45 | tf = max(tl-1, 1) 46 | tempimg = copy.deepcopy(im) 47 | color = [255,0,0] 48 | for label, (xmin, ymin, xmax, ymax) in zip(labels.tolist(), bboxes_scaled.tolist()): 49 | c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax)) 50 | cv2.rectangle(tempimg, c1, c2, color, tl, cv2.LINE_AA) 51 | text = f'{CLASSES[label]}' 52 | t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0] 53 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 54 | cv2.rectangle(tempimg, c1, c2, color, -1, cv2.LINE_AA) # filled 55 | cv2.putText(tempimg, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 56 | fname = os.path.join(output_dir,'gt_img.png') 57 | cv2.imwrite(fname, tempimg) 58 | print(f"{fname} saved.") 59 | 60 | def box_cxcywh_to_xyxy(x): 61 | x_c, y_c, w, h = x.unbind(1) 62 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 63 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 64 | return torch.stack(b, dim=1) 65 | 66 | def rescale_bboxes(out_bbox, size): 67 | img_w, img_h = size 68 | b = box_cxcywh_to_xyxy(out_bbox) 69 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 70 | return b 71 | def draw_bbox_in_img(fname, bbox_scaled, score, color=[0,255,0]): 72 | tl = 3 73 | tf = max(tl-1,1) # font thickness 74 | # color = [0,255,0] 75 | im = cv2.imread(fname) 76 | for p, (xmin, ymin, xmax, ymax) in zip(score, bbox_scaled.tolist()): 77 | c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax)) 78 | cv2.rectangle(im, c1, c2, color, tl, cv2.LINE_AA) 79 | cl = p.argmax() 80 | text = f'{CLASSES[cl]}: {p[cl]:0.2f}' 81 | t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0] 82 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 83 | cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled 84 | cv2.putText(im, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 85 | cv2.imwrite(fname, im) 86 | 87 | def plot_results(cv2_img, prob, boxes, output_dir): 88 | tl = 3 # thickness line 89 | tf = max(tl-1,1) # font thickness 90 | tempimg = copy.deepcopy(cv2_img) 91 | color = [0,0,255] 92 | for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()): 93 | c1, c2 = (int(xmin), int(ymin)), (int(xmax), int(ymax)) 94 | cv2.rectangle(tempimg, c1, c2, color, tl, cv2.LINE_AA) 95 | cl = p.argmax() 96 | text = f'{CLASSES[cl]}: {p[cl]:0.2f}' 97 | t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tf)[0] 98 | c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3 99 | cv2.rectangle(tempimg, c1, c2, color, -1, cv2.LINE_AA) # filled 100 | cv2.putText(tempimg, text, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) 101 | fname = os.path.join(output_dir,'pred_img.png') 102 | cv2.imwrite(fname, tempimg) 103 | print(f"{fname} saved.") 104 | 105 | def increment_path(path, exist_ok=False, sep='', mkdir=False): 106 | # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 107 | path = Path(path) # os-agnostic 108 | if path.exists() and not exist_ok: 109 | suffix = path.suffix 110 | path = path.with_suffix('') 111 | dirs = glob.glob(f"{path}{sep}*") # similar paths 112 | matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] 113 | i = [int(m.groups()[0]) for m in matches if m] # indices 114 | n = max(i) + 1 if i else 2 # increment number 115 | path = Path(f"{path}{sep}{n}{suffix}") # update path 116 | dir = path if path.suffix == '' else path.parent # directory 117 | if not dir.exists() and mkdir: 118 | dir.mkdir(parents=True, exist_ok=True) # make directory 119 | return path 120 | 121 | def save_pred_fig(output_dir, output_dic, keep): 122 | # im = Image.open(os.path.join(output_dir, "img.png")) 123 | im = cv2.imread(os.path.join(output_dir, "img.png")) 124 | h, w = im.shape[:2] 125 | bboxes_scaled = rescale_bboxes(output_dic['pred_boxes2d'][0, keep].cpu(), (w,h)) 126 | prob = output_dic['pred_logits2d'].softmax(-1)[0, :, :-1] 127 | scores = prob[keep] 128 | plot_results(im, scores, bboxes_scaled, output_dir) 129 | 130 | def save_gt_fig(output_dir, gt_anno): 131 | im = cv2.imread(os.path.join(output_dir, "img.png")) 132 | h, w = im.shape[:2] 133 | bboxes_scaled = rescale_bboxes(gt_anno['boxes'], (w,h)) 134 | labels = gt_anno['labels'] 135 | plot_gt(im, labels, bboxes_scaled, output_dir) 136 | 137 | def get_one_query_meanattn(vis_attn,h_featmap,w_featmap): 138 | mean_attentions = vis_attn.mean(0).reshape(h_featmap, w_featmap) 139 | mean_attentions = nn.functional.interpolate(mean_attentions.unsqueeze(0).unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy() 140 | return mean_attentions 141 | 142 | def get_one_query_attn(vis_attn, h_featmap, w_featmap, nh): 143 | attentions = vis_attn.reshape(nh, h_featmap, w_featmap) 144 | # attentions = vis_attn.sum(0).reshape(h_featmap, w_featmap) 145 | attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=16, mode="nearest")[0].cpu().numpy() 146 | return attentions 147 | -------------------------------------------------------------------------------- /semantic_segmentation/config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # DATASET PARAMETERS 4 | DATASET = 'nyudv2' 5 | TRAIN_DIR = './data/nyudv2' # 'Modify data path' 6 | VAL_DIR = TRAIN_DIR 7 | TRAIN_LIST = './data/nyudv2/train.txt' 8 | VAL_LIST = './data/nyudv2/val.txt' 9 | 10 | 11 | SHORTER_SIDE = 350 12 | CROP_SIZE = 500 13 | RESIZE_SIZE = None 14 | 15 | NORMALISE_PARAMS = [1./255, # Image SCALE 16 | np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)), # Image MEAN 17 | np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)), # Image STD 18 | 1./5000] # Depth SCALE 19 | BATCH_SIZE = 6 20 | NUM_WORKERS = 16 21 | NUM_CLASSES = 40 22 | LOW_SCALE = 0.5 23 | HIGH_SCALE = 2.0 24 | IGNORE_LABEL = 255 25 | 26 | # ENCODER PARAMETERS 27 | ENC = '101' # ResNet101 28 | ENC_PRETRAINED = True # pre-trained on ImageNet or randomly initialised 29 | 30 | # GENERAL 31 | FREEZE_BN = True 32 | NUM_SEGM_EPOCHS = [100] * 3 # [150] * 3 if using ResNet152 as backbone 33 | PRINT_EVERY = 10 34 | RANDOM_SEED = 42 35 | VAL_EVERY = 5 # how often to record validation scores 36 | 37 | # OPTIMISERS' PARAMETERS 38 | LR_ENC = [5e-4, 2.5e-4, 1e-4] # TO FREEZE, PUT 0 39 | LR_DEC = [3e-3, 1.5e-3, 7e-4] 40 | MOM_ENC = 0.9 # TO FREEZE, PUT 0 41 | MOM_DEC = 0.9 42 | WD_ENC = 1e-5 # TO FREEZE, PUT 0 43 | WD_DEC = 1e-5 44 | LAMDA = 1e-4 # slightly better 45 | BN_threshold = 2e-2 # slightly better 46 | OPTIM_DEC = 'sgd' 47 | -------------------------------------------------------------------------------- /semantic_segmentation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mix_transformer import * 2 | from .segformer import WeTr -------------------------------------------------------------------------------- /semantic_segmentation/models/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | num_parallel = 2 5 | 6 | 7 | class TokenExchange(nn.Module): 8 | def __init__(self): 9 | super(TokenExchange, self).__init__() 10 | 11 | def forward(self, x, mask, mask_threshold): 12 | # x: [B, N, C], mask: [B, N, 1] 13 | x0, x1 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) 14 | x0[mask[0] >= mask_threshold] = x[0][mask[0] >= mask_threshold] 15 | x0[mask[0] < mask_threshold] = x[1][mask[0] < mask_threshold] 16 | x1[mask[1] >= mask_threshold] = x[1][mask[1] >= mask_threshold] 17 | x1[mask[1] < mask_threshold] = x[0][mask[1] < mask_threshold] 18 | return [x0, x1] 19 | 20 | 21 | class ModuleParallel(nn.Module): 22 | def __init__(self, module): 23 | super(ModuleParallel, self).__init__() 24 | self.module = module 25 | 26 | def forward(self, x_parallel): 27 | return [self.module(x) for x in x_parallel] 28 | 29 | 30 | class LayerNormParallel(nn.Module): 31 | def __init__(self, num_features): 32 | super(LayerNormParallel, self).__init__() 33 | for i in range(num_parallel): 34 | setattr(self, 'ln_' + str(i), nn.LayerNorm(num_features, eps=1e-6)) 35 | 36 | def forward(self, x_parallel): 37 | return [getattr(self, 'ln_' + str(i))(x) for i, x in enumerate(x_parallel)] 38 | -------------------------------------------------------------------------------- /semantic_segmentation/models/segformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from . import mix_transformer 5 | from mmcv.cnn import ConvModule 6 | from .modules import num_parallel 7 | 8 | 9 | class MLP(nn.Module): 10 | """ 11 | Linear Embedding 12 | """ 13 | def __init__(self, input_dim=2048, embed_dim=768): 14 | super().__init__() 15 | self.proj = nn.Linear(input_dim, embed_dim) 16 | 17 | def forward(self, x): 18 | x = x.flatten(2).transpose(1, 2) 19 | x = self.proj(x) 20 | return x 21 | 22 | 23 | class SegFormerHead(nn.Module): 24 | """ 25 | SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 26 | """ 27 | def __init__(self, feature_strides=None, in_channels=128, embedding_dim=256, num_classes=20, **kwargs): 28 | super(SegFormerHead, self).__init__() 29 | self.in_channels = in_channels 30 | self.num_classes = num_classes 31 | assert len(feature_strides) == len(self.in_channels) 32 | assert min(feature_strides) == feature_strides[0] 33 | self.feature_strides = feature_strides 34 | 35 | c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels 36 | 37 | #decoder_params = kwargs['decoder_params'] 38 | #embedding_dim = decoder_params['embed_dim'] 39 | 40 | self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim) 41 | self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim) 42 | self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim) 43 | self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim) 44 | self.dropout = nn.Dropout2d(0.1) 45 | 46 | self.linear_fuse = ConvModule( 47 | in_channels=embedding_dim*4, 48 | out_channels=embedding_dim, 49 | kernel_size=1, 50 | norm_cfg=dict(type='BN', requires_grad=True) 51 | ) 52 | 53 | self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1) 54 | 55 | def forward(self, x): 56 | c1, c2, c3, c4 = x 57 | 58 | ############## MLP decoder on C1-C4 ########### 59 | n, _, h, w = c4.shape 60 | 61 | _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3]) 62 | _c4 = F.interpolate(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False) 63 | 64 | _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3]) 65 | _c3 = F.interpolate(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False) 66 | 67 | _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3]) 68 | _c2 = F.interpolate(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False) 69 | 70 | _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3]) 71 | 72 | _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) 73 | 74 | x = self.dropout(_c) 75 | x = self.linear_pred(x) 76 | 77 | return x 78 | 79 | 80 | class WeTr(nn.Module): 81 | def __init__(self, backbone, num_classes=20, embedding_dim=256, pretrained=True): 82 | super().__init__() 83 | self.num_classes = num_classes 84 | self.embedding_dim = embedding_dim 85 | self.feature_strides = [4, 8, 16, 32] 86 | self.num_parallel = num_parallel 87 | #self.in_channels = [32, 64, 160, 256] 88 | #self.in_channels = [64, 128, 320, 512] 89 | 90 | self.encoder = getattr(mix_transformer, backbone)() 91 | self.in_channels = self.encoder.embed_dims 92 | ## initilize encoder 93 | if pretrained: 94 | state_dict = torch.load('pretrained/' + backbone + '.pth') 95 | state_dict.pop('head.weight') 96 | state_dict.pop('head.bias') 97 | state_dict = expand_state_dict(self.encoder.state_dict(), state_dict, self.num_parallel) 98 | self.encoder.load_state_dict(state_dict, strict=True) 99 | 100 | self.decoder = SegFormerHead(feature_strides=self.feature_strides, in_channels=self.in_channels, 101 | embedding_dim=self.embedding_dim, num_classes=self.num_classes) 102 | 103 | self.alpha = nn.Parameter(torch.ones(self.num_parallel, requires_grad=True)) 104 | self.register_parameter('alpha', self.alpha) 105 | 106 | def get_param_groups(self): 107 | param_groups = [[], [], []] 108 | for name, param in list(self.encoder.named_parameters()): 109 | if "norm" in name: 110 | param_groups[1].append(param) 111 | else: 112 | param_groups[0].append(param) 113 | for param in list(self.decoder.parameters()): 114 | param_groups[2].append(param) 115 | return param_groups 116 | 117 | def forward(self, x): 118 | x, masks = self.encoder(x) 119 | x = [self.decoder(x[0]), self.decoder(x[1])] 120 | ens = 0 121 | alpha_soft = F.softmax(self.alpha) 122 | for l in range(self.num_parallel): 123 | ens += alpha_soft[l] * x[l].detach() 124 | x.append(ens) 125 | return x, masks 126 | 127 | 128 | def expand_state_dict(model_dict, state_dict, num_parallel): 129 | model_dict_keys = model_dict.keys() 130 | state_dict_keys = state_dict.keys() 131 | for model_dict_key in model_dict_keys: 132 | model_dict_key_re = model_dict_key.replace('module.', '') 133 | if model_dict_key_re in state_dict_keys: 134 | model_dict[model_dict_key] = state_dict[model_dict_key_re] 135 | for i in range(num_parallel): 136 | ln = '.ln_%d' % i 137 | replace = True if ln in model_dict_key_re else False 138 | model_dict_key_re = model_dict_key_re.replace(ln, '') 139 | if replace and model_dict_key_re in state_dict_keys: 140 | model_dict[model_dict_key] = state_dict[model_dict_key_re] 141 | return model_dict 142 | 143 | 144 | if __name__=="__main__": 145 | # import torch.distributed as dist 146 | # dist.init_process_group('gloo', init_method='file:///temp/somefile', rank=0, world_size=1) 147 | pretrained_weights = torch.load('pretrained/mit_b1.pth') 148 | wetr = WeTr('mit_b1', num_classes=20, embedding_dim=256, pretrained=True).cuda() 149 | wetr.get_param_groupsv() 150 | dummy_input = torch.rand(2,3,512,512).cuda() 151 | wetr(dummy_input) 152 | -------------------------------------------------------------------------------- /semantic_segmentation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .helpers import * 2 | from .meter import * 3 | -------------------------------------------------------------------------------- /semantic_segmentation/utils/cmap.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yikaiw/TokenFusion/3834ccf7765bb0bd50ea729069ad5adbd6de288d/semantic_segmentation/utils/cmap.npy -------------------------------------------------------------------------------- /semantic_segmentation/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def line_to_paths_fn_nyudv2(x, input_names): 9 | return x.decode('utf-8').strip('\n').split('\t') 10 | 11 | line_to_paths_fn = {'nyudv2': line_to_paths_fn_nyudv2} 12 | 13 | 14 | class SegDataset(Dataset): 15 | """Multi-Modality Segmentation dataset. 16 | 17 | Works with any datasets that contain image 18 | and any number of 2D-annotations. 19 | 20 | Args: 21 | data_file (string): Path to the data file with annotations. 22 | data_dir (string): Directory with all the images. 23 | line_to_paths_fn (callable): function to convert a line of data_file 24 | into paths (img_relpath, msk_relpath, ...). 25 | masks_names (list of strings): keys for each annotation mask 26 | (e.g., 'segm', 'depth'). 27 | transform_trn (callable, optional): Optional transform 28 | to be applied on a sample during the training stage. 29 | transform_val (callable, optional): Optional transform 30 | to be applied on a sample during the validation stage. 31 | stage (str): initial stage of dataset - either 'train' or 'val'. 32 | 33 | """ 34 | def __init__(self, dataset, data_file, data_dir, input_names, input_mask_idxs, 35 | transform_trn=None, transform_val=None, stage='train', ignore_label=None): 36 | with open(data_file, 'rb') as f: 37 | datalist = f.readlines() 38 | self.datalist = [line_to_paths_fn[dataset](l, input_names) for l in datalist] 39 | self.root_dir = data_dir 40 | self.transform_trn = transform_trn 41 | self.transform_val = transform_val 42 | self.stage = stage 43 | self.input_names = input_names 44 | self.input_mask_idxs = input_mask_idxs 45 | self.ignore_label = ignore_label 46 | 47 | def set_stage(self, stage): 48 | """Define which set of transformation to use. 49 | 50 | Args: 51 | stage (str): either 'train' or 'val' 52 | 53 | """ 54 | self.stage = stage 55 | 56 | def __len__(self): 57 | return len(self.datalist) 58 | 59 | def __getitem__(self, idx): 60 | idxs = self.input_mask_idxs 61 | names = [os.path.join(self.root_dir, rpath) for rpath in self.datalist[idx]] 62 | sample = {} 63 | for i, key in enumerate(self.input_names): 64 | sample[key] = self.read_image(names[idxs[i]], key) 65 | try: 66 | mask = np.array(Image.open(names[idxs[-1]])) 67 | except FileNotFoundError: # for sunrgbd 68 | path = names[idxs[-1]] 69 | num_idx = int(path[-10:-4]) + 5050 70 | path = path[:-10] + '%06d' % num_idx + path[-4:] 71 | mask = np.array(Image.open(path)) 72 | assert len(mask.shape) == 2, 'Masks must be encoded without colourmap' 73 | sample['inputs'] = self.input_names 74 | sample['mask'] = mask 75 | if self.stage == 'train': 76 | if self.transform_trn: 77 | sample = self.transform_trn(sample) 78 | elif self.stage == 'val': 79 | if self.transform_val: 80 | sample = self.transform_val(sample) 81 | del sample['inputs'] 82 | return sample 83 | 84 | @staticmethod 85 | def read_image_(x, key): 86 | img = cv2.imread(x) 87 | if key == 'depth': 88 | img = cv2.applyColorMap(cv2.convertScaleAbs(255 - img, alpha=1), cv2.COLORMAP_JET) 89 | return img 90 | 91 | @staticmethod 92 | def read_image(x, key): 93 | """Simple image reader 94 | 95 | Args: 96 | x (str): path to image. 97 | 98 | Returns image as `np.array`. 99 | 100 | """ 101 | img_arr = np.array(Image.open(x)) 102 | if len(img_arr.shape) == 2: # grayscale 103 | img_arr = np.tile(img_arr, [3, 1, 1]).transpose(1, 2, 0) 104 | return img_arr 105 | -------------------------------------------------------------------------------- /semantic_segmentation/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib as mpl 4 | import matplotlib.cm as cm 5 | import PIL.Image as pil 6 | import cv2 7 | import os 8 | 9 | IMG_SCALE = 1./255 10 | IMG_MEAN = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3)) 11 | IMG_STD = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3)) 12 | logger = None 13 | 14 | 15 | def print_log(message): 16 | print(message, flush=True) 17 | if logger: 18 | logger.write(str(message) + '\n') 19 | 20 | 21 | def maybe_download(model_name, model_url, model_dir=None, map_location=None): 22 | import os, sys 23 | from six.moves import urllib 24 | if model_dir is None: 25 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 26 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 27 | if not os.path.exists(model_dir): 28 | os.makedirs(model_dir) 29 | filename = '{}.pth.tar'.format(model_name) 30 | cached_file = os.path.join(model_dir, filename) 31 | if not os.path.exists(cached_file): 32 | url = model_url 33 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 34 | urllib.request.urlretrieve(url, cached_file) 35 | return torch.load(cached_file, map_location=map_location) 36 | 37 | 38 | def prepare_img(img): 39 | return (img * IMG_SCALE - IMG_MEAN) / IMG_STD 40 | 41 | 42 | def make_validation_img(img_, depth_, lab, pre): 43 | cmap = np.load('./utils/cmap.npy') 44 | 45 | img = np.array([i * IMG_STD.reshape((3, 1, 1)) + IMG_MEAN.reshape((3, 1, 1)) for i in img_]) 46 | img *= 255 47 | img = img.astype(np.uint8) 48 | img = np.concatenate(img, axis=1) 49 | 50 | depth_ = depth_[0].transpose(1, 2, 0) / max(depth_.max(), 10) 51 | vmax = np.percentile(depth_, 95) 52 | normalizer = mpl.colors.Normalize(vmin=depth_.min(), vmax=vmax) 53 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 54 | depth = (mapper.to_rgba(depth_)[:,:,:3] * 255).astype(np.uint8) 55 | lab = np.concatenate(lab) 56 | lab = np.array([cmap[i.astype(np.uint8) + 1] for i in lab]) 57 | 58 | pre = np.concatenate(pre) 59 | pre = np.array([cmap[i.astype(np.uint8) + 1] for i in pre]) 60 | img = img.transpose(1, 2, 0) 61 | 62 | return np.concatenate([img, depth, lab, pre], 1) 63 | -------------------------------------------------------------------------------- /semantic_segmentation/utils/load_state.txt: -------------------------------------------------------------------------------- 1 | RuntimeError: Error(s) in loading state_dict for mit_b1: 2 | Missing key(s) in state_dict: "patch_embed1.proj.module.weight", "patch_embed1.proj.module.bias", "patch_embed1.norm.ln_0.weight", "patch_embed1.norm.ln_0.bias", "patch_embed1.norm.ln_1.weight", "patch_embed1.norm.ln_1.bias", "patch_embed2.proj.module.weight", "patch_embed2.proj.module.bias", "patch_embed2.norm.ln_0.weight", "patch_embed2.norm.ln_0.bias", "patch_embed2.norm.ln_1.weight", "patch_embed2.norm.ln_1.bias", "patch_embed3.proj.module.weight", "patch_embed3.proj.module.bias", "patch_embed3.norm.ln_0.weight", "patch_embed3.norm.ln_0.bias", "patch_embed3.norm.ln_1.weight", "patch_embed3.norm.ln_1.bias", "patch_embed4.proj.module.weight", "patch_embed4.proj.module.bias", "patch_embed4.norm.ln_0.weight", "patch_embed4.norm.ln_0.bias", "patch_embed4.norm.ln_1.weight", "patch_embed4.norm.ln_1.bias", "block1.0.norm1.ln_0.weight", "block1.0.norm1.ln_0.bias", "block1.0.norm1.ln_1.weight", "block1.0.norm1.ln_1.bias", "block1.0.attn.q.module.weight", "block1.0.attn.q.module.bias", "block1.0.attn.kv.module.weight", "block1.0.attn.kv.module.bias", "block1.0.attn.proj.module.weight", "block1.0.attn.proj.module.bias", "block1.0.attn.sr.module.weight", "block1.0.attn.sr.module.bias", "block1.0.attn.norm.ln_0.weight", "block1.0.attn.norm.ln_0.bias", "block1.0.attn.norm.ln_1.weight", "block1.0.attn.norm.ln_1.bias", "block1.0.norm2.ln_0.weight", "block1.0.norm2.ln_0.bias", "block1.0.norm2.ln_1.weight", "block1.0.norm2.ln_1.bias", "block1.0.mlp.module.fc1.weight", "block1.0.mlp.module.fc1.bias", "block1.0.mlp.module.dwconv.dwconv.weight", "block1.0.mlp.module.dwconv.dwconv.bias", "block1.0.mlp.module.fc2.weight", "block1.0.mlp.module.fc2.bias", "block1.1.norm1.ln_0.weight", "block1.1.norm1.ln_0.bias", "block1.1.norm1.ln_1.weight", "block1.1.norm1.ln_1.bias", "block1.1.attn.q.module.weight", "block1.1.attn.q.module.bias", "block1.1.attn.kv.module.weight", "block1.1.attn.kv.module.bias", "block1.1.attn.proj.module.weight", "block1.1.attn.proj.module.bias", "block1.1.attn.sr.module.weight", "block1.1.attn.sr.module.bias", "block1.1.attn.norm.ln_0.weight", "block1.1.attn.norm.ln_0.bias", "block1.1.attn.norm.ln_1.weight", "block1.1.attn.norm.ln_1.bias", "block1.1.norm2.ln_0.weight", "block1.1.norm2.ln_0.bias", "block1.1.norm2.ln_1.weight", "block1.1.norm2.ln_1.bias", "block1.1.mlp.module.fc1.weight", "block1.1.mlp.module.fc1.bias", "block1.1.mlp.module.dwconv.dwconv.weight", "block1.1.mlp.module.dwconv.dwconv.bias", "block1.1.mlp.module.fc2.weight", "block1.1.mlp.module.fc2.bias", "norm1.ln_0.weight", "norm1.ln_0.bias", "norm1.ln_1.weight", "norm1.ln_1.bias", "block2.0.norm1.ln_0.weight", "block2.0.norm1.ln_0.bias", "block2.0.norm1.ln_1.weight", "block2.0.norm1.ln_1.bias", "block2.0.attn.q.module.weight", "block2.0.attn.q.module.bias", "block2.0.attn.kv.module.weight", "block2.0.attn.kv.module.bias", "block2.0.attn.proj.module.weight", "block2.0.attn.proj.module.bias", "block2.0.attn.sr.module.weight", "block2.0.attn.sr.module.bias", "block2.0.attn.norm.ln_0.weight", "block2.0.attn.norm.ln_0.bias", "block2.0.attn.norm.ln_1.weight", "block2.0.attn.norm.ln_1.bias", "block2.0.norm2.ln_0.weight", "block2.0.norm2.ln_0.bias", "block2.0.norm2.ln_1.weight", "block2.0.norm2.ln_1.bias", "block2.0.mlp.module.fc1.weight", "block2.0.mlp.module.fc1.bias", "block2.0.mlp.module.dwconv.dwconv.weight", "block2.0.mlp.module.dwconv.dwconv.bias", "block2.0.mlp.module.fc2.weight", "block2.0.mlp.module.fc2.bias", "block2.1.norm1.ln_0.weight", "block2.1.norm1.ln_0.bias", "block2.1.norm1.ln_1.weight", "block2.1.norm1.ln_1.bias", "block2.1.attn.q.module.weight", "block2.1.attn.q.module.bias", "block2.1.attn.kv.module.weight", "block2.1.attn.kv.module.bias", "block2.1.attn.proj.module.weight", "block2.1.attn.proj.module.bias", "block2.1.attn.sr.module.weight", "block2.1.attn.sr.module.bias", "block2.1.attn.norm.ln_0.weight", "block2.1.attn.norm.ln_0.bias", "block2.1.attn.norm.ln_1.weight", "block2.1.attn.norm.ln_1.bias", "block2.1.norm2.ln_0.weight", "block2.1.norm2.ln_0.bias", "block2.1.norm2.ln_1.weight", "block2.1.norm2.ln_1.bias", "block2.1.mlp.module.fc1.weight", "block2.1.mlp.module.fc1.bias", "block2.1.mlp.module.dwconv.dwconv.weight", "block2.1.mlp.module.dwconv.dwconv.bias", "block2.1.mlp.module.fc2.weight", "block2.1.mlp.module.fc2.bias", "norm2.ln_0.weight", "norm2.ln_0.bias", "norm2.ln_1.weight", "norm2.ln_1.bias", "block3.0.norm1.ln_0.weight", "block3.0.norm1.ln_0.bias", "block3.0.norm1.ln_1.weight", "block3.0.norm1.ln_1.bias", "block3.0.attn.q.module.weight", "block3.0.attn.q.module.bias", "block3.0.attn.kv.module.weight", "block3.0.attn.kv.module.bias", "block3.0.attn.proj.module.weight", "block3.0.attn.proj.module.bias", "block3.0.attn.sr.module.weight", "block3.0.attn.sr.module.bias", "block3.0.attn.norm.ln_0.weight", "block3.0.attn.norm.ln_0.bias", "block3.0.attn.norm.ln_1.weight", "block3.0.attn.norm.ln_1.bias", "block3.0.norm2.ln_0.weight", "block3.0.norm2.ln_0.bias", "block3.0.norm2.ln_1.weight", "block3.0.norm2.ln_1.bias", "block3.0.mlp.module.fc1.weight", "block3.0.mlp.module.fc1.bias", "block3.0.mlp.module.dwconv.dwconv.weight", "block3.0.mlp.module.dwconv.dwconv.bias", "block3.0.mlp.module.fc2.weight", "block3.0.mlp.module.fc2.bias", "block3.1.norm1.ln_0.weight", "block3.1.norm1.ln_0.bias", "block3.1.norm1.ln_1.weight", "block3.1.norm1.ln_1.bias", "block3.1.attn.q.module.weight", "block3.1.attn.q.module.bias", "block3.1.attn.kv.module.weight", "block3.1.attn.kv.module.bias", "block3.1.attn.proj.module.weight", "block3.1.attn.proj.module.bias", "block3.1.attn.sr.module.weight", "block3.1.attn.sr.module.bias", "block3.1.attn.norm.ln_0.weight", "block3.1.attn.norm.ln_0.bias", "block3.1.attn.norm.ln_1.weight", "block3.1.attn.norm.ln_1.bias", "block3.1.norm2.ln_0.weight", "block3.1.norm2.ln_0.bias", "block3.1.norm2.ln_1.weight", "block3.1.norm2.ln_1.bias", "block3.1.mlp.module.fc1.weight", "block3.1.mlp.module.fc1.bias", "block3.1.mlp.module.dwconv.dwconv.weight", "block3.1.mlp.module.dwconv.dwconv.bias", "block3.1.mlp.module.fc2.weight", "block3.1.mlp.module.fc2.bias", "norm3.ln_0.weight", "norm3.ln_0.bias", "norm3.ln_1.weight", "norm3.ln_1.bias", "block4.0.norm1.ln_0.weight", "block4.0.norm1.ln_0.bias", "block4.0.norm1.ln_1.weight", "block4.0.norm1.ln_1.bias", "block4.0.attn.q.module.weight", "block4.0.attn.q.module.bias", "block4.0.attn.kv.module.weight", "block4.0.attn.kv.module.bias", "block4.0.attn.proj.module.weight", "block4.0.attn.proj.module.bias", "block4.0.norm2.ln_0.weight", "block4.0.norm2.ln_0.bias", "block4.0.norm2.ln_1.weight", "block4.0.norm2.ln_1.bias", "block4.0.mlp.module.fc1.weight", "block4.0.mlp.module.fc1.bias", "block4.0.mlp.module.dwconv.dwconv.weight", "block4.0.mlp.module.dwconv.dwconv.bias", "block4.0.mlp.module.fc2.weight", "block4.0.mlp.module.fc2.bias", "block4.1.norm1.ln_0.weight", "block4.1.norm1.ln_0.bias", "block4.1.norm1.ln_1.weight", "block4.1.norm1.ln_1.bias", "block4.1.attn.q.module.weight", "block4.1.attn.q.module.bias", "block4.1.attn.kv.module.weight", "block4.1.attn.kv.module.bias", "block4.1.attn.proj.module.weight", "block4.1.attn.proj.module.bias", "block4.1.norm2.ln_0.weight", "block4.1.norm2.ln_0.bias", "block4.1.norm2.ln_1.weight", "block4.1.norm2.ln_1.bias", "block4.1.mlp.module.fc1.weight", "block4.1.mlp.module.fc1.bias", "block4.1.mlp.module.dwconv.dwconv.weight", "block4.1.mlp.module.dwconv.dwconv.bias", "block4.1.mlp.module.fc2.weight", "block4.1.mlp.module.fc2.bias", "norm4.ln_0.weight", "norm4.ln_0.bias", "norm4.ln_1.weight", "norm4.ln_1.bias". 3 | Unexpected key(s) in state_dict: "patch_embed1.proj.weight", "patch_embed1.proj.bias", "patch_embed1.norm.weight", "patch_embed1.norm.bias", "patch_embed2.proj.weight", "patch_embed2.proj.bias", "patch_embed2.norm.weight", "patch_embed2.norm.bias", "patch_embed3.proj.weight", "patch_embed3.proj.bias", "patch_embed3.norm.weight", "patch_embed3.norm.bias", "patch_embed4.proj.weight", "patch_embed4.proj.bias", "patch_embed4.norm.weight", "patch_embed4.norm.bias", "block1.0.norm1.weight", "block1.0.norm1.bias", "block1.0.attn.q.weight", "block1.0.attn.q.bias", "block1.0.attn.kv.weight", "block1.0.attn.kv.bias", "block1.0.attn.proj.weight", "block1.0.attn.proj.bias", "block1.0.attn.sr.weight", "block1.0.attn.sr.bias", "block1.0.attn.norm.weight", "block1.0.attn.norm.bias", "block1.0.norm2.weight", "block1.0.norm2.bias", "block1.0.mlp.fc1.weight", "block1.0.mlp.fc1.bias", "block1.0.mlp.dwconv.dwconv.weight", "block1.0.mlp.dwconv.dwconv.bias", "block1.0.mlp.fc2.weight", "block1.0.mlp.fc2.bias", "block1.1.norm1.weight", "block1.1.norm1.bias", "block1.1.attn.q.weight", "block1.1.attn.q.bias", "block1.1.attn.kv.weight", "block1.1.attn.kv.bias", "block1.1.attn.proj.weight", "block1.1.attn.proj.bias", "block1.1.attn.sr.weight", "block1.1.attn.sr.bias", "block1.1.attn.norm.weight", "block1.1.attn.norm.bias", "block1.1.norm2.weight", "block1.1.norm2.bias", "block1.1.mlp.fc1.weight", "block1.1.mlp.fc1.bias", "block1.1.mlp.dwconv.dwconv.weight", "block1.1.mlp.dwconv.dwconv.bias", "block1.1.mlp.fc2.weight", "block1.1.mlp.fc2.bias", "norm1.weight", "norm1.bias", "block2.0.norm1.weight", "block2.0.norm1.bias", "block2.0.attn.q.weight", "block2.0.attn.q.bias", "block2.0.attn.kv.weight", "block2.0.attn.kv.bias", "block2.0.attn.proj.weight", "block2.0.attn.proj.bias", "block2.0.attn.sr.weight", "block2.0.attn.sr.bias", "block2.0.attn.norm.weight", "block2.0.attn.norm.bias", "block2.0.norm2.weight", "block2.0.norm2.bias", "block2.0.mlp.fc1.weight", "block2.0.mlp.fc1.bias", "block2.0.mlp.dwconv.dwconv.weight", "block2.0.mlp.dwconv.dwconv.bias", "block2.0.mlp.fc2.weight", "block2.0.mlp.fc2.bias", "block2.1.norm1.weight", "block2.1.norm1.bias", "block2.1.attn.q.weight", "block2.1.attn.q.bias", "block2.1.attn.kv.weight", "block2.1.attn.kv.bias", "block2.1.attn.proj.weight", "block2.1.attn.proj.bias", "block2.1.attn.sr.weight", "block2.1.attn.sr.bias", "block2.1.attn.norm.weight", "block2.1.attn.norm.bias", "block2.1.norm2.weight", "block2.1.norm2.bias", "block2.1.mlp.fc1.weight", "block2.1.mlp.fc1.bias", "block2.1.mlp.dwconv.dwconv.weight", "block2.1.mlp.dwconv.dwconv.bias", "block2.1.mlp.fc2.weight", "block2.1.mlp.fc2.bias", "norm2.weight", "norm2.bias", "block3.0.norm1.weight", "block3.0.norm1.bias", "block3.0.attn.q.weight", "block3.0.attn.q.bias", "block3.0.attn.kv.weight", "block3.0.attn.kv.bias", "block3.0.attn.proj.weight", "block3.0.attn.proj.bias", "block3.0.attn.sr.weight", "block3.0.attn.sr.bias", "block3.0.attn.norm.weight", "block3.0.attn.norm.bias", "block3.0.norm2.weight", "block3.0.norm2.bias", "block3.0.mlp.fc1.weight", "block3.0.mlp.fc1.bias", "block3.0.mlp.dwconv.dwconv.weight", "block3.0.mlp.dwconv.dwconv.bias", "block3.0.mlp.fc2.weight", "block3.0.mlp.fc2.bias", "block3.1.norm1.weight", "block3.1.norm1.bias", "block3.1.attn.q.weight", "block3.1.attn.q.bias", "block3.1.attn.kv.weight", "block3.1.attn.kv.bias", "block3.1.attn.proj.weight", "block3.1.attn.proj.bias", "block3.1.attn.sr.weight", "block3.1.attn.sr.bias", "block3.1.attn.norm.weight", "block3.1.attn.norm.bias", "block3.1.norm2.weight", "block3.1.norm2.bias", "block3.1.mlp.fc1.weight", "block3.1.mlp.fc1.bias", "block3.1.mlp.dwconv.dwconv.weight", "block3.1.mlp.dwconv.dwconv.bias", "block3.1.mlp.fc2.weight", "block3.1.mlp.fc2.bias", "norm3.weight", "norm3.bias", "block4.0.norm1.weight", "block4.0.norm1.bias", "block4.0.attn.q.weight", "block4.0.attn.q.bias", "block4.0.attn.kv.weight", "block4.0.attn.kv.bias", "block4.0.attn.proj.weight", "block4.0.attn.proj.bias", "block4.0.norm2.weight", "block4.0.norm2.bias", "block4.0.mlp.fc1.weight", "block4.0.mlp.fc1.bias", "block4.0.mlp.dwconv.dwconv.weight", "block4.0.mlp.dwconv.dwconv.bias", "block4.0.mlp.fc2.weight", "block4.0.mlp.fc2.bias", "block4.1.norm1.weight", "block4.1.norm1.bias", "block4.1.attn.q.weight", "block4.1.attn.q.bias", "block4.1.attn.kv.weight", "block4.1.attn.kv.bias", "block4.1.attn.proj.weight", "block4.1.attn.proj.bias", "block4.1.norm2.weight", "block4.1.norm2.bias", "block4.1.mlp.fc1.weight", "block4.1.mlp.fc1.bias", "block4.1.mlp.dwconv.dwconv.weight", "block4.1.mlp.dwconv.dwconv.bias", "block4.1.mlp.fc2.weight", "block4.1.mlp.fc2.bias", "norm4.weight", "norm4.bias". 4 | -------------------------------------------------------------------------------- /semantic_segmentation/utils/meter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def confusion_matrix(x, y, n, ignore_label=None, mask=None): 8 | if mask is None: 9 | mask = np.ones_like(x) == 1 10 | k = (x >= 0) & (y < n) & (x != ignore_label) & (mask.astype(np.bool)) 11 | return np.bincount(n * x[k].astype(int) + y[k], minlength=n ** 2).reshape(n, n) 12 | 13 | 14 | def getScores(conf_matrix): 15 | if conf_matrix.sum() == 0: 16 | return 0, 0, 0 17 | with np.errstate(divide='ignore',invalid='ignore'): 18 | overall = np.diag(conf_matrix).sum() / np.float(conf_matrix.sum()) 19 | perclass = np.diag(conf_matrix) / conf_matrix.sum(1).astype(np.float) 20 | IU = np.diag(conf_matrix) / (conf_matrix.sum(1) + conf_matrix.sum(0) \ 21 | - np.diag(conf_matrix)).astype(np.float) 22 | return overall * 100., np.nanmean(perclass) * 100., np.nanmean(IU) * 100. 23 | 24 | 25 | def compute_params(model): 26 | """Compute number of parameters""" 27 | n_total_params = 0 28 | for name, m in model.named_parameters(): 29 | n_elem = m.numel() 30 | n_total_params += n_elem 31 | return n_total_params 32 | 33 | 34 | # Adopted from https://raw.githubusercontent.com/pytorch/examples/master/imagenet/main.py 35 | class AverageMeter(object): 36 | """Computes and stores the average and current value""" 37 | def __init__(self): 38 | self.reset() 39 | 40 | def reset(self): 41 | self.val = 0 42 | self.avg = 0 43 | self.sum = 0 44 | self.count = 0 45 | 46 | def update(self, val, n=1): 47 | self.val = val 48 | self.sum += val * n 49 | self.count += n 50 | self.avg = self.sum / self.count 51 | 52 | 53 | class Saver(): 54 | """Saver class for managing parameters""" 55 | def __init__(self, args, ckpt_dir, best_val=0, condition=lambda x, y: x > y): 56 | """ 57 | Args: 58 | args (dict): dictionary with arguments. 59 | ckpt_dir (str): path to directory in which to store the checkpoint. 60 | best_val (float): initial best value. 61 | condition (function): how to decide whether to save the new checkpoint 62 | by comparing best value and new value (x,y). 63 | 64 | """ 65 | if not os.path.exists(ckpt_dir): 66 | os.makedirs(ckpt_dir) 67 | with open('{}/args.json'.format(ckpt_dir), 'w') as f: 68 | json.dump({k: v for k, v in args.items() if isinstance(v, (int, float, str))}, f, 69 | sort_keys = True, indent = 4, ensure_ascii = False) 70 | self.ckpt_dir = ckpt_dir 71 | self.best_val = best_val 72 | self.condition = condition 73 | self._counter = 0 74 | 75 | def _do_save(self, new_val): 76 | """Check whether need to save""" 77 | return self.condition(new_val, self.best_val) 78 | 79 | def save(self, new_val, dict_to_save): 80 | """Save new checkpoint""" 81 | self._counter += 1 82 | if self._do_save(new_val): 83 | # print(' New best value {:.4f}, was {:.4f}'.format(new_val, self.best_val), flush=True) 84 | self.best_val = new_val 85 | dict_to_save['best_val'] = new_val 86 | torch.save(dict_to_save, '{}/model-best.pth.tar'.format(self.ckpt_dir)) 87 | else: 88 | dict_to_save['best_val'] = new_val 89 | torch.save(dict_to_save, '{}/checkpoint.pth.tar'.format(self.ckpt_dir)) 90 | 91 | -------------------------------------------------------------------------------- /semantic_segmentation/utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class PolyWarmupAdamW(torch.optim.AdamW): 4 | 5 | def __init__(self, params, lr, weight_decay, betas, warmup_iter=None, max_iter=None, warmup_ratio=None, power=None): 6 | super().__init__(params, lr=lr, betas=betas,weight_decay=weight_decay, eps=1e-8) 7 | 8 | self.global_step = 0 9 | self.warmup_iter = warmup_iter 10 | self.warmup_ratio = warmup_ratio 11 | self.max_iter = max_iter 12 | self.power = power 13 | 14 | self.__init_lr = [group['lr'] for group in self.param_groups] 15 | 16 | def step(self, closure=None): 17 | ## adjust lr 18 | if self.global_step < self.warmup_iter: 19 | 20 | lr_mult = 1 - (1 - self.global_step / self.warmup_iter) * (1 - self.warmup_ratio) 21 | for i in range(len(self.param_groups)): 22 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult 23 | 24 | elif self.global_step < self.max_iter: 25 | 26 | lr_mult = (1 - self.global_step / self.max_iter) ** self.power 27 | for i in range(len(self.param_groups)): 28 | self.param_groups[i]['lr'] = self.__init_lr[i] * lr_mult 29 | 30 | # step 31 | super().step(closure) 32 | 33 | self.global_step += 1 -------------------------------------------------------------------------------- /semantic_segmentation/utils/transforms.py: -------------------------------------------------------------------------------- 1 | """RefineNet-LightWeight 2 | 3 | RefineNet-LigthWeight PyTorch for non-commercial purposes 4 | 5 | Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | 30 | 31 | import cv2 32 | import numpy as np 33 | import torch 34 | 35 | # Usual dtypes for common modalities 36 | KEYS_TO_DTYPES = { 37 | 'rgb': torch.float, 38 | 'depth': torch.float, 39 | 'normals': torch.float, 40 | 'mask': torch.long, 41 | } 42 | 43 | 44 | class Pad(object): 45 | """Pad image and mask to the desired size. 46 | 47 | Args: 48 | size (int) : minimum length/width. 49 | img_val (array) : image padding value. 50 | msk_val (int) : mask padding value. 51 | 52 | """ 53 | def __init__(self, size, img_val, msk_val): 54 | assert isinstance(size, int) 55 | self.size = size 56 | self.img_val = img_val 57 | self.msk_val = msk_val 58 | 59 | def __call__(self, sample): 60 | image = sample['rgb'] 61 | h, w = image.shape[:2] 62 | h_pad = int(np.clip(((self.size - h) + 1) // 2, 0, 1e6)) 63 | w_pad = int(np.clip(((self.size - w) + 1) // 2, 0, 1e6)) 64 | pad = ((h_pad, h_pad), (w_pad, w_pad)) 65 | for key in sample['inputs']: 66 | sample[key] = self.transform_input(sample[key], pad) 67 | sample['mask'] = np.pad(sample['mask'], pad, mode='constant', constant_values=self.msk_val) 68 | return sample 69 | 70 | def transform_input(self, input, pad): 71 | input = np.stack([ 72 | np.pad(input[:, :, c], pad, mode='constant', 73 | constant_values=self.img_val[c]) for c in range(3) 74 | ], axis=2) 75 | return input 76 | 77 | 78 | class RandomCrop(object): 79 | """Crop randomly the image in a sample. 80 | 81 | Args: 82 | crop_size (int): Desired output size. 83 | 84 | """ 85 | def __init__(self, crop_size): 86 | assert isinstance(crop_size, int) 87 | self.crop_size = crop_size 88 | if self.crop_size % 2 != 0: 89 | self.crop_size -= 1 90 | 91 | def __call__(self, sample): 92 | image = sample['rgb'] 93 | h, w = image.shape[:2] 94 | new_h = min(h, self.crop_size) 95 | new_w = min(w, self.crop_size) 96 | top = np.random.randint(0, h - new_h + 1) 97 | left = np.random.randint(0, w - new_w + 1) 98 | for key in sample['inputs']: 99 | sample[key] = self.transform_input(sample[key], top, new_h, left, new_w) 100 | sample['mask'] = sample['mask'][top : top + new_h, left : left + new_w] 101 | return sample 102 | 103 | def transform_input(self, input, top, new_h, left, new_w): 104 | input = input[top : top + new_h, left : left + new_w] 105 | return input 106 | 107 | 108 | class ResizeAndScale(object): 109 | """Resize shorter/longer side to a given value and randomly scale. 110 | 111 | Args: 112 | side (int) : shorter / longer side value. 113 | low_scale (float) : lower scaling bound. 114 | high_scale (float) : upper scaling bound. 115 | shorter (bool) : whether to resize shorter / longer side. 116 | 117 | """ 118 | def __init__(self, side, low_scale, high_scale, shorter=True): 119 | assert isinstance(side, int) 120 | assert isinstance(low_scale, float) 121 | assert isinstance(high_scale, float) 122 | self.side = side 123 | self.low_scale = low_scale 124 | self.high_scale = high_scale 125 | self.shorter = shorter 126 | 127 | def __call__(self, sample): 128 | image = sample['rgb'] 129 | scale = np.random.uniform(self.low_scale, self.high_scale) 130 | if self.shorter: 131 | min_side = min(image.shape[:2]) 132 | if min_side * scale < self.side: 133 | scale = (self.side * 1. / min_side) 134 | else: 135 | max_side = max(image.shape[:2]) 136 | if max_side * scale > self.side: 137 | scale = (self.side * 1. / max_side) 138 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST} 139 | for key in sample['inputs']: 140 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 141 | sample[key] = self.transform_input(sample[key], scale, inter) 142 | sample['mask'] = cv2.resize(sample['mask'], None, fx=scale, fy=scale, 143 | interpolation=cv2.INTER_NEAREST) 144 | return sample 145 | 146 | def transform_input(self, input, scale, inter): 147 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 148 | return input 149 | 150 | 151 | class CropAlignToMask(object): 152 | """Crop inputs to the size of the mask.""" 153 | def __call__(self, sample): 154 | mask_h, mask_w = sample['mask'].shape[:2] 155 | for key in sample['inputs']: 156 | sample[key] = self.transform_input(sample[key], mask_h, mask_w) 157 | return sample 158 | 159 | def transform_input(self, input, mask_h, mask_w): 160 | input_h, input_w = input.shape[:2] 161 | if (input_h, input_w) == (mask_h, mask_w): 162 | return input 163 | h, w = (input_h - mask_h) // 2, (input_w - mask_w) // 2 164 | del_h, del_w = (input_h - mask_h) % 2, (input_w - mask_w) % 2 165 | input = input[h: input_h - h - del_h, w: input_w - w - del_w] 166 | assert input.shape[:2] == (mask_h, mask_w) 167 | return input 168 | 169 | 170 | class ResizeAlignToMask(object): 171 | """Resize inputs to the size of the mask.""" 172 | def __call__(self, sample): 173 | mask_h, mask_w = sample['mask'].shape[:2] 174 | assert mask_h == mask_w 175 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST} 176 | for key in sample['inputs']: 177 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 178 | sample[key] = self.transform_input(sample[key], mask_h, inter) 179 | return sample 180 | 181 | def transform_input(self, input, mask_h, inter): 182 | input_h, input_w = input.shape[:2] 183 | assert input_h == input_w 184 | scale = mask_h / input_h 185 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 186 | return input 187 | 188 | 189 | class ResizeInputs(object): 190 | def __init__(self, size): 191 | self.size = size 192 | 193 | def __call__(self, sample): 194 | # sample['rgb'] = sample['rgb'].numpy() 195 | if self.size is None: 196 | return sample 197 | size = sample['rgb'].shape[0] 198 | scale = self.size / size 199 | # print(sample['rgb'].shape, type(sample['rgb'])) 200 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST} 201 | for key in sample['inputs']: 202 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 203 | sample[key] = self.transform_input(sample[key], scale, inter) 204 | return sample 205 | 206 | def transform_input(self, input, scale, inter): 207 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 208 | return input 209 | 210 | 211 | class ResizeInputsScale(object): 212 | def __init__(self, scale): 213 | self.scale = scale 214 | 215 | def __call__(self, sample): 216 | if self.scale is None: 217 | return sample 218 | inters = {'rgb': cv2.INTER_CUBIC, 'depth': cv2.INTER_NEAREST} 219 | for key in sample['inputs']: 220 | inter = inters[key] if key in inters else cv2.INTER_CUBIC 221 | sample[key] = self.transform_input(sample[key], self.scale, inter) 222 | return sample 223 | 224 | def transform_input(self, input, scale, inter): 225 | input = cv2.resize(input, None, fx=scale, fy=scale, interpolation=inter) 226 | return input 227 | 228 | 229 | class RandomMirror(object): 230 | """Randomly flip the image and the mask""" 231 | def __call__(self, sample): 232 | do_mirror = np.random.randint(2) 233 | if do_mirror: 234 | for key in sample['inputs']: 235 | sample[key] = cv2.flip(sample[key], 1) 236 | sample['mask'] = cv2.flip(sample['mask'], 1) 237 | return sample 238 | 239 | 240 | class Normalise(object): 241 | """Normalise a tensor image with mean and standard deviation. 242 | Given mean: (R, G, B) and std: (R, G, B), 243 | will normalise each channel of the torch.*Tensor, i.e. 244 | channel = (scale * channel - mean) / std 245 | 246 | Args: 247 | scale (float): Scaling constant. 248 | mean (sequence): Sequence of means for R,G,B channels respecitvely. 249 | std (sequence): Sequence of standard deviations for R,G,B channels 250 | respecitvely. 251 | depth_scale (float): Depth divisor for depth annotations. 252 | 253 | """ 254 | def __init__(self, scale, mean, std, depth_scale=1.): 255 | self.scale = scale 256 | self.mean = mean 257 | self.std = std 258 | self.depth_scale = depth_scale 259 | 260 | def __call__(self, sample): 261 | for key in sample['inputs']: 262 | if key == 'depth': 263 | continue 264 | sample[key] = (self.scale * sample[key] - self.mean) / self.std 265 | if 'depth' in sample: 266 | # sample['depth'] = self.scale * sample['depth'] 267 | # sample['depth'] = (self.scale * sample['depth'] - self.mean) / self.std 268 | if self.depth_scale > 0: 269 | sample['depth'] = self.depth_scale * sample['depth'] 270 | elif self.depth_scale == -1: # taskonomy 271 | # sample['depth'] = np.log(1 + sample['depth']) / np.log(2.** 16.0) 272 | sample['depth'] = np.log(1 + sample['depth']) 273 | elif self.depth_scale == -2: # sunrgbd 274 | depth = sample['depth'] 275 | sample['depth'] = (depth - depth.min()) * 255.0 / (depth.max() - depth.min()) 276 | return sample 277 | 278 | 279 | class ToTensor(object): 280 | """Convert ndarrays in sample to Tensors.""" 281 | def __call__(self, sample): 282 | # swap color axis because 283 | # numpy image: H x W x C 284 | # torch image: C X H X W 285 | for key in sample['inputs']: 286 | sample[key] = torch.from_numpy( 287 | sample[key].transpose((2, 0, 1)) 288 | ).to(KEYS_TO_DTYPES[key] if key in KEYS_TO_DTYPES else KEYS_TO_DTYPES['rgb']) 289 | sample['mask'] = torch.from_numpy(sample['mask']).to(KEYS_TO_DTYPES['mask']) 290 | return sample 291 | 292 | 293 | def make_list(x): 294 | """Returns the given input as a list.""" 295 | if isinstance(x, list): 296 | return x 297 | elif isinstance(x, tuple): 298 | return list(x) 299 | else: 300 | return [x] 301 | 302 | --------------------------------------------------------------------------------