├── .gitignore ├── LICENSE ├── README.md ├── config └── scannet │ ├── semaffinet_2cm.yaml │ └── semaffinet_5cm.yaml ├── dataset ├── augmentation.py ├── augmentation_2d.py ├── pregroup_2d_scannet.py ├── preprocess_3d_scannet.py ├── scanNet3D.py ├── scanNetCross.py ├── scannet │ ├── GL.py │ ├── scannet_names.txt │ ├── train_2D.txt │ └── val_2D.txt ├── voxelization_utils.py └── voxelizer.py ├── fig └── semaffinet.png ├── metrics └── iou.py ├── models ├── bpm.py ├── me_common.py ├── resnet_d.py ├── resnet_mink.py ├── semaffinet.py ├── shadownet_2d.py ├── shadownet_3d.py ├── transformer_utils │ ├── transformer.py │ └── transformer_predictor.py ├── unet_2d.py └── unet_3d.py ├── tool ├── test.py ├── test.sh ├── train.py └── train.sh └── util ├── config.py ├── criterion.py ├── solver.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | Exp/ 3 | extensions/ 4 | initmodel/ 5 | __pycache__ 6 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ziyi Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SemAffiNet: Semantic-Affine Transformation for Point Cloud Segmentation 2 | Created by [Ziyi Wang](https://wangzy22.github.io/), [Yongming Rao](https://raoyongming.github.io/), [Xumin Yu](https://yuxumin.github.io/), [Jie Zhou](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1), [Jiwen Lu](https://scholar.google.com/citations?user=TN8uDQoAAAAJ&hl=zh-CN) 3 | 4 | This repository is an official implementation of SemAffiNet (CVPR 2022). 5 | 6 | [[arxiv]](http://arxiv.org/abs/2205.13490) 7 | 8 | ![intro](fig/semaffinet.png) 9 | 10 | ## Installation 11 | 12 | ### Prerequisites 13 | 14 | - Python 3.8 15 | - PyTorch 1.8.1 16 | - MinkowskiEngine 0.5.4 17 | - timm 18 | - open3d 19 | - cv2, tensorboardX, imageio, SharedArray, scipy, tqdm, h5py 20 | 21 | ``` 22 | conda create -n semaffinet python=3.8 23 | conda activate semaffinet 24 | conda install pytorch==1.8.1 torchvision==0.9.1 cudatoolkit=10.2 25 | 26 | git clone https://github.com/NVIDIA/MinkowskiEngine.git 27 | cd MinkowskiEngine 28 | export CXX=g++-7 29 | conda install openblas 30 | python setup.py install --blas_include_dirs=${CONDA_PREFIX}/include --blas=openblas 31 | 32 | pip install timm 33 | pip install open3d 34 | pip install opencv-python 35 | conda install tensorboardX, imageio, sharedarray, plyfile, tqdm 36 | ``` 37 | 38 | ## Usage 39 | 40 | ### Data Preparation 41 | 42 | - Download the official [ScanNetV2](https://github.com/ScanNet/ScanNet) dataset. 43 | 44 | - Prepare ScanNetV2 2D data: 45 | Please follow instructions in [3DMV](https://github.com/angeladai/3DMV/tree/master/prepare_data) repo. 46 | ``` 47 | python prepare_2d_data.py --scannet_path SCANNET_INPUT_PATH --output_path SCANNET_OUTPUT_PATH --export_label_images 48 | ``` 49 | 50 | - Prepare ScanNetV2 3D data: 51 | ``` 52 | python dataset/preprocess_3d_scannet.py 53 | ``` 54 | 55 | - Group ScanNetV2 2D views: preprocess 2D data and group multiple views of one scene into several groups. 56 | You will need to install `pointnet2_ops` from [PointNet++](https://github.com/erikwijmans/Pointnet2_PyTorch/tree/master/pointnet2_ops_lib) PyTorch repo to run the following command: 57 | ``` 58 | python dataset/pregroup_2d_scannet.py 59 | ``` 60 | You can also download our processed group results [here](https://drive.google.com/drive/folders/1qgOuSVjtH_gQZRobQ3p4iwesHDVV9nMu?usp=sharing). 61 | 62 | - The data is expected to be in the following file structure: 63 | ``` 64 | SemAffiNet/ 65 | |-- data/ 66 | |-- 2D/ 67 | |-- scene0000_00/ 68 | |-- color/ 69 | |-- 0.jpg 70 | |-- depth/ 71 | |-- 0.png 72 | |-- label/ 73 | |-- 0.png 74 | |-- pose/ 75 | |-- 0.txt 76 | |-- 3D/ 77 | |-- train/ 78 | |-- scene0000_00_vh_clean_2.pth 79 | |-- val/ 80 | |-- scene0011_00_vh_clean_2.pth 81 | |-- test/ 82 | |-- scene0707_00_vh_clean_2.pth 83 | |-- view_groups/ 84 | |-- view_groups_train.pth 85 | |-- view_groups_val.pth 86 | |-- view_groups_test.pth 87 | ``` 88 | 89 | ### Init model preparation 90 | Download the pre-trained resnet34d [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth) and place it in the `initmodel` folder. The pre-trained weight is from the `timm` repository. 91 | ### Train 92 | 93 | - ScanNetV2 **5cm** voxelization setting: 94 | ``` 95 | bash tool/train.sh SemAffiNet_5cm config/scannet/semaffinet_5cm.yaml scannet 2 96 | ``` 97 | - ScanNetV2 **2cm** voxelization setting: 98 | ``` 99 | bash tool/train.sh SemAffiNet_2cm config/scannet/semaffinet_2cm.yaml scannet 2 100 | ``` 101 | 102 | ### Test 103 | 104 | - ScanNetV2 **5cm** voxelization setting: 105 | ``` 106 | bash tool/test.sh SemAffiNet_5cm config/scannet/semaffinet_5cm.yaml scannet 2 107 | ``` 108 | - ScanNetV2 **2cm** voxelization setting: 109 | ``` 110 | bash tool/test.sh SemAffiNet_2cm config/scannet/semaffinet_2cm.yaml scannet 2 111 | ``` 112 | 113 | ## Results 114 | 115 | We provide pre-trained SemAffiNet models: 116 | | Dataset | URL | 3D mIoU | 2D mIoU | 117 | | ------- | --- | ------- | ------- | 118 | | ScanNetV2 5cm | [Google Drive](https://drive.google.com/file/d/16ghVzxbm05Sn4h8t7Ogr2LKJwqlZRgWI/view?usp=sharing) | 72.1 | 68.2 | 119 | | ScanNetV2 2cm | [Google Drive](https://drive.google.com/file/d/1rL_jVnJGRmGDcg4_0MRLug4s6qx0nF5O/view?usp=sharing) | 74.5 | 74.2 | 120 | 121 | Please rename the checkpoints as `model_best.pth.tar` and organize the directory as the following structure: 122 | ``` 123 | SemAffiNet/ 124 | |-- initmodel/ 125 | |-- resnet34d_ra2-f8dcfcaf.pth 126 | |-- Exp/ 127 | |-- scannet/ 128 | |-- SemAffiNet_2cm/ 129 | |-- model/ 130 | |-- model_best.pth.tar 131 | |-- SemAffiNet_5cm/ 132 | |-- model/ 133 | |-- model_best.pth.tar 134 | ``` 135 | 136 | ## Citation 137 | If you find our work useful in your research, please consider citing: 138 | ``` 139 | @inproceedings{wang2022semaff, 140 | title={SemAffiNet: Semantic-Affine Transformation for Point Cloud Segmentation}, 141 | author={Wang, Ziyi and Rao, Yongming and Yu, Xumin and Zhou, Jie and Lu, Jiwen}, 142 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 143 | year={2022} 144 | } 145 | ``` 146 | 147 | ## Acknowledgements 148 | 149 | Our code is inspired by [BPNet](https://github.com/wbhu/BPNet). Some of the data preprocessing codes for ScanNetV2 are inspired by [3DMV](https://github.com/angeladai/3DMV/tree/master/prepare_data). 150 | -------------------------------------------------------------------------------- /config/scannet/semaffinet_2cm.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: scannet_cross 3 | data_root: ./data 4 | classes: 20 5 | aug: True 6 | voxelSize: 0.02 7 | 8 | MASKFORMER: 9 | hidden_dim: 128 10 | dropout: 0.1 11 | nheads: 8 12 | dim_feedforward: 1024 13 | enc_layers: 6 14 | dec_layers: 6 15 | num_point_batch: 800 16 | mask_dim: 128 17 | drop_rate: 0.1 18 | attn_drop_rate: 0.1 19 | drop_path: 0.1 20 | num_tokens_2d: 88 21 | 22 | TRAIN: 23 | viewNum: 3 24 | weight_2d: 0.1 25 | arch: semaffinet 26 | layers_2d: 34 27 | arch_3d: MinkUNet18A 28 | 29 | sync_bn_2d: True 30 | sync_bn_3d: True 31 | ignore_label: 255 32 | train_gpu: [0,1,2,3,4,5,6,7] 33 | workers: 16 # data loader workers 34 | batch_size: 16 # batch size for training 35 | batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff 36 | base_lr: 0.02 37 | loop: 5 38 | epochs: 100 39 | empty_cache_epochs: 5 40 | start_epoch: 0 41 | power: 0.9 42 | momentum: 0.9 43 | weight_decay: 0.0001 44 | 45 | scheduler: SquaredLRWarmingUp 46 | step_size: 20000 47 | max_iter: 40000 48 | step_gamma: 0.1 49 | poly_power: 0.9 50 | exp_gamma: 0.95 51 | exp_step_size: 445 52 | 53 | manual_seed: 1463 54 | print_freq: 10 55 | save_freq: 1 56 | save_path: 57 | weight: # path to initial weight (default: none) 58 | resume: 59 | evaluate: True 60 | eval_freq: 1 61 | 62 | Distributed: 63 | dist_url: tcp://127.0.0.1:6787 64 | dist_backend: 'nccl' 65 | multiprocessing_distributed: True 66 | world_size: 1 67 | rank: 0 68 | 69 | 70 | TEST: 71 | split: val # split in [train, val and test] 72 | val_benchmark: True 73 | test_workers: 4 74 | test_gpu: [0,1,2,3] 75 | test_batch_size: 16 76 | model_path: 77 | save_folder: 78 | test_repeats: 7 -------------------------------------------------------------------------------- /config/scannet/semaffinet_5cm.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | data_name: scannet_cross 3 | data_root: ./data 4 | classes: 20 5 | aug: True 6 | voxelSize: 0.05 7 | 8 | MASKFORMER: 9 | hidden_dim: 128 10 | dropout: 0.1 11 | nheads: 8 12 | dim_feedforward: 1024 13 | enc_layers: 6 14 | dec_layers: 6 15 | num_point_batch: 150 16 | mask_dim: 128 17 | drop_rate: 0.1 18 | attn_drop_rate: 0.1 19 | drop_path: 0.1 20 | num_tokens_2d: 88 21 | 22 | TRAIN: 23 | viewNum: 5 24 | weight_2d: 0.1 25 | arch: semaffinet 26 | layers_2d: 34 27 | arch_3d: MinkUNet18A 28 | 29 | sync_bn_2d: True 30 | sync_bn_3d: True 31 | ignore_label: 255 32 | train_gpu: [0,1,2,3,4,5,6,7] 33 | workers: 16 # data loader workers 34 | batch_size: 16 # batch size for training 35 | batch_size_val: 16 # batch size for validation during training, memory and speed tradeoff 36 | base_lr: 0.02 37 | loop: 5 38 | epochs: 100 39 | start_epoch: 0 40 | power: 0.9 41 | momentum: 0.9 42 | weight_decay: 0.0001 43 | 44 | scheduler: SquaredLRWarmingUp 45 | step_size: 20000 46 | max_iter: 60000 47 | step_gamma: 0.1 48 | poly_power: 0.9 49 | exp_gamma: 0.95 50 | exp_step_size: 445 51 | 52 | manual_seed: 1463 53 | print_freq: 10 54 | save_freq: 1 55 | save_path: 56 | weight: # path to initial weight (default: none) 57 | resume: 58 | evaluate: True 59 | eval_freq: 1 60 | 61 | Distributed: 62 | dist_url: tcp://127.0.0.1:6787 63 | dist_backend: 'nccl' 64 | multiprocessing_distributed: True 65 | world_size: 1 66 | rank: 0 67 | 68 | 69 | TEST: 70 | split: val # split in [train, val and test] 71 | val_benchmark: True 72 | test_workers: 4 73 | test_gpu: [0,1,2,3] 74 | test_batch_size: 16 75 | model_path: 76 | save_folder: 77 | test_repeats: 7 -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import logging 4 | import numpy as np 5 | import scipy 6 | import scipy.ndimage 7 | import scipy.interpolate 8 | import torch 9 | 10 | 11 | # A sparse tensor consists of coordinates and associated features. 12 | # You must apply augmentation to both. 13 | # In 2D, flip, shear, scale, and rotation of images are coordinate transformation 14 | # color jitter, hue, etc., are feature transformations 15 | ############################## 16 | # Feature transformations 17 | ############################## 18 | class ChromaticTranslation(object): 19 | """Add random color to the image, input must be an array in [0,255] or a PIL image""" 20 | 21 | def __init__(self, trans_range_ratio=1e-1): 22 | """ 23 | trans_range_ratio: ratio of translation i.e. 255 * 2 * ratio * rand(-0.5, 0.5) 24 | """ 25 | self.trans_range_ratio = trans_range_ratio 26 | 27 | def __call__(self, coords, feats, labels): 28 | if random.random() < 0.95: 29 | tr = (np.random.rand(1, 3) - 0.5) * 255 * 2 * self.trans_range_ratio 30 | feats[:, :3] = np.clip(tr + feats[:, :3], 0, 255) 31 | return coords, feats, labels 32 | 33 | 34 | class ChromaticAutoContrast(object): 35 | 36 | def __init__(self, randomize_blend_factor=True, blend_factor=0.5): 37 | self.randomize_blend_factor = randomize_blend_factor 38 | self.blend_factor = blend_factor 39 | 40 | def __call__(self, coords, feats, labels): 41 | if random.random() < 0.2: 42 | # mean = np.mean(feats, 0, keepdims=True) 43 | # std = np.std(feats, 0, keepdims=True) 44 | # lo = mean - std 45 | # hi = mean + std 46 | lo = np.min(feats, 0, keepdims=True) 47 | hi = np.max(feats, 0, keepdims=True) 48 | 49 | scale = 255 / (hi - lo) 50 | 51 | contrast_feats = (feats - lo) * scale 52 | 53 | blend_factor = random.random() if self.randomize_blend_factor else self.blend_factor 54 | feats = (1 - blend_factor) * feats + blend_factor * contrast_feats 55 | return coords, feats, labels 56 | 57 | 58 | class ChromaticJitter(object): 59 | 60 | def __init__(self, std=0.01): 61 | self.std = std 62 | 63 | def __call__(self, coords, feats, labels): 64 | if random.random() < 0.95: 65 | noise = np.random.randn(feats.shape[0], 3) 66 | noise *= self.std * 255 67 | feats[:, :3] = np.clip(noise + feats[:, :3], 0, 255) 68 | return coords, feats, labels 69 | 70 | 71 | class HueSaturationTranslation(object): 72 | 73 | @staticmethod 74 | def rgb_to_hsv(rgb): 75 | # Translated from source of colorsys.rgb_to_hsv 76 | # r,g,b should be a numpy arrays with values between 0 and 255 77 | # rgb_to_hsv returns an array of floats between 0.0 and 1.0. 78 | rgb = rgb.astype('float') 79 | hsv = np.zeros_like(rgb) 80 | # in case an RGBA array was passed, just copy the A channel 81 | hsv[..., 3:] = rgb[..., 3:] 82 | r, g, b = rgb[..., 0], rgb[..., 1], rgb[..., 2] 83 | maxc = np.max(rgb[..., :3], axis=-1) 84 | minc = np.min(rgb[..., :3], axis=-1) 85 | hsv[..., 2] = maxc 86 | mask = maxc != minc 87 | hsv[mask, 1] = (maxc - minc)[mask] / maxc[mask] 88 | rc = np.zeros_like(r) 89 | gc = np.zeros_like(g) 90 | bc = np.zeros_like(b) 91 | rc[mask] = (maxc - r)[mask] / (maxc - minc)[mask] 92 | gc[mask] = (maxc - g)[mask] / (maxc - minc)[mask] 93 | bc[mask] = (maxc - b)[mask] / (maxc - minc)[mask] 94 | hsv[..., 0] = np.select([r == maxc, g == maxc], [bc - gc, 2.0 + rc - bc], default=4.0 + gc - rc) 95 | hsv[..., 0] = (hsv[..., 0] / 6.0) % 1.0 96 | return hsv 97 | 98 | @staticmethod 99 | def hsv_to_rgb(hsv): 100 | # Translated from source of colorsys.hsv_to_rgb 101 | # h,s should be a numpy arrays with values between 0.0 and 1.0 102 | # v should be a numpy array with values between 0.0 and 255.0 103 | # hsv_to_rgb returns an array of uints between 0 and 255. 104 | rgb = np.empty_like(hsv) 105 | rgb[..., 3:] = hsv[..., 3:] 106 | h, s, v = hsv[..., 0], hsv[..., 1], hsv[..., 2] 107 | i = (h * 6.0).astype('uint8') 108 | f = (h * 6.0) - i 109 | p = v * (1.0 - s) 110 | q = v * (1.0 - s * f) 111 | t = v * (1.0 - s * (1.0 - f)) 112 | i = i % 6 113 | conditions = [s == 0.0, i == 1, i == 2, i == 3, i == 4, i == 5] 114 | rgb[..., 0] = np.select(conditions, [v, q, p, p, t, v], default=v) 115 | rgb[..., 1] = np.select(conditions, [v, v, v, q, p, p], default=t) 116 | rgb[..., 2] = np.select(conditions, [v, p, t, v, v, q], default=p) 117 | return rgb.astype('uint8') 118 | 119 | def __init__(self, hue_max, saturation_max): 120 | self.hue_max = hue_max 121 | self.saturation_max = saturation_max 122 | 123 | def __call__(self, coords, feats, labels): 124 | # Assume feat[:, :3] is rgb 125 | hsv = HueSaturationTranslation.rgb_to_hsv(feats[:, :3]) 126 | hue_val = (random.random() - 0.5) * 2 * self.hue_max 127 | sat_ratio = 1 + (random.random() - 0.5) * 2 * self.saturation_max 128 | hsv[..., 0] = np.remainder(hue_val + hsv[..., 0] + 1, 1) 129 | hsv[..., 1] = np.clip(sat_ratio * hsv[..., 1], 0, 1) 130 | feats[:, :3] = np.clip(HueSaturationTranslation.hsv_to_rgb(hsv), 0, 255) 131 | 132 | return coords, feats, labels 133 | 134 | 135 | ############################## 136 | # Coordinate transformations 137 | ############################## 138 | class RandomHorizontalFlip(object): 139 | 140 | def __init__(self, upright_axis, is_temporal): 141 | """ 142 | upright_axis: axis index among x,y,z, i.e. 2 for z 143 | """ 144 | self.is_temporal = is_temporal 145 | self.D = 4 if is_temporal else 3 146 | self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()] 147 | # Use the rest of axes for flipping. 148 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 149 | 150 | def __call__(self, coords, feats, labels): 151 | if random.random() < 0.95: 152 | for curr_ax in self.horz_axes: 153 | if random.random() < 0.5: 154 | coord_max = np.max(coords[:, curr_ax]) 155 | coords[:, curr_ax] = coord_max - coords[:, curr_ax] 156 | return coords, feats, labels 157 | 158 | 159 | class ElasticDistortion: 160 | 161 | def __init__(self, distortion_params): 162 | self.distortion_params = distortion_params 163 | 164 | def elastic_distortion(self, coords, granularity, magnitude): 165 | """Apply elastic distortion on sparse coordinate space. 166 | 167 | pointcloud: numpy array of (number of points, at least 3 spatial dims) 168 | granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) 169 | magnitude: noise multiplier 170 | """ 171 | blurx = np.ones((3, 1, 1, 1)).astype('float32') / 3 172 | blury = np.ones((1, 3, 1, 1)).astype('float32') / 3 173 | blurz = np.ones((1, 1, 3, 1)).astype('float32') / 3 174 | coords_min = coords.min(0) 175 | 176 | # Create Gaussian noise tensor of the size given by granularity. 177 | noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 178 | noise = np.random.randn(*noise_dim, 3).astype(np.float32) 179 | 180 | # Smoothing. 181 | for _ in range(2): 182 | noise = scipy.ndimage.filters.convolve(noise, blurx, mode='constant', cval=0) 183 | noise = scipy.ndimage.filters.convolve(noise, blury, mode='constant', cval=0) 184 | noise = scipy.ndimage.filters.convolve(noise, blurz, mode='constant', cval=0) 185 | 186 | # Trilinear interpolate noise filters for each spatial dimensions. 187 | ax = [ 188 | np.linspace(d_min, d_max, d) 189 | for d_min, d_max, d in zip(coords_min - granularity, coords_min + 190 | granularity * (noise_dim - 2), noise_dim) 191 | ] 192 | interp = scipy.interpolate.RegularGridInterpolator(ax, noise, bounds_error=0, fill_value=0) 193 | coords = coords + interp(coords) * magnitude 194 | return coords 195 | 196 | def __call__(self, pointcloud): 197 | if self.distortion_params is not None: 198 | if random.random() < 0.95: 199 | for granularity, magnitude in self.distortion_params: 200 | pointcloud = self.elastic_distortion(pointcloud, granularity, magnitude) 201 | return pointcloud 202 | 203 | 204 | class Compose(object): 205 | """Composes several transforms together.""" 206 | 207 | def __init__(self, transforms): 208 | self.transforms = transforms 209 | 210 | def __call__(self, *args): 211 | for t in self.transforms: 212 | args = t(*args) 213 | return args 214 | 215 | 216 | class cfl_collate_fn_factory: 217 | """Generates collate function for coords, feats, labels. 218 | 219 | Args: 220 | limit_numpoints: If 0 or False, does not alter batch size. If positive integer, limits batch 221 | size so that the number of input coordinates is below limit_numpoints. 222 | """ 223 | 224 | def __init__(self, limit_numpoints): 225 | self.limit_numpoints = limit_numpoints 226 | 227 | def __call__(self, list_data): 228 | coords, feats, labels = list(zip(*list_data)) 229 | coords_batch, feats_batch, labels_batch = [], [], [] 230 | 231 | batch_id = 0 232 | batch_num_points = 0 233 | for batch_id, _ in enumerate(coords): 234 | num_points = coords[batch_id].shape[0] 235 | batch_num_points += num_points 236 | if self.limit_numpoints and batch_num_points > self.limit_numpoints: 237 | num_full_points = sum(len(c) for c in coords) 238 | num_full_batch_size = len(coords) 239 | logging.warning( 240 | f'\t\tCannot fit {num_full_points} points into {self.limit_numpoints} points ' 241 | f'limit. Truncating batch size at {batch_id} out of {num_full_batch_size} with {batch_num_points - num_points}.' 242 | ) 243 | break 244 | coords_batch.append( 245 | torch.cat((torch.from_numpy(coords[batch_id]).int(), 246 | torch.ones(num_points, 1).int() * batch_id), 1)) 247 | feats_batch.append(torch.from_numpy(feats[batch_id])) 248 | labels_batch.append(torch.from_numpy(labels[batch_id]).int()) 249 | 250 | batch_id += 1 251 | 252 | # Concatenate all lists 253 | coords_batch = torch.cat(coords_batch, 0).int() 254 | feats_batch = torch.cat(feats_batch, 0).float() 255 | labels_batch = torch.cat(labels_batch, 0).int() 256 | return coords_batch, feats_batch, labels_batch 257 | 258 | 259 | class cflt_collate_fn_factory: 260 | """Generates collate function for coords, feats, labels, point_clouds, transformations. 261 | 262 | Args: 263 | limit_numpoints: If 0 or False, does not alter batch size. If positive integer, limits batch 264 | size so that the number of input coordinates is below limit_numpoints. 265 | """ 266 | 267 | def __init__(self, limit_numpoints): 268 | self.limit_numpoints = limit_numpoints 269 | 270 | def __call__(self, list_data): 271 | coords, feats, labels, pointclouds, transformations = list(zip(*list_data)) 272 | cfl_collate_fn = cfl_collate_fn_factory(limit_numpoints=self.limit_numpoints) 273 | coords_batch, feats_batch, labels_batch = cfl_collate_fn(list(zip(coords, feats, labels))) 274 | num_truncated_batch = coords_batch[:, -1].max().item() + 1 275 | 276 | batch_id = 0 277 | pointclouds_batch, transformations_batch = [], [] 278 | for pointcloud, transformation in zip(pointclouds, transformations): 279 | if batch_id >= num_truncated_batch: 280 | break 281 | pointclouds_batch.append( 282 | torch.cat((torch.from_numpy(pointcloud), torch.ones(pointcloud.shape[0], 1) * batch_id), 283 | 1)) 284 | transformations_batch.append( 285 | torch.cat( 286 | (torch.from_numpy(transformation), torch.ones(transformation.shape[0], 1) * batch_id), 287 | 1)) 288 | batch_id += 1 289 | 290 | pointclouds_batch = torch.cat(pointclouds_batch, 0).float() 291 | transformations_batch = torch.cat(transformations_batch, 0).float() 292 | return coords_batch, feats_batch, labels_batch, pointclouds_batch, transformations_batch 293 | -------------------------------------------------------------------------------- /dataset/augmentation_2d.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | import numbers 5 | import collections 6 | import cv2 7 | 8 | import torch 9 | 10 | 11 | class Compose(object): 12 | # Composes segtransforms: segtransform.Compose([segtransform.RandScale([0.5, 2.0]), segtransform.ToTensor()]) 13 | def __init__(self, segtransform): 14 | self.segtransform = segtransform 15 | 16 | def __call__(self, image, label): 17 | for t in self.segtransform: 18 | image, label = t(image, label) 19 | return image, label 20 | 21 | 22 | class ToTensor(object): 23 | # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). 24 | def __call__(self, image, label): 25 | if not isinstance(image, np.ndarray) or not isinstance(label, np.ndarray): 26 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray" 27 | "[eg: data readed by cv2.imread()].\n")) 28 | if len(image.shape) > 3 or len(image.shape) < 2: 29 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray with 3 dims or 2 dims.\n")) 30 | if len(image.shape) == 2: 31 | image = np.expand_dims(image, axis=2) 32 | if not len(label.shape) == 2: 33 | raise (RuntimeError("segtransform.ToTensor() only handle np.ndarray labellabel with 2 dims.\n")) 34 | 35 | image = torch.from_numpy(image.transpose((2, 0, 1))) 36 | if not isinstance(image, torch.FloatTensor): 37 | image = image.float() 38 | label = torch.from_numpy(label) 39 | if not isinstance(label, torch.LongTensor): 40 | label = label.long() 41 | return image, label 42 | 43 | 44 | class Normalize(object): 45 | # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std 46 | def __init__(self, mean, std=None): 47 | if std is None: 48 | assert len(mean) > 0 49 | else: 50 | assert len(mean) == len(std) 51 | self.mean = mean 52 | self.std = std 53 | 54 | def __call__(self, image, label): 55 | if self.std is None: 56 | for t, m in zip(image, self.mean): 57 | t.sub_(m) 58 | else: 59 | for t, m, s in zip(image, self.mean, self.std): 60 | t.sub_(m).div_(s) 61 | return image, label 62 | 63 | 64 | class Resize(object): 65 | # Resize the input to the given size, 'size' is a 2-element tuple or list in the order of (h, w). 66 | def __init__(self, size): 67 | assert (isinstance(size, collections.Iterable) and len(size) == 2) 68 | self.size = size 69 | 70 | def __call__(self, image, label): 71 | image = cv2.resize(image, self.size[::-1], interpolation=cv2.INTER_LINEAR) 72 | label = cv2.resize(label, self.size[::-1], interpolation=cv2.INTER_NEAREST) 73 | return image, label 74 | 75 | 76 | class RandScale(object): 77 | # Randomly resize image & label with scale factor in [scale_min, scale_max] 78 | def __init__(self, scale, aspect_ratio=None): 79 | assert (isinstance(scale, collections.Iterable) and len(scale) == 2) 80 | if isinstance(scale, collections.Iterable) and len(scale) == 2 \ 81 | and isinstance(scale[0], numbers.Number) and isinstance(scale[1], numbers.Number) \ 82 | and 0 < scale[0] < scale[1]: 83 | self.scale = scale 84 | else: 85 | raise (RuntimeError("segtransform.RandScale() scale param error.\n")) 86 | if aspect_ratio is None: 87 | self.aspect_ratio = aspect_ratio 88 | elif isinstance(aspect_ratio, collections.Iterable) and len(aspect_ratio) == 2 \ 89 | and isinstance(aspect_ratio[0], numbers.Number) and isinstance(aspect_ratio[1], numbers.Number) \ 90 | and 0 < aspect_ratio[0] < aspect_ratio[1]: 91 | self.aspect_ratio = aspect_ratio 92 | else: 93 | raise (RuntimeError("segtransform.RandScale() aspect_ratio param error.\n")) 94 | 95 | def __call__(self, image, label): 96 | temp_scale = self.scale[0] + (self.scale[1] - self.scale[0]) * random.random() 97 | temp_aspect_ratio = 1.0 98 | if self.aspect_ratio is not None: 99 | temp_aspect_ratio = self.aspect_ratio[0] + (self.aspect_ratio[1] - self.aspect_ratio[0]) * random.random() 100 | temp_aspect_ratio = math.sqrt(temp_aspect_ratio) 101 | scale_factor_x = temp_scale * temp_aspect_ratio 102 | scale_factor_y = temp_scale / temp_aspect_ratio 103 | image = cv2.resize(image, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_LINEAR) 104 | label = cv2.resize(label, None, fx=scale_factor_x, fy=scale_factor_y, interpolation=cv2.INTER_NEAREST) 105 | return image, label 106 | 107 | 108 | class Crop(object): 109 | """Crops the given ndarray image (H*W*C or H*W). 110 | Args: 111 | size (sequence or int): Desired output size of the crop. If size is an 112 | int instead of sequence like (h, w), a square crop (size, size) is made. 113 | """ 114 | 115 | def __init__(self, size, crop_type='center', padding=None, ignore_label=255): 116 | if isinstance(size, int): 117 | self.crop_h = size 118 | self.crop_w = size 119 | elif isinstance(size, collections.Iterable) and len(size) == 2 \ 120 | and isinstance(size[0], int) and isinstance(size[1], int) \ 121 | and size[0] > 0 and size[1] > 0: 122 | self.crop_h = size[0] 123 | self.crop_w = size[1] 124 | else: 125 | raise (RuntimeError("crop size error.\n")) 126 | if crop_type == 'center' or crop_type == 'rand': 127 | self.crop_type = crop_type 128 | else: 129 | raise (RuntimeError("crop type error: rand | center\n")) 130 | if padding is None: 131 | self.padding = padding 132 | elif isinstance(padding, list): 133 | if all(isinstance(i, numbers.Number) for i in padding): 134 | self.padding = padding 135 | else: 136 | raise (RuntimeError("padding in Crop() should be a number list\n")) 137 | if len(padding) != 3: 138 | raise (RuntimeError("padding channel is not equal with 3\n")) 139 | else: 140 | raise (RuntimeError("padding in Crop() should be a number list\n")) 141 | if isinstance(ignore_label, int): 142 | self.ignore_label = ignore_label 143 | else: 144 | raise (RuntimeError("ignore_label should be an integer number\n")) 145 | 146 | def __call__(self, image, label): 147 | h, w = label.shape 148 | pad_h = max(self.crop_h - h, 0) 149 | pad_w = max(self.crop_w - w, 0) 150 | pad_h_half = int(pad_h / 2) 151 | pad_w_half = int(pad_w / 2) 152 | if pad_h > 0 or pad_w > 0: 153 | if self.padding is None: 154 | raise (RuntimeError("segtransform.Crop() need padding while padding argument is None\n")) 155 | image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, 156 | cv2.BORDER_CONSTANT, value=self.padding) 157 | label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, 158 | cv2.BORDER_CONSTANT, value=self.ignore_label) 159 | h, w = label.shape 160 | if self.crop_type == 'rand': 161 | h_off = random.randint(0, h - self.crop_h) 162 | w_off = random.randint(0, w - self.crop_w) 163 | else: 164 | h_off = int((h - self.crop_h) / 2) 165 | w_off = int((w - self.crop_w) / 2) 166 | image = image[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 167 | label = label[h_off:h_off + self.crop_h, w_off:w_off + self.crop_w] 168 | return image, label 169 | 170 | 171 | class RandRotate(object): 172 | # Randomly rotate image & label with rotate factor in [rotate_min, rotate_max] 173 | def __init__(self, rotate, padding, ignore_label=255, p=0.5): 174 | assert (isinstance(rotate, collections.Iterable) and len(rotate) == 2) 175 | if isinstance(rotate[0], numbers.Number) and isinstance(rotate[1], numbers.Number) and rotate[0] < rotate[1]: 176 | self.rotate = rotate 177 | else: 178 | raise (RuntimeError("segtransform.RandRotate() scale param error.\n")) 179 | assert padding is not None 180 | assert isinstance(padding, list) and len(padding) == 3 181 | if all(isinstance(i, numbers.Number) for i in padding): 182 | self.padding = padding 183 | else: 184 | raise (RuntimeError("padding in RandRotate() should be a number list\n")) 185 | assert isinstance(ignore_label, int) 186 | self.ignore_label = ignore_label 187 | self.p = p 188 | 189 | def __call__(self, image, label): 190 | if random.random() < self.p: 191 | angle = self.rotate[0] + (self.rotate[1] - self.rotate[0]) * random.random() 192 | h, w = label.shape 193 | matrix = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 194 | image = cv2.warpAffine(image, matrix, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, 195 | borderValue=self.padding) 196 | label = cv2.warpAffine(label, matrix, (w, h), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, 197 | borderValue=self.ignore_label) 198 | return image, label 199 | 200 | 201 | class RandomHorizontalFlip(object): 202 | def __init__(self, p=0.5): 203 | self.p = p 204 | 205 | def __call__(self, image, label): 206 | if random.random() < self.p: 207 | image = cv2.flip(image, 1) 208 | label = cv2.flip(label, 1) 209 | return image, label 210 | 211 | 212 | class RandomVerticalFlip(object): 213 | def __init__(self, p=0.5): 214 | self.p = p 215 | 216 | def __call__(self, image, label): 217 | if random.random() < self.p: 218 | image = cv2.flip(image, 0) 219 | label = cv2.flip(label, 0) 220 | return image, label 221 | 222 | 223 | class RandomGaussianBlur(object): 224 | def __init__(self, radius=5): 225 | self.radius = radius 226 | 227 | def __call__(self, image, label): 228 | if random.random() < 0.5: 229 | image = cv2.GaussianBlur(image, (self.radius, self.radius), 0) 230 | return image, label 231 | 232 | 233 | class RGB2BGR(object): 234 | # Converts image from RGB order to BGR order, for model initialized from Caffe 235 | def __call__(self, image, label): 236 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 237 | return image, label 238 | 239 | 240 | class BGR2RGB(object): 241 | # Converts image from BGR order to RGB order, for model initialized from Pytorch 242 | def __call__(self, image, label): 243 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 244 | return image, label 245 | -------------------------------------------------------------------------------- /dataset/pregroup_2d_scannet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import imageio 4 | import numpy as np 5 | from glob import glob 6 | from tqdm import tqdm 7 | 8 | from pointnet2_ops import pointnet2_utils 9 | from dataset.scanNetCross import LinkCreator 10 | 11 | 12 | def main(root, split): 13 | linkCreator = LinkCreator(image_dim=(320, 240), voxelSize=0.05) 14 | data3d_paths = sorted(glob(os.path.join(root, '3D', split, '*.pth'))) 15 | data2d_paths = [] 16 | room_groups = {} 17 | for x in data3d_paths: 18 | ps = glob(os.path.join(x[:-15].replace(split, '2D'), 'color', '*.jpg')) 19 | ps.sort(key=lambda x: int(x.split('/')[-1].split('.')[0])) 20 | data2d_paths.append(ps) 21 | for room_id in tqdm(range(len(data3d_paths))): 22 | coords, _, _ = torch.load(data3d_paths[room_id]) 23 | quartile = (coords[:, -1].max() - coords[:, -1].min()) * 0.1 24 | qlower = coords[:, -1].min() + quartile 25 | qupper = coords[:, -1].max() - quartile 26 | coords_room = torch.from_numpy(coords[(coords[:, -1] > qlower) & (coords[:, -1] < qupper)]).unsqueeze(dim=0).cuda() 27 | fps_idx = pointnet2_utils.furthest_point_sample(coords_room, 5) 28 | fps_coords = pointnet2_utils.gather_operation(coords_room.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous().cpu() 29 | 30 | frames_path = data2d_paths[room_id] 31 | range_frames = [] 32 | centroid_frames = [] 33 | delete_frames = [] 34 | for f in frames_path: 35 | depth = imageio.imread(f.replace('color', 'depth').replace('jpg', 'png')) / 1000.0 36 | posePath = f.replace('color', 'pose').replace('.jpg', '.txt') 37 | pose = np.asarray( 38 | [[float(x[0]), float(x[1]), float(x[2]), float(x[3])] for x in 39 | (x.split(" ") for x in open(posePath).read().splitlines())] 40 | ) 41 | H, W = depth.shape 42 | link = linkCreator.computeLinking(pose, coords, depth) 43 | coords_link = coords[link[:, 2] == 1] 44 | if len(coords_link) < 10: 45 | delete_frames.append(f.split('/')[-1]) 46 | continue 47 | mu = coords_link.mean(axis=0, keepdims=True) 48 | sigma = coords_link.std(axis=0, keepdims=True) 49 | coords_link = coords_link[(coords_link < mu + 3 * sigma).all(axis=1) & (coords_link > mu - 3 * sigma).all(axis=1)] 50 | centroid = coords_link.mean(axis=0) 51 | centroid_frames.append(centroid) 52 | range_min = coords_link.min(axis=0) 53 | range_max = coords_link.max(axis=0) 54 | range_xyz = np.stack([range_min, range_max], axis=0) 55 | range_frames.append(range_xyz) 56 | range_frames = torch.from_numpy(np.stack(range_frames, axis=0)) 57 | centroid_frames = torch.from_numpy(np.stack(centroid_frames, axis=0)) 58 | distance = torch.norm(fps_coords - centroid_frames.unsqueeze(dim=1), p=2, dim=-1) 59 | group_nearest = distance.argmin(dim=1) 60 | room_name = data3d_paths[room_id].split('/')[2][:12] 61 | room_groups[room_name] = {} 62 | room_groups[room_name]['group_centroid'] = fps_coords 63 | room_groups[room_name]['frames_group'] = group_nearest 64 | room_groups[room_name]['frames_range'] = range_frames 65 | room_groups[room_name]['frames_centroid'] = centroid_frames 66 | room_groups[room_name]['frames_delete'] = delete_frames 67 | torch.save(room_groups, 'data/view_groups/view_groups_'+split+'pth') 68 | 69 | def get_stats(group): 70 | return {'coords_z': group.sum()} 71 | 72 | 73 | if __name__ == "__main__": 74 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 75 | root = 'data' 76 | split = 'val' 77 | main(root, split) -------------------------------------------------------------------------------- /dataset/preprocess_3d_scannet.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import multiprocessing as mp 3 | import numpy as np 4 | import plyfile 5 | import torch 6 | 7 | # Map relevant classes to {0,1,...,19}, and ignored classes to 255 8 | remapper = np.ones(150) * (255) 9 | for i, x in enumerate([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]): 10 | remapper[x] = i 11 | 12 | 13 | def f(fn): 14 | fn2 = fn[:-3] + 'labels.ply' 15 | a = plyfile.PlyData().read(fn) 16 | v = np.array([list(x) for x in a.elements[0]]) 17 | coords = np.ascontiguousarray(v[:, :3]) 18 | colors = np.ascontiguousarray(v[:, 3:6]) / 127.5 - 1 19 | a = plyfile.PlyData().read(fn2) 20 | w = remapper[np.array(a.elements[0]['label'])] 21 | torch.save((coords, colors, w), fn[:-4] + '.pth') 22 | print(fn, fn2) 23 | 24 | 25 | files = sorted(glob.glob('PATH_OF_TRAIN/*_vh_clean_2.ply')) 26 | files2 = sorted(glob.glob('PATH_OF_TRAIN/*_vh_clean_2.labels.ply')) 27 | assert len(files) == len(files2) 28 | 29 | p = mp.Pool(processes=mp.cpu_count()) 30 | p.map(f, files) 31 | p.close() 32 | p.join() 33 | 34 | files = sorted(glob.glob('PATH_OF_VAL/*_vh_clean_2.ply')) 35 | files2 = sorted(glob.glob('PATH_OF_VAL/*_vh_clean_2.labels.ply')) 36 | assert len(files) == len(files2) 37 | 38 | p = mp.Pool(processes=mp.cpu_count()) 39 | p.map(f, files) 40 | p.close() 41 | p.join() 42 | -------------------------------------------------------------------------------- /dataset/scanNet3D.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torch 3 | import numpy as np 4 | from os.path import join, exists 5 | from glob import glob 6 | import multiprocessing as mp 7 | import SharedArray as SA 8 | import dataset.augmentation as t 9 | from dataset.voxelizer import Voxelizer 10 | 11 | 12 | def sa_create(name, var): 13 | x = SA.create(name, var.shape, dtype=var.dtype) 14 | x[...] = var[...] 15 | x.flags.writeable = False 16 | return x 17 | 18 | 19 | def collation_fn(batch): 20 | """ 21 | :param batch: 22 | :return: coords_batch: N x 4 (x,y,z,batch) 23 | 24 | """ 25 | coords, feats, labels = list(zip(*batch)) 26 | 27 | for i in range(len(coords)): 28 | coords[i][:, 0] *= i 29 | 30 | return torch.cat(coords), torch.cat(feats), torch.cat(labels) 31 | 32 | 33 | def collation_fn_eval_all(batch): 34 | """ 35 | :param batch: 36 | :return: coords_batch: N x 4 (x,y,z,batch) 37 | 38 | """ 39 | coords, feats, labels, inds_recons = list(zip(*batch)) 40 | inds_recons = list(inds_recons) 41 | # pdb.set_trace() 42 | 43 | accmulate_points_num = 0 44 | for i in range(len(coords)): 45 | coords[i][:, 0] *= i 46 | inds_recons[i] = accmulate_points_num + inds_recons[i] 47 | accmulate_points_num += coords[i].shape[0] 48 | 49 | return torch.cat(coords), torch.cat(feats), torch.cat(labels), torch.cat(inds_recons) 50 | 51 | 52 | class ScanNet3D(data.Dataset): 53 | # Augmentation arguments 54 | SCALE_AUGMENTATION_BOUND = (0.9, 1.1) 55 | ROTATION_AUGMENTATION_BOUND = ((-np.pi / 64, np.pi / 64), (-np.pi / 64, np.pi / 64), (-np.pi, 56 | np.pi)) 57 | TRANSLATION_AUGMENTATION_RATIO_BOUND = ((-0.2, 0.2), (-0.2, 0.2), (0, 0)) 58 | ELASTIC_DISTORT_PARAMS = ((0.2, 0.4), (0.8, 1.6)) 59 | 60 | ROTATION_AXIS = 'z' 61 | LOCFEAT_IDX = 2 62 | 63 | def __init__(self, dataPathPrefix='Data', voxelSize=0.05, 64 | split='train', aug=False, memCacheInit=False, identifier=1247, loop=1, 65 | data_aug_color_trans_ratio=0.1, data_aug_color_jitter_std=0.05, data_aug_hue_max=0.5, 66 | data_aug_saturation_max=0.2, eval_all=False 67 | ): 68 | super(ScanNet3D, self).__init__() 69 | self.split = split 70 | self.identifier = identifier 71 | self.data_paths = sorted(glob(join(dataPathPrefix, '3D', split, '*.pth'))) 72 | self.voxelSize = voxelSize 73 | self.aug = aug 74 | self.loop = loop 75 | self.eval_all = eval_all 76 | 77 | self.voxelizer = Voxelizer( 78 | voxel_size=voxelSize, 79 | clip_bound=None, 80 | use_augmentation=True, 81 | scale_augmentation_bound=self.SCALE_AUGMENTATION_BOUND, 82 | rotation_augmentation_bound=self.ROTATION_AUGMENTATION_BOUND, 83 | translation_augmentation_ratio_bound=self.TRANSLATION_AUGMENTATION_RATIO_BOUND) 84 | 85 | if aug: 86 | prevoxel_transform_train = [t.ElasticDistortion(self.ELASTIC_DISTORT_PARAMS)] 87 | self.prevoxel_transforms = t.Compose(prevoxel_transform_train) 88 | input_transforms = [ 89 | t.RandomHorizontalFlip(self.ROTATION_AXIS, is_temporal=False), 90 | t.ChromaticAutoContrast(), 91 | t.ChromaticTranslation(data_aug_color_trans_ratio), 92 | t.ChromaticJitter(data_aug_color_jitter_std), 93 | t.HueSaturationTranslation(data_aug_hue_max, data_aug_saturation_max), 94 | ] 95 | self.input_transforms = t.Compose(input_transforms) 96 | 97 | if memCacheInit and (not exists("/dev/shm/wbhu_scannet_3d_%s_%06d_locs_%08d" % (split, identifier, 0))): 98 | print('[*] Starting shared memory init ...') 99 | for i, (locs, feats, labels) in enumerate(torch.utils.data.DataLoader( 100 | self.data_paths, collate_fn=lambda x: torch.load(x[0]), 101 | num_workers=min(8, mp.cpu_count()), shuffle=False)): 102 | labels[labels == -100] = 255 103 | labels = labels.astype(np.uint8) 104 | # Scale color to 0-255 105 | feats = (feats + 1.) * 127.5 106 | sa_create("shm://wbhu_scannet_3d_%s_%06d_locs_%08d" % (split, identifier, i), locs) 107 | sa_create("shm://wbhu_scannet_3d_%s_%06d_feats_%08d" % (split, identifier, i), feats) 108 | sa_create("shm://wbhu_scannet_3d_%s_%06d_labels_%08d" % (split, identifier, i), labels) 109 | 110 | print('[*] %s (%s) loading done (%d)! ' % (dataPathPrefix, split, len(self.data_paths))) 111 | 112 | def __getitem__(self, index_long): 113 | index = index_long % len(self.data_paths) 114 | locs_in = SA.attach("shm://wbhu_scannet_3d_%s_%06d_locs_%08d" % (self.split, self.identifier, index)).copy() 115 | feats_in = SA.attach("shm://wbhu_scannet_3d_%s_%06d_feats_%08d" % (self.split, self.identifier, index)).copy() 116 | labels_in = SA.attach("shm://wbhu_scannet_3d_%s_%06d_labels_%08d" % (self.split, self.identifier, index)).copy() 117 | 118 | locs = self.prevoxel_transforms(locs_in) if self.aug else locs_in 119 | locs, feats, labels, inds_reconstruct = self.voxelizer.voxelize(locs, feats_in, labels_in) 120 | if self.eval_all: 121 | labels = labels_in 122 | if self.aug: 123 | locs, feats, labels = self.input_transforms(locs, feats, labels) 124 | coords = torch.from_numpy(locs).int() 125 | coords = torch.cat((torch.ones(coords.shape[0], 1, dtype=torch.int), coords), dim=1) 126 | feats = torch.from_numpy(feats).float() / 127.5 - 1. 127 | labels = torch.from_numpy(labels).long() 128 | 129 | if self.eval_all: 130 | return coords, feats, labels, torch.from_numpy(inds_reconstruct).long() 131 | return coords, feats, labels 132 | 133 | def __len__(self): 134 | return len(self.data_paths) * self.loop 135 | 136 | 137 | if __name__ == '__main__': 138 | import time, random 139 | from tensorboardX import SummaryWriter 140 | 141 | data_root = '/research/dept6/wbhu/Dataset/ScanNet' 142 | train_data = ScanNet3D(dataPathPrefix=data_root, aug=True, split='train', memCacheInit=True, voxelSize=0.05) 143 | val_data = ScanNet3D(dataPathPrefix=data_root, aug=False, split='val', memCacheInit=True, voxelSize=0.05, 144 | eval_all=True) 145 | 146 | manual_seed = 123 147 | 148 | 149 | def worker_init_fn(worker_id): 150 | random.seed(manual_seed + worker_id) 151 | 152 | 153 | random.seed(manual_seed) 154 | np.random.seed(manual_seed) 155 | torch.manual_seed(manual_seed) 156 | torch.cuda.manual_seed_all(manual_seed) 157 | 158 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True, num_workers=4, pin_memory=True, 159 | worker_init_fn=worker_init_fn, collate_fn=collation_fn) 160 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=16, shuffle=False, num_workers=4, pin_memory=True, 161 | worker_init_fn=worker_init_fn, collate_fn=collation_fn_eval_all) 162 | trainLog = SummaryWriter('Exp/scannet/statistic/train') 163 | valLog = SummaryWriter('Exp/scannet/statistic/val') 164 | 165 | for idx in range(1): 166 | end = time.time() 167 | for step, (coords, feat, label) in enumerate(train_loader): 168 | print( 169 | 'time: {}/{}--{}'.format(step + 1, len(train_loader), time.time() - end)) 170 | trainLog.add_histogram('voxel_coord_x', coords[:, 0], global_step=step) 171 | trainLog.add_histogram('voxel_coord_y', coords[:, 1], global_step=step) 172 | trainLog.add_histogram('voxel_coord_z', coords[:, 2], global_step=step) 173 | trainLog.add_histogram('color', feat, global_step=step) 174 | # time.sleep(0.3) 175 | end = time.time() 176 | 177 | for step, (coords, feat, label, inds_reverse) in enumerate(val_loader): 178 | print( 179 | 'time: {}/{}--{}'.format(step + 1, len(val_loader), time.time() - end)) 180 | valLog.add_histogram('voxel_coord_x', coords[:, 0], global_step=step) 181 | valLog.add_histogram('voxel_coord_y', coords[:, 1], global_step=step) 182 | valLog.add_histogram('voxel_coord_z', coords[:, 2], global_step=step) 183 | valLog.add_histogram('color', feat, global_step=step) 184 | # time.sleep(0.3) 185 | end = time.time() 186 | 187 | trainLog.close() 188 | valLog.close() 189 | -------------------------------------------------------------------------------- /dataset/scanNetCross.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import torch 5 | import numpy as np 6 | from os.path import join 7 | from itertools import compress 8 | from glob import glob 9 | import SharedArray as SA 10 | import imageio 11 | 12 | import dataset.augmentation_2d as t_2d 13 | from dataset.scanNet3D import ScanNet3D 14 | 15 | 16 | # create camera intrinsics 17 | def make_intrinsic(fx, fy, mx, my): 18 | intrinsic = np.eye(4) 19 | intrinsic[0][0] = fx 20 | intrinsic[1][1] = fy 21 | intrinsic[0][2] = mx 22 | intrinsic[1][2] = my 23 | return intrinsic 24 | 25 | 26 | # create camera intrinsics 27 | def adjust_intrinsic(intrinsic, intrinsic_image_dim, image_dim): 28 | if intrinsic_image_dim == image_dim: 29 | return intrinsic 30 | resize_width = int(math.floor(image_dim[1] * float(intrinsic_image_dim[0]) / float(intrinsic_image_dim[1]))) 31 | intrinsic[0, 0] *= float(resize_width) / float(intrinsic_image_dim[0]) 32 | intrinsic[1, 1] *= float(image_dim[1]) / float(intrinsic_image_dim[1]) 33 | # account for cropping here 34 | intrinsic[0, 2] *= float(image_dim[0] - 1) / float(intrinsic_image_dim[0] - 1) 35 | intrinsic[1, 2] *= float(image_dim[1] - 1) / float(intrinsic_image_dim[1] - 1) 36 | return intrinsic 37 | 38 | 39 | class LinkCreator(object): 40 | def __init__(self, fx=577.870605, fy=577.870605, mx=319.5, my=239.5, image_dim=(320, 240), voxelSize=0.05): 41 | self.intricsic = make_intrinsic(fx=fx, fy=fy, mx=mx, my=my) 42 | self.intricsic = adjust_intrinsic(self.intricsic, intrinsic_image_dim=[640, 480], image_dim=image_dim) 43 | self.imageDim = image_dim 44 | self.voxel_size = voxelSize 45 | 46 | def computeLinking(self, camera_to_world, coords, depth): 47 | """ 48 | :param camera_to_world: 4 x 4 49 | :param coords: N x 3 format 50 | :param depth: H x W format 51 | :return: linking, N x 3 format, (H,W,mask) 52 | """ 53 | link = np.zeros((3, coords.shape[0]), dtype=np.int) 54 | coordsNew = np.concatenate([coords, np.ones([coords.shape[0], 1])], axis=1).T 55 | assert coordsNew.shape[0] == 4, "[!] Shape error" 56 | 57 | world_to_camera = np.linalg.inv(camera_to_world) 58 | p = np.matmul(world_to_camera, coordsNew) 59 | p[0] = (p[0] * self.intricsic[0][0]) / p[2] + self.intricsic[0][2] 60 | p[1] = (p[1] * self.intricsic[1][1]) / p[2] + self.intricsic[1][2] 61 | pi = np.round(p).astype(np.int) 62 | inside_mask = (pi[0] >= 0) * (pi[1] >= 0) \ 63 | * (pi[0] < self.imageDim[0]) * (pi[1] < self.imageDim[1]) 64 | occlusion_mask = np.abs(depth[pi[1][inside_mask], pi[0][inside_mask]] 65 | - p[2][inside_mask]) <= self.voxel_size 66 | inside_mask[inside_mask == True] = occlusion_mask 67 | link[0][inside_mask] = pi[1][inside_mask] 68 | link[1][inside_mask] = pi[0][inside_mask] 69 | link[2][inside_mask] = 1 70 | 71 | return link.T 72 | 73 | 74 | class ScanNetCross(ScanNet3D): 75 | IMG_DIM = (320, 240) 76 | 77 | def __init__(self, dataPathPrefix='Data', voxelSize=0.05, 78 | split='train', aug=False, memCacheInit=False, 79 | identifier=7439, loop=1, 80 | data_aug_color_trans_ratio=0.1, 81 | data_aug_color_jitter_std=0.05, data_aug_hue_max=0.5, 82 | data_aug_saturation_max=0.2, eval_all=False, 83 | val_benchmark=False, view_num=5 84 | ): 85 | super(ScanNetCross, self).__init__(dataPathPrefix=dataPathPrefix, voxelSize=voxelSize, 86 | split=split, aug=aug, memCacheInit=memCacheInit, 87 | identifier=identifier, loop=loop, 88 | data_aug_color_trans_ratio=data_aug_color_trans_ratio, 89 | data_aug_color_jitter_std=data_aug_color_jitter_std, 90 | data_aug_hue_max=data_aug_hue_max, 91 | data_aug_saturation_max=data_aug_saturation_max, 92 | eval_all=eval_all) 93 | self.VIEW_NUM = view_num 94 | self.split = split 95 | self.val_benchmark = val_benchmark 96 | if self.val_benchmark: 97 | self.offset = 0 98 | # Prepare for 2D 99 | self.data2D_paths = [] 100 | for x in self.data_paths: 101 | ps = glob(join(x[:-15].replace('3D/'+split, '2D'), 'color', '*.jpg')) 102 | assert len(ps) >= self.VIEW_NUM, '[!] %s has only %d frames, less than expected %d samples' % ( 103 | x, len(ps), self.VIEW_NUM) 104 | ps.sort(key=lambda x: int(x.split('/')[-1].split('.')[0])) 105 | if val_benchmark: 106 | ps = ps[::5] 107 | self.data2D_paths.append(ps) 108 | 109 | self.remapper = np.ones(256) * 255 110 | for i, x in enumerate([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]): 111 | self.remapper[x] = i 112 | 113 | self.linkCreator = LinkCreator(image_dim=self.IMG_DIM, voxelSize=voxelSize) 114 | self.view_groups = torch.load(os.path.join(dataPathPrefix, 'view_groups', 'view_groups_'+split+'.pth')) 115 | # 2D AUG 116 | value_scale = 255 117 | mean = [0.485, 0.456, 0.406] 118 | mean = [item * value_scale for item in mean] 119 | std = [0.229, 0.224, 0.225] 120 | std = [item * value_scale for item in std] 121 | if self.aug: 122 | self.transform_2d = t_2d.Compose([ 123 | t_2d.RandomGaussianBlur(), 124 | t_2d.Crop([self.IMG_DIM[1] + 1, self.IMG_DIM[0] + 1], crop_type='rand', padding=mean, 125 | ignore_label=255), 126 | t_2d.ToTensor(), 127 | t_2d.Normalize(mean=mean, std=std)]) 128 | else: 129 | self.transform_2d = t_2d.Compose([ 130 | t_2d.Crop([self.IMG_DIM[1] + 1, self.IMG_DIM[0] + 1], crop_type='rand', padding=mean, 131 | ignore_label=255), 132 | t_2d.ToTensor(), 133 | t_2d.Normalize(mean=mean, std=std)]) 134 | 135 | def __getitem__(self, index_long): 136 | index = index_long % len(self.data_paths) 137 | locs_in = SA.attach("shm://wbhu_scannet_3d_%s_%06d_locs_%08d" % (self.split, self.identifier, index)) 138 | feats_in = SA.attach("shm://wbhu_scannet_3d_%s_%06d_feats_%08d" % (self.split, self.identifier, index)) 139 | labels_in = SA.attach("shm://wbhu_scannet_3d_%s_%06d_labels_%08d" % (self.split, self.identifier, index)) 140 | 141 | colors, labels_2d, links = self.get_2d_group(index, locs_in) 142 | 143 | locs = self.prevoxel_transforms(locs_in) if self.aug else locs_in 144 | locs, feats, labels, inds_reconstruct, links = self.voxelizer.voxelize(locs, feats_in, labels_in, link=links) 145 | if self.eval_all: 146 | labels = labels_in 147 | if self.aug: 148 | locs, feats, labels = self.input_transforms(locs, feats, labels) 149 | coords = torch.from_numpy(locs).int() 150 | coords = torch.cat((torch.ones(coords.shape[0], 1, dtype=torch.int), coords), dim=1) 151 | feats = torch.from_numpy(feats).float() / 127.5 - 1. 152 | labels = torch.from_numpy(labels).long() 153 | 154 | if self.eval_all: 155 | return coords, feats, labels, colors, labels_2d, links, torch.from_numpy(inds_reconstruct).long() 156 | return coords, feats, labels, colors, labels_2d, links 157 | 158 | def get_2d_group(self, room_id, coords: np.ndarray): 159 | """ 160 | :param room_id: 161 | :param coords: Nx3 162 | :return: imgs: CxHxWxV Tensor 163 | labels: HxWxV Tensor 164 | links: Nx4xV(1,H,W,mask) Tensor 165 | """ 166 | room_name = self.data_paths[room_id].split('/')[-1][:12] 167 | frames_path = self.data2D_paths[room_id] 168 | groups = self.view_groups[room_name] 169 | if not self.val_benchmark and len(groups['frames_delete']) > 0: 170 | prefix = frames_path[0].split('/') 171 | for delete_frame in groups['frames_delete']: 172 | prefix[-1] = delete_frame 173 | delete_name = '/'.join(prefix) 174 | if delete_name in frames_path: 175 | frames_path.remove(delete_name) 176 | partial = int(len(frames_path) / 5) 177 | imgs, labels, links, names = [], [], [], [] 178 | if self.VIEW_NUM < 5: 179 | select = torch.randperm(5)[:self.VIEW_NUM] 180 | select = random.sample(range(0, 5), self.VIEW_NUM) 181 | else: 182 | select = range(self.VIEW_NUM) 183 | for v in select: 184 | if not self.val_benchmark and len(groups['frames_group']) == len(frames_path): 185 | group_list = list(compress(frames_path, (groups['frames_group'] == v))) 186 | if len(group_list) < 1: 187 | group_list = frames_path[v * partial:v * partial + partial] 188 | f = random.sample(group_list, k=1)[0] 189 | else: 190 | partial_benchmark = int(len(frames_path) / self.VIEW_NUM) 191 | select_id = (v * partial_benchmark+self.offset) % len(frames_path) 192 | f = frames_path[select_id] 193 | # pdb.set_trace() 194 | img = imageio.imread(f) 195 | label = imageio.imread(f.replace('color', 'label').replace('jpg', 'png')) 196 | label = self.remapper[label] 197 | depth = imageio.imread(f.replace('color', 'depth').replace('jpg', 'png')) / 1000.0 # convert to meter 198 | posePath = f.replace('color', 'pose').replace('.jpg', '.txt') 199 | pose = np.asarray( 200 | [[float(x[0]), float(x[1]), float(x[2]), float(x[3])] for x in 201 | (x.split(" ") for x in open(posePath).read().splitlines())] 202 | ) 203 | # pdb.set_trace() 204 | link = np.ones([coords.shape[0], 4], dtype=np.int) 205 | link[:, 1:4] = self.linkCreator.computeLinking(pose, coords, depth) 206 | img, label = self.transform_2d(img, label) 207 | imgs.append(img) 208 | labels.append(label) 209 | links.append(link) 210 | names.append(f) 211 | 212 | imgs = torch.stack(imgs, dim=-1) 213 | labels = torch.stack(labels, dim=-1) 214 | links = np.stack(links, axis=-1) 215 | links = torch.from_numpy(links) 216 | names = np.stack(names, axis=-1) 217 | return imgs, labels, links 218 | 219 | 220 | def collation_fn(batch): 221 | """ 222 | :param batch: 223 | :return: coords: N x 4 (batch,x,y,z) 224 | feats: N x 3 225 | labels: N 226 | colors: B x C x H x W x V 227 | labels_2d: B x H x W x V 228 | links: N x 4 x V (B,H,W,mask) 229 | 230 | """ 231 | coords, feats, labels, colors, labels_2d, links = list(zip(*batch)) 232 | # pdb.set_trace() 233 | 234 | for i in range(len(coords)): 235 | coords[i][:, 0] *= i 236 | links[i][:, 0, :] *= i 237 | 238 | return torch.cat(coords), torch.cat(feats), torch.cat(labels), \ 239 | torch.stack(colors), torch.stack(labels_2d), torch.cat(links) 240 | 241 | 242 | def collation_fn_eval_all(batch): 243 | """ 244 | :param batch: 245 | :return: coords: N x 4 (x,y,z,batch) 246 | feats: N x 3 247 | labels: N 248 | colors: B x C x H x W x V 249 | labels_2d: B x H x W x V 250 | links: N x 4 x V (B,H,W,mask) 251 | inds_recons:ON 252 | 253 | """ 254 | coords, feats, labels, colors, labels_2d, links, inds_recons = list(zip(*batch)) 255 | inds_recons = list(inds_recons) 256 | # pdb.set_trace() 257 | 258 | accmulate_points_num = 0 259 | for i in range(len(coords)): 260 | coords[i][:, 0] *= i 261 | links[i][:, 0, :] *= i 262 | inds_recons[i] = accmulate_points_num + inds_recons[i] 263 | accmulate_points_num += coords[i].shape[0] 264 | 265 | return torch.cat(coords), torch.cat(feats), torch.cat(labels), \ 266 | torch.stack(colors), torch.stack(labels_2d), torch.cat(links), torch.cat(inds_recons) 267 | 268 | 269 | if __name__ == '__main__': 270 | import time 271 | from tensorboardX import SummaryWriter 272 | 273 | data_root = '/research/dept6/wbhu/Dataset/ScanNet' 274 | train_data = ScanNetCross(dataPathPrefix=data_root, aug=True, split='train', memCacheInit=True, voxelSize=0.05) 275 | val_data = ScanNetCross(dataPathPrefix=data_root, aug=False, split='val', memCacheInit=True, voxelSize=0.05, 276 | eval_all=True) 277 | coords, feats, labels, colors, labels_2d, links = train_data.__getitem__(0) 278 | print(coords.shape, feats.shape, labels.shape, colors.shape, labels_2d.shape, links.shape) 279 | coords, feats, labels, colors, labels_2d, links, inds_recons = val_data.__getitem__(0) 280 | print(coords.shape, feats.shape, labels.shape, colors.shape, labels_2d.shape, links.shape, inds_recons.shape) 281 | exit(0) 282 | 283 | manual_seed = 123 284 | 285 | 286 | def worker_init_fn(worker_id): 287 | random.seed(manual_seed + worker_id) 288 | 289 | 290 | random.seed(manual_seed) 291 | np.random.seed(manual_seed) 292 | torch.manual_seed(manual_seed) 293 | torch.cuda.manual_seed_all(manual_seed) 294 | 295 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True, num_workers=2, pin_memory=True, 296 | worker_init_fn=worker_init_fn, collate_fn=collation_fn) 297 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=8, shuffle=False, num_workers=2, pin_memory=True, 298 | worker_init_fn=worker_init_fn, collate_fn=collation_fn_eval_all) 299 | # _ = iter(train_loader).__next__() 300 | trainLog = SummaryWriter('Exp/scannet/statistic_cross/train') 301 | valLog = SummaryWriter('Exp/scannet/statistic_cross/val') 302 | 303 | for idx in range(1): 304 | end = time.time() 305 | for step, (coords, feats, labels, colors, labels_2d, links) in enumerate(train_loader): 306 | print( 307 | 'time: {}/{}--{}'.format(step + 1, len(train_loader), time.time() - end)) 308 | trainLog.add_histogram('voxel_coord_x', coords[:, 0], global_step=step) 309 | trainLog.add_histogram('voxel_coord_y', coords[:, 1], global_step=step) 310 | trainLog.add_histogram('voxel_coord_z', coords[:, 2], global_step=step) 311 | trainLog.add_histogram('color', feats, global_step=step) 312 | trainLog.add_histogram('2D_image', colors, global_step=step) 313 | # time.sleep(0.3) 314 | end = time.time() 315 | 316 | for step, (coords, feats, labels, colors, labels_2d, links, inds_reverse) in enumerate(val_loader): 317 | print( 318 | 'time: {}/{}--{}'.format(step + 1, len(val_loader), time.time() - end)) 319 | valLog.add_histogram('voxel_coord_x', coords[:, 0], global_step=step) 320 | valLog.add_histogram('voxel_coord_y', coords[:, 1], global_step=step) 321 | valLog.add_histogram('voxel_coord_z', coords[:, 2], global_step=step) 322 | valLog.add_histogram('color', feats, global_step=step) 323 | valLog.add_histogram('2D_image', colors, global_step=step) 324 | # time.sleep(0.3) 325 | end = time.time() 326 | 327 | trainLog.close() 328 | valLog.close() 329 | -------------------------------------------------------------------------------- /dataset/scannet/GL.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from os.path import join 3 | 4 | train_scenes = [x.split('/')[-1][:-15] for x in glob('train/*.pth')] 5 | 6 | with open('list/train_2D.txt','w') as f: 7 | for s in train_scenes: 8 | for im in glob(join('2D',s,'color','*.jpg')): 9 | f.write(im+' '+im.replace('color','label').replace('.jpg','.png')+'\n') 10 | 11 | 12 | val_scenes = [x.split('/')[-1][:-15] for x in glob('val/*.pth')] 13 | 14 | with open('list/val_2D.txt','w') as f: 15 | for s in val_scenes: 16 | for im in glob(join('2D',s,'color','*.jpg')): 17 | f.write(im+' '+im.replace('color','label').replace('.jpg','.png')+'\n') 18 | -------------------------------------------------------------------------------- /dataset/scannet/scannet_names.txt: -------------------------------------------------------------------------------- 1 | bathtub 2 | bed 3 | bookshelf 4 | cabinet 5 | chair 6 | counter 7 | curtain 8 | desk 9 | door 10 | floor 11 | otherfurniture 12 | picture 13 | refrigerator 14 | showercurtain 15 | sink 16 | sofa 17 | table 18 | toilet 19 | wall 20 | window 21 | -------------------------------------------------------------------------------- /dataset/voxelization_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import torch 25 | import numpy as np 26 | from collections import Sequence 27 | 28 | 29 | def fnv_hash_vec(arr): 30 | """ 31 | FNV64-1A 32 | """ 33 | assert arr.ndim == 2 34 | # Floor first for negative coordinates 35 | arr = arr.copy() 36 | arr = arr.astype(np.uint64, copy=False) 37 | hashed_arr = np.uint64(14695981039346656037) * \ 38 | np.ones(arr.shape[0], dtype=np.uint64) 39 | for j in range(arr.shape[1]): 40 | hashed_arr *= np.uint64(1099511628211) 41 | hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) 42 | return hashed_arr 43 | 44 | 45 | def ravel_hash_vec(arr): 46 | """ 47 | Ravel the coordinates after subtracting the min coordinates. 48 | """ 49 | assert arr.ndim == 2 50 | arr = arr.copy() 51 | arr -= arr.min(0) 52 | arr = arr.astype(np.uint64, copy=False) 53 | arr_max = arr.max(0).astype(np.uint64) + 1 54 | 55 | keys = np.zeros(arr.shape[0], dtype=np.uint64) 56 | # Fortran style indexing 57 | for j in range(arr.shape[1] - 1): 58 | keys += arr[:, j] 59 | keys *= arr_max[j + 1] 60 | keys += arr[:, -1] 61 | return keys 62 | 63 | 64 | def sparse_quantize(coords, 65 | feats=None, 66 | labels=None, 67 | ignore_label=255, 68 | set_ignore_label_when_collision=False, 69 | return_index=False, 70 | hash_type='fnv', 71 | quantization_size=1): 72 | r"""Given coordinates, and features (optionally labels), the function 73 | generates quantized (voxelized) coordinates. 74 | 75 | Args: 76 | coords (:attr:`numpy.ndarray` or :attr:`torch.Tensor`): a matrix of size 77 | :math:`N \times D` where :math:`N` is the number of points in the 78 | :math:`D` dimensional space. 79 | 80 | feats (:attr:`numpy.ndarray` or :attr:`torch.Tensor`, optional): a matrix of size 81 | :math:`N \times D_F` where :math:`N` is the number of points and 82 | :math:`D_F` is the dimension of the features. 83 | 84 | labels (:attr:`numpy.ndarray`, optional): labels associated to eah coordinates. 85 | 86 | ignore_label (:attr:`int`, optional): the int value of the IGNORE LABEL. 87 | 88 | set_ignore_label_when_collision (:attr:`bool`, optional): use the `ignore_label` 89 | when at least two points fall into the same cell. 90 | 91 | return_index (:attr:`bool`, optional): True if you want the indices of the 92 | quantized coordinates. False by default. 93 | 94 | hash_type (:attr:`str`, optional): Hash function used for quantization. Either 95 | `ravel` or `fnv`. `ravel` by default. 96 | 97 | quantization_size (:attr:`float`, :attr:`list`, or 98 | :attr:`numpy.ndarray`, optional): the length of the each side of the 99 | hyperrectangle of of the grid cell. 100 | 101 | .. note:: 102 | Please check `examples/indoor.py` for the usage. 103 | 104 | """ 105 | use_label = labels is not None 106 | use_feat = feats is not None 107 | if not use_label and not use_feat: 108 | return_index = True 109 | 110 | assert hash_type in [ 111 | 'ravel', 'fnv' 112 | ], "Invalid hash_type. Either ravel, or fnv allowed. You put hash_type=" + hash_type 113 | assert coords.ndim == 2, \ 114 | "The coordinates must be a 2D matrix. The shape of the input is " + str(coords.shape) 115 | if use_feat: 116 | assert feats.ndim == 2 117 | assert coords.shape[0] == feats.shape[0] 118 | if use_label: 119 | assert coords.shape[0] == len(labels) 120 | 121 | # Quantize the coordinates 122 | dimension = coords.shape[1] 123 | if isinstance(quantization_size, (Sequence, np.ndarray, torch.Tensor)): 124 | assert len( 125 | quantization_size 126 | ) == dimension, "Quantization size and coordinates size mismatch." 127 | quantization_size = [i for i in quantization_size] 128 | elif np.isscalar(quantization_size): # Assume that it is a scalar 129 | quantization_size = [quantization_size for i in range(dimension)] 130 | else: 131 | raise ValueError('Not supported type for quantization_size.') 132 | discrete_coords = np.floor(coords / np.array(quantization_size)) 133 | 134 | # Hash function type 135 | if hash_type == 'ravel': 136 | key = ravel_hash_vec(discrete_coords) 137 | else: 138 | key = fnv_hash_vec(discrete_coords) 139 | 140 | if use_label: 141 | _, inds, counts = np.unique(key, return_index=True, return_counts=True) 142 | filtered_labels = labels[inds] 143 | if set_ignore_label_when_collision: 144 | filtered_labels[counts > 1] = ignore_label 145 | if return_index: 146 | return inds, filtered_labels 147 | else: 148 | return discrete_coords[inds], feats[inds], filtered_labels 149 | else: 150 | _, inds, inds_reverse = np.unique(key, return_index=True, return_inverse=True) 151 | if return_index: 152 | return inds, inds_reverse 153 | else: 154 | if use_feat: 155 | return discrete_coords[inds], feats[inds] 156 | else: 157 | return discrete_coords[inds] 158 | -------------------------------------------------------------------------------- /dataset/voxelizer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import numpy as np 4 | from dataset.voxelization_utils import sparse_quantize 5 | from scipy.linalg import expm, norm 6 | 7 | 8 | # Rotation matrix along axis with angle theta 9 | def M(axis, theta): 10 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 11 | 12 | 13 | class Voxelizer: 14 | 15 | def __init__(self, 16 | voxel_size=1, 17 | clip_bound=None, 18 | use_augmentation=False, 19 | scale_augmentation_bound=None, 20 | rotation_augmentation_bound=None, 21 | translation_augmentation_ratio_bound=None, 22 | ignore_label=255): 23 | """ 24 | Args: 25 | voxel_size: side length of a voxel 26 | clip_bound: boundary of the voxelizer. Points outside the bound will be deleted 27 | expects either None or an array like ((-100, 100), (-100, 100), (-100, 100)). 28 | scale_augmentation_bound: None or (0.9, 1.1) 29 | rotation_augmentation_bound: None or ((np.pi / 6, np.pi / 6), None, None) for 3 axis. 30 | Use random order of x, y, z to prevent bias. 31 | translation_augmentation_bound: ((-5, 5), (0, 0), (-10, 10)) 32 | ignore_label: label assigned for ignore (not a training label). 33 | """ 34 | self.voxel_size = voxel_size 35 | self.clip_bound = clip_bound 36 | self.ignore_label = ignore_label 37 | 38 | # Augmentation 39 | self.use_augmentation = use_augmentation 40 | self.scale_augmentation_bound = scale_augmentation_bound 41 | self.rotation_augmentation_bound = rotation_augmentation_bound 42 | self.translation_augmentation_ratio_bound = translation_augmentation_ratio_bound 43 | 44 | def get_transformation_matrix(self): 45 | voxelization_matrix, rotation_matrix = np.eye(4), np.eye(4) 46 | # Get clip boundary from config or pointcloud. 47 | # Get inner clip bound to crop from. 48 | 49 | # Transform pointcloud coordinate to voxel coordinate. 50 | # 1. Random rotation 51 | rot_mat = np.eye(3) 52 | if self.use_augmentation and self.rotation_augmentation_bound is not None: 53 | if isinstance(self.rotation_augmentation_bound, collections.Iterable): 54 | rot_mats = [] 55 | for axis_ind, rot_bound in enumerate(self.rotation_augmentation_bound): 56 | theta = 0 57 | axis = np.zeros(3) 58 | axis[axis_ind] = 1 59 | if rot_bound is not None: 60 | theta = np.random.uniform(*rot_bound) 61 | rot_mats.append(M(axis, theta)) 62 | # Use random order 63 | np.random.shuffle(rot_mats) 64 | rot_mat = rot_mats[0] @ rot_mats[1] @ rot_mats[2] 65 | else: 66 | raise ValueError() 67 | rotation_matrix[:3, :3] = rot_mat 68 | # 2. Scale and translate to the voxel space. 69 | scale = 1 / self.voxel_size 70 | if self.use_augmentation and self.scale_augmentation_bound is not None: 71 | scale *= np.random.uniform(*self.scale_augmentation_bound) 72 | np.fill_diagonal(voxelization_matrix[:3, :3], scale) 73 | # Get final transformation matrix. 74 | return voxelization_matrix, rotation_matrix 75 | 76 | def clip(self, coords, center=None, trans_aug_ratio=None): 77 | bound_min = np.min(coords, 0).astype(float) 78 | bound_max = np.max(coords, 0).astype(float) 79 | bound_size = bound_max - bound_min 80 | if center is None: 81 | center = bound_min + bound_size * 0.5 82 | lim = self.clip_bound 83 | if trans_aug_ratio is not None: 84 | trans = np.multiply(trans_aug_ratio, bound_size) 85 | center += trans 86 | # Clip points outside the limit 87 | clip_inds = ((coords[:, 0] >= (lim[0][0] + center[0])) & 88 | (coords[:, 0] < (lim[0][1] + center[0])) & 89 | (coords[:, 1] >= (lim[1][0] + center[1])) & 90 | (coords[:, 1] < (lim[1][1] + center[1])) & 91 | (coords[:, 2] >= (lim[2][0] + center[2])) & 92 | (coords[:, 2] < (lim[2][1] + center[2]))) 93 | return clip_inds 94 | 95 | def voxelize(self, coords, feats, labels, center=None, link=None): 96 | assert coords.shape[1] == 3 and coords.shape[0] == feats.shape[0] and coords.shape[0] 97 | if self.clip_bound is not None: 98 | trans_aug_ratio = np.zeros(3) 99 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: 100 | for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound): 101 | trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound) 102 | 103 | clip_inds = self.clip(coords, center, trans_aug_ratio) 104 | if clip_inds.sum(): 105 | coords, feats = coords[clip_inds], feats[clip_inds] 106 | if labels is not None: 107 | labels = labels[clip_inds] 108 | 109 | # Get rotation and scale 110 | M_v, M_r = self.get_transformation_matrix() 111 | # Apply transformations 112 | rigid_transformation = M_v 113 | if self.use_augmentation: 114 | rigid_transformation = M_r @ rigid_transformation 115 | 116 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype))) 117 | coords_aug = np.floor(homo_coords @ rigid_transformation.T[:, :3]) 118 | 119 | # Align all coordinates to the origin. 120 | min_coords = coords_aug.min(0) 121 | M_t = np.eye(4) 122 | M_t[:3, -1] = -min_coords 123 | rigid_transformation = M_t @ rigid_transformation 124 | coords_aug = np.floor(coords_aug - min_coords) 125 | 126 | inds, inds_reconstruct = sparse_quantize(coords_aug, return_index=True) 127 | coords_aug, feats, labels = coords_aug[inds], feats[inds], labels[inds] 128 | 129 | # Normal rotation 130 | if feats.shape[1] > 6: 131 | feats[:, 3:6] = feats[:, 3:6] @ (M_r[:3, :3].T) 132 | 133 | if link is not None: 134 | return coords_aug, feats, labels, np.array(inds_reconstruct), link[inds] 135 | 136 | return coords_aug, feats, labels, np.array(inds_reconstruct) 137 | -------------------------------------------------------------------------------- /fig/semaffinet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzy22/SemAffiNet/a5e6d9f289713ec185a93b6d3c2a269cfdee603b/fig/semaffinet.png -------------------------------------------------------------------------------- /metrics/iou.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | 9 | CLASS_LABELS = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf', 'picture', 10 | 'counter', 'desk', 'curtain', 'refrigerator', 'shower curtain', 'toilet', 'sink', 'bathtub', 11 | 'otherfurniture'] 12 | UNKNOWN_ID = 255 13 | N_CLASSES = len(CLASS_LABELS) 14 | 15 | 16 | def confusion_matrix(pred_ids, gt_ids): 17 | assert pred_ids.shape == gt_ids.shape, (pred_ids.shape, gt_ids.shape) 18 | idxs = gt_ids != UNKNOWN_ID 19 | return np.bincount(pred_ids[idxs] * 20 + gt_ids[idxs], minlength=400).reshape((20, 20)).astype(np.ulonglong) 20 | 21 | 22 | def get_iou(label_id, confusion): 23 | # true positives 24 | tp = np.longlong(confusion[label_id, label_id]) 25 | # false positives 26 | fp = np.longlong(confusion[label_id, :].sum()) - tp 27 | # false negatives 28 | fn = np.longlong(confusion[:, label_id].sum()) - tp 29 | 30 | denom = (tp + fp + fn) 31 | if denom == 0: 32 | return float('nan') 33 | return float(tp) / denom, tp, denom 34 | 35 | 36 | def evaluate(pred_ids, gt_ids, stdout=False): 37 | if stdout: 38 | print('evaluating', gt_ids.size, 'points...') 39 | confusion = confusion_matrix(pred_ids, gt_ids) 40 | class_ious = {} 41 | mean_iou = 0 42 | for i in range(N_CLASSES): 43 | label_name = CLASS_LABELS[i] 44 | class_ious[label_name] = get_iou(i, confusion) 45 | mean_iou += class_ious[label_name][0] / 20 46 | 47 | if stdout: 48 | print('classes IoU') 49 | print('----------------------------') 50 | for i in range(N_CLASSES): 51 | label_name = CLASS_LABELS[i] 52 | print('{0:<14s}: {1:>5.3f} ({2:>6d}/{3:<6d})'.format(label_name, class_ious[label_name][0], 53 | class_ious[label_name][1], 54 | class_ious[label_name][2])) 55 | print('mean IOU', mean_iou) 56 | return mean_iou 57 | -------------------------------------------------------------------------------- /models/bpm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import MinkowskiEngine as ME 4 | 5 | 6 | def get_coords_map(x, y): 7 | assert ( 8 | x.coordinate_manager == y.coordinate_manager 9 | ), "X and Y are using different CoordinateManagers. Y must be derived from X through strided conv/pool/etc." 10 | return x.coordinate_manager.stride_map(x.coordinate_map_key, y.coordinate_map_key) 11 | 12 | 13 | 14 | class Linking(nn.Module): 15 | def __init__(self, fea2d_dim, fea3d_dim, viewNum=3): 16 | super(Linking, self).__init__() 17 | self.viewNum = viewNum 18 | self.fea2d_dim = fea2d_dim 19 | 20 | self.view_fusion = nn.Sequential( 21 | ME.MinkowskiConvolution(fea2d_dim * viewNum, fea2d_dim, kernel_size=3, dimension=3), 22 | ME.MinkowskiBatchNorm(fea2d_dim), 23 | ME.MinkowskiReLU(inplace=True), 24 | ME.MinkowskiConvolution(fea2d_dim, fea3d_dim, kernel_size=3, dimension=3), 25 | ME.MinkowskiBatchNorm(fea3d_dim), 26 | ME.MinkowskiReLU(inplace=True) 27 | ) 28 | 29 | self.fuseTo3d = nn.Sequential( 30 | ME.MinkowskiConvolution(fea3d_dim * 2, fea3d_dim, kernel_size=3, dimension=3), 31 | ME.MinkowskiBatchNorm(fea3d_dim), 32 | ME.MinkowskiReLU(inplace=True) 33 | ) 34 | 35 | self.view_sep = nn.Sequential( 36 | ME.MinkowskiConvolution(fea3d_dim, fea2d_dim, kernel_size=3, dimension=3), 37 | ME.MinkowskiBatchNorm(fea2d_dim), 38 | ME.MinkowskiReLU(inplace=True) 39 | ) 40 | self.fuseTo2d = nn.Sequential( 41 | nn.Conv2d(fea2d_dim * 2, fea2d_dim, kernel_size=3, padding=1, bias=False), 42 | nn.BatchNorm2d(fea2d_dim), 43 | nn.ReLU(inplace=True) 44 | ) 45 | 46 | def forward(self, feat_2d_all, feat_3d, links, init_3d_data=None): 47 | """ 48 | :param feat_2d_all: V_B * C * H * WV 49 | :param feat_3d: SparseTensor, Feature of N*C 50 | :return: 51 | """ 52 | feat_3d_for_2d = self.view_sep(feat_3d).F 53 | V_B, C, H, W = feat_2d_all.shape 54 | feat_2d_all = feat_2d_all.view(self.viewNum, -1, C, H, W) 55 | 56 | # Link 57 | coords_map_in, coords_map_out = get_coords_map(init_3d_data, feat_3d) 58 | current_links = torch.zeros([feat_3d.shape[0], links.shape[1], links.shape[2]], dtype=torch.long).cuda() 59 | current_links[coords_map_out, :] = links[coords_map_in, :] 60 | 61 | feat_3d_to_2d = torch.zeros_like(feat_2d_all) 62 | feat_2d_to_3d = torch.zeros([feat_3d.F.shape[0], self.viewNum * self.fea2d_dim], dtype=torch.float).cuda() 63 | for v in range(self.viewNum): 64 | # pdb.set_trace() 65 | f = feat_2d_all[v, current_links[:, 0, v], :, current_links[:, 1, v], current_links[:, 2, v]] 66 | f *= current_links[:, 3, v].unsqueeze(dim=1).float() 67 | feat_2d_to_3d[:, v * self.fea2d_dim:(v + 1) * self.fea2d_dim] = f 68 | feat_3d_to_2d[v, current_links[:, 0, v], :, current_links[:, 1, v], current_links[:, 2, v]] = feat_3d_for_2d 69 | 70 | feat_3d_to_2d = feat_3d_to_2d.view(V_B, C, H, W) 71 | feat_2d_all = feat_2d_all.view(V_B, C, H, W) 72 | fused_2d = self.fuseTo2d(torch.cat([feat_2d_all, feat_3d_to_2d], dim=1)) 73 | 74 | feat_2d_to_3d = ME.SparseTensor(feat_2d_to_3d, feat_3d.C) 75 | feat_2d_to_3d = self.view_fusion(feat_2d_to_3d) 76 | # pdb.set_trace() 77 | feat_3d._F = torch.cat([feat_3d._F, feat_2d_to_3d._F], dim=-1) 78 | fused_3d = self.fuseTo3d(feat_3d) 79 | 80 | return fused_3d, fused_2d 81 | -------------------------------------------------------------------------------- /models/me_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). All Rights Reserved. 2 | # 3 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 4 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part of 5 | # the code. 6 | import collections 7 | from enum import Enum 8 | 9 | # from lib.layers import MinkowskiSwitchNorm, MinkowskiLayerNorm 10 | 11 | import MinkowskiEngine as ME 12 | import MinkowskiEngine.MinkowskiFunctional as MEF 13 | 14 | 15 | class NormType(Enum): 16 | BATCH_NORM = 0 17 | SPARSE_LAYER_NORM = 1 18 | SPARSE_INSTANCE_NORM = 2 19 | SPARSE_SWITCH_NORM = 3 20 | 21 | 22 | class NonlinearityType(Enum): 23 | ReLU = 0 24 | LeakyReLU = 1 25 | PReLU = 2 26 | CELU = 3 27 | SELU = 4 28 | 29 | 30 | def get_norm(norm_type, n_channels, D, bn_momentum=0.1, affine=True): 31 | if norm_type == NormType.BATCH_NORM: 32 | return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum, affine=affine) 33 | # elif norm_type == NormType.SPARSE_LAYER_NORM: 34 | # return MinkowskiLayerNorm(n_channels, D=D) 35 | elif norm_type == NormType.SPARSE_INSTANCE_NORM: 36 | return ME.MinkowskiInstanceNorm(n_channels) 37 | # elif norm_type == NormType.SPARSE_SWITCH_NORM: 38 | # return MinkowskiSwitchNorm(n_channels, D=D) 39 | else: 40 | raise ValueError(f'Norm type: {norm_type} not supported') 41 | 42 | 43 | str_to_nonlinearity_dict = {m.name: m for m in NonlinearityType} 44 | 45 | 46 | def get_nonlinearity_fn(nonlinearity_type, input, *args, **kwargs): 47 | nonlinearity_type = str_to_nonlinearity_dict[nonlinearity_type] 48 | if nonlinearity_type == NonlinearityType.ReLU: 49 | return MEF.relu(input, *args, **kwargs) 50 | elif nonlinearity_type == NonlinearityType.ReLU: 51 | return MEF.leaky_relu(input, *args, **kwargs) 52 | elif nonlinearity_type == NonlinearityType.PReLU: 53 | return MEF.prelu(input, *args, **kwargs) 54 | elif nonlinearity_type == NonlinearityType.CELU: 55 | return MEF.celu(input, *args, **kwargs) 56 | elif nonlinearity_type == NonlinearityType.SELU: 57 | return MEF.selu(input, *args, **kwargs) 58 | else: 59 | raise ValueError(f'Norm type: {nonlinearity_type} not supported') 60 | 61 | 62 | class ConvType(Enum): 63 | """ 64 | Define the kernel region type 65 | """ 66 | HYPERCUBE = 0, 'HYPERCUBE' 67 | SPATIAL_HYPERCUBE = 1, 'SPATIAL_HYPERCUBE' 68 | SPATIO_TEMPORAL_HYPERCUBE = 2, 'SPATIO_TEMPORAL_HYPERCUBE' 69 | HYPERCROSS = 3, 'HYPERCROSS' 70 | SPATIAL_HYPERCROSS = 4, 'SPATIAL_HYPERCROSS' 71 | SPATIO_TEMPORAL_HYPERCROSS = 5, 'SPATIO_TEMPORAL_HYPERCROSS' 72 | SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, 'SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS ' 73 | 74 | def __new__(cls, value, name): 75 | member = object.__new__(cls) 76 | member._value_ = value 77 | member.fullname = name 78 | return member 79 | 80 | def __int__(self): 81 | return self.value 82 | 83 | 84 | # Covert the ConvType var to a RegionType var 85 | conv_to_region_type = { 86 | # kernel_size = [k, k, k, 1] 87 | ConvType.HYPERCUBE: ME.RegionType.HYPER_CUBE, 88 | ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, 89 | ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, 90 | ConvType.HYPERCROSS: ME.RegionType.HYPER_CROSS, 91 | ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, 92 | ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, 93 | ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.CUSTOM 94 | } 95 | 96 | int_to_region_type = { 97 | 0: ME.RegionType.HYPER_CUBE, 98 | 1: ME.RegionType.HYPER_CROSS, 99 | 2: ME.RegionType.CUSTOM 100 | } 101 | 102 | 103 | def convert_region_type(region_type): 104 | """ 105 | Convert the integer region_type to the corresponding RegionType enum object. 106 | """ 107 | return int_to_region_type[region_type] 108 | 109 | 110 | def convert_conv_type(conv_type, kernel_size, D): 111 | assert isinstance(conv_type, ConvType), "conv_type must be of ConvType" 112 | region_type = conv_to_region_type[conv_type] 113 | axis_types = None 114 | if conv_type == ConvType.SPATIAL_HYPERCUBE: 115 | # No temporal convolution 116 | if isinstance(kernel_size, collections.Sequence): 117 | kernel_size = kernel_size[:3] 118 | else: 119 | kernel_size = [ 120 | kernel_size, 121 | ] * 3 122 | if D == 4: 123 | kernel_size.append(1) 124 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE: 125 | # conv_type conversion already handled 126 | assert D == 4 127 | elif conv_type == ConvType.HYPERCUBE: 128 | # conv_type conversion already handled 129 | pass 130 | elif conv_type == ConvType.SPATIAL_HYPERCROSS: 131 | if isinstance(kernel_size, collections.Sequence): 132 | kernel_size = kernel_size[:3] 133 | else: 134 | kernel_size = [ 135 | kernel_size, 136 | ] * 3 137 | if D == 4: 138 | kernel_size.append(1) 139 | elif conv_type == ConvType.HYPERCROSS: 140 | # conv_type conversion already handled 141 | pass 142 | elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS: 143 | # conv_type conversion already handled 144 | assert D == 4 145 | elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: 146 | # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim 147 | if D < 4: 148 | region_type = ME.RegionType.HYPER_CUBE 149 | else: 150 | axis_types = [ 151 | ME.RegionType.HYPER_CUBE, 152 | ] * 3 153 | if D == 4: 154 | axis_types.append(ME.RegionType.HYPER_CROSS) 155 | return region_type, axis_types, kernel_size 156 | 157 | 158 | def conv(in_planes, 159 | out_planes, 160 | kernel_size, 161 | stride=1, 162 | dilation=1, 163 | bias=False, 164 | conv_type=ConvType.HYPERCUBE, 165 | D=-1): 166 | assert D > 0, 'Dimension must be a positive integer' 167 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 168 | kernel_generator = ME.KernelGenerator( 169 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 170 | 171 | return ME.MinkowskiConvolution( 172 | in_channels=in_planes, 173 | out_channels=out_planes, 174 | kernel_size=kernel_size, 175 | stride=stride, 176 | dilation=dilation, 177 | bias=bias, 178 | kernel_generator=kernel_generator, 179 | dimension=D) 180 | 181 | 182 | def conv_dw(in_planes, 183 | kernel_size, 184 | stride=1, 185 | dilation=1, 186 | bias=False, 187 | conv_type=ConvType.HYPERCUBE, 188 | D=-1): 189 | 190 | assert D > 0, 'Dimension must be a positive integer' 191 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 192 | kernel_generator = ME.KernelGenerator( 193 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 194 | 195 | return ME.MinkowskiChannelwiseConvolution( 196 | in_channels=in_planes, 197 | kernel_size=kernel_size, 198 | stride=stride, 199 | dilation=dilation, 200 | bias=bias, 201 | kernel_generator=kernel_generator, 202 | dimension=D 203 | ) 204 | 205 | 206 | def conv_tr(in_planes, 207 | out_planes, 208 | kernel_size, 209 | upsample_stride=1, 210 | dilation=1, 211 | bias=False, 212 | conv_type=ConvType.HYPERCUBE, 213 | D=-1): 214 | assert D > 0, 'Dimension must be a positive integer' 215 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 216 | kernel_generator = ME.KernelGenerator( 217 | kernel_size, 218 | upsample_stride, 219 | dilation, 220 | region_type=region_type, 221 | axis_types=axis_types, 222 | dimension=D) 223 | 224 | return ME.MinkowskiConvolutionTranspose( 225 | in_channels=in_planes, 226 | out_channels=out_planes, 227 | kernel_size=kernel_size, 228 | stride=upsample_stride, 229 | dilation=dilation, 230 | bias=bias, 231 | kernel_generator=kernel_generator, 232 | dimension=D) 233 | 234 | 235 | def avg_pool(kernel_size, 236 | stride=1, 237 | dilation=1, 238 | conv_type=ConvType.HYPERCUBE, 239 | in_coords_key=None, 240 | D=-1): 241 | assert D > 0, 'Dimension must be a positive integer' 242 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 243 | kernel_generator = ME.KernelGenerator( 244 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 245 | 246 | return ME.MinkowskiAvgPooling( 247 | kernel_size=kernel_size, 248 | stride=stride, 249 | dilation=dilation, 250 | kernel_generator=kernel_generator, 251 | dimension=D) 252 | 253 | 254 | def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 255 | assert D > 0, 'Dimension must be a positive integer' 256 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 257 | kernel_generator = ME.KernelGenerator( 258 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 259 | 260 | return ME.MinkowskiAvgUnpooling( 261 | kernel_size=kernel_size, 262 | stride=stride, 263 | dilation=dilation, 264 | kernel_generator=kernel_generator, 265 | dimension=D) 266 | 267 | 268 | def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): 269 | assert D > 0, 'Dimension must be a positive integer' 270 | region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) 271 | kernel_generator = ME.KernelGenerator( 272 | kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D) 273 | 274 | return ME.MinkowskiSumPooling( 275 | kernel_size=kernel_size, 276 | stride=stride, 277 | dilation=dilation, 278 | kernel_generator=kernel_generator, 279 | dimension=D) -------------------------------------------------------------------------------- /models/resnet_d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | from timm.models.layers import DropBlock2d, DropPath, AvgPool2dSame, create_attn, create_classifier 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnet34d'] 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, 17 | reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, 18 | attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): 19 | super(BasicBlock, self).__init__() 20 | 21 | assert cardinality == 1, 'BasicBlock only supports cardinality of 1' 22 | assert base_width == 64, 'BasicBlock does not support changing base width' 23 | first_planes = planes // reduce_first 24 | outplanes = planes * self.expansion 25 | first_dilation = first_dilation or dilation 26 | use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) 27 | 28 | self.conv1 = nn.Conv2d( 29 | inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, 30 | dilation=first_dilation, bias=False) 31 | self.bn1 = norm_layer(first_planes) 32 | self.act1 = act_layer(inplace=True) 33 | self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None 34 | 35 | self.conv2 = nn.Conv2d( 36 | first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) 37 | self.bn2 = norm_layer(outplanes) 38 | 39 | self.se = create_attn(attn_layer, outplanes) 40 | 41 | self.act2 = act_layer(inplace=True) 42 | self.downsample = downsample 43 | self.stride = stride 44 | self.dilation = dilation 45 | self.drop_block = drop_block 46 | self.drop_path = drop_path 47 | 48 | def zero_init_last_bn(self): 49 | nn.init.zeros_(self.bn2.weight) 50 | 51 | def forward(self, x): 52 | shortcut = x 53 | 54 | x = self.conv1(x) 55 | x = self.bn1(x) 56 | if self.drop_block is not None: 57 | x = self.drop_block(x) 58 | x = self.act1(x) 59 | if self.aa is not None: 60 | x = self.aa(x) 61 | 62 | x = self.conv2(x) 63 | x = self.bn2(x) 64 | if self.drop_block is not None: 65 | x = self.drop_block(x) 66 | 67 | if self.se is not None: 68 | x = self.se(x) 69 | 70 | if self.drop_path is not None: 71 | x = self.drop_path(x) 72 | 73 | if self.downsample is not None: 74 | shortcut = self.downsample(shortcut) 75 | x += shortcut 76 | x = self.act2(x) 77 | 78 | return x 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, 85 | reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, 86 | attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): 87 | super(Bottleneck, self).__init__() 88 | 89 | width = int(math.floor(planes * (base_width / 64)) * cardinality) 90 | first_planes = width // reduce_first 91 | outplanes = planes * self.expansion 92 | first_dilation = first_dilation or dilation 93 | use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation) 94 | 95 | self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False) 96 | self.bn1 = norm_layer(first_planes) 97 | self.act1 = act_layer(inplace=True) 98 | 99 | self.conv2 = nn.Conv2d( 100 | first_planes, width, kernel_size=3, stride=1 if use_aa else stride, 101 | padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) 102 | self.bn2 = norm_layer(width) 103 | self.act2 = act_layer(inplace=True) 104 | self.aa = aa_layer(channels=width, stride=stride) if use_aa else None 105 | 106 | self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) 107 | self.bn3 = norm_layer(outplanes) 108 | 109 | self.se = create_attn(attn_layer, outplanes) 110 | 111 | self.act3 = act_layer(inplace=True) 112 | self.downsample = downsample 113 | self.stride = stride 114 | self.dilation = dilation 115 | self.drop_block = drop_block 116 | self.drop_path = drop_path 117 | 118 | def zero_init_last_bn(self): 119 | nn.init.zeros_(self.bn3.weight) 120 | 121 | def forward(self, x): 122 | shortcut = x 123 | 124 | x = self.conv1(x) 125 | x = self.bn1(x) 126 | if self.drop_block is not None: 127 | x = self.drop_block(x) 128 | x = self.act1(x) 129 | 130 | x = self.conv2(x) 131 | x = self.bn2(x) 132 | if self.drop_block is not None: 133 | x = self.drop_block(x) 134 | x = self.act2(x) 135 | if self.aa is not None: 136 | x = self.aa(x) 137 | 138 | x = self.conv3(x) 139 | x = self.bn3(x) 140 | if self.drop_block is not None: 141 | x = self.drop_block(x) 142 | 143 | if self.se is not None: 144 | x = self.se(x) 145 | 146 | if self.drop_path is not None: 147 | x = self.drop_path(x) 148 | 149 | if self.downsample is not None: 150 | shortcut = self.downsample(shortcut) 151 | x += shortcut 152 | x = self.act3(x) 153 | 154 | return x 155 | 156 | 157 | def get_padding(kernel_size, stride, dilation=1): 158 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 159 | return padding 160 | 161 | 162 | def downsample_conv( 163 | in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): 164 | norm_layer = norm_layer or nn.BatchNorm2d 165 | kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size 166 | first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 167 | p = get_padding(kernel_size, stride, first_dilation) 168 | 169 | return nn.Sequential(*[ 170 | nn.Conv2d( 171 | in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False), 172 | norm_layer(out_channels) 173 | ]) 174 | 175 | 176 | def downsample_avg( 177 | in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): 178 | norm_layer = norm_layer or nn.BatchNorm2d 179 | avg_stride = stride if dilation == 1 else 1 180 | if stride == 1 and dilation == 1: 181 | pool = nn.Identity() 182 | else: 183 | avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d 184 | pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) 185 | 186 | return nn.Sequential(*[ 187 | pool, 188 | nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False), 189 | norm_layer(out_channels) 190 | ]) 191 | 192 | 193 | def drop_blocks(drop_block_rate=0.): 194 | return [ 195 | None, None, 196 | DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, 197 | DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] 198 | 199 | 200 | def make_blocks( 201 | block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32, 202 | down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs): 203 | stages = [] 204 | feature_info = [] 205 | net_num_blocks = sum(block_repeats) 206 | net_block_idx = 0 207 | net_stride = 4 208 | dilation = prev_dilation = 1 209 | for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))): 210 | stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it 211 | stride = 1 if stage_idx == 0 else 2 212 | if net_stride >= output_stride: 213 | dilation *= stride 214 | stride = 1 215 | else: 216 | net_stride *= stride 217 | 218 | downsample = None 219 | if stride != 1 or inplanes != planes * block_fn.expansion: 220 | down_kwargs = dict( 221 | in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size, 222 | stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer')) 223 | downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs) 224 | 225 | block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs) 226 | blocks = [] 227 | for block_idx in range(num_blocks): 228 | downsample = downsample if block_idx == 0 else None 229 | stride = stride if block_idx == 0 else 1 230 | block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule 231 | blocks.append(block_fn( 232 | inplanes, planes, stride, downsample, first_dilation=prev_dilation, 233 | drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs)) 234 | prev_dilation = dilation 235 | inplanes = planes * block_fn.expansion 236 | net_block_idx += 1 237 | 238 | stages.append((stage_name, nn.Sequential(*blocks))) 239 | feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name)) 240 | 241 | return stages, feature_info 242 | 243 | 244 | class ResNet(nn.Module): 245 | """ResNet / ResNeXt / SE-ResNeXt / SE-Net 246 | This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that 247 | * have > 1 stride in the 3x3 conv layer of bottleneck 248 | * have conv-bn-act ordering 249 | This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s 250 | variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the 251 | 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default. 252 | ResNet variants (the same modifications can be used in SE/ResNeXt models as well): 253 | * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b 254 | * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64) 255 | * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample 256 | * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample 257 | * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128) 258 | * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample 259 | * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample 260 | ResNeXt 261 | * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths 262 | * same c,d, e, s variants as ResNet can be enabled 263 | SE-ResNeXt 264 | * normal - 7x7 stem, stem_width = 64 265 | * same c, d, e, s variants as ResNet can be enabled 266 | SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, 267 | reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block 268 | Parameters 269 | ---------- 270 | block : Block 271 | Class for the residual block. Options are BasicBlockGl, BottleneckGl. 272 | layers : list of int 273 | Numbers of layers in each block 274 | num_classes : int, default 1000 275 | Number of classification classes. 276 | in_chans : int, default 3 277 | Number of input (color) channels. 278 | cardinality : int, default 1 279 | Number of convolution groups for 3x3 conv in Bottleneck. 280 | base_width : int, default 64 281 | Factor determining bottleneck channels. `planes * base_width / 64 * cardinality` 282 | stem_width : int, default 64 283 | Number of channels in stem convolutions 284 | stem_type : str, default '' 285 | The type of stem: 286 | * '', default - a single 7x7 conv with a width of stem_width 287 | * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 288 | * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 289 | block_reduce_first: int, default 1 290 | Reduction factor for first convolution output width of residual blocks, 291 | 1 for all archs except senets, where 2 292 | down_kernel_size: int, default 1 293 | Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets 294 | avg_down : bool, default False 295 | Whether to use average pooling for projection skip connection between stages/downsample. 296 | output_stride : int, default 32 297 | Set the output stride of the network, 32, 16, or 8. Typically used in segmentation. 298 | act_layer : nn.Module, activation layer 299 | norm_layer : nn.Module, normalization layer 300 | aa_layer : nn.Module, anti-aliasing layer 301 | drop_rate : float, default 0. 302 | Dropout probability before classifier, for training 303 | global_pool : str, default 'avg' 304 | Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' 305 | """ 306 | 307 | def __init__(self, block, layers, num_classes=1000, in_chans=3, 308 | cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False, 309 | output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False, 310 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0., 311 | drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None): 312 | block_args = block_args or dict() 313 | assert output_stride in (8, 16, 32) 314 | self.num_classes = num_classes 315 | self.drop_rate = drop_rate 316 | super(ResNet, self).__init__() 317 | 318 | # Stem 319 | deep_stem = 'deep' in stem_type 320 | inplanes = stem_width * 2 if deep_stem else 64 321 | if deep_stem: 322 | stem_chs = (stem_width, stem_width) 323 | if 'tiered' in stem_type: 324 | stem_chs = (3 * (stem_width // 4), stem_width) 325 | self.conv1 = nn.Sequential(*[ 326 | nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False), 327 | norm_layer(stem_chs[0]), 328 | act_layer(inplace=True), 329 | nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False), 330 | norm_layer(stem_chs[1]), 331 | act_layer(inplace=True), 332 | nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)]) 333 | else: 334 | self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False) 335 | self.bn1 = norm_layer(inplanes) 336 | self.act1 = act_layer(inplace=True) 337 | self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')] 338 | 339 | # Stem Pooling 340 | if replace_stem_pool: 341 | self.maxpool = nn.Sequential(*filter(None, [ 342 | nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False), 343 | aa_layer(channels=inplanes, stride=2) if aa_layer else None, 344 | norm_layer(inplanes), 345 | act_layer(inplace=True) 346 | ])) 347 | else: 348 | if aa_layer is not None: 349 | self.maxpool = nn.Sequential(*[ 350 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 351 | aa_layer(channels=inplanes, stride=2)]) 352 | else: 353 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 354 | 355 | # Feature Blocks 356 | channels = [64, 128, 256, 512] 357 | stage_modules, stage_feature_info = make_blocks( 358 | block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width, 359 | output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down, 360 | down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, 361 | drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args) 362 | for stage in stage_modules: 363 | self.add_module(*stage) # layer1, layer2, etc 364 | self.feature_info.extend(stage_feature_info) 365 | 366 | # Head (Pooling and Classifier) 367 | self.num_features = 512 * block.expansion 368 | self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) 369 | 370 | self.init_weights(zero_init_last_bn=zero_init_last_bn) 371 | 372 | def init_weights(self, zero_init_last_bn=True): 373 | for n, m in self.named_modules(): 374 | if isinstance(m, nn.Conv2d): 375 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 376 | elif isinstance(m, nn.BatchNorm2d): 377 | nn.init.ones_(m.weight) 378 | nn.init.zeros_(m.bias) 379 | if zero_init_last_bn: 380 | for m in self.modules(): 381 | if hasattr(m, 'zero_init_last_bn'): 382 | m.zero_init_last_bn() 383 | 384 | def get_classifier(self): 385 | return self.fc 386 | 387 | def reset_classifier(self, num_classes, global_pool='avg'): 388 | self.num_classes = num_classes 389 | self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) 390 | 391 | def forward_features(self, x): 392 | x = self.conv1(x) 393 | x = self.bn1(x) 394 | x = self.act1(x) 395 | x = self.maxpool(x) 396 | 397 | x = self.layer1(x) 398 | x = self.layer2(x) 399 | x = self.layer3(x) 400 | x = self.layer4(x) 401 | return x 402 | 403 | def forward(self, x): 404 | x = self.forward_features(x) 405 | x = self.global_pool(x) 406 | if self.drop_rate: 407 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 408 | x = self.fc(x) 409 | return x 410 | 411 | 412 | def resnet18(pretrained=False, **kwargs): 413 | """Constructs a ResNet-18 model. 414 | 415 | Args: 416 | pretrained (bool): If True, returns a model pre-trained on ImageNet 417 | """ 418 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 419 | if pretrained: 420 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 421 | model_path = './initmodel/resnet18-5c106cde.pth' 422 | model.load_state_dict(torch.load(model_path), strict=False) 423 | return model 424 | 425 | 426 | def resnet34(pretrained=False, **kwargs): 427 | """Constructs a ResNet-34 model. 428 | 429 | Args: 430 | pretrained (bool): If True, returns a model pre-trained on ImageNet 431 | """ 432 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 433 | if pretrained: 434 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 435 | model_path = './initmodel/resnet34-333f7ec4.pth' 436 | model.load_state_dict(torch.load(model_path), strict=False) 437 | return model 438 | 439 | 440 | def resnet34d(pretrained=False, **kwargs): 441 | """Constructs a ResNet-34 model. 442 | 443 | Args: 444 | pretrained (bool): If True, returns a model pre-trained on ImageNet 445 | """ 446 | model = ResNet(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) 447 | if pretrained: 448 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 449 | model_path = './initmodel/resnet34d_ra2-f8dcfcaf.pth' 450 | model.load_state_dict(torch.load(model_path), strict=True) 451 | return model 452 | 453 | 454 | def resnet50(pretrained=False, **kwargs): 455 | """Constructs a ResNet-50 model. 456 | 457 | Args: 458 | pretrained (bool): If True, returns a model pre-trained on ImageNet 459 | """ 460 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 461 | if pretrained: 462 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 463 | # model_path = './initmodel/resnet50_v2.pth' 464 | model_path = './initmodel/resnet50-19c8e357.pth' 465 | model.load_state_dict(torch.load(model_path), strict=False) 466 | return model 467 | 468 | 469 | def resnet101(pretrained=False, **kwargs): 470 | """Constructs a ResNet-101 model. 471 | 472 | Args: 473 | pretrained (bool): If True, returns a model pre-trained on ImageNet 474 | """ 475 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 476 | if pretrained: 477 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 478 | model_path = './initmodel/resnet101_v2.pth' 479 | model.load_state_dict(torch.load(model_path), strict=False) 480 | return model 481 | 482 | 483 | def resnet152(pretrained=False, **kwargs): 484 | """Constructs a ResNet-152 model. 485 | 486 | Args: 487 | pretrained (bool): If True, returns a model pre-trained on ImageNet 488 | """ 489 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 490 | if pretrained: 491 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 492 | model_path = './initmodel/resnet152_v2.pth' 493 | model.load_state_dict(torch.load(model_path), strict=False) 494 | return model 495 | -------------------------------------------------------------------------------- /models/resnet_mink.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import torch.nn as nn 25 | 26 | import MinkowskiEngine as ME 27 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 28 | 29 | 30 | class ResNetBase(nn.Module): 31 | BLOCK = None 32 | LAYERS = () 33 | INIT_DIM = 64 34 | PLANES = (64, 128, 256, 512) 35 | 36 | def __init__(self, in_channels, out_channels, D=3): 37 | nn.Module.__init__(self) 38 | self.D = D 39 | assert self.BLOCK is not None 40 | 41 | self.network_initialization(in_channels, out_channels, D) 42 | self.weight_initialization() 43 | 44 | def network_initialization(self, in_channels, out_channels, D): 45 | 46 | self.inplanes = self.INIT_DIM 47 | self.conv1 = ME.MinkowskiConvolution( 48 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D) 49 | 50 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 51 | self.relu = ME.MinkowskiReLU(inplace=True) 52 | 53 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) 54 | 55 | self.layer1 = self._make_layer( 56 | self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2) 57 | self.layer2 = self._make_layer( 58 | self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2) 59 | self.layer3 = self._make_layer( 60 | self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2) 61 | self.layer4 = self._make_layer( 62 | self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2) 63 | 64 | self.conv5 = ME.MinkowskiConvolution( 65 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D) 66 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) 67 | 68 | self.glob_avg = ME.MinkowskiGlobalMaxPooling(dimension=D) 69 | 70 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 71 | 72 | def weight_initialization(self): 73 | for m in self.modules(): 74 | if isinstance(m, ME.MinkowskiConvolution): 75 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') 76 | 77 | if isinstance(m, ME.MinkowskiBatchNorm) and m.bn.weight is not None and m.bn.bias is not None: 78 | nn.init.constant_(m.bn.weight, 1) 79 | nn.init.constant_(m.bn.bias, 0) 80 | 81 | def _make_layer(self, 82 | block, 83 | planes, 84 | blocks, 85 | stride=1, 86 | dilation=1, 87 | bn_momentum=0.1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * block.expansion: 90 | downsample = nn.Sequential( 91 | ME.MinkowskiConvolution( 92 | self.inplanes, 93 | planes * block.expansion, 94 | kernel_size=1, 95 | stride=stride, 96 | dimension=self.D), 97 | ME.MinkowskiBatchNorm(planes * block.expansion)) 98 | layers = [] 99 | layers.append( 100 | block( 101 | self.inplanes, 102 | planes, 103 | stride=stride, 104 | dilation=dilation, 105 | downsample=downsample, 106 | dimension=self.D)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, blocks): 109 | layers.append( 110 | block( 111 | self.inplanes, 112 | planes, 113 | stride=1, 114 | dilation=dilation, 115 | dimension=self.D)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.conv1(x) 121 | x = self.bn1(x) 122 | x = self.relu(x) 123 | x = self.pool(x) 124 | 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | 130 | x = self.conv5(x) 131 | x = self.bn5(x) 132 | x = self.relu(x) 133 | 134 | x = self.glob_avg(x) 135 | return self.final(x) 136 | 137 | 138 | class ResNet14(ResNetBase): 139 | BLOCK = BasicBlock 140 | LAYERS = (1, 1, 1, 1) 141 | 142 | 143 | class ResNet18(ResNetBase): 144 | BLOCK = BasicBlock 145 | LAYERS = (2, 2, 2, 2) 146 | 147 | 148 | class ResNet34(ResNetBase): 149 | BLOCK = BasicBlock 150 | LAYERS = (3, 4, 6, 3) 151 | 152 | 153 | class ResNet50(ResNetBase): 154 | BLOCK = Bottleneck 155 | LAYERS = (3, 4, 6, 3) 156 | 157 | 158 | class ResNet101(ResNetBase): 159 | BLOCK = Bottleneck 160 | LAYERS = (3, 4, 23, 3) 161 | -------------------------------------------------------------------------------- /models/semaffinet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from collections import OrderedDict 3 | 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import MinkowskiEngine as ME 9 | 10 | from models.unet_2d import ResUnet as model2D 11 | from models.unet_3d import mink_unet as model3D 12 | from models.bpm import Linking 13 | from models.transformer_utils.transformer_predictor import TransformerPredictor, MLP 14 | 15 | 16 | def state_dict_remove_moudle(state_dict): 17 | new_state_dict = OrderedDict() 18 | for k, v in state_dict.items(): 19 | # name = k[7:] # remove 'module.' of dataparallel 20 | name = k.replace('module.', '') 21 | new_state_dict[name] = v 22 | return new_state_dict 23 | 24 | 25 | def constructor3d(**kwargs): 26 | model = model3D(**kwargs) 27 | # model = model.cuda() 28 | return model 29 | 30 | 31 | def constructor2d(**kwargs): 32 | model = model2D(**kwargs) 33 | # model = model.cuda() 34 | return model 35 | 36 | 37 | class SemAffiNet(nn.Module): 38 | 39 | def __init__(self, cfg=None): 40 | super(SemAffiNet, self).__init__() 41 | self.viewNum = cfg.viewNum 42 | self.tr_dim = cfg.hidden_dim 43 | self.mask_dim = cfg.mask_dim 44 | # 2D 45 | net2d = constructor2d(layers=cfg.layers_2d, classes=cfg.classes, out_channels=self.mask_dim) 46 | self.layer0_2d = net2d.layer0 47 | self.layer1_2d = net2d.layer1 48 | self.layer2_2d = net2d.layer2 49 | self.layer3_2d = net2d.layer3 50 | self.layer4_2d = net2d.layer4 51 | 52 | self.class_convfeat_2d = net2d.class_convfeat 53 | self.relu_2d = net2d.relu 54 | 55 | self.up4_2d = net2d.up4 56 | self.class_conv4_2d = net2d.class_conv4 57 | self.bn4_2d = net2d.bn4 58 | self.sat_w4_2d = MLP(self.tr_dim, self.tr_dim, self.mask_dim, 5) 59 | self.sat_b4_2d = MLP(self.tr_dim, self.tr_dim, self.mask_dim, 5) 60 | self.delayer4_2d = net2d.delayer4 61 | 62 | self.up3_2d = net2d.up3 63 | self.class_conv3_2d = net2d.class_conv3 64 | self.bn3_2d = net2d.bn3 65 | self.sat_w3_2d = MLP(self.tr_dim, self.tr_dim, self.mask_dim, 5) 66 | self.sat_b3_2d = MLP(self.tr_dim, self.tr_dim, self.mask_dim, 5) 67 | self.delayer3_2d = net2d.delayer3 68 | 69 | self.up2_2d = net2d.up2 70 | self.class_conv2_2d = net2d.class_conv2 71 | self.bn2_2d = net2d.bn2 72 | self.sat_w2_2d = MLP(self.tr_dim, self.tr_dim, self.mask_dim, 5) 73 | self.sat_b2_2d = MLP(self.tr_dim, self.tr_dim, self.mask_dim, 5) 74 | self.delayer2_2d = net2d.delayer2 75 | 76 | self.cls_2d = net2d.cls 77 | self.num_layers_2d = cfg.layers_2d 78 | 79 | if cfg.layers_2d >= 50: 80 | self.down_layer1_2d = nn.Conv2d(256, 64, 1) 81 | self.down_layer2_2d = nn.Conv2d(512, 128, 1) 82 | self.down_layer3_2d = nn.Conv2d(1024, 256, 1) 83 | self.down_layer4_2d = nn.Conv2d(2048, 512, 1) 84 | 85 | # 3D 86 | net3d = constructor3d(in_channels=3, out_channels=cfg.mask_dim, D=3, arch=cfg.arch_3d) 87 | self.layer0_3d = nn.Sequential(net3d.conv0p1s1, net3d.bn0, net3d.relu) 88 | self.layer1_3d = nn.Sequential(net3d.conv1p1s2, net3d.bn1, net3d.relu, net3d.block1) 89 | self.layer2_3d = nn.Sequential(net3d.conv2p2s2, net3d.bn2, net3d.relu, net3d.block2) 90 | self.layer3_3d = nn.Sequential(net3d.conv3p4s2, net3d.bn3, net3d.relu, net3d.block3) 91 | self.layer4_3d = nn.Sequential(net3d.conv4p8s2, net3d.bn4, net3d.relu, net3d.block4) 92 | 93 | self.class_convfeat = net3d.class_convfeat 94 | self.relu_3d = net3d.relu 95 | out_channels = net3d.out_channels 96 | 97 | self.convtr4p16s2 = net3d.convtr4p16s2 98 | self.class_conv4 = net3d.class_conv4 99 | self.bntr4 = net3d.bntr4 100 | self.sat_w4 = MLP(self.tr_dim, self.tr_dim, out_channels, 5) 101 | self.sat_b4 = MLP(self.tr_dim, self.tr_dim, out_channels, 5) 102 | self.block5 = net3d.block5 103 | 104 | self.convtr5p8s2 = net3d.convtr5p8s2 105 | self.class_conv5 = net3d.class_conv5 106 | self.bntr5 = net3d.bntr5 107 | self.sat_w5 = MLP(self.tr_dim, self.tr_dim, out_channels, 5) 108 | self.sat_b5 = MLP(self.tr_dim, self.tr_dim, out_channels, 5) 109 | self.block6 = net3d.block6 110 | 111 | self.convtr6p4s2 = net3d.convtr6p4s2 112 | self.class_conv6 = net3d.class_conv6 113 | self.bntr6 = net3d.bntr6 114 | self.sat_w6 = MLP(self.tr_dim, self.tr_dim, out_channels, 5) 115 | self.sat_b6 = MLP(self.tr_dim, self.tr_dim, out_channels, 5) 116 | self.block7 = net3d.block7 117 | 118 | self.convtr7p2s2 = net3d.convtr7p2s2 119 | self.bntr7 = net3d.bntr7 120 | self.block8 = net3d.block8 121 | self.cls_3d = net3d.final 122 | 123 | self.predictor = TransformerPredictor(config=cfg, in_channels_3d=net3d.PLANES[3], in_channels_2d=512, 124 | num_classes=cfg.classes, D=3) 125 | 126 | # Linker 127 | self.linker_p2 = Linking(96, net3d.PLANES[6], viewNum=self.viewNum) 128 | self.linker_p3 = Linking(128, net3d.PLANES[5], viewNum=self.viewNum) 129 | self.linker_p4 = Linking(256, net3d.PLANES[4], viewNum=self.viewNum) 130 | self.linker_p5 = Linking(512, net3d.PLANES[3], viewNum=self.viewNum) 131 | 132 | def SAT(self, pred_class, coordinates, query_weight, query_bias): 133 | pred_class = torch.sigmoid(pred_class) 134 | query_weight = torch.abs(query_weight) 135 | weight_list = [] 136 | bias_list = [] 137 | for b in coordinates[:, 0].unique(): 138 | pred_class_b = pred_class[coordinates[:, 0] == b] 139 | weight = torch.einsum("nm,mc->nc", pred_class_b, query_weight[b]) 140 | bias = torch.einsum("nm,mc->nc", pred_class_b, query_bias[b]) 141 | weight_list.append(weight) 142 | bias_list.append(bias) 143 | 144 | return torch.cat(weight_list, dim=0), torch.cat(bias_list, dim=0) 145 | 146 | def mul_mask_voxel(self, mask_embed, voxel_embed): 147 | outputs_seg_lvl = [] 148 | outputs_coords_lvl = [] 149 | for i in range(len(voxel_embed.C[:, 0].unique())): 150 | idx_i = voxel_embed.C[:, 0] == i 151 | seg_mask = torch.einsum("qc,cn->qn", mask_embed[i], voxel_embed.F[idx_i].transpose(0, 1).contiguous()) 152 | outputs_seg_lvl.append(seg_mask.transpose(0, 1).contiguous()) 153 | outputs_coords_lvl.append(voxel_embed.C[idx_i]) 154 | outputs_seg_lvl = torch.cat(outputs_seg_lvl, dim=0) 155 | outputs_coords_lvl = torch.cat(outputs_coords_lvl, dim=0) 156 | assert torch.equal(voxel_embed.C, outputs_coords_lvl) 157 | return outputs_seg_lvl 158 | 159 | def forward(self, sparse_3d, images, links, targets_mid_3d=None): 160 | """ 161 | images:BCHWV 162 | """ 163 | ######################## 164 | ##### Encoder Part ##### 165 | ######################## 166 | 167 | ### 2D feature extract 168 | x_size = images.size() 169 | h, w = x_size[2], x_size[3] 170 | data_2d = images.permute(4, 0, 1, 2, 3).contiguous() # VBCHW 171 | data_2d = data_2d.view(x_size[0] * x_size[4], x_size[1], x_size[2], x_size[3]) 172 | x = self.layer0_2d(data_2d) # 1/4 173 | x2 = self.layer1_2d(x) # 1/4 174 | x3 = self.layer2_2d(x2) # 1/8 175 | x4 = self.layer3_2d(x3) # 1/16 176 | x5 = self.layer4_2d(x4) # 1/32 177 | 178 | if self.num_layers_2d >= 50: 179 | x2 = self.down_layer1_2d(x2) 180 | x3 = self.down_layer2_2d(x3) 181 | x4 = self.down_layer3_2d(x4) 182 | x5 = self.down_layer4_2d(x5) 183 | 184 | ### 3D feature extract 185 | out_p1 = self.layer0_3d(sparse_3d) 186 | out_b1p2 = self.layer1_3d(out_p1) 187 | out_b2p4 = self.layer2_3d(out_b1p2) 188 | out_b3p8 = self.layer3_3d(out_b2p4) 189 | out_b4p16 = self.layer4_3d(out_b3p8) # corresponding to FPN p5 190 | 191 | mask_embed_2d, mask_embed, hs_sat_2d, hs_sat = self.predictor(x5, out_b4p16) 192 | predictions, predictions_2d = {}, {} 193 | b, n_c, d = hs_sat[-1].shape 194 | b2, n_p, d = hs_sat_2d[-1].shape 195 | outputs_seg_list = [] 196 | outputs_seg_list_2d = [] 197 | 198 | ######################## 199 | ##### Decoder Part ##### 200 | ######################## 201 | 202 | ### Class prediction @ p5 203 | 204 | p5 = self.class_convfeat_2d(x5) 205 | seg_p5_2d = torch.einsum('bqc,bchw->bqhw', mask_embed_2d, p5) 206 | outputs_seg_list_2d.append(seg_p5_2d) 207 | 208 | outfeat = self.class_convfeat(out_b4p16) 209 | seg_feat_3d = self.mul_mask_voxel(mask_embed, outfeat) 210 | outputs_seg_list.append(seg_feat_3d) 211 | 212 | ### Linking @ p5 213 | V_B, C, H, W = x5.shape 214 | links_current_level = links.clone() 215 | links_current_level[:, 1:3, :] = ((H - 1.) / (h - 1.) * links_current_level[:, 1:3, :].float()).int() 216 | fused_3d_p5, fused_2d_p5 = self.linker_p5(x5, out_b4p16, links_current_level, init_3d_data=sparse_3d) 217 | 218 | ### p5->p4, Class prediction @ p4, AdaBN @ p4, Block @ p4 219 | 220 | p4 = self.up4_2d(F.interpolate(fused_2d_p5, x4.shape[-2:], mode='bilinear', align_corners=True)) 221 | p4 = self.class_conv4_2d(p4) 222 | seg_p4_2d = torch.einsum('bqc,bchw->bqhw', mask_embed_2d, p4) 223 | outputs_seg_list_2d.append(seg_p4_2d) 224 | 225 | p4_inst = self.bn4_2d(p4) 226 | hs_sat_w4_2d = self.sat_w4_2d(hs_sat_2d[-4].view(b2*n_p, d)).view(b2, n_p, -1) 227 | hs_sat_b4_2d = self.sat_b4_2d(hs_sat_2d[-4].view(b2*n_p, d)).view(b2, n_p, -1) 228 | weight4_2d = torch.einsum('bcd,bchw->bdhw', torch.abs(hs_sat_w4_2d), torch.sigmoid(seg_p4_2d)) 229 | bias4_2d = torch.einsum('bcd,bchw->bdhw', hs_sat_b4_2d, torch.sigmoid(seg_p4_2d)) 230 | p4_instnorm = p4_inst * weight4_2d + bias4_2d 231 | feat_p4 = self.relu_2d(p4_instnorm) 232 | p4 = torch.cat([feat_p4, x4], dim=1) 233 | p4 = self.delayer4_2d(p4) 234 | 235 | out4 = self.convtr4p16s2(fused_3d_p5) 236 | out4 = self.class_conv4(out4) 237 | seg_out4_3d = self.mul_mask_voxel(mask_embed, out4) 238 | outputs_seg_list.append(seg_out4_3d) 239 | 240 | out4_inst = self.bntr4(out4) 241 | hs_sat_w4 = self.sat_w4(hs_sat[-4].view(b*n_c, -1)).view(b, n_c, -1) 242 | hs_sat_b4 = self.sat_b4(hs_sat[-4].view(b*n_c, -1)).view(b, n_c, -1) 243 | weight, bias = self.SAT(seg_out4_3d, out4.C, hs_sat_w4, hs_sat_b4) 244 | out4_instnorm = out4_inst.F * weight + bias 245 | out4 = ME.SparseTensor(features=out4_instnorm, coordinate_map_key=out4.coordinate_map_key, coordinate_manager=out4.coordinate_manager) 246 | feat_out4 = self.relu_3d(out4) 247 | out_cat4 = ME.cat(feat_out4, out_b3p8) 248 | out4 = self.block5(out_cat4) 249 | 250 | ### Linking @ p4 251 | V_B, C, H, W = p4.shape 252 | links_current_level = links.clone() 253 | links_current_level[:, 1:3, :] = ((H - 1.) / (h - 1.) * links_current_level[:, 1:3, :].float()).int() 254 | fused_3d_p4, fused_2d_p4 = self.linker_p4(p4, out4, links_current_level, init_3d_data=sparse_3d) 255 | 256 | ### p4->p3, Class prediction @ p3, AdaBN @ p3, Block @ p3 257 | p3 = self.up3_2d(F.interpolate(fused_2d_p4, x3.shape[-2:], mode='bilinear', align_corners=True)) 258 | p3 = self.class_conv3_2d(p3) 259 | seg_p3_2d = torch.einsum('bqc,bchw->bqhw', mask_embed_2d, p3) 260 | outputs_seg_list_2d.append(seg_p3_2d) 261 | 262 | p3_inst = self.bn3_2d(p3) 263 | hs_sat_w3_2d = self.sat_w3_2d(hs_sat_2d[-3].view(b2*n_p, d)).view(b2, n_p, -1) 264 | hs_sat_b3_2d = self.sat_b3_2d(hs_sat_2d[-3].view(b2*n_p, d)).view(b2, n_p, -1) 265 | weight3_2d = torch.einsum('bcd,bchw->bdhw', torch.abs(hs_sat_w3_2d), torch.sigmoid(seg_p3_2d)) 266 | bias3_2d = torch.einsum('bcd,bchw->bdhw', hs_sat_b3_2d, torch.sigmoid(seg_p3_2d)) 267 | p3_instnorm = p3_inst * weight3_2d + bias3_2d 268 | feat_p3 = self.relu_2d(p3_instnorm) 269 | p3 = torch.cat([feat_p3, x3], dim=1) 270 | p3 = self.delayer3_2d(p3) 271 | 272 | out5 = self.convtr5p8s2(fused_3d_p4) 273 | out5 = self.class_conv5(out5) 274 | seg_out5_3d = self.mul_mask_voxel(mask_embed, out5) 275 | outputs_seg_list.append(seg_out5_3d) 276 | 277 | out5_inst = self.bntr5(out5) 278 | hs_sat_w5 = self.sat_w5(hs_sat[-3].view(b*n_c, -1)).view(b, n_c, -1) 279 | hs_sat_b5 = self.sat_b5(hs_sat[-3].view(b*n_c, -1)).view(b, n_c, -1) 280 | weight, bias = self.SAT(seg_out5_3d, out5.C, hs_sat_w5, hs_sat_b5) 281 | out5_instnorm = out5_inst.F * weight + bias 282 | out5 = ME.SparseTensor(features=out5_instnorm, coordinate_map_key=out5.coordinate_map_key, coordinate_manager=out5.coordinate_manager) 283 | feat_out5 = self.relu_3d(out5) 284 | out_cat5 = ME.cat(feat_out5, out_b2p4) 285 | out5 = self.block6(out_cat5) 286 | 287 | ### Linking @ p3 288 | V_B, C, H, W = p3.shape 289 | links_current_level = links.clone() 290 | links_current_level[:, 1:3, :] = ((H - 1.) / (h - 1.) * links_current_level[:, 1:3, :].float()).int() 291 | fused_3d_p3, fused_2d_p3 = self.linker_p3(p3, out5, links_current_level, init_3d_data=sparse_3d) 292 | 293 | ### p3->p2, Class prediction @ p2, AdaBN @ p2, Block @ p2 294 | p2 = self.up2_2d(F.interpolate(fused_2d_p3, x2.shape[-2:], mode='bilinear', align_corners=True)) 295 | p2 = self.class_conv2_2d(p2) 296 | seg_p2_2d = torch.einsum('bqc,bchw->bqhw', mask_embed_2d, p2) 297 | outputs_seg_list_2d.append(seg_p2_2d) 298 | 299 | p2_inst = self.bn2_2d(p2) 300 | hs_sat_w2_2d = self.sat_w2_2d(hs_sat_2d[-2].view(b2*n_p, d)).view(b2, n_p, -1) 301 | hs_sat_b2_2d = self.sat_b2_2d(hs_sat_2d[-2].view(b2*n_p, d)).view(b2, n_p, -1) 302 | weight2_2d = torch.einsum('bcd,bchw->bdhw', torch.abs(hs_sat_w2_2d), torch.sigmoid(seg_p2_2d)) 303 | bias2_2d = torch.einsum('bcd,bchw->bdhw', hs_sat_b2_2d, torch.sigmoid(seg_p2_2d)) 304 | p2_instnorm = p2_inst * weight2_2d + bias2_2d 305 | feat_p2 = self.relu_2d(p2_instnorm) 306 | p2 = torch.cat([feat_p2, x2], dim=1) 307 | p2 = self.delayer2_2d(p2) 308 | 309 | out6 = self.convtr6p4s2(fused_3d_p3) 310 | out6 = self.class_conv6(out6) 311 | seg_out6_3d = self.mul_mask_voxel(mask_embed, out6) 312 | outputs_seg_list.append(seg_out6_3d) 313 | 314 | out6_inst = self.bntr6(out6) 315 | hs_sat_w6 = self.sat_w6(hs_sat[-2].view(b*n_c, -1)).view(b, n_c, -1) 316 | hs_sat_b6 = self.sat_b6(hs_sat[-2].view(b*n_c, -1)).view(b, n_c, -1) 317 | weight, bias = self.SAT(seg_out6_3d, out6.C, hs_sat_w6, hs_sat_b6) 318 | out6_instnorm = out6_inst.F * weight + bias 319 | out6 = ME.SparseTensor(features=out6_instnorm, coordinate_map_key=out6.coordinate_map_key, coordinate_manager=out6.coordinate_manager) 320 | feat_out6 = self.relu_3d(out6) 321 | out_cat6 = ME.cat(feat_out6, out_b1p2) 322 | out6 = self.block7(out_cat6) 323 | 324 | # Linking @ p2 325 | V_B, C, H, W = p2.shape 326 | links_current_level = links.clone() 327 | links_current_level[:, 1:3, :] = ((H - 1.) / (h - 1.) * links_current_level[:, 1:3, :].float()).int() 328 | fused_3d_p2, fused_2d_p2 = self.linker_p2(p2, out6, links_current_level, init_3d_data=sparse_3d) 329 | 330 | # feat_3d = self.layer8_3d(ME.cat(fused_3d_p2, out_b1p2)) 331 | out7 = self.convtr7p2s2(fused_3d_p2) 332 | out7 = self.bntr7(out7) 333 | feat_out7 = self.relu_3d(out7) 334 | out_cat7 = ME.cat(feat_out7, out_p1) 335 | out7 = self.block8(out_cat7) 336 | 337 | # Res 338 | # pdb.set_trace() 339 | res_2d = self.cls_2d(fused_2d_p2) 340 | res_2d = F.interpolate(res_2d, size=(h, w), mode='bilinear', align_corners=True) 341 | outputs_seg8_2d = torch.einsum('bqc,bchw->bqhw', mask_embed_2d, res_2d) 342 | outputs_seg_list_2d.append(outputs_seg8_2d) 343 | 344 | res_3d = self.cls_3d(out7) 345 | outputs_seg8 = self.mul_mask_voxel(mask_embed, res_3d) 346 | outputs_seg_list.append(outputs_seg8) 347 | 348 | predictions['pred_masks'] = outputs_seg_list[-1] 349 | V_B, C, H, W = outputs_seg8_2d.shape 350 | predictions_2d['pred_masks'] = outputs_seg_list_2d[-1].view(self.viewNum, b, C, H, W).permute(1, 2, 3, 4, 0) 351 | predictions_2d['aux_pred'] = outputs_seg_list_2d[:-1] 352 | 353 | if targets_mid_3d is not None: 354 | predictions["aux_pred"] = [] 355 | predictions["aux_gt"] = [] 356 | outputs_coords_list = [out_b4p16.C, out_b3p8.C, out_b2p4.C, out_b1p2.C] 357 | for i in range(len(targets_mid_3d)): 358 | outputs_segfeat_rearrange, targetsfeat_rearrange = self.get_rearrange_targets(outputs_seg_list[i], outputs_coords_list[i], targets_mid_3d[i]) 359 | predictions["aux_pred"].append(outputs_segfeat_rearrange) 360 | predictions["aux_gt"].append(targetsfeat_rearrange) 361 | 362 | return predictions, predictions_2d 363 | 364 | def get_rearrange_targets(self, seg_F, seg_C, target_class): 365 | assert len(seg_F) == len(target_class) 366 | base = seg_C.max() + 1 367 | base2dec_gt = target_class.C[:, 0] * (base ** 3) + target_class.C[:, 1] * (base ** 2) + target_class.C[:, 2] * base + target_class.C[:, 3] 368 | base2dec_pred = seg_C[:, 0] * (base ** 3) + seg_C[:, 1] * (base ** 2) + seg_C[:, 2] * base + seg_C[:, 3] 369 | 370 | _, idx_gt = torch.sort(base2dec_gt) 371 | _, idx_pred = torch.sort(base2dec_pred) 372 | 373 | seg_rearrange_F = seg_F[idx_pred] 374 | seg_rearrange_C = seg_C[idx_pred] 375 | target_class_F = target_class.F[idx_gt] 376 | target_class_C = target_class.C[idx_gt] 377 | assert torch.equal(target_class_C, seg_rearrange_C) 378 | 379 | return seg_rearrange_F, target_class_F 380 | -------------------------------------------------------------------------------- /models/shadownet_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResNetShadow(nn.Module): 6 | def __init__(self, channels): 7 | super(ResNetShadow, self).__init__() 8 | self.conv0 = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1, groups=channels, bias=False) 9 | self.maxpool0 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 10 | 11 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1, groups=channels, bias=False) 12 | self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1, groups=channels, bias=False) 13 | self.conv4 = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1, groups=channels, bias=False) 14 | 15 | self.weight_initialization() 16 | 17 | def weight_initialization(self): 18 | for m in self.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | m.weight = nn.Parameter(torch.ones_like(m.weight)) 21 | 22 | def forward(self, x): 23 | x0 = self.conv0(x) 24 | x0 = self.maxpool0(x0) 25 | 26 | x2 = self.conv2(x0) 27 | x3 = self.conv3(x2) 28 | x4 = self.conv4(x3) 29 | 30 | return [x4, x3, x2, x0] 31 | -------------------------------------------------------------------------------- /models/shadownet_3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). All Rights Reserved. 2 | # 3 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 4 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part of 5 | # the code. 6 | from models.resnet_mink import ResNetBase 7 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 8 | from models.me_common import ConvType, conv_dw 9 | import MinkowskiEngine as ME 10 | 11 | import torch 12 | from torch import nn as nn 13 | 14 | 15 | class Res16UNetShadow(ResNetBase): 16 | BLOCK = None 17 | PLANES = (32, 64, 128, 256, 256, 256, 256, 256) 18 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 19 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 20 | INIT_DIM = 32 21 | OUT_PIXEL_DIST = 1 22 | NON_BLOCK_CONV_TYPE = ConvType.SPATIAL_HYPERCUBE 23 | CONV_TYPE = ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS 24 | 25 | # To use the model, must call initialize_coords before forward pass. 26 | # Once data is processed, call clear to reset the model before calling initialize_coords 27 | def __init__(self, in_channels, D=3, **kwargs): 28 | super(Res16UNetShadow, self).__init__(in_channels, in_channels, D) 29 | 30 | def weight_initialization(self): 31 | for m in self.modules(): 32 | if isinstance(m, ME.MinkowskiChannelwiseConvolution): 33 | m.kernel = nn.Parameter(torch.ones_like(m.kernel)) 34 | 35 | def network_initialization(self, in_channels, out_channels, D): 36 | # Setup net_metadata 37 | 38 | def space_n_time_m(n, m): 39 | return n if D == 3 else [n, n, n, m] 40 | 41 | if D == 4: 42 | self.OUT_PIXEL_DIST = space_n_time_m(self.OUT_PIXEL_DIST, 1) 43 | 44 | # Output of the first conv concated to conv6 45 | self.inplanes = self.INIT_DIM 46 | 47 | self.conv1p1s2 = conv_dw( 48 | in_channels, 49 | kernel_size=space_n_time_m(2, 1), 50 | stride=space_n_time_m(2, 1), 51 | dilation=1, 52 | conv_type=self.NON_BLOCK_CONV_TYPE, 53 | D=D) 54 | 55 | self.conv2p2s2 = conv_dw( 56 | in_channels, 57 | kernel_size=space_n_time_m(2, 1), 58 | stride=space_n_time_m(2, 1), 59 | dilation=1, 60 | conv_type=self.NON_BLOCK_CONV_TYPE, 61 | D=D) 62 | 63 | self.conv3p4s2 = conv_dw( 64 | in_channels, 65 | kernel_size=space_n_time_m(2, 1), 66 | stride=space_n_time_m(2, 1), 67 | dilation=1, 68 | conv_type=self.NON_BLOCK_CONV_TYPE, 69 | D=D) 70 | 71 | self.conv4p8s2 = conv_dw( 72 | in_channels, 73 | kernel_size=space_n_time_m(2, 1), 74 | stride=space_n_time_m(2, 1), 75 | dilation=1, 76 | conv_type=self.NON_BLOCK_CONV_TYPE, 77 | D=D) 78 | 79 | def forward(self, x): 80 | out_p1s2 = self.conv1p1s2(x) 81 | out_p2s2 = self.conv2p2s2(out_p1s2) 82 | out_p4s2 = self.conv3p4s2(out_p2s2) 83 | out_p8s2 = self.conv4p8s2(out_p4s2) 84 | 85 | return [out_p8s2, out_p4s2, out_p2s2, out_p1s2] 86 | 87 | 88 | class Res16UNet14(Res16UNetShadow): 89 | BLOCK = BasicBlock 90 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 91 | 92 | 93 | class Res16UNet18(Res16UNetShadow): 94 | BLOCK = BasicBlock 95 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 96 | 97 | 98 | class Res16UNet34(Res16UNetShadow): 99 | BLOCK = BasicBlock 100 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 101 | 102 | 103 | class Res16UNet50(Res16UNetShadow): 104 | BLOCK = Bottleneck 105 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 106 | 107 | 108 | class Res16UNet101(Res16UNetShadow): 109 | BLOCK = Bottleneck 110 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) 111 | 112 | 113 | class Res16UNet14A(Res16UNet14): 114 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 115 | 116 | 117 | class Res16UNet14A2(Res16UNet14A): 118 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) 119 | 120 | 121 | class Res16UNet14B(Res16UNet14): 122 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 123 | 124 | 125 | class Res16UNet14B2(Res16UNet14B): 126 | LAYERS = (1, 1, 1, 1, 2, 2, 2, 2) 127 | 128 | 129 | class Res16UNet14B3(Res16UNet14B): 130 | LAYERS = (2, 2, 2, 2, 1, 1, 1, 1) 131 | 132 | 133 | class Res16UNet14C(Res16UNet14): 134 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128) 135 | 136 | 137 | class Res16UNet14D(Res16UNet14): 138 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 139 | 140 | 141 | class Res16UNet18A(Res16UNet18): 142 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 143 | 144 | 145 | class Res16UNet18B(Res16UNet18): 146 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 147 | 148 | 149 | class Res16UNet18D(Res16UNet18): 150 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 151 | 152 | 153 | class Res16UNet18E(Res16UNet18): 154 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 155 | 156 | 157 | class Res16UNet18F(Res16UNet18): 158 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 159 | 160 | 161 | class Res16UNet34A(Res16UNet34): 162 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 163 | 164 | 165 | class Res16UNet34B(Res16UNet34): 166 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 167 | 168 | 169 | class Res16UNet34C(Res16UNet34): 170 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) -------------------------------------------------------------------------------- /models/transformer_utils/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def drop_path(x, drop_prob: float = 0., training: bool = False): 6 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 7 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 8 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 9 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 10 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 11 | 'survival rate' as the argument. 12 | """ 13 | if drop_prob == 0. or not training: 14 | return x 15 | keep_prob = 1 - drop_prob 16 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 17 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 18 | random_tensor.floor_() # binarize 19 | output = x.div(keep_prob) * random_tensor 20 | return output 21 | 22 | 23 | class DropPath(nn.Module): 24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 25 | """ 26 | def __init__(self, drop_prob=None): 27 | super(DropPath, self).__init__() 28 | self.drop_prob = drop_prob 29 | 30 | def forward(self, x): 31 | return drop_path(x, self.drop_prob, self.training) 32 | 33 | 34 | class Mlp(nn.Module): 35 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 36 | super().__init__() 37 | out_features = out_features or in_features 38 | hidden_features = hidden_features or in_features 39 | self.fc1 = nn.Linear(in_features, hidden_features) 40 | self.act = act_layer() 41 | self.fc2 = nn.Linear(hidden_features, out_features) 42 | self.drop = nn.Dropout(drop) 43 | 44 | def forward(self, x): 45 | x = self.fc1(x) 46 | x = self.act(x) 47 | x = self.drop(x) 48 | x = self.fc2(x) 49 | x = self.drop(x) 50 | return x 51 | 52 | 53 | class Attention(nn.Module): 54 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 55 | super().__init__() 56 | self.num_heads = num_heads 57 | head_dim = dim // num_heads 58 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 62 | self.attn_drop = nn.Dropout(attn_drop) 63 | self.proj = nn.Linear(dim, dim) 64 | self.proj_drop = nn.Dropout(proj_drop) 65 | 66 | def forward(self, x): 67 | B, N, C = x.shape 68 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 69 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 70 | 71 | attn = (q @ k.transpose(-2, -1)) * self.scale 72 | attn = attn.softmax(dim=-1) 73 | attn = self.attn_drop(attn) 74 | 75 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 76 | x = self.proj(x) 77 | x = self.proj_drop(x) 78 | return x 79 | 80 | 81 | class CrossAttention(nn.Module): 82 | def __init__(self, dim, out_dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 83 | super().__init__() 84 | self.num_heads = num_heads 85 | self.dim = dim 86 | self.out_dim = out_dim 87 | head_dim = out_dim // num_heads 88 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 89 | self.scale = qk_scale or head_dim ** -0.5 90 | 91 | self.q_map = nn.Linear(dim, out_dim, bias=qkv_bias) 92 | self.k_map = nn.Linear(dim, out_dim, bias=qkv_bias) 93 | self.v_map = nn.Linear(dim, out_dim, bias=qkv_bias) 94 | self.attn_drop = nn.Dropout(attn_drop) 95 | 96 | self.proj = nn.Linear(out_dim, out_dim) 97 | self.proj_drop = nn.Dropout(proj_drop) 98 | 99 | def forward(self, q, v): 100 | B, N, _ = q.shape 101 | C = self.out_dim 102 | k = v 103 | NK = k.size(1) 104 | 105 | q = self.q_map(q).view(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 106 | k = self.k_map(k).view(B, NK, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 107 | v = self.v_map(v).view(B, NK, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 108 | 109 | attn = (q @ k.transpose(-2, -1)) * self.scale 110 | attn = attn.softmax(dim=-1) 111 | attn = self.attn_drop(attn) 112 | 113 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 114 | x = self.proj(x) 115 | x = self.proj_drop(x) 116 | return x 117 | 118 | 119 | class EncoderBlock(nn.Module): 120 | 121 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 122 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 123 | super().__init__() 124 | self.norm1 = norm_layer(dim) 125 | self.attn = Attention( 126 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 127 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 128 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 129 | self.norm2 = norm_layer(dim) 130 | mlp_hidden_dim = int(dim * mlp_ratio) 131 | 132 | self.knn_map = nn.Sequential( 133 | nn.Linear(dim * 2, dim), 134 | nn.LeakyReLU(negative_slope=0.2) 135 | ) 136 | 137 | self.merge_map = nn.Linear(dim * 2, dim) 138 | 139 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 140 | 141 | def forward(self, x): 142 | # x = x + self.drop_path(self.attn(self.norm1(x))) 143 | norm_x = self.norm1(x) 144 | x_1 = self.attn(norm_x) 145 | x = x + self.drop_path(x_1) 146 | x = x + self.drop_path(self.mlp(self.norm2(x))) 147 | return x 148 | 149 | 150 | class DecoderBlock(nn.Module): 151 | def __init__(self, dim, num_heads, dim_q=None, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 152 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 153 | super().__init__() 154 | self.norm1 = norm_layer(dim) 155 | self.self_attn = Attention( 156 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 157 | dim_q = dim_q or dim 158 | self.norm_q = norm_layer(dim_q) 159 | self.norm_v = norm_layer(dim) 160 | self.attn = CrossAttention( 161 | dim, dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 162 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 163 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 164 | self.norm2 = norm_layer(dim) 165 | mlp_hidden_dim = int(dim * mlp_ratio) 166 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 167 | 168 | self.knn_map = nn.Sequential( 169 | nn.Linear(dim * 2, dim), 170 | nn.LeakyReLU(negative_slope=0.2) 171 | ) 172 | 173 | self.merge_map = nn.Linear(dim * 2, dim) 174 | 175 | self.knn_map_cross = nn.Sequential( 176 | nn.Linear(dim * 2, dim), 177 | nn.LeakyReLU(negative_slope=0.2) 178 | ) 179 | 180 | self.merge_map_cross = nn.Linear(dim * 2, dim) 181 | 182 | def forward(self, q, v): 183 | # q = q + self.drop_path(self.self_attn(self.norm1(q))) 184 | norm_q = self.norm1(q) 185 | q_1 = self.self_attn(norm_q) 186 | q = q + self.drop_path(q_1) 187 | 188 | norm_q = self.norm_q(q) 189 | norm_v = self.norm_v(v) 190 | q_2 = self.attn(norm_q, norm_v) 191 | q = q + self.drop_path(q_2) 192 | 193 | # q = q + self.drop_path(self.attn(self.norm_q(q), self.norm_v(v))) 194 | q = q + self.drop_path(self.mlp(self.norm2(q))) 195 | return q 196 | 197 | 198 | class TransformerEncoder(nn.Module): 199 | def __init__(self, embed_dim=384, depth=3, num_heads=6, 200 | mlp_ratio=0.2, qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path=0.): 201 | super(TransformerEncoder, self).__init__() 202 | self.encoder = nn.ModuleList([ 203 | EncoderBlock( 204 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 205 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path) 206 | for _ in range(depth)]) 207 | 208 | def forward(self, x, pos): 209 | for i, blk in enumerate(self.encoder): 210 | x = blk(x + pos) 211 | 212 | return x 213 | 214 | 215 | class TransformerDecoder(nn.Module): 216 | def __init__(self, embed_dim=384, depth=4, num_heads=6, 217 | mlp_ratio=0.2, qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path=0.): 218 | super(TransformerDecoder, self).__init__() 219 | self.decoder = nn.ModuleList([ 220 | DecoderBlock( 221 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 222 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path) 223 | for _ in range(depth)]) 224 | 225 | def forward(self, x, q): 226 | qs = [] 227 | for blk in self.decoder: 228 | q = blk(q, x) 229 | qs.append(q) 230 | return qs -------------------------------------------------------------------------------- /models/transformer_utils/transformer_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import MinkowskiEngine as ME 5 | 6 | from models.transformer_utils.transformer import TransformerEncoder, TransformerDecoder 7 | from timm.models.layers.weight_init import trunc_normal_ 8 | 9 | 10 | class TransformerPredictor(ME.MinkowskiNetwork): 11 | def __init__(self, config, in_channels_3d, in_channels_2d, num_classes, D=3): 12 | super(TransformerPredictor, self).__init__(D) 13 | 14 | self.num_classes = num_classes 15 | self.num_view = config.viewNum 16 | self.npb = config.num_point_batch 17 | self.hidden_dim = config.hidden_dim 18 | 19 | self.pos_embed_2d = nn.Parameter(torch.zeros(1, config.num_tokens_2d, self.hidden_dim)) 20 | self.pe_layer_3d = MLP(3, in_channels_3d, self.hidden_dim, num_layers=3) 21 | trunc_normal_(self.pos_embed_2d, std=.02) 22 | 23 | self.enc_share = TransformerEncoder( 24 | embed_dim=config.hidden_dim, 25 | depth=config.enc_layers, 26 | num_heads=config.nheads, 27 | drop_rate=config.drop_rate, 28 | attn_drop_rate=config.attn_drop_rate, 29 | drop_path=config.drop_path 30 | ) 31 | 32 | self.dec_3d = TransformerDecoder( 33 | embed_dim=config.hidden_dim, 34 | depth=config.dec_layers, 35 | num_heads=config.nheads, 36 | drop_rate=config.drop_rate, 37 | attn_drop_rate=config.attn_drop_rate, 38 | drop_path=config.drop_path 39 | ) 40 | 41 | self.dec_2d = TransformerDecoder( 42 | embed_dim=config.hidden_dim, 43 | depth=config.dec_layers, 44 | num_heads=config.nheads, 45 | drop_rate=config.drop_rate, 46 | attn_drop_rate=config.attn_drop_rate, 47 | drop_path=config.drop_path 48 | ) 49 | 50 | self.query_embed_2d = nn.Embedding(num_classes, self.hidden_dim) 51 | self.query_embed_3d = nn.Embedding(num_classes, self.hidden_dim) 52 | 53 | self.input_proj_3d = nn.Conv1d(in_channels_3d, self.hidden_dim, kernel_size=1) 54 | c2_xavier_fill(self.input_proj_3d) 55 | 56 | self.input_proj_2d = nn.Conv1d(in_channels_2d, self.hidden_dim, kernel_size=1) 57 | c2_xavier_fill(self.input_proj_2d) 58 | 59 | self.mask_embed_2d = MLP(self.hidden_dim, self.hidden_dim, config.mask_dim, 3) 60 | self.mask_embed_3d = MLP(self.hidden_dim, self.hidden_dim, config.mask_dim, 3) 61 | 62 | def forward(self, x_2d, x_3d): 63 | ## 3d pre-processing ## 64 | srcs_3d = [] 65 | pos_batch = [] 66 | for i in range(len(x_3d.C[:, 0].unique())): 67 | mask_i = (x_3d.C[:, 0] == i) 68 | src = x_3d.F[mask_i] 69 | pos = self.pe_layer_3d(x_3d.C[mask_i][:, 1:].type(torch.float)) 70 | if len(src) > self.npb: 71 | r = torch.randint(0, len(src), (self.npb,)) 72 | src = src[r] 73 | pos = pos[r] 74 | elif len(src) < self.npb: 75 | r = torch.randint(0, len(src), (self.npb - len(src),)) 76 | src_repeat = src[r] 77 | pos_repeat = pos[r] 78 | src = torch.cat([src, src_repeat], dim=0) 79 | pos = torch.cat([pos, pos_repeat], dim=0) 80 | srcs_3d.append(src) 81 | pos_batch.append(pos) 82 | srcs_3d = torch.stack(srcs_3d, dim=0).transpose(1, 2).contiguous() 83 | trans_input_3d = self.input_proj_3d(srcs_3d).transpose(1, 2).contiguous() 84 | pe_3d = torch.stack(pos_batch, dim=0) 85 | 86 | ## 2d pre-processing ## 87 | VB, C, H, W = x_2d.size() 88 | B = int(VB // self.num_view) 89 | N_2d = self.num_view*H*W 90 | assert(trans_input_3d.shape[0] == B) 91 | srcs_2d = x_2d.view(VB, C, H*W) 92 | trans_input_2d = self.input_proj_2d(srcs_2d).transpose(1, 2).contiguous() 93 | trans_input_2d = trans_input_2d.view(self.num_view, B, H*W, self.hidden_dim).transpose(0, 1).contiguous().view(B, N_2d, self.hidden_dim) 94 | 95 | trans_input = torch.cat([trans_input_2d, trans_input_3d], dim=1) 96 | pos_input = torch.cat([self.pos_embed_2d.expand(VB, -1, -1).view(self.num_view, B, H*W, self.hidden_dim).transpose(0, 1).contiguous().view(B, N_2d, self.hidden_dim), 97 | pe_3d], dim=1) 98 | trans_feat = self.enc_share(trans_input, pos_input) 99 | trans_feat_2d = trans_feat[:, :N_2d, :].view(B, self.num_view, H*W, self.hidden_dim).transpose(0, 1).contiguous().view(VB, H*W, self.hidden_dim) 100 | trans_feat_3d = trans_feat[:, N_2d:, :] 101 | 102 | hs_adain_2d = self.dec_2d(trans_feat_2d, self.query_embed_2d.weight.unsqueeze(dim=0).expand(VB, -1, -1)) 103 | hs_adain_3d = self.dec_3d(trans_feat_3d, self.query_embed_3d.weight.unsqueeze(dim=0).expand(B, -1, -1)) 104 | 105 | mask_embed_2d = self.mask_embed_2d(hs_adain_2d[-1]) 106 | mask_embed_3d = self.mask_embed_3d(hs_adain_3d[-1]) 107 | 108 | return mask_embed_2d, mask_embed_3d, hs_adain_2d, hs_adain_3d 109 | 110 | 111 | class Conv2d(torch.nn.Conv2d): 112 | """ 113 | A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. 114 | """ 115 | 116 | def __init__(self, *args, **kwargs): 117 | """ 118 | Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: 119 | Args: 120 | norm (nn.Module, optional): a normalization layer 121 | activation (callable(Tensor) -> Tensor): a callable activation function 122 | It assumes that norm layer is used before activation. 123 | """ 124 | norm = kwargs.pop("norm", None) 125 | activation = kwargs.pop("activation", None) 126 | super().__init__(*args, **kwargs) 127 | 128 | self.norm = norm 129 | self.activation = activation 130 | 131 | def forward(self, x): 132 | # torchscript does not support SyncBatchNorm yet 133 | # https://github.com/pytorch/pytorch/issues/40507 134 | # and we skip these codes in torchscript since: 135 | # 1. currently we only support torchscript in evaluation mode 136 | # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or 137 | # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. 138 | if not torch.jit.is_scripting(): 139 | if x.numel() == 0 and self.training: 140 | # https://github.com/pytorch/pytorch/issues/12013 141 | assert not isinstance( 142 | self.norm, torch.nn.SyncBatchNorm 143 | ), "SyncBatchNorm does not support empty inputs!" 144 | 145 | x = F.conv2d( 146 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups 147 | ) 148 | if self.norm is not None: 149 | x = self.norm(x) 150 | if self.activation is not None: 151 | x = self.activation(x) 152 | return x 153 | 154 | 155 | def c2_xavier_fill(module: nn.Module) -> None: 156 | """ 157 | Initialize `module.weight` using the "XavierFill" implemented in Caffe2. 158 | Also initializes `module.bias` to 0. 159 | Args: 160 | module (torch.nn.Module): module to initialize. 161 | """ 162 | # Caffe2 implementation of XavierFill in fact 163 | # corresponds to kaiming_uniform_ in PyTorch 164 | nn.init.kaiming_uniform_(module.weight, a=1) 165 | if module.bias is not None: 166 | # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, 167 | # torch.Tensor]`. 168 | nn.init.constant_(module.bias, 0) 169 | 170 | 171 | class MLP(nn.Module): 172 | """Very simple multi-layer perceptron (also called FFN)""" 173 | 174 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 175 | super().__init__() 176 | self.num_layers = num_layers 177 | h = [hidden_dim] * (num_layers - 1) 178 | self.layers = nn.ModuleList( 179 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 180 | ) 181 | 182 | def forward(self, x): 183 | for i, layer in enumerate(self.layers): 184 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 185 | return x -------------------------------------------------------------------------------- /models/unet_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | import models.resnet_d as models 6 | 7 | 8 | class ResUnet(nn.Module): 9 | def __init__(self, layers=18, classes=2, BatchNorm=nn.BatchNorm2d, pretrained=True, out_channels=128): 10 | super(ResUnet, self).__init__() 11 | assert classes > 1 12 | models.BatchNorm = BatchNorm 13 | if layers == 18: 14 | resnet = models.resnet18(pretrained=True, deep_base=False) 15 | block = models.BasicBlock 16 | layers = [2, 2, 2, 2] 17 | elif layers == 34: 18 | resnet = models.resnet34d(pretrained=True) 19 | block = models.BasicBlock 20 | layers = [3, 4, 6, 3] 21 | elif layers == 50: 22 | resnet = models.resnet50(pretrained=True, deep_base=False) 23 | block = models.BasicBlock 24 | layers = [3, 4, 6, 3] 25 | self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.maxpool) 26 | 27 | self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 28 | self.relu = nn.ReLU() 29 | 30 | # Decoder 31 | # self.up4 = nn.Sequential(nn.ConvTranspose2d(512,256,kernel_size=2,stride=2),BatchNorm(256),nn.ReLU()) 32 | # self.up4 = nn.Sequential(nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), BatchNorm(256), nn.ReLU()) 33 | self.class_convfeat = nn.Conv2d(512, out_channels, kernel_size=1) 34 | self.up4 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1) 35 | self.class_conv4 = nn.Conv2d(256, out_channels, kernel_size=1) 36 | self.bn4 = nn.BatchNorm2d(out_channels, affine=False) 37 | inplanes = out_channels + 256 38 | self.delayer4, _ = models.make_blocks(block, [256], [layers[-1]], inplanes) 39 | self.delayer4 = self.delayer4[0][1] 40 | 41 | # self.up3 = nn.Sequential(nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), BatchNorm(128), nn.ReLU()) 42 | self.up3 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 43 | self.class_conv3 = nn.Conv2d(128, out_channels, kernel_size=1) 44 | self.bn3 = nn.BatchNorm2d(out_channels, affine=False) 45 | inplanes = out_channels + 128 46 | self.delayer3, _ = models.make_blocks(block, [128], [layers[-2]], inplanes) 47 | self.delayer3 = self.delayer3[0][1] 48 | 49 | # self.up2 = nn.Sequential(nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1), BatchNorm(96), nn.ReLU()) 50 | self.up2 = nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1) 51 | self.class_conv2 = nn.Conv2d(96, out_channels, kernel_size=1) 52 | self.bn2 = nn.BatchNorm2d(out_channels, affine=False) 53 | inplanes = out_channels + 64 54 | self.delayer2, _ = models.make_blocks(block, [96], [layers[-3]], inplanes) 55 | self.delayer2 = self.delayer2[0][1] 56 | 57 | self.cls = nn.Sequential( 58 | nn.Conv2d(96, 256, kernel_size=3, padding=1, bias=False), 59 | BatchNorm(256), 60 | nn.ReLU(inplace=True), 61 | nn.Conv2d(256, out_channels, kernel_size=1) 62 | ) 63 | if self.training: 64 | self.aux = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False), 65 | BatchNorm(256), nn.ReLU(inplace=True), 66 | nn.Conv2d(256, classes, kernel_size=1)) 67 | 68 | def forward(self, x): 69 | _, _, h, w = x.shape 70 | x = self.layer0(x) # 1/4 71 | x2 = self.layer1(x) # 1/4 72 | x3 = self.layer2(x2) # 1/8 73 | x4 = self.layer3(x3) # 1/16 74 | x5 = self.layer4(x4) # 1/32 75 | p4 = self.up4(F.interpolate(x5, x4.shape[-2:], mode='bilinear', align_corners=True)) 76 | p4 = torch.cat([p4, x4], dim=1) 77 | p4 = self.delayer4(p4) 78 | p3 = self.up3(F.interpolate(p4, x3.shape[-2:], mode='bilinear', align_corners=True)) 79 | p3 = torch.cat([p3, x3], dim=1) 80 | p3 = self.delayer3(p3) 81 | p2 = self.up2(F.interpolate(p3, x2.shape[-2:], mode='bilinear', align_corners=True)) 82 | p2 = torch.cat([p2, x2], dim=1) 83 | p2 = self.delayer2(p2) 84 | x = self.cls(p2) 85 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) 86 | if self.training: 87 | aux = self.aux(x4) 88 | aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True) 89 | return x, aux 90 | else: 91 | return x 92 | -------------------------------------------------------------------------------- /models/unet_3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import MinkowskiEngine as ME 25 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 26 | from models.resnet_mink import ResNetBase 27 | 28 | 29 | class MinkUNetBase(ResNetBase): 30 | BLOCK = None 31 | PLANES = None 32 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 33 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 34 | INIT_DIM = 32 35 | OUT_TENSOR_STRIDE = 1 36 | 37 | # To use the model, must call initialize_coords before forward pass. 38 | # Once data is processed, call clear to reset the model before calling 39 | # initialize_coords 40 | def __init__(self, in_channels, out_channels, D=3): 41 | ResNetBase.__init__(self, in_channels, out_channels, D) 42 | 43 | def network_initialization(self, in_channels, out_channels, D): 44 | # Output of the first conv concated to conv6 45 | self.inplanes = self.INIT_DIM 46 | self.out_channels = out_channels 47 | self.conv0p1s1 = ME.MinkowskiConvolution( 48 | in_channels, self.inplanes, kernel_size=5, dimension=D) 49 | 50 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 51 | 52 | self.conv1p1s2 = ME.MinkowskiConvolution( 53 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 54 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 55 | 56 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], 57 | self.LAYERS[0]) 58 | 59 | self.conv2p2s2 = ME.MinkowskiConvolution( 60 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 61 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) 62 | 63 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], 64 | self.LAYERS[1]) 65 | 66 | self.conv3p4s2 = ME.MinkowskiConvolution( 67 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 68 | 69 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) 70 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], 71 | self.LAYERS[2]) 72 | 73 | self.conv4p8s2 = ME.MinkowskiConvolution( 74 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 75 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) 76 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], 77 | self.LAYERS[3]) 78 | 79 | self.class_convfeat = ME.MinkowskiConvolution( 80 | self.inplanes, out_channels, kernel_size=1, stride=1, bias=False, dimension=D) 81 | 82 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( 83 | self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D) 84 | self.class_conv4 = ME.MinkowskiConvolution( 85 | self.PLANES[4], out_channels, kernel_size=1, stride=1, bias=False, dimension=D) 86 | self.bntr4 = ME.MinkowskiBatchNorm(out_channels, affine=False) 87 | 88 | self.inplanes = out_channels + self.PLANES[2] * self.BLOCK.expansion 89 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], 90 | self.LAYERS[4]) 91 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( 92 | self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D) 93 | self.class_conv5 = ME.MinkowskiConvolution( 94 | self.PLANES[5], out_channels, kernel_size=1, stride=1, bias=False, dimension=D) 95 | self.bntr5 = ME.MinkowskiBatchNorm(out_channels, affine=False) 96 | 97 | self.inplanes = out_channels + self.PLANES[1] * self.BLOCK.expansion 98 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], 99 | self.LAYERS[5]) 100 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( 101 | self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D) 102 | self.class_conv6 = ME.MinkowskiConvolution( 103 | self.PLANES[6], out_channels, kernel_size=1, stride=1, bias=False, dimension=D) 104 | self.bntr6 = ME.MinkowskiBatchNorm(out_channels, affine=False) 105 | 106 | self.inplanes = out_channels + self.PLANES[0] * self.BLOCK.expansion 107 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], 108 | self.LAYERS[6]) 109 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( 110 | self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D) 111 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) 112 | 113 | self.inplanes = self.PLANES[7] + self.INIT_DIM 114 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], 115 | self.LAYERS[7]) 116 | 117 | self.final = ME.MinkowskiConvolution( 118 | self.PLANES[7], 119 | out_channels, 120 | kernel_size=1, 121 | bias=True, 122 | dimension=D) 123 | self.relu = ME.MinkowskiReLU(inplace=True) 124 | 125 | def forward(self, x): 126 | out = self.conv0p1s1(x) 127 | out = self.bn0(out) 128 | out_p1 = self.relu(out) 129 | 130 | out = self.conv1p1s2(out_p1) 131 | out = self.bn1(out) 132 | out = self.relu(out) 133 | out_b1p2 = self.block1(out) 134 | 135 | out = self.conv2p2s2(out_b1p2) 136 | out = self.bn2(out) 137 | out = self.relu(out) 138 | out_b2p4 = self.block2(out) 139 | 140 | out = self.conv3p4s2(out_b2p4) 141 | out = self.bn3(out) 142 | out = self.relu(out) 143 | out_b3p8 = self.block3(out) 144 | 145 | # tensor_stride=16 146 | out = self.conv4p8s2(out_b3p8) 147 | out = self.bn4(out) 148 | out = self.relu(out) 149 | out = self.block4(out) 150 | 151 | # tensor_stride=8 152 | out = self.convtr4p16s2(out) 153 | out = self.bntr4(out) 154 | out = self.relu(out) 155 | 156 | out = ME.cat(out, out_b3p8) 157 | out = self.block5(out) 158 | 159 | # tensor_stride=4 160 | out = self.convtr5p8s2(out) 161 | out = self.bntr5(out) 162 | out = self.relu(out) 163 | 164 | out = ME.cat(out, out_b2p4) 165 | out = self.block6(out) 166 | 167 | # tensor_stride=2 168 | out = self.convtr6p4s2(out) 169 | out = self.bntr6(out) 170 | out = self.relu(out) 171 | 172 | out = ME.cat(out, out_b1p2) 173 | out = self.block7(out) 174 | 175 | # tensor_stride=1 176 | out = self.convtr7p2s2(out) 177 | out = self.bntr7(out) 178 | out = self.relu(out) 179 | 180 | out = ME.cat(out, out_p1) 181 | out = self.block8(out) 182 | 183 | return self.final(out).F 184 | 185 | 186 | class MinkUNet14(MinkUNetBase): 187 | BLOCK = BasicBlock 188 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 189 | 190 | 191 | class MinkUNet18(MinkUNetBase): 192 | BLOCK = BasicBlock 193 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 194 | 195 | 196 | class MinkUNet34(MinkUNetBase): 197 | BLOCK = BasicBlock 198 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 199 | 200 | 201 | class MinkUNet50(MinkUNetBase): 202 | BLOCK = Bottleneck 203 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 204 | 205 | 206 | class MinkUNet101(MinkUNetBase): 207 | BLOCK = Bottleneck 208 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) 209 | 210 | 211 | class MinkUNet14A(MinkUNet14): 212 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 213 | 214 | 215 | class MinkUNet14B(MinkUNet14): 216 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 217 | 218 | 219 | class MinkUNet14C(MinkUNet14): 220 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128) 221 | 222 | 223 | class MinkUNet14D(MinkUNet14): 224 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 225 | 226 | 227 | class MinkUNet18A(MinkUNet18): 228 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 229 | 230 | 231 | class MinkUNet18B(MinkUNet18): 232 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 233 | 234 | 235 | class MinkUNet18D(MinkUNet18): 236 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 237 | 238 | 239 | class MinkUNet34A(MinkUNet34): 240 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 241 | 242 | 243 | class MinkUNet34B(MinkUNet34): 244 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 245 | 246 | 247 | class MinkUNet34C(MinkUNet34): 248 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 249 | 250 | 251 | def mink_unet(in_channels=3, out_channels=20, D=3, arch='MinkUNet18A'): 252 | if arch == 'MinkUNet18A': 253 | return MinkUNet18A(in_channels, out_channels, D) 254 | elif arch == 'MinkUNet18B': 255 | return MinkUNet18B(in_channels, out_channels, D) 256 | elif arch == 'MinkUNet18D': 257 | return MinkUNet18D(in_channels, out_channels, D) 258 | elif arch == 'MinkUNet34A': 259 | return MinkUNet34A(in_channels, out_channels, D) 260 | elif arch == 'MinkUNet34B': 261 | return MinkUNet34B(in_channels, out_channels, D) 262 | elif arch == 'MinkUNet34C': 263 | return MinkUNet34C(in_channels, out_channels, D) 264 | elif arch == 'MinkUNet14A': 265 | return MinkUNet14A(in_channels, out_channels, D) 266 | elif arch == 'MinkUNet14B': 267 | return MinkUNet14B(in_channels, out_channels, D) 268 | elif arch == 'MinkUNet14C': 269 | return MinkUNet14C(in_channels, out_channels, D) 270 | elif arch == 'MinkUNet14D': 271 | return MinkUNet14D(in_channels, out_channels, D) 272 | else: 273 | raise Exception('architecture not supported yet'.format(arch)) 274 | -------------------------------------------------------------------------------- /tool/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import OrderedDict 4 | import numpy as np 5 | import logging 6 | import argparse 7 | 8 | import cv2 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.parallel 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.multiprocessing as mp 15 | import torch.distributed as dist 16 | from os.path import join 17 | from metrics import iou 18 | 19 | from MinkowskiEngine import SparseTensor, CoordsManager 20 | from util import config 21 | from util.util import AverageMeter, intersectionAndUnionGPU 22 | from tqdm import tqdm 23 | from tool.train import get_model 24 | 25 | cv2.ocl.setUseOpenCL(False) 26 | cv2.setNumThreads(0) 27 | 28 | 29 | def worker_init_fn(worker_id): 30 | random.seed(1463 + worker_id) 31 | np.random.seed(1463 + worker_id) 32 | torch.manual_seed(1463 + worker_id) 33 | 34 | 35 | def get_parser(): 36 | parser = argparse.ArgumentParser(description='SemAffiNet') 37 | parser.add_argument('--config', type=str, default='config/scannet/SemAffiNet_5cm.yaml', help='config file') 38 | parser.add_argument('opts', help='see config/scannet/SemAffiNet_5cm.yaml for all options', default=None, 39 | nargs=argparse.REMAINDER) 40 | args = parser.parse_args() 41 | assert args.config is not None 42 | cfg = config.load_cfg_from_cfg_file(args.config) 43 | if args.opts is not None: 44 | cfg = config.merge_cfg_from_list(cfg, args.opts) 45 | return cfg 46 | 47 | 48 | def get_logger(): 49 | logger_name = "main-logger" 50 | logger = logging.getLogger(logger_name) 51 | logger.setLevel(logging.INFO) 52 | handler = logging.StreamHandler() 53 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 54 | handler.setFormatter(logging.Formatter(fmt)) 55 | logger.addHandler(handler) 56 | return logger 57 | 58 | 59 | def main_process(): 60 | return not args.multiprocessing_distributed or ( 61 | args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) 62 | 63 | 64 | def main(): 65 | args = get_parser() 66 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(x) for x in args.train_gpu) 67 | cudnn.benchmark = True 68 | # for cudnn bug at https://github.com/pytorch/pytorch/issues/4107 69 | # https://github.com/Microsoft/human-pose-estimation.pytorch/issues/8 70 | # https://discuss.pytorch.org/t/training-performance-degrades-with-distributeddataparallel/47152/7 71 | # torch.backends.cudnn.enabled = False 72 | 73 | if args.manual_seed is not None: 74 | random.seed(args.manual_seed) 75 | np.random.seed(args.manual_seed) 76 | torch.manual_seed(args.manual_seed) 77 | torch.cuda.manual_seed(args.manual_seed) 78 | torch.cuda.manual_seed_all(args.manual_seed) 79 | # cudnn.benchmark = False 80 | # cudnn.deterministic = True 81 | 82 | print( 83 | 'torch.__version__:%s\ntorch.version.cuda:%s\ntorch.backends.cudnn.version:%s\ntorch.backends.cudnn.enabled:%s' % ( 84 | torch.__version__, torch.version.cuda, torch.backends.cudnn.version(), torch.backends.cudnn.enabled)) 85 | 86 | if args.dist_url == "env://" and args.world_size == -1: 87 | args.world_size = int(os.environ["WORLD_SIZE"]) 88 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 89 | args.ngpus_per_node = len(args.test_gpu) 90 | if len(args.test_gpu) == 1: 91 | args.sync_bn = False 92 | args.distributed = False 93 | args.multiprocessing_distributed = False 94 | args.use_apex = False 95 | 96 | # Following code is for caching dataset into memory 97 | if args.data_name == 'scannet_3d_mink': 98 | from dataset.scanNet3D import ScanNet3D, collation_fn_eval_all 99 | _ = ScanNet3D(dataPathPrefix=args.data_root, voxelSize=args.voxelSize, split='val', aug=False, 100 | memCacheInit=True, eval_all=True, identifier=6738) 101 | elif args.data_name == 'scannet_cross': 102 | from dataset.scanNetCross import ScanNetCross, collation_fn, collation_fn_eval_all 103 | _ = ScanNetCross(dataPathPrefix=args.data_root, voxelSize=args.voxelSize, split='val', aug=False, 104 | memCacheInit=True, eval_all=True, identifier=6738, val_benchmark=args.val_benchmark) 105 | 106 | if args.multiprocessing_distributed: 107 | args.world_size = args.ngpus_per_node * args.world_size 108 | mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args)) 109 | else: 110 | main_worker(args.test_gpu, args.ngpus_per_node, args) 111 | 112 | 113 | def main_worker(gpu, ngpus_per_node, argss): 114 | global args 115 | args = argss 116 | if args.distributed: 117 | if args.dist_url == "env://" and args.rank == -1: 118 | args.rank = int(os.environ["RANK"]) 119 | if args.multiprocessing_distributed: 120 | args.rank = args.rank * ngpus_per_node + gpu 121 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, 122 | rank=args.rank) 123 | 124 | model = get_model(args) 125 | if main_process(): 126 | global logger 127 | logger = get_logger() 128 | logger.info(args) 129 | logger.info("=> creating model ...") 130 | logger.info("Classes: {}".format(args.classes)) 131 | logger.info(model) 132 | 133 | if args.distributed: 134 | torch.cuda.set_device(gpu) 135 | args.test_batch_size = int(args.test_batch_size / ngpus_per_node) 136 | args.test_workers = int(args.test_workers / ngpus_per_node) 137 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu]) 138 | else: 139 | model = model.cuda() 140 | 141 | if os.path.isfile(args.model_path): 142 | if main_process(): 143 | logger.info("=> loading checkpoint '{}'".format(args.model_path)) 144 | checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage.cuda()) 145 | if not args.distributed: 146 | state_dict = {} 147 | for key in checkpoint['state_dict'].keys(): 148 | state_dict[key[7:]] = checkpoint['state_dict'][key] 149 | state_dict = OrderedDict(state_dict) 150 | else: 151 | state_dict = checkpoint['state_dict'] 152 | model.load_state_dict(state_dict, strict=True) 153 | if main_process(): 154 | logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.model_path, checkpoint['epoch'])) 155 | else: 156 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.model_path)) 157 | 158 | # ####################### Data Loader ####################### # 159 | if args.data_name == 'scannet_3d_mink': 160 | from dataset.scanNet3D import ScanNet3D, collation_fn_eval_all 161 | val_data = ScanNet3D(dataPathPrefix=args.data_root, voxelSize=args.voxelSize, split='val', aug=False, 162 | memCacheInit=True, eval_all=True, identifier=6738) 163 | val_sampler = None 164 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.test_batch_size, 165 | shuffle=False, num_workers=args.test_workers, pin_memory=True, 166 | drop_last=False, collate_fn=collation_fn_eval_all, 167 | sampler=val_sampler) 168 | elif args.data_name == 'scannet_cross': 169 | from dataset.scanNetCross import ScanNetCross, collation_fn_eval_all 170 | val_data = ScanNetCross(dataPathPrefix=args.data_root, voxelSize=args.voxelSize, split='val', aug=False, 171 | memCacheInit=True, eval_all=True, identifier=6738, val_benchmark=args.val_benchmark, 172 | view_num=args.viewNum) 173 | val_sampler = None 174 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.test_batch_size, 175 | shuffle=False, num_workers=args.test_workers, pin_memory=True, 176 | drop_last=False, collate_fn=collation_fn_eval_all, 177 | sampler=val_sampler) 178 | else: 179 | raise Exception('Dataset not supported yet'.format(args.data_name)) 180 | 181 | # ####################### Test ####################### # 182 | if args.data_name == 'scannet_3d_mink': 183 | validate(model, val_loader) 184 | elif args.data_name == 'scannet_cross': 185 | test_cross_3d(model, val_loader) 186 | 187 | 188 | def validate(model, val_loader): 189 | torch.backends.cudnn.enabled = False # for cudnn bug at https://github.com/pytorch/pytorch/issues/4107 190 | model.eval() 191 | with torch.no_grad(): 192 | store = 0.0 193 | for rep_i in range(args.test_repeats): 194 | preds = [] 195 | gts = [] 196 | for i, (coords, feat, label, inds_reverse) in enumerate(tqdm(val_loader)): 197 | sinput = SparseTensor(feat.cuda(non_blocking=True), coords.cuda(non_blocking=True)) 198 | predictions = model(sinput) 199 | predictions_enlarge = predictions[inds_reverse, :] 200 | if args.multiprocessing_distributed: 201 | dist.all_reduce(predictions_enlarge) 202 | preds.append(predictions_enlarge.detach_().cpu()) 203 | gts.append(label.cpu()) 204 | gt = torch.cat(gts) 205 | pred = torch.cat(preds) 206 | current_iou = iou.evaluate(pred.max(1)[1].numpy(), gt.numpy()) 207 | if rep_i == 0 and main_process(): 208 | np.save(join(args.save_folder, 'gt.npy'), gt.numpy()) 209 | store = pred + store 210 | accumu_iou = iou.evaluate(store.max(1)[1].numpy(), gt.numpy()) 211 | if main_process(): 212 | np.save(join(args.save_folder, 'pred.npy'), store.max(1)[1].numpy()) 213 | 214 | 215 | def test_cross_3d(model, val_data_loader): 216 | torch.backends.cudnn.enabled = False # for cudnn bug at https://github.com/pytorch/pytorch/issues/4107 217 | intersection_meter = AverageMeter() 218 | union_meter = AverageMeter() 219 | target_meter = AverageMeter() 220 | 221 | with torch.no_grad(): 222 | model.eval() 223 | store = 0.0 224 | for rep_i in range(args.test_repeats): 225 | preds, gts = [], [] 226 | val_data_loader.dataset.offset = rep_i 227 | if main_process(): 228 | pbar = tqdm(total=len(val_data_loader)) 229 | for i, (coords, feat, label_3d, color, label_2d, link, inds_reverse) in enumerate(val_data_loader): 230 | if main_process(): 231 | pbar.update(1) 232 | sinput = SparseTensor(feat.cuda(non_blocking=True), coords.cuda(non_blocking=True)) 233 | color, link = color.cuda(non_blocking=True), link.cuda(non_blocking=True) 234 | label_3d, label_2d = label_3d.cuda(non_blocking=True), label_2d.cuda(non_blocking=True) 235 | predictions_3d, predictions_2d = model(sinput, color, link) 236 | output_3d = predictions_3d['pred_masks'] 237 | output_2d = predictions_2d['pred_masks'] 238 | output_2d = output_2d.contiguous() 239 | output_3d = output_3d[inds_reverse, :] 240 | if args.multiprocessing_distributed: 241 | dist.all_reduce(output_3d) 242 | 243 | output_2d = output_2d.detach().max(1)[1] 244 | intersection, union, target = intersectionAndUnionGPU(output_2d, label_2d.detach(), args.classes, 245 | args.ignore_label) 246 | intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy() 247 | intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target) 248 | 249 | preds.append(output_3d.detach_().cpu().numpy().astype(np.half)) 250 | gts.append(label_3d.cpu().numpy()) 251 | torch.cuda.empty_cache() 252 | if main_process(): 253 | pbar.close() 254 | gt = np.concatenate(gts) 255 | pred = np.concatenate(preds) 256 | if rep_i == 0: 257 | np.save(join(args.save_folder, 'gt.npy'), gt) 258 | store = pred + store 259 | pred_id = np.argmax(store, axis=1) 260 | mIou_3d = iou.evaluate(pred_id, gt) 261 | np.save(join(args.save_folder, 'pred.npy'), pred_id) 262 | 263 | iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) 264 | # accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10) 265 | mIoU_2d = np.mean(iou_class) 266 | # mAcc = np.mean(accuracy_class) 267 | # allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) 268 | if main_process(): 269 | print("2D: ", mIoU_2d, ", 3D: ", mIou_3d) 270 | 271 | 272 | if __name__ == '__main__': 273 | main() 274 | -------------------------------------------------------------------------------- /tool/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # LOG 3 | # shellcheck disable=SC2230 4 | # shellcheck disable=SC2086 5 | set -x 6 | # Exit script when a command returns nonzero state 7 | set -e 8 | #set -o pipefail 9 | 10 | exp_name=$1 11 | config=$2 12 | dataset=$3 13 | T=$4 14 | 15 | export OPENBLAS_NUM_THREADS=${T} 16 | export GOTO_NUM_THREADS=${T} 17 | export OMP_NUM_THREADS=${T} 18 | export KMP_INIT_AT_FORK=FALSE 19 | 20 | PYTHON=python 21 | TRAIN_CODE=train.py 22 | TEST_CODE=test.py 23 | 24 | 25 | exp_dir=Exp/${dataset}/${exp_name} 26 | model_dir=${exp_dir}/model 27 | result_dir=${exp_dir}/result 28 | 29 | now=$(date +"%Y%m%d_%H%M%S") 30 | 31 | cp tool/test.sh tool/${TEST_CODE} ${exp_dir} 32 | mkdir -p ${result_dir}/last 33 | mkdir -p ${result_dir}/best 34 | 35 | export PYTHONPATH=. 36 | #rm -rf /dev/shm/wbhu* 37 | echo $OMP_NUM_THREADS | tee -a ${exp_dir}/test_last-$now.log 38 | nvidia-smi | tee -a ${exp_dir}/test_last-$now.log 39 | which pip | tee -a ${exp_dir}/test_last-$now.log 40 | 41 | # TEST 42 | #rm -rf /dev/shm/wbhu* 43 | now=$(date +"%Y%m%d_%H%M%S") 44 | 45 | $PYTHON -u ${exp_dir}/${TEST_CODE} \ 46 | --config=${config} \ 47 | save_folder ${result_dir}/best \ 48 | model_path ${model_dir}/model_best.pth.tar \ 49 | 2>&1 | tee -a ${exp_dir}/test_best-$now.log 50 | 51 | $PYTHON -u ${exp_dir}/${TEST_CODE} \ 52 | --config=${config} \ 53 | save_folder ${result_dir}/last \ 54 | model_path ${model_dir}/model_last.pth.tar \ 55 | 2>&1 | tee -a ${exp_dir}/test_last-$now.log -------------------------------------------------------------------------------- /tool/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # LOG 3 | # shellcheck disable=SC2230 4 | # shellcheck disable=SC2086 5 | set -x 6 | # Exit script when a command returns nonzero state 7 | set -e 8 | #set -o pipefail 9 | 10 | exp_name=$1 11 | config=$2 12 | dataset=$3 13 | T=$4 14 | 15 | export OPENBLAS_NUM_THREADS=${T} 16 | export GOTO_NUM_THREADS=${T} 17 | export OMP_NUM_THREADS=${T} 18 | export KMP_INIT_AT_FORK=FALSE 19 | 20 | PYTHON=python 21 | TRAIN_CODE=train.py 22 | TEST_CODE=test.py 23 | 24 | 25 | exp_dir=Exp/${dataset}/${exp_name} 26 | model_dir=${exp_dir}/model 27 | result_dir=${exp_dir}/result 28 | 29 | now=$(date +"%Y%m%d_%H%M%S") 30 | 31 | mkdir -p ${model_dir} ${result_dir} 32 | mkdir -p ${result_dir}/last 33 | mkdir -p ${result_dir}/best 34 | cp tool/train.sh tool/${TRAIN_CODE} ${config} tool/test.sh tool/${TEST_CODE} ${exp_dir} 35 | 36 | export PYTHONPATH=. 37 | #rm -rf /dev/shm/wbhu* 38 | echo $OMP_NUM_THREADS | tee -a ${exp_dir}/train-$now.log 39 | nvidia-smi | tee -a ${exp_dir}/train-$now.log 40 | which pip | tee -a ${exp_dir}/train-$now.log 41 | 42 | now=$(date +"%Y%m%d_%H%M%S") 43 | $PYTHON -u ${exp_dir}/${TRAIN_CODE} \ 44 | --config=${config} \ 45 | save_path ${exp_dir} \ 46 | 2>&1 | tee -a ${exp_dir}/train-$now.log 47 | 48 | # TEST 49 | #rm -rf /dev/shm/wbhu* 50 | now=$(date +"%Y%m%d_%H%M%S") 51 | 52 | $PYTHON -u ${exp_dir}/${TEST_CODE} \ 53 | --config=${config} \ 54 | save_folder ${result_dir}/best \ 55 | model_path ${model_dir}/model_best.pth.tar \ 56 | 2>&1 | tee -a ${exp_dir}/test_best-$now.log 57 | 58 | $PYTHON -u ${exp_dir}/${TEST_CODE} \ 59 | --config=${config} \ 60 | save_folder ${result_dir}/last \ 61 | model_path ${model_dir}/model_last.pth.tar \ 62 | 2>&1 | tee -a ${exp_dir}/test_last-$now.log 63 | -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------- 2 | # Functions for parsing args 3 | # ----------------------------------------------------------------------------- 4 | import yaml 5 | import os 6 | from ast import literal_eval 7 | import copy 8 | 9 | 10 | class CfgNode(dict): 11 | """ 12 | CfgNode represents an internal node in the configuration tree. It's a simple 13 | dict-like container that allows for attribute-based access to keys. 14 | """ 15 | 16 | def __init__(self, init_dict=None, key_list=None, new_allowed=False): 17 | # Recursively convert nested dictionaries in init_dict into CfgNodes 18 | init_dict = {} if init_dict is None else init_dict 19 | key_list = [] if key_list is None else key_list 20 | for k, v in init_dict.items(): 21 | if type(v) is dict: 22 | # Convert dict to CfgNode 23 | init_dict[k] = CfgNode(v, key_list=key_list + [k]) 24 | super(CfgNode, self).__init__(init_dict) 25 | 26 | def __getattr__(self, name): 27 | if name in self: 28 | return self[name] 29 | else: 30 | raise AttributeError(name) 31 | 32 | def __setattr__(self, name, value): 33 | self[name] = value 34 | 35 | def __str__(self): 36 | def _indent(s_, num_spaces): 37 | s = s_.split("\n") 38 | if len(s) == 1: 39 | return s_ 40 | first = s.pop(0) 41 | s = [(num_spaces * " ") + line for line in s] 42 | s = "\n".join(s) 43 | s = first + "\n" + s 44 | return s 45 | 46 | r = "" 47 | s = [] 48 | for k, v in sorted(self.items()): 49 | seperator = "\n" if isinstance(v, CfgNode) else " " 50 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 51 | attr_str = _indent(attr_str, 2) 52 | s.append(attr_str) 53 | r += "\n".join(s) 54 | return r 55 | 56 | def __repr__(self): 57 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 58 | 59 | 60 | def load_cfg_from_cfg_file(file): 61 | cfg = {} 62 | assert os.path.isfile(file) and file.endswith('.yaml'), \ 63 | '{} is not a yaml file'.format(file) 64 | 65 | with open(file, 'r') as f: 66 | cfg_from_file = yaml.safe_load(f) 67 | 68 | for key in cfg_from_file: 69 | for k, v in cfg_from_file[key].items(): 70 | cfg[k] = v 71 | 72 | cfg = CfgNode(cfg) 73 | return cfg 74 | 75 | 76 | def merge_cfg_from_list(cfg, cfg_list): 77 | new_cfg = copy.deepcopy(cfg) 78 | assert len(cfg_list) % 2 == 0 79 | for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]): 80 | subkey = full_key.split('.')[-1] 81 | assert subkey in cfg, 'Non-existent key: {}'.format(full_key) 82 | value = _decode_cfg_value(v) 83 | value = _check_and_coerce_cfg_value_type( 84 | value, cfg[subkey], subkey, full_key 85 | ) 86 | setattr(new_cfg, subkey, value) 87 | 88 | return new_cfg 89 | 90 | 91 | def _decode_cfg_value(v): 92 | """Decodes a raw config value (e.g., from a yaml config files or command 93 | line argument) into a Python object. 94 | """ 95 | # All remaining processing is only applied to strings 96 | if not isinstance(v, str): 97 | return v 98 | # Try to interpret `v` as a: 99 | # string, number, tuple, list, dict, boolean, or None 100 | try: 101 | v = literal_eval(v) 102 | # The following two excepts allow v to pass through when it represents a 103 | # string. 104 | # 105 | # Longer explanation: 106 | # The type of v is always a string (before calling literal_eval), but 107 | # sometimes it *represents* a string and other times a data structure, like 108 | # a list. In the case that v represents a string, what we got back from the 109 | # yaml parser is 'foo' *without quotes* (so, not '"foo"'). literal_eval is 110 | # ok with '"foo"', but will raise a ValueError if given 'foo'. In other 111 | # cases, like paths (v = 'foo/bar' and not v = '"foo/bar"'), literal_eval 112 | # will raise a SyntaxError. 113 | except ValueError: 114 | pass 115 | except SyntaxError: 116 | pass 117 | return v 118 | 119 | 120 | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key): 121 | """Checks that `replacement`, which is intended to replace `original` is of 122 | the right type. The type is correct if it matches exactly or is one of a few 123 | cases in which the type can be easily coerced. 124 | """ 125 | original_type = type(original) 126 | replacement_type = type(replacement) 127 | 128 | # The types must match (with some exceptions) 129 | if replacement_type == original_type or original is None: 130 | return replacement 131 | 132 | # Cast replacement from from_type to to_type if the replacement and original 133 | # types match from_type and to_type 134 | def conditional_cast(from_type, to_type): 135 | if replacement_type == from_type and original_type == to_type: 136 | return True, to_type(replacement) 137 | else: 138 | return False, None 139 | 140 | # Conditionally casts 141 | # list <-> tuple 142 | casts = [(tuple, list), (list, tuple)] 143 | # For py2: allow converting from str (bytes) to a unicode string 144 | # try: 145 | # casts.append((str, unicode)) # noqa: F821 146 | # except Exception: 147 | # pass 148 | 149 | for (from_type, to_type) in casts: 150 | converted, converted_value = conditional_cast(from_type, to_type) 151 | if converted: 152 | return converted_value 153 | 154 | raise ValueError( 155 | "Type mismatch ({} vs. {}) with values ({} vs. {}) for config " 156 | "key: {}".format( 157 | original_type, replacement_type, original, replacement, full_key 158 | ) 159 | ) 160 | 161 | -------------------------------------------------------------------------------- /util/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def loss_omni(preds, targets, mode='bce'): 6 | num_class = preds[0].shape[1] 7 | loss = [] 8 | for lvl in range(len(preds)): 9 | pred = preds[lvl] 10 | target = targets[lvl][:, :num_class] 11 | if mode == 'bce': 12 | # weight = target / target.max(dim=1, keepdim=True)[0].clamp(min=1) 13 | l = F.binary_cross_entropy_with_logits(pred, target.clamp(0, 1).detach()) 14 | elif mode == 'focal': 15 | l = focal_loss(pred, target) 16 | elif mode == 'dice': 17 | l = dice_loss(pred, target) 18 | else: 19 | raise NotImplementedError('Invalid mode!') 20 | loss.append(l) 21 | loss = torch.stack(loss).mean() 22 | return loss 23 | 24 | 25 | def focal_loss(inputs, targets, alpha=0.25, gamma=2): 26 | prob = inputs.sigmoid() 27 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 28 | p_t = prob * targets + (1 - prob) * (1 - targets) 29 | loss = ce_loss * ((1 - p_t) ** gamma) 30 | 31 | if alpha >= 0: 32 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 33 | loss = alpha_t * loss 34 | 35 | return loss.mean() 36 | 37 | 38 | def dice_loss(inputs, targets): 39 | """ 40 | Compute the DICE loss, similar to generalized IOU for masks 41 | Args: 42 | inputs: A float tensor of arbitrary shape. 43 | The predictions for each example. 44 | targets: A float tensor with the same shape as inputs. Stores the binary 45 | classification label for each element in inputs 46 | (0 for the negative class and 1 for the positive class). 47 | """ 48 | inputs = inputs.sigmoid() 49 | inputs = inputs.flatten(1) 50 | numerator = 2 * (inputs * targets).sum(-1) 51 | denominator = inputs.sum(-1) + targets.sum(-1) 52 | loss = 1 - (numerator + 1) / (denominator + 1) 53 | return loss.mean() 54 | -------------------------------------------------------------------------------- /util/solver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). All Rights Reserved. 2 | # 3 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 4 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part of 5 | # the code. 6 | import torch 7 | import logging 8 | 9 | from torch.optim import SGD, Adam, AdamW 10 | from torch.optim.lr_scheduler import LambdaLR, StepLR 11 | 12 | 13 | class LambdaStepLR(LambdaLR): 14 | 15 | def __init__(self, optimizer, lr_lambda, last_step=-1): 16 | super(LambdaStepLR, self).__init__(optimizer, lr_lambda, last_step) 17 | 18 | @property 19 | def last_step(self): 20 | """Use last_epoch for the step counter""" 21 | return self.last_epoch 22 | 23 | @last_step.setter 24 | def last_step(self, v): 25 | self.last_epoch = v 26 | 27 | 28 | class PolyLR(LambdaStepLR): 29 | """DeepLab learning rate policy""" 30 | 31 | def __init__(self, optimizer, max_iter, power=0.9, last_step=-1): 32 | super(PolyLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**power, last_step) 33 | 34 | 35 | class SquaredLR(LambdaStepLR): 36 | """ Used for SGD Lars""" 37 | 38 | def __init__(self, optimizer, max_iter, last_step=-1): 39 | super(SquaredLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1))**2, last_step) 40 | 41 | 42 | class SquaredLRWarmingUp(LambdaStepLR): 43 | """ Used for SGD Lars""" 44 | 45 | def __init__(self, optimizer, max_iter, warmingup_iter=500, last_step=-1): 46 | lambda_s = lambda s: (1 - s / (max_iter + 1))**2 if s > warmingup_iter else s / warmingup_iter * ((1 - s / (max_iter + 1))**2 - 1e-6) + 1e-6 47 | super(SquaredLRWarmingUp, self).__init__(optimizer, lambda_s, last_step) 48 | 49 | 50 | class ExpLR(LambdaStepLR): 51 | 52 | def __init__(self, optimizer, step_size, gamma=0.9, last_step=-1): 53 | # (0.9 ** 21.854) = 0.1, (0.95 ** 44.8906) = 0.1 54 | # To get 0.1 every N using gamma 0.9, N * log(0.9)/log(0.1) = 0.04575749 N 55 | # To get 0.1 every N using gamma g, g ** N = 0.1 -> N * log(g) = log(0.1) -> g = np.exp(log(0.1) / N) 56 | super(ExpLR, self).__init__(optimizer, lambda s: gamma**(s / step_size), last_step) 57 | 58 | 59 | class SGDLars(SGD): 60 | """Lars Optimizer (https://arxiv.org/pdf/1708.03888.pdf)""" 61 | 62 | def step(self, closure=None): 63 | """Performs a single optimization step. 64 | 65 | Arguments: 66 | closure (callable, optional): A closure that reevaluates the model 67 | and returns the loss. 68 | 69 | .. note:: 70 | The implementation of SGD with Momentum/Nesterov subtly differs from 71 | Sutskever et. al. and implementations in some other frameworks. 72 | 73 | Considering the specific case of Momentum, the update can be written as 74 | 75 | .. math:: 76 | v = \rho * v + g \\ 77 | p = p - lr * v 78 | 79 | where p, g, v and :math:`\rho` denote the parameters, gradient, 80 | velocity, and momentum respectively. 81 | 82 | This is in contrast to Sutskever et. al. and 83 | other frameworks which employ an update of the form 84 | 85 | .. math:: 86 | v = \rho * v + lr * g \\ 87 | p = p - v 88 | 89 | The Nesterov version is analogously modified. 90 | """ 91 | loss = None 92 | if closure is not None: 93 | loss = closure() 94 | 95 | for group in self.param_groups: 96 | weight_decay = group['weight_decay'] 97 | momentum = group['momentum'] 98 | dampening = group['dampening'] 99 | nesterov = group['nesterov'] 100 | 101 | for p in group['params']: 102 | if p.grad is None: 103 | continue 104 | d_p = p.grad.data 105 | # LARS 106 | w_norm = torch.norm(p.data) 107 | lamb = w_norm / (w_norm + torch.norm(d_p)) 108 | d_p.mul_(lamb) 109 | if weight_decay != 0: 110 | d_p.add_(weight_decay, p.data) 111 | if momentum != 0: 112 | param_state = self.state[p] 113 | if 'momentum_buffer' not in param_state: 114 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 115 | buf.mul_(momentum).add_(d_p) 116 | else: 117 | buf = param_state['momentum_buffer'] 118 | buf.mul_(momentum).add_(1 - dampening, d_p) 119 | if nesterov: 120 | d_p = d_p.add(momentum, buf) 121 | else: 122 | d_p = buf 123 | 124 | p.data.add_(-group['lr'], d_p) 125 | 126 | return loss 127 | 128 | 129 | def initialize_optimizer(params, config, mode): 130 | if mode == 'b': 131 | config_optimizer = config.optimizer_b 132 | config_lr = config.lr_b 133 | elif mode == 'p': 134 | config_optimizer = config.optimizer_p 135 | config_lr = config.lr_p 136 | else: 137 | raise ValueError('Invalid mode!') 138 | assert config_optimizer in ['SGD', 'Adagrad', 'Adam', 'RMSProp', 'Rprop', 'SGDLars', 'AdamW'] 139 | 140 | if config_optimizer == 'SGD': 141 | return SGD( 142 | params, 143 | lr=config_lr, 144 | momentum=config.sgd_momentum, 145 | dampening=config.sgd_dampening, 146 | weight_decay=1e-4) 147 | if config_optimizer == 'SGDLars': 148 | return SGDLars( 149 | params, 150 | lr=config_lr, 151 | momentum=config.sgd_momentum, 152 | dampening=config.sgd_dampening, 153 | weight_decay=config.weight_decay) 154 | elif config_optimizer == 'Adam': 155 | return Adam( 156 | params, 157 | lr=config_lr, 158 | betas=(config.adam_beta1, config.adam_beta2), 159 | weight_decay=config.weight_decay) 160 | elif config_optimizer == 'AdamW': 161 | return AdamW( 162 | params, 163 | lr=config_lr, 164 | betas=(config.adam_beta1, config.adam_beta2), 165 | weight_decay=config.weight_decay) 166 | else: 167 | logging.error('Optimizer type not supported') 168 | raise ValueError('Optimizer type not supported') 169 | 170 | 171 | def initialize_scheduler(optimizer, config, last_step=-1): 172 | if config.scheduler == 'StepLR': 173 | return StepLR( 174 | optimizer, step_size=config.step_size, gamma=config.step_gamma, last_epoch=last_step) 175 | elif config.scheduler == 'PolyLR': 176 | return PolyLR(optimizer, max_iter=config.max_iter, power=config.poly_power, last_step=last_step) 177 | elif config.scheduler == 'SquaredLR': 178 | return SquaredLR(optimizer, max_iter=config.max_iter, last_step=last_step) 179 | elif config.scheduler == 'ExpLR': 180 | return ExpLR( 181 | optimizer, step_size=config.exp_step_size, gamma=config.exp_gamma, last_step=last_step) 182 | elif config.scheduler == 'SquaredLRWarmingUp': 183 | return SquaredLRWarmingUp(optimizer, max_iter=config.max_iter, last_step=last_step) 184 | else: 185 | logging.error('Scheduler not supported') 186 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from os.path import join 4 | 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.init as initer 11 | 12 | 13 | def save_checkpoint(state, is_best, sav_path, filename='model_last.pth.tar'): 14 | filename = join(sav_path, filename) 15 | torch.save(state, filename) 16 | if is_best: 17 | shutil.copyfile(filename, join(sav_path, 'model_best.pth.tar')) 18 | 19 | 20 | class AverageMeter(object): 21 | """Computes and stores the average and current value""" 22 | 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | 38 | 39 | def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1): 40 | """Sets the learning rate to the base LR decayed by 10 every step epochs""" 41 | lr = base_lr * (multiplier ** (epoch // step_epoch)) 42 | return lr 43 | 44 | 45 | def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9): 46 | """poly learning rate policy""" 47 | lr = base_lr * (1 - float(curr_iter) / max_iter) ** power 48 | return lr 49 | 50 | 51 | def intersectionAndUnion(output, target, K, ignore_index=255): 52 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 53 | assert (output.ndim in [1, 2, 3, 4]) 54 | assert output.shape == target.shape 55 | output = output.reshape(output.size).copy() 56 | target = target.reshape(target.size) 57 | output[np.where(target == ignore_index)[0]] = ignore_index 58 | intersection = output[np.where(output == target)[0]] 59 | area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1)) 60 | area_output, _ = np.histogram(output, bins=np.arange(K + 1)) 61 | area_target, _ = np.histogram(target, bins=np.arange(K + 1)) 62 | area_union = area_output + area_target - area_intersection 63 | return area_intersection, area_union, area_target 64 | 65 | 66 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 67 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 68 | assert (output.dim() in [1, 2, 3, 4]) 69 | assert output.shape == target.shape 70 | output = output.view(-1) 71 | target = target.view(-1) 72 | output[target == ignore_index] = ignore_index 73 | intersection = output[output == target] 74 | # https://github.com/pytorch/pytorch/issues/1382 75 | area_intersection = torch.histc(intersection.float().cpu(), bins=K, min=0, max=K - 1) 76 | area_output = torch.histc(output.float().cpu(), bins=K, min=0, max=K - 1) 77 | area_target = torch.histc(target.float().cpu(), bins=K, min=0, max=K - 1) 78 | area_union = area_output + area_target - area_intersection 79 | return area_intersection.cuda(), area_union.cuda(), area_target.cuda() 80 | 81 | 82 | def check_mkdir(dir_name): 83 | if not os.path.exists(dir_name): 84 | os.mkdir(dir_name) 85 | 86 | 87 | def check_makedirs(dir_name): 88 | if not os.path.exists(dir_name): 89 | os.makedirs(dir_name) 90 | 91 | 92 | # def init_weights(model, conv='kaiming', batchnorm='normal', linear='kaiming', lstm='kaiming'): 93 | # """ 94 | # :param model: Pytorch Model which is nn.Module 95 | # :param conv: 'kaiming' or 'xavier' 96 | # :param batchnorm: 'normal' or 'constant' 97 | # :param linear: 'kaiming' or 'xavier' 98 | # :param lstm: 'kaiming' or 'xavier' 99 | # """ 100 | # for m in model.modules(): 101 | # if isinstance(m, (nn.modules.conv._ConvNd)): 102 | # if conv == 'kaiming': 103 | # initer.kaiming_normal_(m.weight) 104 | # elif conv == 'xavier': 105 | # initer.xavier_normal_(m.weight) 106 | # else: 107 | # raise ValueError("init type of conv error.\n") 108 | # if m.bias is not None: 109 | # initer.constant_(m.bias, 0) 110 | # 111 | # elif isinstance(m, (nn.modules.batchnorm._BatchNorm)): 112 | # if batchnorm == 'normal': 113 | # initer.normal_(m.weight, 1.0, 0.02) 114 | # elif batchnorm == 'constant': 115 | # initer.constant_(m.weight, 1.0) 116 | # else: 117 | # raise ValueError("init type of batchnorm error.\n") 118 | # initer.constant_(m.bias, 0.0) 119 | # 120 | # elif isinstance(m, nn.Linear): 121 | # if linear == 'kaiming': 122 | # initer.kaiming_normal_(m.weight) 123 | # elif linear == 'xavier': 124 | # initer.xavier_normal_(m.weight) 125 | # else: 126 | # raise ValueError("init type of linear error.\n") 127 | # if m.bias is not None: 128 | # initer.constant_(m.bias, 0) 129 | # 130 | # elif isinstance(m, nn.LSTM): 131 | # for name, param in m.named_parameters(): 132 | # if 'weight' in name: 133 | # if lstm == 'kaiming': 134 | # initer.kaiming_normal_(param) 135 | # elif lstm == 'xavier': 136 | # initer.xavier_normal_(param) 137 | # else: 138 | # raise ValueError("init type of lstm error.\n") 139 | # elif 'bias' in name: 140 | # initer.constant_(param, 0) 141 | # 142 | # 143 | # def group_weight(weight_group, module, lr): 144 | # group_decay = [] 145 | # group_no_decay = [] 146 | # for m in module.modules(): 147 | # if isinstance(m, nn.Linear): 148 | # group_decay.append(m.weight) 149 | # if m.bias is not None: 150 | # group_no_decay.append(m.bias) 151 | # elif isinstance(m, nn.modules.conv._ConvNd): 152 | # group_decay.append(m.weight) 153 | # if m.bias is not None: 154 | # group_no_decay.append(m.bias) 155 | # elif isinstance(m, nn.modules.batchnorm._BatchNorm): 156 | # if m.weight is not None: 157 | # group_no_decay.append(m.weight) 158 | # if m.bias is not None: 159 | # group_no_decay.append(m.bias) 160 | # assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) 161 | # weight_group.append(dict(params=group_decay, lr=lr)) 162 | # weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr)) 163 | # return weight_group 164 | # 165 | # 166 | # def convert_to_syncbn(model): 167 | # def recursive_set(cur_module, name, module): 168 | # if len(name.split('.')) > 1: 169 | # recursive_set(getattr(cur_module, name[:name.find('.')]), name[name.find('.') + 1:], module) 170 | # else: 171 | # setattr(cur_module, name, module) 172 | # 173 | # from lib.sync_bn import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 174 | # for name, m in model.named_modules(): 175 | # if isinstance(m, nn.BatchNorm1d): 176 | # recursive_set(model, name, SynchronizedBatchNorm1d(m.num_features, m.eps, m.momentum, m.affine)) 177 | # elif isinstance(m, nn.BatchNorm2d): 178 | # recursive_set(model, name, SynchronizedBatchNorm2d(m.num_features, m.eps, m.momentum, m.affine)) 179 | # elif isinstance(m, nn.BatchNorm3d): 180 | # recursive_set(model, name, SynchronizedBatchNorm3d(m.num_features, m.eps, m.momentum, m.affine)) 181 | # 182 | # 183 | # def colorize(gray, palette): 184 | # # gray: numpy array of the label and 1*3N size list palette 185 | # color = Image.fromarray(gray.astype(np.uint8)).convert('P') 186 | # color.putpalette(palette) 187 | # return color 188 | --------------------------------------------------------------------------------