├── DATA └── Notes.txt ├── LICENSE ├── Network.png ├── README.md ├── ckpt_files_OCID └── pretrained │ └── Note.txt ├── grasp_det_seg ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── _version.cpython-36.pyc ├── _version.py ├── algos │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── detection.cpython-36.pyc │ ├── detection.py │ ├── fpn.py │ ├── rpn.py │ └── semantic_seg.py ├── config │ ├── __init__.py │ ├── config.py │ └── defaults │ │ └── det_seg_OCID.ini ├── data_OCID │ ├── OCID_class_dict.py │ ├── __init__.py │ ├── dataset.py │ ├── misc.py │ ├── sampler.py │ └── transform.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── resnet.cpython-36.pyc │ ├── det_seg.py │ └── resnet.py ├── modules │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── losses.cpython-36.pyc │ │ ├── misc.cpython-36.pyc │ │ └── residual.cpython-36.pyc │ ├── fpn.py │ ├── heads │ │ ├── __init__.py │ │ ├── fpn.py │ │ └── rpn.py │ ├── losses.py │ ├── misc.py │ └── residual.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── coco_ap.cpython-36.pyc │ ├── misc.cpython-36.pyc │ └── scheduler.cpython-36.pyc │ ├── bbx │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── bbx.cpython-36.pyc │ ├── _backend.pyi │ └── bbx.py │ ├── logging.py │ ├── meters.py │ ├── misc.py │ ├── nms │ ├── __init__.py │ ├── _backend.pyi │ └── nms.py │ ├── parallel │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── data_parallel.cpython-36.pyc │ │ ├── packed_sequence.cpython-36.pyc │ │ └── scatter_gather.cpython-36.pyc │ ├── data_parallel.py │ ├── packed_sequence.py │ └── scatter_gather.py │ ├── roi_sampling │ ├── __init__.py │ ├── _backend.pyi │ └── functions.py │ ├── scheduler.py │ ├── sequence.py │ └── snapshot.py ├── include ├── bbx.h ├── nms.h ├── roi_sampling.h └── utils │ ├── checks.h │ ├── common.h │ └── cuda.cuh ├── requirements.txt ├── sample.png ├── scripts ├── evaluate_det_seg_OCID.py ├── test_det_seg_OCID.py └── train_det_seg_OCID.py ├── setup.cfg ├── setup.py ├── src ├── bbx │ ├── bbx.cpp │ ├── bbx_cpu.cpp │ └── bbx_cuda.cu ├── nms │ ├── nms.cpp │ ├── nms_cpu.cpp │ └── nms_cuda.cu └── roi_sampling │ ├── roi_sampling.cpp │ ├── roi_sampling_cpu.cpp │ └── roi_sampling_cuda.cu └── weights_pretrained └── Note.txt /DATA/Notes.txt: -------------------------------------------------------------------------------- 1 | Unzip OCID_grasp.zip in this folder. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, mapillary 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/Network.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # End-to-end Trainable Deep Neural Network for Robotic Grasp Detection and Semantic Segmentation from RGB 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/end-to-end-trainable-deep-neural-network-for/robotic-grasping-on-cornell-grasp-dataset-1)](https://paperswithcode.com/sota/robotic-grasping-on-cornell-grasp-dataset-1?p=end-to-end-trainable-deep-neural-network-for) 4 | 5 | 6 |

7 | 8 |
9 | arXiv 10 |

11 | 12 | This repository contains the code for the ICRA21 paper "End-to-end Trainable Deep Neural Network for Robotic Grasp Detection 13 | and Semantic Segmentation from RGB". 14 | It contains the code for training and testing our proposed method in combination with the OCID_grasp dataset. 15 | 16 | If you use our method or dataset extension for your research, please cite: 17 | ```bibtex 18 | @InProceedings{ainetter2021end, 19 | title={End-to-end Trainable Deep Neural Network for Robotic Grasp Detection and Semantic Segmentation from RGB}, 20 | author={Ainetter, Stefan and Fraundorfer, Friedrich}, 21 | booktitle={IEEE International Conference on Robotics and Automation (ICRA)}, 22 | pages={13452--13458} 23 | year={2021} 24 | } 25 | ``` 26 | 27 | ## Requirements and setup 28 | 29 | Main system requirements: 30 | * CUDA 10.1 31 | * Linux with GCC 7 or 8 32 | * PyTorch v1.1.0 33 | 34 | **IMPORTANT NOTE**: These requirements are not necessarily stringent, e.g. it might be possible to compile with older 35 | versions of CUDA, or under Windows. However, we have only tested the code under the above settings and cannot provide support for other setups. 36 | 37 | To install PyTorch, please refer to https://github.com/pytorch/pytorch#installation. 38 | 39 | To install all other dependencies using pip: 40 | ```bash 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ### Setup 45 | 46 | Our code is split into two main components: a library containing implementations for the various network modules, 47 | algorithms and utilities, and a set of scripts to train / test the networks. 48 | 49 | The library, called `grasp_det_seg`, can be installed with: 50 | ```bash 51 | git clone https://github.com/stefan-ainetter/grasp_det_seg_cnn.git 52 | cd grasp_det_seg_cnn 53 | python setup.py install 54 | ``` 55 | 56 | ## Trained models 57 | 58 | The model files provided are made available under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) license. 59 | 60 | A trained model for the OCID_grasp dataset can be downloaded [here](https://cloud.tugraz.at/index.php/s/NA7icqiJ5SeNSA6/download?path=%2FGrasp_det_seg_cnn%2FOCID_pretrained&files=model_last.pth.tar). 61 | Download and copy the downloaded weights into the `ckpt_files_OCID/pretrained` folder. 62 | 63 | For re-training the network on OCID_grasp, you need to download weights pretrained on ImageNet 64 | [here](https://cloud.tugraz.at/index.php/s/NA7icqiJ5SeNSA6?path=%2FGrasp_det_seg_cnn%2FImageNet_weights) and copy them 65 | into the `weights_pretrained` folder. 66 | 67 | ### Training 68 | 69 | Training involves three main steps: Preparing the dataset, creating a configuration file and running the training 70 | script. 71 | 72 | To prepare the dataset: 73 | 1) Download the OCID_grasp dataset [here](https://cloud.tugraz.at/index.php/s/NA7icqiJ5SeNSA6/download?path=%2FGrasp_det_seg_cnn%2FOCID_grasp&files=OCID_grasp.zip). 74 | Unpack the downloaded `OCID_grasp.zip` file into the `DATA` folder. 75 | 2) The configuration file is a simple text file in `ini` format. 76 | The default value of each configuration parameter, as well as a short description of what it does, is available in 77 | [grasp_det_seg/config/defaults](grasp_det_seg/config/defaults). 78 | **Note** that these are just an indication of what a "reasonable" value for each parameter could be, and are not 79 | meant as a way to reproduce any of the results from our paper. 80 | 81 | 3) To launch the training: 82 | ```bash 83 | cd scripts 84 | python3 -m torch.distributed.launch --nproc_per_node=1 train_det_seg_OCID.py 85 | --log_dir=LOGDIR CONFIG DATA_DIR 86 | ``` 87 | Training logs, both in text and Tensorboard formats as well as the trained network parameters, will be written 88 | in `LOG_DIR` (e.g. `ckpt_files_OCID`). 89 | The file `CONFIG` contains the network configuration e.g. `grasp_det_seg/config/defaults/det_seg_OCID.ini`, 90 | and `DATA_DIR` points to the previously downloaded OCID_grasp splits, e.g. `DATA/OCID_grasp/data_split`. 91 | 92 | Note that, for now, our code **must** be launched in "distributed" mode using PyTorch's `torch.distributed.launch` 93 | utility. 94 | 95 | ### Running inference 96 | 97 | Given a trained network, inference can be run on any set of images using 98 | [scripts/test_det_seg_OCID.py](scripts/test_det_seg_OCID.py): 99 | ```bash 100 | cd scripts 101 | python3 -m torch.distributed.launch --nproc_per_node=1 test_det_seg_OCID.py 102 | --log_dir=LOG_DIR CONFIG MODEL_PARAMS DATA_DIR OUTPUT_DIR 103 | 104 | ``` 105 | Predictions will be written to `OUTPUT_DIR` e.g. the `output` folder. `MODEL_PARAMS` are pre-trained weights e.g. `ckpt_files_OCID/pretrained/model_last.pth.tar`, 106 | `DATA_DIR` points to the used dateset splits e.g. `DATA/OCID_grasp/data_split`. 107 | 108 | ## OCID_grasp dataset 109 | The OCID_grasp dataset can be downloaded [here](https://cloud.tugraz.at/index.php/s/NA7icqiJ5SeNSA6/download?path=%2FGrasp_det_seg_cnn%2FOCID_grasp&files=OCID_grasp.zip). 110 | OCID_grasp consists of 1763 selected RGB-D images of the OCID dataset, with over 11.4k segmented object masks and more than 75k hand-annotated 111 | grasp candidates. Additionally, each object is classified into one of 31 object classes. 112 | ## Related Work 113 | OCID_grasp is a dataset extension of the [OCID dataset](https://www.acin.tuwien.ac.at/en/vision-for-robotics/software-tools/object-clutter-indoor-dataset/). 114 | If you decide to use OCID_grasp for your research, please also cite the OCID paper: 115 | ```bibtex 116 | @inproceedings{suchi2019easylabel, 117 | title={EasyLabel: a semi-automatic pixel-wise object annotation tool for creating robotic RGB-D datasets}, 118 | author={Suchi, Markus and Patten, Timothy and Fischinger, David and Vincze, Markus}, 119 | booktitle={2019 International Conference on Robotics and Automation (ICRA)}, 120 | pages={6678--6684}, 121 | year={2019}, 122 | organization={IEEE} 123 | } 124 | ``` 125 | Our framework is based on the architecture from [Seamless Scene Segmentation](https://github.com/mapillary/seamseg): 126 | ```bibtex 127 | @InProceedings{Porzi_2019_CVPR, 128 | author = {Porzi, Lorenzo and Rota Bul\`o, Samuel and Colovic, Aleksander and Kontschieder, Peter}, 129 | title = {Seamless Scene Segmentation}, 130 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 131 | month = {June}, 132 | year = {2019} 133 | } 134 | ``` 135 | --- 136 | ## About our latest Research 137 | ### Our paper 'Depth-aware Object Segmentation and Grasp Detection for Robotic Picking Tasks' got accepted at BMVC21 138 | In our latest work, we implemented a method for joint grasp detection and class-agnostic object instance segmentation, 139 | which was published at BMVC21. 140 | More information can be found [here](https://arxiv.org/pdf/2111.11114). 141 | -------------------------------------------------------------------------------- /ckpt_files_OCID/pretrained/Note.txt: -------------------------------------------------------------------------------- 1 | Add pre-trained weights here -------------------------------------------------------------------------------- /grasp_det_seg/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import version as __version__ 2 | -------------------------------------------------------------------------------- /grasp_det_seg/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/__pycache__/_version.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/__pycache__/_version.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/_version.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # file generated by setuptools_scm 3 | # don't change, don't track in version control 4 | version = '0.1.dev0' 5 | version_tuple = (0, 1, 'dev0') 6 | -------------------------------------------------------------------------------- /grasp_det_seg/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/algos/__init__.py -------------------------------------------------------------------------------- /grasp_det_seg/algos/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/algos/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/algos/__pycache__/detection.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/algos/__pycache__/detection.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/algos/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from inplace_abn import active_group, set_active_group 3 | 4 | from grasp_det_seg.utils.bbx import shift_boxes 5 | from grasp_det_seg.utils.misc import Empty 6 | from grasp_det_seg.utils.parallel import PackedSequence 7 | from grasp_det_seg.utils.roi_sampling import roi_sampling 8 | from .detection import DetectionAlgo 9 | from .rpn import RPNAlgo 10 | 11 | 12 | class RPNAlgoFPN(RPNAlgo): 13 | """RPN algorithm for FPN-based region proposal networks 14 | 15 | Parameters 16 | ---------- 17 | proposal_generator : RPNProposalGenerator 18 | anchor_matcher : RPNAnchorMatcher 19 | loss : RPNLoss 20 | anchor_scale : float 21 | Anchor scale factor, this is multiplied by the RPN stride at each level to determine the actual anchor sizes 22 | anchor_ratios : sequence of float 23 | Anchor aspect ratios 24 | anchor_strides: sequence of int 25 | Effective strides of the RPN outputs at each FPN level 26 | min_level : int 27 | First FPN level to work on 28 | levels : int 29 | Number of FPN levels to work on 30 | """ 31 | 32 | def __init__(self, 33 | proposal_generator, 34 | anchor_matcher, 35 | loss, 36 | anchor_scale, 37 | anchor_ratios, 38 | anchor_strides, 39 | min_level, 40 | levels): 41 | super(RPNAlgoFPN, self).__init__((anchor_scale,), anchor_ratios) 42 | self.proposal_generator = proposal_generator 43 | self.anchor_matcher = anchor_matcher 44 | self.loss = loss 45 | self.min_level = min_level 46 | self.levels = levels 47 | 48 | # Cache per-cell anchors 49 | self.anchor_strides = anchor_strides[min_level:min_level + levels] 50 | self.anchors = [self._base_anchors(stride) for stride in self.anchor_strides] 51 | 52 | @staticmethod 53 | def _get_logits(head, x): 54 | obj_logits, bbx_logits, h, w = [], [], [], [] 55 | for x_i in x: 56 | obj_logits_i, bbx_logits_i = head(x_i) 57 | h_i, w_i = (int(s) for s in obj_logits_i.shape[-2:]) 58 | 59 | obj_logits_i = obj_logits_i.permute(0, 2, 3, 1).contiguous().view(obj_logits_i.size(0), -1) 60 | bbx_logits_i = bbx_logits_i.permute(0, 2, 3, 1).contiguous().view(bbx_logits_i.size(0), -1, 4) 61 | 62 | obj_logits.append(obj_logits_i) 63 | bbx_logits.append(bbx_logits_i) 64 | h.append(h_i) 65 | w.append(w_i) 66 | 67 | return torch.cat(obj_logits, dim=1), torch.cat(bbx_logits, dim=1), h, w 68 | 69 | def _inference(self, obj_logits, bbx_logits, anchors, valid_size, training): 70 | # Compute shifted boxes 71 | boxes = shift_boxes(anchors, bbx_logits) 72 | 73 | # Clip boxes to their image sizes 74 | for i, (height, width) in enumerate(valid_size): 75 | boxes[i, :, [0, 2]] = boxes[i, :, [0, 2]].clamp(min=0, max=height) 76 | boxes[i, :, [1, 3]] = boxes[i, :, [1, 3]].clamp(min=0, max=width) 77 | 78 | return self.proposal_generator(boxes, obj_logits, training) 79 | 80 | def training(self, head, x, bbx, iscrowd, valid_size, training=True, do_inference=False): 81 | # Calculate logits for the levels that we need 82 | x = x[self.min_level:self.min_level + self.levels] 83 | obj_logits, bbx_logits, h, w = self._get_logits(head, x) 84 | 85 | with torch.no_grad(): 86 | # Compute anchors for each scale and merge them 87 | anchors = [] 88 | for h_i, w_i, stride_i, anchors_i in zip(h, w, self.anchor_strides, self.anchors): 89 | anchors.append(self._shifted_anchors( 90 | anchors_i, stride_i, h_i, w_i, bbx_logits.dtype, bbx_logits.device)) 91 | anchors = torch.cat(anchors, dim=0) 92 | # obj_lbl: binary class label for each anchor (being an object or not) 93 | # bbx_lbl: coordinates for each bbx with pos object_lbl 94 | match = self.anchor_matcher(anchors, bbx, iscrowd, valid_size) 95 | obj_lbl, bbx_lbl = self._match_to_lbl(anchors, bbx, match) 96 | 97 | # Compute losses 98 | obj_loss, bbx_loss = self.loss(obj_logits, bbx_logits, obj_lbl, bbx_lbl) 99 | 100 | # Optionally, also run inference 101 | if do_inference: 102 | with torch.no_grad(): 103 | proposals = self._inference(obj_logits, bbx_logits, anchors, valid_size, training) 104 | else: 105 | proposals = None 106 | 107 | return obj_loss, bbx_loss, proposals 108 | 109 | def inference(self, head, x, valid_size, training): 110 | # Calculate logits for the levels that we need 111 | x = x[self.min_level:self.min_level + self.levels] 112 | obj_logits, bbx_logits, h, w = self._get_logits(head, x) 113 | 114 | # Compute anchors for each scale and merge them 115 | anchors = [] 116 | for h_i, w_i, stride_i, anchors_i in zip(h, w, self.anchor_strides, self.anchors): 117 | anchors.append(self._shifted_anchors( 118 | anchors_i, stride_i, h_i, w_i, bbx_logits.dtype, bbx_logits.device)) 119 | anchors = torch.cat(anchors, dim=0) 120 | 121 | return self._inference(obj_logits, bbx_logits, anchors, valid_size, training) 122 | 123 | 124 | class DetectionAlgoFPN(DetectionAlgo): 125 | """Detection algorithm for FPN networks 126 | """ 127 | 128 | def __init__(self, 129 | prediction_generator, 130 | proposal_matcher, 131 | loss, 132 | classes, 133 | bbx_reg_weights, 134 | canonical_scale, 135 | canonical_level, 136 | roi_size, 137 | min_level, 138 | levels): 139 | super(DetectionAlgoFPN, self).__init__(classes, bbx_reg_weights) 140 | self.prediction_generator = prediction_generator 141 | self.proposal_matcher = proposal_matcher 142 | self.loss = loss 143 | self.canonical_scale = canonical_scale 144 | self.canonical_level = canonical_level 145 | self.roi_size = roi_size 146 | self.min_level = min_level 147 | self.levels = levels 148 | 149 | def _target_level(self, boxes): 150 | scales = (boxes[:, 2:] - boxes[:, :2]).prod(dim=-1).sqrt() 151 | target_level = torch.floor(self.canonical_level + torch.log2(scales / self.canonical_scale + 1e-6)) 152 | return target_level.clamp(min=self.min_level, max=self.min_level + self.levels - 1) 153 | 154 | def _rois(self, x, proposals, proposals_idx, img_size): 155 | stride = proposals.new([fs / os for fs, os in zip(x.shape[-2:], img_size)]) 156 | proposals = (proposals - 0.5) * stride.repeat(2) + 0.5 157 | return roi_sampling(x, proposals, proposals_idx, self.roi_size) 158 | 159 | def _head(self, head, x, proposals, proposals_idx, img_size): 160 | # Find target levels 161 | target_level = self._target_level(proposals) 162 | 163 | # Sample rois 164 | rois = x[0].new_zeros(proposals.size(0), x[0].size(1), self.roi_size[0], self.roi_size[1]) 165 | for level_i, x_i in enumerate(x): 166 | idx = target_level == (level_i + self.min_level) 167 | if idx.any().item(): 168 | rois[idx] = self._rois(x_i, proposals[idx], proposals_idx[idx], img_size) 169 | 170 | # Run head 171 | return head(rois) 172 | 173 | def training(self, head, x, proposals, bbx, cat, iscrowd, img_size): 174 | x = x[self.min_level:self.min_level + self.levels] 175 | 176 | try: 177 | if proposals.all_none: 178 | raise Empty 179 | 180 | with torch.no_grad(): 181 | # Match proposals to ground truth 182 | proposals, match = self.proposal_matcher(proposals, bbx, cat, iscrowd) 183 | cls_lbl, bbx_lbl = self._match_to_lbl(proposals, bbx, cat, match) 184 | 185 | if proposals.all_none: 186 | raise Empty 187 | 188 | # Run head 189 | set_active_group(head, active_group(True)) 190 | proposals, proposals_idx = proposals.contiguous 191 | cls_logits, bbx_logits = self._head(head, x, proposals, proposals_idx, img_size) 192 | 193 | # Calculate loss 194 | cls_loss, bbx_loss = self.loss(cls_logits, bbx_logits, cls_lbl, bbx_lbl) 195 | except Empty: 196 | active_group(False) 197 | cls_loss = bbx_loss = sum(x_i.sum() for x_i in x) * 0 198 | 199 | return cls_loss, bbx_loss 200 | 201 | def inference(self, head, x, proposals, valid_size, img_size): 202 | x = x[self.min_level:self.min_level + self.levels] 203 | 204 | if not proposals.all_none: 205 | # Run head on the given proposals 206 | proposals, proposals_idx = proposals.contiguous 207 | cls_logits, bbx_logits = self._head(head, x, proposals, proposals_idx, img_size) 208 | 209 | # Shift the proposals according to the logits 210 | bbx_reg_weights = x[0].new(self.bbx_reg_weights) 211 | boxes = shift_boxes(proposals.unsqueeze(1), bbx_logits / bbx_reg_weights) 212 | scores = torch.softmax(cls_logits, dim=1) 213 | 214 | # Split boxes and scores by image, clip to valid size 215 | boxes, scores = self._split_and_clip(boxes, scores, proposals_idx, valid_size) 216 | 217 | bbx_pred, cls_pred, obj_pred = self.prediction_generator(boxes, scores) 218 | else: 219 | bbx_pred = PackedSequence([None for _ in range(x[0].size(0))]) 220 | cls_pred = PackedSequence([None for _ in range(x[0].size(0))]) 221 | obj_pred = PackedSequence([None for _ in range(x[0].size(0))]) 222 | 223 | return bbx_pred, cls_pred, obj_pred -------------------------------------------------------------------------------- /grasp_det_seg/algos/rpn.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as functional 6 | 7 | from grasp_det_seg.modules.losses import smooth_l1 8 | from grasp_det_seg.utils.bbx import ious, calculate_shift 9 | from grasp_det_seg.utils.misc import Empty 10 | from grasp_det_seg.utils.nms import nms 11 | from grasp_det_seg.utils.parallel import PackedSequence 12 | 13 | CHUNK_SIZE = 16 14 | 15 | 16 | class ProposalGenerator: 17 | """Perform NMS-based selection of proposals 18 | 19 | Parameters 20 | ---------- 21 | nms_threshold : float 22 | Intersection over union threshold for the NMS 23 | num_pre_nms_train : int 24 | Number of top-scoring proposals to feed to NMS, training mode 25 | num_post_nms_train : int 26 | Number of top-scoring proposal to keep after NMS, training mode 27 | num_pre_nms_val : int 28 | Number of top-scoring proposals to feed to NMS, validation mode 29 | num_post_nms_val : int 30 | Number of top-scoring proposal to keep after NMS, validation mode 31 | min_size : int 32 | Minimum size for proposals, discard anything with a side smaller than this 33 | """ 34 | 35 | def __init__(self, 36 | nms_threshold=0.7, 37 | num_pre_nms_train=12000, 38 | num_post_nms_train=2000, 39 | num_pre_nms_val=6000, 40 | num_post_nms_val=300, 41 | min_size=0): 42 | super(ProposalGenerator, self).__init__() 43 | self.nms_threshold = nms_threshold 44 | self.num_pre_nms_train = num_pre_nms_train 45 | self.num_post_nms_train = num_post_nms_train 46 | self.num_pre_nms_val = num_pre_nms_val 47 | self.num_post_nms_val = num_post_nms_val 48 | self.min_size = min_size 49 | 50 | def __call__(self, boxes, scores, training): 51 | """Perform NMS-based selection of proposals 52 | """ 53 | if training: 54 | num_pre_nms = self.num_pre_nms_train 55 | num_post_nms = self.num_post_nms_train 56 | else: 57 | num_pre_nms = self.num_pre_nms_val 58 | num_post_nms = self.num_post_nms_val 59 | 60 | proposals = [] 61 | for bbx_i, obj_i in zip(boxes, scores): 62 | try: 63 | # Optional size pre-selection 64 | if self.min_size > 0: 65 | bbx_size = bbx_i[:, 2:] - bbx_i[:, :2] 66 | valid = (bbx_size[:, 0] >= self.min_size) & (bbx_size[:, 1] >= self.min_size) 67 | 68 | if valid.any().item(): 69 | bbx_i, obj_i = bbx_i[valid], obj_i[valid] 70 | else: 71 | raise Empty 72 | 73 | # Score pre-selection 74 | obj_i, idx = obj_i.topk(min(obj_i.size(0), num_pre_nms)) 75 | bbx_i = bbx_i[idx] 76 | 77 | # NMS 78 | idx = nms(bbx_i, obj_i, self.nms_threshold, num_post_nms) 79 | if idx.numel() == 0: 80 | raise Empty 81 | bbx_i = bbx_i[idx] 82 | 83 | proposals.append(bbx_i) 84 | except Empty: 85 | proposals.append(None) 86 | 87 | return PackedSequence(proposals) 88 | 89 | 90 | class AnchorMatcher: 91 | """Match anchors to ground truth boxes 92 | """ 93 | 94 | def __init__(self, 95 | num_samples=256, 96 | pos_ratio=.5, 97 | pos_threshold=.7, 98 | neg_threshold=.3, 99 | void_threshold=0.): 100 | self.num_samples = num_samples 101 | self.pos_ratio = pos_ratio 102 | self.pos_threshold = pos_threshold 103 | self.neg_threshold = neg_threshold 104 | self.void_threshold = void_threshold 105 | 106 | def _subsample(self, match): 107 | num_pos = int(self.num_samples * self.pos_ratio) 108 | pos_idx = torch.nonzero(match >= 0).view(-1) 109 | if pos_idx.numel() > num_pos: 110 | rand_selection = torch.randperm(pos_idx.numel(), dtype=torch.long, device=match.device)[num_pos:] 111 | match[pos_idx[rand_selection]] = -2 112 | else: 113 | num_pos = pos_idx.numel() 114 | 115 | num_neg = self.num_samples - num_pos 116 | neg_idx = torch.nonzero(match == -1).view(-1) 117 | if neg_idx.numel() > num_neg: 118 | rand_selection = torch.randperm(neg_idx.numel(), dtype=torch.long, device=match.device)[num_neg:] 119 | match[neg_idx[rand_selection]] = -2 120 | 121 | @staticmethod 122 | def _is_inside(bbx, valid_size): 123 | p0y, p0x, p1y, p1x = bbx[:, 0], bbx[:, 1], bbx[:, 2], bbx[:, 3] 124 | return (p0y >= 0) & (p0x >= 0) & (p1y <= valid_size[0]) & (p1x <= valid_size[1]) 125 | 126 | def __call__(self, anchors, bbx, iscrowd, valid_size): 127 | """Match anchors to ground truth boxes 128 | """ 129 | match = [] 130 | for bbx_i_, valid_size_i in zip(bbx, valid_size): 131 | bbx_i = bbx_i_[:,[0,1,3,4]] 132 | 133 | # Default labels: everything is void 134 | match_i = anchors.new_full((anchors.size(0),), -2, dtype=torch.long) 135 | 136 | try: 137 | # Find anchors that are entirely within the original image area 138 | valid = self._is_inside(anchors, valid_size_i) 139 | 140 | if not valid.any().item(): 141 | raise Empty 142 | 143 | valid_anchors = anchors[valid] 144 | 145 | if bbx_i is not None: 146 | max_a2g_iou = bbx_i.new_zeros(valid_anchors.size(0)) 147 | max_a2g_idx = bbx_i.new_full((valid_anchors.size(0),), -1, dtype=torch.long) 148 | max_g2a_iou = [] 149 | max_g2a_idx = [] 150 | 151 | # Calculate assignments iteratively to save memory 152 | for j, bbx_i_j in enumerate(torch.split(bbx_i, CHUNK_SIZE, dim=0)): 153 | iou = ious(valid_anchors, bbx_i_j) 154 | 155 | # Anchor -> GT 156 | iou_max, iou_idx = iou.max(dim=1) 157 | replace_idx = iou_max > max_a2g_iou 158 | 159 | max_a2g_idx[replace_idx] = iou_idx[replace_idx] + j * CHUNK_SIZE 160 | max_a2g_iou[replace_idx] = iou_max[replace_idx] 161 | 162 | # GT -> Anchor 163 | max_g2a_iou_j, max_g2a_idx_j = iou.transpose(0, 1).max(dim=1) 164 | max_g2a_iou.append(max_g2a_iou_j) 165 | max_g2a_idx.append(max_g2a_idx_j) 166 | 167 | del iou 168 | 169 | max_g2a_iou = torch.cat(max_g2a_iou, dim=0) 170 | max_g2a_idx = torch.cat(max_g2a_idx, dim=0) 171 | 172 | a2g_pos = max_a2g_iou >= self.pos_threshold 173 | a2g_neg = max_a2g_iou < self.neg_threshold 174 | g2a_pos = max_g2a_iou > 0 175 | 176 | valid_match = valid_anchors.new_full((valid_anchors.size(0),), -2, dtype=torch.long) 177 | valid_match[a2g_pos] = max_a2g_idx[a2g_pos] 178 | valid_match[a2g_neg] = -1 179 | valid_match[max_g2a_idx[g2a_pos]] = g2a_pos.nonzero().squeeze() 180 | else: 181 | # No ground truth boxes for this image: everything that is not void is negative 182 | valid_match = valid_anchors.new_full((valid_anchors.size(0),), -1, dtype=torch.long) 183 | 184 | # Subsample positives and negatives 185 | self._subsample(valid_match) 186 | 187 | match_i[valid] = valid_match 188 | except Empty: 189 | pass 190 | 191 | match.append(match_i) 192 | 193 | return torch.stack(match, dim=0) 194 | 195 | 196 | class RPNLoss: 197 | """RPN loss function 198 | 199 | Parameters 200 | ---------- 201 | sigma : float 202 | "bandwidth" parameter of the smooth-L1 loss used for bounding box regression 203 | """ 204 | 205 | def __init__(self, sigma): 206 | self.sigma = sigma 207 | 208 | def bbx_loss(self, bbx_logits, bbx_lbl, num_non_void): 209 | bbx_logits = bbx_logits.view(-1, 4) 210 | bbx_lbl = bbx_lbl.view(-1, 4) 211 | 212 | bbx_loss = smooth_l1(bbx_logits, bbx_lbl, self.sigma).sum(dim=-1).sum() 213 | bbx_loss *= torch.clamp(1 / num_non_void, max=1.) 214 | return bbx_loss 215 | 216 | def __call__(self, obj_logits, bbx_logits, obj_lbl, bbx_lbl): 217 | """RPN loss function 218 | """ 219 | # Get contiguous view of the labels 220 | positives = obj_lbl == 1 221 | non_void = obj_lbl != -1 222 | num_non_void = non_void.float().sum() 223 | 224 | # Objectness loss 225 | obj_loss = functional.binary_cross_entropy_with_logits( 226 | obj_logits, positives.float(), non_void.float(), reduction="sum") 227 | obj_loss *= torch.clamp(1. / num_non_void, max=1.) 228 | 229 | # Bounding box regression loss 230 | if positives.any().item(): 231 | bbx_logits = bbx_logits[positives.unsqueeze(-1).expand_as(bbx_logits)] 232 | bbx_lbl = bbx_lbl[positives.unsqueeze(-1).expand_as(bbx_lbl)] 233 | bbx_loss = self.bbx_loss(bbx_logits, bbx_lbl, num_non_void) 234 | else: 235 | bbx_loss = bbx_logits.sum() * 0 236 | 237 | return obj_loss.mean(), bbx_loss.mean() 238 | 239 | 240 | class RPNAlgo: 241 | """Base class for RPN algorithms 242 | 243 | Parameters 244 | ---------- 245 | anchor_scales : sequence of float 246 | Anchor scale factors, these will be multiplied by the RPN stride to determine the actual anchor sizes 247 | anchor_ratios : sequence of float 248 | Anchor aspect ratios 249 | """ 250 | 251 | def __init__(self, anchor_scales, anchor_ratios): 252 | self.anchor_scales = anchor_scales 253 | self.anchor_ratios = anchor_ratios 254 | 255 | def _base_anchors(self, stride): 256 | # Pre-generate per-cell anchors 257 | anchors = [] 258 | center = stride / 2. 259 | for scale in self.anchor_scales: 260 | for ratio in self.anchor_ratios: 261 | h = stride * scale * sqrt(ratio) 262 | w = stride * scale * sqrt(1. / ratio) 263 | 264 | anchor = ( 265 | center - h / 2., 266 | center - w / 2., 267 | center + h / 2., 268 | center + w / 2. 269 | ) 270 | anchors.append(anchor) 271 | 272 | return anchors 273 | 274 | @staticmethod 275 | def _shifted_anchors(anchors, stride, height, width, dtype=torch.float32, device="cpu"): 276 | grid_y = torch.arange(0, stride * height, stride, dtype=dtype, device=device) 277 | grid_x = torch.arange(0, stride * width, stride, dtype=dtype, device=device) 278 | grid = torch.stack([grid_y.view(-1, 1).repeat(1, width), grid_x.view(1, -1).repeat(height, 1)], dim=-1) 279 | 280 | anchors = torch.tensor(anchors, dtype=dtype, device=device) 281 | shifted_anchors = anchors.view(1, 1, -1, 4) + grid.repeat(1, 1, 2).unsqueeze(2) 282 | return shifted_anchors.view(-1, 4) 283 | 284 | @staticmethod 285 | def _match_to_lbl(anchors, bbx, match): 286 | pos, neg = match >= 0, match == -1 287 | 288 | # Objectness labels from matching tensor 289 | obj_lbl = torch.full_like(match, -1) 290 | obj_lbl[neg] = 0 291 | obj_lbl[pos] = 1 292 | 293 | # Bounding box regression labels from matching tensor 294 | bbx_lbl = anchors.new_zeros(len(bbx), anchors.size(0), anchors.size(1)) 295 | for i, (pos_i, bbx_i_, match_i) in enumerate(zip(pos, bbx, match)): 296 | bbx_i = bbx_i_[:,[0,1,3,4]] 297 | if pos_i.any(): 298 | bbx_lbl[i, pos_i] = calculate_shift(anchors[pos_i], bbx_i[match_i[pos_i]]) 299 | 300 | return obj_lbl, bbx_lbl 301 | 302 | def training(self, head, x, bbx, iscrowd, valid_size, training=True, do_inference=False): 303 | """Given input features and ground truth compute losses and, optionally, predictions 304 | """ 305 | raise NotImplementedError() 306 | 307 | def inference(self, head, x, valid_size, training): 308 | """Given input features compute object proposals 309 | """ 310 | raise NotImplementedError() 311 | -------------------------------------------------------------------------------- /grasp_det_seg/algos/semantic_seg.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import torch 4 | import torch.nn.functional as functional 5 | 6 | from grasp_det_seg.utils.parallel import PackedSequence 7 | from grasp_det_seg.utils.sequence import pack_padded_images 8 | 9 | 10 | class SemanticSegLoss: 11 | """Semantic segmentation loss 12 | 13 | Parameters 14 | ---------- 15 | ohem : float or None 16 | Online hard example mining fraction, or `None` to disable OHEM 17 | ignore_index : int 18 | Index of the void class 19 | """ 20 | 21 | def __init__(self, ohem=None, ignore_index=255): 22 | if ohem is not None and (ohem <= 0 or ohem > 1): 23 | raise ValueError("ohem should be in (0, 1]") 24 | self.ohem = ohem 25 | self.ignore_index = ignore_index 26 | 27 | def __call__(self, sem_logits, sem): 28 | """Compute the semantic segmentation loss 29 | """ 30 | sem_loss = [] 31 | for sem_logits_i, sem_i in zip(sem_logits, sem): 32 | sem_loss_i = functional.cross_entropy( 33 | sem_logits_i.unsqueeze(0), sem_i.unsqueeze(0), ignore_index=self.ignore_index, reduction="none") 34 | sem_loss_i = sem_loss_i.view(-1) 35 | 36 | if self.ohem is not None and self.ohem != 1: 37 | top_k = int(ceil(sem_loss_i.numel() * self.ohem)) 38 | if top_k != sem_loss_i.numel(): 39 | sem_loss_i, _ = sem_loss_i.topk(top_k) 40 | 41 | sem_loss.append(sem_loss_i.mean()) 42 | 43 | return sum(sem_loss) / len(sem_logits) 44 | 45 | 46 | class SemanticSegAlgo: 47 | """Semantic segmentation algorithm 48 | """ 49 | 50 | def __init__(self, loss, num_classes, ignore_index=255): 51 | self.loss = loss 52 | self.num_classes = num_classes 53 | self.ignore_index = ignore_index 54 | 55 | @staticmethod 56 | def _pack_logits(sem_logits, valid_size, img_size): 57 | sem_logits = functional.interpolate(sem_logits, size=img_size, mode="bilinear", align_corners=False) 58 | return pack_padded_images(sem_logits, valid_size) 59 | 60 | def _confusion_matrix(self, sem_pred, sem): 61 | confmat = sem[0].new_zeros(self.num_classes * self.num_classes, dtype=torch.float) 62 | 63 | for sem_pred_i, sem_i in zip(sem_pred, sem): 64 | valid = sem_i != self.ignore_index 65 | if valid.any(): 66 | sem_pred_i = sem_pred_i[valid] 67 | sem_i = sem_i[valid] 68 | 69 | confmat.index_add_( 70 | 0, sem_i.view(-1) * self.num_classes + sem_pred_i.view(-1), confmat.new_ones(sem_i.numel())) 71 | 72 | return confmat.view(self.num_classes, self.num_classes) 73 | 74 | @staticmethod 75 | def _logits(head, x, valid_size, img_size): 76 | sem_logits, sem_feats = head(x) 77 | return sem_logits,SemanticSegAlgo._pack_logits(sem_logits, valid_size, img_size), sem_feats 78 | 79 | def training(self, head, x, sem, valid_size, img_size): 80 | """Given input features and ground truth compute semantic segmentation loss, confusion matrix and prediction 81 | """ 82 | # Compute logits and prediction 83 | sem_logits_low_res, sem_logits, sem_feats = self._logits(head, x, valid_size, img_size) 84 | sem_pred = PackedSequence([sem_logits_i.max(dim=0)[1] for sem_logits_i in sem_logits]) 85 | sem_pred_low_res = PackedSequence([sem_logits_low_res_i.max(dim=0)[1].float() for sem_logits_low_res_i in sem_logits_low_res]) 86 | 87 | # Compute loss and confusion matrix 88 | sem_loss = self.loss(sem_logits, sem) 89 | conf_mat = self._confusion_matrix(sem_pred, sem) 90 | 91 | return sem_loss, conf_mat, sem_pred,sem_logits,sem_logits_low_res,sem_pred_low_res,sem_feats 92 | 93 | def inference(self, head, x, valid_size, img_size): 94 | """Given input features compute semantic segmentation prediction 95 | """ 96 | sem_logits_low_res, sem_logits, sem_feats = self._logits(head, x, valid_size, img_size) 97 | sem_pred = PackedSequence([sem_logits_i.max(dim=0)[1] for sem_logits_i in sem_logits]) 98 | sem_pred_low_res = PackedSequence([sem_logits_low_res_i.max(dim=0)[1].float() for sem_logits_low_res_i in sem_logits_low_res]) 99 | 100 | return sem_pred, sem_feats, sem_pred_low_res 101 | 102 | 103 | def confusion_matrix(sem_pred, sem, num_classes, ignore_index=255): 104 | confmat = sem_pred.new_zeros(num_classes * num_classes, dtype=torch.float) 105 | 106 | valid = sem != ignore_index 107 | if valid.any(): 108 | sem_pred = sem_pred[valid] 109 | sem = sem[valid] 110 | 111 | confmat.index_add_(0, sem.view(-1) * num_classes + sem_pred.view(-1), confmat.new_ones(sem.numel())) 112 | 113 | return confmat.view(num_classes, num_classes) 114 | -------------------------------------------------------------------------------- /grasp_det_seg/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import load_config 2 | -------------------------------------------------------------------------------- /grasp_det_seg/config/config.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import configparser 3 | from os import path, listdir 4 | 5 | _CONVERTERS = { 6 | "struct": ast.literal_eval 7 | } 8 | 9 | def load_config(config_file, defaults_file): 10 | parser = configparser.ConfigParser(allow_no_value=True, converters=_CONVERTERS) 11 | parser.read([defaults_file, config_file]) 12 | return parser 13 | -------------------------------------------------------------------------------- /grasp_det_seg/config/defaults/det_seg_OCID.ini: -------------------------------------------------------------------------------- 1 | # GENERAL NOTE: the fields denoted as meta-info are not actual configuration parameters. Instead, they are used to 2 | # describe some characteristic of a network module that needs to be accessible from some other module but is hard to 3 | # determine in a generic way from within the code. A typical example is the total output stride of the network body. 4 | # These should be properly configured by the user to match the actual properties of the network. 5 | 6 | [general] 7 | # Number of epochs between validations 8 | val_interval = 25 9 | # Number of steps before outputting a log entry 10 | log_interval = 10 11 | cudnn_benchmark = no 12 | num_classes = 18 13 | num_stuff = 0 14 | num_things = 18 15 | # 0 - 31 16 | num_semantic = 32 17 | 18 | 19 | [body] 20 | # Architecture for the body 21 | body = resnet101 22 | # Path to pre-trained ImageNet weights 23 | weights = ./GraspDetSeg_CNN/weights_pretrained/resnet101 24 | # Normalization mode: 25 | # -- bn: in-place batch norm everywhere 26 | # -- syncbn: synchronized in-place batch norm everywhere 27 | # -- syncbn+bn: synchronized in-place batch norm in the static part of the network, in-place batch norm everywhere else 28 | # -- gn: group norm everywhere 29 | # -- syncbn+gn: synchronized in-place batch norm in the static part of the network, group norm everywhere else 30 | # -- off: do not normalize activations (scale and bias are kept) 31 | normalization_mode = syncbn 32 | # Activation: 'leaky_relu' or 'elu' 33 | activation = leaky_relu 34 | activation_slope = 0.01 35 | # Group norm parameters 36 | gn_groups = 16 37 | # Additional parameters for the body 38 | body_params = {} 39 | # Number of frozen modules: in [1, 5] 40 | num_frozen = 2 41 | # Wether to freeze BN modules 42 | bn_frozen = yes 43 | # Meta-info 44 | out_channels = {"mod1": 64, "mod2": 256, "mod3": 512, "mod4": 1024, "mod5": 2048} 45 | out_strides = {"mod1": 4, "mod2": 4, "mod3": 8, "mod4": 16, "mod5": 32} 46 | 47 | [fpn] 48 | out_channels = 256 49 | extra_scales = 0 50 | interpolation = nearest 51 | # Input settings 52 | inputs = ["mod2", "mod3", "mod4", "mod5"] 53 | # Meta-info 54 | out_strides = (4, 8, 16, 32) 55 | 56 | [rpn] 57 | hidden_channels = 256 58 | stride = 1 59 | # Anchor settings 60 | anchor_ratios = (1., 0.1, 0.4, 0.7, 1.2) 61 | anchor_scale = 2 62 | # Proposal settings 63 | nms_threshold = 0.7 64 | num_pre_nms_train = 12000 65 | num_post_nms_train = 2000 66 | num_pre_nms_val = 6000 67 | num_post_nms_val = 300 68 | min_size = 16 69 | # Anchor matcher settings 70 | num_samples = 256 71 | pos_ratio = .5 72 | pos_threshold = .7 73 | neg_threshold = .3 74 | void_threshold = 0.7 75 | # FPN-specific settings 76 | fpn_min_level = 0 77 | fpn_levels = 3 78 | # Loss settings 79 | sigma = 3. 80 | 81 | [roi] 82 | roi_size = (14, 14) 83 | # Matcher settings 84 | num_samples = 128 85 | pos_ratio = .25 86 | pos_threshold = .5 87 | neg_threshold_hi = .5 88 | neg_threshold_lo = 0. 89 | void_threshold = 0.7 90 | void_is_background = no 91 | # Prediction generator settings 92 | nms_threshold = 0.3 93 | score_threshold = 0.05 94 | max_predictions = 100 95 | # FPN-specific settings 96 | fpn_min_level = 0 97 | fpn_levels = 4 98 | fpn_canonical_scale = 224 99 | fpn_canonical_level = 2 100 | # Loss settings 101 | sigma = 1. 102 | bbx_reg_weights = (10., 10., 5., 5.) 103 | 104 | [sem] 105 | fpn_min_level = 0 106 | fpn_levels = 4 107 | pooling_size = (64, 64) 108 | # Loss settings 109 | ohem = .25 110 | 111 | [optimizer] 112 | lr = 0.03 113 | weight_decay = 0.0001 114 | weight_decay_norm = yes 115 | momentum = 0.9 116 | nesterov = yes 117 | # obj, bbx, roi_cls, roi_bbx, sem 118 | loss_weights = (1., 1., 1., 1.,.75) 119 | 120 | [scheduler] 121 | epochs = 800 122 | # Scheduler type: 'linear', 'step', 'poly' or 'multistep' 123 | type = poly 124 | # When to update the learning rate: 'batch', 'epoch' 125 | update_mode = batch 126 | # Additional parameters for the scheduler 127 | # -- linear 128 | # from: initial lr multiplier 129 | # to: final lr multiplier 130 | # -- step 131 | # step_size: number of steps between lr decreases 132 | # gamma: multiplicative factor 133 | # -- poly 134 | # gamma: exponent of the polynomial 135 | # -- multistep 136 | # milestones: step indicies where the lr decreases will be triggered 137 | params = {"gamma": 0.9} 138 | burn_in_steps = 500 139 | burn_in_start = 0.333 140 | 141 | [dataloader] 142 | # Absolute path to the project 143 | root_path = ./GraspDetSeg_CNN 144 | # Image size parameters 145 | shortest_size = 480 146 | longest_max_size = 640 147 | # Batch size 148 | train_batch_size = 10 149 | val_batch_size = 1 150 | # Augmentation parameters 151 | rgb_mean = (0.485, 0.456, 0.406) 152 | rgb_std = (0.229, 0.224, 0.225) 153 | random_flip = no 154 | random_scale = None 155 | rotate_and_scale = True 156 | # Number of worker threads 157 | num_workers = 6 158 | # Subsets 159 | train_set = training_0 160 | val_set = validation_0 161 | test_set = validation_0 -------------------------------------------------------------------------------- /grasp_det_seg/data_OCID/OCID_class_dict.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cls_names = { 4 | 'background' : '0', 5 | 'apple' : '1', 6 | 'ball' : '2', 7 | 'banana' : '3', 8 | 'bell_pepper' : '4', 9 | 'binder' : '5', 10 | 'bowl' : '6', 11 | 'cereal_box' : '7', 12 | 'coffee_mug' : '8', 13 | 'flashlight' : '9', 14 | 'food_bag' : '10', 15 | 'food_box' : '11', 16 | 'food_can' : '12', 17 | 'glue_stick' : '13', 18 | 'hand_towel' : '14', 19 | 'instant_noodles' : '15', 20 | 'keyboard' : '16', 21 | 'kleenex' : '17', 22 | 'lemon' : '18', 23 | 'lime' : '19', 24 | 'marker' : '20', 25 | 'orange' : '21', 26 | 'peach' : '22', 27 | 'pear' : '23', 28 | 'potato' : '24', 29 | 'shampoo' : '25', 30 | 'soda_can' : '26', 31 | 'sponge' : '27', 32 | 'stapler' : '28', 33 | 'tomato' : '29', 34 | 'toothpaste' : '30', 35 | 'unknown' : '31' 36 | } 37 | 38 | colors = { 39 | '0': np.array([0, 0, 0]), 40 | '1': np.array([ 211, 47, 47 ]), 41 | '2': np.array([ 0, 255, 0]), 42 | '3': np.array([123, 31, 162]), 43 | '4': np.array([ 81, 45, 168 ]), 44 | '5': np.array([ 48, 63, 159 ]), 45 | '6': np.array([25, 118, 210]), 46 | '7': np.array([ 2, 136, 209 ]), 47 | '8': np.array([ 153, 51, 102 ]), 48 | '9': np.array([ 0, 121, 107 ]), 49 | '10': np.array([ 56, 142, 60 ]), 50 | '11': np.array([ 104, 159, 56 ]), 51 | '12': np.array([ 175, 180, 43 ]), 52 | '13': np.array([ 251, 192, 45 ]), 53 | '14': np.array([ 255, 160, 0 ]), 54 | '15': np.array([ 245, 124, 0 ]), 55 | '16': np.array([ 230, 74, 25 ]), 56 | '17': np.array([ 93, 64, 55 ]), 57 | '18': np.array([ 97, 97, 97 ]), 58 | '19': np.array([ 84, 110, 122 ]), 59 | '20': np.array([ 255, 255, 102]), 60 | '21': np.array([ 0, 151, 167 ]), 61 | '22': np.array([ 153, 255, 102 ]), 62 | '23': np.array([ 51, 255, 102 ]), 63 | '24': np.array([ 0, 255, 255 ]), 64 | '25': np.array([ 255, 255, 255 ]), 65 | '26': np.array([ 255, 204, 204 ]), 66 | '27': np.array([ 153, 102, 0 ]), 67 | '28': np.array([ 204, 255, 204 ]), 68 | '29': np.array([ 204, 255, 0 ]), 69 | '30': np.array([ 255, 0, 255 ]), 70 | '31': np.array([ 194, 24, 91 ]), 71 | } 72 | 73 | colors_list = list(colors.values()) 74 | cls_list = list(cls_names.keys()) -------------------------------------------------------------------------------- /grasp_det_seg/data_OCID/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import OCIDDataset, OCIDTestDataset 2 | from .misc import iss_collate_fn, read_boxes_from_file, prepare_frcnn_format 3 | from .transform import OCIDTransform, OCIDTestTransform -------------------------------------------------------------------------------- /grasp_det_seg/data_OCID/dataset.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import cv2 3 | import numpy as np 4 | import torch.utils.data as data 5 | import os 6 | from PIL import Image 7 | 8 | 9 | class OCIDDataset(data.Dataset): 10 | """OCID_grasp dataset for grasp detection and semantic segmentation 11 | """ 12 | 13 | def __init__(self, data_path, root_dir, split_name, transform): 14 | super(OCIDDataset, self).__init__() 15 | self.data_path = data_path 16 | self.root_dir = root_dir 17 | self.split_name = split_name 18 | self.transform = transform 19 | 20 | self._images = self._load_split() 21 | 22 | def _load_split(self): 23 | with open(path.join(self.data_path, self.split_name + ".txt"), "r") as fid: 24 | images = [x.strip() for x in fid.readlines()] 25 | 26 | return images 27 | 28 | def _load_item(self, item): 29 | seq_path, im_name = item.split(',') 30 | sample_path = os.path.join(self.root_dir, seq_path) 31 | img_path = os.path.join(sample_path, 'rgb', im_name) 32 | mask_path = os.path.join(sample_path, 'seg_mask_labeled_combi', im_name) 33 | anno_path = os.path.join(sample_path, 'Annotations', im_name[:-4] + '.txt') 34 | img_bgr = cv2.imread(img_path) 35 | img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 36 | 37 | with open(anno_path, "r") as f: 38 | points_list = [] 39 | boxes_list = [] 40 | for count, line in enumerate(f): 41 | line = line.rstrip() 42 | [x, y] = line.split(' ') 43 | 44 | x = float(x) 45 | y = float(y) 46 | 47 | pt = (x, y) 48 | points_list.append(pt) 49 | 50 | if len(points_list) == 4: 51 | boxes_list.append(points_list) 52 | points_list = [] 53 | 54 | msk = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED) 55 | box_arry = np.asarray(boxes_list) 56 | return img, msk, box_arry 57 | 58 | @property 59 | def categories(self): 60 | """Category names""" 61 | return self._meta["categories"] 62 | 63 | @property 64 | def num_categories(self): 65 | """Number of categories""" 66 | return len(self.categories) 67 | 68 | @property 69 | def num_stuff(self): 70 | """Number of "stuff" categories""" 71 | return self._meta["num_stuff"] 72 | 73 | @property 74 | def num_thing(self): 75 | """Number of "thing" categories""" 76 | return self.num_categories - self.num_stuff 77 | 78 | @property 79 | def original_ids(self): 80 | """Original class id of each category""" 81 | return self._meta["original_ids"] 82 | 83 | @property 84 | def palette(self): 85 | """Default palette to be used when color-coding semantic labels""" 86 | return np.array(self._meta["palette"], dtype=np.uint8) 87 | 88 | @property 89 | def img_sizes(self): 90 | """Size of each image of the dataset""" 91 | return [img_desc["size"] for img_desc in self._images] 92 | 93 | @property 94 | def img_categories(self): 95 | """Categories present in each image of the dataset""" 96 | return [img_desc["cat"] for img_desc in self._images] 97 | 98 | @property 99 | def get_images(self): 100 | """Categories present in each image of the dataset""" 101 | return self._images 102 | 103 | def __len__(self): 104 | return len(self._images) 105 | 106 | def __getitem__(self, item): 107 | im_rgb, msk, bbox_infos = self._load_item(item) 108 | 109 | rec, im_size = self.transform(im_rgb, msk, bbox_infos) 110 | 111 | rec["abs_path"] = item 112 | rec["root_path"] = self.root_dir 113 | rec["im_size"] = im_size 114 | return rec 115 | 116 | def get_raw_image(self, idx): 117 | """Load a single, unmodified image with given id from the dataset""" 118 | img_file = path.join(self._img_dir, idx) 119 | if path.exists(img_file + ".png"): 120 | img_file = img_file + ".png" 121 | elif path.exists(img_file + ".jpg"): 122 | img_file = img_file + ".jpg" 123 | else: 124 | raise IOError("Cannot find any image for id {} in {}".format(idx, self._img_dir)) 125 | 126 | return Image.open(img_file) 127 | 128 | def get_image_desc(self, idx): 129 | """Look up an image descriptor given the id""" 130 | matching = [img_desc for img_desc in self._images if img_desc["id"] == idx] 131 | if len(matching) == 1: 132 | return matching[0] 133 | else: 134 | raise ValueError("No image found with id %s" % idx) 135 | 136 | 137 | class OCIDTestDataset(data.Dataset): 138 | 139 | def __init__(self, data_path, root_dir, split_name, transform): 140 | super(OCIDTestDataset, self).__init__() 141 | self.data_path = data_path 142 | self.root_dir = root_dir 143 | self.split_name = split_name 144 | self.transform = transform 145 | 146 | self._images = self._load_split() 147 | 148 | def _load_split(self): 149 | with open(path.join(self.data_path, self.split_name + ".txt"), "r") as fid: 150 | images = [x.strip() for x in fid.readlines()] 151 | return images 152 | 153 | @property 154 | def img_sizes(self): 155 | """Size of each image of the dataset""" 156 | return [img_desc["size"] for img_desc in self._images] 157 | 158 | @property 159 | def get_images(self): 160 | """Categories present in each image of the dataset""" 161 | return self._images 162 | 163 | def __len__(self): 164 | return len(self._images) 165 | 166 | def __getitem__(self, item): 167 | seq_path, im_name = item.split(',') 168 | sample_path = os.path.join(self.root_dir, seq_path) 169 | img_path = os.path.join(sample_path, 'rgb', im_name) 170 | img_bgr = cv2.imread(img_path) 171 | im_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) 172 | 173 | img_, im_size = self.transform(im_rgb) 174 | 175 | return {"img": img_, 176 | "root_path": self.root_dir, 177 | "abs_path": item, 178 | "im_size": im_size 179 | } 180 | -------------------------------------------------------------------------------- /grasp_det_seg/data_OCID/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from grasp_det_seg.utils.parallel import PackedSequence 5 | 6 | 7 | def iss_collate_fn(items): 8 | """Collate function for ISS batches""" 9 | out = {} 10 | if len(items) > 0: 11 | for key in items[0]: 12 | out[key] = [item[key] for item in items] 13 | if isinstance(items[0][key], torch.Tensor): 14 | out[key] = PackedSequence(out[key]) 15 | return out 16 | 17 | def prepare_frcnn_format(boxes,im_size): 18 | boxes_ary = np.asarray(boxes) 19 | 20 | boxes_ary = np.swapaxes(boxes_ary, 1, 2) 21 | xy_ctr = np.sum(boxes_ary, axis=2) / 4 22 | x_ctr = xy_ctr[:, 0] 23 | y_ctr = xy_ctr[:, 1] 24 | width = np.sqrt(np.sum((boxes_ary[:, :, 0] - boxes_ary[:, :, 1]) ** 2, axis=1)) 25 | height = np.sqrt(np.sum((boxes_ary[:, :, 1] - boxes_ary[:, :, 2]) ** 2, axis=1)) 26 | 27 | theta = np.zeros((boxes_ary.shape[0]), dtype=np.int) 28 | theta = np.arctan((boxes_ary[:, 1, 1] - boxes_ary[:, 1, 0]) / (boxes_ary[:, 0, 0] - boxes_ary[:, 0, 1])) 29 | b = np.arctan((boxes_ary[:, 1, 0] - boxes_ary[:, 1, 1]) / (boxes_ary[:, 0, 1] - boxes_ary[:, 0, 0])) 30 | theta[np.where(boxes_ary[:, 0, 0] <= boxes_ary[:, 0, 1])] = b[np.where(boxes_ary[:, 0, 0] <= boxes_ary[:, 0, 1])] 31 | 32 | # used for fasterrcnn loss 33 | x_min = x_ctr - width / 2 34 | x_max = x_ctr + width / 2 35 | y_min = y_ctr - height / 2 36 | y_max = y_ctr + height / 2 37 | 38 | x_coords = np.vstack((x_min, x_max)) 39 | y_coords = np.vstack((y_min, y_max)) 40 | 41 | mat = np.asarray((np.all(x_coords > im_size[1], axis=0), np.all(x_coords < 0, axis=0), 42 | np.all(y_coords > im_size[0], axis=0), np.all(y_coords < 0, axis=0))) 43 | 44 | fail = np.any(mat, axis=0) 45 | correct_idx = np.where(fail == False) 46 | theta_deg = np.rad2deg(theta) + 90 47 | cls = (np.round((theta_deg) / (180 / 18))).astype(int) 48 | cls[np.where(cls == 18)] = 0 49 | 50 | ret_value = (boxes_ary[correct_idx], theta_deg[correct_idx],cls[correct_idx]) 51 | return ret_value 52 | 53 | def read_boxes_from_file(gt_path,delta_xy): 54 | with open(gt_path)as f: 55 | points_list = [] 56 | box_list = [] 57 | for count, line in enumerate(f): 58 | line = line.rstrip() 59 | [x, y] = line.split(' ') 60 | x = float(x) - int(delta_xy[0]) 61 | y = float(y) - int(delta_xy[1]) 62 | 63 | pt = (x, y) 64 | points_list.append(pt) 65 | 66 | if len(points_list) == 4: 67 | box_list.append(points_list) 68 | points_list = [] 69 | return box_list 70 | -------------------------------------------------------------------------------- /grasp_det_seg/data_OCID/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import distributed 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class ARBatchSampler(Sampler): 9 | def __init__(self, data_source, batch_size, drop_last=False, epoch=0): 10 | super(ARBatchSampler, self).__init__(data_source) 11 | self.data_source = data_source 12 | self.batch_size = batch_size 13 | self.drop_last = drop_last 14 | self._epoch = epoch 15 | 16 | # Split images by orientation 17 | self.img_sets = self.data_source.get_images 18 | 19 | def _split_images(self, indices): 20 | # returns lists of [im_id, aspect_ratio] 21 | 22 | img_sizes = self.data_source.img_sizes 23 | img_sets = [[], []] 24 | for img_id in indices: 25 | aspect_ratio = img_sizes[img_id][0] / img_sizes[img_id][1] 26 | if aspect_ratio < 1: 27 | img_sets[0].append({"id": img_id, "ar": aspect_ratio}) 28 | else: 29 | img_sets[1].append({"id": img_id, "ar": aspect_ratio}) 30 | 31 | return img_sets 32 | 33 | def _generate_batches(self): 34 | g = torch.Generator() 35 | g.manual_seed(self._epoch) 36 | 37 | self.img_sets = [self.img_sets[i] for i in torch.randperm(len(self.img_sets), generator=g)] 38 | 39 | batches = [] 40 | leftover = [] 41 | batch = [] 42 | for img in self.img_sets: 43 | batch.append(img) 44 | if len(batch) == self.batch_size: 45 | batches.append(batch) 46 | batch = [] 47 | leftover += batch 48 | 49 | if not self.drop_last: 50 | batch = [] 51 | for img in leftover: 52 | batch.append(img) 53 | if len(batch) == self.batch_size: 54 | batches.append(batch) 55 | batch = [] 56 | 57 | if len(batch) != 0: 58 | batches.append(batch) 59 | 60 | return batches 61 | 62 | def set_epoch(self, epoch): 63 | self._epoch = epoch 64 | 65 | def __len__(self): 66 | if self.drop_last: 67 | return len(self.img_sets) // self.batch_size 68 | else: 69 | return (len(self.img_sets) + self.batch_size - 1) // self.batch_size 70 | 71 | 72 | def __iter__(self): 73 | batches = self._generate_batches() 74 | for batch in batches: 75 | batch = sorted(batch, key=lambda i: i["ar"]) 76 | batch = [i["id"] for i in batch] 77 | yield batch 78 | 79 | 80 | class DistributedARBatchSampler(ARBatchSampler): 81 | def __init__(self, data_source, batch_size, num_replicas=None, rank=None, drop_last=False, epoch=0): 82 | super(DistributedARBatchSampler, self).__init__(data_source, batch_size, drop_last, epoch) 83 | 84 | # Automatically get world size and rank if not provided 85 | if num_replicas is None: 86 | num_replicas = distributed.get_world_size() 87 | if rank is None: 88 | rank = distributed.get_rank() 89 | 90 | self.num_replicas = num_replicas 91 | self.rank = rank 92 | 93 | tot_batches = super(DistributedARBatchSampler, self).__len__() 94 | self.num_batches = int(math.ceil(tot_batches / self.num_replicas)) 95 | 96 | def __len__(self): 97 | return self.num_batches 98 | 99 | def __iter__(self): 100 | batches = self._generate_batches() 101 | 102 | g = torch.Generator() 103 | g.manual_seed(self._epoch) 104 | indices = list(torch.randperm(len(batches), generator=g)) 105 | 106 | # add extra samples to make it evenly divisible 107 | indices += indices[:(self.num_batches * self.num_replicas - len(indices))] 108 | assert len(indices) == self.num_batches * self.num_replicas 109 | 110 | # subsample 111 | offset = self.num_batches * self.rank 112 | indices = indices[offset:offset + self.num_batches] 113 | assert len(indices) == self.num_batches 114 | 115 | for idx in indices: 116 | yield batches[idx] 117 | -------------------------------------------------------------------------------- /grasp_det_seg/data_OCID/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | import scipy 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | import cv2 7 | from torchvision.transforms import functional as tfn 8 | 9 | 10 | class OCIDTransform: 11 | """Transformer function for OCID_grasp dataset 12 | """ 13 | 14 | def __init__(self, 15 | shortest_size, 16 | longest_max_size, 17 | rgb_mean=None, 18 | rgb_std=None, 19 | random_flip=False, 20 | random_scale=None, 21 | rotate_and_scale=False): 22 | self.shortest_size = shortest_size 23 | self.longest_max_size = longest_max_size 24 | self.rgb_mean = rgb_mean 25 | self.rgb_std = rgb_std 26 | self.random_flip = random_flip 27 | self.random_scale = random_scale 28 | self.rotate_and_scale = rotate_and_scale 29 | 30 | def _adjusted_scale(self, in_width, in_height, target_size): 31 | min_size = min(in_width, in_height) 32 | max_size = max(in_width, in_height) 33 | scale = target_size / min_size 34 | 35 | if int(max_size * scale) > self.longest_max_size: 36 | scale = self.longest_max_size / max_size 37 | 38 | return scale 39 | 40 | @staticmethod 41 | def _random_flip(img, msk): 42 | if random.random() < 0.5: 43 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 44 | msk = [m.transpose(Image.FLIP_LEFT_RIGHT) for m in msk] 45 | return img, msk 46 | else: 47 | return img, msk 48 | 49 | def _random_target_size(self): 50 | if len(self.random_scale) == 2: 51 | target_size = random.uniform(self.shortest_size * self.random_scale[0], 52 | self.shortest_size * self.random_scale[1]) 53 | else: 54 | target_sizes = [self.shortest_size * scale for scale in self.random_scale] 55 | target_size = random.choice(target_sizes) 56 | return int(target_size) 57 | 58 | def _normalize_image(self, img): 59 | if self.rgb_mean is not None: 60 | img.sub_(img.new(self.rgb_mean).view(-1, 1, 1)) 61 | if self.rgb_std is not None: 62 | img.div_(img.new(self.rgb_std).view(-1, 1, 1)) 63 | return img 64 | 65 | @staticmethod 66 | def _Rotate2D(pts, cnt, ang): 67 | ang = np.deg2rad(ang) 68 | return scipy.dot(pts - cnt, 69 | scipy.array([[scipy.cos(ang), scipy.sin(ang)], [-scipy.sin(ang), scipy.cos(ang)]])) + cnt 70 | 71 | @staticmethod 72 | def _prepare_frcnn_format(boxes, im_size): 73 | A = boxes 74 | xy_ctr = np.sum(A, axis=2) / 4 75 | x_ctr = xy_ctr[:, 0] 76 | y_ctr = xy_ctr[:, 1] 77 | width = np.sqrt(np.sum((A[:, :, 0] - A[:, :, 1]) ** 2, axis=1)) 78 | height = np.sqrt(np.sum((A[:, :, 1] - A[:, :, 2]) ** 2, axis=1)) 79 | 80 | theta = np.zeros((A.shape[0]), dtype=np.int) 81 | 82 | theta = np.arctan((A[:, 1, 1] - A[:, 1, 0]) / (A[:, 0, 0] - A[:, 0, 1])) 83 | b = np.arctan((A[:, 1, 0] - A[:, 1, 1]) / (A[:, 0, 1] - A[:, 0, 0])) 84 | theta[np.where(A[:, 0, 0] <= A[:, 0, 1])] = b[np.where(A[:, 0, 0] <= A[:, 0, 1])] 85 | 86 | # used for fasterrcnn loss 87 | x_min = x_ctr - width / 2 88 | x_max = x_ctr + width / 2 89 | y_min = y_ctr - height / 2 90 | y_max = y_ctr + height / 2 91 | 92 | x_coords = np.vstack((x_min, x_max)) 93 | y_coords = np.vstack((y_min, y_max)) 94 | 95 | mat = np.asarray((np.all(x_coords > im_size[1], axis=0), np.all(x_coords < 0, axis=0), 96 | np.all(y_coords > im_size[0], axis=0), np.all(y_coords < 0, axis=0))) 97 | 98 | fail = np.any(mat, axis=0) 99 | correct_idx = np.where(fail == False) 100 | theta_deg = np.rad2deg(theta) + 90 101 | cls = (np.round((theta_deg) / (180 / 18))).astype(int) 102 | cls[np.where(cls == 18)] = 0 103 | 104 | if np.any(cls) > 17: 105 | assert False 106 | 107 | ret_value = ( 108 | x_min[correct_idx], y_min[correct_idx], theta_deg[correct_idx], x_max[correct_idx], y_max[correct_idx], 109 | cls[correct_idx]) 110 | return ret_value 111 | 112 | def _rotateAndScale(self, img, msk, all_boxes): 113 | im_size = [self.shortest_size, self.longest_max_size] 114 | 115 | img_pad = cv2.copyMakeBorder(img, 200, 200, 200, 200, borderType=cv2.BORDER_REPLICATE) 116 | msk_pad = cv2.copyMakeBorder(msk, 200, 200, 200, 200, borderType=cv2.BORDER_CONSTANT) 117 | 118 | (oldY, oldX, chan) = img_pad.shape # note: numpy uses (y,x) convention but most OpenCV functions use (x,y) 119 | 120 | theta = float(np.random.randint(360) - 1) 121 | dx = np.random.randint(101) - 51 122 | dy = np.random.randint(101) - 51 123 | 124 | M = cv2.getRotationMatrix2D(center=(oldX / 2, oldY / 2), angle=theta, 125 | scale=1.0) # rotate about center of image. 126 | 127 | # choose a new image size. 128 | newX, newY = oldX, oldY 129 | # include this if you want to prevent corners being cut off 130 | r = np.deg2rad(theta) 131 | newX, newY = (abs(np.sin(r) * newY) + abs(np.cos(r) * newX), abs(np.sin(r) * newX) + abs(np.cos(r) * newY)) 132 | 133 | # Find the translation that moves the result to the center of that region. 134 | (tx, ty) = ((newX - oldX) / 2, (newY - oldY) / 2) 135 | M[0, 2] += tx 136 | M[1, 2] += ty 137 | 138 | imgRotate = cv2.warpAffine(img_pad, M, dsize=(int(newX), int(newY))) 139 | mskRotate = cv2.warpAffine(msk_pad, M, dsize=(int(newX), int(newY))) 140 | 141 | imgRotateCrop = imgRotate[ 142 | int(imgRotate.shape[0] / 2 - (im_size[0] / 2)) - dx:int( 143 | imgRotate.shape[0] / 2 + (im_size[0] / 2)) - dx, 144 | int(imgRotate.shape[1] / 2 - (im_size[1] / 2)) - dy:int( 145 | imgRotate.shape[1] / 2 + (im_size[1] / 2)) - dy, :] 146 | mskRotateCrop = mskRotate[ 147 | int(mskRotate.shape[0] / 2 - (im_size[0] / 2)) - dx:int( 148 | mskRotate.shape[0] / 2 + (im_size[0] / 2)) - dx, 149 | int(mskRotate.shape[1] / 2 - (im_size[1] / 2)) - dy:int( 150 | mskRotate.shape[1] / 2 + (im_size[1] / 2)) - dy] 151 | 152 | bbsInShift = np.zeros_like(all_boxes) 153 | bbsInShift[:, 0, :] = all_boxes[:, 0, :] - (im_size[1] / 2) 154 | bbsInShift[:, 1, :] = all_boxes[:, 1, :] - (im_size[0] / 2) 155 | R = np.array([[np.cos(theta / 180 * np.pi), -np.sin(theta / 180 * np.pi)], 156 | [np.sin(theta / 180 * np.pi), np.cos(theta / 180 * np.pi)]]) 157 | R_all = np.expand_dims(R, axis=0) # 158 | R_all = np.repeat(R_all, all_boxes.shape[0], axis=0) 159 | bbsInShift = np.swapaxes(bbsInShift, 1, 2) 160 | 161 | bbsRotated = np.dot(bbsInShift, R_all.T) 162 | bbsRotated = bbsRotated[:, :, :, 0] 163 | bbsRotated = np.swapaxes(bbsRotated, 1, 2) 164 | bbsInShiftBack = np.asarray(bbsRotated) 165 | bbsInShiftBack[:, 0, :] = (bbsRotated[:, 0, :] + (im_size[1] / 2) + dy) 166 | bbsInShiftBack[:, 1, :] = (bbsRotated[:, 1, :] + (im_size[0] / 2) + dx) 167 | 168 | return imgRotateCrop, mskRotateCrop, bbsInShiftBack 169 | 170 | def __call__(self, img_, msk_, bbox_infos_): 171 | im_size = [self.shortest_size, self.longest_max_size] 172 | bbox_infos_ = np.swapaxes(bbox_infos_, 1, 2) 173 | 174 | x_min = int(img_.shape[0] / 2 - int(im_size[0] / 2)) 175 | x_max = int(img_.shape[0] / 2 + int(im_size[0] / 2)) 176 | y_min = int(img_.shape[1] / 2 - int(im_size[1] / 2)) 177 | y_max = int(img_.shape[1] / 2 + int(im_size[1] / 2)) 178 | 179 | new_origin = np.array([[y_min], [x_min]]) 180 | 181 | img = img_[x_min:x_max, y_min:y_max, :] 182 | 183 | msk = msk_[x_min:x_max, y_min:y_max] 184 | 185 | bbox_infos_ = bbox_infos_ - new_origin 186 | bbox_infos = np.copy(bbox_infos_) 187 | 188 | if self.rotate_and_scale: 189 | img, msk, bbox_transformed = self._rotateAndScale(img, msk, bbox_infos_) 190 | bbox_infos = bbox_transformed 191 | # Random flip 192 | if self.random_flip: 193 | img, msk = self._random_flip(img, msk) 194 | 195 | # Adjust scale, possibly at random 196 | if self.random_scale is not None: 197 | target_size = self._random_target_size() 198 | else: 199 | target_size = self.shortest_size 200 | 201 | ret = self._prepare_frcnn_format(bbox_infos, im_size) 202 | (x1, y1, theta, x2, y2, cls) = ret 203 | if len(cls) == 0: 204 | print('NO valid boxes after augmentation, switch to gt values') 205 | ret = self._prepare_frcnn_format(bbox_infos_, im_size) 206 | img = img_[x_min:x_max, y_min:y_max, :] 207 | 208 | msk = msk_[x_min:x_max, y_min:y_max] 209 | 210 | bbox_infos = np.asarray(ret).T 211 | bbox_infos = bbox_infos.astype(np.float32) 212 | 213 | # Image transformations 214 | img = tfn.to_tensor(img) 215 | img = self._normalize_image(img) 216 | 217 | # Label transformations 218 | msk = np.stack([np.array(m, dtype=np.int32, copy=False) for m in msk], axis=0) 219 | 220 | # Convert labels to torch and extract bounding boxes 221 | msk = torch.from_numpy(msk.astype(np.long)) 222 | 223 | bbx = torch.from_numpy(np.asarray(bbox_infos)).contiguous() 224 | if bbox_infos.shape[1] != 6: 225 | assert False 226 | 227 | return dict(img=img, msk=msk, bbx=bbx), im_size 228 | 229 | 230 | class OCIDTestTransform: 231 | """Transformer function for OCID_grasp dataset, used at test time 232 | """ 233 | 234 | def __init__(self, 235 | shortest_size, 236 | longest_max_size, 237 | rgb_mean=None, 238 | rgb_std=None): 239 | self.longest_max_size = longest_max_size 240 | self.shortest_size = shortest_size 241 | self.rgb_mean = rgb_mean 242 | self.rgb_std = rgb_std 243 | 244 | def _adjusted_scale(self, in_width, in_height): 245 | min_size = min(in_width, in_height) 246 | scale = self.shortest_size / min_size 247 | return scale 248 | 249 | def _normalize_image(self, img): 250 | if self.rgb_mean is not None: 251 | img.sub_(img.new(self.rgb_mean).view(-1, 1, 1)) 252 | if self.rgb_std is not None: 253 | img.div_(img.new(self.rgb_std).view(-1, 1, 1)) 254 | return img 255 | 256 | def __call__(self, img): 257 | im_size = [self.shortest_size, self.longest_max_size] 258 | 259 | x_min = int(img.shape[0] / 2 - int(im_size[0] / 2)) 260 | x_max = int(img.shape[0] / 2 + int(im_size[0] / 2)) 261 | y_min = int(img.shape[1] / 2 - int(im_size[1] / 2)) 262 | y_max = int(img.shape[1] / 2 + int(im_size[1] / 2)) 263 | 264 | img = img[x_min:x_max, y_min:y_max, :] 265 | 266 | # Image transformations 267 | img = tfn.to_tensor(img) 268 | img = self._normalize_image(img) 269 | 270 | return img, im_size 271 | -------------------------------------------------------------------------------- /grasp_det_seg/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /grasp_det_seg/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/models/det_seg.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from grasp_det_seg.utils.sequence import pad_packed_images 7 | 8 | NETWORK_INPUTS = ["img", "msk", "bbx"] 9 | 10 | class DetSegNet(nn.Module): 11 | def __init__(self, 12 | body, 13 | rpn_head, 14 | roi_head, 15 | sem_head, 16 | rpn_algo, 17 | detection_algo, 18 | semantic_seg_algo, 19 | classes): 20 | super(DetSegNet, self).__init__() 21 | self.num_stuff = classes["stuff"] 22 | 23 | # Modules 24 | self.body = body 25 | self.rpn_head = rpn_head 26 | self.roi_head = roi_head 27 | self.sem_head = sem_head 28 | 29 | # Algorithms 30 | self.rpn_algo = rpn_algo 31 | self.detection_algo = detection_algo 32 | self.semantic_seg_algo = semantic_seg_algo 33 | 34 | def _prepare_inputs(self, msk, cat, iscrowd, bbx): 35 | cat_out, iscrowd_out, bbx_out, ids_out, sem_out = [], [], [], [], [] 36 | for msk_i, cat_i, iscrowd_i, bbx_i in zip(msk, cat, iscrowd, bbx): 37 | msk_i = msk_i.squeeze(0) 38 | thing = (cat_i >= self.num_stuff) & (cat_i != 255) 39 | valid = thing & ~iscrowd_i 40 | 41 | if valid.any().item(): 42 | cat_out.append(cat_i[valid]) 43 | bbx_out.append(bbx_i[valid]) 44 | ids_out.append(torch.nonzero(valid)) 45 | else: 46 | cat_out.append(None) 47 | bbx_out.append(None) 48 | ids_out.append(None) 49 | 50 | if iscrowd_i.any().item(): 51 | iscrowd_i = iscrowd_i & thing 52 | iscrowd_out.append(iscrowd_i[msk_i]) 53 | else: 54 | iscrowd_out.append(None) 55 | 56 | sem_out.append(cat_i[msk_i]) 57 | 58 | return cat_out, iscrowd_out, bbx_out, ids_out, sem_out 59 | 60 | def forward(self, img, msk=None, cat=None, iscrowd=None, bbx=None, do_loss=False, do_prediction=True): 61 | # Pad the input images 62 | img, valid_size = pad_packed_images(img) 63 | img_size = img.shape[-2:] 64 | 65 | # Convert ground truth to the internal format 66 | if do_loss: 67 | sem, _ = pad_packed_images(msk) 68 | msk, _ = pad_packed_images(msk) 69 | 70 | # Run network body 71 | x = self.body(img) 72 | 73 | # RPN part 74 | if do_loss: 75 | obj_loss, bbx_loss, proposals = self.rpn_algo.training( 76 | self.rpn_head, x, bbx, iscrowd, valid_size, training=self.training, do_inference=True) 77 | elif do_prediction: 78 | proposals = self.rpn_algo.inference(self.rpn_head, x, valid_size, self.training) 79 | obj_loss, bbx_loss = None, None 80 | else: 81 | obj_loss, bbx_loss, proposals = None, None, None 82 | 83 | # ROI part 84 | if do_loss: 85 | roi_cls_loss, roi_bbx_loss = self.detection_algo.training( 86 | self.roi_head, x, proposals, bbx, cat, iscrowd, img_size) 87 | else: 88 | roi_cls_loss, roi_bbx_loss = None, None 89 | if do_prediction: 90 | bbx_pred, cls_pred, obj_pred = self.detection_algo.inference( 91 | self.roi_head, x, proposals, valid_size, img_size) 92 | else: 93 | bbx_pred, cls_pred, obj_pred = None, None, None 94 | 95 | # Segmentation part 96 | if do_loss: 97 | sem_loss, conf_mat, sem_pred,sem_logits,sem_logits_low_res, sem_pred_low_res, sem_feats =\ 98 | self.semantic_seg_algo.training(self.sem_head, x, sem, valid_size, img_size) 99 | elif do_prediction: 100 | sem_pred,sem_feats,_ = self.semantic_seg_algo.inference(self.sem_head, x, valid_size, img_size) 101 | sem_loss, conf_mat = None, None 102 | else: 103 | sem_loss, conf_mat, sem_pred, sem_feats = None, None, None, None 104 | 105 | # Prepare outputs 106 | loss = OrderedDict([ 107 | ("obj_loss", obj_loss), 108 | ("bbx_loss", bbx_loss), 109 | ("roi_cls_loss", roi_cls_loss), 110 | ("roi_bbx_loss", roi_bbx_loss), 111 | ("sem_loss", sem_loss) 112 | ]) 113 | pred = OrderedDict([ 114 | ("bbx_pred", bbx_pred), 115 | ("cls_pred", cls_pred), 116 | ("obj_pred", obj_pred), 117 | ("sem_pred", sem_pred) 118 | ]) 119 | conf = OrderedDict([ 120 | ("sem_conf", conf_mat) 121 | ]) 122 | return loss, pred, conf 123 | -------------------------------------------------------------------------------- /grasp_det_seg/models/resnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import torch.nn as nn 6 | from inplace_abn import ABN 7 | 8 | from grasp_det_seg.modules.misc import GlobalAvgPool2d 9 | from grasp_det_seg.modules.residual import ResidualBlock 10 | from grasp_det_seg.utils.misc import try_index 11 | 12 | 13 | class ResNet(nn.Module): 14 | """Standard residual network 15 | 16 | Parameters 17 | ---------- 18 | structure : list of int 19 | Number of residual blocks in each of the four modules of the network 20 | bottleneck : bool 21 | If `True` use "bottleneck" residual blocks with 3 convolutions, otherwise use standard blocks 22 | norm_act : callable or list of callable 23 | Function to create normalization / activation Module. If a list is passed it should have four elements, one for 24 | each module of the network 25 | classes : int 26 | If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end 27 | of the network 28 | dilation : int or list of int 29 | List of dilation factors for the four modules of the network, or `1` to ignore dilation 30 | dropout : list of float or None 31 | If present, specifies the amount of dropout to apply in the blocks of each of the four modules of the network 32 | caffe_mode : bool 33 | If `True`, use bias in the first convolution for compatibility with the Caffe pretrained models 34 | """ 35 | 36 | def __init__(self, 37 | structure, 38 | bottleneck, 39 | norm_act=ABN, 40 | classes=0, 41 | dilation=1, 42 | dropout=None, 43 | caffe_mode=False): 44 | super(ResNet, self).__init__() 45 | self.structure = structure 46 | self.bottleneck = bottleneck 47 | self.dilation = dilation 48 | self.dropout = dropout 49 | self.caffe_mode = caffe_mode 50 | 51 | if len(structure) != 4: 52 | raise ValueError("Expected a structure with four values") 53 | if dilation != 1 and len(dilation) != 4: 54 | raise ValueError("If dilation is not 1 it must contain four values") 55 | 56 | # Initial layers 57 | layers = [ 58 | ("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=caffe_mode)), 59 | ("bn1", try_index(norm_act, 0)(64)) 60 | ] 61 | if try_index(dilation, 0) == 1: 62 | layers.append(("pool1", nn.MaxPool2d(3, stride=2, padding=1))) 63 | self.mod1 = nn.Sequential(OrderedDict(layers)) 64 | 65 | # Groups of residual blocks 66 | in_channels = 64 67 | if self.bottleneck: 68 | channels = (64, 64, 256) 69 | else: 70 | channels = (64, 64) 71 | for mod_id, num in enumerate(structure): 72 | mod_dropout = None 73 | if self.dropout is not None: 74 | if self.dropout[mod_id] is not None: 75 | mod_dropout = partial(nn.Dropout, p=self.dropout[mod_id]) 76 | 77 | # Create blocks for module 78 | blocks = [] 79 | for block_id in range(num): 80 | stride, dil = self._stride_dilation(dilation, mod_id, block_id) 81 | blocks.append(( 82 | "block%d" % (block_id + 1), 83 | ResidualBlock(in_channels, channels, norm_act=try_index(norm_act, mod_id), 84 | stride=stride, dilation=dil, dropout=mod_dropout) 85 | )) 86 | 87 | # Update channels and p_keep 88 | in_channels = channels[-1] 89 | 90 | # Create module 91 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 92 | 93 | # Double the number of channels for the next module 94 | channels = [c * 2 for c in channels] 95 | 96 | # Pooling and predictor 97 | if classes != 0: 98 | self.classifier = nn.Sequential(OrderedDict([ 99 | ("avg_pool", GlobalAvgPool2d()), 100 | ("fc", nn.Linear(in_channels, classes)) 101 | ])) 102 | 103 | @staticmethod 104 | def _stride_dilation(dilation, mod_id, block_id): 105 | d = try_index(dilation, mod_id) 106 | s = 2 if d == 1 and block_id == 0 and mod_id > 0 else 1 107 | return s, d 108 | 109 | def forward(self, x): 110 | outs = OrderedDict() 111 | 112 | outs["mod1"] = self.mod1(x) 113 | outs["mod2"] = self.mod2(outs["mod1"]) 114 | outs["mod3"] = self.mod3(outs["mod2"]) 115 | outs["mod4"] = self.mod4(outs["mod3"]) 116 | outs["mod5"] = self.mod5(outs["mod4"]) 117 | 118 | if hasattr(self, "classifier"): 119 | outs["classifier"] = self.classifier(outs["mod5"]) 120 | 121 | return outs 122 | 123 | 124 | _NETS = { 125 | "18": {"structure": [2, 2, 2, 2], "bottleneck": False}, 126 | "34": {"structure": [3, 4, 6, 3], "bottleneck": False}, 127 | "50": {"structure": [3, 4, 6, 3], "bottleneck": True}, 128 | "101": {"structure": [3, 4, 23, 3], "bottleneck": True}, 129 | "152": {"structure": [3, 8, 36, 3], "bottleneck": True}, 130 | } 131 | 132 | __all__ = [] 133 | for name, params in _NETS.items(): 134 | net_name = "net_resnet" + name 135 | setattr(sys.modules[__name__], net_name, partial(ResNet, **params)) 136 | __all__.append(net_name) 137 | -------------------------------------------------------------------------------- /grasp_det_seg/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/modules/__init__.py -------------------------------------------------------------------------------- /grasp_det_seg/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/modules/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/modules/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/modules/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/modules/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/modules/__pycache__/residual.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/modules/__pycache__/residual.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/modules/fpn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | from inplace_abn import ABN 6 | 7 | 8 | class FPN(nn.Module): 9 | """Feature Pyramid Network module 10 | 11 | Parameters 12 | ---------- 13 | in_channels : sequence of int 14 | Number of feature channels in each of the input feature levels 15 | out_channels : int 16 | Number of output feature channels (same for each level) 17 | extra_scales : int 18 | Number of extra low-resolution scales 19 | norm_act : callable 20 | Function to create normalization + activation modules 21 | interpolation : str 22 | Interpolation mode to use when up-sampling, see `torch.nn.functional.interpolate` 23 | """ 24 | 25 | def __init__(self, in_channels, out_channels=256, extra_scales=0, norm_act=ABN, interpolation="nearest"): 26 | super(FPN, self).__init__() 27 | self.interpolation = interpolation 28 | 29 | # Lateral connections and output convolutions 30 | self.lateral = nn.ModuleList([ 31 | self._make_lateral(channels, out_channels, norm_act) for channels in in_channels 32 | ]) 33 | self.output = nn.ModuleList([ 34 | self._make_output(out_channels, norm_act) for _ in in_channels 35 | ]) 36 | 37 | if extra_scales > 0: 38 | self.extra = nn.ModuleList([ 39 | self._make_extra(in_channels[-1] if i == 0 else out_channels, out_channels, norm_act) 40 | for i in range(extra_scales) 41 | ]) 42 | 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | gain = nn.init.calculate_gain(self.lateral[0].bn.activation, self.lateral[0].bn.activation_param) 47 | for mod in self.modules(): 48 | if isinstance(mod, nn.Conv2d): 49 | nn.init.xavier_normal_(mod.weight, gain) 50 | elif isinstance(mod, ABN): 51 | nn.init.constant_(mod.weight, 1.) 52 | if hasattr(mod, "bias") and mod.bias is not None: 53 | nn.init.constant_(mod.bias, 0.) 54 | 55 | @staticmethod 56 | def _make_lateral(input_channels, hidden_channels, norm_act): 57 | return nn.Sequential(OrderedDict([ 58 | ("conv", nn.Conv2d(input_channels, hidden_channels, 1, bias=False)), 59 | ("bn", norm_act(hidden_channels)) 60 | ])) 61 | 62 | @staticmethod 63 | def _make_output(channels, norm_act): 64 | return nn.Sequential(OrderedDict([ 65 | ("conv", nn.Conv2d(channels, channels, 3, padding=1, bias=False)), 66 | ("bn", norm_act(channels)) 67 | ])) 68 | 69 | @staticmethod 70 | def _make_extra(input_channels, out_channels, norm_act): 71 | return nn.Sequential(OrderedDict([ 72 | ("conv", nn.Conv2d(input_channels, out_channels, 3, stride=2, padding=1, bias=False)), 73 | ("bn", norm_act(out_channels)) 74 | ])) 75 | 76 | def forward(self, xs): 77 | """Feature Pyramid Network module 78 | 79 | Parameters 80 | ---------- 81 | xs : sequence of torch.Tensor 82 | The input feature maps, tensors with shapes N x C_i x H_i x W_i 83 | 84 | Returns 85 | ------- 86 | ys : sequence of torch.Tensor 87 | The output feature maps, tensors with shapes N x K x H_i x W_i 88 | """ 89 | ys = [] 90 | interp_params = {"mode": self.interpolation} 91 | if self.interpolation == "bilinear": 92 | interp_params["align_corners"] = False 93 | 94 | # Build pyramid 95 | for x_i, lateral_i in zip(xs[::-1], self.lateral[::-1]): 96 | x_i = lateral_i(x_i) 97 | if len(ys) > 0: 98 | x_i = x_i + functional.interpolate(ys[0], size=x_i.shape[-2:], **interp_params) 99 | ys.insert(0, x_i) 100 | 101 | # Compute outputs 102 | ys = [output_i(y_i) for y_i, output_i in zip(ys, self.output)] 103 | 104 | # Compute extra outputs if necessary 105 | if hasattr(self, "extra"): 106 | y = xs[-1] 107 | for extra_i in self.extra: 108 | y = extra_i(y) 109 | ys.append(y) 110 | 111 | return ys 112 | 113 | 114 | class FPNBody(nn.Module): 115 | """Wrapper for a backbone network and an FPN module 116 | 117 | Parameters 118 | ---------- 119 | backbone : torch.nn.Module 120 | Backbone network, which takes a batch of images and produces a dictionary of intermediate features 121 | fpn : torch.nn.Module 122 | FPN module, which takes a list of intermediate features and produces a list of outputs 123 | fpn_inputs : iterable 124 | An iterable producing the names of the intermediate features to take from the backbone's output and pass 125 | to the FPN 126 | """ 127 | 128 | def __init__(self, backbone, fpn, fpn_inputs=()): 129 | super(FPNBody, self).__init__() 130 | self.fpn_inputs = fpn_inputs 131 | 132 | self.backbone = backbone 133 | self.fpn = fpn 134 | 135 | def forward(self, x): 136 | x = self.backbone(x) 137 | xs = [x[fpn_input] for fpn_input in self.fpn_inputs] 138 | return self.fpn(xs) 139 | -------------------------------------------------------------------------------- /grasp_det_seg/modules/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn import FPNROIHead, FPNSemanticHeadDeeplab 2 | from .rpn import RPNHead 3 | -------------------------------------------------------------------------------- /grasp_det_seg/modules/heads/fpn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as functional 6 | from inplace_abn import ABN 7 | from grasp_det_seg.utils.misc import try_index 8 | 9 | class FPNROIHead(nn.Module): 10 | """ROI head module for FPN 11 | """ 12 | 13 | def __init__(self, in_channels, classes, roi_size, hidden_channels=1024, norm_act=ABN): 14 | super(FPNROIHead, self).__init__() 15 | 16 | self.fc = nn.Sequential(OrderedDict([ 17 | ("fc1", nn.Linear(int(roi_size[0] * roi_size[1] * in_channels / 4), hidden_channels, bias=False)), 18 | ("bn1", norm_act(hidden_channels)), 19 | ("fc2", nn.Linear(hidden_channels, hidden_channels, bias=False)), 20 | ("bn2", norm_act(hidden_channels)) 21 | ])) 22 | self.roi_cls = nn.Linear(hidden_channels, classes["thing"] + 1) 23 | self.roi_bbx = nn.Linear(hidden_channels, classes["thing"] * 4) 24 | 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | gain = nn.init.calculate_gain(self.fc.bn1.activation, self.fc.bn1.activation_param) 29 | 30 | for name, mod in self.named_modules(): 31 | if isinstance(mod, nn.Linear): 32 | if "roi_cls" in name: 33 | nn.init.xavier_normal_(mod.weight, .01) 34 | elif "roi_bbx" in name: 35 | nn.init.xavier_normal_(mod.weight, .001) 36 | else: 37 | nn.init.xavier_normal_(mod.weight, gain) 38 | elif isinstance(mod, ABN): 39 | nn.init.constant_(mod.weight, 1.) 40 | 41 | if hasattr(mod, "bias") and mod.bias is not None: 42 | nn.init.constant_(mod.bias, 0.) 43 | 44 | def forward(self, x): 45 | """ROI head module for FPN 46 | """ 47 | x = functional.avg_pool2d(x, 2) 48 | 49 | # Run head 50 | x = self.fc(x.view(x.size(0), -1)) 51 | return self.roi_cls(x), self.roi_bbx(x).view(x.size(0), -1, 4) 52 | 53 | class FPNSemanticHeadDeeplab(nn.Module): 54 | """Semantic segmentation head for FPN-style networks, extending Deeplab v3 for FPN bodies""" 55 | 56 | class _MiniDL(nn.Module): 57 | def __init__(self, in_channels, out_channels, dilation, pooling_size, norm_act): 58 | super(FPNSemanticHeadDeeplab._MiniDL, self).__init__() 59 | self.pooling_size = pooling_size 60 | 61 | self.conv1_3x3 = nn.Conv2d(in_channels, out_channels, 3, padding=1, bias=False) 62 | self.conv1_dil = nn.Conv2d(in_channels, out_channels, 3, dilation=dilation, padding=dilation, bias=False) 63 | self.conv1_glb = nn.Conv2d(in_channels, out_channels, 1, bias=False) 64 | self.bn1 = norm_act(out_channels * 3) 65 | 66 | self.conv2 = nn.Conv2d(out_channels * 3, out_channels, 1, bias=False) 67 | self.bn2 = norm_act(out_channels) 68 | 69 | def _global_pooling(self, x): 70 | pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]), 71 | min(try_index(self.pooling_size, 1), x.shape[3])) 72 | padding = ( 73 | (pooling_size[1] - 1) // 2, 74 | (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1, 75 | (pooling_size[0] - 1) // 2, 76 | (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1 77 | ) 78 | 79 | pool = functional.avg_pool2d(x, pooling_size, stride=1) 80 | pool = functional.pad(pool, pad=padding, mode="replicate") 81 | return pool 82 | 83 | def forward(self, x): 84 | x = torch.cat([ 85 | self.conv1_3x3(x), 86 | self.conv1_dil(x), 87 | self.conv1_glb(self._global_pooling(x)), 88 | ], dim=1) 89 | x = self.bn1(x) 90 | x = self.conv2(x) 91 | x = self.bn2(x) 92 | return x 93 | 94 | def __init__(self, 95 | in_channels, 96 | min_level, 97 | levels, 98 | num_classes, 99 | hidden_channels=128, 100 | dilation=6, 101 | pooling_size=(64, 64), 102 | norm_act=ABN, 103 | interpolation="bilinear"): 104 | super(FPNSemanticHeadDeeplab, self).__init__() 105 | self.min_level = min_level 106 | self.levels = levels 107 | self.interpolation = interpolation 108 | 109 | self.output = nn.ModuleList([ 110 | self._MiniDL(in_channels, hidden_channels, dilation, pooling_size, norm_act) for _ in range(levels) 111 | ]) 112 | self.conv_sem = nn.Conv2d(hidden_channels * levels, num_classes, 1) 113 | 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | gain = nn.init.calculate_gain(self.output[0].bn1.activation, self.output[0].bn1.activation_param) 118 | for name, mod in self.named_modules(): 119 | if isinstance(mod, nn.Conv2d): 120 | if "conv_sem" not in name: 121 | nn.init.xavier_normal_(mod.weight, gain) 122 | else: 123 | nn.init.xavier_normal_(mod.weight, .1) 124 | elif isinstance(mod, ABN): 125 | nn.init.constant_(mod.weight, 1.) 126 | if hasattr(mod, "bias") and mod.bias is not None: 127 | nn.init.constant_(mod.bias, 0.) 128 | 129 | def forward(self, xs): 130 | xs = xs[self.min_level:self.min_level + self.levels] 131 | 132 | ref_size = xs[0].shape[-2:] 133 | interp_params = {"mode": self.interpolation} 134 | if self.interpolation == "bilinear": 135 | interp_params["align_corners"] = False 136 | 137 | for i, output in enumerate(self.output): 138 | xs[i] = output(xs[i]) 139 | if i > 0: 140 | xs[i] = functional.interpolate(xs[i], size=ref_size, **interp_params) 141 | 142 | xs_feats = torch.cat(xs, dim=1) 143 | xs = self.conv_sem(xs_feats) 144 | 145 | return xs,xs_feats 146 | -------------------------------------------------------------------------------- /grasp_det_seg/modules/heads/rpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from inplace_abn import ABN 4 | 5 | class RPNHead(nn.Module): 6 | """RPN head module 7 | 8 | Parameters 9 | ---------- 10 | in_channels : int 11 | Number of channels in the input feature map 12 | num_anchors : int 13 | Number of anchors predicted at each spatial location 14 | stride : int 15 | Stride of the internal convolutions 16 | hidden_channels : int 17 | Number of channels in the internal intermediate feature map 18 | norm_act : callable 19 | Function to create normalization + activation modules 20 | """ 21 | 22 | def __init__(self, in_channels, num_anchors, stride=1, hidden_channels=255, norm_act=ABN): 23 | super(RPNHead, self).__init__() 24 | 25 | self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, padding=1, stride=stride, bias=False) 26 | self.bn1 = norm_act(hidden_channels) 27 | self.conv_obj = nn.Conv2d(hidden_channels, num_anchors, 1) 28 | self.conv_bbx = nn.Conv2d(hidden_channels, num_anchors * 4, 1) 29 | 30 | self.reset_parameters() 31 | 32 | def reset_parameters(self): 33 | activation = self.bn1.activation 34 | activation_param = self.bn1.activation_param 35 | 36 | # Hidden convolution 37 | gain = nn.init.calculate_gain(activation, activation_param) 38 | nn.init.xavier_normal_(self.conv1.weight, gain) 39 | self.bn1.reset_parameters() 40 | 41 | # Classifiers 42 | for m in [self.conv_obj, self.conv_bbx]: 43 | nn.init.xavier_normal_(m.weight, .01) 44 | nn.init.constant_(m.bias, 0) 45 | 46 | def forward(self, x): 47 | """RPN head module 48 | """ 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | return self.conv_obj(x), self.conv_bbx(x) 52 | -------------------------------------------------------------------------------- /grasp_det_seg/modules/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from grasp_det_seg.utils.parallel import PackedSequence 4 | 5 | 6 | def smooth_l1(x1, x2, sigma): 7 | """Smooth L1 loss""" 8 | sigma2 = sigma ** 2 9 | 10 | diff = x1 - x2 11 | abs_diff = diff.abs() 12 | 13 | mask = (abs_diff.detach() < (1. / sigma2)).float() 14 | return mask * (sigma2 / 2.) * diff ** 2 + (1 - mask) * (abs_diff - 0.5 / sigma2) 15 | 16 | 17 | def ohem_loss(loss, ohem=None): 18 | if isinstance(loss, torch.Tensor): 19 | loss = loss.view(loss.size(0), -1) 20 | if ohem is None: 21 | return loss.mean() 22 | 23 | top_k = min(max(int(ohem * loss.size(1)), 1), loss.size(1)) 24 | if top_k != loss.size(1): 25 | loss, _ = loss.topk(top_k, dim=1) 26 | 27 | return loss.mean() 28 | elif isinstance(loss, PackedSequence): 29 | if ohem is None: 30 | return sum(loss_i.mean() for loss_i in loss) / len(loss) 31 | 32 | loss_out = loss.data.new_zeros(()) 33 | for loss_i in loss: 34 | loss_i = loss_i.view(-1) 35 | 36 | top_k = min(max(int(ohem * loss_i.numel()), 1), loss_i.numel()) 37 | if top_k != loss_i.numel(): 38 | loss_i, _ = loss_i.topk(top_k, dim=0) 39 | 40 | loss_out += loss_i.mean() 41 | 42 | return loss_out / len(loss) 43 | -------------------------------------------------------------------------------- /grasp_det_seg/modules/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as functional 4 | 5 | from inplace_abn import ABN 6 | 7 | 8 | class GlobalAvgPool2d(nn.Module): 9 | """Global average pooling over the input's spatial dimensions""" 10 | 11 | def __init__(self): 12 | super(GlobalAvgPool2d, self).__init__() 13 | 14 | def forward(self, inputs): 15 | in_size = inputs.size() 16 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 17 | 18 | 19 | class Interpolate(nn.Module): 20 | """nn.Module wrapper to nn.functional.interpolate""" 21 | 22 | def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): 23 | super(Interpolate, self).__init__() 24 | self.size = size 25 | self.scale_factor = scale_factor 26 | self.mode = mode 27 | self.align_corners = align_corners 28 | 29 | def forward(self, x): 30 | return functional.interpolate(x, self.size, self.scale_factor, self.mode, self.align_corners) 31 | 32 | 33 | class ActivatedAffine(ABN): 34 | """Drop-in replacement for ABN which performs inference-mode BN + activation""" 35 | 36 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", 37 | activation_param=0.01): 38 | super(ActivatedAffine, self).__init__(num_features, eps, momentum, affine, activation, activation_param) 39 | 40 | @staticmethod 41 | def _broadcast_shape(x): 42 | out_size = [] 43 | for i, s in enumerate(x.size()): 44 | if i != 1: 45 | out_size.append(1) 46 | else: 47 | out_size.append(s) 48 | return out_size 49 | 50 | def forward(self, x): 51 | inv_var = torch.rsqrt(self.running_var + self.eps) 52 | if self.affine: 53 | alpha = self.weight * inv_var 54 | beta = self.bias - self.running_mean * alpha 55 | else: 56 | alpha = inv_var 57 | beta = - self.running_mean * alpha 58 | 59 | x.mul_(alpha.view(self._broadcast_shape(x))) 60 | x.add_(beta.view(self._broadcast_shape(x))) 61 | 62 | if self.activation == "relu": 63 | return functional.relu(x, inplace=True) 64 | elif self.activation == "leaky_relu": 65 | return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True) 66 | elif self.activation == "elu": 67 | return functional.elu(x, alpha=self.activation_param, inplace=True) 68 | elif self.activation == "identity": 69 | return x 70 | else: 71 | raise RuntimeError("Unknown activation function {}".format(self.activation)) 72 | 73 | 74 | class ActivatedGroupNorm(ABN): 75 | """GroupNorm + activation function compatible with the ABN interface""" 76 | 77 | def __init__(self, num_channels, num_groups, eps=1e-5, affine=True, activation="leaky_relu", activation_param=0.01): 78 | super(ActivatedGroupNorm, self).__init__(num_channels, eps, affine=affine, activation=activation, 79 | activation_param=activation_param) 80 | self.num_groups = num_groups 81 | 82 | # Delete running mean and var since they are not used here 83 | delattr(self, "running_mean") 84 | delattr(self, "running_var") 85 | 86 | def reset_parameters(self): 87 | if self.affine: 88 | nn.init.constant_(self.weight, 1) 89 | nn.init.constant_(self.bias, 0) 90 | 91 | def forward(self, x): 92 | x = functional.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 93 | 94 | if self.activation == "relu": 95 | return functional.relu(x, inplace=True) 96 | elif self.activation == "leaky_relu": 97 | return functional.leaky_relu(x, negative_slope=self.activation_param, inplace=True) 98 | elif self.activation == "elu": 99 | return functional.elu(x, alpha=self.activation_param, inplace=True) 100 | elif self.activation == "identity": 101 | return x 102 | else: 103 | raise RuntimeError("Unknown activation function {}".format(self.activation)) 104 | -------------------------------------------------------------------------------- /grasp_det_seg/modules/residual.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | from inplace_abn import ABN 6 | 7 | 8 | class ResidualBlock(nn.Module): 9 | """Configurable residual block 10 | 11 | Parameters 12 | ---------- 13 | in_channels : int 14 | Number of input channels. 15 | channels : list of int 16 | Number of channels in the internal feature maps. Can either have two or three elements: if three construct 17 | a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then 18 | `3 x 3` then `1 x 1` convolutions. 19 | stride : int 20 | Stride of the first `3 x 3` convolution 21 | dilation : int 22 | Dilation to apply to the `3 x 3` convolutions. 23 | groups : int 24 | Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with 25 | bottleneck blocks. 26 | norm_act : callable 27 | Function to create normalization / activation Module. 28 | dropout: callable 29 | Function to create Dropout Module. 30 | """ 31 | 32 | def __init__(self, 33 | in_channels, 34 | channels, 35 | stride=1, 36 | dilation=1, 37 | groups=1, 38 | norm_act=ABN, 39 | dropout=None): 40 | super(ResidualBlock, self).__init__() 41 | 42 | # Check parameters for inconsistencies 43 | if len(channels) != 2 and len(channels) != 3: 44 | raise ValueError("channels must contain either two or three values") 45 | if len(channels) == 2 and groups != 1: 46 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 47 | 48 | is_bottleneck = len(channels) == 3 49 | need_proj_conv = stride != 1 or in_channels != channels[-1] 50 | 51 | if not is_bottleneck: 52 | bn2 = norm_act(channels[1]) 53 | bn2.activation = "identity" 54 | layers = [ 55 | ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, 56 | dilation=dilation)), 57 | ("bn1", norm_act(channels[0])), 58 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, 59 | dilation=dilation)), 60 | ("bn2", bn2) 61 | ] 62 | if dropout is not None: 63 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 64 | else: 65 | bn3 = norm_act(channels[2]) 66 | bn3.activation = "identity" 67 | layers = [ 68 | ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)), 69 | ("bn1", norm_act(channels[0])), 70 | ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=stride, padding=dilation, bias=False, 71 | groups=groups, dilation=dilation)), 72 | ("bn2", norm_act(channels[1])), 73 | ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)), 74 | ("bn3", bn3) 75 | ] 76 | if dropout is not None: 77 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 78 | self.convs = nn.Sequential(OrderedDict(layers)) 79 | 80 | if need_proj_conv: 81 | self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 82 | self.proj_bn = norm_act(channels[-1]) 83 | self.proj_bn.activation = "identity" 84 | 85 | def forward(self, x): 86 | if hasattr(self, "proj_conv"): 87 | residual = self.proj_conv(x) 88 | residual = self.proj_bn(residual) 89 | else: 90 | residual = x 91 | 92 | x = self.convs(x) + residual 93 | 94 | if self.convs.bn1.activation == "relu": 95 | return functional.relu(x, inplace=True) 96 | elif self.convs.bn1.activation == "leaky_relu": 97 | return functional.leaky_relu(x, negative_slope=self.convs.bn1.activation_param, inplace=True) 98 | elif self.convs.bn1.activation == "elu": 99 | return functional.elu(x, alpha=self.convs.bn1.activation_param, inplace=True) 100 | elif self.convs.bn1.activation == "identity": 101 | return x 102 | else: 103 | raise RuntimeError("Unknown activation function {}".format(self.activation)) 104 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/__init__.py -------------------------------------------------------------------------------- /grasp_det_seg/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/__pycache__/coco_ap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/__pycache__/coco_ap.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/__pycache__/scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/__pycache__/scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/bbx/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbx import * 2 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/bbx/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/bbx/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/bbx/__pycache__/bbx.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/bbx/__pycache__/bbx.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/bbx/_backend.pyi: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def extract_boxes(mask: torch.Tensor, n_instances: int) -> torch.Tensor: ... 5 | 6 | 7 | def mask_count(bbx: torch.Tensor, int_mask: torch.Tensor) -> torch.Tensor: ... 8 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/bbx/bbx.py: -------------------------------------------------------------------------------- 1 | from math import log 2 | import math 3 | import numpy as np 4 | import torch 5 | 6 | from . import _backend 7 | 8 | __all__ = [ 9 | "extract_boxes", 10 | "shift_boxes", 11 | "shift_boxes_rotation", 12 | "calculate_shift", 13 | "calculate_shift_rotation", 14 | "corners_to_center_scale", 15 | "center_scale_to_corners", 16 | "invert_roi_bbx", 17 | "ious", 18 | "mask_overlap", 19 | "bbx_overlap" 20 | ] 21 | 22 | 23 | def extract_boxes(mask, num_instances): 24 | """Calculate bounding boxes from instance segmentation mask 25 | 26 | Parameters 27 | ---------- 28 | mask : torch.Tensor 29 | A tensor with shape H x W containing an instance segmentation mask 30 | num_instances : int 31 | The number of instances to look for 32 | 33 | Returns 34 | ------- 35 | bbx : torch.Tensor 36 | A tensor with shape `num_instances` x 4 containing the coordinates of the bounding boxes in "corners" form 37 | 38 | """ 39 | if mask.ndimension() == 2: 40 | mask = mask.unsqueeze(0) 41 | return _backend.extract_boxes(mask, num_instances) 42 | 43 | 44 | def shift_boxes(bbx, shift, dim=-1, scale_clip=log(1000. / 16.)): 45 | """Shift bounding boxes using the faster r-CNN formulas 46 | 47 | Each 4-vector of `bbx` and `shift` contain, respectively, bounding box coordiantes in "corners" form and shifts 48 | in the form `(dy, dx, dh, dw)`. The output is calculated according to the Faster r-CNN formulas: 49 | 50 | y_out = y_in + h_in * dy 51 | x_out = x_in + w_in * dx 52 | h_out = h_in * exp(dh) 53 | w_out = w_in * exp(dw) 54 | 55 | Parameters 56 | ---------- 57 | bbx : torch.Tensor 58 | A tensor of bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 59 | shift : torch.Tensor 60 | A tensor of shifts with shape N_0 x ... x N_i = 4 x ... x N_n 61 | dim : int 62 | The dimension i of the input tensors which contains the bounding box coordinates and the shifts 63 | scale_clip : float 64 | Maximum scale shift value to avoid exp overflow 65 | 66 | Returns 67 | ------- 68 | bbx_out : torch.Tensor 69 | A tensor of shifted bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 70 | 71 | """ 72 | yx_in, hw_in = corners_to_center_scale(*bbx.split(2, dim=dim)) 73 | dyx, dhw = shift.split(2, dim=dim) 74 | 75 | yx_out = yx_in + hw_in * dyx 76 | hw_out = hw_in * dhw.clamp(max=scale_clip).exp() 77 | 78 | return torch.cat(center_scale_to_corners(yx_out, hw_out), dim=dim) 79 | 80 | def shift_boxes_rotation(bbx,theta, shift, dim=-1, scale_clip=log(1000. / 16.)): 81 | """Shift bounding boxes using the faster r-CNN formulas 82 | 83 | Each 4-vector of `bbx` and `shift` contain, respectively, bounding box coordiantes in "corners" form and shifts 84 | in the form `(dy, dx, dh, dw)`. The output is calculated according to the Faster r-CNN formulas: 85 | 86 | y_out = y_in + h_in * dy 87 | x_out = x_in + w_in * dx 88 | h_out = h_in * exp(dh) 89 | w_out = w_in * exp(dw) 90 | 91 | Parameters 92 | ---------- 93 | bbx : torch.Tensor 94 | A tensor of bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 95 | shift : torch.Tensor 96 | A tensor of shifts with shape N_0 x ... x N_i = 4 x ... x N_n 97 | dim : int 98 | The dimension i of the input tensors which contains the bounding box coordinates and the shifts 99 | scale_clip : float 100 | Maximum scale shift value to avoid exp overflow 101 | 102 | Returns 103 | ------- 104 | bbx_out : torch.Tensor 105 | A tensor of shifted bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 106 | 107 | """ 108 | # convert degree to rad 109 | theta_ = (theta * torch.Tensor([math.pi]).float().to('cuda:0')) / 180. 110 | 111 | 112 | yx_in, hw_in = corners_to_center_scale(*bbx.split(2, dim=dim)) 113 | y_in,x_in = yx_in.split(1,dim=dim) 114 | h_in,w_in = hw_in.split(1,dim=dim) 115 | dyx, dhw,_ = shift.split((2,2,1), dim=dim) 116 | 117 | dy, dx, dh,dw, dtheta = shift.split((1,1,1,1,1), dim=dim) 118 | 119 | pred_ctr_x = dx * w_in * torch.cos(theta_.unsqueeze(1)) - dy * h_in * torch.sin(theta_.unsqueeze(1)) + x_in 120 | pred_ctr_y = dx * w_in * torch.sin(theta_.unsqueeze(1)) + dy * h_in * torch.cos(theta_.unsqueeze(1)) + y_in 121 | pred_w = torch.exp(dw.clamp(max=scale_clip)) * w_in 122 | pred_h = torch.exp(dh.clamp(max=scale_clip)) * h_in 123 | 124 | pred_angle = (torch.Tensor([math.pi]).float().to('cuda:0')) * dtheta + theta_.unsqueeze(1)#[:, np.newaxis] 125 | #pred_angle = pred_angle % (torch.Tensor([math.pi]).float().to('cuda:0')) 126 | pred_angle = torch.fmod(pred_angle,torch.Tensor([math.pi]).float().to('cuda:0')) * (180./torch.Tensor([math.pi]).float().to('cuda:0')) 127 | #torch.fmod(theta_gt - cls_pred_i, torch.Tensor([math.pi]).float().to('cuda:0')) 128 | yx_out_ = yx_in + hw_in * dyx 129 | hw_out_ = hw_in * dhw.clamp(max=scale_clip).exp() 130 | yx_out = torch.cat((pred_ctr_y,pred_ctr_x),dim=dim) 131 | hw_out = torch.cat((pred_h,pred_w),dim=dim) 132 | 133 | return torch.cat(center_scale_to_corners(yx_out, hw_out), dim=dim),pred_angle 134 | 135 | 136 | def calculate_shift(bbx0, bbx1, dim=-1, eps=1e-5): 137 | """Calculate shift parameters between bounding boxes using the faster r-CNN formulas 138 | 139 | Each 4-vector of `bbx0` and `bbx1` contains bounding box coordiantes in "corners" form. The output is calculated 140 | according to the Faster r-CNN formulas: 141 | 142 | dy = (y1 - y0) / h0 143 | dx = (x1 - x0) / w0 144 | dh = log(h1 / h0) 145 | dw = log(w1 / w0) 146 | 147 | Parameters 148 | ---------- 149 | bbx0 : torch.Tensor 150 | A tensor of source bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 151 | bbx1 : torch.Tensor 152 | A tensor of target bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 153 | dim : int 154 | The dimension `i` of the input tensors which contains the bounding box coordinates 155 | eps : float 156 | Small number used to avoid overflow 157 | 158 | Returns 159 | ------- 160 | shift : torch.Tensor 161 | A tensor of calculated shifts from `bbx0` to `bbx1` with shape N_0 x ... x N_i = 4 x ... x N_n 162 | 163 | """ 164 | # 0 -> anchor ; 1 -> gt 165 | yx0, hw0 = corners_to_center_scale(*bbx0.split(2, dim=dim)) 166 | yx1, hw1 = corners_to_center_scale(*bbx1.split(2, dim=dim)) 167 | 168 | hw0 = hw0.clamp(min=eps) 169 | 170 | dyx = (yx1 - yx0) / hw0 171 | dhw = (hw1 / hw0).log() 172 | 173 | return torch.cat([dyx, dhw], dim=dim) 174 | 175 | def calculate_shift_rotation(bbx0, bbx1,cls_pred_i,theta_gt, dim=-1, eps=1e-5): 176 | """Calculate shift parameters between bounding boxes using the faster r-CNN formulas 177 | 178 | Each 4-vector of `bbx0` and `bbx1` contains bounding box coordiantes in "corners" form. The output is calculated 179 | according to the Faster r-CNN formulas: 180 | 181 | dy = (y1 - y0) / h0 182 | dx = (x1 - x0) / w0 183 | dh = log(h1 / h0) 184 | dw = log(w1 / w0) 185 | 186 | Parameters 187 | ---------- 188 | bbx0 : torch.Tensor 189 | A tensor of source bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 190 | bbx1 : torch.Tensor 191 | A tensor of target bounding boxes with shape N_0 x ... x N_i = 4 x ... x N_n 192 | dim : int 193 | The dimension `i` of the input tensors which contains the bounding box coordinates 194 | eps : float 195 | Small number used to avoid overflow 196 | 197 | Returns 198 | ------- 199 | shift : torch.Tensor 200 | A tensor of calculated shifts from `bbx0` to `bbx1` with shape N_0 x ... x N_i = 4 x ... x N_n 201 | 202 | """ 203 | # 0 -> anchor ; 1 -> gt 204 | yx0, hw0 = corners_to_center_scale(*bbx0.split(2, dim=dim)) 205 | yx1, hw1 = corners_to_center_scale(*bbx1.split(2, dim=dim)) 206 | 207 | hw0 = hw0.clamp(min=eps) 208 | 209 | # convert degree to rad 210 | cls_pred_i_ = (cls_pred_i * torch.Tensor([math.pi]).float().to('cuda:0')) / 180. 211 | theta_gt_ = (theta_gt * torch.Tensor([math.pi]).float().to('cuda:0')) / 180. 212 | #cls_pred_i_ = cls_pred_i 213 | #theta_gt_ = theta_gt 214 | 215 | 216 | #dyx = (yx1 - yx0) / hw0 217 | #dyx = (yx1 - yx0) 218 | #tx_mat = [torch.cos(cls_pred_i), torch.sin(cls_pred_i)] 219 | #ty_mat = [torch.cos(cls_pred_i), -torch.sin(cls_pred_i)] 220 | dx = (1/hw0[:,1]) * ((yx1[:,1] - yx0[:,1]) * torch.cos(cls_pred_i_) + (yx1[:,0] - yx0[:,0]) * torch.sin(cls_pred_i_)) 221 | dy = (1/hw0[:,0]) * ((yx1[:,0] - yx0[:,0]) * torch.cos(cls_pred_i_) - (yx1[:,1] - yx0[:,1]) * torch.sin(cls_pred_i_)) 222 | #t_theta = torch.Tensor([1/2*math.pi]).float().to('cuda:0') * \ 223 | # torch.fmod(theta_gt-cls_pred_i,torch.Tensor([2*math.pi]).float().to('cuda:0')) 224 | t_theta = torch.Tensor([1/math.pi]).float().to('cuda:0') * \ 225 | torch.fmod(theta_gt_-cls_pred_i_,torch.Tensor([math.pi]).float().to('cuda:0')) 226 | dhw = (hw1 / hw0).log() 227 | dyx = torch.cat([dy.unsqueeze(1),dx.unsqueeze(1)],dim =dim) 228 | 229 | return torch.cat([dyx, dhw,t_theta.unsqueeze(1)], dim=dim) 230 | 231 | def corners_to_center_scale(p0, p1): 232 | """Convert bounding boxes from "corners" form to "center+scale" form""" 233 | yx = 0.5 * (p0 + p1) 234 | hw = p1 - p0 235 | return yx, hw 236 | 237 | 238 | def center_scale_to_corners(yx, hw): 239 | """Convert bounding boxes from "center+scale" form to "corners" form""" 240 | hw_half = 0.5 * hw 241 | p0 = yx - hw_half 242 | p1 = yx + hw_half 243 | return p0, p1 244 | 245 | 246 | def invert_roi_bbx(bbx, roi_size, img_size): 247 | """Compute bbx coordinates to perform inverse roi sampling""" 248 | bbx_size = bbx[:, 2:] - bbx[:, :2] 249 | return torch.cat([ 250 | -bbx.new(roi_size) * bbx[:, :2] / bbx_size, 251 | bbx.new(roi_size) * (bbx.new(img_size) - bbx[:, :2]) / bbx_size 252 | ], dim=1) 253 | 254 | 255 | def ious(bbx0, bbx1): 256 | """Calculate intersection over union between sets of bounding boxes 257 | 258 | Parameters 259 | ---------- 260 | bbx0 : torch.Tensor 261 | A tensor of bounding boxes in "corners" form with shape N x 4 262 | bbx1 : torch.Tensor 263 | A tensor of bounding boxes in "corners" form with shape M x 4 264 | 265 | Returns 266 | ------- 267 | iou : torch.Tensor 268 | A tensor with shape N x M containing the IoUs between all pairs of bounding boxes in bbx0 and bbx1 269 | """ 270 | bbx0_tl, bbx0_br = bbx0.unsqueeze(dim=1).split(2, -1) 271 | bbx1_tl, bbx1_br = bbx1.unsqueeze(dim=0).split(2, -1) 272 | 273 | # Intersection coordinates 274 | int_tl = torch.max(bbx0_tl, bbx1_tl) 275 | int_br = torch.min(bbx0_br, bbx1_br) 276 | 277 | intersection = (int_br - int_tl).clamp(min=0).prod(dim=-1) 278 | bbx0_area = (bbx0_br - bbx0_tl).prod(dim=-1) 279 | bbx1_area = (bbx1_br - bbx1_tl).prod(dim=-1) 280 | return intersection / (bbx0_area + bbx1_area - intersection) 281 | 282 | 283 | def mask_overlap(bbx, mask): 284 | """Calculate overlap between a set of bounding boxes and a mask 285 | 286 | Parameters 287 | ---------- 288 | bbx : torch.Tensor 289 | A tensor of bounding boxes in "corners" form with shape N x 4 290 | mask : torch.Tensor 291 | A binary tensor with shape H x W 292 | 293 | Returns 294 | ------- 295 | overlap : torch.Tensor 296 | A tensor with shape N containing the proportion of non-zero pixels in each box 297 | """ 298 | # Compute integral image of the mask 299 | int_mask = bbx.new_zeros((mask.size(0) + 1, mask.size(1) + 1)) 300 | int_mask[1:, 1:] = mask > 0 301 | int_mask = int_mask.cumsum(0).cumsum(1) 302 | 303 | count = _backend.mask_count(bbx, int_mask) 304 | area = (bbx[:, 2:] - bbx[:, :2]).prod(dim=1) 305 | 306 | return count / area 307 | 308 | 309 | def bbx_overlap(bbx0, bbx1): 310 | """Calculate intersection over area between two sets of bounding boxes 311 | 312 | Intersection over area is defined as: 313 | area(inter(bbx0, bbx1)) / area(bbx0) 314 | 315 | Parameters 316 | ---------- 317 | bbx0 : torch.Tensor 318 | A tensor of bounding boxes in "corners" form with shape N x 4 319 | bbx1 : torch.Tensor 320 | A tensor of bounding boxes in "corners" form with shape M x 4 321 | 322 | Returns 323 | ------- 324 | ratios : torch.Tensor 325 | A tensor with shape N x M containing the intersection over areas between all pairs of bounding boxes 326 | """ 327 | bbx0_tl, bbx0_br = bbx0.unsqueeze(dim=1).split(2, -1) 328 | bbx1_tl, bbx1_br = bbx1.unsqueeze(dim=0).split(2, -1) 329 | 330 | # Intersection coordinates 331 | int_tl = torch.max(bbx0_tl, bbx1_tl) 332 | int_br = torch.min(bbx0_br, bbx1_br) 333 | 334 | intersection = (int_br - int_tl).clamp(min=0).prod(dim=-1) 335 | bbx0_area = (bbx0_br - bbx0_tl).prod(dim=-1) 336 | 337 | return intersection / bbx0_area 338 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from math import log10 3 | from os import path 4 | 5 | from .meters import AverageMeter 6 | 7 | _NAME = "GraspDetSeg_CNN" 8 | 9 | 10 | def _current_total_formatter(current, total): 11 | width = int(log10(total)) + 1 12 | return ("[{:" + str(width) + "}/{:" + str(width) + "}]").format(current, total) 13 | 14 | 15 | def init(log_dir, name): 16 | logger = logging.getLogger(_NAME) 17 | logger.setLevel(logging.DEBUG) 18 | 19 | # Set console logging 20 | console_handler = logging.StreamHandler() 21 | console_formatter = logging.Formatter(fmt="%(asctime)s %(message)s", datefmt="%H:%M:%S") 22 | console_handler.setFormatter(console_formatter) 23 | console_handler.setLevel(logging.DEBUG) 24 | logger.addHandler(console_handler) 25 | 26 | # Setup file logging 27 | file_handler = logging.FileHandler(path.join(log_dir, name + ".log"), mode="w") 28 | file_formatter = logging.Formatter(fmt="%(levelname).1s %(asctime)s %(message)s", datefmt="%y-%m-%d %H:%M:%S") 29 | file_handler.setFormatter(file_formatter) 30 | file_handler.setLevel(logging.INFO) 31 | logger.addHandler(file_handler) 32 | 33 | 34 | def get_logger(): 35 | return logging.getLogger(_NAME) 36 | 37 | 38 | def iteration(summary, phase, global_step, epoch, num_epochs, step, num_steps, values, multiple_lines=False): 39 | logger = get_logger() 40 | 41 | # Build message and write summary 42 | msg = _current_total_formatter(epoch, num_epochs) + " " + _current_total_formatter(step, num_steps) 43 | for k, v in values.items(): 44 | if isinstance(v, AverageMeter): 45 | msg += "\n" if multiple_lines else "" + "\t{}={:.3f} ({:.3f})".format(k, v.value.item(), v.mean.item()) 46 | if summary is not None: 47 | summary.add_scalar("{}/{}".format(phase, k), v.value.item(), global_step) 48 | else: 49 | msg += "\n" if multiple_lines else "" + "\t{}={:.3f}".format(k, v) 50 | if summary is not None: 51 | summary.add_scalar("{}/{}".format(phase, k), v, global_step) 52 | 53 | # Write log 54 | logger.info(msg) 55 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/meters.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | class Meter: 7 | def __init__(self): 8 | self._states = OrderedDict() 9 | 10 | def register_state(self, name, tensor): 11 | if name not in self._states and isinstance(tensor, torch.Tensor): 12 | self._states[name] = tensor 13 | 14 | def __getattr__(self, item): 15 | if "_states" in self.__dict__: 16 | _states = self.__dict__["_states"] 17 | if item in _states: 18 | return _states[item] 19 | return self.__dict__[item] 20 | 21 | def reset(self): 22 | for state in self._states.values(): 23 | state.zero_() 24 | 25 | def state_dict(self): 26 | return dict(self._states) 27 | 28 | def load_state_dict(self, state_dict): 29 | for k, v in state_dict.items(): 30 | if k in self._states: 31 | self._states[k].copy_(v) 32 | else: 33 | raise KeyError("Unexpected key {} in state dict when loading {} from state dict" 34 | .format(k, self.__class__.__name__)) 35 | 36 | 37 | class ConstantMeter(Meter): 38 | def __init__(self, shape): 39 | super(ConstantMeter, self).__init__() 40 | self.register_state("last", torch.zeros(shape, dtype=torch.float32)) 41 | 42 | def update(self, value): 43 | self.last.copy_(value) 44 | 45 | @property 46 | def value(self): 47 | return self.last 48 | 49 | 50 | class AverageMeter(ConstantMeter): 51 | def __init__(self, shape, momentum=1.): 52 | super(AverageMeter, self).__init__(shape) 53 | self.register_state("sum", torch.zeros(shape, dtype=torch.float32)) 54 | self.register_state("count", torch.tensor(0, dtype=torch.float32)) 55 | self.momentum = momentum 56 | 57 | def update(self, value): 58 | super(AverageMeter, self).update(value) 59 | self.sum.mul_(self.momentum).add_(value) 60 | self.count.mul_(self.momentum).add_(1.) 61 | 62 | @property 63 | def mean(self): 64 | if self.count.item() == 0: 65 | return torch.tensor(0.) 66 | else: 67 | return self.sum / self.count.clamp(min=1) 68 | 69 | 70 | class ConfusionMatrixMeter(AverageMeter): 71 | def __init__(self, num_classes, momentum=1.): 72 | super(ConfusionMatrixMeter, self).__init__((num_classes, num_classes), momentum) 73 | 74 | @property 75 | def iou(self): 76 | mean_conf = self.mean 77 | return mean_conf.diag() / (mean_conf.sum(dim=0) + mean_conf.sum(dim=1) - mean_conf.diag()) 78 | 79 | @property 80 | def precision(self): 81 | return self.mean.diag() * torch.clamp(1. / self.mean.sum(dim=0), max=1.) 82 | 83 | @property 84 | def recall(self): 85 | return self.mean.diag() * torch.clamp(1. / self.mean.sum(dim=1), max=1.) 86 | 87 | 88 | class PanopticMeter(AverageMeter): 89 | def panoptic(self): 90 | return None if self.sum is None else \ 91 | self.sum[0] / (self.sum[1] + 0.5 * self.sum[2] + 0.5 * self.sum[3]) 92 | 93 | @property 94 | def avg(self): 95 | panoptic = self.panoptic() 96 | return 0 if panoptic is None else panoptic.mean() 97 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/misc.py: -------------------------------------------------------------------------------- 1 | import io 2 | from collections import OrderedDict 3 | from functools import partial 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | from inplace_abn import InPlaceABN, InPlaceABNSync, ABN 9 | 10 | from grasp_det_seg.modules.misc import ActivatedAffine, ActivatedGroupNorm 11 | from . import scheduler as lr_scheduler 12 | 13 | NORM_LAYERS = [ABN, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm] 14 | OTHER_LAYERS = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d] 15 | 16 | 17 | class Empty(Exception): 18 | """Exception to facilitate handling of empty predictions, annotations etc.""" 19 | pass 20 | 21 | 22 | def try_index(scalar_or_list, i): 23 | try: 24 | return scalar_or_list[i] 25 | except TypeError: 26 | return scalar_or_list 27 | 28 | 29 | def config_to_string(config): 30 | with io.StringIO() as sio: 31 | config.write(sio) 32 | config_str = sio.getvalue() 33 | return config_str 34 | 35 | 36 | def scheduler_from_config(scheduler_config, optimizer, epoch_length): 37 | assert scheduler_config["type"] in ("linear", "step", "poly", "multistep") 38 | 39 | params = scheduler_config.getstruct("params") 40 | 41 | if scheduler_config["type"] == "linear": 42 | if scheduler_config["update_mode"] == "batch": 43 | count = epoch_length * scheduler_config.getint("epochs") 44 | else: 45 | count = scheduler_config.getint("epochs") 46 | 47 | beta = float(params["from"]) 48 | alpha = float(params["to"] - beta) / count 49 | 50 | scheduler = lr_scheduler.LambdaLR(optimizer, lambda it: it * alpha + beta) 51 | elif scheduler_config["type"] == "step": 52 | scheduler = lr_scheduler.StepLR(optimizer, params["step_size"], params["gamma"]) 53 | elif scheduler_config["type"] == "poly": 54 | if scheduler_config["update_mode"] == "batch": 55 | count = epoch_length * scheduler_config.getint("epochs") 56 | else: 57 | count = scheduler_config.getint("epochs") 58 | scheduler = lr_scheduler.LambdaLR(optimizer, lambda it: (1 - float(it) / count) ** params["gamma"]) 59 | elif scheduler_config["type"] == "multistep": 60 | scheduler = lr_scheduler.MultiStepLR(optimizer, params["milestones"], params["gamma"]) 61 | else: 62 | raise ValueError("Unrecognized scheduler type {}, valid options: 'linear', 'step', 'poly', 'multistep'" 63 | .format(scheduler_config["type"])) 64 | 65 | if scheduler_config.getint("burn_in_steps") != 0: 66 | scheduler = lr_scheduler.BurnInLR(scheduler, 67 | scheduler_config.getint("burn_in_steps"), 68 | scheduler_config.getfloat("burn_in_start")) 69 | 70 | return scheduler 71 | 72 | 73 | def norm_act_from_config(body_config): 74 | """Make normalization + activation function from configuration 75 | 76 | Available normalization modes are: 77 | - `bn`: Standard In-Place Batch Normalization 78 | - `syncbn`: Synchronized In-Place Batch Normalization 79 | - `syncbn+bn`: Synchronized In-Place Batch Normalization in the "static" part of the network, Standard In-Place 80 | Batch Normalization in the "dynamic" parts 81 | - `gn`: Group Normalization 82 | - `syncbn+gn`: Synchronized In-Place Batch Normalization in the "static" part of the network, Group Normalization 83 | in the "dynamic" parts 84 | - `off`: No normalization (preserve scale and bias parameters) 85 | 86 | The "static" part of the network includes the backbone, FPN and semantic segmentation components, while the 87 | "dynamic" part of the network includes the RPN, detection and instance segmentation components. Note that this 88 | distinction is due to historical reasons and for back-compatibility with the CVPR2019 pre-trained models. 89 | 90 | Parameters 91 | ---------- 92 | body_config 93 | Configuration object containing the following fields: `normalization_mode`, `activation`, `activation_slope` 94 | and `gn_groups` 95 | 96 | Returns 97 | ------- 98 | norm_act_static : callable 99 | Function that returns norm_act modules for the static parts of the network 100 | norm_act_dynamic : callable 101 | Function that returns norm_act modules for the dynamic parts of the network 102 | """ 103 | mode = body_config["normalization_mode"] 104 | activation = body_config["activation"] 105 | slope = body_config.getfloat("activation_slope") 106 | groups = body_config.getint("gn_groups") 107 | 108 | if mode == "bn": 109 | norm_act_static = norm_act_dynamic = partial(InPlaceABN, activation=activation, activation_param=slope) 110 | elif mode == "syncbn": 111 | norm_act_static = norm_act_dynamic = partial(InPlaceABNSync, activation=activation, activation_param=slope) 112 | elif mode == "syncbn+bn": 113 | norm_act_static = partial(InPlaceABNSync, activation=activation, activation_param=slope) 114 | norm_act_dynamic = partial(InPlaceABN, activation=activation, activation_param=slope) 115 | elif mode == "gn": 116 | norm_act_static = norm_act_dynamic = partial( 117 | ActivatedGroupNorm, num_groups=groups, activation=activation, activation_param=slope) 118 | elif mode == "syncbn+gn": 119 | norm_act_static = partial(InPlaceABNSync, activation=activation, activation_param=slope) 120 | norm_act_dynamic = partial(ActivatedGroupNorm, num_groups=groups, activation=activation, activation_param=slope) 121 | elif mode == "off": 122 | norm_act_static = norm_act_dynamic = partial(ActivatedAffine, activation=activation, activation_param=slope) 123 | else: 124 | raise ValueError("Unrecognized normalization_mode {}, valid options: 'bn', 'syncbn', 'syncbn+bn', 'gn', " 125 | "'syncbn+gn', 'off'".format(mode)) 126 | 127 | return norm_act_static, norm_act_dynamic 128 | 129 | 130 | def freeze_params(module): 131 | """Freeze all parameters of the given module""" 132 | for p in module.parameters(): 133 | p.requires_grad_(False) 134 | 135 | 136 | def all_reduce_losses(losses): 137 | """Coalesced mean all reduce over a dictionary of 0-dimensional tensors""" 138 | names, values = [], [] 139 | for k, v in losses.items(): 140 | names.append(k) 141 | values.append(v) 142 | 143 | # Peform the actual coalesced all_reduce 144 | values = torch.cat([v.view(1) for v in values], dim=0) 145 | dist.all_reduce(values, dist.ReduceOp.SUM) 146 | values.div_(dist.get_world_size()) 147 | values = torch.chunk(values, values.size(0), dim=0) 148 | 149 | # Reconstruct the dictionary 150 | return OrderedDict((k, v.view(())) for k, v in zip(names, values)) 151 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/nms/__init__.py: -------------------------------------------------------------------------------- 1 | from .nms import nms 2 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/nms/_backend.pyi: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def nms(bbx: torch.Tensor, scores: torch.Tensor, threshold: float, n_max: int) -> torch.Tensor: ... 5 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/nms/nms.py: -------------------------------------------------------------------------------- 1 | from . import _backend 2 | 3 | 4 | def nms(bbx, scores, threshold=0.5, n_max=-1): 5 | """Perform non-maxima suppression 6 | 7 | Select up to n_max bounding boxes from bbx, giving priorities to bounding boxes with greater scores. Each selected 8 | bounding box suppresses all other not yet selected boxes that intersect it by more than the given threshold. 9 | 10 | Parameters 11 | ---------- 12 | bbx : torch.Tensor 13 | A tensor of bounding boxes with shape N x 4 14 | scores : torch.Tensor 15 | A tensor of bounding box scores with shape N 16 | threshold : float 17 | The minimum iou value for a pair of bounding boxes to be considered a match 18 | n_max : int 19 | Maximum number of bounding boxes to select. If n_max <= 0, keep all surviving boxes 20 | 21 | Returns 22 | ------- 23 | selection : torch.Tensor 24 | A tensor with the indices of the selected boxes 25 | 26 | """ 27 | selection = _backend.nms(bbx, scores, threshold, n_max) 28 | return selection.to(device=bbx.device) 29 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import DistributedDataParallel 2 | from .packed_sequence import PackedSequence 3 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/parallel/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/__pycache__/data_parallel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/parallel/__pycache__/data_parallel.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/__pycache__/packed_sequence.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/parallel/__pycache__/packed_sequence.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/__pycache__/scatter_gather.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/grasp_det_seg/utils/parallel/__pycache__/scatter_gather.cpython-36.pyc -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | from torch.nn.parallel import DistributedDataParallel as TorchDistributedDataParallel 2 | 3 | from .scatter_gather import scatter_kwargs, gather 4 | 5 | 6 | class DistributedDataParallel(TorchDistributedDataParallel): 7 | """`nn.parallel.DistributedDataParallel` extension which can handle `PackedSequence`s""" 8 | 9 | def scatter(self, inputs, kwargs, device_ids): 10 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 11 | 12 | def gather(self, outputs, output_device): 13 | return gather(outputs, output_device, dim=self.dim) 14 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/packed_sequence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _all_same(lst): 5 | return not lst or lst.count(lst[0]) == len(lst) 6 | 7 | 8 | class PackedSequence: 9 | def __init__(self, *args): 10 | if len(args) == 1 and isinstance(args[0], list): 11 | tensors = args[0] 12 | else: 13 | tensors = args 14 | 15 | # Check if all input are tensors of the same type and device 16 | for tensor in tensors: 17 | if tensor is not None and not isinstance(tensor, torch.Tensor): 18 | raise TypeError("All args must be tensors") 19 | if not _all_same([tensor.dtype for tensor in tensors if tensor is not None]): 20 | raise TypeError("All tensors must have the same type") 21 | if not _all_same([tensor.device for tensor in tensors if tensor is not None]): 22 | raise TypeError("All tensors must reside on the same device") 23 | self._tensors = tensors 24 | 25 | # Check useful properties of the sequence 26 | self._compatible = _all_same([tensor.shape[1:] for tensor in self._tensors if tensor is not None]) 27 | self._all_none = all([tensor is None for tensor in self._tensors]) 28 | 29 | def __add__(self, other): 30 | if not isinstance(other, PackedSequence): 31 | raise TypeError("other must be a PackedSequence") 32 | return PackedSequence(self._tensors + other._tensors) 33 | 34 | def __iadd__(self, other): 35 | if not isinstance(other, PackedSequence): 36 | raise TypeError("other must be a PackedSequence") 37 | self._tensors += other._tensors 38 | return self 39 | 40 | def __len__(self): 41 | return self._tensors.__len__() 42 | 43 | def __getitem__(self, item): 44 | if isinstance(item, slice): 45 | return PackedSequence(*self._tensors.__getitem__(item)) 46 | else: 47 | return self._tensors.__getitem__(item) 48 | 49 | def __iter__(self): 50 | return self._tensors.__iter__() 51 | 52 | def cuda(self, device=None, non_blocking=False): 53 | self._tensors = [ 54 | tensor.cuda(device, non_blocking) if tensor is not None else None 55 | for tensor in self._tensors 56 | ] 57 | return self 58 | 59 | def cpu(self): 60 | self._tensors = [ 61 | tensor.cpu() if tensor is not None else None 62 | for tensor in self._tensors 63 | ] 64 | return self 65 | 66 | @property 67 | def all_none(self): 68 | return self._all_none 69 | 70 | @property 71 | def dtype(self): 72 | if self.all_none: 73 | return None 74 | return next(tensor.dtype for tensor in self._tensors if tensor is not None) 75 | 76 | @property 77 | def device(self): 78 | if self.all_none: 79 | return None 80 | return next(tensor.device for tensor in self._tensors if tensor is not None) 81 | 82 | @property 83 | def contiguous(self): 84 | if not self._compatible: 85 | raise ValueError("The tensors in the sequence are not compatible for contiguous view") 86 | if self.all_none: 87 | return None, None 88 | 89 | packed_tensors = [] 90 | packed_idx = [] 91 | for i, tensor in enumerate(self._tensors): 92 | if tensor is not None: 93 | packed_tensors.append(tensor) 94 | packed_idx.append(tensor.new_full((tensor.size(0),), i, dtype=torch.long)) 95 | 96 | return torch.cat(packed_tensors, dim=0), torch.cat(packed_idx, dim=0) 97 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/parallel/scatter_gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel._functions import Scatter, Gather 3 | 4 | from .packed_sequence import PackedSequence 5 | 6 | 7 | def scatter(inputs, target_gpus, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | 14 | def scatter_map(obj): 15 | if isinstance(obj, torch.Tensor): 16 | return Scatter.apply(target_gpus, None, dim, obj) 17 | if isinstance(obj, tuple) and len(obj) > 0: 18 | return list(zip(*map(scatter_map, obj))) 19 | if isinstance(obj, list) and len(obj) > 0: 20 | return list(map(list, zip(*map(scatter_map, obj)))) 21 | if isinstance(obj, dict) and len(obj) > 0: 22 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 23 | if isinstance(obj, PackedSequence): 24 | return packed_sequence_scatter(obj, target_gpus) 25 | return [obj for _ in target_gpus] 26 | 27 | # After scatter_map is called, a scatter_map cell will exist. This cell 28 | # has a reference to the actual function scatter_map, which has references 29 | # to a closure that has a reference to the scatter_map cell (because the 30 | # fn is recursive). To avoid this reference cycle, we set the function to 31 | # None, clearing the cell 32 | try: 33 | return scatter_map(inputs) 34 | finally: 35 | scatter_map = None 36 | 37 | 38 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): 39 | r"""Scatter with support for kwargs dictionary""" 40 | inputs = scatter(inputs, target_gpus, dim) if inputs else [] 41 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 42 | if len(inputs) < len(kwargs): 43 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 44 | elif len(kwargs) < len(inputs): 45 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 46 | inputs = tuple(inputs) 47 | kwargs = tuple(kwargs) 48 | return inputs, kwargs 49 | 50 | 51 | def gather(outputs, target_device, dim=0): 52 | r""" 53 | Gathers tensors from different GPUs on a specified device 54 | (-1 means the CPU). 55 | """ 56 | 57 | def gather_map(outputs): 58 | out = outputs[0] 59 | if isinstance(out, torch.Tensor): 60 | return Gather.apply(target_device, dim, *outputs) 61 | if out is None: 62 | return None 63 | if isinstance(out, dict): 64 | if not all((len(out) == len(d) for d in outputs)): 65 | raise ValueError('All dicts must have the same number of keys') 66 | return type(out)(((k, gather_map([d[k] for d in outputs])) 67 | for k in out)) 68 | if isinstance(out, PackedSequence): 69 | return packed_sequence_gather(outputs, target_device) 70 | return type(out)(map(gather_map, zip(*outputs))) 71 | 72 | # Recursive function calls like this create reference cycles. 73 | # Setting the function to None clears the refcycle. 74 | try: 75 | return gather_map(outputs) 76 | finally: 77 | gather_map = None 78 | 79 | 80 | def packed_sequence_scatter(seq, target_gpus): 81 | # Find chunks 82 | k, m = divmod(len(seq), len(target_gpus)) 83 | limits = [(i * k + min(i, m), (i + 1) * k + min(i + 1, m)) for i in range(len(target_gpus))] 84 | outs = [] 85 | for device, (i, j) in zip(target_gpus, limits): 86 | outs.append(seq[i:j].cuda(device)) 87 | return outs 88 | 89 | 90 | def packed_sequence_gather(seqs, target_device): 91 | out = seqs[0].cuda(target_device) 92 | for i in range(1, len(seqs)): 93 | out += seqs[i].cuda(target_device) 94 | return out 95 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/roi_sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import roi_sampling 2 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/roi_sampling/_backend.pyi: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | class PaddingMode: 7 | Zero = ... 8 | Border = ... 9 | 10 | 11 | class Interpolation: 12 | Bilinear = ... 13 | Nearest = ... 14 | 15 | 16 | def roi_sampling_forward( 17 | x: torch.Tensor, bbx: torch.Tensor, idx: torch.Tensor, out_size: Tuple[int, int], 18 | interpolation: Interpolation, padding: PaddingMode, valid_mask: bool) -> Tuple[torch.Tensor, torch.Tensor]: ... 19 | 20 | 21 | def roi_sampling_backward( 22 | dy: torch.Tensor, bbx: torch.Tensor, idx: torch.Tensor, in_size: Tuple[int, int, int], 23 | interpolation: Interpolation, padding: PaddingMode) -> torch.Tensor: ... 24 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/roi_sampling/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | from torch.autograd.function import once_differentiable 4 | 5 | from . import _backend 6 | 7 | _INTERPOLATION = {"bilinear": _backend.Interpolation.Bilinear, "nearest": _backend.Interpolation.Nearest} 8 | _PADDING = {"zero": _backend.PaddingMode.Zero, "border": _backend.PaddingMode.Border} 9 | 10 | 11 | class ROISampling(autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x, bbx, idx, roi_size, interpolation, padding, valid_mask): 14 | ctx.save_for_backward(bbx, idx) 15 | ctx.input_shape = (x.size(0), x.size(2), x.size(3)) 16 | ctx.valid_mask = valid_mask 17 | 18 | try: 19 | ctx.interpolation = _INTERPOLATION[interpolation] 20 | except KeyError: 21 | raise ValueError("Unknown interpolation {}".format(interpolation)) 22 | try: 23 | ctx.padding = _PADDING[padding] 24 | except KeyError: 25 | raise ValueError("Unknown padding {}".format(padding)) 26 | 27 | y, mask = _backend.roi_sampling_forward(x, bbx, idx, roi_size, ctx.interpolation, ctx.padding, valid_mask) 28 | 29 | if not torch.is_floating_point(x): 30 | ctx.mark_non_differentiable(y) 31 | if valid_mask: 32 | ctx.mark_non_differentiable(mask) 33 | return y, mask 34 | else: 35 | return y 36 | 37 | @staticmethod 38 | @once_differentiable 39 | def backward(ctx, *args): 40 | if ctx.valid_mask: 41 | dy, _ = args 42 | else: 43 | dy = args[0] 44 | 45 | assert torch.is_floating_point(dy), "ROISampling.backward is only defined for floating point types" 46 | bbx, idx = ctx.saved_tensors 47 | 48 | dx = _backend.roi_sampling_backward(dy, bbx, idx, ctx.input_shape, ctx.interpolation, ctx.padding) 49 | return dx, None, None, None, None, None, None 50 | 51 | 52 | def roi_sampling(x, bbx, idx, roi_size, interpolation="bilinear", padding="border", valid_mask=False): 53 | """Sample ROIs from a batch of images using bi-linear interpolation 54 | 55 | ROIs are sampled from the input by bi-linear interpolation, using the following equations to transform from 56 | ROI coordinates to image coordinates: 57 | 58 | y_img = y0 + y_roi / h_roi * (y1 - y0), for y_roi in range(0, h_roi) 59 | x_img = x0 + x_roi / w_roi * (x1 - x0), for x_roi in range(0, w_roi) 60 | 61 | where `(h_roi, w_roi)` is the shape of the ROI and `(y0, x0, y1, x1)` are its bounding box coordinates on the image 62 | 63 | Parameters 64 | ---------- 65 | x : torch.Tensor 66 | A tensor with shape N x C x H x W containing a batch of images to sample from 67 | bbx : torch.Tensor 68 | A tensor with shape K x 4 containing the bounding box coordinates of the ROIs in "corners" format 69 | idx : torch.Tensor 70 | A tensor with shape K containing the batch indices of the image each ROI should be sampled from 71 | roi_size : tuple of int 72 | The size `(h_roi, w_roi)` of the output ROIs 73 | interpolation : str 74 | Sampling mode, one of "bilinear" or "nearest" 75 | padding : str 76 | Padding mode, one of "border" or "zero" 77 | valid_mask : bool 78 | If `True` also return a mask tensor that indicates which points of the outputs where sampled from within the 79 | valid region of the input 80 | 81 | Returns 82 | ------- 83 | y : torch.Tensor 84 | A tensor with shape K x C x h_roi x w_roi containing the sampled ROIs 85 | mask : torch.Tensor 86 | Optional output returned only when valid_mask is `True`: a mask tensor with shape K x h_roi x w_roi, whose 87 | entries are `!= 0` where the corresponding location in `y` was sampled from within the limits of the input image 88 | """ 89 | return ROISampling.apply(x, bbx, idx, roi_size, interpolation, padding, valid_mask) 90 | 91 | 92 | __all__ = ["roi_sampling"] 93 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains copies of the main LR schedulers from Pytorch 1.0, as well as some additional schedulers 3 | and utility code. This is mostly intended as a work-around for the bugs and general issues introduced in Pytorch 1.1 4 | and should be reworked as soon as a proper (and stable) scheduler interface is introduced in Pytorch. 5 | """ 6 | import types 7 | from bisect import bisect_right 8 | 9 | from torch.optim import Optimizer 10 | 11 | 12 | class _LRScheduler(object): 13 | def __init__(self, optimizer, last_epoch=-1): 14 | if not isinstance(optimizer, Optimizer): 15 | raise TypeError('{} is not an Optimizer'.format( 16 | type(optimizer).__name__)) 17 | self.optimizer = optimizer 18 | if last_epoch == -1: 19 | for group in optimizer.param_groups: 20 | group.setdefault('initial_lr', group['lr']) 21 | else: 22 | for i, group in enumerate(optimizer.param_groups): 23 | if 'initial_lr' not in group: 24 | raise KeyError("param 'initial_lr' is not specified " 25 | "in param_groups[{}] when resuming an optimizer".format(i)) 26 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 27 | self.step(last_epoch + 1) 28 | self.last_epoch = last_epoch 29 | 30 | def state_dict(self): 31 | """Returns the state of the scheduler as a :class:`dict`. 32 | It contains an entry for every variable in self.__dict__ which 33 | is not the optimizer. 34 | """ 35 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 36 | 37 | def load_state_dict(self, state_dict): 38 | """Loads the schedulers state. 39 | Arguments: 40 | state_dict (dict): scheduler state. Should be an object returned 41 | from a call to :meth:`state_dict`. 42 | """ 43 | self.__dict__.update(state_dict) 44 | 45 | def get_lr(self): 46 | raise NotImplementedError 47 | 48 | def step(self, epoch=None): 49 | if epoch is None: 50 | epoch = self.last_epoch + 1 51 | self.last_epoch = epoch 52 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 53 | param_group['lr'] = lr 54 | 55 | 56 | class LambdaLR(_LRScheduler): 57 | """Sets the learning rate of each parameter group to the initial lr 58 | times a given function. When last_epoch=-1, sets initial lr as lr. 59 | Args: 60 | optimizer (Optimizer): Wrapped optimizer. 61 | lr_lambda (function or list): A function which computes a multiplicative 62 | factor given an integer parameter epoch, or a list of such 63 | functions, one for each group in optimizer.param_groups. 64 | last_epoch (int): The index of last epoch. Default: -1. 65 | Example: 66 | >>> # Assuming optimizer has two groups. 67 | >>> lambda1 = lambda epoch: epoch // 30 68 | >>> lambda2 = lambda epoch: 0.95 ** epoch 69 | >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) 70 | >>> for epoch in range(100): 71 | >>> scheduler.step() 72 | >>> train(...) 73 | >>> validate(...) 74 | """ 75 | 76 | def __init__(self, optimizer, lr_lambda, last_epoch=-1): 77 | self.optimizer = optimizer 78 | if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): 79 | self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) 80 | else: 81 | if len(lr_lambda) != len(optimizer.param_groups): 82 | raise ValueError("Expected {} lr_lambdas, but got {}".format( 83 | len(optimizer.param_groups), len(lr_lambda))) 84 | self.lr_lambdas = list(lr_lambda) 85 | self.last_epoch = last_epoch 86 | super(LambdaLR, self).__init__(optimizer, last_epoch) 87 | 88 | def state_dict(self): 89 | """Returns the state of the scheduler as a :class:`dict`. 90 | It contains an entry for every variable in self.__dict__ which 91 | is not the optimizer. 92 | The learning rate lambda functions will only be saved if they are callable objects 93 | and not if they are functions or lambdas. 94 | """ 95 | state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} 96 | state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) 97 | 98 | for idx, fn in enumerate(self.lr_lambdas): 99 | if not isinstance(fn, types.FunctionType): 100 | state_dict['lr_lambdas'][idx] = fn.__dict__.copy() 101 | 102 | return state_dict 103 | 104 | def load_state_dict(self, state_dict): 105 | """Loads the schedulers state. 106 | Arguments: 107 | state_dict (dict): scheduler state. Should be an object returned 108 | from a call to :meth:`state_dict`. 109 | """ 110 | lr_lambdas = state_dict.pop('lr_lambdas') 111 | self.__dict__.update(state_dict) 112 | 113 | for idx, fn in enumerate(lr_lambdas): 114 | if fn is not None: 115 | self.lr_lambdas[idx].__dict__.update(fn) 116 | 117 | def get_lr(self): 118 | return [base_lr * lmbda(self.last_epoch) 119 | for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] 120 | 121 | 122 | class StepLR(_LRScheduler): 123 | """Sets the learning rate of each parameter group to the initial lr 124 | decayed by gamma every step_size epochs. When last_epoch=-1, sets 125 | initial lr as lr. 126 | Args: 127 | optimizer (Optimizer): Wrapped optimizer. 128 | step_size (int): Period of learning rate decay. 129 | gamma (float): Multiplicative factor of learning rate decay. 130 | Default: 0.1. 131 | last_epoch (int): The index of last epoch. Default: -1. 132 | Example: 133 | >>> # Assuming optimizer uses lr = 0.05 for all groups 134 | >>> # lr = 0.05 if epoch < 30 135 | >>> # lr = 0.005 if 30 <= epoch < 60 136 | >>> # lr = 0.0005 if 60 <= epoch < 90 137 | >>> # ... 138 | >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) 139 | >>> for epoch in range(100): 140 | >>> scheduler.step() 141 | >>> train(...) 142 | >>> validate(...) 143 | """ 144 | 145 | def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): 146 | self.step_size = step_size 147 | self.gamma = gamma 148 | super(StepLR, self).__init__(optimizer, last_epoch) 149 | 150 | def get_lr(self): 151 | return [base_lr * self.gamma ** (self.last_epoch // self.step_size) 152 | for base_lr in self.base_lrs] 153 | 154 | 155 | class MultiStepLR(_LRScheduler): 156 | """Set the learning rate of each parameter group to the initial lr decayed 157 | by gamma once the number of epoch reaches one of the milestones. When 158 | last_epoch=-1, sets initial lr as lr. 159 | Args: 160 | optimizer (Optimizer): Wrapped optimizer. 161 | milestones (list): List of epoch indices. Must be increasing. 162 | gamma (float): Multiplicative factor of learning rate decay. 163 | Default: 0.1. 164 | last_epoch (int): The index of last epoch. Default: -1. 165 | Example: 166 | >>> # Assuming optimizer uses lr = 0.05 for all groups 167 | >>> # lr = 0.05 if epoch < 30 168 | >>> # lr = 0.005 if 30 <= epoch < 80 169 | >>> # lr = 0.0005 if epoch >= 80 170 | >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) 171 | >>> for epoch in range(100): 172 | >>> scheduler.step() 173 | >>> train(...) 174 | >>> validate(...) 175 | """ 176 | 177 | def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1): 178 | if not list(milestones) == sorted(milestones): 179 | raise ValueError('Milestones should be a list of' 180 | ' increasing integers. Got {}', milestones) 181 | self.milestones = milestones 182 | self.gamma = gamma 183 | super(MultiStepLR, self).__init__(optimizer, last_epoch) 184 | 185 | def get_lr(self): 186 | return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch) 187 | for base_lr in self.base_lrs] 188 | 189 | 190 | class BurnInLR(_LRScheduler): 191 | def __init__(self, base, steps, start): 192 | self.base = base 193 | self.steps = steps 194 | self.start = start 195 | super(BurnInLR, self).__init__(base.optimizer, base.last_epoch) 196 | 197 | def step(self, epoch=None): 198 | super(BurnInLR, self).step(epoch) 199 | 200 | # Also update epoch for the wrapped scheduler 201 | if epoch is None: 202 | epoch = self.base.last_epoch + 1 203 | self.base.last_epoch = epoch 204 | 205 | def get_lr(self): 206 | beta = self.start 207 | alpha = (1. - beta) / self.steps 208 | if self.last_epoch <= self.steps: 209 | return [base_lr * (self.last_epoch * alpha + beta) for base_lr in self.base_lrs] 210 | else: 211 | return self.base.get_lr() 212 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/sequence.py: -------------------------------------------------------------------------------- 1 | from .parallel import PackedSequence 2 | 3 | 4 | def pad_packed_images(packed_images, pad_value=0., snap_size_to=None): 5 | """Assemble a padded tensor for a `PackedSequence` of images with different spatial sizes 6 | 7 | This method allows any standard convnet to operate on a `PackedSequence` of images as a batch 8 | 9 | Parameters 10 | ---------- 11 | packed_images : PackedSequence 12 | A PackedSequence containing N tensors with different spatial sizes H_i, W_i. The tensors can be either 2D or 3D. 13 | If they are 3D, they must all have the same number of channels C. 14 | pad_value : float or int 15 | Value used to fill the padded areas 16 | snap_size_to : int or None 17 | If not None, chose the spatial sizes of the padded tensor to be multiples of this 18 | 19 | Returns 20 | ------- 21 | padded_images : torch.Tensor 22 | A tensor with shape N x C x H x W or N x H x W, where `H = max_i H_i` and `W = max_i W_i` containing the images 23 | of the sequence aligned to the top left corner and padded with `pad_value` 24 | sizes : list of tuple of int 25 | A list with the original spatial sizes of the input images 26 | """ 27 | if packed_images.all_none: 28 | raise ValueError("at least one image in packed_images should be non-None") 29 | 30 | reference_img = next(img for img in packed_images if img is not None) 31 | max_size = reference_img.shape[-2:] 32 | ndims = len(reference_img.shape) 33 | chn = reference_img.shape[0] if ndims == 3 else 0 34 | 35 | # Check the shapes and find maximum spatial size 36 | for img in packed_images: 37 | if img is not None: 38 | if len(img.shape) != 3 and len(img.shape) != 2: 39 | raise ValueError("The input sequence must contain 2D or 3D tensors") 40 | if len(img.shape) != ndims: 41 | raise ValueError("All tensors in the input sequence must have the same number of dimensions") 42 | if ndims == 3 and img.shape[0] != chn: 43 | raise ValueError("3D tensors must all have the same number of channels") 44 | max_size = [max(s1, s2) for s1, s2 in zip(max_size, img.shape[-2:])] 45 | 46 | # Optional size snapping 47 | if snap_size_to is not None: 48 | max_size = [(s + snap_size_to - 1) // snap_size_to * snap_size_to for s in max_size] 49 | 50 | if ndims == 3: 51 | padded_images = reference_img.new_full([len(packed_images), chn] + max_size, pad_value) 52 | else: 53 | padded_images = reference_img.new_full([len(packed_images)] + max_size, pad_value) 54 | 55 | sizes = [] 56 | for i, tensor in enumerate(packed_images): 57 | if tensor is not None: 58 | if ndims == 3: 59 | padded_images[i, :, :tensor.shape[1], :tensor.shape[2]] = tensor 60 | sizes.append(tensor.shape[1:]) 61 | else: 62 | padded_images[i, :tensor.shape[0], :tensor.shape[1]] = tensor 63 | sizes.append(tensor.shape) 64 | else: 65 | sizes.append((0, 0)) 66 | 67 | return padded_images, sizes 68 | 69 | 70 | def pack_padded_images(padded_images, sizes): 71 | """Inverse function of `pad_packed_images`, refer to that for details""" 72 | images = [] 73 | for img, size in zip(padded_images, sizes): 74 | if img.dim() == 2: 75 | images.append(img[:int(size[0]), :int(size[1])]) 76 | else: 77 | images.append(img[:, :int(size[0]), :int(size[1])]) 78 | 79 | return PackedSequence([img.contiguous() for img in images]) 80 | -------------------------------------------------------------------------------- /grasp_det_seg/utils/snapshot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .misc import config_to_string 4 | 5 | 6 | def save_snapshot(file, config, epoch, last_score, best_score, global_step, **kwargs): 7 | data = { 8 | "config": config_to_string(config), 9 | "state_dict": dict(kwargs), 10 | "training_meta": { 11 | "epoch": epoch, 12 | "last_score": last_score, 13 | "best_score": best_score, 14 | "global_step": global_step 15 | } 16 | } 17 | torch.save(data, file) 18 | 19 | 20 | def pre_train_from_snapshots(model, snapshots, modules): 21 | for snapshot in snapshots: 22 | if ":" in snapshot: 23 | module_name, snapshot = snapshot.split(":") 24 | else: 25 | module_name = None 26 | 27 | snapshot = torch.load(snapshot, map_location="cpu") 28 | state_dict = snapshot["state_dict"] 29 | 30 | if module_name is None: 31 | for module_name in modules: 32 | if module_name in state_dict: 33 | _load_pretraining_dict(getattr(model, module_name), state_dict[module_name]) 34 | else: 35 | if module_name in modules: 36 | _load_pretraining_dict(getattr(model, module_name), state_dict[module_name]) 37 | else: 38 | raise ValueError("Unrecognized network module {}".format(module_name)) 39 | 40 | 41 | def resume_from_snapshot(model, snapshot, modules): 42 | snapshot = torch.load(snapshot, map_location="cpu") 43 | state_dict = snapshot["state_dict"] 44 | 45 | for module in modules: 46 | if module in state_dict: 47 | _load_pretraining_dict(getattr(model, module), state_dict[module]) 48 | else: 49 | raise KeyError("The given snapshot does not contain a state_dict for module '{}'".format(module)) 50 | 51 | return snapshot 52 | 53 | 54 | def _load_pretraining_dict(model, state_dict): 55 | """Load state dictionary from a pre-training snapshot 56 | 57 | This is an even less strict version of `model.load_state_dict(..., False)`, which also ignores parameters from 58 | `state_dict` that don't have the same shapes as the corresponding ones in `model`. This is useful when loading 59 | from pre-trained models that are trained on different datasets. 60 | 61 | Parameters 62 | ---------- 63 | model : torch.nn.Model 64 | Target model 65 | state_dict : dict 66 | Dictionary of model parameters 67 | """ 68 | model_sd = model.state_dict() 69 | 70 | for k, v in model_sd.items(): 71 | if k in state_dict: 72 | if v.shape != state_dict[k].shape: 73 | del state_dict[k] 74 | 75 | model.load_state_dict(state_dict, False) 76 | -------------------------------------------------------------------------------- /include/bbx.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | at::Tensor mask_count_cpu(const at::Tensor& bbx, const at::Tensor& int_mask); 6 | at::Tensor mask_count_cuda(const at::Tensor& bbx, const at::Tensor& int_mask); 7 | -------------------------------------------------------------------------------- /include/nms.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | const int64_t THREADS_PER_BLOCK = sizeof(int64_t) * 8; 6 | 7 | at::Tensor comp_mat_cpu(const at::Tensor& bbx, float threshold); 8 | at::Tensor comp_mat_cuda(const at::Tensor& bbx, float threshold); 9 | 10 | at::Tensor nms_cpu(const at::Tensor& comp_mat, const at::Tensor& scores, int n_max); -------------------------------------------------------------------------------- /include/roi_sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "utils/common.h" 9 | 10 | // ENUMS 11 | 12 | enum class PaddingMode { Zero, Border }; 13 | enum class Interpolation { Bilinear, Nearest }; 14 | 15 | // PROTOTYPES 16 | 17 | std::tuple roi_sampling_forward_cpu( 18 | const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple out_size, 19 | Interpolation interpolation, PaddingMode padding, bool valid_mask); 20 | std::tuple roi_sampling_forward_cuda( 21 | const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple out_size, 22 | Interpolation interpolation, PaddingMode padding, bool valid_mask); 23 | 24 | at::Tensor roi_sampling_backward_cpu( 25 | const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, std::tuple in_size, 26 | Interpolation interpolation, PaddingMode padding); 27 | at::Tensor roi_sampling_backward_cuda( 28 | const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, std::tuple in_size, 29 | Interpolation interpolation, PaddingMode padding); 30 | 31 | /* CONVENTIONS 32 | * 33 | * Integer indexes are i (vertical), j (horizontal) and k (generic) 34 | * Continuous coordinates are y (vertical), x (horizontal) and s (generic) 35 | * 36 | * The relation between the two is: y = i + 0.5, x = j + 0.5 37 | */ 38 | 39 | // SAMPLER 40 | 41 | template 42 | struct Sampler { 43 | Sampler(Indexer indexer, Interpolator interpolator) : _indexer(indexer), _interpolator(interpolator) {} 44 | 45 | template 46 | HOST_DEVICE scalar_t forward(coord_t y, coord_t x, Accessor accessor) const { 47 | // Step 1: find the four indices of the points to read from the input and their offsets 48 | index_t i_l, i_h, j_l, j_h; 49 | coord_t delta_y, delta_x; 50 | _neighbors(y, i_l, i_h, delta_y); 51 | _neighbors(x, j_l, j_h, delta_x); 52 | 53 | // Step 2: read the four points 54 | scalar_t p_ll = _indexer.get(accessor, i_l, j_l), 55 | p_lh = _indexer.get(accessor, i_l, j_h), 56 | p_hl = _indexer.get(accessor, i_h, j_l), 57 | p_hh = _indexer.get(accessor, i_h, j_h); 58 | 59 | // Step 3: get the interpolated value 60 | return _interpolator.get(delta_y, delta_x, p_ll, p_lh, p_hl, p_hh); 61 | } 62 | 63 | template 64 | HOST_DEVICE void backward(coord_t y, coord_t x, scalar_t grad, Accessor accessor) const { 65 | // Step 1: find the four indices of the points to read from the input and their offsets 66 | index_t i_l, i_h, j_l, j_h; 67 | coord_t delta_y, delta_x; 68 | _neighbors(y, i_l, i_h, delta_y); 69 | _neighbors(x, j_l, j_h, delta_x); 70 | 71 | // Step 2: reverse-interpolation 72 | scalar_t p_ll, p_lh, p_hl, p_hh; 73 | _interpolator.set(delta_y, delta_x, grad, p_ll, p_lh, p_hl, p_hh); 74 | 75 | // Step 3: accumulate 76 | _indexer.set(accessor, i_l, j_l, p_ll); 77 | _indexer.set(accessor, i_l, j_h, p_lh); 78 | _indexer.set(accessor, i_h, j_l, p_hl); 79 | _indexer.set(accessor, i_h, j_h, p_hh); 80 | } 81 | 82 | private: 83 | INLINE_HOST_DEVICE void _neighbors(coord_t s, index_t &k_l, index_t &k_h, coord_t &delta) const { 84 | k_l = static_cast(FLOOR(s - 0.5)); 85 | k_h = k_l + 1; 86 | delta = s - (static_cast(k_l) + 0.5); 87 | } 88 | 89 | private: 90 | Indexer _indexer; 91 | Interpolator _interpolator; 92 | }; 93 | 94 | // INDEXER 95 | 96 | template 97 | struct IndexerBase { 98 | IndexerBase(index_t height, index_t width) : _height(height), _width(width) {}; 99 | 100 | index_t _height; 101 | index_t _width; 102 | }; 103 | 104 | template 105 | struct Indexer; 106 | 107 | template 108 | struct Indexer : IndexerBase { 109 | using IndexerBase::IndexerBase; 110 | 111 | template 112 | INLINE_HOST_DEVICE scalar_t get(Accessor accessor, index_t i, index_t j) const { 113 | return _in_bounds(i, this->_height) && _in_bounds(j, this->_width) ? accessor[i][j] : 0; 114 | } 115 | 116 | template 117 | INLINE_HOST_DEVICE void set(Accessor accessor, index_t i, index_t j, scalar_t value) const { 118 | if (_in_bounds(i, this->_height) && _in_bounds(j, this->_width)) { 119 | ACCUM_BLOCK(accessor[i][j], value); 120 | } 121 | } 122 | 123 | private: 124 | INLINE_HOST_DEVICE bool _in_bounds(index_t k, index_t size) const { 125 | return k >= 0 && k < size; 126 | } 127 | }; 128 | 129 | template 130 | struct Indexer : IndexerBase { 131 | using IndexerBase::IndexerBase; 132 | 133 | template 134 | INLINE_HOST_DEVICE scalar_t get(Accessor accessor, index_t i, index_t j) const { 135 | _clamp(i, j); 136 | return accessor[i][j]; 137 | } 138 | 139 | template 140 | INLINE_HOST_DEVICE void set(Accessor accessor, index_t i, index_t j, scalar_t value) const { 141 | _clamp(i, j); 142 | ACCUM_BLOCK(accessor[i][j], value); 143 | } 144 | 145 | private: 146 | INLINE_HOST_DEVICE void _clamp(index_t &i, index_t &j) const { 147 | i = i >= 0 ? i : 0; 148 | i = i < this->_height ? i : this->_height - 1; 149 | j = j >= 0 ? j : 0; 150 | j = j < this->_width ? j : this->_width - 1; 151 | } 152 | }; 153 | 154 | // INTERPOLATORS 155 | 156 | template 157 | struct Interpolator; 158 | 159 | template 160 | struct Interpolator { 161 | INLINE_HOST_DEVICE scalar_t get( 162 | coord_t delta_y, coord_t delta_x, scalar_t p_ll, scalar_t p_lh, scalar_t p_hl, scalar_t p_hh) const { 163 | scalar_t hor_int_l = (1 - delta_x) * p_ll + delta_x * p_lh; 164 | scalar_t hor_int_h = (1 - delta_x) * p_hl + delta_x * p_hh; 165 | return (1 - delta_y) * hor_int_l + delta_y * hor_int_h; 166 | } 167 | 168 | INLINE_HOST_DEVICE void set( 169 | coord_t delta_y, coord_t delta_x, scalar_t value, 170 | scalar_t &p_ll, scalar_t &p_lh, scalar_t &p_hl, scalar_t &p_hh) const { 171 | p_ll = (1 - delta_x) * (1 - delta_y) * value; 172 | p_lh = delta_x * (1 - delta_y) * value; 173 | p_hl = (1 - delta_x) * delta_y * value; 174 | p_hh = delta_x * delta_y * value; 175 | } 176 | }; 177 | 178 | template 179 | struct Interpolator { 180 | INLINE_HOST_DEVICE scalar_t get( 181 | coord_t delta_y, coord_t delta_x, scalar_t p_ll, scalar_t p_lh, scalar_t p_hl, scalar_t p_hh) const { 182 | return p_ll * static_cast(delta_y < 0.5 && delta_x < 0.5) + 183 | p_lh * static_cast(delta_y < 0.5 && delta_x >= 0.5) + 184 | p_hl * static_cast(delta_y >= 0.5 && delta_x < 0.5) + 185 | p_hh * static_cast(delta_y >= 0.5 && delta_x >= 0.5); 186 | } 187 | 188 | INLINE_HOST_DEVICE void set( 189 | coord_t delta_y, coord_t delta_x, scalar_t value, 190 | scalar_t &p_ll, scalar_t &p_lh, scalar_t &p_hl, scalar_t &p_hh) const { 191 | p_ll = static_cast(delta_y < 0.5 && delta_x < 0.5) * value; 192 | p_lh = static_cast(delta_y < 0.5 && delta_x >= 0.5) * value; 193 | p_hl = static_cast(delta_y >= 0.5 && delta_x < 0.5) * value; 194 | p_hh = static_cast(delta_y >= 0.5 && delta_x >= 0.5) * value; 195 | } 196 | }; 197 | 198 | // UTILITY FUNCTIONS AND MACROS 199 | 200 | template 201 | INLINE_HOST_DEVICE coord_t roi_to_img(coord_t s_roi, coord_t s0_img, coord_t s1_img, coord_t roi_size) { 202 | return s_roi / roi_size * (s1_img - s0_img) + s0_img; 203 | } 204 | 205 | template 206 | INLINE_HOST_DEVICE coord_t img_to_img(coord_t s, coord_t size_in, coord_t size_out) { 207 | return s / size_in * size_out; 208 | } 209 | 210 | #define INTERPOLATION_PADDING_DEFINES(INTERPOLATION, PADDING) \ 211 | using indexer_t = Indexer; \ 212 | using interpolator_t = Interpolator; \ 213 | using sampler_t = Sampler; 214 | 215 | #define DISPATCH_INTERPOLATION_PADDING_MODES(INTERPOLATION, PADDING, ...) \ 216 | [&] { \ 217 | switch (INTERPOLATION) { \ 218 | case Interpolation::Bilinear: \ 219 | AT_CHECK(!std::is_integral::value, \ 220 | "Bilinear interpolation is not available for integral types"); \ 221 | switch (PADDING) { \ 222 | case PaddingMode::Zero: { \ 223 | INTERPOLATION_PADDING_DEFINES(Interpolation::Bilinear, PaddingMode::Zero) \ 224 | return __VA_ARGS__(); \ 225 | } \ 226 | case PaddingMode::Border: { \ 227 | INTERPOLATION_PADDING_DEFINES(Interpolation::Bilinear, PaddingMode::Border)\ 228 | return __VA_ARGS__(); \ 229 | }} \ 230 | case Interpolation::Nearest: \ 231 | switch (PADDING) { \ 232 | case PaddingMode::Zero: { \ 233 | INTERPOLATION_PADDING_DEFINES(Interpolation::Nearest, PaddingMode::Zero) \ 234 | return __VA_ARGS__(); \ 235 | } \ 236 | case PaddingMode::Border: { \ 237 | INTERPOLATION_PADDING_DEFINES(Interpolation::Nearest, PaddingMode::Border) \ 238 | return __VA_ARGS__(); \ 239 | }} \ 240 | } \ 241 | }() 242 | -------------------------------------------------------------------------------- /include/utils/checks.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT 6 | #ifndef AT_CHECK 7 | #define AT_CHECK AT_ASSERT 8 | #endif 9 | 10 | #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor") 12 | #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous") 13 | 14 | #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 15 | #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x) -------------------------------------------------------------------------------- /include/utils/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | /* 7 | * Functions to share code between CPU and GPU 8 | */ 9 | 10 | #ifdef __CUDACC__ 11 | // CUDA versions 12 | 13 | #define HOST_DEVICE __host__ __device__ 14 | #define INLINE_HOST_DEVICE __host__ __device__ inline 15 | #define FLOOR(x) floor(x) 16 | 17 | #if __CUDA_ARCH__ >= 600 18 | // Recent compute capabilities have both grid-level and block-level atomicAdd for all data types, so we use those 19 | #define ACCUM_BLOCK(x,y) atomicAdd_block(&(x),(y)) 20 | #define ACCUM(x, y) atomicAdd(&(x),(y)) 21 | #else 22 | // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float 23 | // and use the known atomicCAS-based implementation for double 24 | template 25 | __device__ inline data_t atomic_add(data_t *address, data_t val) { 26 | return atomicAdd(address, val); 27 | } 28 | 29 | template<> 30 | __device__ inline double atomic_add(double *address, double val) { 31 | unsigned long long int* address_as_ull = (unsigned long long int*)address; 32 | unsigned long long int old = *address_as_ull, assumed; 33 | do { 34 | assumed = old; 35 | old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed))); 36 | } while (assumed != old); 37 | return __longlong_as_double(old); 38 | } 39 | 40 | #define ACCUM_BLOCK(x,y) atomic_add(&(x),(y)) 41 | #define ACCUM(x,y) atomic_add(&(x),(y)) 42 | #endif // #if __CUDA_ARCH__ >= 600 43 | 44 | #else 45 | // CPU versions 46 | 47 | #define HOST_DEVICE 48 | #define INLINE_HOST_DEVICE inline 49 | #define FLOOR(x) std::floor(x) 50 | #define ACCUM_BLOCK(x,y) (x) += (y) 51 | #define ACCUM(x,y) (x) += (y) 52 | 53 | #endif // #ifdef __CUDACC__ 54 | 55 | /* 56 | * Other utility functions 57 | */ 58 | template 59 | INLINE_HOST_DEVICE void ind2sub(T i, T *sizes, T &i_n) { 60 | static_assert(dim == 1, "dim must be 1"); 61 | i_n = i % sizes[0]; 62 | } 63 | 64 | template 65 | INLINE_HOST_DEVICE void ind2sub(T i, T *sizes, T &i_n, Indices&...args) { 66 | static_assert(dim == sizeof...(args) + 1, "dim must equal the number of args"); 67 | i_n = i % sizes[dim - 1]; 68 | ind2sub(i / sizes[dim - 1], sizes, args...); 69 | } 70 | 71 | template inline T div_up(T x, T y) { 72 | static_assert(std::is_integral::value, "div_up is only defined for integral types"); 73 | return x / y + (x % y > 0); 74 | } -------------------------------------------------------------------------------- /include/utils/cuda.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * General settings and functions 5 | */ 6 | const int WARP_SIZE = 32; 7 | const int MAX_BLOCK_SIZE = 1024; 8 | 9 | static int getNumThreads(int nElem) { 10 | int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE}; 11 | for (int i = 0; i < 6; ++i) { 12 | if (nElem <= threadSizes[i]) { 13 | return threadSizes[i]; 14 | } 15 | } 16 | return MAX_BLOCK_SIZE; 17 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/mapillary/inplace_abn.git 2 | numpy 3 | opencv-contrib-python 4 | Pillow 5 | scikit-image 6 | scipy 7 | Shapely==1.7.0 8 | torch==1.1.0 9 | torchvision==0.3.0 10 | umsgpack==0.1.0 11 | future==0.18.2 12 | tensorboard==1.14.0 13 | -------------------------------------------------------------------------------- /sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-ainetter/grasp_det_seg_cnn/6ff96464f8906fb555d0a2f5a8b86c7f1330f108/sample.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from os import path, listdir 2 | import setuptools 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | 6 | def find_sources(root_dir): 7 | sources = [] 8 | for file in listdir(root_dir): 9 | _, ext = path.splitext(file) 10 | if ext in [".cpp", ".cu"]: 11 | sources.append(path.join(root_dir, file)) 12 | 13 | return sources 14 | 15 | 16 | def make_extension(name, package): 17 | return CUDAExtension( 18 | name="{}.{}._backend".format(package, name), 19 | sources=find_sources(path.join("src", name)), 20 | extra_compile_args={ 21 | "cxx": ["-O3"], 22 | "nvcc": ["--expt-extended-lambda"], 23 | }, 24 | include_dirs=["include/"], 25 | ) 26 | 27 | 28 | here = path.abspath(path.dirname(__file__)) 29 | 30 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 31 | long_description = f.read() 32 | 33 | setuptools.setup( 34 | # Meta-data 35 | name="GraspDetSeg_CNN", 36 | author="Stefan Ainetter", 37 | author_email="stefan.ainetter@icg.tugraz.at", 38 | description="Grasp Detection and Segmentation for Pytorch, code based on Seamless Scene Segmentation (https://github.com/mapillary/seamseg).", 39 | long_description_content_type="text/markdown", 40 | url="", 41 | classifiers=[ 42 | "Programming Language :: Python :: 3", 43 | "Programming Language :: Python :: 3.4", 44 | "Programming Language :: Python :: 3.5", 45 | "Programming Language :: Python :: 3.6", 46 | "Programming Language :: Python :: 3.7", 47 | ], 48 | 49 | # Versioning 50 | use_scm_version={"root": ".", "relative_to": __file__, "write_to": "grasp_det_seg/_version.py"}, 51 | 52 | # Requirements 53 | setup_requires=["setuptools_scm"], 54 | python_requires=">=3, <4", 55 | 56 | # Package description 57 | packages=[ 58 | "grasp_det_seg", 59 | "grasp_det_seg.algos", 60 | "grasp_det_seg.config", 61 | "grasp_det_seg.data_OCID", 62 | "grasp_det_seg.models", 63 | "grasp_det_seg.modules", 64 | "grasp_det_seg.modules.heads", 65 | "grasp_det_seg.utils", 66 | "grasp_det_seg.utils.bbx", 67 | "grasp_det_seg.utils.nms", 68 | "grasp_det_seg.utils.parallel", 69 | "grasp_det_seg.utils.roi_sampling", 70 | ], 71 | ext_modules=[ 72 | make_extension("nms", "grasp_det_seg.utils"), 73 | make_extension("bbx", "grasp_det_seg.utils"), 74 | make_extension("roi_sampling", "grasp_det_seg.utils") 75 | ], 76 | cmdclass={"build_ext": BuildExtension}, 77 | include_package_data=True, 78 | ) 79 | -------------------------------------------------------------------------------- /src/bbx/bbx.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "bbx.h" 4 | #include "utils/checks.h" 5 | 6 | at::Tensor extract_boxes(const at::Tensor& mask, int n_instances){ 7 | AT_CHECK(mask.ndimension() == 3, "Input mask should be 3D"); 8 | 9 | at::Tensor bbx = at::full({n_instances, 4}, -1, mask.options().dtype(at::kFloat)); 10 | 11 | AT_DISPATCH_ALL_TYPES(mask.scalar_type(), "extract_boxes", ([&]{ 12 | auto _mask = mask.accessor(); 13 | auto _bbx = bbx.accessor(); 14 | 15 | for (int c = 0; c < _mask.size(0); ++c) { 16 | for (int i = 0; i < _mask.size(1); ++i) { 17 | for (int j = 0; j < _mask.size(2); ++j) { 18 | int64_t id = static_cast(_mask[c][i][j]); 19 | if (id < n_instances) { 20 | if (_bbx[id][0] < 0 || _bbx[id][0] > i) _bbx[id][0] = i; 21 | if (_bbx[id][1] < 0 || _bbx[id][1] > j) _bbx[id][1] = j; 22 | if (_bbx[id][2] < 0 || _bbx[id][2] <= i) _bbx[id][2] = i + 1; 23 | if (_bbx[id][3] < 0 || _bbx[id][3] <= j) _bbx[id][3] = j + 1; 24 | } 25 | } 26 | } 27 | } 28 | })); 29 | 30 | return bbx; 31 | } 32 | 33 | at::Tensor mask_count(const at::Tensor& bbx, const at::Tensor& int_mask) { 34 | AT_CHECK(bbx.ndimension() == 2, "Input bbx should be 2D"); 35 | AT_CHECK(bbx.size(1) == 4, "Input bbx must be N x 4"); 36 | AT_CHECK(int_mask.ndimension() == 2, "Input mask should be 2D"); 37 | 38 | if (bbx.is_cuda()) { 39 | return mask_count_cuda(bbx, int_mask); 40 | } else { 41 | return mask_count_cpu(bbx, int_mask); 42 | } 43 | } 44 | 45 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 46 | m.def("extract_boxes", &extract_boxes, "Extract bounding boxes from image of instance IDs"); 47 | m.def("mask_count", &mask_count, "Count the number of non-zero entries in different regions of a mask"); 48 | } 49 | 50 | -------------------------------------------------------------------------------- /src/bbx/bbx_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "bbx.h" 4 | 5 | template 6 | inline T clamp(T x, T a, T b) { 7 | return std::max(a, std::min(b, x)); 8 | } 9 | 10 | at::Tensor mask_count_cpu(const at::Tensor& bbx, const at::Tensor& int_mask) { 11 | // Get dimensions 12 | auto num = bbx.size(0), height = int_mask.size(0), width = int_mask.size(1); 13 | 14 | // Create output 15 | auto count = at::zeros({num}, bbx.options()); 16 | 17 | AT_DISPATCH_FLOATING_TYPES(bbx.scalar_type(), "mask_count_cpu", ([&] { 18 | auto _bbx = bbx.accessor(); 19 | auto _int_mask = int_mask.accessor(); 20 | auto _count = count.accessor(); 21 | 22 | for (int64_t n = 0; n < num; ++n) { 23 | auto i0 = clamp(static_cast(_bbx[n][0]), int64_t(0), int64_t(height - 1)), 24 | j0 = clamp(static_cast(_bbx[n][1]), int64_t(0), int64_t(width - 1)), 25 | i1 = clamp(static_cast(_bbx[n][2]), int64_t(0), int64_t(height - 1)), 26 | j1 = clamp(static_cast(_bbx[n][3]), int64_t(0), int64_t(width - 1)); 27 | 28 | _count[n] = _int_mask[i1][j1] - _int_mask[i0][j1] - _int_mask[i1][j0] + _int_mask[i0][j0]; 29 | } 30 | })); 31 | 32 | return count; 33 | } 34 | -------------------------------------------------------------------------------- /src/bbx/bbx_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "bbx.h" 6 | #include "utils/cuda.cuh" 7 | 8 | template 9 | __device__ inline T clamp(T x, T a, T b) { 10 | return max(a, min(b, x)); 11 | } 12 | 13 | template 14 | __global__ void mask_count_kernel(const at::PackedTensorAccessor bbx, 15 | const at::PackedTensorAccessor int_mask, 16 | at::PackedTensorAccessor count) { 17 | index_t num = bbx.size(0), height = int_mask.size(0), width = int_mask.size(1); 18 | index_t n = blockIdx.x * blockDim.x + threadIdx.x; 19 | if (n < num) { 20 | auto _bbx = bbx[n]; 21 | 22 | int i0 = clamp(static_cast(_bbx[0]), index_t(0), height - 1), 23 | j0 = clamp(static_cast(_bbx[1]), index_t(0), width - 1), 24 | i1 = clamp(static_cast(_bbx[2]), index_t(0), height - 1), 25 | j1 = clamp(static_cast(_bbx[3]), index_t(0), width - 1); 26 | 27 | count[n] = int_mask[i1][j1] - int_mask[i0][j1] - int_mask[i1][j0] + int_mask[i0][j0]; 28 | } 29 | } 30 | 31 | at::Tensor mask_count_cuda(const at::Tensor& bbx, const at::Tensor& int_mask) { 32 | // Get dimensions 33 | auto num = bbx.size(0); 34 | 35 | // Create output 36 | auto count = at::zeros({num}, bbx.options()); 37 | 38 | // Run kernel 39 | dim3 threads(getNumThreads(num)); 40 | dim3 blocks((num + threads.x - 1) / threads.x); 41 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 42 | AT_DISPATCH_FLOATING_TYPES(bbx.scalar_type(), "mask_count_cuda", ([&] { 43 | if (at::cuda::detail::canUse32BitIndexMath(int_mask)) { 44 | auto _bbx = bbx.packed_accessor(); 45 | auto _int_mask = int_mask.packed_accessor(); 46 | auto _count = count.packed_accessor(); 47 | 48 | mask_count_kernel<<>>(_bbx, _int_mask, _count); 49 | } else { 50 | auto _bbx = bbx.packed_accessor(); 51 | auto _int_mask = int_mask.packed_accessor(); 52 | auto _count = count.packed_accessor(); 53 | 54 | mask_count_kernel<<>>(_bbx, _int_mask, _count); 55 | } 56 | })); 57 | 58 | return count; 59 | } -------------------------------------------------------------------------------- /src/nms/nms.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "nms.h" 4 | #include "utils/checks.h" 5 | 6 | at::Tensor nms(const at::Tensor& bbx, const at::Tensor& scores, float threshold, int n_max) { 7 | // Check inputs 8 | AT_CHECK(bbx.scalar_type() == scores.scalar_type(), "bbx and scores must have the same type"); 9 | AT_CHECK(bbx.size(0) == scores.size(0), "bbx and scores must have the same length"); 10 | AT_CHECK(bbx.size(1) == 4 && bbx.ndimension() == 2, "bbx must be an N x 4 tensor"); 11 | AT_CHECK(bbx.is_contiguous(), "bbx must be a contiguous tensor"); 12 | 13 | at::Tensor comp_mat; 14 | if (bbx.is_cuda()) { 15 | comp_mat = comp_mat_cuda(bbx, threshold); 16 | comp_mat = comp_mat.toBackend(at::Backend::CPU); 17 | } else { 18 | comp_mat = comp_mat_cpu(bbx, threshold); 19 | } 20 | 21 | // Sort scores 22 | auto sorted_and_idx = scores.sort(0, true); 23 | auto idx = std::get<1>(sorted_and_idx); 24 | 25 | // Run actual non-maxima suppression on CPU 26 | return nms_cpu(comp_mat, idx.toBackend(at::Backend::CPU), n_max); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("nms", &nms, "Perform non-maxima suppression, always return result as CPU Tensor"); 31 | } 32 | -------------------------------------------------------------------------------- /src/nms/nms_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | 7 | #include "nms.h" 8 | #include "utils/common.h" 9 | 10 | template 11 | inline T area(T tl0, T tl1, T br0, T br1) { 12 | return std::max(br0 - tl0, T(0)) * std::max(br1 - tl1, T(0)); 13 | } 14 | 15 | template 16 | inline T iou(at::TensorAccessor &bbx0, at::TensorAccessor &bbx1) { 17 | auto ptl0 = std::max(bbx0[0], bbx1[0]); 18 | auto ptl1 = std::max(bbx0[1], bbx1[1]); 19 | auto pbr0 = std::min(bbx0[2], bbx1[2]); 20 | auto pbr1 = std::min(bbx0[3], bbx1[3]); 21 | auto intersection = area(ptl0, ptl1, pbr0, pbr1); 22 | auto area0 = area(bbx0[0], bbx0[1], bbx0[2], bbx0[3]); 23 | auto area1 = area(bbx1[0], bbx1[1], bbx1[2], bbx1[3]); 24 | return intersection / (area0 + area1 - intersection); 25 | } 26 | 27 | at::Tensor comp_mat_cpu(const at::Tensor& bbx, float threshold) { 28 | int64_t num = bbx.size(0); 29 | int64_t blocks = div_up(num, THREADS_PER_BLOCK); 30 | 31 | auto comp_mat = at::zeros({num, blocks}, bbx.options().dtype(at::ScalarType::Long)); 32 | 33 | AT_DISPATCH_FLOATING_TYPES(bbx.scalar_type(), "comp_mat_cpu", ([&] { 34 | auto _bbx = bbx.accessor(); 35 | auto _comp_mat = comp_mat.accessor(); 36 | 37 | for (int64_t i = 0; i < num; ++i) { 38 | auto _bbx_i = _bbx[i]; 39 | auto _comp_mat_i = _comp_mat[i]; 40 | 41 | for (int64_t j = i + 1; j < num; ++j) { 42 | auto _bbx_j = _bbx[j]; 43 | auto iou_ij = iou(_bbx_i, _bbx_j); 44 | 45 | if (iou_ij >= threshold) { 46 | int64_t block_idx = j / THREADS_PER_BLOCK; 47 | int64_t bit_idx = j % THREADS_PER_BLOCK; 48 | 49 | _comp_mat_i[block_idx] |= int64_t(1) << bit_idx; 50 | } 51 | } 52 | } 53 | })); 54 | 55 | return comp_mat; 56 | } 57 | 58 | at::Tensor nms_cpu(const at::Tensor& comp_mat, const at::Tensor& idx, int n_max) { 59 | int64_t num = comp_mat.size(0); 60 | 61 | auto _comp_mat = comp_mat.accessor(); 62 | auto _idx = idx.data(); 63 | 64 | // Copy to C++ data structures 65 | std::list candidates; 66 | std::copy(_idx, _idx + num, std::back_inserter(candidates)); 67 | 68 | std::vector selection; 69 | size_t n_max_ = n_max > 0 ? n_max : num; 70 | 71 | // Run actual nms 72 | while (!candidates.empty() && selection.size() < n_max_) { 73 | // Select first element 74 | auto i = candidates.front(); 75 | selection.push_back(i); 76 | candidates.pop_front(); 77 | 78 | // Remove conflicts 79 | candidates.remove_if([&_comp_mat,&i] (const int64_t &j) { 80 | auto ii = std::min(i, j), jj = std::max(i, j); 81 | 82 | auto block_idx = jj / THREADS_PER_BLOCK; 83 | auto bit_idx = jj % THREADS_PER_BLOCK; 84 | return _comp_mat[ii][block_idx] & (int64_t(1) << bit_idx); 85 | }); 86 | } 87 | 88 | // Copy to output 89 | auto selection_tensor = at::zeros(selection.size(), comp_mat.options()); 90 | std::copy(selection.begin(), selection.end(), selection_tensor.data()); 91 | 92 | return selection_tensor; 93 | } -------------------------------------------------------------------------------- /src/nms/nms_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "nms.h" 5 | #include "utils/common.h" 6 | #include "utils/cuda.cuh" 7 | 8 | template struct VectType; 9 | template<> struct VectType { 10 | typedef float4 value; 11 | typedef float4* ptr; 12 | typedef const float4* const_ptr; 13 | }; 14 | template<> struct VectType { 15 | typedef double4 value; 16 | typedef double4* ptr; 17 | typedef const double4* const_ptr; 18 | }; 19 | 20 | template 21 | __device__ inline T area(T tl0, T tl1, T br0, T br1) { 22 | return max(br0 - tl0, T(0)) * max(br1 - tl1, T(0)); 23 | } 24 | 25 | template 26 | __device__ inline T iou(typename VectType::value bbx0, typename VectType::value bbx1) { 27 | auto ptl0 = max(bbx0.x, bbx1.x); 28 | auto ptl1 = max(bbx0.y, bbx1.y); 29 | auto pbr0 = min(bbx0.z, bbx1.z); 30 | auto pbr1 = min(bbx0.w, bbx1.w); 31 | auto intersection = area(ptl0, ptl1, pbr0, pbr1); 32 | auto area0 = area(bbx0.x, bbx0.y, bbx0.z, bbx0.w); 33 | auto area1 = area(bbx1.x, bbx1.y, bbx1.z, bbx1.w); 34 | return intersection / (area0 + area1 - intersection); 35 | } 36 | 37 | template 38 | __global__ void comp_mat_kernel(const int64_t num, const int64_t blocks, const float threshold, 39 | const T* __restrict__ bbx, int64_t* __restrict__ comp_mat) { 40 | // Find position in grid 41 | const int row_start = blockIdx.y; 42 | const int col_start = blockIdx.x; 43 | const int row_size = min(num - row_start * THREADS_PER_BLOCK, THREADS_PER_BLOCK); 44 | const int col_size = min(num - col_start * THREADS_PER_BLOCK, THREADS_PER_BLOCK); 45 | 46 | auto _bbx = reinterpret_cast::const_ptr>(bbx); 47 | 48 | // Load data to block storage 49 | __shared__ typename VectType::value block_bbx[THREADS_PER_BLOCK]; 50 | if (threadIdx.x < col_size) { 51 | block_bbx[threadIdx.x] = _bbx[THREADS_PER_BLOCK * col_start + threadIdx.x]; 52 | } 53 | __syncthreads(); 54 | 55 | // Perform actual computation 56 | if (threadIdx.x < row_size) { 57 | const int cur_box_idx = THREADS_PER_BLOCK * row_start + threadIdx.x; 58 | const auto cur_box = _bbx[cur_box_idx]; 59 | 60 | int start = 0; 61 | if (row_start == col_start) { 62 | start = threadIdx.x + 1; 63 | } 64 | 65 | int64_t t = 0; 66 | for (int i = start; i < col_size; ++i) { 67 | if (iou(cur_box, block_bbx[i]) >= threshold) { 68 | t |= int64_t(1) << i; 69 | } 70 | } 71 | comp_mat[cur_box_idx * blocks + col_start] = t; 72 | } 73 | } 74 | 75 | at::Tensor comp_mat_cuda(const at::Tensor& bbx, float threshold) { 76 | int64_t num = bbx.size(0); 77 | int64_t blocks = div_up(num, THREADS_PER_BLOCK); 78 | 79 | auto comp_mat = at::zeros({num, blocks}, bbx.options().dtype(at::kLong)); 80 | 81 | dim3 blk(blocks, blocks, 1); 82 | dim3 thd(THREADS_PER_BLOCK, 1, 1); 83 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 84 | AT_DISPATCH_FLOATING_TYPES(bbx.scalar_type(), "comp_mat_cuda", ([&] { 85 | comp_mat_kernel<<>>( 86 | num, blocks, threshold, bbx.data(), comp_mat.data()); 87 | })); 88 | 89 | return comp_mat; 90 | } 91 | -------------------------------------------------------------------------------- /src/roi_sampling/roi_sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "utils/checks.h" 6 | #include "roi_sampling.h" 7 | 8 | std::tuple roi_sampling_forward( 9 | const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple out_size, 10 | Interpolation interpolation, PaddingMode padding, bool valid_mask) { 11 | // Check dimensions 12 | AT_CHECK(x.ndimension() == 4, "x must be a 4-dimensional tensor"); 13 | AT_CHECK(bbx.ndimension() == 2, "bbx must be a 2-dimensional tensor"); 14 | AT_CHECK(idx.ndimension() == 1, "idx must be a 1-dimensional tensor"); 15 | AT_CHECK(bbx.size(0) == idx.size(0), "idx and bbx must have the same size in the first dimension"); 16 | AT_CHECK(bbx.size(1) == 4, "bbx must be N x 4"); 17 | 18 | // Check types 19 | AT_CHECK(bbx.scalar_type() == at::ScalarType::Float, "bbx must have type float32"); 20 | AT_CHECK(idx.scalar_type() == at::ScalarType::Long, "idx must have type long"); 21 | 22 | if (x.is_cuda()) { 23 | CHECK_CUDA(bbx); 24 | CHECK_CUDA(idx); 25 | 26 | return roi_sampling_forward_cuda(x, bbx, idx, out_size, interpolation, padding, valid_mask); 27 | } else { 28 | CHECK_CPU(bbx); 29 | CHECK_CPU(idx); 30 | 31 | return roi_sampling_forward_cpu(x, bbx, idx, out_size, interpolation, padding, valid_mask); 32 | } 33 | } 34 | 35 | at::Tensor roi_sampling_backward( 36 | const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, std::tuple in_size, 37 | Interpolation interpolation, PaddingMode padding) { 38 | // Check dimensions 39 | AT_CHECK(dy.ndimension() == 4, "dy must be a 4-dimensional tensor"); 40 | AT_CHECK(bbx.ndimension() == 2, "bbx must be a 2-dimensional tensor"); 41 | AT_CHECK(idx.ndimension() == 1, "idx must be a 1-dimensional tensor"); 42 | AT_CHECK(bbx.size(0) == idx.size(0), "idx and bbx must have the same size in the first dimension"); 43 | AT_CHECK(bbx.size(1) == 4, "bbx must be N x 4"); 44 | 45 | // Check types 46 | AT_CHECK(bbx.scalar_type() == at::ScalarType::Float, "bbx must have type float32"); 47 | AT_CHECK(idx.scalar_type() == at::ScalarType::Long, "idx must have type long"); 48 | 49 | if (dy.is_cuda()) { 50 | CHECK_CUDA(bbx); 51 | CHECK_CUDA(idx); 52 | 53 | return roi_sampling_backward_cuda(dy, bbx, idx, in_size, interpolation, padding); 54 | } else { 55 | CHECK_CPU(bbx); 56 | CHECK_CPU(idx); 57 | 58 | return roi_sampling_backward_cpu(dy, bbx, idx, in_size, interpolation, padding); 59 | } 60 | } 61 | 62 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 63 | pybind11::enum_(m, "PaddingMode") 64 | .value("Zero", PaddingMode::Zero) 65 | .value("Border", PaddingMode::Border); 66 | 67 | pybind11::enum_(m, "Interpolation") 68 | .value("Bilinear", Interpolation::Bilinear) 69 | .value("Nearest", Interpolation::Nearest); 70 | 71 | m.def("roi_sampling_forward", &roi_sampling_forward, "ROI sampling forward"); 72 | m.def("roi_sampling_backward", &roi_sampling_backward, "ROI sampling backward"); 73 | } 74 | -------------------------------------------------------------------------------- /src/roi_sampling/roi_sampling_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "roi_sampling.h" 4 | 5 | template 6 | void roi_sampling_forward_impl( 7 | at::TensorAccessor x, 8 | at::TensorAccessor bbx, 9 | at::TensorAccessor idx, 10 | at::TensorAccessor y, 11 | at::TensorAccessor mask, 12 | bool valid_mask, 13 | Sampler sampler) { 14 | auto roi_height = static_cast(y.size(2)), 15 | roi_width = static_cast(y.size(3)); 16 | auto img_height = static_cast(x.size(2)), 17 | img_width = static_cast(x.size(3)); 18 | 19 | for (int64_t n = 0; n < idx.size(0); ++n) { 20 | auto img_idx = idx[n]; 21 | auto i0 = bbx[n][0], j0 = bbx[n][1], i1 = bbx[n][2], j1 = bbx[n][3]; 22 | 23 | for (int64_t c = 0; c < x.size(1); ++c) { 24 | // Create indexer for this plane and image 25 | auto accessor = x[img_idx][c]; 26 | 27 | for (int64_t i_roi = 0; i_roi < y.size(2); ++i_roi) { 28 | auto y_img = roi_to_img(static_cast(i_roi) + coord_t(0.5), i0, i1, roi_height); 29 | 30 | for (int64_t j_roi = 0; j_roi < y.size(3); ++j_roi) { 31 | auto x_img = roi_to_img(static_cast(j_roi) + coord_t(0.5), j0, j1, roi_width); 32 | 33 | y[n][c][i_roi][j_roi] = sampler.forward(y_img, x_img, accessor); 34 | 35 | // Optionally write to mask 36 | if (valid_mask) { 37 | mask[n][i_roi][j_roi] = y_img >= 0 && y_img < img_height && x_img >= 0 && x_img < img_width; 38 | } 39 | } 40 | } 41 | } 42 | } 43 | } 44 | 45 | std::tuple roi_sampling_forward_cpu( 46 | const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple out_size, 47 | Interpolation interpolation, PaddingMode padding, bool valid_mask) { 48 | 49 | // Prepare outputs 50 | auto y = at::empty({idx.size(0), x.size(1), std::get<0>(out_size), std::get<1>(out_size)}, x.options()); 51 | auto mask = valid_mask 52 | ? at::zeros({idx.size(0), std::get<0>(out_size), std::get<1>(out_size)}, x.options().dtype(at::kByte)) 53 | : at::zeros({1, 1, 1}, x.options().dtype(at::kByte)); 54 | 55 | AT_DISPATCH_ALL_TYPES(x.scalar_type(), "roi_sampling_forward_cpu", ([&] { 56 | using coord_t = float; 57 | using index_t = int64_t; 58 | 59 | auto _x = x.accessor(); 60 | auto _bbx = bbx.accessor(); 61 | auto _idx = idx.accessor(); 62 | auto _y = y.accessor(); 63 | auto _mask = mask.accessor(); 64 | 65 | DISPATCH_INTERPOLATION_PADDING_MODES(interpolation, padding, ([&] { 66 | indexer_t indexer(x.size(2), x.size(3)); 67 | interpolator_t interpolator; 68 | sampler_t sampler(indexer, interpolator); 69 | 70 | roi_sampling_forward_impl(_x, _bbx, _idx, _y, _mask, valid_mask, sampler); 71 | })); 72 | })); 73 | 74 | return std::make_tuple(y, mask); 75 | } 76 | 77 | template 78 | void roi_sampling_backward_impl( 79 | at::TensorAccessor dy, 80 | at::TensorAccessor bbx, 81 | at::TensorAccessor idx, 82 | at::TensorAccessor dx, 83 | Sampler sampler) { 84 | auto roi_height = static_cast(dy.size(2)), 85 | roi_width = static_cast(dy.size(3)); 86 | 87 | for (int64_t n = 0; n < idx.size(0); ++n) { 88 | auto img_idx = idx[n]; 89 | auto i0 = bbx[n][0], j0 = bbx[n][1], i1 = bbx[n][2], j1 = bbx[n][3]; 90 | 91 | for (int64_t c = 0; c < dy.size(1); ++c) { 92 | // Create indexer for this plane and image 93 | auto accessor = dx[img_idx][c]; 94 | 95 | for (int64_t i_roi = 0; i_roi < dy.size(2); ++i_roi) { 96 | auto y_img = roi_to_img(static_cast(i_roi) + coord_t(0.5), i0, i1, roi_height); 97 | 98 | for (int64_t j_roi = 0; j_roi < dy.size(3); ++j_roi) { 99 | auto x_img = roi_to_img(static_cast(j_roi) + coord_t(0.5), j0, j1, roi_width); 100 | 101 | sampler.backward(y_img, x_img, dy[n][c][i_roi][j_roi], accessor); 102 | } 103 | } 104 | } 105 | } 106 | } 107 | 108 | at::Tensor roi_sampling_backward_cpu( 109 | const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, std::tuple in_size, 110 | Interpolation interpolation, PaddingMode padding) { 111 | 112 | // Prepare output 113 | auto dx = at::zeros({std::get<0>(in_size), dy.size(1), std::get<1>(in_size), std::get<2>(in_size)}, dy.options()); 114 | 115 | AT_DISPATCH_ALL_TYPES(dy.scalar_type(), "roi_sampling_backward_cpu", ([&] { 116 | using coord_t = float; 117 | using index_t = int64_t; 118 | 119 | auto _dy = dy.accessor(); 120 | auto _bbx = bbx.accessor(); 121 | auto _idx = idx.accessor(); 122 | auto _dx = dx.accessor(); 123 | 124 | DISPATCH_INTERPOLATION_PADDING_MODES(interpolation, padding, ([&] { 125 | indexer_t indexer(dx.size(2), dx.size(3)); 126 | interpolator_t interpolator; 127 | sampler_t sampler(indexer, interpolator); 128 | 129 | roi_sampling_backward_impl(_dy, _bbx, _idx, _dx, sampler); 130 | })); 131 | })); 132 | 133 | return dx; 134 | } 135 | -------------------------------------------------------------------------------- /src/roi_sampling/roi_sampling_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "utils/checks.h" 8 | #include "utils/cuda.cuh" 9 | #include "utils/common.h" 10 | #include "roi_sampling.h" 11 | 12 | 13 | template 14 | __global__ void roi_sampling_forward_kernel( 15 | const at::PackedTensorAccessor x, 16 | const at::PackedTensorAccessor bbx, 17 | const at::PackedTensorAccessor idx, 18 | at::PackedTensorAccessor y, 19 | at::PackedTensorAccessor mask, 20 | bool valid_mask, 21 | Sampler sampler) { 22 | 23 | // Dimensions 24 | auto chn = x.size(1), img_height = x.size(2), img_width = x.size(3); 25 | auto roi_height = y.size(2), roi_width = y.size(3); 26 | index_t sizes[3] = {chn, roi_height, roi_width}; 27 | index_t out_size = chn * roi_height * roi_width; 28 | 29 | index_t n = blockIdx.x; 30 | 31 | // Get bounding box coordinates and image index 32 | auto i0 = bbx[n][0], j0 = bbx[n][1], i1 = bbx[n][2], j1 = bbx[n][3]; 33 | auto img_idx = idx[n]; 34 | 35 | auto x_n = x[img_idx], y_n = y[n]; 36 | 37 | for (int iter = threadIdx.x; iter < out_size; iter += blockDim.x) { 38 | // Find current indices 39 | index_t c, i, j; 40 | ind2sub(iter, sizes, j, i, c); 41 | 42 | auto y_img = roi_to_img(static_cast(i) + coord_t(0.5), i0, i1, static_cast(roi_height)); 43 | auto x_img = roi_to_img(static_cast(j) + coord_t(0.5), j0, j1, static_cast(roi_width)); 44 | 45 | y_n[c][i][j] = sampler.forward(y_img, x_img, x_n[c]); 46 | 47 | if (valid_mask) { 48 | mask[n][i][j] = 49 | y_img >= 0 && y_img < static_cast(img_height) && 50 | x_img >= 0 && x_img < static_cast(img_width); 51 | } 52 | } 53 | } 54 | 55 | template 56 | void roi_sampling_forward_template( 57 | const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, at::Tensor& y, at::Tensor& mask, 58 | Interpolation interpolation, PaddingMode padding, bool valid_mask) { 59 | // Create accessors 60 | auto x_accessor = x.packed_accessor(); 61 | auto bbx_accessor = bbx.packed_accessor(); 62 | auto idx_accessor = idx.packed_accessor(); 63 | auto y_accessor = y.packed_accessor(); 64 | auto mask_accessor = mask.packed_accessor(); 65 | 66 | dim3 blocks(y.size(0)); 67 | dim3 threads(getNumThreads(y.size(1) * y.size(2) * y.size(3))); 68 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 69 | 70 | // Run kernel 71 | DISPATCH_INTERPOLATION_PADDING_MODES(interpolation, padding, ([&] { 72 | indexer_t indexer(x.size(2), x.size(3)); 73 | interpolator_t interpolator; 74 | sampler_t sampler(indexer, interpolator); 75 | 76 | roi_sampling_forward_kernel<<>>( 77 | x_accessor, bbx_accessor, idx_accessor, y_accessor, mask_accessor, valid_mask, sampler); 78 | })); 79 | } 80 | 81 | std::tuple roi_sampling_forward_cuda( 82 | const at::Tensor& x, const at::Tensor& bbx, const at::Tensor& idx, std::tuple out_size, 83 | Interpolation interpolation, PaddingMode padding, bool valid_mask) { 84 | 85 | // Prepare outputs 86 | auto y = at::empty({idx.size(0), x.size(1), std::get<0>(out_size), std::get<1>(out_size)}, x.options()); 87 | auto mask = valid_mask 88 | ? at::zeros({idx.size(0), std::get<0>(out_size), std::get<1>(out_size)}, x.options().dtype(at::kByte)) 89 | : at::zeros({1, 1, 1}, x.options().dtype(at::kByte)); 90 | 91 | AT_DISPATCH_ALL_TYPES(x.scalar_type(), "roi_sampling_forward_cuda", ([&] { 92 | if (at::cuda::detail::canUse32BitIndexMath(x) && at::cuda::detail::canUse32BitIndexMath(y)) { 93 | roi_sampling_forward_template( 94 | x, bbx, idx, y, mask, interpolation, padding, valid_mask); 95 | } else { 96 | roi_sampling_forward_template( 97 | x, bbx, idx, y, mask, interpolation, padding, valid_mask); 98 | } 99 | })); 100 | 101 | return std::make_tuple(y, mask); 102 | } 103 | 104 | template 105 | __global__ void roi_sampling_backward_kernel( 106 | const at::PackedTensorAccessor dy, 107 | const at::PackedTensorAccessor bbx, 108 | const at::PackedTensorAccessor idx, 109 | at::PackedTensorAccessor dx, 110 | Sampler sampler) { 111 | 112 | // Dimensions 113 | auto num = dy.size(0), roi_height = dy.size(2), roi_width = dy.size(3); 114 | auto img_height = dx.size(2), img_width = dx.size(3); 115 | index_t iter_sizes[3] = {num, roi_height, roi_width}; 116 | index_t iter_size = num * roi_height * roi_width; 117 | 118 | // Local indices 119 | index_t c = blockIdx.x; 120 | 121 | for (int iter = threadIdx.x; iter < iter_size; iter += blockDim.x) { 122 | // Find current indices 123 | index_t n, i, j; 124 | ind2sub(iter, iter_sizes, j, i, n); 125 | 126 | // Get bounding box coordinates and image index 127 | // Get bounding box coordinates and image index 128 | auto i0 = bbx[n][0], j0 = bbx[n][1], i1 = bbx[n][2], j1 = bbx[n][3]; 129 | auto img_idx = idx[n]; 130 | 131 | auto y_img = roi_to_img(static_cast(i) + coord_t(0.5), i0, i1, static_cast(roi_height)); 132 | auto x_img = roi_to_img(static_cast(j) + coord_t(0.5), j0, j1, static_cast(roi_width)); 133 | 134 | sampler.backward(y_img, x_img, dy[n][c][i][j], dx[img_idx][c]); 135 | } 136 | } 137 | 138 | template 139 | void roi_sampling_backward_template( 140 | const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, at::Tensor& dx, 141 | Interpolation interpolation, PaddingMode padding) { 142 | // Create accessors 143 | auto dy_accessor = dy.packed_accessor(); 144 | auto bbx_accessor = bbx.packed_accessor(); 145 | auto idx_accessor = idx.packed_accessor(); 146 | auto dx_accessor = dx.packed_accessor(); 147 | 148 | dim3 blocks(dy.size(1)); 149 | dim3 threads(getNumThreads(dy.size(0) * dy.size(2) * dy.size(3))); 150 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 151 | 152 | // Run kernel 153 | DISPATCH_INTERPOLATION_PADDING_MODES(interpolation, padding, ([&] { 154 | indexer_t indexer(dx.size(2), dx.size(3)); 155 | interpolator_t interpolator; 156 | sampler_t sampler(indexer, interpolator); 157 | 158 | roi_sampling_backward_kernel<<>>( 159 | dy_accessor, bbx_accessor, idx_accessor, dx_accessor, sampler); 160 | })); 161 | } 162 | 163 | at::Tensor roi_sampling_backward_cuda( 164 | const at::Tensor& dy, const at::Tensor& bbx, const at::Tensor& idx, std::tuple in_size, 165 | Interpolation interpolation, PaddingMode padding) { 166 | 167 | // Prepare output 168 | auto dx = at::zeros({std::get<0>(in_size), dy.size(1), std::get<1>(in_size), std::get<2>(in_size)}, dy.options()); 169 | 170 | AT_DISPATCH_FLOATING_TYPES(dy.scalar_type(), "roi_sampling_backward_cuda", ([&] { 171 | if (at::cuda::detail::canUse32BitIndexMath(dy) && at::cuda::detail::canUse32BitIndexMath(dx)) { 172 | roi_sampling_backward_template( 173 | dy, bbx, idx, dx, interpolation, padding); 174 | } else { 175 | roi_sampling_backward_template( 176 | dy, bbx, idx, dx, interpolation, padding); 177 | } 178 | })); 179 | 180 | return dx; 181 | } -------------------------------------------------------------------------------- /weights_pretrained/Note.txt: -------------------------------------------------------------------------------- 1 | Add weights_pretrained here --------------------------------------------------------------------------------