├── .gitattributes ├── LICENSE ├── README.MD ├── data ├── __init__.py ├── dataset.py ├── util.py └── voc_dataset.py ├── demo.ipynb ├── imgs ├── faster-speed.jpg ├── model_all.png └── visdom-fasterrcnn.png ├── misc ├── convert_caffe_pretrain.py ├── demo.jpg └── train_fast.py ├── model ├── __init__.py ├── faster_rcnn.py ├── faster_rcnn_vgg16.py ├── region_proposal_network.py └── utils │ ├── __init__.py │ ├── bbox_tools.py │ └── creator_tool.py ├── requirements.txt ├── train.py ├── trainer.py └── utils ├── __init__.py ├── array_tool.py ├── config.py ├── eval_tool.py └── vis_tool.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb binary 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 Yun Chen 4 | 5 | Original works by: 6 | -------------------------------------------------------- 7 | chainer/chainercv 8 | Copyright (c) 2017 Yusuke Niitani 9 | Licensed under The MIT License 10 | https://github.com/chainer/chainercv/blob/master/LICENSE 11 | -------------------------------------------------------- 12 | Faster R-CNN 13 | Copyright (c) 2015 Microsoft 14 | Licensed under The MIT License 15 | https://github.com/rbgirshick/py-faster-rcnn/blob/master/LICENSE 16 | -------------------------------------------------------- 17 | 18 | Permission is hereby granted, free of charge, to any person obtaining a copy 19 | of this software and associated documentation files (the "Software"), to deal 20 | in the Software without restriction, including without limitation the rights 21 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 22 | copies of the Software, and to permit persons to whom the Software is 23 | furnished to do so, subject to the following conditions: 24 | 25 | The above copyright notice and this permission notice shall be included in 26 | all copies or substantial portions of the Software. 27 | 28 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 29 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 30 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 31 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 32 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 33 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 34 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # A Simple and Fast Implementation of Faster R-CNN 2 | 3 | ## 1. Introduction 4 | 5 | **[Update:]** I've further simplified the code to pytorch 1.5, torchvision 0.6, and replace the customized ops roipool and nms with the one from torchvision. if you want the old version code, please checkout branch [v1.0](https://github.com/chenyuntc/simple-faster-rcnn-pytorch/tree/v1.0) 6 | 7 | 8 | 9 | This project is a **Simplified** Faster R-CNN implementation based on [chainercv](https://github.com/chainer/chainercv) and other [projects](#acknowledgement) . I hope it can serve as an start code for those who want to know the detail of Faster R-CNN. It aims to: 10 | 11 | - Simplify the code (*Simple is better than complex*) 12 | - Make the code more straightforward (*Flat is better than nested*) 13 | - Match the performance reported in [origin paper](https://arxiv.org/abs/1506.01497) (*Speed Counts and mAP Matters*) 14 | 15 | And it has the following features: 16 | - It can be run as pure Python code, no more build affair. 17 | - It's a minimal implemention in around 2000 lines valid code with a lot of comment and instruction.(thanks to chainercv's excellent documentation) 18 | - It achieves higher mAP than the origin implementation (0.712 VS 0.699) 19 | - It achieve speed compariable with other implementation (6fps and 14fps for train and test in TITAN XP) 20 | - It's memory-efficient (about 3GB for vgg16) 21 | 22 | 23 | ![img](imgs/faster-speed.jpg) 24 | 25 | 26 | 27 | ## 2. Performance 28 | 29 | ### 2.1 mAP 30 | 31 | VGG16 train on `trainval` and test on `test` split. 32 | 33 | **Note**: the training shows great randomness, you may need a bit of luck and more epoches of training to reach the highest mAP. However, it should be easy to surpass the lower bound. 34 | 35 | | Implementation | mAP | 36 | | :--------------------------------------: | :---------: | 37 | | [origin paper](https://arxiv.org/abs/1506.01497) | 0.699 | 38 | | train with caffe pretrained model | 0.700-0.712 | 39 | | train with torchvision pretrained model | 0.685-0.701 | 40 | | model converted from [chainercv](https://github.com/chainer/chainercv/tree/master/examples/faster_rcnn) (reported 0.706) | 0.7053 | 41 | 42 | ### 2.2 Speed 43 | 44 | | Implementation | GPU | Inference | Trainining | 45 | | :--------------------------------------: | :------: | :-------: | :--------: | 46 | | [origin paper](https://arxiv.org/abs/1506.01497) | K40 | 5 fps | NA | 47 | | This[1] | TITAN Xp | 14-15 fps | 6 fps | 48 | | [pytorch-faster-rcnn](https://github.com/ruotianluo/pytorch-faster-rcnn) | TITAN Xp | 15-17fps | 6fps | 49 | 50 | [1]: make sure you install cupy correctly and only one program run on the GPU. The training speed is sensitive to your gpu status. see [troubleshooting](troubleshooting) for more info. Morever it's slow in the start of the program -- it need time to warm up. 51 | 52 | It could be faster by removing visualization, logging, averaging loss etc. 53 | ## 3. Install dependencies 54 | 55 | 56 | Here is an example of create environ **from scratch** with `anaconda` 57 | 58 | ```sh 59 | # create conda env 60 | conda create --name simp python=3.7 61 | conda activate simp 62 | # install pytorch 63 | conda install pytorch torchvision cudatoolkit=10.2 -c pytorch 64 | 65 | # install other dependancy 66 | pip install visdom scikit-image tqdm fire ipdb pprint matplotlib torchnet 67 | 68 | # start visdom 69 | nohup python -m visdom.server & 70 | 71 | ``` 72 | 73 | If you don't use anaconda, then: 74 | 75 | - install PyTorch with GPU (code are GPU-only), refer to [official website](http://pytorch.org) 76 | 77 | - install other dependencies: `pip install visdom scikit-image tqdm fire ipdb pprint matplotlib torchnet` 78 | 79 | - start visdom for visualization 80 | 81 | ```Bash 82 | nohup python -m visdom.server & 83 | ``` 84 | 85 | 86 | 87 | ## 4. Demo 88 | 89 | Download pretrained model from [Google Drive](https://drive.google.com/open?id=1cQ27LIn-Rig4-Uayzy_gH5-cW-NRGVzY) or [Baidu Netdisk( passwd: scxn)](https://pan.baidu.com/s/1o87RuXW) 90 | 91 | 92 | See [demo.ipynb](https://github.com/chenyuntc/simple-faster-rcnn-pytorch/blob/master/demo.ipynb) for more detail. 93 | 94 | ## 5. Train 95 | 96 | ### 5.1 Prepare data 97 | 98 | #### Pascal VOC2007 99 | 100 | 1. Download the training, validation, test data and VOCdevkit 101 | 102 | ```Bash 103 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar 104 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar 105 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar 106 | ``` 107 | 108 | 2. Extract all of these tars into one directory named `VOCdevkit` 109 | 110 | ```Bash 111 | tar xvf VOCtrainval_06-Nov-2007.tar 112 | tar xvf VOCtest_06-Nov-2007.tar 113 | tar xvf VOCdevkit_08-Jun-2007.tar 114 | ``` 115 | 116 | 3. It should have this basic structure 117 | 118 | ```Bash 119 | $VOCdevkit/ # development kit 120 | $VOCdevkit/VOCcode/ # VOC utility code 121 | $VOCdevkit/VOC2007 # image sets, annotations, etc. 122 | # ... and several other directories ... 123 | ``` 124 | 125 | 4. modify `voc_data_dir` cfg item in `utils/config.py`, or pass it to program using argument like `--voc-data-dir=/path/to/VOCdevkit/VOC2007/` . 126 | 127 | 128 | ### 5.2 [Optional]Prepare caffe-pretrained vgg16 129 | 130 | If you want to use caffe-pretrain model as initial weight, you can run below to get vgg16 weights converted from caffe, which is the same as the origin paper use. 131 | 132 | ````Bash 133 | python misc/convert_caffe_pretrain.py 134 | ```` 135 | 136 | This scripts would download pretrained model and converted it to the format compatible with torchvision. If you are in China and can not download the pretrain model, you may refer to [this issue](https://github.com/chenyuntc/simple-faster-rcnn-pytorch/issues/63) 137 | 138 | Then you could specify where caffe-pretraind model `vgg16_caffe.pth` stored in `utils/config.py` by setting `caffe_pretrain_path`. The default path is ok. 139 | 140 | If you want to use pretrained model from torchvision, you may skip this step. 141 | 142 | **NOTE**, caffe pretrained model has shown slight better performance. 143 | 144 | **NOTE**: caffe model require images in BGR 0-255, while torchvision model requires images in RGB and 0-1. See `data/dataset.py`for more detail. 145 | 146 | ### 5.3 begin training 147 | 148 | 149 | ```bash 150 | python train.py train --env='fasterrcnn' --plot-every=100 151 | ``` 152 | 153 | you may refer to `utils/config.py` for more argument. 154 | 155 | Some Key arguments: 156 | 157 | - `--caffe-pretrain=False`: use pretrain model from caffe or torchvision (Default: torchvison) 158 | - `--plot-every=n`: visualize prediction, loss etc every `n` batches. 159 | - `--env`: visdom env for visualization 160 | - `--voc_data_dir`: where the VOC data stored 161 | - `--use-drop`: use dropout in RoI head, default False 162 | - `--use-Adam`: use Adam instead of SGD, default SGD. (You need set a very low `lr` for Adam) 163 | - `--load-path`: pretrained model path, default `None`, if it's specified, it would be loaded. 164 | 165 | you may open browser, visit `http://:8097` and see the visualization of training procedure as below: 166 | 167 | ![visdom](imgs/visdom-fasterrcnn.png) 168 | 169 | ## Troubleshooting 170 | 171 | - dataloader: `received 0 items of ancdata` 172 | 173 | see [discussion](https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667), It's alreadly fixed in [train.py](https://github.com/chenyuntc/simple-faster-rcnn-pytorch/blob/master/train.py#L17-L22). So I think you are free from this problem. 174 | 175 | - Windows support 176 | 177 | I don't have windows machine with GPU to debug and test it. It's welcome if anyone could make a pull request and test it. 178 | 179 | 180 | 181 | ## Acknowledgement 182 | This work builds on many excellent works, which include: 183 | 184 | - [Yusuke Niitani's ChainerCV](https://github.com/chainer/chainercv) (mainly) 185 | - [Ruotian Luo's pytorch-faster-rcnn](https://github.com/ruotianluo/pytorch-faster-rcnn) which based on [Xinlei Chen's tf-faster-rcnn](https://github.com/endernewton/tf-faster-rcnn) 186 | - [faster-rcnn.pytorch by Jianwei Yang and Jiasen Lu](https://github.com/jwyang/faster-rcnn.pytorch).It mainly refer to [longcw's faster_rcnn_pytorch](https://github.com/longcw/faster_rcnn_pytorch) 187 | - All the above Repositories have referred to [py-faster-rcnn by Ross Girshick and Sean Bell](https://github.com/rbgirshick/py-faster-rcnn) either directly or indirectly. 188 | 189 | ## ^_^ 190 | Licensed under MIT, see the LICENSE for more detail. 191 | 192 | Contribution Welcome. 193 | 194 | If you encounter any problem, feel free to open an issue, but too busy lately. 195 | 196 | Correct me if anything is wrong or unclear. 197 | 198 | model structure 199 | ![img](imgs/model_all.png) 200 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyuntc/simple-faster-rcnn-pytorch/367db367834efd8a2bc58ee0023b2b628a0e474d/data/__init__.py -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | import torch as t 4 | from data.voc_dataset import VOCBboxDataset 5 | from skimage import transform as sktsf 6 | from torchvision import transforms as tvtsf 7 | from data import util 8 | import numpy as np 9 | from utils.config import opt 10 | 11 | 12 | def inverse_normalize(img): 13 | if opt.caffe_pretrain: 14 | img = img + (np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1)) 15 | return img[::-1, :, :] 16 | # approximate un-normalize for visualize 17 | return (img * 0.225 + 0.45).clip(min=0, max=1) * 255 18 | 19 | 20 | def pytorch_normalze(img): 21 | """ 22 | https://github.com/pytorch/vision/issues/223 23 | return appr -1~1 RGB 24 | """ 25 | normalize = tvtsf.Normalize(mean=[0.485, 0.456, 0.406], 26 | std=[0.229, 0.224, 0.225]) 27 | img = normalize(t.from_numpy(img)) 28 | return img.numpy() 29 | 30 | 31 | def caffe_normalize(img): 32 | """ 33 | return appr -125-125 BGR 34 | """ 35 | img = img[[2, 1, 0], :, :] # RGB-BGR 36 | img = img * 255 37 | mean = np.array([122.7717, 115.9465, 102.9801]).reshape(3, 1, 1) 38 | img = (img - mean).astype(np.float32, copy=True) 39 | return img 40 | 41 | 42 | def preprocess(img, min_size=600, max_size=1000): 43 | """Preprocess an image for feature extraction. 44 | 45 | The length of the shorter edge is scaled to :obj:`self.min_size`. 46 | After the scaling, if the length of the longer edge is longer than 47 | :param min_size: 48 | :obj:`self.max_size`, the image is scaled to fit the longer edge 49 | to :obj:`self.max_size`. 50 | 51 | After resizing the image, the image is subtracted by a mean image value 52 | :obj:`self.mean`. 53 | 54 | Args: 55 | img (~numpy.ndarray): An image. This is in CHW and RGB format. 56 | The range of its value is :math:`[0, 255]`. 57 | 58 | Returns: 59 | ~numpy.ndarray: A preprocessed image. 60 | 61 | """ 62 | C, H, W = img.shape 63 | scale1 = min_size / min(H, W) 64 | scale2 = max_size / max(H, W) 65 | scale = min(scale1, scale2) 66 | img = img / 255. 67 | img = sktsf.resize(img, (C, H * scale, W * scale), mode='reflect',anti_aliasing=False) 68 | # both the longer and shorter should be less than 69 | # max_size and min_size 70 | if opt.caffe_pretrain: 71 | normalize = caffe_normalize 72 | else: 73 | normalize = pytorch_normalze 74 | return normalize(img) 75 | 76 | 77 | class Transform(object): 78 | 79 | def __init__(self, min_size=600, max_size=1000): 80 | self.min_size = min_size 81 | self.max_size = max_size 82 | 83 | def __call__(self, in_data): 84 | img, bbox, label = in_data 85 | _, H, W = img.shape 86 | img = preprocess(img, self.min_size, self.max_size) 87 | _, o_H, o_W = img.shape 88 | scale = o_H / H 89 | bbox = util.resize_bbox(bbox, (H, W), (o_H, o_W)) 90 | 91 | # horizontally flip 92 | img, params = util.random_flip( 93 | img, x_random=True, return_param=True) 94 | bbox = util.flip_bbox( 95 | bbox, (o_H, o_W), x_flip=params['x_flip']) 96 | 97 | return img, bbox, label, scale 98 | 99 | 100 | class Dataset: 101 | def __init__(self, opt): 102 | self.opt = opt 103 | self.db = VOCBboxDataset(opt.voc_data_dir) 104 | self.tsf = Transform(opt.min_size, opt.max_size) 105 | 106 | def __getitem__(self, idx): 107 | ori_img, bbox, label, difficult = self.db.get_example(idx) 108 | 109 | img, bbox, label, scale = self.tsf((ori_img, bbox, label)) 110 | # TODO: check whose stride is negative to fix this instead copy all 111 | # some of the strides of a given numpy array are negative. 112 | return img.copy(), bbox.copy(), label.copy(), scale 113 | 114 | def __len__(self): 115 | return len(self.db) 116 | 117 | 118 | class TestDataset: 119 | def __init__(self, opt, split='test', use_difficult=True): 120 | self.opt = opt 121 | self.db = VOCBboxDataset(opt.voc_data_dir, split=split, use_difficult=use_difficult) 122 | 123 | def __getitem__(self, idx): 124 | ori_img, bbox, label, difficult = self.db.get_example(idx) 125 | img = preprocess(ori_img) 126 | return img, ori_img.shape[1:], bbox, label, difficult 127 | 128 | def __len__(self): 129 | return len(self.db) 130 | -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | 5 | 6 | def read_image(path, dtype=np.float32, color=True): 7 | """Read an image from a file. 8 | 9 | This function reads an image from given file. The image is CHW format and 10 | the range of its value is :math:`[0, 255]`. If :obj:`color = True`, the 11 | order of the channels is RGB. 12 | 13 | Args: 14 | path (str): A path of image file. 15 | dtype: The type of array. The default value is :obj:`~numpy.float32`. 16 | color (bool): This option determines the number of channels. 17 | If :obj:`True`, the number of channels is three. In this case, 18 | the order of the channels is RGB. This is the default behaviour. 19 | If :obj:`False`, this function returns a grayscale image. 20 | 21 | Returns: 22 | ~numpy.ndarray: An image. 23 | """ 24 | 25 | f = Image.open(path) 26 | try: 27 | if color: 28 | img = f.convert('RGB') 29 | else: 30 | img = f.convert('P') 31 | img = np.asarray(img, dtype=dtype) 32 | finally: 33 | if hasattr(f, 'close'): 34 | f.close() 35 | 36 | if img.ndim == 2: 37 | # reshape (H, W) -> (1, H, W) 38 | return img[np.newaxis] 39 | else: 40 | # transpose (H, W, C) -> (C, H, W) 41 | return img.transpose((2, 0, 1)) 42 | 43 | 44 | def resize_bbox(bbox, in_size, out_size): 45 | """Resize bounding boxes according to image resize. 46 | 47 | The bounding boxes are expected to be packed into a two dimensional 48 | tensor of shape :math:`(R, 4)`, where :math:`R` is the number of 49 | bounding boxes in the image. The second axis represents attributes of 50 | the bounding box. They are :math:`(y_{min}, x_{min}, y_{max}, x_{max})`, 51 | where the four attributes are coordinates of the top left and the 52 | bottom right vertices. 53 | 54 | Args: 55 | bbox (~numpy.ndarray): An array whose shape is :math:`(R, 4)`. 56 | :math:`R` is the number of bounding boxes. 57 | in_size (tuple): A tuple of length 2. The height and the width 58 | of the image before resized. 59 | out_size (tuple): A tuple of length 2. The height and the width 60 | of the image after resized. 61 | 62 | Returns: 63 | ~numpy.ndarray: 64 | Bounding boxes rescaled according to the given image shapes. 65 | 66 | """ 67 | bbox = bbox.copy() 68 | y_scale = float(out_size[0]) / in_size[0] 69 | x_scale = float(out_size[1]) / in_size[1] 70 | bbox[:, 0] = y_scale * bbox[:, 0] 71 | bbox[:, 2] = y_scale * bbox[:, 2] 72 | bbox[:, 1] = x_scale * bbox[:, 1] 73 | bbox[:, 3] = x_scale * bbox[:, 3] 74 | return bbox 75 | 76 | 77 | def flip_bbox(bbox, size, y_flip=False, x_flip=False): 78 | """Flip bounding boxes accordingly. 79 | 80 | The bounding boxes are expected to be packed into a two dimensional 81 | tensor of shape :math:`(R, 4)`, where :math:`R` is the number of 82 | bounding boxes in the image. The second axis represents attributes of 83 | the bounding box. They are :math:`(y_{min}, x_{min}, y_{max}, x_{max})`, 84 | where the four attributes are coordinates of the top left and the 85 | bottom right vertices. 86 | 87 | Args: 88 | bbox (~numpy.ndarray): An array whose shape is :math:`(R, 4)`. 89 | :math:`R` is the number of bounding boxes. 90 | size (tuple): A tuple of length 2. The height and the width 91 | of the image before resized. 92 | y_flip (bool): Flip bounding box according to a vertical flip of 93 | an image. 94 | x_flip (bool): Flip bounding box according to a horizontal flip of 95 | an image. 96 | 97 | Returns: 98 | ~numpy.ndarray: 99 | Bounding boxes flipped according to the given flips. 100 | 101 | """ 102 | H, W = size 103 | bbox = bbox.copy() 104 | if y_flip: 105 | y_max = H - bbox[:, 0] 106 | y_min = H - bbox[:, 2] 107 | bbox[:, 0] = y_min 108 | bbox[:, 2] = y_max 109 | if x_flip: 110 | x_max = W - bbox[:, 1] 111 | x_min = W - bbox[:, 3] 112 | bbox[:, 1] = x_min 113 | bbox[:, 3] = x_max 114 | return bbox 115 | 116 | 117 | def crop_bbox( 118 | bbox, y_slice=None, x_slice=None, 119 | allow_outside_center=True, return_param=False): 120 | """Translate bounding boxes to fit within the cropped area of an image. 121 | 122 | This method is mainly used together with image cropping. 123 | This method translates the coordinates of bounding boxes like 124 | :func:`data.util.translate_bbox`. In addition, 125 | this function truncates the bounding boxes to fit within the cropped area. 126 | If a bounding box does not overlap with the cropped area, 127 | this bounding box will be removed. 128 | 129 | The bounding boxes are expected to be packed into a two dimensional 130 | tensor of shape :math:`(R, 4)`, where :math:`R` is the number of 131 | bounding boxes in the image. The second axis represents attributes of 132 | the bounding box. They are :math:`(y_{min}, x_{min}, y_{max}, x_{max})`, 133 | where the four attributes are coordinates of the top left and the 134 | bottom right vertices. 135 | 136 | Args: 137 | bbox (~numpy.ndarray): Bounding boxes to be transformed. The shape is 138 | :math:`(R, 4)`. :math:`R` is the number of bounding boxes. 139 | y_slice (slice): The slice of y axis. 140 | x_slice (slice): The slice of x axis. 141 | allow_outside_center (bool): If this argument is :obj:`False`, 142 | bounding boxes whose centers are outside of the cropped area 143 | are removed. The default value is :obj:`True`. 144 | return_param (bool): If :obj:`True`, this function returns 145 | indices of kept bounding boxes. 146 | 147 | Returns: 148 | ~numpy.ndarray or (~numpy.ndarray, dict): 149 | 150 | If :obj:`return_param = False`, returns an array :obj:`bbox`. 151 | 152 | If :obj:`return_param = True`, 153 | returns a tuple whose elements are :obj:`bbox, param`. 154 | :obj:`param` is a dictionary of intermediate parameters whose 155 | contents are listed below with key, value-type and the description 156 | of the value. 157 | 158 | * **index** (*numpy.ndarray*): An array holding indices of used \ 159 | bounding boxes. 160 | 161 | """ 162 | 163 | t, b = _slice_to_bounds(y_slice) 164 | l, r = _slice_to_bounds(x_slice) 165 | crop_bb = np.array((t, l, b, r)) 166 | 167 | if allow_outside_center: 168 | mask = np.ones(bbox.shape[0], dtype=bool) 169 | else: 170 | center = (bbox[:, :2] + bbox[:, 2:]) / 2.0 171 | mask = np.logical_and(crop_bb[:2] <= center, center < crop_bb[2:]) \ 172 | .all(axis=1) 173 | 174 | bbox = bbox.copy() 175 | bbox[:, :2] = np.maximum(bbox[:, :2], crop_bb[:2]) 176 | bbox[:, 2:] = np.minimum(bbox[:, 2:], crop_bb[2:]) 177 | bbox[:, :2] -= crop_bb[:2] 178 | bbox[:, 2:] -= crop_bb[:2] 179 | 180 | mask = np.logical_and(mask, (bbox[:, :2] < bbox[:, 2:]).all(axis=1)) 181 | bbox = bbox[mask] 182 | 183 | if return_param: 184 | return bbox, {'index': np.flatnonzero(mask)} 185 | else: 186 | return bbox 187 | 188 | 189 | def _slice_to_bounds(slice_): 190 | if slice_ is None: 191 | return 0, np.inf 192 | 193 | if slice_.start is None: 194 | l = 0 195 | else: 196 | l = slice_.start 197 | 198 | if slice_.stop is None: 199 | u = np.inf 200 | else: 201 | u = slice_.stop 202 | 203 | return l, u 204 | 205 | 206 | def translate_bbox(bbox, y_offset=0, x_offset=0): 207 | """Translate bounding boxes. 208 | 209 | This method is mainly used together with image transforms, such as padding 210 | and cropping, which translates the left top point of the image from 211 | coordinate :math:`(0, 0)` to coordinate 212 | :math:`(y, x) = (y_{offset}, x_{offset})`. 213 | 214 | The bounding boxes are expected to be packed into a two dimensional 215 | tensor of shape :math:`(R, 4)`, where :math:`R` is the number of 216 | bounding boxes in the image. The second axis represents attributes of 217 | the bounding box. They are :math:`(y_{min}, x_{min}, y_{max}, x_{max})`, 218 | where the four attributes are coordinates of the top left and the 219 | bottom right vertices. 220 | 221 | Args: 222 | bbox (~numpy.ndarray): Bounding boxes to be transformed. The shape is 223 | :math:`(R, 4)`. :math:`R` is the number of bounding boxes. 224 | y_offset (int or float): The offset along y axis. 225 | x_offset (int or float): The offset along x axis. 226 | 227 | Returns: 228 | ~numpy.ndarray: 229 | Bounding boxes translated according to the given offsets. 230 | 231 | """ 232 | 233 | out_bbox = bbox.copy() 234 | out_bbox[:, :2] += (y_offset, x_offset) 235 | out_bbox[:, 2:] += (y_offset, x_offset) 236 | 237 | return out_bbox 238 | 239 | 240 | def random_flip(img, y_random=False, x_random=False, 241 | return_param=False, copy=False): 242 | """Randomly flip an image in vertical or horizontal direction. 243 | 244 | Args: 245 | img (~numpy.ndarray): An array that gets flipped. This is in 246 | CHW format. 247 | y_random (bool): Randomly flip in vertical direction. 248 | x_random (bool): Randomly flip in horizontal direction. 249 | return_param (bool): Returns information of flip. 250 | copy (bool): If False, a view of :obj:`img` will be returned. 251 | 252 | Returns: 253 | ~numpy.ndarray or (~numpy.ndarray, dict): 254 | 255 | If :obj:`return_param = False`, 256 | returns an array :obj:`out_img` that is the result of flipping. 257 | 258 | If :obj:`return_param = True`, 259 | returns a tuple whose elements are :obj:`out_img, param`. 260 | :obj:`param` is a dictionary of intermediate parameters whose 261 | contents are listed below with key, value-type and the description 262 | of the value. 263 | 264 | * **y_flip** (*bool*): Whether the image was flipped in the\ 265 | vertical direction or not. 266 | * **x_flip** (*bool*): Whether the image was flipped in the\ 267 | horizontal direction or not. 268 | 269 | """ 270 | y_flip, x_flip = False, False 271 | if y_random: 272 | y_flip = random.choice([True, False]) 273 | if x_random: 274 | x_flip = random.choice([True, False]) 275 | 276 | if y_flip: 277 | img = img[:, ::-1, :] 278 | if x_flip: 279 | img = img[:, :, ::-1] 280 | 281 | if copy: 282 | img = img.copy() 283 | 284 | if return_param: 285 | return img, {'y_flip': y_flip, 'x_flip': x_flip} 286 | else: 287 | return img 288 | -------------------------------------------------------------------------------- /data/voc_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | 4 | import numpy as np 5 | 6 | from .util import read_image 7 | 8 | 9 | class VOCBboxDataset: 10 | """Bounding box dataset for PASCAL `VOC`_. 11 | 12 | .. _`VOC`: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/ 13 | 14 | The index corresponds to each image. 15 | 16 | When queried by an index, if :obj:`return_difficult == False`, 17 | this dataset returns a corresponding 18 | :obj:`img, bbox, label`, a tuple of an image, bounding boxes and labels. 19 | This is the default behaviour. 20 | If :obj:`return_difficult == True`, this dataset returns corresponding 21 | :obj:`img, bbox, label, difficult`. :obj:`difficult` is a boolean array 22 | that indicates whether bounding boxes are labeled as difficult or not. 23 | 24 | The bounding boxes are packed into a two dimensional tensor of shape 25 | :math:`(R, 4)`, where :math:`R` is the number of bounding boxes in 26 | the image. The second axis represents attributes of the bounding box. 27 | They are :math:`(y_{min}, x_{min}, y_{max}, x_{max})`, where the 28 | four attributes are coordinates of the top left and the bottom right 29 | vertices. 30 | 31 | The labels are packed into a one dimensional tensor of shape :math:`(R,)`. 32 | :math:`R` is the number of bounding boxes in the image. 33 | The class name of the label :math:`l` is :math:`l` th element of 34 | :obj:`VOC_BBOX_LABEL_NAMES`. 35 | 36 | The array :obj:`difficult` is a one dimensional boolean array of shape 37 | :math:`(R,)`. :math:`R` is the number of bounding boxes in the image. 38 | If :obj:`use_difficult` is :obj:`False`, this array is 39 | a boolean array with all :obj:`False`. 40 | 41 | The type of the image, the bounding boxes and the labels are as follows. 42 | 43 | * :obj:`img.dtype == numpy.float32` 44 | * :obj:`bbox.dtype == numpy.float32` 45 | * :obj:`label.dtype == numpy.int32` 46 | * :obj:`difficult.dtype == numpy.bool` 47 | 48 | Args: 49 | data_dir (string): Path to the root of the training data. 50 | i.e. "/data/image/voc/VOCdevkit/VOC2007/" 51 | split ({'train', 'val', 'trainval', 'test'}): Select a split of the 52 | dataset. :obj:`test` split is only available for 53 | 2007 dataset. 54 | year ({'2007', '2012'}): Use a dataset prepared for a challenge 55 | held in :obj:`year`. 56 | use_difficult (bool): If :obj:`True`, use images that are labeled as 57 | difficult in the original annotation. 58 | return_difficult (bool): If :obj:`True`, this dataset returns 59 | a boolean array 60 | that indicates whether bounding boxes are labeled as difficult 61 | or not. The default value is :obj:`False`. 62 | 63 | """ 64 | 65 | def __init__(self, data_dir, split='trainval', 66 | use_difficult=False, return_difficult=False, 67 | ): 68 | 69 | # if split not in ['train', 'trainval', 'val']: 70 | # if not (split == 'test' and year == '2007'): 71 | # warnings.warn( 72 | # 'please pick split from \'train\', \'trainval\', \'val\'' 73 | # 'for 2012 dataset. For 2007 dataset, you can pick \'test\'' 74 | # ' in addition to the above mentioned splits.' 75 | # ) 76 | id_list_file = os.path.join( 77 | data_dir, 'ImageSets/Main/{0}.txt'.format(split)) 78 | 79 | self.ids = [id_.strip() for id_ in open(id_list_file)] 80 | self.data_dir = data_dir 81 | self.use_difficult = use_difficult 82 | self.return_difficult = return_difficult 83 | self.label_names = VOC_BBOX_LABEL_NAMES 84 | 85 | def __len__(self): 86 | return len(self.ids) 87 | 88 | def get_example(self, i): 89 | """Returns the i-th example. 90 | 91 | Returns a color image and bounding boxes. The image is in CHW format. 92 | The returned image is RGB. 93 | 94 | Args: 95 | i (int): The index of the example. 96 | 97 | Returns: 98 | tuple of an image and bounding boxes 99 | 100 | """ 101 | id_ = self.ids[i] 102 | anno = ET.parse( 103 | os.path.join(self.data_dir, 'Annotations', id_ + '.xml')) 104 | bbox = list() 105 | label = list() 106 | difficult = list() 107 | for obj in anno.findall('object'): 108 | # when in not using difficult split, and the object is 109 | # difficult, skipt it. 110 | if not self.use_difficult and int(obj.find('difficult').text) == 1: 111 | continue 112 | 113 | difficult.append(int(obj.find('difficult').text)) 114 | bndbox_anno = obj.find('bndbox') 115 | # subtract 1 to make pixel indexes 0-based 116 | bbox.append([ 117 | int(bndbox_anno.find(tag).text) - 1 118 | for tag in ('ymin', 'xmin', 'ymax', 'xmax')]) 119 | name = obj.find('name').text.lower().strip() 120 | label.append(VOC_BBOX_LABEL_NAMES.index(name)) 121 | bbox = np.stack(bbox).astype(np.float32) 122 | label = np.stack(label).astype(np.int32) 123 | # When `use_difficult==False`, all elements in `difficult` are False. 124 | difficult = np.array(difficult, dtype=np.bool).astype(np.uint8) # PyTorch don't support np.bool 125 | 126 | # Load a image 127 | img_file = os.path.join(self.data_dir, 'JPEGImages', id_ + '.jpg') 128 | img = read_image(img_file, color=True) 129 | 130 | # if self.return_difficult: 131 | # return img, bbox, label, difficult 132 | return img, bbox, label, difficult 133 | 134 | __getitem__ = get_example 135 | 136 | 137 | VOC_BBOX_LABEL_NAMES = ( 138 | 'aeroplane', 139 | 'bicycle', 140 | 'bird', 141 | 'boat', 142 | 'bottle', 143 | 'bus', 144 | 'car', 145 | 'cat', 146 | 'chair', 147 | 'cow', 148 | 'diningtable', 149 | 'dog', 150 | 'horse', 151 | 'motorbike', 152 | 'person', 153 | 'pottedplant', 154 | 'sheep', 155 | 'sofa', 156 | 'train', 157 | 'tvmonitor') 158 | -------------------------------------------------------------------------------- /imgs/faster-speed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyuntc/simple-faster-rcnn-pytorch/367db367834efd8a2bc58ee0023b2b628a0e474d/imgs/faster-speed.jpg -------------------------------------------------------------------------------- /imgs/model_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyuntc/simple-faster-rcnn-pytorch/367db367834efd8a2bc58ee0023b2b628a0e474d/imgs/model_all.png -------------------------------------------------------------------------------- /imgs/visdom-fasterrcnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyuntc/simple-faster-rcnn-pytorch/367db367834efd8a2bc58ee0023b2b628a0e474d/imgs/visdom-fasterrcnn.png -------------------------------------------------------------------------------- /misc/convert_caffe_pretrain.py: -------------------------------------------------------------------------------- 1 | # code from ruotian luo 2 | # https://github.com/ruotianluo/pytorch-faster-rcnn 3 | import torch 4 | from torch.utils.model_zoo import load_url 5 | from torchvision import models 6 | 7 | sd = load_url("https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth") 8 | sd['classifier.0.weight'] = sd['classifier.1.weight'] 9 | sd['classifier.0.bias'] = sd['classifier.1.bias'] 10 | del sd['classifier.1.weight'] 11 | del sd['classifier.1.bias'] 12 | 13 | sd['classifier.3.weight'] = sd['classifier.4.weight'] 14 | sd['classifier.3.bias'] = sd['classifier.4.bias'] 15 | del sd['classifier.4.weight'] 16 | del sd['classifier.4.bias'] 17 | 18 | import os 19 | # speicify the path to save 20 | if not os.path.exists('checkpoints'): 21 | os.makedirs('checkpoints') 22 | torch.save(sd, "checkpoints/vgg16_caffe.pth") -------------------------------------------------------------------------------- /misc/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyuntc/simple-faster-rcnn-pytorch/367db367834efd8a2bc58ee0023b2b628a0e474d/misc/demo.jpg -------------------------------------------------------------------------------- /misc/train_fast.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import ipdb 4 | import matplotlib 5 | from tqdm import tqdm 6 | 7 | from utils.config import opt 8 | from data.dataset import Dataset, TestDataset 9 | from model import FasterRCNNVGG16 10 | from torch.utils import data as data_ 11 | from trainer import FasterRCNNTrainer 12 | from utils import array_tool as at 13 | from utils.vis_tool import visdom_bbox 14 | from utils.eval_tool import eval_detection_voc 15 | 16 | matplotlib.use('agg') 17 | 18 | def eval(dataloader, faster_rcnn, test_num=10000): 19 | pred_bboxes, pred_labels, pred_scores = list(), list(), list() 20 | gt_bboxes, gt_labels, gt_difficults = list(), list(), list() 21 | for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)): 22 | sizes = [sizes[0][0], sizes[1][0]] 23 | pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes]) 24 | gt_bboxes += list(gt_bboxes_.numpy()) 25 | gt_labels += list(gt_labels_.numpy()) 26 | gt_difficults += list(gt_difficults_.numpy()) 27 | pred_bboxes += pred_bboxes_ 28 | pred_labels += pred_labels_ 29 | pred_scores += pred_scores_ 30 | if ii == test_num: break 31 | 32 | result = eval_detection_voc( 33 | pred_bboxes, pred_labels, pred_scores, 34 | gt_bboxes, gt_labels, gt_difficults, 35 | use_07_metric=True) 36 | return result 37 | 38 | 39 | def train(**kwargs): 40 | opt._parse(kwargs) 41 | 42 | dataset = Dataset(opt) 43 | print('load data') 44 | dataloader = data_.DataLoader(dataset, \ 45 | batch_size=1, \ 46 | shuffle=True, \ 47 | # pin_memory=True, 48 | num_workers=opt.num_workers) 49 | testset = TestDataset(opt) 50 | test_dataloader = data_.DataLoader(testset, 51 | batch_size=1, 52 | num_workers=2, 53 | shuffle=False, \ 54 | # pin_memory=True 55 | ) 56 | faster_rcnn = FasterRCNNVGG16() 57 | print('model construct completed') 58 | trainer = FasterRCNNTrainer(faster_rcnn).cuda() 59 | if opt.load_path: 60 | trainer.load(opt.load_path) 61 | print('load pretrained model from %s' % opt.load_path) 62 | 63 | trainer.vis.text(dataset.db.label_names, win='labels') 64 | best_map = 0 65 | for epoch in range(7): 66 | trainer.reset_meters() 67 | for ii, (img, bbox_, label_, scale, ori_img) in tqdm(enumerate(dataloader)): 68 | scale = at.scalar(scale) 69 | img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda() 70 | losses = trainer.train_step(img, bbox, label, scale) 71 | 72 | if (ii + 1) % opt.plot_every == 0: 73 | if os.path.exists(opt.debug_file): 74 | ipdb.set_trace() 75 | 76 | # plot loss 77 | trainer.vis.plot_many(trainer.get_meter_data()) 78 | 79 | # plot groud truth bboxes 80 | ori_img_ = (img * 0.225 + 0.45).clamp(min=0, max=1) * 255 81 | gt_img = visdom_bbox(at.tonumpy(ori_img_)[0], 82 | at.tonumpy(bbox_)[0], 83 | label_[0].numpy()) 84 | trainer.vis.img('gt_img', gt_img) 85 | 86 | # plot predicti bboxes 87 | _bboxes, _labels, _scores = trainer.faster_rcnn.predict(ori_img,visualize=True) 88 | pred_img = visdom_bbox( at.tonumpy(ori_img[0]), 89 | at.tonumpy(_bboxes[0]), 90 | at.tonumpy(_labels[0]).reshape(-1), 91 | at.tonumpy(_scores[0])) 92 | trainer.vis.img('pred_img', pred_img) 93 | 94 | # rpn confusion matrix(meter) 95 | trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm') 96 | # roi confusion matrix 97 | trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float()) 98 | if epoch==4: 99 | trainer.faster_rcnn.scale_lr(opt.lr_decay) 100 | 101 | eval_result = eval(test_dataloader, faster_rcnn, test_num=1e100) 102 | print('eval_result') 103 | trainer.save(mAP=eval_result['map']) 104 | 105 | if __name__ == '__main__': 106 | import fire 107 | 108 | fire.Fire() 109 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.faster_rcnn_vgg16 import FasterRCNNVGG16 2 | -------------------------------------------------------------------------------- /model/faster_rcnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | import torch as t 4 | import numpy as np 5 | from utils import array_tool as at 6 | from model.utils.bbox_tools import loc2bbox 7 | from torchvision.ops import nms 8 | # from model.utils.nms import non_maximum_suppression 9 | 10 | from torch import nn 11 | from data.dataset import preprocess 12 | from torch.nn import functional as F 13 | from utils.config import opt 14 | 15 | 16 | def nograd(f): 17 | def new_f(*args,**kwargs): 18 | with t.no_grad(): 19 | return f(*args,**kwargs) 20 | return new_f 21 | 22 | class FasterRCNN(nn.Module): 23 | """Base class for Faster R-CNN. 24 | 25 | This is a base class for Faster R-CNN links supporting object detection 26 | API [#]_. The following three stages constitute Faster R-CNN. 27 | 28 | 1. **Feature extraction**: Images are taken and their \ 29 | feature maps are calculated. 30 | 2. **Region Proposal Networks**: Given the feature maps calculated in \ 31 | the previous stage, produce set of RoIs around objects. 32 | 3. **Localization and Classification Heads**: Using feature maps that \ 33 | belong to the proposed RoIs, classify the categories of the objects \ 34 | in the RoIs and improve localizations. 35 | 36 | Each stage is carried out by one of the callable 37 | :class:`torch.nn.Module` objects :obj:`feature`, :obj:`rpn` and :obj:`head`. 38 | 39 | There are two functions :meth:`predict` and :meth:`__call__` to conduct 40 | object detection. 41 | :meth:`predict` takes images and returns bounding boxes that are converted 42 | to image coordinates. This will be useful for a scenario when 43 | Faster R-CNN is treated as a black box function, for instance. 44 | :meth:`__call__` is provided for a scnerario when intermediate outputs 45 | are needed, for instance, for training and debugging. 46 | 47 | Links that support obejct detection API have method :meth:`predict` with 48 | the same interface. Please refer to :meth:`predict` for 49 | further details. 50 | 51 | .. [#] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \ 52 | Faster R-CNN: Towards Real-Time Object Detection with \ 53 | Region Proposal Networks. NIPS 2015. 54 | 55 | Args: 56 | extractor (nn.Module): A module that takes a BCHW image 57 | array and returns feature maps. 58 | rpn (nn.Module): A module that has the same interface as 59 | :class:`model.region_proposal_network.RegionProposalNetwork`. 60 | Please refer to the documentation found there. 61 | head (nn.Module): A module that takes 62 | a BCHW variable, RoIs and batch indices for RoIs. This returns class 63 | dependent localization paramters and class scores. 64 | loc_normalize_mean (tuple of four floats): Mean values of 65 | localization estimates. 66 | loc_normalize_std (tupler of four floats): Standard deviation 67 | of localization estimates. 68 | 69 | """ 70 | 71 | def __init__(self, extractor, rpn, head, 72 | loc_normalize_mean = (0., 0., 0., 0.), 73 | loc_normalize_std = (0.1, 0.1, 0.2, 0.2) 74 | ): 75 | super(FasterRCNN, self).__init__() 76 | self.extractor = extractor 77 | self.rpn = rpn 78 | self.head = head 79 | 80 | # mean and std 81 | self.loc_normalize_mean = loc_normalize_mean 82 | self.loc_normalize_std = loc_normalize_std 83 | self.use_preset('evaluate') 84 | 85 | @property 86 | def n_class(self): 87 | # Total number of classes including the background. 88 | return self.head.n_class 89 | 90 | def forward(self, x, scale=1.): 91 | """Forward Faster R-CNN. 92 | 93 | Scaling paramter :obj:`scale` is used by RPN to determine the 94 | threshold to select small objects, which are going to be 95 | rejected irrespective of their confidence scores. 96 | 97 | Here are notations used. 98 | 99 | * :math:`N` is the number of batch size 100 | * :math:`R'` is the total number of RoIs produced across batches. \ 101 | Given :math:`R_i` proposed RoIs from the :math:`i` th image, \ 102 | :math:`R' = \\sum _{i=1} ^ N R_i`. 103 | * :math:`L` is the number of classes excluding the background. 104 | 105 | Classes are ordered by the background, the first class, ..., and 106 | the :math:`L` th class. 107 | 108 | Args: 109 | x (autograd.Variable): 4D image variable. 110 | scale (float): Amount of scaling applied to the raw image 111 | during preprocessing. 112 | 113 | Returns: 114 | Variable, Variable, array, array: 115 | Returns tuple of four values listed below. 116 | 117 | * **roi_cls_locs**: Offsets and scalings for the proposed RoIs. \ 118 | Its shape is :math:`(R', (L + 1) \\times 4)`. 119 | * **roi_scores**: Class predictions for the proposed RoIs. \ 120 | Its shape is :math:`(R', L + 1)`. 121 | * **rois**: RoIs proposed by RPN. Its shape is \ 122 | :math:`(R', 4)`. 123 | * **roi_indices**: Batch indices of RoIs. Its shape is \ 124 | :math:`(R',)`. 125 | 126 | """ 127 | img_size = x.shape[2:] 128 | 129 | h = self.extractor(x) 130 | rpn_locs, rpn_scores, rois, roi_indices, anchor = \ 131 | self.rpn(h, img_size, scale) 132 | roi_cls_locs, roi_scores = self.head( 133 | h, rois, roi_indices) 134 | return roi_cls_locs, roi_scores, rois, roi_indices 135 | 136 | def use_preset(self, preset): 137 | """Use the given preset during prediction. 138 | 139 | This method changes values of :obj:`self.nms_thresh` and 140 | :obj:`self.score_thresh`. These values are a threshold value 141 | used for non maximum suppression and a threshold value 142 | to discard low confidence proposals in :meth:`predict`, 143 | respectively. 144 | 145 | If the attributes need to be changed to something 146 | other than the values provided in the presets, please modify 147 | them by directly accessing the public attributes. 148 | 149 | Args: 150 | preset ({'visualize', 'evaluate'): A string to determine the 151 | preset to use. 152 | 153 | """ 154 | if preset == 'visualize': 155 | self.nms_thresh = 0.3 156 | self.score_thresh = 0.7 157 | elif preset == 'evaluate': 158 | self.nms_thresh = 0.3 159 | self.score_thresh = 0.05 160 | else: 161 | raise ValueError('preset must be visualize or evaluate') 162 | 163 | def _suppress(self, raw_cls_bbox, raw_prob): 164 | bbox = list() 165 | label = list() 166 | score = list() 167 | # skip cls_id = 0 because it is the background class 168 | for l in range(1, self.n_class): 169 | cls_bbox_l = raw_cls_bbox.reshape((-1, self.n_class, 4))[:, l, :] 170 | prob_l = raw_prob[:, l] 171 | mask = prob_l > self.score_thresh 172 | cls_bbox_l = cls_bbox_l[mask] 173 | prob_l = prob_l[mask] 174 | keep = nms(cls_bbox_l, prob_l,self.nms_thresh) 175 | # import ipdb;ipdb.set_trace() 176 | # keep = cp.asnumpy(keep) 177 | bbox.append(cls_bbox_l[keep].cpu().numpy()) 178 | # The labels are in [0, self.n_class - 2]. 179 | label.append((l - 1) * np.ones((len(keep),))) 180 | score.append(prob_l[keep].cpu().numpy()) 181 | bbox = np.concatenate(bbox, axis=0).astype(np.float32) 182 | label = np.concatenate(label, axis=0).astype(np.int32) 183 | score = np.concatenate(score, axis=0).astype(np.float32) 184 | return bbox, label, score 185 | 186 | @nograd 187 | def predict(self, imgs,sizes=None,visualize=False): 188 | """Detect objects from images. 189 | 190 | This method predicts objects for each image. 191 | 192 | Args: 193 | imgs (iterable of numpy.ndarray): Arrays holding images. 194 | All images are in CHW and RGB format 195 | and the range of their value is :math:`[0, 255]`. 196 | 197 | Returns: 198 | tuple of lists: 199 | This method returns a tuple of three lists, 200 | :obj:`(bboxes, labels, scores)`. 201 | 202 | * **bboxes**: A list of float arrays of shape :math:`(R, 4)`, \ 203 | where :math:`R` is the number of bounding boxes in a image. \ 204 | Each bouding box is organized by \ 205 | :math:`(y_{min}, x_{min}, y_{max}, x_{max})` \ 206 | in the second axis. 207 | * **labels** : A list of integer arrays of shape :math:`(R,)`. \ 208 | Each value indicates the class of the bounding box. \ 209 | Values are in range :math:`[0, L - 1]`, where :math:`L` is the \ 210 | number of the foreground classes. 211 | * **scores** : A list of float arrays of shape :math:`(R,)`. \ 212 | Each value indicates how confident the prediction is. 213 | 214 | """ 215 | self.eval() 216 | if visualize: 217 | self.use_preset('visualize') 218 | prepared_imgs = list() 219 | sizes = list() 220 | for img in imgs: 221 | size = img.shape[1:] 222 | img = preprocess(at.tonumpy(img)) 223 | prepared_imgs.append(img) 224 | sizes.append(size) 225 | else: 226 | prepared_imgs = imgs 227 | bboxes = list() 228 | labels = list() 229 | scores = list() 230 | for img, size in zip(prepared_imgs, sizes): 231 | img = at.totensor(img[None]).float() 232 | scale = img.shape[3] / size[1] 233 | roi_cls_loc, roi_scores, rois, _ = self(img, scale=scale) 234 | # We are assuming that batch size is 1. 235 | roi_score = roi_scores.data 236 | roi_cls_loc = roi_cls_loc.data 237 | roi = at.totensor(rois) / scale 238 | 239 | # Convert predictions to bounding boxes in image coordinates. 240 | # Bounding boxes are scaled to the scale of the input images. 241 | mean = t.Tensor(self.loc_normalize_mean).cuda(). \ 242 | repeat(self.n_class)[None] 243 | std = t.Tensor(self.loc_normalize_std).cuda(). \ 244 | repeat(self.n_class)[None] 245 | 246 | roi_cls_loc = (roi_cls_loc * std + mean) 247 | roi_cls_loc = roi_cls_loc.view(-1, self.n_class, 4) 248 | roi = roi.view(-1, 1, 4).expand_as(roi_cls_loc) 249 | cls_bbox = loc2bbox(at.tonumpy(roi).reshape((-1, 4)), 250 | at.tonumpy(roi_cls_loc).reshape((-1, 4))) 251 | cls_bbox = at.totensor(cls_bbox) 252 | cls_bbox = cls_bbox.view(-1, self.n_class * 4) 253 | # clip bounding box 254 | cls_bbox[:, 0::2] = (cls_bbox[:, 0::2]).clamp(min=0, max=size[0]) 255 | cls_bbox[:, 1::2] = (cls_bbox[:, 1::2]).clamp(min=0, max=size[1]) 256 | 257 | prob = (F.softmax(at.totensor(roi_score), dim=1)) 258 | 259 | bbox, label, score = self._suppress(cls_bbox, prob) 260 | bboxes.append(bbox) 261 | labels.append(label) 262 | scores.append(score) 263 | 264 | self.use_preset('evaluate') 265 | self.train() 266 | return bboxes, labels, scores 267 | 268 | def get_optimizer(self): 269 | """ 270 | return optimizer, It could be overwriten if you want to specify 271 | special optimizer 272 | """ 273 | lr = opt.lr 274 | params = [] 275 | for key, value in dict(self.named_parameters()).items(): 276 | if value.requires_grad: 277 | if 'bias' in key: 278 | params += [{'params': [value], 'lr': lr * 2, 'weight_decay': 0}] 279 | else: 280 | params += [{'params': [value], 'lr': lr, 'weight_decay': opt.weight_decay}] 281 | if opt.use_adam: 282 | self.optimizer = t.optim.Adam(params) 283 | else: 284 | self.optimizer = t.optim.SGD(params, momentum=0.9) 285 | return self.optimizer 286 | 287 | def scale_lr(self, decay=0.1): 288 | for param_group in self.optimizer.param_groups: 289 | param_group['lr'] *= decay 290 | return self.optimizer 291 | 292 | 293 | 294 | 295 | -------------------------------------------------------------------------------- /model/faster_rcnn_vgg16.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch as t 3 | from torch import nn 4 | from torchvision.models import vgg16 5 | from torchvision.ops import RoIPool 6 | 7 | from model.region_proposal_network import RegionProposalNetwork 8 | from model.faster_rcnn import FasterRCNN 9 | from utils import array_tool as at 10 | from utils.config import opt 11 | 12 | 13 | def decom_vgg16(): 14 | # the 30th layer of features is relu of conv5_3 15 | if opt.caffe_pretrain: 16 | model = vgg16(pretrained=False) 17 | if not opt.load_path: 18 | model.load_state_dict(t.load(opt.caffe_pretrain_path)) 19 | else: 20 | model = vgg16(not opt.load_path) 21 | 22 | features = list(model.features)[:30] 23 | classifier = model.classifier 24 | 25 | classifier = list(classifier) 26 | del classifier[6] 27 | if not opt.use_drop: 28 | del classifier[5] 29 | del classifier[2] 30 | classifier = nn.Sequential(*classifier) 31 | 32 | # freeze top4 conv 33 | for layer in features[:10]: 34 | for p in layer.parameters(): 35 | p.requires_grad = False 36 | 37 | return nn.Sequential(*features), classifier 38 | 39 | 40 | class FasterRCNNVGG16(FasterRCNN): 41 | """Faster R-CNN based on VGG-16. 42 | For descriptions on the interface of this model, please refer to 43 | :class:`model.faster_rcnn.FasterRCNN`. 44 | 45 | Args: 46 | n_fg_class (int): The number of classes excluding the background. 47 | ratios (list of floats): This is ratios of width to height of 48 | the anchors. 49 | anchor_scales (list of numbers): This is areas of anchors. 50 | Those areas will be the product of the square of an element in 51 | :obj:`anchor_scales` and the original area of the reference 52 | window. 53 | 54 | """ 55 | 56 | feat_stride = 16 # downsample 16x for output of conv5 in vgg16 57 | 58 | def __init__(self, 59 | n_fg_class=20, 60 | ratios=[0.5, 1, 2], 61 | anchor_scales=[8, 16, 32] 62 | ): 63 | 64 | extractor, classifier = decom_vgg16() 65 | 66 | rpn = RegionProposalNetwork( 67 | 512, 512, 68 | ratios=ratios, 69 | anchor_scales=anchor_scales, 70 | feat_stride=self.feat_stride, 71 | ) 72 | 73 | head = VGG16RoIHead( 74 | n_class=n_fg_class + 1, 75 | roi_size=7, 76 | spatial_scale=(1. / self.feat_stride), 77 | classifier=classifier 78 | ) 79 | 80 | super(FasterRCNNVGG16, self).__init__( 81 | extractor, 82 | rpn, 83 | head, 84 | ) 85 | 86 | 87 | class VGG16RoIHead(nn.Module): 88 | """Faster R-CNN Head for VGG-16 based implementation. 89 | This class is used as a head for Faster R-CNN. 90 | This outputs class-wise localizations and classification based on feature 91 | maps in the given RoIs. 92 | 93 | Args: 94 | n_class (int): The number of classes possibly including the background. 95 | roi_size (int): Height and width of the feature maps after RoI-pooling. 96 | spatial_scale (float): Scale of the roi is resized. 97 | classifier (nn.Module): Two layer Linear ported from vgg16 98 | 99 | """ 100 | 101 | def __init__(self, n_class, roi_size, spatial_scale, 102 | classifier): 103 | # n_class includes the background 104 | super(VGG16RoIHead, self).__init__() 105 | 106 | self.classifier = classifier 107 | self.cls_loc = nn.Linear(4096, n_class * 4) 108 | self.score = nn.Linear(4096, n_class) 109 | 110 | normal_init(self.cls_loc, 0, 0.001) 111 | normal_init(self.score, 0, 0.01) 112 | 113 | self.n_class = n_class 114 | self.roi_size = roi_size 115 | self.spatial_scale = spatial_scale 116 | self.roi = RoIPool( (self.roi_size, self.roi_size),self.spatial_scale) 117 | 118 | def forward(self, x, rois, roi_indices): 119 | """Forward the chain. 120 | 121 | We assume that there are :math:`N` batches. 122 | 123 | Args: 124 | x (Variable): 4D image variable. 125 | rois (Tensor): A bounding box array containing coordinates of 126 | proposal boxes. This is a concatenation of bounding box 127 | arrays from multiple images in the batch. 128 | Its shape is :math:`(R', 4)`. Given :math:`R_i` proposed 129 | RoIs from the :math:`i` th image, 130 | :math:`R' = \\sum _{i=1} ^ N R_i`. 131 | roi_indices (Tensor): An array containing indices of images to 132 | which bounding boxes correspond to. Its shape is :math:`(R',)`. 133 | 134 | """ 135 | # in case roi_indices is ndarray 136 | roi_indices = at.totensor(roi_indices).float() 137 | rois = at.totensor(rois).float() 138 | indices_and_rois = t.cat([roi_indices[:, None], rois], dim=1) 139 | # NOTE: important: yx->xy 140 | xy_indices_and_rois = indices_and_rois[:, [0, 2, 1, 4, 3]] 141 | indices_and_rois = xy_indices_and_rois.contiguous() 142 | 143 | pool = self.roi(x, indices_and_rois) 144 | pool = pool.view(pool.size(0), -1) 145 | fc7 = self.classifier(pool) 146 | roi_cls_locs = self.cls_loc(fc7) 147 | roi_scores = self.score(fc7) 148 | return roi_cls_locs, roi_scores 149 | 150 | 151 | def normal_init(m, mean, stddev, truncated=False): 152 | """ 153 | weight initalizer: truncated normal and random normal. 154 | """ 155 | # x is a parameter 156 | if truncated: 157 | m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation 158 | else: 159 | m.weight.data.normal_(mean, stddev) 160 | m.bias.data.zero_() 161 | -------------------------------------------------------------------------------- /model/region_proposal_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn import functional as F 3 | import torch as t 4 | from torch import nn 5 | 6 | from model.utils.bbox_tools import generate_anchor_base 7 | from model.utils.creator_tool import ProposalCreator 8 | 9 | 10 | class RegionProposalNetwork(nn.Module): 11 | """Region Proposal Network introduced in Faster R-CNN. 12 | 13 | This is Region Proposal Network introduced in Faster R-CNN [#]_. 14 | This takes features extracted from images and propose 15 | class agnostic bounding boxes around "objects". 16 | 17 | .. [#] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \ 18 | Faster R-CNN: Towards Real-Time Object Detection with \ 19 | Region Proposal Networks. NIPS 2015. 20 | 21 | Args: 22 | in_channels (int): The channel size of input. 23 | mid_channels (int): The channel size of the intermediate tensor. 24 | ratios (list of floats): This is ratios of width to height of 25 | the anchors. 26 | anchor_scales (list of numbers): This is areas of anchors. 27 | Those areas will be the product of the square of an element in 28 | :obj:`anchor_scales` and the original area of the reference 29 | window. 30 | feat_stride (int): Stride size after extracting features from an 31 | image. 32 | initialW (callable): Initial weight value. If :obj:`None` then this 33 | function uses Gaussian distribution scaled by 0.1 to 34 | initialize weight. 35 | May also be a callable that takes an array and edits its values. 36 | proposal_creator_params (dict): Key valued paramters for 37 | :class:`model.utils.creator_tools.ProposalCreator`. 38 | 39 | .. seealso:: 40 | :class:`~model.utils.creator_tools.ProposalCreator` 41 | 42 | """ 43 | 44 | def __init__( 45 | self, in_channels=512, mid_channels=512, ratios=[0.5, 1, 2], 46 | anchor_scales=[8, 16, 32], feat_stride=16, 47 | proposal_creator_params=dict(), 48 | ): 49 | super(RegionProposalNetwork, self).__init__() 50 | self.anchor_base = generate_anchor_base( 51 | anchor_scales=anchor_scales, ratios=ratios) 52 | self.feat_stride = feat_stride 53 | self.proposal_layer = ProposalCreator(self, **proposal_creator_params) 54 | n_anchor = self.anchor_base.shape[0] 55 | self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1) 56 | self.score = nn.Conv2d(mid_channels, n_anchor * 2, 1, 1, 0) 57 | self.loc = nn.Conv2d(mid_channels, n_anchor * 4, 1, 1, 0) 58 | normal_init(self.conv1, 0, 0.01) 59 | normal_init(self.score, 0, 0.01) 60 | normal_init(self.loc, 0, 0.01) 61 | 62 | def forward(self, x, img_size, scale=1.): 63 | """Forward Region Proposal Network. 64 | 65 | Here are notations. 66 | 67 | * :math:`N` is batch size. 68 | * :math:`C` channel size of the input. 69 | * :math:`H` and :math:`W` are height and witdh of the input feature. 70 | * :math:`A` is number of anchors assigned to each pixel. 71 | 72 | Args: 73 | x (~torch.autograd.Variable): The Features extracted from images. 74 | Its shape is :math:`(N, C, H, W)`. 75 | img_size (tuple of ints): A tuple :obj:`height, width`, 76 | which contains image size after scaling. 77 | scale (float): The amount of scaling done to the input images after 78 | reading them from files. 79 | 80 | Returns: 81 | (~torch.autograd.Variable, ~torch.autograd.Variable, array, array, array): 82 | 83 | This is a tuple of five following values. 84 | 85 | * **rpn_locs**: Predicted bounding box offsets and scales for \ 86 | anchors. Its shape is :math:`(N, H W A, 4)`. 87 | * **rpn_scores**: Predicted foreground scores for \ 88 | anchors. Its shape is :math:`(N, H W A, 2)`. 89 | * **rois**: A bounding box array containing coordinates of \ 90 | proposal boxes. This is a concatenation of bounding box \ 91 | arrays from multiple images in the batch. \ 92 | Its shape is :math:`(R', 4)`. Given :math:`R_i` predicted \ 93 | bounding boxes from the :math:`i` th image, \ 94 | :math:`R' = \\sum _{i=1} ^ N R_i`. 95 | * **roi_indices**: An array containing indices of images to \ 96 | which RoIs correspond to. Its shape is :math:`(R',)`. 97 | * **anchor**: Coordinates of enumerated shifted anchors. \ 98 | Its shape is :math:`(H W A, 4)`. 99 | 100 | """ 101 | n, _, hh, ww = x.shape 102 | anchor = _enumerate_shifted_anchor( 103 | np.array(self.anchor_base), 104 | self.feat_stride, hh, ww) 105 | 106 | n_anchor = anchor.shape[0] // (hh * ww) 107 | h = F.relu(self.conv1(x)) 108 | 109 | rpn_locs = self.loc(h) 110 | # UNNOTE: check whether need contiguous 111 | # A: Yes 112 | rpn_locs = rpn_locs.permute(0, 2, 3, 1).contiguous().view(n, -1, 4) 113 | rpn_scores = self.score(h) 114 | rpn_scores = rpn_scores.permute(0, 2, 3, 1).contiguous() 115 | rpn_softmax_scores = F.softmax(rpn_scores.view(n, hh, ww, n_anchor, 2), dim=4) 116 | rpn_fg_scores = rpn_softmax_scores[:, :, :, :, 1].contiguous() 117 | rpn_fg_scores = rpn_fg_scores.view(n, -1) 118 | rpn_scores = rpn_scores.view(n, -1, 2) 119 | 120 | rois = list() 121 | roi_indices = list() 122 | for i in range(n): 123 | roi = self.proposal_layer( 124 | rpn_locs[i].cpu().data.numpy(), 125 | rpn_fg_scores[i].cpu().data.numpy(), 126 | anchor, img_size, 127 | scale=scale) 128 | batch_index = i * np.ones((len(roi),), dtype=np.int32) 129 | rois.append(roi) 130 | roi_indices.append(batch_index) 131 | 132 | rois = np.concatenate(rois, axis=0) 133 | roi_indices = np.concatenate(roi_indices, axis=0) 134 | return rpn_locs, rpn_scores, rois, roi_indices, anchor 135 | 136 | 137 | def _enumerate_shifted_anchor(anchor_base, feat_stride, height, width): 138 | # Enumerate all shifted anchors: 139 | # 140 | # add A anchors (1, A, 4) to 141 | # cell K shifts (K, 1, 4) to get 142 | # shift anchors (K, A, 4) 143 | # reshape to (K*A, 4) shifted anchors 144 | # return (K*A, 4) 145 | 146 | # !TODO: add support for torch.CudaTensor 147 | # xp = cuda.get_array_module(anchor_base) 148 | # it seems that it can't be boosed using GPU 149 | import numpy as xp 150 | shift_y = xp.arange(0, height * feat_stride, feat_stride) 151 | shift_x = xp.arange(0, width * feat_stride, feat_stride) 152 | shift_x, shift_y = xp.meshgrid(shift_x, shift_y) 153 | shift = xp.stack((shift_y.ravel(), shift_x.ravel(), 154 | shift_y.ravel(), shift_x.ravel()), axis=1) 155 | 156 | A = anchor_base.shape[0] 157 | K = shift.shape[0] 158 | anchor = anchor_base.reshape((1, A, 4)) + \ 159 | shift.reshape((1, K, 4)).transpose((1, 0, 2)) 160 | anchor = anchor.reshape((K * A, 4)).astype(np.float32) 161 | return anchor 162 | 163 | 164 | def _enumerate_shifted_anchor_torch(anchor_base, feat_stride, height, width): 165 | # Enumerate all shifted anchors: 166 | # 167 | # add A anchors (1, A, 4) to 168 | # cell K shifts (K, 1, 4) to get 169 | # shift anchors (K, A, 4) 170 | # reshape to (K*A, 4) shifted anchors 171 | # return (K*A, 4) 172 | 173 | # !TODO: add support for torch.CudaTensor 174 | # xp = cuda.get_array_module(anchor_base) 175 | import torch as t 176 | shift_y = t.arange(0, height * feat_stride, feat_stride) 177 | shift_x = t.arange(0, width * feat_stride, feat_stride) 178 | shift_x, shift_y = xp.meshgrid(shift_x, shift_y) 179 | shift = xp.stack((shift_y.ravel(), shift_x.ravel(), 180 | shift_y.ravel(), shift_x.ravel()), axis=1) 181 | 182 | A = anchor_base.shape[0] 183 | K = shift.shape[0] 184 | anchor = anchor_base.reshape((1, A, 4)) + \ 185 | shift.reshape((1, K, 4)).transpose((1, 0, 2)) 186 | anchor = anchor.reshape((K * A, 4)).astype(np.float32) 187 | return anchor 188 | 189 | 190 | def normal_init(m, mean, stddev, truncated=False): 191 | """ 192 | weight initalizer: truncated normal and random normal. 193 | """ 194 | # x is a parameter 195 | if truncated: 196 | m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation 197 | else: 198 | m.weight.data.normal_(mean, stddev) 199 | m.bias.data.zero_() 200 | -------------------------------------------------------------------------------- /model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenyuntc/simple-faster-rcnn-pytorch/367db367834efd8a2bc58ee0023b2b628a0e474d/model/utils/__init__.py -------------------------------------------------------------------------------- /model/utils/bbox_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy as xp 3 | 4 | import six 5 | from six import __init__ 6 | 7 | 8 | def loc2bbox(src_bbox, loc): 9 | """Decode bounding boxes from bounding box offsets and scales. 10 | 11 | Given bounding box offsets and scales computed by 12 | :meth:`bbox2loc`, this function decodes the representation to 13 | coordinates in 2D image coordinates. 14 | 15 | Given scales and offsets :math:`t_y, t_x, t_h, t_w` and a bounding 16 | box whose center is :math:`(y, x) = p_y, p_x` and size :math:`p_h, p_w`, 17 | the decoded bounding box's center :math:`\\hat{g}_y`, :math:`\\hat{g}_x` 18 | and size :math:`\\hat{g}_h`, :math:`\\hat{g}_w` are calculated 19 | by the following formulas. 20 | 21 | * :math:`\\hat{g}_y = p_h t_y + p_y` 22 | * :math:`\\hat{g}_x = p_w t_x + p_x` 23 | * :math:`\\hat{g}_h = p_h \\exp(t_h)` 24 | * :math:`\\hat{g}_w = p_w \\exp(t_w)` 25 | 26 | The decoding formulas are used in works such as R-CNN [#]_. 27 | 28 | The output is same type as the type of the inputs. 29 | 30 | .. [#] Ross Girshick, Jeff Donahue, Trevor Darrell, Jitendra Malik. \ 31 | Rich feature hierarchies for accurate object detection and semantic \ 32 | segmentation. CVPR 2014. 33 | 34 | Args: 35 | src_bbox (array): A coordinates of bounding boxes. 36 | Its shape is :math:`(R, 4)`. These coordinates are 37 | :math:`p_{ymin}, p_{xmin}, p_{ymax}, p_{xmax}`. 38 | loc (array): An array with offsets and scales. 39 | The shapes of :obj:`src_bbox` and :obj:`loc` should be same. 40 | This contains values :math:`t_y, t_x, t_h, t_w`. 41 | 42 | Returns: 43 | array: 44 | Decoded bounding box coordinates. Its shape is :math:`(R, 4)`. \ 45 | The second axis contains four values \ 46 | :math:`\\hat{g}_{ymin}, \\hat{g}_{xmin}, 47 | \\hat{g}_{ymax}, \\hat{g}_{xmax}`. 48 | 49 | """ 50 | 51 | if src_bbox.shape[0] == 0: 52 | return xp.zeros((0, 4), dtype=loc.dtype) 53 | 54 | src_bbox = src_bbox.astype(src_bbox.dtype, copy=False) 55 | 56 | src_height = src_bbox[:, 2] - src_bbox[:, 0] 57 | src_width = src_bbox[:, 3] - src_bbox[:, 1] 58 | src_ctr_y = src_bbox[:, 0] + 0.5 * src_height 59 | src_ctr_x = src_bbox[:, 1] + 0.5 * src_width 60 | 61 | dy = loc[:, 0::4] 62 | dx = loc[:, 1::4] 63 | dh = loc[:, 2::4] 64 | dw = loc[:, 3::4] 65 | 66 | ctr_y = dy * src_height[:, xp.newaxis] + src_ctr_y[:, xp.newaxis] 67 | ctr_x = dx * src_width[:, xp.newaxis] + src_ctr_x[:, xp.newaxis] 68 | h = xp.exp(dh) * src_height[:, xp.newaxis] 69 | w = xp.exp(dw) * src_width[:, xp.newaxis] 70 | 71 | dst_bbox = xp.zeros(loc.shape, dtype=loc.dtype) 72 | dst_bbox[:, 0::4] = ctr_y - 0.5 * h 73 | dst_bbox[:, 1::4] = ctr_x - 0.5 * w 74 | dst_bbox[:, 2::4] = ctr_y + 0.5 * h 75 | dst_bbox[:, 3::4] = ctr_x + 0.5 * w 76 | 77 | return dst_bbox 78 | 79 | 80 | def bbox2loc(src_bbox, dst_bbox): 81 | """Encodes the source and the destination bounding boxes to "loc". 82 | 83 | Given bounding boxes, this function computes offsets and scales 84 | to match the source bounding boxes to the target bounding boxes. 85 | Mathematcially, given a bounding box whose center is 86 | :math:`(y, x) = p_y, p_x` and 87 | size :math:`p_h, p_w` and the target bounding box whose center is 88 | :math:`g_y, g_x` and size :math:`g_h, g_w`, the offsets and scales 89 | :math:`t_y, t_x, t_h, t_w` can be computed by the following formulas. 90 | 91 | * :math:`t_y = \\frac{(g_y - p_y)} {p_h}` 92 | * :math:`t_x = \\frac{(g_x - p_x)} {p_w}` 93 | * :math:`t_h = \\log(\\frac{g_h} {p_h})` 94 | * :math:`t_w = \\log(\\frac{g_w} {p_w})` 95 | 96 | The output is same type as the type of the inputs. 97 | The encoding formulas are used in works such as R-CNN [#]_. 98 | 99 | .. [#] Ross Girshick, Jeff Donahue, Trevor Darrell, Jitendra Malik. \ 100 | Rich feature hierarchies for accurate object detection and semantic \ 101 | segmentation. CVPR 2014. 102 | 103 | Args: 104 | src_bbox (array): An image coordinate array whose shape is 105 | :math:`(R, 4)`. :math:`R` is the number of bounding boxes. 106 | These coordinates are 107 | :math:`p_{ymin}, p_{xmin}, p_{ymax}, p_{xmax}`. 108 | dst_bbox (array): An image coordinate array whose shape is 109 | :math:`(R, 4)`. 110 | These coordinates are 111 | :math:`g_{ymin}, g_{xmin}, g_{ymax}, g_{xmax}`. 112 | 113 | Returns: 114 | array: 115 | Bounding box offsets and scales from :obj:`src_bbox` \ 116 | to :obj:`dst_bbox`. \ 117 | This has shape :math:`(R, 4)`. 118 | The second axis contains four values :math:`t_y, t_x, t_h, t_w`. 119 | 120 | """ 121 | 122 | height = src_bbox[:, 2] - src_bbox[:, 0] 123 | width = src_bbox[:, 3] - src_bbox[:, 1] 124 | ctr_y = src_bbox[:, 0] + 0.5 * height 125 | ctr_x = src_bbox[:, 1] + 0.5 * width 126 | 127 | base_height = dst_bbox[:, 2] - dst_bbox[:, 0] 128 | base_width = dst_bbox[:, 3] - dst_bbox[:, 1] 129 | base_ctr_y = dst_bbox[:, 0] + 0.5 * base_height 130 | base_ctr_x = dst_bbox[:, 1] + 0.5 * base_width 131 | 132 | eps = xp.finfo(height.dtype).eps 133 | height = xp.maximum(height, eps) 134 | width = xp.maximum(width, eps) 135 | 136 | dy = (base_ctr_y - ctr_y) / height 137 | dx = (base_ctr_x - ctr_x) / width 138 | dh = xp.log(base_height / height) 139 | dw = xp.log(base_width / width) 140 | 141 | loc = xp.vstack((dy, dx, dh, dw)).transpose() 142 | return loc 143 | 144 | 145 | def bbox_iou(bbox_a, bbox_b): 146 | """Calculate the Intersection of Unions (IoUs) between bounding boxes. 147 | 148 | IoU is calculated as a ratio of area of the intersection 149 | and area of the union. 150 | 151 | This function accepts both :obj:`numpy.ndarray` and :obj:`cupy.ndarray` as 152 | inputs. Please note that both :obj:`bbox_a` and :obj:`bbox_b` need to be 153 | same type. 154 | The output is same type as the type of the inputs. 155 | 156 | Args: 157 | bbox_a (array): An array whose shape is :math:`(N, 4)`. 158 | :math:`N` is the number of bounding boxes. 159 | The dtype should be :obj:`numpy.float32`. 160 | bbox_b (array): An array similar to :obj:`bbox_a`, 161 | whose shape is :math:`(K, 4)`. 162 | The dtype should be :obj:`numpy.float32`. 163 | 164 | Returns: 165 | array: 166 | An array whose shape is :math:`(N, K)`. \ 167 | An element at index :math:`(n, k)` contains IoUs between \ 168 | :math:`n` th bounding box in :obj:`bbox_a` and :math:`k` th bounding \ 169 | box in :obj:`bbox_b`. 170 | 171 | """ 172 | if bbox_a.shape[1] != 4 or bbox_b.shape[1] != 4: 173 | raise IndexError 174 | 175 | # top left 176 | tl = xp.maximum(bbox_a[:, None, :2], bbox_b[:, :2]) 177 | # bottom right 178 | br = xp.minimum(bbox_a[:, None, 2:], bbox_b[:, 2:]) 179 | 180 | area_i = xp.prod(br - tl, axis=2) * (tl < br).all(axis=2) 181 | area_a = xp.prod(bbox_a[:, 2:] - bbox_a[:, :2], axis=1) 182 | area_b = xp.prod(bbox_b[:, 2:] - bbox_b[:, :2], axis=1) 183 | return area_i / (area_a[:, None] + area_b - area_i) 184 | 185 | 186 | def __test(): 187 | pass 188 | 189 | 190 | if __name__ == '__main__': 191 | __test() 192 | 193 | 194 | def generate_anchor_base(base_size=16, ratios=[0.5, 1, 2], 195 | anchor_scales=[8, 16, 32]): 196 | """Generate anchor base windows by enumerating aspect ratio and scales. 197 | 198 | Generate anchors that are scaled and modified to the given aspect ratios. 199 | Area of a scaled anchor is preserved when modifying to the given aspect 200 | ratio. 201 | 202 | :obj:`R = len(ratios) * len(anchor_scales)` anchors are generated by this 203 | function. 204 | The :obj:`i * len(anchor_scales) + j` th anchor corresponds to an anchor 205 | generated by :obj:`ratios[i]` and :obj:`anchor_scales[j]`. 206 | 207 | For example, if the scale is :math:`8` and the ratio is :math:`0.25`, 208 | the width and the height of the base window will be stretched by :math:`8`. 209 | For modifying the anchor to the given aspect ratio, 210 | the height is halved and the width is doubled. 211 | 212 | Args: 213 | base_size (number): The width and the height of the reference window. 214 | ratios (list of floats): This is ratios of width to height of 215 | the anchors. 216 | anchor_scales (list of numbers): This is areas of anchors. 217 | Those areas will be the product of the square of an element in 218 | :obj:`anchor_scales` and the original area of the reference 219 | window. 220 | 221 | Returns: 222 | ~numpy.ndarray: 223 | An array of shape :math:`(R, 4)`. 224 | Each element is a set of coordinates of a bounding box. 225 | The second axis corresponds to 226 | :math:`(y_{min}, x_{min}, y_{max}, x_{max})` of a bounding box. 227 | 228 | """ 229 | py = base_size / 2. 230 | px = base_size / 2. 231 | 232 | anchor_base = np.zeros((len(ratios) * len(anchor_scales), 4), 233 | dtype=np.float32) 234 | for i in six.moves.range(len(ratios)): 235 | for j in six.moves.range(len(anchor_scales)): 236 | h = base_size * anchor_scales[j] * np.sqrt(ratios[i]) 237 | w = base_size * anchor_scales[j] * np.sqrt(1. / ratios[i]) 238 | 239 | index = i * len(anchor_scales) + j 240 | anchor_base[index, 0] = py - h / 2. 241 | anchor_base[index, 1] = px - w / 2. 242 | anchor_base[index, 2] = py + h / 2. 243 | anchor_base[index, 3] = px + w / 2. 244 | return anchor_base 245 | -------------------------------------------------------------------------------- /model/utils/creator_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torchvision.ops import nms 4 | from model.utils.bbox_tools import bbox2loc, bbox_iou, loc2bbox 5 | 6 | 7 | class ProposalTargetCreator(object): 8 | """Assign ground truth bounding boxes to given RoIs. 9 | 10 | The :meth:`__call__` of this class generates training targets 11 | for each object proposal. 12 | This is used to train Faster RCNN [#]_. 13 | 14 | .. [#] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \ 15 | Faster R-CNN: Towards Real-Time Object Detection with \ 16 | Region Proposal Networks. NIPS 2015. 17 | 18 | Args: 19 | n_sample (int): The number of sampled regions. 20 | pos_ratio (float): Fraction of regions that is labeled as a 21 | foreground. 22 | pos_iou_thresh (float): IoU threshold for a RoI to be considered as a 23 | foreground. 24 | neg_iou_thresh_hi (float): RoI is considered to be the background 25 | if IoU is in 26 | [:obj:`neg_iou_thresh_hi`, :obj:`neg_iou_thresh_hi`). 27 | neg_iou_thresh_lo (float): See above. 28 | 29 | """ 30 | 31 | def __init__(self, 32 | n_sample=128, 33 | pos_ratio=0.25, pos_iou_thresh=0.5, 34 | neg_iou_thresh_hi=0.5, neg_iou_thresh_lo=0.0 35 | ): 36 | self.n_sample = n_sample 37 | self.pos_ratio = pos_ratio 38 | self.pos_iou_thresh = pos_iou_thresh 39 | self.neg_iou_thresh_hi = neg_iou_thresh_hi 40 | self.neg_iou_thresh_lo = neg_iou_thresh_lo # NOTE:default 0.1 in py-faster-rcnn 41 | 42 | def __call__(self, roi, bbox, label, 43 | loc_normalize_mean=(0., 0., 0., 0.), 44 | loc_normalize_std=(0.1, 0.1, 0.2, 0.2)): 45 | """Assigns ground truth to sampled proposals. 46 | 47 | This function samples total of :obj:`self.n_sample` RoIs 48 | from the combination of :obj:`roi` and :obj:`bbox`. 49 | The RoIs are assigned with the ground truth class labels as well as 50 | bounding box offsets and scales to match the ground truth bounding 51 | boxes. As many as :obj:`pos_ratio * self.n_sample` RoIs are 52 | sampled as foregrounds. 53 | 54 | Offsets and scales of bounding boxes are calculated using 55 | :func:`model.utils.bbox_tools.bbox2loc`. 56 | Also, types of input arrays and output arrays are same. 57 | 58 | Here are notations. 59 | 60 | * :math:`S` is the total number of sampled RoIs, which equals \ 61 | :obj:`self.n_sample`. 62 | * :math:`L` is number of object classes possibly including the \ 63 | background. 64 | 65 | Args: 66 | roi (array): Region of Interests (RoIs) from which we sample. 67 | Its shape is :math:`(R, 4)` 68 | bbox (array): The coordinates of ground truth bounding boxes. 69 | Its shape is :math:`(R', 4)`. 70 | label (array): Ground truth bounding box labels. Its shape 71 | is :math:`(R',)`. Its range is :math:`[0, L - 1]`, where 72 | :math:`L` is the number of foreground classes. 73 | loc_normalize_mean (tuple of four floats): Mean values to normalize 74 | coordinates of bouding boxes. 75 | loc_normalize_std (tupler of four floats): Standard deviation of 76 | the coordinates of bounding boxes. 77 | 78 | Returns: 79 | (array, array, array): 80 | 81 | * **sample_roi**: Regions of interests that are sampled. \ 82 | Its shape is :math:`(S, 4)`. 83 | * **gt_roi_loc**: Offsets and scales to match \ 84 | the sampled RoIs to the ground truth bounding boxes. \ 85 | Its shape is :math:`(S, 4)`. 86 | * **gt_roi_label**: Labels assigned to sampled RoIs. Its shape is \ 87 | :math:`(S,)`. Its range is :math:`[0, L]`. The label with \ 88 | value 0 is the background. 89 | 90 | """ 91 | n_bbox, _ = bbox.shape 92 | 93 | roi = np.concatenate((roi, bbox), axis=0) 94 | 95 | pos_roi_per_image = np.round(self.n_sample * self.pos_ratio) 96 | iou = bbox_iou(roi, bbox) 97 | gt_assignment = iou.argmax(axis=1) 98 | max_iou = iou.max(axis=1) 99 | # Offset range of classes from [0, n_fg_class - 1] to [1, n_fg_class]. 100 | # The label with value 0 is the background. 101 | gt_roi_label = label[gt_assignment] + 1 102 | 103 | # Select foreground RoIs as those with >= pos_iou_thresh IoU. 104 | pos_index = np.where(max_iou >= self.pos_iou_thresh)[0] 105 | pos_roi_per_this_image = int(min(pos_roi_per_image, pos_index.size)) 106 | if pos_index.size > 0: 107 | pos_index = np.random.choice( 108 | pos_index, size=pos_roi_per_this_image, replace=False) 109 | 110 | # Select background RoIs as those within 111 | # [neg_iou_thresh_lo, neg_iou_thresh_hi). 112 | neg_index = np.where((max_iou < self.neg_iou_thresh_hi) & 113 | (max_iou >= self.neg_iou_thresh_lo))[0] 114 | neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image 115 | neg_roi_per_this_image = int(min(neg_roi_per_this_image, 116 | neg_index.size)) 117 | if neg_index.size > 0: 118 | neg_index = np.random.choice( 119 | neg_index, size=neg_roi_per_this_image, replace=False) 120 | 121 | # The indices that we're selecting (both positive and negative). 122 | keep_index = np.append(pos_index, neg_index) 123 | gt_roi_label = gt_roi_label[keep_index] 124 | gt_roi_label[pos_roi_per_this_image:] = 0 # negative labels --> 0 125 | sample_roi = roi[keep_index] 126 | 127 | # Compute offsets and scales to match sampled RoIs to the GTs. 128 | gt_roi_loc = bbox2loc(sample_roi, bbox[gt_assignment[keep_index]]) 129 | gt_roi_loc = ((gt_roi_loc - np.array(loc_normalize_mean, np.float32) 130 | ) / np.array(loc_normalize_std, np.float32)) 131 | 132 | return sample_roi, gt_roi_loc, gt_roi_label 133 | 134 | 135 | class AnchorTargetCreator(object): 136 | """Assign the ground truth bounding boxes to anchors. 137 | 138 | Assigns the ground truth bounding boxes to anchors for training Region 139 | Proposal Networks introduced in Faster R-CNN [#]_. 140 | 141 | Offsets and scales to match anchors to the ground truth are 142 | calculated using the encoding scheme of 143 | :func:`model.utils.bbox_tools.bbox2loc`. 144 | 145 | .. [#] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \ 146 | Faster R-CNN: Towards Real-Time Object Detection with \ 147 | Region Proposal Networks. NIPS 2015. 148 | 149 | Args: 150 | n_sample (int): The number of regions to produce. 151 | pos_iou_thresh (float): Anchors with IoU above this 152 | threshold will be assigned as positive. 153 | neg_iou_thresh (float): Anchors with IoU below this 154 | threshold will be assigned as negative. 155 | pos_ratio (float): Ratio of positive regions in the 156 | sampled regions. 157 | 158 | """ 159 | 160 | def __init__(self, 161 | n_sample=256, 162 | pos_iou_thresh=0.7, neg_iou_thresh=0.3, 163 | pos_ratio=0.5): 164 | self.n_sample = n_sample 165 | self.pos_iou_thresh = pos_iou_thresh 166 | self.neg_iou_thresh = neg_iou_thresh 167 | self.pos_ratio = pos_ratio 168 | 169 | def __call__(self, bbox, anchor, img_size): 170 | """Assign ground truth supervision to sampled subset of anchors. 171 | 172 | Types of input arrays and output arrays are same. 173 | 174 | Here are notations. 175 | 176 | * :math:`S` is the number of anchors. 177 | * :math:`R` is the number of bounding boxes. 178 | 179 | Args: 180 | bbox (array): Coordinates of bounding boxes. Its shape is 181 | :math:`(R, 4)`. 182 | anchor (array): Coordinates of anchors. Its shape is 183 | :math:`(S, 4)`. 184 | img_size (tuple of ints): A tuple :obj:`H, W`, which 185 | is a tuple of height and width of an image. 186 | 187 | Returns: 188 | (array, array): 189 | 190 | #NOTE: it's scale not only offset 191 | * **loc**: Offsets and scales to match the anchors to \ 192 | the ground truth bounding boxes. Its shape is :math:`(S, 4)`. 193 | * **label**: Labels of anchors with values \ 194 | :obj:`(1=positive, 0=negative, -1=ignore)`. Its shape \ 195 | is :math:`(S,)`. 196 | 197 | """ 198 | 199 | img_H, img_W = img_size 200 | 201 | n_anchor = len(anchor) 202 | inside_index = _get_inside_index(anchor, img_H, img_W) 203 | anchor = anchor[inside_index] 204 | argmax_ious, label = self._create_label( 205 | inside_index, anchor, bbox) 206 | 207 | # compute bounding box regression targets 208 | loc = bbox2loc(anchor, bbox[argmax_ious]) 209 | 210 | # map up to original set of anchors 211 | label = _unmap(label, n_anchor, inside_index, fill=-1) 212 | loc = _unmap(loc, n_anchor, inside_index, fill=0) 213 | 214 | return loc, label 215 | 216 | def _create_label(self, inside_index, anchor, bbox): 217 | # label: 1 is positive, 0 is negative, -1 is dont care 218 | label = np.empty((len(inside_index),), dtype=np.int32) 219 | label.fill(-1) 220 | 221 | argmax_ious, max_ious, gt_argmax_ious = \ 222 | self._calc_ious(anchor, bbox, inside_index) 223 | 224 | # assign negative labels first so that positive labels can clobber them 225 | label[max_ious < self.neg_iou_thresh] = 0 226 | 227 | # positive label: for each gt, anchor with highest iou 228 | label[gt_argmax_ious] = 1 229 | 230 | # positive label: above threshold IOU 231 | label[max_ious >= self.pos_iou_thresh] = 1 232 | 233 | # subsample positive labels if we have too many 234 | n_pos = int(self.pos_ratio * self.n_sample) 235 | pos_index = np.where(label == 1)[0] 236 | if len(pos_index) > n_pos: 237 | disable_index = np.random.choice( 238 | pos_index, size=(len(pos_index) - n_pos), replace=False) 239 | label[disable_index] = -1 240 | 241 | # subsample negative labels if we have too many 242 | n_neg = self.n_sample - np.sum(label == 1) 243 | neg_index = np.where(label == 0)[0] 244 | if len(neg_index) > n_neg: 245 | disable_index = np.random.choice( 246 | neg_index, size=(len(neg_index) - n_neg), replace=False) 247 | label[disable_index] = -1 248 | 249 | return argmax_ious, label 250 | 251 | def _calc_ious(self, anchor, bbox, inside_index): 252 | # ious between the anchors and the gt boxes 253 | ious = bbox_iou(anchor, bbox) 254 | argmax_ious = ious.argmax(axis=1) 255 | max_ious = ious[np.arange(len(inside_index)), argmax_ious] 256 | gt_argmax_ious = ious.argmax(axis=0) 257 | gt_max_ious = ious[gt_argmax_ious, np.arange(ious.shape[1])] 258 | gt_argmax_ious = np.where(ious == gt_max_ious)[0] 259 | 260 | return argmax_ious, max_ious, gt_argmax_ious 261 | 262 | 263 | def _unmap(data, count, index, fill=0): 264 | # Unmap a subset of item (data) back to the original set of items (of 265 | # size count) 266 | 267 | if len(data.shape) == 1: 268 | ret = np.empty((count,), dtype=data.dtype) 269 | ret.fill(fill) 270 | ret[index] = data 271 | else: 272 | ret = np.empty((count,) + data.shape[1:], dtype=data.dtype) 273 | ret.fill(fill) 274 | ret[index, :] = data 275 | return ret 276 | 277 | 278 | def _get_inside_index(anchor, H, W): 279 | # Calc indicies of anchors which are located completely inside of the image 280 | # whose size is speficied. 281 | index_inside = np.where( 282 | (anchor[:, 0] >= 0) & 283 | (anchor[:, 1] >= 0) & 284 | (anchor[:, 2] <= H) & 285 | (anchor[:, 3] <= W) 286 | )[0] 287 | return index_inside 288 | 289 | 290 | class ProposalCreator: 291 | # unNOTE: I'll make it undifferential 292 | # unTODO: make sure it's ok 293 | # It's ok 294 | """Proposal regions are generated by calling this object. 295 | 296 | The :meth:`__call__` of this object outputs object detection proposals by 297 | applying estimated bounding box offsets 298 | to a set of anchors. 299 | 300 | This class takes parameters to control number of bounding boxes to 301 | pass to NMS and keep after NMS. 302 | If the paramters are negative, it uses all the bounding boxes supplied 303 | or keep all the bounding boxes returned by NMS. 304 | 305 | This class is used for Region Proposal Networks introduced in 306 | Faster R-CNN [#]_. 307 | 308 | .. [#] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \ 309 | Faster R-CNN: Towards Real-Time Object Detection with \ 310 | Region Proposal Networks. NIPS 2015. 311 | 312 | Args: 313 | nms_thresh (float): Threshold value used when calling NMS. 314 | n_train_pre_nms (int): Number of top scored bounding boxes 315 | to keep before passing to NMS in train mode. 316 | n_train_post_nms (int): Number of top scored bounding boxes 317 | to keep after passing to NMS in train mode. 318 | n_test_pre_nms (int): Number of top scored bounding boxes 319 | to keep before passing to NMS in test mode. 320 | n_test_post_nms (int): Number of top scored bounding boxes 321 | to keep after passing to NMS in test mode. 322 | force_cpu_nms (bool): If this is :obj:`True`, 323 | always use NMS in CPU mode. If :obj:`False`, 324 | the NMS mode is selected based on the type of inputs. 325 | min_size (int): A paramter to determine the threshold on 326 | discarding bounding boxes based on their sizes. 327 | 328 | """ 329 | 330 | def __init__(self, 331 | parent_model, 332 | nms_thresh=0.7, 333 | n_train_pre_nms=12000, 334 | n_train_post_nms=2000, 335 | n_test_pre_nms=6000, 336 | n_test_post_nms=300, 337 | min_size=16 338 | ): 339 | self.parent_model = parent_model 340 | self.nms_thresh = nms_thresh 341 | self.n_train_pre_nms = n_train_pre_nms 342 | self.n_train_post_nms = n_train_post_nms 343 | self.n_test_pre_nms = n_test_pre_nms 344 | self.n_test_post_nms = n_test_post_nms 345 | self.min_size = min_size 346 | 347 | def __call__(self, loc, score, 348 | anchor, img_size, scale=1.): 349 | """input should be ndarray 350 | Propose RoIs. 351 | 352 | Inputs :obj:`loc, score, anchor` refer to the same anchor when indexed 353 | by the same index. 354 | 355 | On notations, :math:`R` is the total number of anchors. This is equal 356 | to product of the height and the width of an image and the number of 357 | anchor bases per pixel. 358 | 359 | Type of the output is same as the inputs. 360 | 361 | Args: 362 | loc (array): Predicted offsets and scaling to anchors. 363 | Its shape is :math:`(R, 4)`. 364 | score (array): Predicted foreground probability for anchors. 365 | Its shape is :math:`(R,)`. 366 | anchor (array): Coordinates of anchors. Its shape is 367 | :math:`(R, 4)`. 368 | img_size (tuple of ints): A tuple :obj:`height, width`, 369 | which contains image size after scaling. 370 | scale (float): The scaling factor used to scale an image after 371 | reading it from a file. 372 | 373 | Returns: 374 | array: 375 | An array of coordinates of proposal boxes. 376 | Its shape is :math:`(S, 4)`. :math:`S` is less than 377 | :obj:`self.n_test_post_nms` in test time and less than 378 | :obj:`self.n_train_post_nms` in train time. :math:`S` depends on 379 | the size of the predicted bounding boxes and the number of 380 | bounding boxes discarded by NMS. 381 | 382 | """ 383 | # NOTE: when test, remember 384 | # faster_rcnn.eval() 385 | # to set self.traing = False 386 | if self.parent_model.training: 387 | n_pre_nms = self.n_train_pre_nms 388 | n_post_nms = self.n_train_post_nms 389 | else: 390 | n_pre_nms = self.n_test_pre_nms 391 | n_post_nms = self.n_test_post_nms 392 | 393 | # Convert anchors into proposal via bbox transformations. 394 | # roi = loc2bbox(anchor, loc) 395 | roi = loc2bbox(anchor, loc) 396 | 397 | # Clip predicted boxes to image. 398 | roi[:, slice(0, 4, 2)] = np.clip( 399 | roi[:, slice(0, 4, 2)], 0, img_size[0]) 400 | roi[:, slice(1, 4, 2)] = np.clip( 401 | roi[:, slice(1, 4, 2)], 0, img_size[1]) 402 | 403 | # Remove predicted boxes with either height or width < threshold. 404 | min_size = self.min_size * scale 405 | hs = roi[:, 2] - roi[:, 0] 406 | ws = roi[:, 3] - roi[:, 1] 407 | keep = np.where((hs >= min_size) & (ws >= min_size))[0] 408 | roi = roi[keep, :] 409 | score = score[keep] 410 | 411 | # Sort all (proposal, score) pairs by score from highest to lowest. 412 | # Take top pre_nms_topN (e.g. 6000). 413 | order = score.ravel().argsort()[::-1] 414 | if n_pre_nms > 0: 415 | order = order[:n_pre_nms] 416 | roi = roi[order, :] 417 | score = score[order] 418 | 419 | # Apply nms (e.g. threshold = 0.7). 420 | # Take after_nms_topN (e.g. 300). 421 | 422 | # unNOTE: somthing is wrong here! 423 | # TODO: remove cuda.to_gpu 424 | keep = nms( 425 | torch.from_numpy(roi).cuda(), 426 | torch.from_numpy(score).cuda(), 427 | self.nms_thresh) 428 | if n_post_nms > 0: 429 | keep = keep[:n_post_nms] 430 | roi = roi[keep.cpu().numpy()] 431 | return roi 432 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | visdom 2 | torchvision 3 | scikit-image 4 | tqdm 5 | fire 6 | pprint 7 | matplotlib 8 | ipdb 9 | cython 10 | git+https://github.com/pytorch/tnt.git@master -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | 4 | import ipdb 5 | import matplotlib 6 | from tqdm import tqdm 7 | 8 | from utils.config import opt 9 | from data.dataset import Dataset, TestDataset, inverse_normalize 10 | from model import FasterRCNNVGG16 11 | from torch.utils import data as data_ 12 | from trainer import FasterRCNNTrainer 13 | from utils import array_tool as at 14 | from utils.vis_tool import visdom_bbox 15 | from utils.eval_tool import eval_detection_voc 16 | 17 | # fix for ulimit 18 | # https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667 19 | import resource 20 | 21 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 22 | resource.setrlimit(resource.RLIMIT_NOFILE, (20480, rlimit[1])) 23 | 24 | matplotlib.use('agg') 25 | 26 | 27 | def eval(dataloader, faster_rcnn, test_num=10000): 28 | pred_bboxes, pred_labels, pred_scores = list(), list(), list() 29 | gt_bboxes, gt_labels, gt_difficults = list(), list(), list() 30 | for ii, (imgs, sizes, gt_bboxes_, gt_labels_, gt_difficults_) in tqdm(enumerate(dataloader)): 31 | sizes = [sizes[0][0].item(), sizes[1][0].item()] 32 | pred_bboxes_, pred_labels_, pred_scores_ = faster_rcnn.predict(imgs, [sizes]) 33 | gt_bboxes += list(gt_bboxes_.numpy()) 34 | gt_labels += list(gt_labels_.numpy()) 35 | gt_difficults += list(gt_difficults_.numpy()) 36 | pred_bboxes += pred_bboxes_ 37 | pred_labels += pred_labels_ 38 | pred_scores += pred_scores_ 39 | if ii == test_num: break 40 | 41 | result = eval_detection_voc( 42 | pred_bboxes, pred_labels, pred_scores, 43 | gt_bboxes, gt_labels, gt_difficults, 44 | use_07_metric=True) 45 | return result 46 | 47 | 48 | def train(**kwargs): 49 | opt._parse(kwargs) 50 | 51 | dataset = Dataset(opt) 52 | print('load data') 53 | dataloader = data_.DataLoader(dataset, \ 54 | batch_size=1, \ 55 | shuffle=True, \ 56 | # pin_memory=True, 57 | num_workers=opt.num_workers) 58 | testset = TestDataset(opt) 59 | test_dataloader = data_.DataLoader(testset, 60 | batch_size=1, 61 | num_workers=opt.test_num_workers, 62 | shuffle=False, \ 63 | pin_memory=True 64 | ) 65 | faster_rcnn = FasterRCNNVGG16() 66 | print('model construct completed') 67 | trainer = FasterRCNNTrainer(faster_rcnn).cuda() 68 | if opt.load_path: 69 | trainer.load(opt.load_path) 70 | print('load pretrained model from %s' % opt.load_path) 71 | trainer.vis.text(dataset.db.label_names, win='labels') 72 | best_map = 0 73 | lr_ = opt.lr 74 | for epoch in range(opt.epoch): 75 | trainer.reset_meters() 76 | for ii, (img, bbox_, label_, scale) in tqdm(enumerate(dataloader)): 77 | scale = at.scalar(scale) 78 | img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda() 79 | trainer.train_step(img, bbox, label, scale) 80 | 81 | if (ii + 1) % opt.plot_every == 0: 82 | if os.path.exists(opt.debug_file): 83 | ipdb.set_trace() 84 | 85 | # plot loss 86 | trainer.vis.plot_many(trainer.get_meter_data()) 87 | 88 | # plot groud truth bboxes 89 | ori_img_ = inverse_normalize(at.tonumpy(img[0])) 90 | gt_img = visdom_bbox(ori_img_, 91 | at.tonumpy(bbox_[0]), 92 | at.tonumpy(label_[0])) 93 | trainer.vis.img('gt_img', gt_img) 94 | 95 | # plot predicti bboxes 96 | _bboxes, _labels, _scores = trainer.faster_rcnn.predict([ori_img_], visualize=True) 97 | pred_img = visdom_bbox(ori_img_, 98 | at.tonumpy(_bboxes[0]), 99 | at.tonumpy(_labels[0]).reshape(-1), 100 | at.tonumpy(_scores[0])) 101 | trainer.vis.img('pred_img', pred_img) 102 | 103 | # rpn confusion matrix(meter) 104 | trainer.vis.text(str(trainer.rpn_cm.value().tolist()), win='rpn_cm') 105 | # roi confusion matrix 106 | trainer.vis.img('roi_cm', at.totensor(trainer.roi_cm.conf, False).float()) 107 | eval_result = eval(test_dataloader, faster_rcnn, test_num=opt.test_num) 108 | trainer.vis.plot('test_map', eval_result['map']) 109 | lr_ = trainer.faster_rcnn.optimizer.param_groups[0]['lr'] 110 | log_info = 'lr:{}, map:{},loss:{}'.format(str(lr_), 111 | str(eval_result['map']), 112 | str(trainer.get_meter_data())) 113 | trainer.vis.log(log_info) 114 | 115 | if eval_result['map'] > best_map: 116 | best_map = eval_result['map'] 117 | best_path = trainer.save(best_map=best_map) 118 | if epoch == 9: 119 | trainer.load(best_path) 120 | trainer.faster_rcnn.scale_lr(opt.lr_decay) 121 | lr_ = lr_ * opt.lr_decay 122 | 123 | if epoch == 13: 124 | break 125 | 126 | 127 | if __name__ == '__main__': 128 | import fire 129 | 130 | fire.Fire() 131 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | from collections import namedtuple 4 | import time 5 | from torch.nn import functional as F 6 | from model.utils.creator_tool import AnchorTargetCreator, ProposalTargetCreator 7 | 8 | from torch import nn 9 | import torch as t 10 | from utils import array_tool as at 11 | from utils.vis_tool import Visualizer 12 | 13 | from utils.config import opt 14 | from torchnet.meter import ConfusionMeter, AverageValueMeter 15 | 16 | LossTuple = namedtuple('LossTuple', 17 | ['rpn_loc_loss', 18 | 'rpn_cls_loss', 19 | 'roi_loc_loss', 20 | 'roi_cls_loss', 21 | 'total_loss' 22 | ]) 23 | 24 | 25 | class FasterRCNNTrainer(nn.Module): 26 | """wrapper for conveniently training. return losses 27 | 28 | The losses include: 29 | 30 | * :obj:`rpn_loc_loss`: The localization loss for \ 31 | Region Proposal Network (RPN). 32 | * :obj:`rpn_cls_loss`: The classification loss for RPN. 33 | * :obj:`roi_loc_loss`: The localization loss for the head module. 34 | * :obj:`roi_cls_loss`: The classification loss for the head module. 35 | * :obj:`total_loss`: The sum of 4 loss above. 36 | 37 | Args: 38 | faster_rcnn (model.FasterRCNN): 39 | A Faster R-CNN model that is going to be trained. 40 | """ 41 | 42 | def __init__(self, faster_rcnn): 43 | super(FasterRCNNTrainer, self).__init__() 44 | 45 | self.faster_rcnn = faster_rcnn 46 | self.rpn_sigma = opt.rpn_sigma 47 | self.roi_sigma = opt.roi_sigma 48 | 49 | # target creator create gt_bbox gt_label etc as training targets. 50 | self.anchor_target_creator = AnchorTargetCreator() 51 | self.proposal_target_creator = ProposalTargetCreator() 52 | 53 | self.loc_normalize_mean = faster_rcnn.loc_normalize_mean 54 | self.loc_normalize_std = faster_rcnn.loc_normalize_std 55 | 56 | self.optimizer = self.faster_rcnn.get_optimizer() 57 | # visdom wrapper 58 | self.vis = Visualizer(env=opt.env) 59 | 60 | # indicators for training status 61 | self.rpn_cm = ConfusionMeter(2) 62 | self.roi_cm = ConfusionMeter(21) 63 | self.meters = {k: AverageValueMeter() for k in LossTuple._fields} # average loss 64 | 65 | def forward(self, imgs, bboxes, labels, scale): 66 | """Forward Faster R-CNN and calculate losses. 67 | 68 | Here are notations used. 69 | 70 | * :math:`N` is the batch size. 71 | * :math:`R` is the number of bounding boxes per image. 72 | 73 | Currently, only :math:`N=1` is supported. 74 | 75 | Args: 76 | imgs (~torch.autograd.Variable): A variable with a batch of images. 77 | bboxes (~torch.autograd.Variable): A batch of bounding boxes. 78 | Its shape is :math:`(N, R, 4)`. 79 | labels (~torch.autograd..Variable): A batch of labels. 80 | Its shape is :math:`(N, R)`. The background is excluded from 81 | the definition, which means that the range of the value 82 | is :math:`[0, L - 1]`. :math:`L` is the number of foreground 83 | classes. 84 | scale (float): Amount of scaling applied to 85 | the raw image during preprocessing. 86 | 87 | Returns: 88 | namedtuple of 5 losses 89 | """ 90 | n = bboxes.shape[0] 91 | if n != 1: 92 | raise ValueError('Currently only batch size 1 is supported.') 93 | 94 | _, _, H, W = imgs.shape 95 | img_size = (H, W) 96 | 97 | features = self.faster_rcnn.extractor(imgs) 98 | 99 | rpn_locs, rpn_scores, rois, roi_indices, anchor = \ 100 | self.faster_rcnn.rpn(features, img_size, scale) 101 | 102 | # Since batch size is one, convert variables to singular form 103 | bbox = bboxes[0] 104 | label = labels[0] 105 | rpn_score = rpn_scores[0] 106 | rpn_loc = rpn_locs[0] 107 | roi = rois 108 | 109 | # Sample RoIs and forward 110 | # it's fine to break the computation graph of rois, 111 | # consider them as constant input 112 | sample_roi, gt_roi_loc, gt_roi_label = self.proposal_target_creator( 113 | roi, 114 | at.tonumpy(bbox), 115 | at.tonumpy(label), 116 | self.loc_normalize_mean, 117 | self.loc_normalize_std) 118 | # NOTE it's all zero because now it only support for batch=1 now 119 | sample_roi_index = t.zeros(len(sample_roi)) 120 | roi_cls_loc, roi_score = self.faster_rcnn.head( 121 | features, 122 | sample_roi, 123 | sample_roi_index) 124 | 125 | # ------------------ RPN losses -------------------# 126 | gt_rpn_loc, gt_rpn_label = self.anchor_target_creator( 127 | at.tonumpy(bbox), 128 | anchor, 129 | img_size) 130 | gt_rpn_label = at.totensor(gt_rpn_label).long() 131 | gt_rpn_loc = at.totensor(gt_rpn_loc) 132 | rpn_loc_loss = _fast_rcnn_loc_loss( 133 | rpn_loc, 134 | gt_rpn_loc, 135 | gt_rpn_label.data, 136 | self.rpn_sigma) 137 | 138 | # NOTE: default value of ignore_index is -100 ... 139 | rpn_cls_loss = F.cross_entropy(rpn_score, gt_rpn_label.cuda(), ignore_index=-1) 140 | _gt_rpn_label = gt_rpn_label[gt_rpn_label > -1] 141 | _rpn_score = at.tonumpy(rpn_score)[at.tonumpy(gt_rpn_label) > -1] 142 | self.rpn_cm.add(at.totensor(_rpn_score, False), _gt_rpn_label.data.long()) 143 | 144 | # ------------------ ROI losses (fast rcnn loss) -------------------# 145 | n_sample = roi_cls_loc.shape[0] 146 | roi_cls_loc = roi_cls_loc.view(n_sample, -1, 4) 147 | roi_loc = roi_cls_loc[t.arange(0, n_sample).long().cuda(), \ 148 | at.totensor(gt_roi_label).long()] 149 | gt_roi_label = at.totensor(gt_roi_label).long() 150 | gt_roi_loc = at.totensor(gt_roi_loc) 151 | 152 | roi_loc_loss = _fast_rcnn_loc_loss( 153 | roi_loc.contiguous(), 154 | gt_roi_loc, 155 | gt_roi_label.data, 156 | self.roi_sigma) 157 | 158 | roi_cls_loss = nn.CrossEntropyLoss()(roi_score, gt_roi_label.cuda()) 159 | 160 | self.roi_cm.add(at.totensor(roi_score, False), gt_roi_label.data.long()) 161 | 162 | losses = [rpn_loc_loss, rpn_cls_loss, roi_loc_loss, roi_cls_loss] 163 | losses = losses + [sum(losses)] 164 | 165 | return LossTuple(*losses) 166 | 167 | def train_step(self, imgs, bboxes, labels, scale): 168 | self.optimizer.zero_grad() 169 | losses = self.forward(imgs, bboxes, labels, scale) 170 | losses.total_loss.backward() 171 | self.optimizer.step() 172 | self.update_meters(losses) 173 | return losses 174 | 175 | def save(self, save_optimizer=False, save_path=None, **kwargs): 176 | """serialize models include optimizer and other info 177 | return path where the model-file is stored. 178 | 179 | Args: 180 | save_optimizer (bool): whether save optimizer.state_dict(). 181 | save_path (string): where to save model, if it's None, save_path 182 | is generate using time str and info from kwargs. 183 | 184 | Returns: 185 | save_path(str): the path to save models. 186 | """ 187 | save_dict = dict() 188 | 189 | save_dict['model'] = self.faster_rcnn.state_dict() 190 | save_dict['config'] = opt._state_dict() 191 | save_dict['other_info'] = kwargs 192 | save_dict['vis_info'] = self.vis.state_dict() 193 | 194 | if save_optimizer: 195 | save_dict['optimizer'] = self.optimizer.state_dict() 196 | 197 | if save_path is None: 198 | timestr = time.strftime('%m%d%H%M') 199 | save_path = 'checkpoints/fasterrcnn_%s' % timestr 200 | for k_, v_ in kwargs.items(): 201 | save_path += '_%s' % v_ 202 | 203 | save_dir = os.path.dirname(save_path) 204 | if not os.path.exists(save_dir): 205 | os.makedirs(save_dir) 206 | 207 | t.save(save_dict, save_path) 208 | self.vis.save([self.vis.env]) 209 | return save_path 210 | 211 | def load(self, path, load_optimizer=True, parse_opt=False, ): 212 | state_dict = t.load(path) 213 | if 'model' in state_dict: 214 | self.faster_rcnn.load_state_dict(state_dict['model']) 215 | else: # legacy way, for backward compatibility 216 | self.faster_rcnn.load_state_dict(state_dict) 217 | return self 218 | if parse_opt: 219 | opt._parse(state_dict['config']) 220 | if 'optimizer' in state_dict and load_optimizer: 221 | self.optimizer.load_state_dict(state_dict['optimizer']) 222 | return self 223 | 224 | def update_meters(self, losses): 225 | loss_d = {k: at.scalar(v) for k, v in losses._asdict().items()} 226 | for key, meter in self.meters.items(): 227 | meter.add(loss_d[key]) 228 | 229 | def reset_meters(self): 230 | for key, meter in self.meters.items(): 231 | meter.reset() 232 | self.roi_cm.reset() 233 | self.rpn_cm.reset() 234 | 235 | def get_meter_data(self): 236 | return {k: v.value()[0] for k, v in self.meters.items()} 237 | 238 | 239 | def _smooth_l1_loss(x, t, in_weight, sigma): 240 | sigma2 = sigma ** 2 241 | diff = in_weight * (x - t) 242 | abs_diff = diff.abs() 243 | flag = (abs_diff.data < (1. / sigma2)).float() 244 | y = (flag * (sigma2 / 2.) * (diff ** 2) + 245 | (1 - flag) * (abs_diff - 0.5 / sigma2)) 246 | return y.sum() 247 | 248 | 249 | def _fast_rcnn_loc_loss(pred_loc, gt_loc, gt_label, sigma): 250 | in_weight = t.zeros(gt_loc.shape).cuda() 251 | # Localization loss is calculated only for positive rois. 252 | # NOTE: unlike origin implementation, 253 | # we don't need inside_weight and outside_weight, they can calculate by gt_label 254 | in_weight[(gt_label > 0).view(-1, 1).expand_as(in_weight).cuda()] = 1 255 | loc_loss = _smooth_l1_loss(pred_loc, gt_loc, in_weight.detach(), sigma) 256 | # Normalize by total number of negtive and positive rois. 257 | loc_loss /= ((gt_label >= 0).sum().float()) # ignore gt_label==-1 for rpn_loss 258 | return loc_loss 259 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 cy 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /utils/array_tool.py: -------------------------------------------------------------------------------- 1 | """ 2 | tools to convert specified type 3 | """ 4 | import torch as t 5 | import numpy as np 6 | 7 | 8 | def tonumpy(data): 9 | if isinstance(data, np.ndarray): 10 | return data 11 | if isinstance(data, t.Tensor): 12 | return data.detach().cpu().numpy() 13 | 14 | 15 | def totensor(data, cuda=True): 16 | if isinstance(data, np.ndarray): 17 | tensor = t.from_numpy(data) 18 | if isinstance(data, t.Tensor): 19 | tensor = data.detach() 20 | if cuda: 21 | tensor = tensor.cuda() 22 | return tensor 23 | 24 | 25 | def scalar(data): 26 | if isinstance(data, np.ndarray): 27 | return data.reshape(1)[0] 28 | if isinstance(data, t.Tensor): 29 | return data.item() -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | 4 | # Default Configs for training 5 | # NOTE that, config items could be overwriten by passing argument through command line. 6 | # e.g. --voc-data-dir='./data/' 7 | 8 | class Config: 9 | # data 10 | voc_data_dir = '/dataset/PASCAL2007/VOC2007/' 11 | min_size = 600 # image resize 12 | max_size = 1000 # image resize 13 | num_workers = 8 14 | test_num_workers = 8 15 | 16 | # sigma for l1_smooth_loss 17 | rpn_sigma = 3. 18 | roi_sigma = 1. 19 | 20 | # param for optimizer 21 | # 0.0005 in origin paper but 0.0001 in tf-faster-rcnn 22 | weight_decay = 0.0005 23 | lr_decay = 0.1 # 1e-3 -> 1e-4 24 | lr = 1e-3 25 | 26 | 27 | # visualization 28 | env = 'faster-rcnn' # visdom env 29 | port = 8097 30 | plot_every = 40 # vis every N iter 31 | 32 | # preset 33 | data = 'voc' 34 | pretrained_model = 'vgg16' 35 | 36 | # training 37 | epoch = 14 38 | 39 | 40 | use_adam = False # Use Adam optimizer 41 | use_chainer = False # try match everything as chainer 42 | use_drop = False # use dropout in RoIHead 43 | # debug 44 | debug_file = '/tmp/debugf' 45 | 46 | test_num = 10000 47 | # model 48 | load_path = None 49 | 50 | caffe_pretrain = False # use caffe pretrained model instead of torchvision 51 | caffe_pretrain_path = 'checkpoints/vgg16_caffe.pth' 52 | 53 | def _parse(self, kwargs): 54 | state_dict = self._state_dict() 55 | for k, v in kwargs.items(): 56 | if k not in state_dict: 57 | raise ValueError('UnKnown Option: "--%s"' % k) 58 | setattr(self, k, v) 59 | 60 | print('======user config========') 61 | pprint(self._state_dict()) 62 | print('==========end============') 63 | 64 | def _state_dict(self): 65 | return {k: getattr(self, k) for k, _ in Config.__dict__.items() \ 66 | if not k.startswith('_')} 67 | 68 | 69 | opt = Config() 70 | -------------------------------------------------------------------------------- /utils/eval_tool.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from collections import defaultdict 4 | import itertools 5 | import numpy as np 6 | import six 7 | 8 | from model.utils.bbox_tools import bbox_iou 9 | 10 | 11 | def eval_detection_voc( 12 | pred_bboxes, pred_labels, pred_scores, gt_bboxes, gt_labels, 13 | gt_difficults=None, 14 | iou_thresh=0.5, use_07_metric=False): 15 | """Calculate average precisions based on evaluation code of PASCAL VOC. 16 | 17 | This function evaluates predicted bounding boxes obtained from a dataset 18 | which has :math:`N` images by using average precision for each class. 19 | The code is based on the evaluation code used in PASCAL VOC Challenge. 20 | 21 | Args: 22 | pred_bboxes (iterable of numpy.ndarray): An iterable of :math:`N` 23 | sets of bounding boxes. 24 | Its index corresponds to an index for the base dataset. 25 | Each element of :obj:`pred_bboxes` is a set of coordinates 26 | of bounding boxes. This is an array whose shape is :math:`(R, 4)`, 27 | where :math:`R` corresponds 28 | to the number of bounding boxes, which may vary among boxes. 29 | The second axis corresponds to 30 | :math:`y_{min}, x_{min}, y_{max}, x_{max}` of a bounding box. 31 | pred_labels (iterable of numpy.ndarray): An iterable of labels. 32 | Similar to :obj:`pred_bboxes`, its index corresponds to an 33 | index for the base dataset. Its length is :math:`N`. 34 | pred_scores (iterable of numpy.ndarray): An iterable of confidence 35 | scores for predicted bounding boxes. Similar to :obj:`pred_bboxes`, 36 | its index corresponds to an index for the base dataset. 37 | Its length is :math:`N`. 38 | gt_bboxes (iterable of numpy.ndarray): An iterable of ground truth 39 | bounding boxes 40 | whose length is :math:`N`. An element of :obj:`gt_bboxes` is a 41 | bounding box whose shape is :math:`(R, 4)`. Note that the number of 42 | bounding boxes in each image does not need to be same as the number 43 | of corresponding predicted boxes. 44 | gt_labels (iterable of numpy.ndarray): An iterable of ground truth 45 | labels which are organized similarly to :obj:`gt_bboxes`. 46 | gt_difficults (iterable of numpy.ndarray): An iterable of boolean 47 | arrays which is organized similarly to :obj:`gt_bboxes`. 48 | This tells whether the 49 | corresponding ground truth bounding box is difficult or not. 50 | By default, this is :obj:`None`. In that case, this function 51 | considers all bounding boxes to be not difficult. 52 | iou_thresh (float): A prediction is correct if its Intersection over 53 | Union with the ground truth is above this value. 54 | use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric 55 | for calculating average precision. The default value is 56 | :obj:`False`. 57 | 58 | Returns: 59 | dict: 60 | 61 | The keys, value-types and the description of the values are listed 62 | below. 63 | 64 | * **ap** (*numpy.ndarray*): An array of average precisions. \ 65 | The :math:`l`-th value corresponds to the average precision \ 66 | for class :math:`l`. If class :math:`l` does not exist in \ 67 | either :obj:`pred_labels` or :obj:`gt_labels`, the corresponding \ 68 | value is set to :obj:`numpy.nan`. 69 | * **map** (*float*): The average of Average Precisions over classes. 70 | 71 | """ 72 | 73 | prec, rec = calc_detection_voc_prec_rec( 74 | pred_bboxes, pred_labels, pred_scores, 75 | gt_bboxes, gt_labels, gt_difficults, 76 | iou_thresh=iou_thresh) 77 | 78 | ap = calc_detection_voc_ap(prec, rec, use_07_metric=use_07_metric) 79 | 80 | return {'ap': ap, 'map': np.nanmean(ap)} 81 | 82 | 83 | def calc_detection_voc_prec_rec( 84 | pred_bboxes, pred_labels, pred_scores, gt_bboxes, gt_labels, 85 | gt_difficults=None, 86 | iou_thresh=0.5): 87 | """Calculate precision and recall based on evaluation code of PASCAL VOC. 88 | 89 | This function calculates precision and recall of 90 | predicted bounding boxes obtained from a dataset which has :math:`N` 91 | images. 92 | The code is based on the evaluation code used in PASCAL VOC Challenge. 93 | 94 | Args: 95 | pred_bboxes (iterable of numpy.ndarray): An iterable of :math:`N` 96 | sets of bounding boxes. 97 | Its index corresponds to an index for the base dataset. 98 | Each element of :obj:`pred_bboxes` is a set of coordinates 99 | of bounding boxes. This is an array whose shape is :math:`(R, 4)`, 100 | where :math:`R` corresponds 101 | to the number of bounding boxes, which may vary among boxes. 102 | The second axis corresponds to 103 | :math:`y_{min}, x_{min}, y_{max}, x_{max}` of a bounding box. 104 | pred_labels (iterable of numpy.ndarray): An iterable of labels. 105 | Similar to :obj:`pred_bboxes`, its index corresponds to an 106 | index for the base dataset. Its length is :math:`N`. 107 | pred_scores (iterable of numpy.ndarray): An iterable of confidence 108 | scores for predicted bounding boxes. Similar to :obj:`pred_bboxes`, 109 | its index corresponds to an index for the base dataset. 110 | Its length is :math:`N`. 111 | gt_bboxes (iterable of numpy.ndarray): An iterable of ground truth 112 | bounding boxes 113 | whose length is :math:`N`. An element of :obj:`gt_bboxes` is a 114 | bounding box whose shape is :math:`(R, 4)`. Note that the number of 115 | bounding boxes in each image does not need to be same as the number 116 | of corresponding predicted boxes. 117 | gt_labels (iterable of numpy.ndarray): An iterable of ground truth 118 | labels which are organized similarly to :obj:`gt_bboxes`. 119 | gt_difficults (iterable of numpy.ndarray): An iterable of boolean 120 | arrays which is organized similarly to :obj:`gt_bboxes`. 121 | This tells whether the 122 | corresponding ground truth bounding box is difficult or not. 123 | By default, this is :obj:`None`. In that case, this function 124 | considers all bounding boxes to be not difficult. 125 | iou_thresh (float): A prediction is correct if its Intersection over 126 | Union with the ground truth is above this value.. 127 | 128 | Returns: 129 | tuple of two lists: 130 | This function returns two lists: :obj:`prec` and :obj:`rec`. 131 | 132 | * :obj:`prec`: A list of arrays. :obj:`prec[l]` is precision \ 133 | for class :math:`l`. If class :math:`l` does not exist in \ 134 | either :obj:`pred_labels` or :obj:`gt_labels`, :obj:`prec[l]` is \ 135 | set to :obj:`None`. 136 | * :obj:`rec`: A list of arrays. :obj:`rec[l]` is recall \ 137 | for class :math:`l`. If class :math:`l` that is not marked as \ 138 | difficult does not exist in \ 139 | :obj:`gt_labels`, :obj:`rec[l]` is \ 140 | set to :obj:`None`. 141 | 142 | """ 143 | 144 | pred_bboxes = iter(pred_bboxes) 145 | pred_labels = iter(pred_labels) 146 | pred_scores = iter(pred_scores) 147 | gt_bboxes = iter(gt_bboxes) 148 | gt_labels = iter(gt_labels) 149 | if gt_difficults is None: 150 | gt_difficults = itertools.repeat(None) 151 | else: 152 | gt_difficults = iter(gt_difficults) 153 | 154 | n_pos = defaultdict(int) 155 | score = defaultdict(list) 156 | match = defaultdict(list) 157 | 158 | for pred_bbox, pred_label, pred_score, gt_bbox, gt_label, gt_difficult in \ 159 | six.moves.zip( 160 | pred_bboxes, pred_labels, pred_scores, 161 | gt_bboxes, gt_labels, gt_difficults): 162 | 163 | if gt_difficult is None: 164 | gt_difficult = np.zeros(gt_bbox.shape[0], dtype=bool) 165 | 166 | for l in np.unique(np.concatenate((pred_label, gt_label)).astype(int)): 167 | pred_mask_l = pred_label == l 168 | pred_bbox_l = pred_bbox[pred_mask_l] 169 | pred_score_l = pred_score[pred_mask_l] 170 | # sort by score 171 | order = pred_score_l.argsort()[::-1] 172 | pred_bbox_l = pred_bbox_l[order] 173 | pred_score_l = pred_score_l[order] 174 | 175 | gt_mask_l = gt_label == l 176 | gt_bbox_l = gt_bbox[gt_mask_l] 177 | gt_difficult_l = gt_difficult[gt_mask_l] 178 | 179 | n_pos[l] += np.logical_not(gt_difficult_l).sum() 180 | score[l].extend(pred_score_l) 181 | 182 | if len(pred_bbox_l) == 0: 183 | continue 184 | if len(gt_bbox_l) == 0: 185 | match[l].extend((0,) * pred_bbox_l.shape[0]) 186 | continue 187 | 188 | # VOC evaluation follows integer typed bounding boxes. 189 | pred_bbox_l = pred_bbox_l.copy() 190 | pred_bbox_l[:, 2:] += 1 191 | gt_bbox_l = gt_bbox_l.copy() 192 | gt_bbox_l[:, 2:] += 1 193 | 194 | iou = bbox_iou(pred_bbox_l, gt_bbox_l) 195 | gt_index = iou.argmax(axis=1) 196 | # set -1 if there is no matching ground truth 197 | gt_index[iou.max(axis=1) < iou_thresh] = -1 198 | del iou 199 | 200 | selec = np.zeros(gt_bbox_l.shape[0], dtype=bool) 201 | for gt_idx in gt_index: 202 | if gt_idx >= 0: 203 | if gt_difficult_l[gt_idx]: 204 | match[l].append(-1) 205 | else: 206 | if not selec[gt_idx]: 207 | match[l].append(1) 208 | else: 209 | match[l].append(0) 210 | selec[gt_idx] = True 211 | else: 212 | match[l].append(0) 213 | 214 | for iter_ in ( 215 | pred_bboxes, pred_labels, pred_scores, 216 | gt_bboxes, gt_labels, gt_difficults): 217 | if next(iter_, None) is not None: 218 | raise ValueError('Length of input iterables need to be same.') 219 | 220 | n_fg_class = max(n_pos.keys()) + 1 221 | prec = [None] * n_fg_class 222 | rec = [None] * n_fg_class 223 | 224 | for l in n_pos.keys(): 225 | score_l = np.array(score[l]) 226 | match_l = np.array(match[l], dtype=np.int8) 227 | 228 | order = score_l.argsort()[::-1] 229 | match_l = match_l[order] 230 | 231 | tp = np.cumsum(match_l == 1) 232 | fp = np.cumsum(match_l == 0) 233 | 234 | # If an element of fp + tp is 0, 235 | # the corresponding element of prec[l] is nan. 236 | prec[l] = tp / (fp + tp) 237 | # If n_pos[l] is 0, rec[l] is None. 238 | if n_pos[l] > 0: 239 | rec[l] = tp / n_pos[l] 240 | 241 | return prec, rec 242 | 243 | 244 | def calc_detection_voc_ap(prec, rec, use_07_metric=False): 245 | """Calculate average precisions based on evaluation code of PASCAL VOC. 246 | 247 | This function calculates average precisions 248 | from given precisions and recalls. 249 | The code is based on the evaluation code used in PASCAL VOC Challenge. 250 | 251 | Args: 252 | prec (list of numpy.array): A list of arrays. 253 | :obj:`prec[l]` indicates precision for class :math:`l`. 254 | If :obj:`prec[l]` is :obj:`None`, this function returns 255 | :obj:`numpy.nan` for class :math:`l`. 256 | rec (list of numpy.array): A list of arrays. 257 | :obj:`rec[l]` indicates recall for class :math:`l`. 258 | If :obj:`rec[l]` is :obj:`None`, this function returns 259 | :obj:`numpy.nan` for class :math:`l`. 260 | use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric 261 | for calculating average precision. The default value is 262 | :obj:`False`. 263 | 264 | Returns: 265 | ~numpy.ndarray: 266 | This function returns an array of average precisions. 267 | The :math:`l`-th value corresponds to the average precision 268 | for class :math:`l`. If :obj:`prec[l]` or :obj:`rec[l]` is 269 | :obj:`None`, the corresponding value is set to :obj:`numpy.nan`. 270 | 271 | """ 272 | 273 | n_fg_class = len(prec) 274 | ap = np.empty(n_fg_class) 275 | for l in six.moves.range(n_fg_class): 276 | if prec[l] is None or rec[l] is None: 277 | ap[l] = np.nan 278 | continue 279 | 280 | if use_07_metric: 281 | # 11 point metric 282 | ap[l] = 0 283 | for t in np.arange(0., 1.1, 0.1): 284 | if np.sum(rec[l] >= t) == 0: 285 | p = 0 286 | else: 287 | p = np.max(np.nan_to_num(prec[l])[rec[l] >= t]) 288 | ap[l] += p / 11 289 | else: 290 | # correct AP calculation 291 | # first append sentinel values at the end 292 | mpre = np.concatenate(([0], np.nan_to_num(prec[l]), [0])) 293 | mrec = np.concatenate(([0], rec[l], [1])) 294 | 295 | mpre = np.maximum.accumulate(mpre[::-1])[::-1] 296 | 297 | # to calculate area under PR curve, look for points 298 | # where X axis (recall) changes value 299 | i = np.where(mrec[1:] != mrec[:-1])[0] 300 | 301 | # and sum (\Delta recall) * prec 302 | ap[l] = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 303 | 304 | return ap 305 | -------------------------------------------------------------------------------- /utils/vis_tool.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | import matplotlib 5 | import torch as t 6 | import visdom 7 | 8 | matplotlib.use('Agg') 9 | from matplotlib import pyplot as plot 10 | 11 | # from data.voc_dataset import VOC_BBOX_LABEL_NAMES 12 | 13 | 14 | VOC_BBOX_LABEL_NAMES = ( 15 | 'fly', 16 | 'bike', 17 | 'bird', 18 | 'boat', 19 | 'pin', 20 | 'bus', 21 | 'c', 22 | 'cat', 23 | 'chair', 24 | 'cow', 25 | 'table', 26 | 'dog', 27 | 'horse', 28 | 'moto', 29 | 'p', 30 | 'plant', 31 | 'shep', 32 | 'sofa', 33 | 'train', 34 | 'tv', 35 | ) 36 | 37 | 38 | def vis_image(img, ax=None): 39 | """Visualize a color image. 40 | 41 | Args: 42 | img (~numpy.ndarray): An array of shape :math:`(3, height, width)`. 43 | This is in RGB format and the range of its value is 44 | :math:`[0, 255]`. 45 | ax (matplotlib.axes.Axis): The visualization is displayed on this 46 | axis. If this is :obj:`None` (default), a new axis is created. 47 | 48 | Returns: 49 | ~matploblib.axes.Axes: 50 | Returns the Axes object with the plot for further tweaking. 51 | 52 | """ 53 | 54 | if ax is None: 55 | fig = plot.figure() 56 | ax = fig.add_subplot(1, 1, 1) 57 | # CHW -> HWC 58 | img = img.transpose((1, 2, 0)) 59 | ax.imshow(img.astype(np.uint8)) 60 | return ax 61 | 62 | 63 | def vis_bbox(img, bbox, label=None, score=None, ax=None): 64 | """Visualize bounding boxes inside image. 65 | 66 | Args: 67 | img (~numpy.ndarray): An array of shape :math:`(3, height, width)`. 68 | This is in RGB format and the range of its value is 69 | :math:`[0, 255]`. 70 | bbox (~numpy.ndarray): An array of shape :math:`(R, 4)`, where 71 | :math:`R` is the number of bounding boxes in the image. 72 | Each element is organized 73 | by :math:`(y_{min}, x_{min}, y_{max}, x_{max})` in the second axis. 74 | label (~numpy.ndarray): An integer array of shape :math:`(R,)`. 75 | The values correspond to id for label names stored in 76 | :obj:`label_names`. This is optional. 77 | score (~numpy.ndarray): A float array of shape :math:`(R,)`. 78 | Each value indicates how confident the prediction is. 79 | This is optional. 80 | label_names (iterable of strings): Name of labels ordered according 81 | to label ids. If this is :obj:`None`, labels will be skipped. 82 | ax (matplotlib.axes.Axis): The visualization is displayed on this 83 | axis. If this is :obj:`None` (default), a new axis is created. 84 | 85 | Returns: 86 | ~matploblib.axes.Axes: 87 | Returns the Axes object with the plot for further tweaking. 88 | 89 | """ 90 | 91 | label_names = list(VOC_BBOX_LABEL_NAMES) + ['bg'] 92 | # add for index `-1` 93 | if label is not None and not len(bbox) == len(label): 94 | raise ValueError('The length of label must be same as that of bbox') 95 | if score is not None and not len(bbox) == len(score): 96 | raise ValueError('The length of score must be same as that of bbox') 97 | 98 | # Returns newly instantiated matplotlib.axes.Axes object if ax is None 99 | ax = vis_image(img, ax=ax) 100 | 101 | # If there is no bounding box to display, visualize the image and exit. 102 | if len(bbox) == 0: 103 | return ax 104 | 105 | for i, bb in enumerate(bbox): 106 | xy = (bb[1], bb[0]) 107 | height = bb[2] - bb[0] 108 | width = bb[3] - bb[1] 109 | ax.add_patch(plot.Rectangle( 110 | xy, width, height, fill=False, edgecolor='red', linewidth=2)) 111 | 112 | caption = list() 113 | 114 | if label is not None and label_names is not None: 115 | lb = label[i] 116 | if not (-1 <= lb < len(label_names)): # modfy here to add backgroud 117 | raise ValueError('No corresponding name is given') 118 | caption.append(label_names[lb]) 119 | if score is not None: 120 | sc = score[i] 121 | caption.append('{:.2f}'.format(sc)) 122 | 123 | if len(caption) > 0: 124 | ax.text(bb[1], bb[0], 125 | ': '.join(caption), 126 | style='italic', 127 | bbox={'facecolor': 'white', 'alpha': 0.5, 'pad': 0}) 128 | return ax 129 | 130 | 131 | def fig2data(fig): 132 | """ 133 | brief Convert a Matplotlib figure to a 4D numpy array with RGBA 134 | channels and return it 135 | 136 | @param fig: a matplotlib figure 137 | @return a numpy 3D array of RGBA values 138 | """ 139 | # draw the renderer 140 | fig.canvas.draw() 141 | 142 | # Get the RGBA buffer from the figure 143 | w, h = fig.canvas.get_width_height() 144 | buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8) 145 | buf.shape = (w, h, 4) 146 | 147 | # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode 148 | buf = np.roll(buf, 3, axis=2) 149 | return buf.reshape(h, w, 4) 150 | 151 | 152 | def fig4vis(fig): 153 | """ 154 | convert figure to ndarray 155 | """ 156 | ax = fig.get_figure() 157 | img_data = fig2data(ax).astype(np.int32) 158 | plot.close() 159 | # HWC->CHW 160 | return img_data[:, :, :3].transpose((2, 0, 1)) / 255. 161 | 162 | 163 | def visdom_bbox(*args, **kwargs): 164 | fig = vis_bbox(*args, **kwargs) 165 | data = fig4vis(fig) 166 | return data 167 | 168 | 169 | class Visualizer(object): 170 | """ 171 | wrapper for visdom 172 | you can still access naive visdom function by 173 | self.line, self.scater,self._send,etc. 174 | due to the implementation of `__getattr__` 175 | """ 176 | 177 | def __init__(self, env='default', **kwargs): 178 | self.vis = visdom.Visdom('localhost',env=env, use_incoming_socket=False, **kwargs) 179 | self._vis_kw = kwargs 180 | 181 | # e.g.('loss',23) the 23th value of loss 182 | self.index = {} 183 | self.log_text = '' 184 | 185 | def reinit(self, env='default', **kwargs): 186 | """ 187 | change the config of visdom 188 | """ 189 | self.vis = visdom.Visdom(env=env, **kwargs) 190 | return self 191 | 192 | def plot_many(self, d): 193 | """ 194 | plot multi values 195 | @params d: dict (name,value) i.e. ('loss',0.11) 196 | """ 197 | for k, v in d.items(): 198 | if v is not None: 199 | self.plot(k, v) 200 | 201 | def img_many(self, d): 202 | for k, v in d.items(): 203 | self.img(k, v) 204 | 205 | def plot(self, name, y, **kwargs): 206 | """ 207 | self.plot('loss',1.00) 208 | """ 209 | x = self.index.get(name, 0) 210 | self.vis.line(Y=np.array([y]), X=np.array([x]), 211 | win=name, 212 | opts=dict(title=name), 213 | update=None if x == 0 else 'append', 214 | **kwargs 215 | ) 216 | self.index[name] = x + 1 217 | 218 | def img(self, name, img_, **kwargs): 219 | """ 220 | self.img('input_img',t.Tensor(64,64)) 221 | self.img('input_imgs',t.Tensor(3,64,64)) 222 | self.img('input_imgs',t.Tensor(100,1,64,64)) 223 | self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10) 224 | !!don't ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!! 225 | """ 226 | self.vis.images(t.Tensor(img_).cpu().numpy(), 227 | win=name, 228 | opts=dict(title=name), 229 | **kwargs 230 | ) 231 | 232 | def log(self, info, win='log_text'): 233 | """ 234 | self.log({'loss':1,'lr':0.0001}) 235 | """ 236 | self.log_text += ('[{time}] {info}
'.format( 237 | time=time.strftime('%m%d_%H%M%S'), \ 238 | info=info)) 239 | self.vis.text(self.log_text, win) 240 | 241 | def __getattr__(self, name): 242 | return getattr(self.vis, name) 243 | 244 | def state_dict(self): 245 | return { 246 | 'index': self.index, 247 | 'vis_kw': self._vis_kw, 248 | 'log_text': self.log_text, 249 | 'env': self.vis.env 250 | } 251 | 252 | def load_state_dict(self, d): 253 | self.vis = visdom.Visdom(env=d.get('env', self.vis.env), **(self.d.get('vis_kw'))) 254 | self.log_text = d.get('log_text', '') 255 | self.index = d.get('index', dict()) 256 | return self 257 | --------------------------------------------------------------------------------