├── .gitignore ├── LICENSE ├── README.md ├── assets ├── detection_anchors.png ├── detection_final.png ├── detection_masks.png ├── detection_refinement.png ├── park.png └── street.png ├── coco.py ├── config.py ├── convert_from_keras.py ├── demo.py ├── images ├── 1045023827_4ec3e8ba5c_z.jpg ├── 12283150_12d37e6389_z.jpg ├── 2383514521_1fc8d7b0de_z.jpg ├── 2502287818_41e4b0c4fb_z.jpg ├── 2516944023_d00345997d_z.jpg ├── 25691390_f9944f61b5_z.jpg ├── 262985539_1709e54576_z.jpg ├── 3132016470_c27baa00e8_z.jpg ├── 3627527276_6fe8cd9bfe_z.jpg ├── 3651581213_f81963d1dd_z.jpg ├── 3800883468_12af3c0b50_z.jpg ├── 3862500489_6fd195d183_z.jpg ├── 3878153025_8fde829928_z.jpg ├── 4410436637_7b0ca36ee7_z.jpg ├── 4782628554_668bc31826_z.jpg ├── 5951960966_d4e1cda5d0_z.jpg ├── 6584515005_fce9cec486_z.jpg ├── 6821351586_59aa0dc110_z.jpg ├── 7581246086_cf7bbb7255_z.jpg ├── 7933423348_c30bd9bd4e_z.jpg ├── 8053677163_d4c8f416be_z.jpg ├── 8239308689_efa6c11b08_z.jpg ├── 8433365521_9252889f9a_z.jpg ├── 8512296263_5fc5458e20_z.jpg ├── 8699757338_c3941051b6_z.jpg ├── 8734543718_37f6b8bd45_z.jpg ├── 8829708882_48f263491e_z.jpg ├── 9118579087_f9ffa19e63_z.jpg └── 9247489789_132c0d534a_z.jpg ├── model.py ├── nms ├── __init__.py ├── build.py ├── nms_wrapper.py ├── pth_nms.py └── src │ ├── cuda │ ├── nms_kernel.cu │ └── nms_kernel.h │ ├── nms.c │ ├── nms.h │ ├── nms_cuda.c │ └── nms_cuda.h ├── roialign ├── __init__.py └── roi_align │ ├── __init__.py │ ├── build.py │ ├── crop_and_resize.py │ ├── roi_align.py │ └── src │ ├── crop_and_resize.c │ ├── crop_and_resize.h │ ├── crop_and_resize_gpu.c │ ├── crop_and_resize_gpu.h │ └── cuda │ ├── crop_and_resize_kernel.cu │ └── crop_and_resize_kernel.h ├── utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | logs/ 3 | *.pth 4 | *.so 5 | *.pyc 6 | *.o 7 | pycocotools 8 | *_ext* 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Mask R-CNN 2 | 3 | The MIT License (MIT) 4 | 5 | Copyright (c) 2017 Matterport, Inc. 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in 15 | all copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 | THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-mask-rcnn 2 | 3 | 4 | This is a Pytorch implementation of [Mask R-CNN](https://arxiv.org/abs/1703.06870) that is in large parts based on Matterport's 5 | [Mask_RCNN](https://github.com/matterport/Mask_RCNN). Matterport's repository is an implementation on Keras and TensorFlow. 6 | The following parts of the README are excerpts from the Matterport README. Details on the requirements, training on MS COCO 7 | and detection results for this repository can be found at the end of the document. 8 | 9 | The Mask R-CNN model generates bounding boxes and segmentation masks for each instance of an object in the image. It's based 10 | on Feature Pyramid Network (FPN) and a ResNet101 backbone. 11 | 12 | ![Instance Segmentation Sample](assets/street.png) 13 | 14 | The next four images visualize different stages in the detection pipeline: 15 | 16 | 17 | ##### 1. Anchor sorting and filtering 18 | The Region Proposal Network proposes bounding boxes that are likely to belong to an object. Positive and negative anchors 19 | along with anchor box refinement are visualized. 20 | 21 | ![](assets/detection_anchors.png) 22 | 23 | 24 | ##### 2. Bounding Box Refinement 25 | This is an example of final detection boxes (dotted lines) and the refinement applied to them (solid lines) in the second stage. 26 | 27 | ![](assets/detection_refinement.png) 28 | 29 | 30 | ##### 3. Mask Generation 31 | Examples of generated masks. These then get scaled and placed on the image in the right location. 32 | 33 | ![](assets/detection_masks.png) 34 | 35 | 36 | ##### 4. Composing the different pieces into a final result 37 | 38 | ![](assets/detection_final.png) 39 | 40 | ## Requirements 41 | * Python 3 42 | * Pytorch 0.3 43 | * matplotlib, scipy, skimage, h5py 44 | 45 | ## Installation 46 | 1. Clone this repository. 47 | 48 | git clone https://github.com/multimodallearning/pytorch-mask-rcnn.git 49 | 50 | 51 | 2. We use functions from two more repositories that need to be build with the right `--arch` option for cuda support. 52 | The two functions are Non-Maximum Suppression from ruotianluo's [pytorch-faster-rcnn](https://github.com/ruotianluo/pytorch-faster-rcnn) 53 | repository and longcw's [RoiAlign](https://github.com/longcw/RoIAlign.pytorch). 54 | 55 | | GPU | arch | 56 | | --- | --- | 57 | | TitanX | sm_52 | 58 | | GTX 960M | sm_50 | 59 | | GTX 1070 | sm_61 | 60 | | GTX 1080 (Ti) | sm_61 | 61 | 62 | cd nms/src/cuda/ 63 | nvcc -c -o nms_kernel.cu.o nms_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] 64 | cd ../../ 65 | python build.py 66 | cd ../ 67 | 68 | cd roialign/roi_align/src/cuda/ 69 | nvcc -c -o crop_and_resize_kernel.cu.o crop_and_resize_kernel.cu -x cu -Xcompiler -fPIC -arch=[arch] 70 | cd ../../ 71 | python build.py 72 | cd ../../ 73 | 74 | 3. As we use the [COCO dataset](http://cocodataset.org/#home) install the [Python COCO API](https://github.com/cocodataset/cocoapi) and 75 | create a symlink. 76 | 77 | ln -s /path/to/coco/cocoapi/PythonAPI/pycocotools/ pycocotools 78 | 79 | 4. Download the pretrained models on COCO and ImageNet from [Google Drive](https://drive.google.com/open?id=1LXUgC2IZUYNEoXr05tdqyKFZY0pZyPDc). 80 | 81 | ## Demo 82 | 83 | To test your installation simply run the demo with 84 | 85 | python demo.py 86 | 87 | It works on CPU or GPU and the result should look like this: 88 | 89 | ![](assets/park.png) 90 | 91 | ## Training on COCO 92 | Training and evaluation code is in coco.py. You can run it from the command 93 | line as such: 94 | 95 | # Train a new model starting from pre-trained COCO weights 96 | python coco.py train --dataset=/path/to/coco/ --model=coco 97 | 98 | # Train a new model starting from ImageNet weights 99 | python coco.py train --dataset=/path/to/coco/ --model=imagenet 100 | 101 | # Continue training a model that you had trained earlier 102 | python coco.py train --dataset=/path/to/coco/ --model=/path/to/weights.h5 103 | 104 | # Continue training the last model you trained. This will find 105 | # the last trained weights in the model directory. 106 | python coco.py train --dataset=/path/to/coco/ --model=last 107 | 108 | If you have not yet downloaded the COCO dataset you should run the command 109 | with the download option set, e.g.: 110 | 111 | # Train a new model starting from pre-trained COCO weights 112 | python coco.py train --dataset=/path/to/coco/ --model=coco --download=true 113 | 114 | You can also run the COCO evaluation code with: 115 | 116 | # Run COCO evaluation on the last trained model 117 | python coco.py evaluate --dataset=/path/to/coco/ --model=last 118 | 119 | The training schedule, learning rate, and other parameters can be set in coco.py. 120 | 121 | ## Results 122 | 123 | COCO results for bounding box and segmentation are reported based on training 124 | with the default configuration and backbone initialized with pretrained 125 | ImageNet weights. Used metric is AP on IoU=0.50:0.95. 126 | 127 | | | from scratch | converted from keras | Matterport's Mask_RCNN | Mask R-CNN paper | 128 | | --- | --- | --- | --- | --- | 129 | | bbox | t.b.a. | 0.347 | 0.347 | 0.382 | 130 | | segm | t.b.a. | 0.296 | 0.296 | 0.354 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /assets/detection_anchors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/assets/detection_anchors.png -------------------------------------------------------------------------------- /assets/detection_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/assets/detection_final.png -------------------------------------------------------------------------------- /assets/detection_masks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/assets/detection_masks.png -------------------------------------------------------------------------------- /assets/detection_refinement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/assets/detection_refinement.png -------------------------------------------------------------------------------- /assets/park.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/assets/park.png -------------------------------------------------------------------------------- /assets/street.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/assets/street.png -------------------------------------------------------------------------------- /coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mask R-CNN 3 | Configurations and data loading code for MS COCO. 4 | 5 | Copyright (c) 2017 Matterport, Inc. 6 | Licensed under the MIT License (see LICENSE for details) 7 | Written by Waleed Abdulla 8 | 9 | ------------------------------------------------------------ 10 | 11 | Usage: import the module (see Jupyter notebooks for examples), or run from 12 | the command line as such: 13 | 14 | # Train a new model starting from pre-trained COCO weights 15 | python3 coco.py train --dataset=/path/to/coco/ --model=coco 16 | 17 | # Train a new model starting from ImageNet weights 18 | python3 coco.py train --dataset=/path/to/coco/ --model=imagenet 19 | 20 | # Continue training a model that you had trained earlier 21 | python3 coco.py train --dataset=/path/to/coco/ --model=/path/to/weights.h5 22 | 23 | # Continue training the last model you trained 24 | python3 coco.py train --dataset=/path/to/coco/ --model=last 25 | 26 | # Run COCO evaluatoin on the last model you trained 27 | python3 coco.py evaluate --dataset=/path/to/coco/ --model=last 28 | """ 29 | 30 | import os 31 | import time 32 | import numpy as np 33 | 34 | # Download and install the Python COCO tools from https://github.com/waleedka/coco 35 | # That's a fork from the original https://github.com/pdollar/coco with a bug 36 | # fix for Python 3. 37 | # I submitted a pull request https://github.com/cocodataset/cocoapi/pull/50 38 | # If the PR is merged then use the original repo. 39 | # Note: Edit PythonAPI/Makefile and replace "python" with "python3". 40 | from pycocotools.coco import COCO 41 | from pycocotools.cocoeval import COCOeval 42 | from pycocotools import mask as maskUtils 43 | 44 | import zipfile 45 | import urllib.request 46 | import shutil 47 | 48 | from config import Config 49 | import utils 50 | import model as modellib 51 | 52 | import torch 53 | 54 | # Root directory of the project 55 | ROOT_DIR = os.getcwd() 56 | 57 | # Path to trained weights file 58 | COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.pth") 59 | 60 | # Directory to save logs and model checkpoints, if not provided 61 | # through the command line argument --logs 62 | DEFAULT_LOGS_DIR = os.path.join(ROOT_DIR, "logs") 63 | DEFAULT_DATASET_YEAR = "2014" 64 | 65 | ############################################################ 66 | # Configurations 67 | ############################################################ 68 | 69 | class CocoConfig(Config): 70 | """Configuration for training on MS COCO. 71 | Derives from the base Config class and overrides values specific 72 | to the COCO dataset. 73 | """ 74 | # Give the configuration a recognizable name 75 | NAME = "coco" 76 | 77 | # We use one GPU with 8GB memory, which can fit one image. 78 | # Adjust down if you use a smaller GPU. 79 | IMAGES_PER_GPU = 16 80 | 81 | # Uncomment to train on 8 GPUs (default is 1) 82 | # GPU_COUNT = 8 83 | 84 | # Number of classes (including background) 85 | NUM_CLASSES = 1 + 80 # COCO has 80 classes 86 | 87 | 88 | ############################################################ 89 | # Dataset 90 | ############################################################ 91 | 92 | class CocoDataset(utils.Dataset): 93 | def load_coco(self, dataset_dir, subset, year=DEFAULT_DATASET_YEAR, class_ids=None, 94 | class_map=None, return_coco=False, auto_download=False): 95 | """Load a subset of the COCO dataset. 96 | dataset_dir: The root directory of the COCO dataset. 97 | subset: What to load (train, val, minival, valminusminival) 98 | year: What dataset year to load (2014, 2017) as a string, not an integer 99 | class_ids: If provided, only loads images that have the given classes. 100 | class_map: TODO: Not implemented yet. Supports maping classes from 101 | different datasets to the same class ID. 102 | return_coco: If True, returns the COCO object. 103 | auto_download: Automatically download and unzip MS-COCO images and annotations 104 | """ 105 | 106 | if auto_download is True: 107 | self.auto_download(dataset_dir, subset, year) 108 | 109 | coco = COCO("{}/annotations/instances_{}{}.json".format(dataset_dir, subset, year)) 110 | if subset == "minival" or subset == "valminusminival": 111 | subset = "val" 112 | image_dir = "{}/{}{}".format(dataset_dir, subset, year) 113 | 114 | # Load all classes or a subset? 115 | if not class_ids: 116 | # All classes 117 | class_ids = sorted(coco.getCatIds()) 118 | 119 | # All images or a subset? 120 | if class_ids: 121 | image_ids = [] 122 | for id in class_ids: 123 | image_ids.extend(list(coco.getImgIds(catIds=[id]))) 124 | # Remove duplicates 125 | image_ids = list(set(image_ids)) 126 | else: 127 | # All images 128 | image_ids = list(coco.imgs.keys()) 129 | 130 | # Add classes 131 | for i in class_ids: 132 | self.add_class("coco", i, coco.loadCats(i)[0]["name"]) 133 | 134 | # Add images 135 | for i in image_ids: 136 | self.add_image( 137 | "coco", image_id=i, 138 | path=os.path.join(image_dir, coco.imgs[i]['file_name']), 139 | width=coco.imgs[i]["width"], 140 | height=coco.imgs[i]["height"], 141 | annotations=coco.loadAnns(coco.getAnnIds( 142 | imgIds=[i], catIds=class_ids, iscrowd=None))) 143 | if return_coco: 144 | return coco 145 | 146 | def auto_download(self, dataDir, dataType, dataYear): 147 | """Download the COCO dataset/annotations if requested. 148 | dataDir: The root directory of the COCO dataset. 149 | dataType: What to load (train, val, minival, valminusminival) 150 | dataYear: What dataset year to load (2014, 2017) as a string, not an integer 151 | Note: 152 | For 2014, use "train", "val", "minival", or "valminusminival" 153 | For 2017, only "train" and "val" annotations are available 154 | """ 155 | 156 | # Setup paths and file names 157 | if dataType == "minival" or dataType == "valminusminival": 158 | imgDir = "{}/{}{}".format(dataDir, "val", dataYear) 159 | imgZipFile = "{}/{}{}.zip".format(dataDir, "val", dataYear) 160 | imgURL = "http://images.cocodataset.org/zips/{}{}.zip".format("val", dataYear) 161 | else: 162 | imgDir = "{}/{}{}".format(dataDir, dataType, dataYear) 163 | imgZipFile = "{}/{}{}.zip".format(dataDir, dataType, dataYear) 164 | imgURL = "http://images.cocodataset.org/zips/{}{}.zip".format(dataType, dataYear) 165 | # print("Image paths:"); print(imgDir); print(imgZipFile); print(imgURL) 166 | 167 | # Create main folder if it doesn't exist yet 168 | if not os.path.exists(dataDir): 169 | os.makedirs(dataDir) 170 | 171 | # Download images if not available locally 172 | if not os.path.exists(imgDir): 173 | os.makedirs(imgDir) 174 | print("Downloading images to " + imgZipFile + " ...") 175 | with urllib.request.urlopen(imgURL) as resp, open(imgZipFile, 'wb') as out: 176 | shutil.copyfileobj(resp, out) 177 | print("... done downloading.") 178 | print("Unzipping " + imgZipFile) 179 | with zipfile.ZipFile(imgZipFile, "r") as zip_ref: 180 | zip_ref.extractall(dataDir) 181 | print("... done unzipping") 182 | print("Will use images in " + imgDir) 183 | 184 | # Setup annotations data paths 185 | annDir = "{}/annotations".format(dataDir) 186 | if dataType == "minival": 187 | annZipFile = "{}/instances_minival2014.json.zip".format(dataDir) 188 | annFile = "{}/instances_minival2014.json".format(annDir) 189 | annURL = "https://dl.dropboxusercontent.com/s/o43o90bna78omob/instances_minival2014.json.zip?dl=0" 190 | unZipDir = annDir 191 | elif dataType == "valminusminival": 192 | annZipFile = "{}/instances_valminusminival2014.json.zip".format(dataDir) 193 | annFile = "{}/instances_valminusminival2014.json".format(annDir) 194 | annURL = "https://dl.dropboxusercontent.com/s/s3tw5zcg7395368/instances_valminusminival2014.json.zip?dl=0" 195 | unZipDir = annDir 196 | else: 197 | annZipFile = "{}/annotations_trainval{}.zip".format(dataDir, dataYear) 198 | annFile = "{}/instances_{}{}.json".format(annDir, dataType, dataYear) 199 | annURL = "http://images.cocodataset.org/annotations/annotations_trainval{}.zip".format(dataYear) 200 | unZipDir = dataDir 201 | # print("Annotations paths:"); print(annDir); print(annFile); print(annZipFile); print(annURL) 202 | 203 | # Download annotations if not available locally 204 | if not os.path.exists(annDir): 205 | os.makedirs(annDir) 206 | if not os.path.exists(annFile): 207 | if not os.path.exists(annZipFile): 208 | print("Downloading zipped annotations to " + annZipFile + " ...") 209 | with urllib.request.urlopen(annURL) as resp, open(annZipFile, 'wb') as out: 210 | shutil.copyfileobj(resp, out) 211 | print("... done downloading.") 212 | print("Unzipping " + annZipFile) 213 | with zipfile.ZipFile(annZipFile, "r") as zip_ref: 214 | zip_ref.extractall(unZipDir) 215 | print("... done unzipping") 216 | print("Will use annotations in " + annFile) 217 | 218 | def load_mask(self, image_id): 219 | """Load instance masks for the given image. 220 | 221 | Different datasets use different ways to store masks. This 222 | function converts the different mask format to one format 223 | in the form of a bitmap [height, width, instances]. 224 | 225 | Returns: 226 | masks: A bool array of shape [height, width, instance count] with 227 | one mask per instance. 228 | class_ids: a 1D array of class IDs of the instance masks. 229 | """ 230 | # If not a COCO image, delegate to parent class. 231 | image_info = self.image_info[image_id] 232 | if image_info["source"] != "coco": 233 | return super(CocoDataset, self).load_mask(image_id) 234 | 235 | instance_masks = [] 236 | class_ids = [] 237 | annotations = self.image_info[image_id]["annotations"] 238 | # Build mask of shape [height, width, instance_count] and list 239 | # of class IDs that correspond to each channel of the mask. 240 | for annotation in annotations: 241 | class_id = self.map_source_class_id( 242 | "coco.{}".format(annotation['category_id'])) 243 | if class_id: 244 | m = self.annToMask(annotation, image_info["height"], 245 | image_info["width"]) 246 | # Some objects are so small that they're less than 1 pixel area 247 | # and end up rounded out. Skip those objects. 248 | if m.max() < 1: 249 | continue 250 | # Is it a crowd? If so, use a negative class ID. 251 | if annotation['iscrowd']: 252 | # Use negative class ID for crowds 253 | class_id *= -1 254 | # For crowd masks, annToMask() sometimes returns a mask 255 | # smaller than the given dimensions. If so, resize it. 256 | if m.shape[0] != image_info["height"] or m.shape[1] != image_info["width"]: 257 | m = np.ones([image_info["height"], image_info["width"]], dtype=bool) 258 | instance_masks.append(m) 259 | class_ids.append(class_id) 260 | 261 | # Pack instance masks into an array 262 | if class_ids: 263 | mask = np.stack(instance_masks, axis=2) 264 | class_ids = np.array(class_ids, dtype=np.int32) 265 | return mask, class_ids 266 | else: 267 | # Call super class to return an empty mask 268 | return super(CocoDataset, self).load_mask(image_id) 269 | 270 | def image_reference(self, image_id): 271 | """Return a link to the image in the COCO Website.""" 272 | info = self.image_info[image_id] 273 | if info["source"] == "coco": 274 | return "http://cocodataset.org/#explore?id={}".format(info["id"]) 275 | else: 276 | super(CocoDataset, self).image_reference(image_id) 277 | 278 | # The following two functions are from pycocotools with a few changes. 279 | 280 | def annToRLE(self, ann, height, width): 281 | """ 282 | Convert annotation which can be polygons, uncompressed RLE to RLE. 283 | :return: binary mask (numpy 2D array) 284 | """ 285 | segm = ann['segmentation'] 286 | if isinstance(segm, list): 287 | # polygon -- a single object might consist of multiple parts 288 | # we merge all parts into one mask rle code 289 | rles = maskUtils.frPyObjects(segm, height, width) 290 | rle = maskUtils.merge(rles) 291 | elif isinstance(segm['counts'], list): 292 | # uncompressed RLE 293 | rle = maskUtils.frPyObjects(segm, height, width) 294 | else: 295 | # rle 296 | rle = ann['segmentation'] 297 | return rle 298 | 299 | def annToMask(self, ann, height, width): 300 | """ 301 | Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask. 302 | :return: binary mask (numpy 2D array) 303 | """ 304 | rle = self.annToRLE(ann, height, width) 305 | m = maskUtils.decode(rle) 306 | return m 307 | 308 | 309 | ############################################################ 310 | # COCO Evaluation 311 | ############################################################ 312 | 313 | def build_coco_results(dataset, image_ids, rois, class_ids, scores, masks): 314 | """Arrange resutls to match COCO specs in http://cocodataset.org/#format 315 | """ 316 | # If no results, return an empty list 317 | if rois is None: 318 | return [] 319 | 320 | results = [] 321 | for image_id in image_ids: 322 | # Loop through detections 323 | for i in range(rois.shape[0]): 324 | class_id = class_ids[i] 325 | score = scores[i] 326 | bbox = np.around(rois[i], 1) 327 | mask = masks[:, :, i] 328 | 329 | result = { 330 | "image_id": image_id, 331 | "category_id": dataset.get_source_class_id(class_id, "coco"), 332 | "bbox": [bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0]], 333 | "score": score, 334 | "segmentation": maskUtils.encode(np.asfortranarray(mask)) 335 | } 336 | results.append(result) 337 | return results 338 | 339 | 340 | def evaluate_coco(model, dataset, coco, eval_type="bbox", limit=0, image_ids=None): 341 | """Runs official COCO evaluation. 342 | dataset: A Dataset object with valiadtion data 343 | eval_type: "bbox" or "segm" for bounding box or segmentation evaluation 344 | limit: if not 0, it's the number of images to use for evaluation 345 | """ 346 | # Pick COCO images from the dataset 347 | image_ids = image_ids or dataset.image_ids 348 | 349 | # Limit to a subset 350 | if limit: 351 | image_ids = image_ids[:limit] 352 | 353 | # Get corresponding COCO image IDs. 354 | coco_image_ids = [dataset.image_info[id]["id"] for id in image_ids] 355 | 356 | t_prediction = 0 357 | t_start = time.time() 358 | 359 | results = [] 360 | for i, image_id in enumerate(image_ids): 361 | # Load image 362 | image = dataset.load_image(image_id) 363 | 364 | # Run detection 365 | t = time.time() 366 | r = model.detect([image])[0] 367 | t_prediction += (time.time() - t) 368 | 369 | # Convert results to COCO format 370 | image_results = build_coco_results(dataset, coco_image_ids[i:i + 1], 371 | r["rois"], r["class_ids"], 372 | r["scores"], r["masks"]) 373 | results.extend(image_results) 374 | 375 | # Load results. This modifies results with additional attributes. 376 | coco_results = coco.loadRes(results) 377 | 378 | # Evaluate 379 | cocoEval = COCOeval(coco, coco_results, eval_type) 380 | cocoEval.params.imgIds = coco_image_ids 381 | cocoEval.evaluate() 382 | cocoEval.accumulate() 383 | cocoEval.summarize() 384 | 385 | print("Prediction time: {}. Average {}/image".format( 386 | t_prediction, t_prediction / len(image_ids))) 387 | print("Total time: ", time.time() - t_start) 388 | 389 | 390 | ############################################################ 391 | # Training 392 | ############################################################ 393 | 394 | 395 | if __name__ == '__main__': 396 | import argparse 397 | 398 | # Parse command line arguments 399 | parser = argparse.ArgumentParser( 400 | description='Train Mask R-CNN on MS COCO.') 401 | parser.add_argument("command", 402 | metavar="", 403 | help="'train' or 'evaluate' on MS COCO") 404 | parser.add_argument('--dataset', required=True, 405 | metavar="/path/to/coco/", 406 | help='Directory of the MS-COCO dataset') 407 | parser.add_argument('--year', required=False, 408 | default=DEFAULT_DATASET_YEAR, 409 | metavar="", 410 | help='Year of the MS-COCO dataset (2014 or 2017) (default=2014)') 411 | parser.add_argument('--model', required=False, 412 | metavar="/path/to/weights.pth", 413 | help="Path to weights .pth file or 'coco'") 414 | parser.add_argument('--logs', required=False, 415 | default=DEFAULT_LOGS_DIR, 416 | metavar="/path/to/logs/", 417 | help='Logs and checkpoints directory (default=logs/)') 418 | parser.add_argument('--limit', required=False, 419 | default=500, 420 | metavar="", 421 | help='Images to use for evaluation (default=500)') 422 | parser.add_argument('--download', required=False, 423 | default=False, 424 | metavar="", 425 | help='Automatically download and unzip MS-COCO files (default=False)', 426 | type=bool) 427 | args = parser.parse_args() 428 | print("Command: ", args.command) 429 | print("Model: ", args.model) 430 | print("Dataset: ", args.dataset) 431 | print("Year: ", args.year) 432 | print("Logs: ", args.logs) 433 | print("Auto Download: ", args.download) 434 | 435 | # Configurations 436 | if args.command == "train": 437 | config = CocoConfig() 438 | else: 439 | class InferenceConfig(CocoConfig): 440 | # Set batch size to 1 since we'll be running inference on 441 | # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU 442 | GPU_COUNT = 1 443 | IMAGES_PER_GPU = 1 444 | DETECTION_MIN_CONFIDENCE = 0 445 | config = InferenceConfig() 446 | config.display() 447 | 448 | # Create model 449 | if args.command == "train": 450 | model = modellib.MaskRCNN(config=config, 451 | model_dir=args.logs) 452 | else: 453 | model = modellib.MaskRCNN(config=config, 454 | model_dir=args.logs) 455 | if config.GPU_COUNT: 456 | model = model.cuda() 457 | 458 | # Select weights file to load 459 | if args.model: 460 | if args.model.lower() == "coco": 461 | model_path = COCO_MODEL_PATH 462 | elif args.model.lower() == "last": 463 | # Find last trained weights 464 | model_path = model.find_last()[1] 465 | elif args.model.lower() == "imagenet": 466 | # Start from ImageNet trained weights 467 | model_path = config.IMAGENET_MODEL_PATH 468 | else: 469 | model_path = args.model 470 | else: 471 | model_path = "" 472 | 473 | # Load weights 474 | print("Loading weights ", model_path) 475 | model.load_weights(model_path) 476 | 477 | # Train or evaluate 478 | if args.command == "train": 479 | # Training dataset. Use the training set and 35K from the 480 | # validation set, as as in the Mask RCNN paper. 481 | dataset_train = CocoDataset() 482 | dataset_train.load_coco(args.dataset, "train", year=args.year, auto_download=args.download) 483 | dataset_train.load_coco(args.dataset, "valminusminival", year=args.year, auto_download=args.download) 484 | dataset_train.prepare() 485 | 486 | # Validation dataset 487 | dataset_val = CocoDataset() 488 | dataset_val.load_coco(args.dataset, "minival", year=args.year, auto_download=args.download) 489 | dataset_val.prepare() 490 | 491 | # *** This training schedule is an example. Update to your needs *** 492 | 493 | # Training - Stage 1 494 | print("Training network heads") 495 | model.train_model(dataset_train, dataset_val, 496 | learning_rate=config.LEARNING_RATE, 497 | epochs=40, 498 | layers='heads') 499 | 500 | # Training - Stage 2 501 | # Finetune layers from ResNet stage 4 and up 502 | print("Fine tune Resnet stage 4 and up") 503 | model.train_model(dataset_train, dataset_val, 504 | learning_rate=config.LEARNING_RATE, 505 | epochs=120, 506 | layers='4+') 507 | 508 | # Training - Stage 3 509 | # Fine tune all layers 510 | print("Fine tune all layers") 511 | model.train_model(dataset_train, dataset_val, 512 | learning_rate=config.LEARNING_RATE / 10, 513 | epochs=160, 514 | layers='all') 515 | 516 | elif args.command == "evaluate": 517 | # Validation dataset 518 | dataset_val = CocoDataset() 519 | coco = dataset_val.load_coco(args.dataset, "minival", year=args.year, return_coco=True, auto_download=args.download) 520 | dataset_val.prepare() 521 | print("Running COCO evaluation on {} images.".format(args.limit)) 522 | evaluate_coco(model, dataset_val, coco, "bbox", limit=int(args.limit)) 523 | evaluate_coco(model, dataset_val, coco, "segm", limit=int(args.limit)) 524 | else: 525 | print("'{}' is not recognized. " 526 | "Use 'train' or 'evaluate'".format(args.command)) 527 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mask R-CNN 3 | Base Configurations class. 4 | 5 | Copyright (c) 2017 Matterport, Inc. 6 | Licensed under the MIT License (see LICENSE for details) 7 | Written by Waleed Abdulla 8 | """ 9 | 10 | import math 11 | import numpy as np 12 | import os 13 | 14 | 15 | # Base Configuration Class 16 | # Don't use this class directly. Instead, sub-class it and override 17 | # the configurations you need to change. 18 | 19 | class Config(object): 20 | """Base configuration class. For custom configurations, create a 21 | sub-class that inherits from this one and override properties 22 | that need to be changed. 23 | """ 24 | # Name the configurations. For example, 'COCO', 'Experiment 3', ...etc. 25 | # Useful if your code needs to do things differently depending on which 26 | # experiment is running. 27 | NAME = None # Override in sub-classes 28 | 29 | # Path to pretrained imagenet model 30 | IMAGENET_MODEL_PATH = os.path.join(os.getcwd(), "resnet50_imagenet.pth") 31 | 32 | # NUMBER OF GPUs to use. For CPU use 0 33 | GPU_COUNT = 1 34 | 35 | # Number of images to train with on each GPU. A 12GB GPU can typically 36 | # handle 2 images of 1024x1024px. 37 | # Adjust based on your GPU memory and image sizes. Use the highest 38 | # number that your GPU can handle for best performance. 39 | IMAGES_PER_GPU = 1 40 | 41 | # Number of training steps per epoch 42 | # This doesn't need to match the size of the training set. Tensorboard 43 | # updates are saved at the end of each epoch, so setting this to a 44 | # smaller number means getting more frequent TensorBoard updates. 45 | # Validation stats are also calculated at each epoch end and they 46 | # might take a while, so don't set this too small to avoid spending 47 | # a lot of time on validation stats. 48 | STEPS_PER_EPOCH = 1000 49 | 50 | # Number of validation steps to run at the end of every training epoch. 51 | # A bigger number improves accuracy of validation stats, but slows 52 | # down the training. 53 | VALIDATION_STEPS = 50 54 | 55 | # The strides of each layer of the FPN Pyramid. These values 56 | # are based on a Resnet101 backbone. 57 | BACKBONE_STRIDES = [4, 8, 16, 32, 64] 58 | 59 | # Number of classification classes (including background) 60 | NUM_CLASSES = 1 # Override in sub-classes 61 | 62 | # Length of square anchor side in pixels 63 | RPN_ANCHOR_SCALES = (32, 64, 128, 256, 512) 64 | 65 | # Ratios of anchors at each cell (width/height) 66 | # A value of 1 represents a square anchor, and 0.5 is a wide anchor 67 | RPN_ANCHOR_RATIOS = [0.5, 1, 2] 68 | 69 | # Anchor stride 70 | # If 1 then anchors are created for each cell in the backbone feature map. 71 | # If 2, then anchors are created for every other cell, and so on. 72 | RPN_ANCHOR_STRIDE = 1 73 | 74 | # Non-max suppression threshold to filter RPN proposals. 75 | # You can reduce this during training to generate more propsals. 76 | RPN_NMS_THRESHOLD = 0.7 77 | 78 | # How many anchors per image to use for RPN training 79 | RPN_TRAIN_ANCHORS_PER_IMAGE = 256 80 | 81 | # ROIs kept after non-maximum supression (training and inference) 82 | POST_NMS_ROIS_TRAINING = 2000 83 | POST_NMS_ROIS_INFERENCE = 1000 84 | 85 | # If enabled, resizes instance masks to a smaller size to reduce 86 | # memory load. Recommended when using high-resolution images. 87 | USE_MINI_MASK = True 88 | MINI_MASK_SHAPE = (56, 56) # (height, width) of the mini-mask 89 | 90 | # Input image resing 91 | # Images are resized such that the smallest side is >= IMAGE_MIN_DIM and 92 | # the longest side is <= IMAGE_MAX_DIM. In case both conditions can't 93 | # be satisfied together the IMAGE_MAX_DIM is enforced. 94 | IMAGE_MIN_DIM = 800 95 | IMAGE_MAX_DIM = 1024 96 | # If True, pad images with zeros such that they're (max_dim by max_dim) 97 | IMAGE_PADDING = True # currently, the False option is not supported 98 | 99 | # Image mean (RGB) 100 | MEAN_PIXEL = np.array([123.7, 116.8, 103.9]) 101 | 102 | # Number of ROIs per image to feed to classifier/mask heads 103 | # The Mask RCNN paper uses 512 but often the RPN doesn't generate 104 | # enough positive proposals to fill this and keep a positive:negative 105 | # ratio of 1:3. You can increase the number of proposals by adjusting 106 | # the RPN NMS threshold. 107 | TRAIN_ROIS_PER_IMAGE = 200 108 | 109 | # Percent of positive ROIs used to train classifier/mask heads 110 | ROI_POSITIVE_RATIO = 0.33 111 | 112 | # Pooled ROIs 113 | POOL_SIZE = 7 114 | MASK_POOL_SIZE = 14 115 | MASK_SHAPE = [28, 28] 116 | 117 | # Maximum number of ground truth instances to use in one image 118 | MAX_GT_INSTANCES = 100 119 | 120 | # Bounding box refinement standard deviation for RPN and final detections. 121 | RPN_BBOX_STD_DEV = np.array([0.1, 0.1, 0.2, 0.2]) 122 | BBOX_STD_DEV = np.array([0.1, 0.1, 0.2, 0.2]) 123 | 124 | # Max number of final detections 125 | DETECTION_MAX_INSTANCES = 100 126 | 127 | # Minimum probability value to accept a detected instance 128 | # ROIs below this threshold are skipped 129 | DETECTION_MIN_CONFIDENCE = 0.7 130 | 131 | # Non-maximum suppression threshold for detection 132 | DETECTION_NMS_THRESHOLD = 0.3 133 | 134 | # Learning rate and momentum 135 | # The Mask RCNN paper uses lr=0.02, but on TensorFlow it causes 136 | # weights to explode. Likely due to differences in optimzer 137 | # implementation. 138 | LEARNING_RATE = 0.001 139 | LEARNING_MOMENTUM = 0.9 140 | 141 | # Weight decay regularization 142 | WEIGHT_DECAY = 0.0001 143 | 144 | # Use RPN ROIs or externally generated ROIs for training 145 | # Keep this True for most situations. Set to False if you want to train 146 | # the head branches on ROI generated by code rather than the ROIs from 147 | # the RPN. For example, to debug the classifier head without having to 148 | # train the RPN. 149 | USE_RPN_ROIS = True 150 | 151 | def __init__(self): 152 | """Set values of computed attributes.""" 153 | # Effective batch size 154 | if self.GPU_COUNT > 0: 155 | self.BATCH_SIZE = self.IMAGES_PER_GPU * self.GPU_COUNT 156 | else: 157 | self.BATCH_SIZE = self.IMAGES_PER_GPU 158 | 159 | # Adjust step size based on batch size 160 | self.STEPS_PER_EPOCH = self.BATCH_SIZE * self.STEPS_PER_EPOCH 161 | 162 | # Input image size 163 | self.IMAGE_SHAPE = np.array( 164 | [self.IMAGE_MAX_DIM, self.IMAGE_MAX_DIM, 3]) 165 | 166 | # Compute backbone size from input image size 167 | self.BACKBONE_SHAPES = np.array( 168 | [[int(math.ceil(self.IMAGE_SHAPE[0] / stride)), 169 | int(math.ceil(self.IMAGE_SHAPE[1] / stride))] 170 | for stride in self.BACKBONE_STRIDES]) 171 | 172 | def display(self): 173 | """Display Configuration values.""" 174 | print("\nConfigurations:") 175 | for a in dir(self): 176 | if not a.startswith("__") and not callable(getattr(self, a)): 177 | print("{:30} {}".format(a, getattr(self, a))) 178 | print("\n") 179 | -------------------------------------------------------------------------------- /convert_from_keras.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import h5py 4 | import torch 5 | 6 | alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'] 7 | 8 | parser = argparse.ArgumentParser(description='Convert keras-mask-rcnn model to pytorch-mask-rcnn model') 9 | parser.add_argument('--keras_model', 10 | help='the path of the keras model', 11 | default=None, type=str) 12 | parser.add_argument('--pytorch_model', 13 | help='the path of the pytorch model', 14 | default=None, type=str) 15 | 16 | args = parser.parse_args() 17 | 18 | f = h5py.File(args.keras_model, mode='r') 19 | state_dict = collections.OrderedDict(); 20 | for group_name, group in f.items(): 21 | if len(group.items())!=0: 22 | for layer_name, layer in group.items(): 23 | for weight_name, weight in layer.items(): 24 | state_dict[layer_name+'.'+weight_name] = weight.value 25 | 26 | replace_dict = collections.OrderedDict([ 27 | ('beta:0', 'bias'), \ 28 | ('gamma:0', 'weight'), \ 29 | ('moving_mean:0', 'running_mean'),\ 30 | ('moving_variance:0', 'running_var'),\ 31 | ('bias:0', 'bias'), \ 32 | ('kernel:0', 'weight'), \ 33 | ('mrcnn_mask_', 'mask.'), \ 34 | ('mrcnn_mask', 'mask.conv5'), \ 35 | ('mrcnn_class_', 'classifier.'), \ 36 | ('logits', 'linear_class'), \ 37 | ('mrcnn_bbox_fc', 'classifier.linear_bbox'), \ 38 | ('rpn_', 'rpn.'), \ 39 | ('class_raw', 'conv_class'), \ 40 | ('bbox_pred', 'conv_bbox'), \ 41 | ('bn_conv1', 'fpn.C1.1'), \ 42 | ('bn2a_branch1', 'fpn.C2.0.downsample.1'), \ 43 | ('res2a_branch1', 'fpn.C2.0.downsample.0'), \ 44 | ('bn3a_branch1', 'fpn.C3.0.downsample.1'), \ 45 | ('res3a_branch1', 'fpn.C3.0.downsample.0'), \ 46 | ('bn4a_branch1', 'fpn.C4.0.downsample.1'), \ 47 | ('res4a_branch1', 'fpn.C4.0.downsample.0'), \ 48 | ('bn5a_branch1', 'fpn.C5.0.downsample.1'), \ 49 | ('res5a_branch1', 'fpn.C5.0.downsample.0'), \ 50 | ('fpn_c2p2', 'fpn.P2_conv1'), \ 51 | ('fpn_c3p3', 'fpn.P3_conv1'), \ 52 | ('fpn_c4p4', 'fpn.P4_conv1'), \ 53 | ('fpn_c5p5', 'fpn.P5_conv1'), \ 54 | ('fpn_p2', 'fpn.P2_conv2.1'), \ 55 | ('fpn_p3', 'fpn.P3_conv2.1'), \ 56 | ('fpn_p4', 'fpn.P4_conv2.1'), \ 57 | ('fpn_p5', 'fpn.P5_conv2.1'), \ 58 | ]) 59 | 60 | replace_exact_dict = collections.OrderedDict([ 61 | ('conv1.bias', 'fpn.C1.0.bias'), \ 62 | ('conv1.weight', 'fpn.C1.0.weight'), \ 63 | ]) 64 | 65 | for block in range(3): 66 | for branch in range(3): 67 | replace_dict['bn2' + alphabet[block] + '_branch2' + alphabet[branch]] = 'fpn.C2.' + str(block) + '.bn' + str( 68 | branch+1) 69 | replace_dict['res2'+alphabet[block]+'_branch2'+alphabet[branch]] = 'fpn.C2.'+str(block)+'.conv'+str(branch+1) 70 | 71 | for block in range(4): 72 | for branch in range(3): 73 | replace_dict['bn3' + alphabet[block] + '_branch2' + alphabet[branch]] = 'fpn.C3.' + str(block) + '.bn' + str( 74 | branch+1) 75 | replace_dict['res3'+alphabet[block]+'_branch2'+alphabet[branch]] = 'fpn.C3.'+str(block)+'.conv'+str(branch+1) 76 | 77 | for block in range(23): 78 | for branch in range(3): 79 | replace_dict['bn4' + alphabet[block] + '_branch2' + alphabet[branch]] = 'fpn.C4.' + str(block) + '.bn' + str( 80 | branch+1) 81 | replace_dict['res4'+alphabet[block]+'_branch2'+alphabet[branch]] = 'fpn.C4.'+str(block)+'.conv'+str(branch+1) 82 | 83 | for block in range(3): 84 | for branch in range(3): 85 | replace_dict['bn5' + alphabet[block] + '_branch2' + alphabet[branch]] = 'fpn.C5.' + str(block) + '.bn' + str(branch+1) 86 | replace_dict['res5'+ alphabet[block] + '_branch2' + alphabet[branch]] = 'fpn.C5.' + str(block) + '.conv' + str(branch+1) 87 | 88 | 89 | for orig, repl in replace_dict.items(): 90 | for key in list(state_dict.keys()): 91 | if orig in key: 92 | state_dict[key.replace(orig, repl)] = state_dict[key] 93 | del state_dict[key] 94 | 95 | for orig, repl in replace_exact_dict.items(): 96 | for key in list(state_dict.keys()): 97 | if orig == key: 98 | state_dict[repl] = state_dict[key] 99 | del state_dict[key] 100 | 101 | for weight_name in list(state_dict.keys()): 102 | if state_dict[weight_name].ndim == 4: 103 | state_dict[weight_name] = state_dict[weight_name].transpose((3, 2, 0, 1)).copy(order='C') 104 | if state_dict[weight_name].ndim == 2: 105 | state_dict[weight_name] = state_dict[weight_name].transpose((1, 0)).copy(order='C') 106 | 107 | for weight_name in list(state_dict.keys()): 108 | state_dict[weight_name] = torch.from_numpy(state_dict[weight_name]) 109 | 110 | torch.save(state_dict, args.pytorch_model) -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import math 5 | import numpy as np 6 | import skimage.io 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | import coco 11 | import utils 12 | import model as modellib 13 | import visualize 14 | 15 | import torch 16 | 17 | 18 | # Root directory of the project 19 | ROOT_DIR = os.getcwd() 20 | 21 | # Directory to save logs and trained model 22 | MODEL_DIR = os.path.join(ROOT_DIR, "logs") 23 | 24 | # Path to trained weights file 25 | # Download this file and place in the root of your 26 | # project (See README file for details) 27 | COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.pth") 28 | 29 | # Directory of images to run detection on 30 | IMAGE_DIR = os.path.join(ROOT_DIR, "images") 31 | 32 | class InferenceConfig(coco.CocoConfig): 33 | # Set batch size to 1 since we'll be running inference on 34 | # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU 35 | # GPU_COUNT = 0 for CPU 36 | GPU_COUNT = 1 37 | IMAGES_PER_GPU = 1 38 | 39 | config = InferenceConfig() 40 | config.display() 41 | 42 | # Create model object. 43 | model = modellib.MaskRCNN(model_dir=MODEL_DIR, config=config) 44 | if config.GPU_COUNT: 45 | model = model.cuda() 46 | 47 | # Load weights trained on MS-COCO 48 | model.load_state_dict(torch.load(COCO_MODEL_PATH)) 49 | 50 | # COCO Class names 51 | # Index of the class in the list is its ID. For example, to get ID of 52 | # the teddy bear class, use: class_names.index('teddy bear') 53 | class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 54 | 'bus', 'train', 'truck', 'boat', 'traffic light', 55 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 56 | 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 57 | 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 58 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 59 | 'kite', 'baseball bat', 'baseball glove', 'skateboard', 60 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 61 | 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 62 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 63 | 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 64 | 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 65 | 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 66 | 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 67 | 'teddy bear', 'hair drier', 'toothbrush'] 68 | 69 | # Load a random image from the images folder 70 | file_names = next(os.walk(IMAGE_DIR))[2] 71 | image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names))) 72 | 73 | # Run detection 74 | results = model.detect([image]) 75 | 76 | # Visualize results 77 | r = results[0] 78 | visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], 79 | class_names, r['scores']) 80 | plt.show() -------------------------------------------------------------------------------- /images/1045023827_4ec3e8ba5c_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/1045023827_4ec3e8ba5c_z.jpg -------------------------------------------------------------------------------- /images/12283150_12d37e6389_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/12283150_12d37e6389_z.jpg -------------------------------------------------------------------------------- /images/2383514521_1fc8d7b0de_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/2383514521_1fc8d7b0de_z.jpg -------------------------------------------------------------------------------- /images/2502287818_41e4b0c4fb_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/2502287818_41e4b0c4fb_z.jpg -------------------------------------------------------------------------------- /images/2516944023_d00345997d_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/2516944023_d00345997d_z.jpg -------------------------------------------------------------------------------- /images/25691390_f9944f61b5_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/25691390_f9944f61b5_z.jpg -------------------------------------------------------------------------------- /images/262985539_1709e54576_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/262985539_1709e54576_z.jpg -------------------------------------------------------------------------------- /images/3132016470_c27baa00e8_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/3132016470_c27baa00e8_z.jpg -------------------------------------------------------------------------------- /images/3627527276_6fe8cd9bfe_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/3627527276_6fe8cd9bfe_z.jpg -------------------------------------------------------------------------------- /images/3651581213_f81963d1dd_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/3651581213_f81963d1dd_z.jpg -------------------------------------------------------------------------------- /images/3800883468_12af3c0b50_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/3800883468_12af3c0b50_z.jpg -------------------------------------------------------------------------------- /images/3862500489_6fd195d183_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/3862500489_6fd195d183_z.jpg -------------------------------------------------------------------------------- /images/3878153025_8fde829928_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/3878153025_8fde829928_z.jpg -------------------------------------------------------------------------------- /images/4410436637_7b0ca36ee7_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/4410436637_7b0ca36ee7_z.jpg -------------------------------------------------------------------------------- /images/4782628554_668bc31826_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/4782628554_668bc31826_z.jpg -------------------------------------------------------------------------------- /images/5951960966_d4e1cda5d0_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/5951960966_d4e1cda5d0_z.jpg -------------------------------------------------------------------------------- /images/6584515005_fce9cec486_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/6584515005_fce9cec486_z.jpg -------------------------------------------------------------------------------- /images/6821351586_59aa0dc110_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/6821351586_59aa0dc110_z.jpg -------------------------------------------------------------------------------- /images/7581246086_cf7bbb7255_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/7581246086_cf7bbb7255_z.jpg -------------------------------------------------------------------------------- /images/7933423348_c30bd9bd4e_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/7933423348_c30bd9bd4e_z.jpg -------------------------------------------------------------------------------- /images/8053677163_d4c8f416be_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/8053677163_d4c8f416be_z.jpg -------------------------------------------------------------------------------- /images/8239308689_efa6c11b08_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/8239308689_efa6c11b08_z.jpg -------------------------------------------------------------------------------- /images/8433365521_9252889f9a_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/8433365521_9252889f9a_z.jpg -------------------------------------------------------------------------------- /images/8512296263_5fc5458e20_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/8512296263_5fc5458e20_z.jpg -------------------------------------------------------------------------------- /images/8699757338_c3941051b6_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/8699757338_c3941051b6_z.jpg -------------------------------------------------------------------------------- /images/8734543718_37f6b8bd45_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/8734543718_37f6b8bd45_z.jpg -------------------------------------------------------------------------------- /images/8829708882_48f263491e_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/8829708882_48f263491e_z.jpg -------------------------------------------------------------------------------- /images/9118579087_f9ffa19e63_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/9118579087_f9ffa19e63_z.jpg -------------------------------------------------------------------------------- /images/9247489789_132c0d534a_z.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/images/9247489789_132c0d534a_z.jpg -------------------------------------------------------------------------------- /nms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/nms/__init__.py -------------------------------------------------------------------------------- /nms/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | 6 | sources = ['src/nms.c'] 7 | headers = ['src/nms.h'] 8 | defines = [] 9 | with_cuda = False 10 | 11 | if torch.cuda.is_available(): 12 | print('Including CUDA code.') 13 | sources += ['src/nms_cuda.c'] 14 | headers += ['src/nms_cuda.h'] 15 | defines += [('WITH_CUDA', None)] 16 | with_cuda = True 17 | 18 | this_file = os.path.dirname(os.path.realpath(__file__)) 19 | print(this_file) 20 | extra_objects = ['src/cuda/nms_kernel.cu.o'] 21 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 22 | 23 | ffi = create_extension( 24 | '_ext.nms', 25 | headers=headers, 26 | sources=sources, 27 | define_macros=defines, 28 | relative_to=__file__, 29 | with_cuda=with_cuda, 30 | extra_objects=extra_objects 31 | ) 32 | 33 | if __name__ == '__main__': 34 | ffi.build() 35 | -------------------------------------------------------------------------------- /nms/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from nms.pth_nms import pth_nms 12 | 13 | 14 | def nms(dets, thresh): 15 | """Dispatch to either CPU or GPU NMS implementations. 16 | Accept dets as tensor""" 17 | return pth_nms(dets, thresh) 18 | -------------------------------------------------------------------------------- /nms/pth_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ._ext import nms 3 | import numpy as np 4 | 5 | def pth_nms(dets, thresh): 6 | """ 7 | dets has to be a tensor 8 | """ 9 | if not dets.is_cuda: 10 | x1 = dets[:, 1] 11 | y1 = dets[:, 0] 12 | x2 = dets[:, 3] 13 | y2 = dets[:, 2] 14 | scores = dets[:, 4] 15 | 16 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 17 | order = scores.sort(0, descending=True)[1] 18 | # order = torch.from_numpy(np.ascontiguousarray(scores.numpy().argsort()[::-1])).long() 19 | 20 | keep = torch.LongTensor(dets.size(0)) 21 | num_out = torch.LongTensor(1) 22 | nms.cpu_nms(keep, num_out, dets, order, areas, thresh) 23 | 24 | return keep[:num_out[0]] 25 | else: 26 | x1 = dets[:, 1] 27 | y1 = dets[:, 0] 28 | x2 = dets[:, 3] 29 | y2 = dets[:, 2] 30 | scores = dets[:, 4] 31 | 32 | dets_temp = torch.FloatTensor(dets.size()).cuda() 33 | dets_temp[:, 0] = dets[:, 1] 34 | dets_temp[:, 1] = dets[:, 0] 35 | dets_temp[:, 2] = dets[:, 3] 36 | dets_temp[:, 3] = dets[:, 2] 37 | dets_temp[:, 4] = dets[:, 4] 38 | 39 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 40 | order = scores.sort(0, descending=True)[1] 41 | # order = torch.from_numpy(np.ascontiguousarray(scores.cpu().numpy().argsort()[::-1])).long().cuda() 42 | 43 | dets = dets[order].contiguous() 44 | 45 | keep = torch.LongTensor(dets.size(0)) 46 | num_out = torch.LongTensor(1) 47 | # keep = torch.cuda.LongTensor(dets.size(0)) 48 | # num_out = torch.cuda.LongTensor(1) 49 | nms.gpu_nms(keep, num_out, dets_temp, thresh) 50 | 51 | return order[keep[:num_out[0]].cuda()].contiguous() 52 | # return order[keep[:num_out[0]]].contiguous() 53 | 54 | -------------------------------------------------------------------------------- /nms/src/cuda/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | #include 12 | #include 13 | #include 14 | #include "nms_kernel.h" 15 | 16 | __device__ inline float devIoU(float const * const a, float const * const b) { 17 | float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); 18 | float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); 19 | float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f); 20 | float interS = width * height; 21 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 22 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 23 | return interS / (Sa + Sb - interS); 24 | } 25 | 26 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 27 | const float *dev_boxes, unsigned long long *dev_mask) { 28 | const int row_start = blockIdx.y; 29 | const int col_start = blockIdx.x; 30 | 31 | // if (row_start > col_start) return; 32 | 33 | const int row_size = 34 | fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 35 | const int col_size = 36 | fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 37 | 38 | __shared__ float block_boxes[threadsPerBlock * 5]; 39 | if (threadIdx.x < col_size) { 40 | block_boxes[threadIdx.x * 5 + 0] = 41 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 42 | block_boxes[threadIdx.x * 5 + 1] = 43 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 44 | block_boxes[threadIdx.x * 5 + 2] = 45 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 46 | block_boxes[threadIdx.x * 5 + 3] = 47 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 48 | block_boxes[threadIdx.x * 5 + 4] = 49 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 50 | } 51 | __syncthreads(); 52 | 53 | if (threadIdx.x < row_size) { 54 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 55 | const float *cur_box = dev_boxes + cur_box_idx * 5; 56 | int i = 0; 57 | unsigned long long t = 0; 58 | int start = 0; 59 | if (row_start == col_start) { 60 | start = threadIdx.x + 1; 61 | } 62 | for (i = start; i < col_size; i++) { 63 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 64 | t |= 1ULL << i; 65 | } 66 | } 67 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 68 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 69 | } 70 | } 71 | 72 | 73 | void _nms(int boxes_num, float * boxes_dev, 74 | unsigned long long * mask_dev, float nms_overlap_thresh) { 75 | 76 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 77 | DIVUP(boxes_num, threadsPerBlock)); 78 | dim3 threads(threadsPerBlock); 79 | nms_kernel<<>>(boxes_num, 80 | nms_overlap_thresh, 81 | boxes_dev, 82 | mask_dev); 83 | } 84 | 85 | #ifdef __cplusplus 86 | } 87 | #endif 88 | -------------------------------------------------------------------------------- /nms/src/cuda/nms_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _NMS_KERNEL 2 | #define _NMS_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 9 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 10 | 11 | void _nms(int boxes_num, float * boxes_dev, 12 | unsigned long long * mask_dev, float nms_overlap_thresh); 13 | 14 | #ifdef __cplusplus 15 | } 16 | #endif 17 | 18 | #endif 19 | 20 | -------------------------------------------------------------------------------- /nms/src/nms.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh) { 5 | // boxes has to be sorted 6 | THArgCheck(THLongTensor_isContiguous(keep_out), 0, "keep_out must be contiguous"); 7 | THArgCheck(THLongTensor_isContiguous(boxes), 2, "boxes must be contiguous"); 8 | THArgCheck(THLongTensor_isContiguous(order), 3, "order must be contiguous"); 9 | THArgCheck(THLongTensor_isContiguous(areas), 4, "areas must be contiguous"); 10 | // Number of ROIs 11 | long boxes_num = THFloatTensor_size(boxes, 0); 12 | long boxes_dim = THFloatTensor_size(boxes, 1); 13 | 14 | long * keep_out_flat = THLongTensor_data(keep_out); 15 | float * boxes_flat = THFloatTensor_data(boxes); 16 | long * order_flat = THLongTensor_data(order); 17 | float * areas_flat = THFloatTensor_data(areas); 18 | 19 | THByteTensor* suppressed = THByteTensor_newWithSize1d(boxes_num); 20 | THByteTensor_fill(suppressed, 0); 21 | unsigned char * suppressed_flat = THByteTensor_data(suppressed); 22 | 23 | // nominal indices 24 | int i, j; 25 | // sorted indices 26 | int _i, _j; 27 | // temp variables for box i's (the box currently under consideration) 28 | float ix1, iy1, ix2, iy2, iarea; 29 | // variables for computing overlap with box j (lower scoring box) 30 | float xx1, yy1, xx2, yy2; 31 | float w, h; 32 | float inter, ovr; 33 | 34 | long num_to_keep = 0; 35 | for (_i=0; _i < boxes_num; ++_i) { 36 | i = order_flat[_i]; 37 | if (suppressed_flat[i] == 1) { 38 | continue; 39 | } 40 | keep_out_flat[num_to_keep++] = i; 41 | ix1 = boxes_flat[i * boxes_dim]; 42 | iy1 = boxes_flat[i * boxes_dim + 1]; 43 | ix2 = boxes_flat[i * boxes_dim + 2]; 44 | iy2 = boxes_flat[i * boxes_dim + 3]; 45 | iarea = areas_flat[i]; 46 | for (_j = _i + 1; _j < boxes_num; ++_j) { 47 | j = order_flat[_j]; 48 | if (suppressed_flat[j] == 1) { 49 | continue; 50 | } 51 | xx1 = fmaxf(ix1, boxes_flat[j * boxes_dim]); 52 | yy1 = fmaxf(iy1, boxes_flat[j * boxes_dim + 1]); 53 | xx2 = fminf(ix2, boxes_flat[j * boxes_dim + 2]); 54 | yy2 = fminf(iy2, boxes_flat[j * boxes_dim + 3]); 55 | w = fmaxf(0.0, xx2 - xx1 + 1); 56 | h = fmaxf(0.0, yy2 - yy1 + 1); 57 | inter = w * h; 58 | ovr = inter / (iarea + areas_flat[j] - inter); 59 | if (ovr >= nms_overlap_thresh) { 60 | suppressed_flat[j] = 1; 61 | } 62 | } 63 | } 64 | 65 | long *num_out_flat = THLongTensor_data(num_out); 66 | *num_out_flat = num_to_keep; 67 | THByteTensor_free(suppressed); 68 | return 1; 69 | } -------------------------------------------------------------------------------- /nms/src/nms.h: -------------------------------------------------------------------------------- 1 | int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh); -------------------------------------------------------------------------------- /nms/src/nms_cuda.c: -------------------------------------------------------------------------------- 1 | // ------------------------------------------------------------------ 2 | // Faster R-CNN 3 | // Copyright (c) 2015 Microsoft 4 | // Licensed under The MIT License [see fast-rcnn/LICENSE for details] 5 | // Written by Shaoqing Ren 6 | // ------------------------------------------------------------------ 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "cuda/nms_kernel.h" 13 | 14 | 15 | extern THCState *state; 16 | 17 | int gpu_nms(THLongTensor * keep, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh) { 18 | // boxes has to be sorted 19 | THArgCheck(THLongTensor_isContiguous(keep), 0, "boxes must be contiguous"); 20 | THArgCheck(THCudaTensor_isContiguous(state, boxes), 2, "boxes must be contiguous"); 21 | // Number of ROIs 22 | int boxes_num = THCudaTensor_size(state, boxes, 0); 23 | int boxes_dim = THCudaTensor_size(state, boxes, 1); 24 | 25 | float* boxes_flat = THCudaTensor_data(state, boxes); 26 | 27 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 28 | THCudaLongTensor * mask = THCudaLongTensor_newWithSize2d(state, boxes_num, col_blocks); 29 | unsigned long long* mask_flat = THCudaLongTensor_data(state, mask); 30 | 31 | _nms(boxes_num, boxes_flat, mask_flat, nms_overlap_thresh); 32 | 33 | THLongTensor * mask_cpu = THLongTensor_newWithSize2d(boxes_num, col_blocks); 34 | THLongTensor_copyCuda(state, mask_cpu, mask); 35 | THCudaLongTensor_free(state, mask); 36 | 37 | unsigned long long * mask_cpu_flat = THLongTensor_data(mask_cpu); 38 | 39 | THLongTensor * remv_cpu = THLongTensor_newWithSize1d(col_blocks); 40 | unsigned long long* remv_cpu_flat = THLongTensor_data(remv_cpu); 41 | THLongTensor_fill(remv_cpu, 0); 42 | 43 | long * keep_flat = THLongTensor_data(keep); 44 | long num_to_keep = 0; 45 | 46 | int i, j; 47 | for (i = 0; i < boxes_num; i++) { 48 | int nblock = i / threadsPerBlock; 49 | int inblock = i % threadsPerBlock; 50 | 51 | if (!(remv_cpu_flat[nblock] & (1ULL << inblock))) { 52 | keep_flat[num_to_keep++] = i; 53 | unsigned long long *p = &mask_cpu_flat[0] + i * col_blocks; 54 | for (j = nblock; j < col_blocks; j++) { 55 | remv_cpu_flat[j] |= p[j]; 56 | } 57 | } 58 | } 59 | 60 | long * num_out_flat = THLongTensor_data(num_out); 61 | * num_out_flat = num_to_keep; 62 | 63 | THLongTensor_free(mask_cpu); 64 | THLongTensor_free(remv_cpu); 65 | 66 | return 1; 67 | } 68 | -------------------------------------------------------------------------------- /nms/src/nms_cuda.h: -------------------------------------------------------------------------------- 1 | int gpu_nms(THLongTensor * keep_out, THLongTensor* num_out, THCudaTensor * boxes, float nms_overlap_thresh); -------------------------------------------------------------------------------- /roialign/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/roialign/__init__.py -------------------------------------------------------------------------------- /roialign/roi_align/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/pytorch-mask-rcnn/809abba590db89779ac02c42286135f18ea08b53/roialign/roi_align/__init__.py -------------------------------------------------------------------------------- /roialign/roi_align/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.ffi import create_extension 4 | 5 | 6 | sources = ['src/crop_and_resize.c'] 7 | headers = ['src/crop_and_resize.h'] 8 | defines = [] 9 | with_cuda = False 10 | 11 | extra_objects = [] 12 | if torch.cuda.is_available(): 13 | print('Including CUDA code.') 14 | sources += ['src/crop_and_resize_gpu.c'] 15 | headers += ['src/crop_and_resize_gpu.h'] 16 | defines += [('WITH_CUDA', None)] 17 | extra_objects += ['src/cuda/crop_and_resize_kernel.cu.o'] 18 | with_cuda = True 19 | 20 | extra_compile_args = ['-fopenmp', '-std=c99'] 21 | 22 | this_file = os.path.dirname(os.path.realpath(__file__)) 23 | print(this_file) 24 | sources = [os.path.join(this_file, fname) for fname in sources] 25 | headers = [os.path.join(this_file, fname) for fname in headers] 26 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 27 | 28 | ffi = create_extension( 29 | '_ext.crop_and_resize', 30 | headers=headers, 31 | sources=sources, 32 | define_macros=defines, 33 | relative_to=__file__, 34 | with_cuda=with_cuda, 35 | extra_objects=extra_objects, 36 | extra_compile_args=extra_compile_args 37 | ) 38 | 39 | if __name__ == '__main__': 40 | ffi.build() 41 | -------------------------------------------------------------------------------- /roialign/roi_align/crop_and_resize.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | 7 | from ._ext import crop_and_resize as _backend 8 | 9 | 10 | class CropAndResizeFunction(Function): 11 | 12 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 13 | self.crop_height = crop_height 14 | self.crop_width = crop_width 15 | self.extrapolation_value = extrapolation_value 16 | 17 | def forward(self, image, boxes, box_ind): 18 | crops = torch.zeros_like(image) 19 | 20 | if image.is_cuda: 21 | _backend.crop_and_resize_gpu_forward( 22 | image, boxes, box_ind, 23 | self.extrapolation_value, self.crop_height, self.crop_width, crops) 24 | else: 25 | _backend.crop_and_resize_forward( 26 | image, boxes, box_ind, 27 | self.extrapolation_value, self.crop_height, self.crop_width, crops) 28 | 29 | # save for backward 30 | self.im_size = image.size() 31 | self.save_for_backward(boxes, box_ind) 32 | 33 | return crops 34 | 35 | def backward(self, grad_outputs): 36 | boxes, box_ind = self.saved_tensors 37 | 38 | grad_outputs = grad_outputs.contiguous() 39 | grad_image = torch.zeros_like(grad_outputs).resize_(*self.im_size) 40 | 41 | if grad_outputs.is_cuda: 42 | _backend.crop_and_resize_gpu_backward( 43 | grad_outputs, boxes, box_ind, grad_image 44 | ) 45 | else: 46 | _backend.crop_and_resize_backward( 47 | grad_outputs, boxes, box_ind, grad_image 48 | ) 49 | 50 | return grad_image, None, None 51 | 52 | 53 | class CropAndResize(nn.Module): 54 | """ 55 | Crop and resize ported from tensorflow 56 | See more details on https://www.tensorflow.org/api_docs/python/tf/image/crop_and_resize 57 | """ 58 | 59 | def __init__(self, crop_height, crop_width, extrapolation_value=0): 60 | super(CropAndResize, self).__init__() 61 | 62 | self.crop_height = crop_height 63 | self.crop_width = crop_width 64 | self.extrapolation_value = extrapolation_value 65 | 66 | def forward(self, image, boxes, box_ind): 67 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(image, boxes, box_ind) 68 | -------------------------------------------------------------------------------- /roialign/roi_align/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .crop_and_resize import CropAndResizeFunction, CropAndResize 5 | 6 | 7 | class RoIAlign(nn.Module): 8 | 9 | def __init__(self, crop_height, crop_width, extrapolation_value=0, transform_fpcoor=True): 10 | super(RoIAlign, self).__init__() 11 | 12 | self.crop_height = crop_height 13 | self.crop_width = crop_width 14 | self.extrapolation_value = extrapolation_value 15 | self.transform_fpcoor = transform_fpcoor 16 | 17 | def forward(self, featuremap, boxes, box_ind): 18 | """ 19 | RoIAlign based on crop_and_resize. 20 | See more details on https://github.com/ppwwyyxx/tensorpack/blob/6d5ba6a970710eaaa14b89d24aace179eb8ee1af/examples/FasterRCNN/model.py#L301 21 | :param featuremap: NxCxHxW 22 | :param boxes: Mx4 float box with (x1, y1, x2, y2) **without normalization** 23 | :param box_ind: M 24 | :return: MxCxoHxoW 25 | """ 26 | x1, y1, x2, y2 = torch.split(boxes, 1, dim=1) 27 | image_height, image_width = featuremap.size()[2:4] 28 | 29 | if self.transform_fpcoor: 30 | spacing_w = (x2 - x1) / float(self.crop_width) 31 | spacing_h = (y2 - y1) / float(self.crop_height) 32 | 33 | nx0 = (x1 + spacing_w / 2 - 0.5) / float(image_width - 1) 34 | ny0 = (y1 + spacing_h / 2 - 0.5) / float(image_height - 1) 35 | nw = spacing_w * float(self.crop_width - 1) / float(image_width - 1) 36 | nh = spacing_h * float(self.crop_height - 1) / float(image_height - 1) 37 | 38 | boxes = torch.cat((ny0, nx0, ny0 + nh, nx0 + nw), 1) 39 | else: 40 | x1 = x1 / float(image_width - 1) 41 | x2 = x2 / float(image_width - 1) 42 | y1 = y1 / float(image_height - 1) 43 | y2 = y2 / float(image_height - 1) 44 | boxes = torch.cat((y1, x1, y2, x2), 1) 45 | 46 | boxes = boxes.detach().contiguous() 47 | box_ind = box_ind.detach() 48 | return CropAndResizeFunction(self.crop_height, self.crop_width, self.extrapolation_value)(featuremap, boxes, box_ind) 49 | -------------------------------------------------------------------------------- /roialign/roi_align/src/crop_and_resize.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | void CropAndResizePerBox( 7 | const float * image_data, 8 | const int batch_size, 9 | const int depth, 10 | const int image_height, 11 | const int image_width, 12 | 13 | const float * boxes_data, 14 | const int * box_index_data, 15 | const int start_box, 16 | const int limit_box, 17 | 18 | float * corps_data, 19 | const int crop_height, 20 | const int crop_width, 21 | const float extrapolation_value 22 | ) { 23 | const int image_channel_elements = image_height * image_width; 24 | const int image_elements = depth * image_channel_elements; 25 | 26 | const int channel_elements = crop_height * crop_width; 27 | const int crop_elements = depth * channel_elements; 28 | 29 | int b; 30 | #pragma omp parallel for 31 | for (b = start_box; b < limit_box; ++b) { 32 | const float * box = boxes_data + b * 4; 33 | const float y1 = box[0]; 34 | const float x1 = box[1]; 35 | const float y2 = box[2]; 36 | const float x2 = box[3]; 37 | 38 | const int b_in = box_index_data[b]; 39 | if (b_in < 0 || b_in >= batch_size) { 40 | printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); 41 | exit(-1); 42 | } 43 | 44 | const float height_scale = 45 | (crop_height > 1) 46 | ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 47 | : 0; 48 | const float width_scale = 49 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) 50 | : 0; 51 | 52 | for (int y = 0; y < crop_height; ++y) 53 | { 54 | const float in_y = (crop_height > 1) 55 | ? y1 * (image_height - 1) + y * height_scale 56 | : 0.5 * (y1 + y2) * (image_height - 1); 57 | 58 | if (in_y < 0 || in_y > image_height - 1) 59 | { 60 | for (int x = 0; x < crop_width; ++x) 61 | { 62 | for (int d = 0; d < depth; ++d) 63 | { 64 | // crops(b, y, x, d) = extrapolation_value; 65 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; 66 | } 67 | } 68 | continue; 69 | } 70 | 71 | const int top_y_index = floorf(in_y); 72 | const int bottom_y_index = ceilf(in_y); 73 | const float y_lerp = in_y - top_y_index; 74 | 75 | for (int x = 0; x < crop_width; ++x) 76 | { 77 | const float in_x = (crop_width > 1) 78 | ? x1 * (image_width - 1) + x * width_scale 79 | : 0.5 * (x1 + x2) * (image_width - 1); 80 | if (in_x < 0 || in_x > image_width - 1) 81 | { 82 | for (int d = 0; d < depth; ++d) 83 | { 84 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = extrapolation_value; 85 | } 86 | continue; 87 | } 88 | 89 | const int left_x_index = floorf(in_x); 90 | const int right_x_index = ceilf(in_x); 91 | const float x_lerp = in_x - left_x_index; 92 | 93 | for (int d = 0; d < depth; ++d) 94 | { 95 | const float *pimage = image_data + b_in * image_elements + d * image_channel_elements; 96 | 97 | const float top_left = pimage[top_y_index * image_width + left_x_index]; 98 | const float top_right = pimage[top_y_index * image_width + right_x_index]; 99 | const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; 100 | const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; 101 | 102 | const float top = top_left + (top_right - top_left) * x_lerp; 103 | const float bottom = 104 | bottom_left + (bottom_right - bottom_left) * x_lerp; 105 | 106 | corps_data[crop_elements * b + channel_elements * d + y * crop_width + x] = top + (bottom - top) * y_lerp; 107 | } 108 | } // end for x 109 | } // end for y 110 | } // end for b 111 | 112 | } 113 | 114 | 115 | void crop_and_resize_forward( 116 | THFloatTensor * image, 117 | THFloatTensor * boxes, // [y1, x1, y2, x2] 118 | THIntTensor * box_index, // range in [0, batch_size) 119 | const float extrapolation_value, 120 | const int crop_height, 121 | const int crop_width, 122 | THFloatTensor * crops 123 | ) { 124 | const int batch_size = image->size[0]; 125 | const int depth = image->size[1]; 126 | const int image_height = image->size[2]; 127 | const int image_width = image->size[3]; 128 | 129 | const int num_boxes = boxes->size[0]; 130 | 131 | // init output space 132 | THFloatTensor_resize4d(crops, num_boxes, depth, crop_height, crop_width); 133 | THFloatTensor_zero(crops); 134 | 135 | // crop_and_resize for each box 136 | CropAndResizePerBox( 137 | THFloatTensor_data(image), 138 | batch_size, 139 | depth, 140 | image_height, 141 | image_width, 142 | 143 | THFloatTensor_data(boxes), 144 | THIntTensor_data(box_index), 145 | 0, 146 | num_boxes, 147 | 148 | THFloatTensor_data(crops), 149 | crop_height, 150 | crop_width, 151 | extrapolation_value 152 | ); 153 | 154 | } 155 | 156 | 157 | void crop_and_resize_backward( 158 | THFloatTensor * grads, 159 | THFloatTensor * boxes, // [y1, x1, y2, x2] 160 | THIntTensor * box_index, // range in [0, batch_size) 161 | THFloatTensor * grads_image // resize to [bsize, c, hc, wc] 162 | ) 163 | { 164 | // shape 165 | const int batch_size = grads_image->size[0]; 166 | const int depth = grads_image->size[1]; 167 | const int image_height = grads_image->size[2]; 168 | const int image_width = grads_image->size[3]; 169 | 170 | const int num_boxes = grads->size[0]; 171 | const int crop_height = grads->size[2]; 172 | const int crop_width = grads->size[3]; 173 | 174 | // n_elements 175 | const int image_channel_elements = image_height * image_width; 176 | const int image_elements = depth * image_channel_elements; 177 | 178 | const int channel_elements = crop_height * crop_width; 179 | const int crop_elements = depth * channel_elements; 180 | 181 | // init output space 182 | THFloatTensor_zero(grads_image); 183 | 184 | // data pointer 185 | const float * grads_data = THFloatTensor_data(grads); 186 | const float * boxes_data = THFloatTensor_data(boxes); 187 | const int * box_index_data = THIntTensor_data(box_index); 188 | float * grads_image_data = THFloatTensor_data(grads_image); 189 | 190 | for (int b = 0; b < num_boxes; ++b) { 191 | const float * box = boxes_data + b * 4; 192 | const float y1 = box[0]; 193 | const float x1 = box[1]; 194 | const float y2 = box[2]; 195 | const float x2 = box[3]; 196 | 197 | const int b_in = box_index_data[b]; 198 | if (b_in < 0 || b_in >= batch_size) { 199 | printf("Error: batch_index %d out of range [0, %d)\n", b_in, batch_size); 200 | exit(-1); 201 | } 202 | 203 | const float height_scale = 204 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 205 | : 0; 206 | const float width_scale = 207 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) 208 | : 0; 209 | 210 | for (int y = 0; y < crop_height; ++y) 211 | { 212 | const float in_y = (crop_height > 1) 213 | ? y1 * (image_height - 1) + y * height_scale 214 | : 0.5 * (y1 + y2) * (image_height - 1); 215 | if (in_y < 0 || in_y > image_height - 1) 216 | { 217 | continue; 218 | } 219 | const int top_y_index = floorf(in_y); 220 | const int bottom_y_index = ceilf(in_y); 221 | const float y_lerp = in_y - top_y_index; 222 | 223 | for (int x = 0; x < crop_width; ++x) 224 | { 225 | const float in_x = (crop_width > 1) 226 | ? x1 * (image_width - 1) + x * width_scale 227 | : 0.5 * (x1 + x2) * (image_width - 1); 228 | if (in_x < 0 || in_x > image_width - 1) 229 | { 230 | continue; 231 | } 232 | const int left_x_index = floorf(in_x); 233 | const int right_x_index = ceilf(in_x); 234 | const float x_lerp = in_x - left_x_index; 235 | 236 | for (int d = 0; d < depth; ++d) 237 | { 238 | float *pimage = grads_image_data + b_in * image_elements + d * image_channel_elements; 239 | const float grad_val = grads_data[crop_elements * b + channel_elements * d + y * crop_width + x]; 240 | 241 | const float dtop = (1 - y_lerp) * grad_val; 242 | pimage[top_y_index * image_width + left_x_index] += (1 - x_lerp) * dtop; 243 | pimage[top_y_index * image_width + right_x_index] += x_lerp * dtop; 244 | 245 | const float dbottom = y_lerp * grad_val; 246 | pimage[bottom_y_index * image_width + left_x_index] += (1 - x_lerp) * dbottom; 247 | pimage[bottom_y_index * image_width + right_x_index] += x_lerp * dbottom; 248 | } // end d 249 | } // end x 250 | } // end y 251 | } // end b 252 | } -------------------------------------------------------------------------------- /roialign/roi_align/src/crop_and_resize.h: -------------------------------------------------------------------------------- 1 | void crop_and_resize_forward( 2 | THFloatTensor * image, 3 | THFloatTensor * boxes, // [y1, x1, y2, x2] 4 | THIntTensor * box_index, // range in [0, batch_size) 5 | const float extrapolation_value, 6 | const int crop_height, 7 | const int crop_width, 8 | THFloatTensor * crops 9 | ); 10 | 11 | void crop_and_resize_backward( 12 | THFloatTensor * grads, 13 | THFloatTensor * boxes, // [y1, x1, y2, x2] 14 | THIntTensor * box_index, // range in [0, batch_size) 15 | THFloatTensor * grads_image // resize to [bsize, c, hc, wc] 16 | ); -------------------------------------------------------------------------------- /roialign/roi_align/src/crop_and_resize_gpu.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cuda/crop_and_resize_kernel.h" 3 | 4 | extern THCState *state; 5 | 6 | 7 | void crop_and_resize_gpu_forward( 8 | THCudaTensor * image, 9 | THCudaTensor * boxes, // [y1, x1, y2, x2] 10 | THCudaIntTensor * box_index, // range in [0, batch_size) 11 | const float extrapolation_value, 12 | const int crop_height, 13 | const int crop_width, 14 | THCudaTensor * crops 15 | ) { 16 | const int batch_size = THCudaTensor_size(state, image, 0); 17 | const int depth = THCudaTensor_size(state, image, 1); 18 | const int image_height = THCudaTensor_size(state, image, 2); 19 | const int image_width = THCudaTensor_size(state, image, 3); 20 | 21 | const int num_boxes = THCudaTensor_size(state, boxes, 0); 22 | 23 | // init output space 24 | THCudaTensor_resize4d(state, crops, num_boxes, depth, crop_height, crop_width); 25 | THCudaTensor_zero(state, crops); 26 | 27 | cudaStream_t stream = THCState_getCurrentStream(state); 28 | CropAndResizeLaucher( 29 | THCudaTensor_data(state, image), 30 | THCudaTensor_data(state, boxes), 31 | THCudaIntTensor_data(state, box_index), 32 | num_boxes, batch_size, image_height, image_width, 33 | crop_height, crop_width, depth, extrapolation_value, 34 | THCudaTensor_data(state, crops), 35 | stream 36 | ); 37 | } 38 | 39 | 40 | void crop_and_resize_gpu_backward( 41 | THCudaTensor * grads, 42 | THCudaTensor * boxes, // [y1, x1, y2, x2] 43 | THCudaIntTensor * box_index, // range in [0, batch_size) 44 | THCudaTensor * grads_image // resize to [bsize, c, hc, wc] 45 | ) { 46 | // shape 47 | const int batch_size = THCudaTensor_size(state, grads_image, 0); 48 | const int depth = THCudaTensor_size(state, grads_image, 1); 49 | const int image_height = THCudaTensor_size(state, grads_image, 2); 50 | const int image_width = THCudaTensor_size(state, grads_image, 3); 51 | 52 | const int num_boxes = THCudaTensor_size(state, grads, 0); 53 | const int crop_height = THCudaTensor_size(state, grads, 2); 54 | const int crop_width = THCudaTensor_size(state, grads, 3); 55 | 56 | // init output space 57 | THCudaTensor_zero(state, grads_image); 58 | 59 | cudaStream_t stream = THCState_getCurrentStream(state); 60 | CropAndResizeBackpropImageLaucher( 61 | THCudaTensor_data(state, grads), 62 | THCudaTensor_data(state, boxes), 63 | THCudaIntTensor_data(state, box_index), 64 | num_boxes, batch_size, image_height, image_width, 65 | crop_height, crop_width, depth, 66 | THCudaTensor_data(state, grads_image), 67 | stream 68 | ); 69 | } -------------------------------------------------------------------------------- /roialign/roi_align/src/crop_and_resize_gpu.h: -------------------------------------------------------------------------------- 1 | void crop_and_resize_gpu_forward( 2 | THCudaTensor * image, 3 | THCudaTensor * boxes, // [y1, x1, y2, x2] 4 | THCudaIntTensor * box_index, // range in [0, batch_size) 5 | const float extrapolation_value, 6 | const int crop_height, 7 | const int crop_width, 8 | THCudaTensor * crops 9 | ); 10 | 11 | void crop_and_resize_gpu_backward( 12 | THCudaTensor * grads, 13 | THCudaTensor * boxes, // [y1, x1, y2, x2] 14 | THCudaIntTensor * box_index, // range in [0, batch_size) 15 | THCudaTensor * grads_image // resize to [bsize, c, hc, wc] 16 | ); -------------------------------------------------------------------------------- /roialign/roi_align/src/cuda/crop_and_resize_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "crop_and_resize_kernel.h" 4 | 5 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 7 | i += blockDim.x * gridDim.x) 8 | 9 | 10 | __global__ 11 | void CropAndResizeKernel( 12 | const int nthreads, const float *image_ptr, const float *boxes_ptr, 13 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 14 | int image_width, int crop_height, int crop_width, int depth, 15 | float extrapolation_value, float *crops_ptr) 16 | { 17 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 18 | { 19 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 20 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 21 | int idx = out_idx; 22 | const int x = idx % crop_width; 23 | idx /= crop_width; 24 | const int y = idx % crop_height; 25 | idx /= crop_height; 26 | const int d = idx % depth; 27 | const int b = idx / depth; 28 | 29 | const float y1 = boxes_ptr[b * 4]; 30 | const float x1 = boxes_ptr[b * 4 + 1]; 31 | const float y2 = boxes_ptr[b * 4 + 2]; 32 | const float x2 = boxes_ptr[b * 4 + 3]; 33 | 34 | const int b_in = box_ind_ptr[b]; 35 | if (b_in < 0 || b_in >= batch) 36 | { 37 | continue; 38 | } 39 | 40 | const float height_scale = 41 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 42 | : 0; 43 | const float width_scale = 44 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 45 | 46 | const float in_y = (crop_height > 1) 47 | ? y1 * (image_height - 1) + y * height_scale 48 | : 0.5 * (y1 + y2) * (image_height - 1); 49 | if (in_y < 0 || in_y > image_height - 1) 50 | { 51 | crops_ptr[out_idx] = extrapolation_value; 52 | continue; 53 | } 54 | 55 | const float in_x = (crop_width > 1) 56 | ? x1 * (image_width - 1) + x * width_scale 57 | : 0.5 * (x1 + x2) * (image_width - 1); 58 | if (in_x < 0 || in_x > image_width - 1) 59 | { 60 | crops_ptr[out_idx] = extrapolation_value; 61 | continue; 62 | } 63 | 64 | const int top_y_index = floorf(in_y); 65 | const int bottom_y_index = ceilf(in_y); 66 | const float y_lerp = in_y - top_y_index; 67 | 68 | const int left_x_index = floorf(in_x); 69 | const int right_x_index = ceilf(in_x); 70 | const float x_lerp = in_x - left_x_index; 71 | 72 | const float *pimage = image_ptr + (b_in * depth + d) * image_height * image_width; 73 | const float top_left = pimage[top_y_index * image_width + left_x_index]; 74 | const float top_right = pimage[top_y_index * image_width + right_x_index]; 75 | const float bottom_left = pimage[bottom_y_index * image_width + left_x_index]; 76 | const float bottom_right = pimage[bottom_y_index * image_width + right_x_index]; 77 | 78 | const float top = top_left + (top_right - top_left) * x_lerp; 79 | const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; 80 | crops_ptr[out_idx] = top + (bottom - top) * y_lerp; 81 | } 82 | } 83 | 84 | __global__ 85 | void CropAndResizeBackpropImageKernel( 86 | const int nthreads, const float *grads_ptr, const float *boxes_ptr, 87 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 88 | int image_width, int crop_height, int crop_width, int depth, 89 | float *grads_image_ptr) 90 | { 91 | CUDA_1D_KERNEL_LOOP(out_idx, nthreads) 92 | { 93 | // NHWC: out_idx = d + depth * (w + crop_width * (h + crop_height * b)) 94 | // NCHW: out_idx = w + crop_width * (h + crop_height * (d + depth * b)) 95 | int idx = out_idx; 96 | const int x = idx % crop_width; 97 | idx /= crop_width; 98 | const int y = idx % crop_height; 99 | idx /= crop_height; 100 | const int d = idx % depth; 101 | const int b = idx / depth; 102 | 103 | const float y1 = boxes_ptr[b * 4]; 104 | const float x1 = boxes_ptr[b * 4 + 1]; 105 | const float y2 = boxes_ptr[b * 4 + 2]; 106 | const float x2 = boxes_ptr[b * 4 + 3]; 107 | 108 | const int b_in = box_ind_ptr[b]; 109 | if (b_in < 0 || b_in >= batch) 110 | { 111 | continue; 112 | } 113 | 114 | const float height_scale = 115 | (crop_height > 1) ? (y2 - y1) * (image_height - 1) / (crop_height - 1) 116 | : 0; 117 | const float width_scale = 118 | (crop_width > 1) ? (x2 - x1) * (image_width - 1) / (crop_width - 1) : 0; 119 | 120 | const float in_y = (crop_height > 1) 121 | ? y1 * (image_height - 1) + y * height_scale 122 | : 0.5 * (y1 + y2) * (image_height - 1); 123 | if (in_y < 0 || in_y > image_height - 1) 124 | { 125 | continue; 126 | } 127 | 128 | const float in_x = (crop_width > 1) 129 | ? x1 * (image_width - 1) + x * width_scale 130 | : 0.5 * (x1 + x2) * (image_width - 1); 131 | if (in_x < 0 || in_x > image_width - 1) 132 | { 133 | continue; 134 | } 135 | 136 | const int top_y_index = floorf(in_y); 137 | const int bottom_y_index = ceilf(in_y); 138 | const float y_lerp = in_y - top_y_index; 139 | 140 | const int left_x_index = floorf(in_x); 141 | const int right_x_index = ceilf(in_x); 142 | const float x_lerp = in_x - left_x_index; 143 | 144 | float *pimage = grads_image_ptr + (b_in * depth + d) * image_height * image_width; 145 | const float dtop = (1 - y_lerp) * grads_ptr[out_idx]; 146 | atomicAdd( 147 | pimage + top_y_index * image_width + left_x_index, 148 | (1 - x_lerp) * dtop 149 | ); 150 | atomicAdd( 151 | pimage + top_y_index * image_width + right_x_index, 152 | x_lerp * dtop 153 | ); 154 | 155 | const float dbottom = y_lerp * grads_ptr[out_idx]; 156 | atomicAdd( 157 | pimage + bottom_y_index * image_width + left_x_index, 158 | (1 - x_lerp) * dbottom 159 | ); 160 | atomicAdd( 161 | pimage + bottom_y_index * image_width + right_x_index, 162 | x_lerp * dbottom 163 | ); 164 | } 165 | } 166 | 167 | 168 | void CropAndResizeLaucher( 169 | const float *image_ptr, const float *boxes_ptr, 170 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 171 | int image_width, int crop_height, int crop_width, int depth, 172 | float extrapolation_value, float *crops_ptr, cudaStream_t stream) 173 | { 174 | const int total_count = num_boxes * crop_height * crop_width * depth; 175 | const int thread_per_block = 1024; 176 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 177 | cudaError_t err; 178 | 179 | if (total_count > 0) 180 | { 181 | CropAndResizeKernel<<>>( 182 | total_count, image_ptr, boxes_ptr, 183 | box_ind_ptr, num_boxes, batch, image_height, image_width, 184 | crop_height, crop_width, depth, extrapolation_value, crops_ptr); 185 | 186 | err = cudaGetLastError(); 187 | if (cudaSuccess != err) 188 | { 189 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 190 | exit(-1); 191 | } 192 | } 193 | } 194 | 195 | 196 | void CropAndResizeBackpropImageLaucher( 197 | const float *grads_ptr, const float *boxes_ptr, 198 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 199 | int image_width, int crop_height, int crop_width, int depth, 200 | float *grads_image_ptr, cudaStream_t stream) 201 | { 202 | const int total_count = num_boxes * crop_height * crop_width * depth; 203 | const int thread_per_block = 1024; 204 | const int block_count = (total_count + thread_per_block - 1) / thread_per_block; 205 | cudaError_t err; 206 | 207 | if (total_count > 0) 208 | { 209 | CropAndResizeBackpropImageKernel<<>>( 210 | total_count, grads_ptr, boxes_ptr, 211 | box_ind_ptr, num_boxes, batch, image_height, image_width, 212 | crop_height, crop_width, depth, grads_image_ptr); 213 | 214 | err = cudaGetLastError(); 215 | if (cudaSuccess != err) 216 | { 217 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 218 | exit(-1); 219 | } 220 | } 221 | } -------------------------------------------------------------------------------- /roialign/roi_align/src/cuda/crop_and_resize_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _CropAndResize_Kernel 2 | #define _CropAndResize_Kernel 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void CropAndResizeLaucher( 9 | const float *image_ptr, const float *boxes_ptr, 10 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 11 | int image_width, int crop_height, int crop_width, int depth, 12 | float extrapolation_value, float *crops_ptr, cudaStream_t stream); 13 | 14 | void CropAndResizeBackpropImageLaucher( 15 | const float *grads_ptr, const float *boxes_ptr, 16 | const int *box_ind_ptr, int num_boxes, int batch, int image_height, 17 | int image_width, int crop_height, int crop_width, int depth, 18 | float *grads_image_ptr, cudaStream_t stream); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | 24 | #endif -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mask R-CNN 3 | Common utility functions and classes. 4 | 5 | Copyright (c) 2017 Matterport, Inc. 6 | Licensed under the MIT License (see LICENSE for details) 7 | Written by Waleed Abdulla 8 | """ 9 | 10 | import sys 11 | import os 12 | import math 13 | import random 14 | import numpy as np 15 | import scipy.misc 16 | import scipy.ndimage 17 | import skimage.color 18 | import skimage.io 19 | import torch 20 | 21 | ############################################################ 22 | # Bounding Boxes 23 | ############################################################ 24 | 25 | def extract_bboxes(mask): 26 | """Compute bounding boxes from masks. 27 | mask: [height, width, num_instances]. Mask pixels are either 1 or 0. 28 | 29 | Returns: bbox array [num_instances, (y1, x1, y2, x2)]. 30 | """ 31 | boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32) 32 | for i in range(mask.shape[-1]): 33 | m = mask[:, :, i] 34 | # Bounding box. 35 | horizontal_indicies = np.where(np.any(m, axis=0))[0] 36 | vertical_indicies = np.where(np.any(m, axis=1))[0] 37 | if horizontal_indicies.shape[0]: 38 | x1, x2 = horizontal_indicies[[0, -1]] 39 | y1, y2 = vertical_indicies[[0, -1]] 40 | # x2 and y2 should not be part of the box. Increment by 1. 41 | x2 += 1 42 | y2 += 1 43 | else: 44 | # No mask for this instance. Might happen due to 45 | # resizing or cropping. Set bbox to zeros 46 | x1, x2, y1, y2 = 0, 0, 0, 0 47 | boxes[i] = np.array([y1, x1, y2, x2]) 48 | return boxes.astype(np.int32) 49 | 50 | 51 | def compute_iou(box, boxes, box_area, boxes_area): 52 | """Calculates IoU of the given box with the array of the given boxes. 53 | box: 1D vector [y1, x1, y2, x2] 54 | boxes: [boxes_count, (y1, x1, y2, x2)] 55 | box_area: float. the area of 'box' 56 | boxes_area: array of length boxes_count. 57 | 58 | Note: the areas are passed in rather than calculated here for 59 | efficency. Calculate once in the caller to avoid duplicate work. 60 | """ 61 | # Calculate intersection areas 62 | y1 = np.maximum(box[0], boxes[:, 0]) 63 | y2 = np.minimum(box[2], boxes[:, 2]) 64 | x1 = np.maximum(box[1], boxes[:, 1]) 65 | x2 = np.minimum(box[3], boxes[:, 3]) 66 | intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) 67 | union = box_area + boxes_area[:] - intersection[:] 68 | iou = intersection / union 69 | return iou 70 | 71 | 72 | def compute_overlaps(boxes1, boxes2): 73 | """Computes IoU overlaps between two sets of boxes. 74 | boxes1, boxes2: [N, (y1, x1, y2, x2)]. 75 | 76 | For better performance, pass the largest set first and the smaller second. 77 | """ 78 | # Areas of anchors and GT boxes 79 | area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) 80 | area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) 81 | 82 | # Compute overlaps to generate matrix [boxes1 count, boxes2 count] 83 | # Each cell contains the IoU value. 84 | overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) 85 | for i in range(overlaps.shape[1]): 86 | box2 = boxes2[i] 87 | overlaps[:, i] = compute_iou(box2, boxes1, area2[i], area1) 88 | return overlaps 89 | 90 | def box_refinement(box, gt_box): 91 | """Compute refinement needed to transform box to gt_box. 92 | box and gt_box are [N, (y1, x1, y2, x2)] 93 | """ 94 | 95 | height = box[:, 2] - box[:, 0] 96 | width = box[:, 3] - box[:, 1] 97 | center_y = box[:, 0] + 0.5 * height 98 | center_x = box[:, 1] + 0.5 * width 99 | 100 | gt_height = gt_box[:, 2] - gt_box[:, 0] 101 | gt_width = gt_box[:, 3] - gt_box[:, 1] 102 | gt_center_y = gt_box[:, 0] + 0.5 * gt_height 103 | gt_center_x = gt_box[:, 1] + 0.5 * gt_width 104 | 105 | dy = (gt_center_y - center_y) / height 106 | dx = (gt_center_x - center_x) / width 107 | dh = torch.log(gt_height / height) 108 | dw = torch.log(gt_width / width) 109 | 110 | result = torch.stack([dy, dx, dh, dw], dim=1) 111 | return result 112 | 113 | 114 | ############################################################ 115 | # Dataset 116 | ############################################################ 117 | 118 | class Dataset(object): 119 | """The base class for dataset classes. 120 | To use it, create a new class that adds functions specific to the dataset 121 | you want to use. For example: 122 | 123 | class CatsAndDogsDataset(Dataset): 124 | def load_cats_and_dogs(self): 125 | ... 126 | def load_mask(self, image_id): 127 | ... 128 | def image_reference(self, image_id): 129 | ... 130 | 131 | See COCODataset and ShapesDataset as examples. 132 | """ 133 | 134 | def __init__(self, class_map=None): 135 | self._image_ids = [] 136 | self.image_info = [] 137 | # Background is always the first class 138 | self.class_info = [{"source": "", "id": 0, "name": "BG"}] 139 | self.source_class_ids = {} 140 | 141 | def add_class(self, source, class_id, class_name): 142 | assert "." not in source, "Source name cannot contain a dot" 143 | # Does the class exist already? 144 | for info in self.class_info: 145 | if info['source'] == source and info["id"] == class_id: 146 | # source.class_id combination already available, skip 147 | return 148 | # Add the class 149 | self.class_info.append({ 150 | "source": source, 151 | "id": class_id, 152 | "name": class_name, 153 | }) 154 | 155 | def add_image(self, source, image_id, path, **kwargs): 156 | image_info = { 157 | "id": image_id, 158 | "source": source, 159 | "path": path, 160 | } 161 | image_info.update(kwargs) 162 | self.image_info.append(image_info) 163 | 164 | def image_reference(self, image_id): 165 | """Return a link to the image in its source Website or details about 166 | the image that help looking it up or debugging it. 167 | 168 | Override for your dataset, but pass to this function 169 | if you encounter images not in your dataset. 170 | """ 171 | return "" 172 | 173 | def prepare(self, class_map=None): 174 | """Prepares the Dataset class for use. 175 | 176 | TODO: class map is not supported yet. When done, it should handle mapping 177 | classes from different datasets to the same class ID. 178 | """ 179 | def clean_name(name): 180 | """Returns a shorter version of object names for cleaner display.""" 181 | return ",".join(name.split(",")[:1]) 182 | 183 | # Build (or rebuild) everything else from the info dicts. 184 | self.num_classes = len(self.class_info) 185 | self.class_ids = np.arange(self.num_classes) 186 | self.class_names = [clean_name(c["name"]) for c in self.class_info] 187 | self.num_images = len(self.image_info) 188 | self._image_ids = np.arange(self.num_images) 189 | 190 | self.class_from_source_map = {"{}.{}".format(info['source'], info['id']): id 191 | for info, id in zip(self.class_info, self.class_ids)} 192 | 193 | # Map sources to class_ids they support 194 | self.sources = list(set([i['source'] for i in self.class_info])) 195 | self.source_class_ids = {} 196 | # Loop over datasets 197 | for source in self.sources: 198 | self.source_class_ids[source] = [] 199 | # Find classes that belong to this dataset 200 | for i, info in enumerate(self.class_info): 201 | # Include BG class in all datasets 202 | if i == 0 or source == info['source']: 203 | self.source_class_ids[source].append(i) 204 | 205 | def map_source_class_id(self, source_class_id): 206 | """Takes a source class ID and returns the int class ID assigned to it. 207 | 208 | For example: 209 | dataset.map_source_class_id("coco.12") -> 23 210 | """ 211 | return self.class_from_source_map[source_class_id] 212 | 213 | def get_source_class_id(self, class_id, source): 214 | """Map an internal class ID to the corresponding class ID in the source dataset.""" 215 | info = self.class_info[class_id] 216 | assert info['source'] == source 217 | return info['id'] 218 | 219 | def append_data(self, class_info, image_info): 220 | self.external_to_class_id = {} 221 | for i, c in enumerate(self.class_info): 222 | for ds, id in c["map"]: 223 | self.external_to_class_id[ds + str(id)] = i 224 | 225 | # Map external image IDs to internal ones. 226 | self.external_to_image_id = {} 227 | for i, info in enumerate(self.image_info): 228 | self.external_to_image_id[info["ds"] + str(info["id"])] = i 229 | 230 | @property 231 | def image_ids(self): 232 | return self._image_ids 233 | 234 | def source_image_link(self, image_id): 235 | """Returns the path or URL to the image. 236 | Override this to return a URL to the image if it's availble online for easy 237 | debugging. 238 | """ 239 | return self.image_info[image_id]["path"] 240 | 241 | def load_image(self, image_id): 242 | """Load the specified image and return a [H,W,3] Numpy array. 243 | """ 244 | # Load image 245 | image = skimage.io.imread(self.image_info[image_id]['path']) 246 | # If grayscale. Convert to RGB for consistency. 247 | if image.ndim != 3: 248 | image = skimage.color.gray2rgb(image) 249 | return image 250 | 251 | def load_mask(self, image_id): 252 | """Load instance masks for the given image. 253 | 254 | Different datasets use different ways to store masks. Override this 255 | method to load instance masks and return them in the form of am 256 | array of binary masks of shape [height, width, instances]. 257 | 258 | Returns: 259 | masks: A bool array of shape [height, width, instance count] with 260 | a binary mask per instance. 261 | class_ids: a 1D array of class IDs of the instance masks. 262 | """ 263 | # Override this function to load a mask from your dataset. 264 | # Otherwise, it returns an empty mask. 265 | mask = np.empty([0, 0, 0]) 266 | class_ids = np.empty([0], np.int32) 267 | return mask, class_ids 268 | 269 | 270 | def resize_image(image, min_dim=None, max_dim=None, padding=False): 271 | """ 272 | Resizes an image keeping the aspect ratio. 273 | 274 | min_dim: if provided, resizes the image such that it's smaller 275 | dimension == min_dim 276 | max_dim: if provided, ensures that the image longest side doesn't 277 | exceed this value. 278 | padding: If true, pads image with zeros so it's size is max_dim x max_dim 279 | 280 | Returns: 281 | image: the resized image 282 | window: (y1, x1, y2, x2). If max_dim is provided, padding might 283 | be inserted in the returned image. If so, this window is the 284 | coordinates of the image part of the full image (excluding 285 | the padding). The x2, y2 pixels are not included. 286 | scale: The scale factor used to resize the image 287 | padding: Padding added to the image [(top, bottom), (left, right), (0, 0)] 288 | """ 289 | # Default window (y1, x1, y2, x2) and default scale == 1. 290 | h, w = image.shape[:2] 291 | window = (0, 0, h, w) 292 | scale = 1 293 | 294 | # Scale? 295 | if min_dim: 296 | # Scale up but not down 297 | scale = max(1, min_dim / min(h, w)) 298 | # Does it exceed max dim? 299 | if max_dim: 300 | image_max = max(h, w) 301 | if round(image_max * scale) > max_dim: 302 | scale = max_dim / image_max 303 | # Resize image and mask 304 | if scale != 1: 305 | image = scipy.misc.imresize( 306 | image, (round(h * scale), round(w * scale))) 307 | # Need padding? 308 | if padding: 309 | # Get new height and width 310 | h, w = image.shape[:2] 311 | top_pad = (max_dim - h) // 2 312 | bottom_pad = max_dim - h - top_pad 313 | left_pad = (max_dim - w) // 2 314 | right_pad = max_dim - w - left_pad 315 | padding = [(top_pad, bottom_pad), (left_pad, right_pad), (0, 0)] 316 | image = np.pad(image, padding, mode='constant', constant_values=0) 317 | window = (top_pad, left_pad, h + top_pad, w + left_pad) 318 | return image, window, scale, padding 319 | 320 | 321 | def resize_mask(mask, scale, padding): 322 | """Resizes a mask using the given scale and padding. 323 | Typically, you get the scale and padding from resize_image() to 324 | ensure both, the image and the mask, are resized consistently. 325 | 326 | scale: mask scaling factor 327 | padding: Padding to add to the mask in the form 328 | [(top, bottom), (left, right), (0, 0)] 329 | """ 330 | h, w = mask.shape[:2] 331 | mask = scipy.ndimage.zoom(mask, zoom=[scale, scale, 1], order=0) 332 | mask = np.pad(mask, padding, mode='constant', constant_values=0) 333 | return mask 334 | 335 | 336 | def minimize_mask(bbox, mask, mini_shape): 337 | """Resize masks to a smaller version to cut memory load. 338 | Mini-masks can then resized back to image scale using expand_masks() 339 | 340 | See inspect_data.ipynb notebook for more details. 341 | """ 342 | mini_mask = np.zeros(mini_shape + (mask.shape[-1],), dtype=bool) 343 | for i in range(mask.shape[-1]): 344 | m = mask[:, :, i] 345 | y1, x1, y2, x2 = bbox[i][:4] 346 | m = m[y1:y2, x1:x2] 347 | if m.size == 0: 348 | raise Exception("Invalid bounding box with area of zero") 349 | m = scipy.misc.imresize(m.astype(float), mini_shape, interp='bilinear') 350 | mini_mask[:, :, i] = np.where(m >= 128, 1, 0) 351 | return mini_mask 352 | 353 | 354 | def expand_mask(bbox, mini_mask, image_shape): 355 | """Resizes mini masks back to image size. Reverses the change 356 | of minimize_mask(). 357 | 358 | See inspect_data.ipynb notebook for more details. 359 | """ 360 | mask = np.zeros(image_shape[:2] + (mini_mask.shape[-1],), dtype=bool) 361 | for i in range(mask.shape[-1]): 362 | m = mini_mask[:, :, i] 363 | y1, x1, y2, x2 = bbox[i][:4] 364 | h = y2 - y1 365 | w = x2 - x1 366 | m = scipy.misc.imresize(m.astype(float), (h, w), interp='bilinear') 367 | mask[y1:y2, x1:x2, i] = np.where(m >= 128, 1, 0) 368 | return mask 369 | 370 | 371 | # TODO: Build and use this function to reduce code duplication 372 | def mold_mask(mask, config): 373 | pass 374 | 375 | 376 | def unmold_mask(mask, bbox, image_shape): 377 | """Converts a mask generated by the neural network into a format similar 378 | to it's original shape. 379 | mask: [height, width] of type float. A small, typically 28x28 mask. 380 | bbox: [y1, x1, y2, x2]. The box to fit the mask in. 381 | 382 | Returns a binary mask with the same size as the original image. 383 | """ 384 | threshold = 0.5 385 | y1, x1, y2, x2 = bbox 386 | mask = scipy.misc.imresize( 387 | mask, (y2 - y1, x2 - x1), interp='bilinear').astype(np.float32) / 255.0 388 | mask = np.where(mask >= threshold, 1, 0).astype(np.uint8) 389 | 390 | # Put the mask in the right location. 391 | full_mask = np.zeros(image_shape[:2], dtype=np.uint8) 392 | full_mask[y1:y2, x1:x2] = mask 393 | return full_mask 394 | 395 | 396 | ############################################################ 397 | # Anchors 398 | ############################################################ 399 | 400 | def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride): 401 | """ 402 | scales: 1D array of anchor sizes in pixels. Example: [32, 64, 128] 403 | ratios: 1D array of anchor ratios of width/height. Example: [0.5, 1, 2] 404 | shape: [height, width] spatial shape of the feature map over which 405 | to generate anchors. 406 | feature_stride: Stride of the feature map relative to the image in pixels. 407 | anchor_stride: Stride of anchors on the feature map. For example, if the 408 | value is 2 then generate anchors for every other feature map pixel. 409 | """ 410 | # Get all combinations of scales and ratios 411 | scales, ratios = np.meshgrid(np.array(scales), np.array(ratios)) 412 | scales = scales.flatten() 413 | ratios = ratios.flatten() 414 | 415 | # Enumerate heights and widths from scales and ratios 416 | heights = scales / np.sqrt(ratios) 417 | widths = scales * np.sqrt(ratios) 418 | 419 | # Enumerate shifts in feature space 420 | shifts_y = np.arange(0, shape[0], anchor_stride) * feature_stride 421 | shifts_x = np.arange(0, shape[1], anchor_stride) * feature_stride 422 | shifts_x, shifts_y = np.meshgrid(shifts_x, shifts_y) 423 | 424 | # Enumerate combinations of shifts, widths, and heights 425 | box_widths, box_centers_x = np.meshgrid(widths, shifts_x) 426 | box_heights, box_centers_y = np.meshgrid(heights, shifts_y) 427 | 428 | # Reshape to get a list of (y, x) and a list of (h, w) 429 | box_centers = np.stack( 430 | [box_centers_y, box_centers_x], axis=2).reshape([-1, 2]) 431 | box_sizes = np.stack([box_heights, box_widths], axis=2).reshape([-1, 2]) 432 | 433 | # Convert to corner coordinates (y1, x1, y2, x2) 434 | boxes = np.concatenate([box_centers - 0.5 * box_sizes, 435 | box_centers + 0.5 * box_sizes], axis=1) 436 | return boxes 437 | 438 | 439 | def generate_pyramid_anchors(scales, ratios, feature_shapes, feature_strides, 440 | anchor_stride): 441 | """Generate anchors at different levels of a feature pyramid. Each scale 442 | is associated with a level of the pyramid, but each ratio is used in 443 | all levels of the pyramid. 444 | 445 | Returns: 446 | anchors: [N, (y1, x1, y2, x2)]. All generated anchors in one array. Sorted 447 | with the same order of the given scales. So, anchors of scale[0] come 448 | first, then anchors of scale[1], and so on. 449 | """ 450 | # Anchors 451 | # [anchor_count, (y1, x1, y2, x2)] 452 | anchors = [] 453 | for i in range(len(scales)): 454 | anchors.append(generate_anchors(scales[i], ratios, feature_shapes[i], 455 | feature_strides[i], anchor_stride)) 456 | return np.concatenate(anchors, axis=0) 457 | 458 | 459 | 460 | 461 | 462 | 463 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mask R-CNN 3 | Display and Visualization Functions. 4 | 5 | Copyright (c) 2017 Matterport, Inc. 6 | Licensed under the MIT License (see LICENSE for details) 7 | Written by Waleed Abdulla 8 | """ 9 | 10 | import os 11 | import random 12 | import itertools 13 | import colorsys 14 | import numpy as np 15 | from skimage.measure import find_contours 16 | import matplotlib.pyplot as plt 17 | if "DISPLAY" not in os.environ: 18 | plt.switch_backend('agg') 19 | import matplotlib.patches as patches 20 | import matplotlib.lines as lines 21 | from matplotlib.patches import Polygon 22 | 23 | import utils 24 | 25 | 26 | ############################################################ 27 | # Visualization 28 | ############################################################ 29 | 30 | def display_images(images, titles=None, cols=4, cmap=None, norm=None, 31 | interpolation=None): 32 | """Display the given set of images, optionally with titles. 33 | images: list or array of image tensors in HWC format. 34 | titles: optional. A list of titles to display with each image. 35 | cols: number of images per row 36 | cmap: Optional. Color map to use. For example, "Blues". 37 | norm: Optional. A Normalize instance to map values to colors. 38 | interpolation: Optional. Image interporlation to use for display. 39 | """ 40 | titles = titles if titles is not None else [""] * len(images) 41 | rows = len(images) // cols + 1 42 | plt.figure(figsize=(14, 14 * rows // cols)) 43 | i = 1 44 | for image, title in zip(images, titles): 45 | plt.subplot(rows, cols, i) 46 | plt.title(title, fontsize=9) 47 | plt.axis('off') 48 | plt.imshow(image.astype(np.uint8), cmap=cmap, 49 | norm=norm, interpolation=interpolation) 50 | i += 1 51 | plt.show() 52 | 53 | 54 | def random_colors(N, bright=True): 55 | """ 56 | Generate random colors. 57 | To get visually distinct colors, generate them in HSV space then 58 | convert to RGB. 59 | """ 60 | brightness = 1.0 if bright else 0.7 61 | hsv = [(i / N, 1, brightness) for i in range(N)] 62 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 63 | random.shuffle(colors) 64 | return colors 65 | 66 | 67 | def apply_mask(image, mask, color, alpha=0.5): 68 | """Apply the given mask to the image. 69 | """ 70 | for c in range(3): 71 | image[:, :, c] = np.where(mask == 1, 72 | image[:, :, c] * 73 | (1 - alpha) + alpha * color[c] * 255, 74 | image[:, :, c]) 75 | return image 76 | 77 | 78 | def display_instances(image, boxes, masks, class_ids, class_names, 79 | scores=None, title="", 80 | figsize=(16, 16), ax=None): 81 | """ 82 | boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates. 83 | masks: [height, width, num_instances] 84 | class_ids: [num_instances] 85 | class_names: list of class names of the dataset 86 | scores: (optional) confidence scores for each box 87 | figsize: (optional) the size of the image. 88 | """ 89 | # Number of instances 90 | N = boxes.shape[0] 91 | if not N: 92 | print("\n*** No instances to display *** \n") 93 | else: 94 | assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0] 95 | 96 | if not ax: 97 | _, ax = plt.subplots(1, figsize=figsize) 98 | 99 | # Generate random colors 100 | colors = random_colors(N) 101 | 102 | # Show area outside image boundaries. 103 | height, width = image.shape[:2] 104 | ax.set_ylim(height + 10, -10) 105 | ax.set_xlim(-10, width + 10) 106 | ax.axis('off') 107 | ax.set_title(title) 108 | 109 | masked_image = image.astype(np.uint32).copy() 110 | for i in range(N): 111 | color = colors[i] 112 | 113 | # Bounding box 114 | if not np.any(boxes[i]): 115 | # Skip this instance. Has no bbox. Likely lost in image cropping. 116 | continue 117 | y1, x1, y2, x2 = boxes[i] 118 | p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, 119 | alpha=0.7, linestyle="dashed", 120 | edgecolor=color, facecolor='none') 121 | ax.add_patch(p) 122 | 123 | # Label 124 | class_id = class_ids[i] 125 | score = scores[i] if scores is not None else None 126 | label = class_names[class_id] 127 | x = random.randint(x1, (x1 + x2) // 2) 128 | caption = "{} {:.3f}".format(label, score) if score else label 129 | ax.text(x1, y1 + 8, caption, 130 | color='w', size=11, backgroundcolor="none") 131 | 132 | # Mask 133 | mask = masks[:, :, i] 134 | masked_image = apply_mask(masked_image, mask, color) 135 | 136 | # Mask Polygon 137 | # Pad to ensure proper polygons for masks that touch image edges. 138 | padded_mask = np.zeros( 139 | (mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) 140 | padded_mask[1:-1, 1:-1] = mask 141 | contours = find_contours(padded_mask, 0.5) 142 | for verts in contours: 143 | # Subtract the padding and flip (y, x) to (x, y) 144 | verts = np.fliplr(verts) - 1 145 | p = Polygon(verts, facecolor="none", edgecolor=color) 146 | ax.add_patch(p) 147 | ax.imshow(masked_image.astype(np.uint8)) 148 | plt.show() 149 | 150 | 151 | def draw_rois(image, rois, refined_rois, mask, class_ids, class_names, limit=10): 152 | """ 153 | anchors: [n, (y1, x1, y2, x2)] list of anchors in image coordinates. 154 | proposals: [n, 4] the same anchors but refined to fit objects better. 155 | """ 156 | masked_image = image.copy() 157 | 158 | # Pick random anchors in case there are too many. 159 | ids = np.arange(rois.shape[0], dtype=np.int32) 160 | ids = np.random.choice( 161 | ids, limit, replace=False) if ids.shape[0] > limit else ids 162 | 163 | fig, ax = plt.subplots(1, figsize=(12, 12)) 164 | if rois.shape[0] > limit: 165 | plt.title("Showing {} random ROIs out of {}".format( 166 | len(ids), rois.shape[0])) 167 | else: 168 | plt.title("{} ROIs".format(len(ids))) 169 | 170 | # Show area outside image boundaries. 171 | ax.set_ylim(image.shape[0] + 20, -20) 172 | ax.set_xlim(-50, image.shape[1] + 20) 173 | ax.axis('off') 174 | 175 | for i, id in enumerate(ids): 176 | color = np.random.rand(3) 177 | class_id = class_ids[id] 178 | # ROI 179 | y1, x1, y2, x2 = rois[id] 180 | p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, 181 | edgecolor=color if class_id else "gray", 182 | facecolor='none', linestyle="dashed") 183 | ax.add_patch(p) 184 | # Refined ROI 185 | if class_id: 186 | ry1, rx1, ry2, rx2 = refined_rois[id] 187 | p = patches.Rectangle((rx1, ry1), rx2 - rx1, ry2 - ry1, linewidth=2, 188 | edgecolor=color, facecolor='none') 189 | ax.add_patch(p) 190 | # Connect the top-left corners of the anchor and proposal for easy visualization 191 | ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color)) 192 | 193 | # Label 194 | label = class_names[class_id] 195 | ax.text(rx1, ry1 + 8, "{}".format(label), 196 | color='w', size=11, backgroundcolor="none") 197 | 198 | # Mask 199 | m = utils.unmold_mask(mask[id], rois[id] 200 | [:4].astype(np.int32), image.shape) 201 | masked_image = apply_mask(masked_image, m, color) 202 | 203 | ax.imshow(masked_image) 204 | 205 | # Print stats 206 | print("Positive ROIs: ", class_ids[class_ids > 0].shape[0]) 207 | print("Negative ROIs: ", class_ids[class_ids == 0].shape[0]) 208 | print("Positive Ratio: {:.2f}".format( 209 | class_ids[class_ids > 0].shape[0] / class_ids.shape[0])) 210 | 211 | 212 | # TODO: Replace with matplotlib equivalent? 213 | def draw_box(image, box, color): 214 | """Draw 3-pixel width bounding boxes on the given image array. 215 | color: list of 3 int values for RGB. 216 | """ 217 | y1, x1, y2, x2 = box 218 | image[y1:y1 + 2, x1:x2] = color 219 | image[y2:y2 + 2, x1:x2] = color 220 | image[y1:y2, x1:x1 + 2] = color 221 | image[y1:y2, x2:x2 + 2] = color 222 | return image 223 | 224 | 225 | def display_top_masks(image, mask, class_ids, class_names, limit=4): 226 | """Display the given image and the top few class masks.""" 227 | to_display = [] 228 | titles = [] 229 | to_display.append(image) 230 | titles.append("H x W={}x{}".format(image.shape[0], image.shape[1])) 231 | # Pick top prominent classes in this image 232 | unique_class_ids = np.unique(class_ids) 233 | mask_area = [np.sum(mask[:, :, np.where(class_ids == i)[0]]) 234 | for i in unique_class_ids] 235 | top_ids = [v[0] for v in sorted(zip(unique_class_ids, mask_area), 236 | key=lambda r: r[1], reverse=True) if v[1] > 0] 237 | # Generate images and titles 238 | for i in range(limit): 239 | class_id = top_ids[i] if i < len(top_ids) else -1 240 | # Pull masks of instances belonging to the same class. 241 | m = mask[:, :, np.where(class_ids == class_id)[0]] 242 | m = np.sum(m * np.arange(1, m.shape[-1] + 1), -1) 243 | to_display.append(m) 244 | titles.append(class_names[class_id] if class_id != -1 else "-") 245 | display_images(to_display, titles=titles, cols=limit + 1, cmap="Blues_r") 246 | 247 | 248 | def plot_precision_recall(AP, precisions, recalls): 249 | """Draw the precision-recall curve. 250 | 251 | AP: Average precision at IoU >= 0.5 252 | precisions: list of precision values 253 | recalls: list of recall values 254 | """ 255 | # Plot the Precision-Recall curve 256 | _, ax = plt.subplots(1) 257 | ax.set_title("Precision-Recall Curve. AP@50 = {:.3f}".format(AP)) 258 | ax.set_ylim(0, 1.1) 259 | ax.set_xlim(0, 1.1) 260 | _ = ax.plot(recalls, precisions) 261 | 262 | 263 | def plot_overlaps(gt_class_ids, pred_class_ids, pred_scores, 264 | overlaps, class_names, threshold=0.5): 265 | """Draw a grid showing how ground truth objects are classified. 266 | gt_class_ids: [N] int. Ground truth class IDs 267 | pred_class_id: [N] int. Predicted class IDs 268 | pred_scores: [N] float. The probability scores of predicted classes 269 | overlaps: [pred_boxes, gt_boxes] IoU overlaps of predictins and GT boxes. 270 | class_names: list of all class names in the dataset 271 | threshold: Float. The prediction probability required to predict a class 272 | """ 273 | gt_class_ids = gt_class_ids[gt_class_ids != 0] 274 | pred_class_ids = pred_class_ids[pred_class_ids != 0] 275 | 276 | plt.figure(figsize=(12, 10)) 277 | plt.imshow(overlaps, interpolation='nearest', cmap=plt.cm.Blues) 278 | plt.yticks(np.arange(len(pred_class_ids)), 279 | ["{} ({:.2f})".format(class_names[int(id)], pred_scores[i]) 280 | for i, id in enumerate(pred_class_ids)]) 281 | plt.xticks(np.arange(len(gt_class_ids)), 282 | [class_names[int(id)] for id in gt_class_ids], rotation=90) 283 | 284 | thresh = overlaps.max() / 2. 285 | for i, j in itertools.product(range(overlaps.shape[0]), 286 | range(overlaps.shape[1])): 287 | text = "" 288 | if overlaps[i, j] > threshold: 289 | text = "match" if gt_class_ids[j] == pred_class_ids[i] else "wrong" 290 | color = ("white" if overlaps[i, j] > thresh 291 | else "black" if overlaps[i, j] > 0 292 | else "grey") 293 | plt.text(j, i, "{:.3f}\n{}".format(overlaps[i, j], text), 294 | horizontalalignment="center", verticalalignment="center", 295 | fontsize=9, color=color) 296 | 297 | plt.tight_layout() 298 | plt.xlabel("Ground Truth") 299 | plt.ylabel("Predictions") 300 | 301 | 302 | def draw_boxes(image, boxes=None, refined_boxes=None, 303 | masks=None, captions=None, visibilities=None, 304 | title="", ax=None): 305 | """Draw bounding boxes and segmentation masks with differnt 306 | customizations. 307 | 308 | boxes: [N, (y1, x1, y2, x2, class_id)] in image coordinates. 309 | refined_boxes: Like boxes, but draw with solid lines to show 310 | that they're the result of refining 'boxes'. 311 | masks: [N, height, width] 312 | captions: List of N titles to display on each box 313 | visibilities: (optional) List of values of 0, 1, or 2. Determine how 314 | prominant each bounding box should be. 315 | title: An optional title to show over the image 316 | ax: (optional) Matplotlib axis to draw on. 317 | """ 318 | # Number of boxes 319 | assert boxes is not None or refined_boxes is not None 320 | N = boxes.shape[0] if boxes is not None else refined_boxes.shape[0] 321 | 322 | # Matplotlib Axis 323 | if not ax: 324 | _, ax = plt.subplots(1, figsize=(12, 12)) 325 | 326 | # Generate random colors 327 | colors = random_colors(N) 328 | 329 | # Show area outside image boundaries. 330 | margin = image.shape[0] // 10 331 | ax.set_ylim(image.shape[0] + margin, -margin) 332 | ax.set_xlim(-margin, image.shape[1] + margin) 333 | ax.axis('off') 334 | 335 | ax.set_title(title) 336 | 337 | masked_image = image.astype(np.uint32).copy() 338 | for i in range(N): 339 | # Box visibility 340 | visibility = visibilities[i] if visibilities is not None else 1 341 | if visibility == 0: 342 | color = "gray" 343 | style = "dotted" 344 | alpha = 0.5 345 | elif visibility == 1: 346 | color = colors[i] 347 | style = "dotted" 348 | alpha = 1 349 | elif visibility == 2: 350 | color = colors[i] 351 | style = "solid" 352 | alpha = 1 353 | 354 | # Boxes 355 | if boxes is not None: 356 | if not np.any(boxes[i]): 357 | # Skip this instance. Has no bbox. Likely lost in cropping. 358 | continue 359 | y1, x1, y2, x2 = boxes[i] 360 | p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, 361 | alpha=alpha, linestyle=style, 362 | edgecolor=color, facecolor='none') 363 | ax.add_patch(p) 364 | 365 | # Refined boxes 366 | if refined_boxes is not None and visibility > 0: 367 | ry1, rx1, ry2, rx2 = refined_boxes[i].astype(np.int32) 368 | p = patches.Rectangle((rx1, ry1), rx2 - rx1, ry2 - ry1, linewidth=2, 369 | edgecolor=color, facecolor='none') 370 | ax.add_patch(p) 371 | # Connect the top-left corners of the anchor and proposal 372 | if boxes is not None: 373 | ax.add_line(lines.Line2D([x1, rx1], [y1, ry1], color=color)) 374 | 375 | # Captions 376 | if captions is not None: 377 | caption = captions[i] 378 | # If there are refined boxes, display captions on them 379 | if refined_boxes is not None: 380 | y1, x1, y2, x2 = ry1, rx1, ry2, rx2 381 | x = random.randint(x1, (x1 + x2) // 2) 382 | ax.text(x1, y1, caption, size=11, verticalalignment='top', 383 | color='w', backgroundcolor="none", 384 | bbox={'facecolor': color, 'alpha': 0.5, 385 | 'pad': 2, 'edgecolor': 'none'}) 386 | 387 | # Masks 388 | if masks is not None: 389 | mask = masks[:, :, i] 390 | masked_image = apply_mask(masked_image, mask, color) 391 | # Mask Polygon 392 | # Pad to ensure proper polygons for masks that touch image edges. 393 | padded_mask = np.zeros( 394 | (mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) 395 | padded_mask[1:-1, 1:-1] = mask 396 | contours = find_contours(padded_mask, 0.5) 397 | for verts in contours: 398 | # Subtract the padding and flip (y, x) to (x, y) 399 | verts = np.fliplr(verts) - 1 400 | p = Polygon(verts, facecolor="none", edgecolor=color) 401 | ax.add_patch(p) 402 | ax.imshow(masked_image.astype(np.uint8)) 403 | 404 | def plot_loss(loss, val_loss, save=True, log_dir=None): 405 | loss = np.array(loss) 406 | val_loss = np.array(val_loss) 407 | 408 | plt.figure("loss") 409 | plt.gcf().clear() 410 | plt.plot(loss[:, 0], label='train') 411 | plt.plot(val_loss[:, 0], label='valid') 412 | plt.xlabel('epoch') 413 | plt.ylabel('loss') 414 | plt.legend() 415 | if save: 416 | save_path = os.path.join(log_dir, "loss.png") 417 | plt.savefig(save_path) 418 | else: 419 | plt.show(block=False) 420 | plt.pause(0.1) 421 | 422 | plt.figure("rpn_class_loss") 423 | plt.gcf().clear() 424 | plt.plot(loss[:, 1], label='train') 425 | plt.plot(val_loss[:, 1], label='valid') 426 | plt.xlabel('epoch') 427 | plt.ylabel('loss') 428 | plt.legend() 429 | if save: 430 | save_path = os.path.join(log_dir, "rpn_class_loss.png") 431 | plt.savefig(save_path) 432 | else: 433 | plt.show(block=False) 434 | plt.pause(0.1) 435 | 436 | plt.figure("rpn_bbox_loss") 437 | plt.gcf().clear() 438 | plt.plot(loss[:, 2], label='train') 439 | plt.plot(val_loss[:, 2], label='valid') 440 | plt.xlabel('epoch') 441 | plt.ylabel('loss') 442 | plt.legend() 443 | if save: 444 | save_path = os.path.join(log_dir, "rpn_bbox_loss.png") 445 | plt.savefig(save_path) 446 | else: 447 | plt.show(block=False) 448 | plt.pause(0.1) 449 | 450 | plt.figure("mrcnn_class_loss") 451 | plt.gcf().clear() 452 | plt.plot(loss[:, 3], label='train') 453 | plt.plot(val_loss[:, 3], label='valid') 454 | plt.xlabel('epoch') 455 | plt.ylabel('loss') 456 | plt.legend() 457 | if save: 458 | save_path = os.path.join(log_dir, "mrcnn_class_loss.png") 459 | plt.savefig(save_path) 460 | else: 461 | plt.show(block=False) 462 | plt.pause(0.1) 463 | 464 | plt.figure("mrcnn_bbox_loss") 465 | plt.gcf().clear() 466 | plt.plot(loss[:, 4], label='train') 467 | plt.plot(val_loss[:, 4], label='valid') 468 | plt.xlabel('epoch') 469 | plt.ylabel('loss') 470 | plt.legend() 471 | if save: 472 | save_path = os.path.join(log_dir, "mrcnn_bbox_loss.png") 473 | plt.savefig(save_path) 474 | else: 475 | plt.show(block=False) 476 | plt.pause(0.1) 477 | 478 | plt.figure("mrcnn_mask_loss") 479 | plt.gcf().clear() 480 | plt.plot(loss[:, 5], label='train') 481 | plt.plot(val_loss[:, 5], label='valid') 482 | plt.xlabel('epoch') 483 | plt.ylabel('loss') 484 | plt.legend() 485 | if save: 486 | save_path = os.path.join(log_dir, "mrcnn_mask_loss.png") 487 | plt.savefig(save_path) 488 | else: 489 | plt.show(block=False) 490 | plt.pause(0.1) 491 | 492 | 493 | --------------------------------------------------------------------------------