├── LICENSE ├── README.md ├── config └── config.cfg ├── data ├── vis_tp.png └── whole_figure.png ├── dataset ├── __pycache__ │ └── dataset_vd.cpython-36.pyc ├── cocodataset.py └── dataset_vd.py ├── docker └── Dockerfile ├── models ├── shap_loss_1.py ├── yolo_layer.py └── yolov3_shap.py ├── requirements ├── download_weights.sh ├── getcoco.sh └── requirements.txt ├── test.ipynb ├── train_vd_reg.py └── utils ├── cocoapi_evaluator.py ├── parse_yolo_weights.py ├── resize_gt.py ├── utils.py ├── vd_evaluator.py └── vis_bbox.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 Hiroki Kawauchi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, and/or sublicense 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software; and 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 18 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 19 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 20 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 21 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | 23 | 24 | This software uses some portions from the following software under its license: 25 | 26 | PyTorch_YOLOv3 27 | 28 | The MIT License 29 | 30 | Copyright (c) 2018 DeNA Co., Ltd. 31 | 32 | Permission is hereby granted, free of charge, to any person obtaining a copy 33 | of this software and associated documentation files (the "Software"), to deal 34 | in the Software without restriction, including without limitation the rights 35 | to use, copy, modify, merge, publish, distribute, and/or sublicense 36 | copies of the Software, and to permit persons to whom the Software is 37 | furnished to do so, subject to the following conditions: 38 | 39 | The above copyright notice and this permission notice shall be included in all 40 | copies or substantial portions of the Software; and 41 | 42 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 43 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 44 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 45 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 46 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 47 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 48 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SHAP-Based Interpretable Object Detection Method for Satellite Imagery 2 | This is the author implementation of [SHAP-Based Interpretable Object Detection Method for Satellite Imagery](https://www.mdpi.com/2072-4292/14/9/1970). The implementation of the object detection model (YOLOv3) is based on [Pytorch_YOLOv3](https://github.com/DeNA/PyTorch_YOLOv3). The framework of the proposed method can be applied to any differentiable object detection model. 3 | 4 |

5 | 6 | ## Performance 7 | 8 | #### Visualization 9 |

10 | 11 | Please see the paper for details on the results of the evaluation, regularization, and data selection methods. 12 | 13 | ## Installation 14 | #### Requirements 15 | 16 | - Python 3.6.3+ 17 | - Numpy 18 | - OpenCV 19 | - Matplotlib 20 | - Pytorch 1.2+ 21 | - Cython 22 | - Cuda (verified as operable: v10.2) 23 | - Captum (verified as operable: v0.4.1) 24 | 25 | optional: 26 | - tensorboard 27 | - [tensorboardX](https://github.com/lanpa/tensorboardX) 28 | - CuDNN 29 | 30 | #### Download the original YOLOv3 weights 31 | download the pretrained file from the author's project page: 32 | 33 | ```bash 34 | $ mkdir weights 35 | $ cd weights/ 36 | $ bash ../requirements/download_weights.sh 37 | ``` 38 | 39 | ## Usage 40 | 41 | Please see the test.ipynb 42 | 43 | 44 | ## Paper 45 | ### SHAP-based Methods for Interpretable Object Detection in Satellite Imagery 46 | _Hiroki Kawauchi, Takashi Fuse_
47 | 48 | [[Paper]](https://www.mdpi.com/2072-4292/14/9/1970) [[Original Implementation]](https://github.com/hiroki-kawauchi/SHAPObjectDetection.git) 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/config.cfg: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: YOLOv3 3 | BACKBONE: darknet53 4 | ANCHORS: [[10, 13], [16, 30], [33, 23], 5 | [30, 61], [62, 45], [59, 119], 6 | [116, 90], [156, 198], [373, 326]] 7 | ANCH_MASK: [[6, 7, 8], [3, 4, 5], [0, 1, 2]] 8 | N_CLASSES: 1 9 | TRAIN: 10 | TRAIN_DIR: ../train 11 | VAL_DIR: ../val 12 | LR: 0.001 13 | MOMENTUM: 0.9 14 | DECAY: 0.0005 15 | BURN_IN: 100 16 | MAXITER: 3000 17 | STEPS: (1200, 1500) 18 | BATCHSIZE: 4 19 | SUBDIVISION: 16 20 | IMGSIZE: 416 21 | LOSSTYPE: l2 22 | IGNORETHRE: 0.7 23 | ATTENTION_ALPHA: 0.0 24 | ATTENTION_BETA: 1.0 25 | N_SAMPLES: 1 26 | AUGMENTATION: 27 | RANDRESIZE: False 28 | JITTER: 0 29 | RANDOM_PLACING: True 30 | HUE: 0.1 31 | SATURATION: 1.5 32 | EXPOSURE: 1.5 33 | LRFLIP: True 34 | RANDOM_DISTORT: True 35 | TEST: 36 | CONFTHRE: 0.50 37 | NMSTHRE: 0.50 38 | IMGSIZE: 416 39 | NUM_GPUS: 1 40 | -------------------------------------------------------------------------------- /data/vis_tp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiroki-kawauchi/SHAPObjectDetection/5ea90f743227ac22fa07fbeb9b0e1ea821e5defa/data/vis_tp.png -------------------------------------------------------------------------------- /data/whole_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiroki-kawauchi/SHAPObjectDetection/5ea90f743227ac22fa07fbeb9b0e1ea821e5defa/data/whole_figure.png -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_vd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiroki-kawauchi/SHAPObjectDetection/5ea90f743227ac22fa07fbeb9b0e1ea821e5defa/dataset/__pycache__/dataset_vd.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/cocodataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | import cv2 7 | from pycocotools.coco import COCO 8 | 9 | from utils.utils import * 10 | 11 | 12 | class COCODataset(Dataset): 13 | """ 14 | COCO dataset class. 15 | """ 16 | def __init__(self, model_type, data_dir='COCO', json_file='instances_train2017.json', 17 | name='train2017', img_size=416, 18 | augmentation=None, min_size=1, debug=False): 19 | """ 20 | COCO dataset initialization. Annotation data are read into memory by COCO API. 21 | Args: 22 | model_type (str): model name specified in config file 23 | data_dir (str): dataset root directory 24 | json_file (str): COCO json file name 25 | name (str): COCO data name (e.g. 'train2017' or 'val2017') 26 | img_size (int): target image size after pre-processing 27 | min_size (int): bounding boxes smaller than this are ignored 28 | debug (bool): if True, only one data id is selected from the dataset 29 | """ 30 | self.data_dir = data_dir 31 | self.json_file = json_file 32 | self.model_type = model_type 33 | self.coco = COCO(self.data_dir+'annotations/'+self.json_file) 34 | self.ids = self.coco.getImgIds() 35 | if debug: 36 | self.ids = self.ids[1:2] 37 | print("debug mode...", self.ids) 38 | self.class_ids = sorted(self.coco.getCatIds()) 39 | self.name = name 40 | self.max_labels = 50 41 | self.img_size = img_size 42 | self.min_size = min_size 43 | self.lrflip = augmentation['LRFLIP'] 44 | self.jitter = augmentation['JITTER'] 45 | self.random_placing = augmentation['RANDOM_PLACING'] 46 | self.hue = augmentation['HUE'] 47 | self.saturation = augmentation['SATURATION'] 48 | self.exposure = augmentation['EXPOSURE'] 49 | self.random_distort = augmentation['RANDOM_DISTORT'] 50 | 51 | 52 | def __len__(self): 53 | return len(self.ids) 54 | 55 | def __getitem__(self, index): 56 | """ 57 | One image / label pair for the given index is picked up \ 58 | and pre-processed. 59 | Args: 60 | index (int): data index 61 | Returns: 62 | img (numpy.ndarray): pre-processed image 63 | padded_labels (torch.Tensor): pre-processed label data. \ 64 | The shape is :math:`[self.max_labels, 5]`. \ 65 | each label consists of [class, xc, yc, w, h]: 66 | class (float): class index. 67 | xc, yc (float) : center of bbox whose values range from 0 to 1. 68 | w, h (float) : size of bbox whose values range from 0 to 1. 69 | info_img : tuple of h, w, nh, nw, dx, dy. 70 | h, w (int): original shape of the image 71 | nh, nw (int): shape of the resized image without padding 72 | dx, dy (int): pad size 73 | id_ (int): same as the input index. Used for evaluation. 74 | """ 75 | id_ = self.ids[index] 76 | 77 | anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=None) 78 | annotations = self.coco.loadAnns(anno_ids) 79 | 80 | lrflip = False 81 | if np.random.rand() > 0.5 and self.lrflip == True: 82 | lrflip = True 83 | 84 | # load image and preprocess 85 | img_file = os.path.join(self.data_dir, self.name, 86 | '{:012}'.format(id_) + '.jpg') 87 | img = cv2.imread(img_file) 88 | 89 | if self.json_file == 'instances_val5k.json' and img is None: 90 | img_file = os.path.join(self.data_dir, 'train2017', 91 | '{:012}'.format(id_) + '.jpg') 92 | img = cv2.imread(img_file) 93 | assert img is not None 94 | 95 | img, info_img = preprocess(img, self.img_size, jitter=self.jitter, 96 | random_placing=self.random_placing) 97 | 98 | if self.random_distort: 99 | img = random_distort(img, self.hue, self.saturation, self.exposure) 100 | 101 | img = np.transpose(img / 255., (2, 0, 1)) 102 | 103 | if lrflip: 104 | img = np.flip(img, axis=2).copy() 105 | 106 | # load labels 107 | labels = [] 108 | for anno in annotations: 109 | if anno['bbox'][2] > self.min_size and anno['bbox'][3] > self.min_size: 110 | labels.append([]) 111 | labels[-1].append(self.class_ids.index(anno['category_id'])) 112 | labels[-1].extend(anno['bbox']) 113 | 114 | padded_labels = np.zeros((self.max_labels, 5)) 115 | if len(labels) > 0: 116 | labels = np.stack(labels) 117 | if 'YOLO' in self.model_type: 118 | labels = label2yolobox(labels, info_img, self.img_size, lrflip) 119 | padded_labels[range(len(labels))[:self.max_labels] 120 | ] = labels[:self.max_labels] 121 | padded_labels = torch.from_numpy(padded_labels) 122 | 123 | return img, padded_labels, info_img, id_ 124 | -------------------------------------------------------------------------------- /dataset/dataset_vd.py: -------------------------------------------------------------------------------- 1 | # reference: 2 | # https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/master/utils/datasets.py 3 | import json 4 | import glob 5 | import random 6 | import os 7 | import sys 8 | import numpy as np 9 | import cv2 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | 14 | from torch.utils.data import Dataset 15 | import torchvision.transforms as transforms 16 | 17 | from utils.utils import * 18 | 19 | 20 | def load_classes(path): 21 | """ 22 | Loads class labels at 'path' 23 | """ 24 | fp = open(path, "r") 25 | names = fp.read().split("\n")[:-1] 26 | return names 27 | ''' 28 | def pad_to_square(img, pad_value): 29 | c, h, w = img.shape 30 | dim_diff = np.abs(h - w) 31 | # (upper / left) padding and (lower / right) padding 32 | pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2 33 | # Determine padding 34 | pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0) 35 | # Add padding 36 | img = F.pad(img, pad, "constant", value=pad_value) 37 | 38 | return img, pad 39 | 40 | 41 | def resize(image, size): 42 | image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0) 43 | return image 44 | 45 | 46 | def random_resize(images, min_size=288, max_size=448): 47 | new_size = random.sample(list(range(min_size, max_size + 1, 32)), 1)[0] 48 | images = F.interpolate(images, size=new_size, mode="nearest") 49 | return images 50 | 51 | 52 | class ImageFolder(Dataset): 53 | def __init__(self, folder_path, img_size=416): 54 | self.files = sorted(glob.glob("%s/*.*" % folder_path)) 55 | self.img_size = img_size 56 | 57 | def __getitem__(self, index): 58 | img_path = self.files[index % len(self.files)] 59 | # Extract image as PyTorch tensor 60 | img = transforms.ToTensor()(Image.open(img_path)) 61 | # Pad to square resolution 62 | img, _ = pad_to_square(img, 0) 63 | # Resize 64 | img = resize(img, self.img_size) 65 | 66 | return img_path, img 67 | 68 | def __len__(self): 69 | return len(self.files) 70 | ''' 71 | 72 | class ListDataset(Dataset): 73 | def __init__(self, model_type, data_dir, json_file='anno_data.json', 74 | img_size=416, 75 | augmentation=None,min_size=1): 76 | """ 77 | Vehicle detection dataset initialization. Annotation data are read into memory by COCO API. 78 | Args: 79 | model_type (str): model name specified in config file 80 | list_path (str): dataset list textfile path 81 | img_size (int): target image size after pre-processing 82 | min_size (int): bounding boxes smaller than this are ignored 83 | """ 84 | self.model_type = model_type 85 | self.img_size = img_size 86 | self.max_labels = 100 87 | self.min_size = min_size 88 | self.data_dir = data_dir 89 | self.json_file = json_file 90 | self.lrflip = augmentation['LRFLIP'] 91 | self.jitter = augmentation['JITTER'] 92 | self.random_placing = augmentation['RANDOM_PLACING'] 93 | self.hue = augmentation['HUE'] 94 | self.saturation = augmentation['SATURATION'] 95 | self.exposure = augmentation['EXPOSURE'] 96 | self.random_distort = augmentation['RANDOM_DISTORT'] 97 | 98 | self.img_list = glob.glob(os.path.join(self.data_dir, '*.jpg')) 99 | self.img_list.extend(glob.glob(os.path.join(self.data_dir,'*.png'))) 100 | 101 | def __getitem__(self, index): 102 | """ 103 | One image / label pair for the given index is picked up \ 104 | and pre-processed. 105 | Args: 106 | index (int): data index 107 | Returns: 108 | img (numpy.ndarray): pre-processed image 109 | padded_labels (torch.Tensor): pre-processed label data. \ 110 | The shape is :math:`[self.max_labels, 5]`. \ 111 | each label consists of [class, xc, yc, w, h]: 112 | class (float): class index. 113 | xc, yc (float) : center of bbox whose values range from 0 to 1. 114 | w, h (float) : size of bbox whose values range from 0 to 1. 115 | info_img : tuple of h, w, nh, nw, dx, dy. 116 | h, w (int): original shape of the image 117 | nh, nw (int): shape of the resized image without padding 118 | dx, dy (int): pad size 119 | id_ (int): same as the input index. Used for evaluation. 120 | """ 121 | 122 | # load image and preprocess 123 | img_path = self.img_list[index % len(self.img_list)] 124 | img = cv2.imread(img_path) 125 | assert img is not None 126 | img, info_img = preprocess(img, self.img_size, jitter=self.jitter, 127 | random_placing=self.random_placing) 128 | if self.random_distort: 129 | img = random_distort(img, self.hue, self.saturation, self.exposure) 130 | 131 | img = np.transpose(img / 255., (2, 0, 1)) 132 | 133 | lrflip = False 134 | if np.random.rand() > 0.5 and self.lrflip == True: 135 | lrflip = True 136 | 137 | if lrflip: 138 | img = np.flip(img, axis=2).copy() 139 | 140 | # load labels 141 | json_open = open(os.path.join(self.data_dir, self.json_file), 'r') 142 | json_load = json.load(json_open) 143 | annotations = json_load[os.path.basename(img_path)]['regions'] 144 | 145 | labels = [] 146 | for anno in annotations: 147 | if anno['bb'][2] > self.min_size and anno['bb'][3] > self.min_size: 148 | labels.append([]) 149 | labels[-1].append(anno['class_id']) 150 | labels[-1].extend(anno['bb']) 151 | 152 | padded_labels = np.zeros((self.max_labels, 5)) 153 | if len(labels) > 0: 154 | labels = np.stack(labels).astype(np.float64) 155 | if 'YOLO' in self.model_type: 156 | labels = label2yolobox(labels, info_img, self.img_size, lrflip) 157 | padded_labels[range(len(labels))[:self.max_labels] 158 | ] = labels[:self.max_labels] 159 | padded_labels = torch.from_numpy(padded_labels) 160 | 161 | return img, padded_labels, info_img, index 162 | 163 | 164 | def __len__(self): 165 | return len(self.img_list) -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | sudo \ 5 | git \ 6 | zip \ 7 | libopencv-dev \ 8 | build-essential libssl-dev libbz2-dev libreadline-dev libsqlite3-dev curl \ 9 | wget && \ 10 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 11 | 12 | ARG UID 13 | RUN useradd docker -l -u $UID -G sudo -s /bin/bash -m 14 | RUN echo 'Defaults visiblepw' >> /etc/sudoers 15 | RUN echo 'docker ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers 16 | 17 | USER docker 18 | 19 | ENV PYENV_ROOT /home/docker/.pyenv 20 | ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH 21 | RUN curl -L https://raw.githubusercontent.com/yyuu/pyenv-installer/master/bin/pyenv-installer | bash 22 | 23 | ENV PYTHON_VERSION 3.6.8 24 | RUN pyenv install ${PYTHON_VERSION} && pyenv global ${PYTHON_VERSION} 25 | 26 | RUN pip install -U pip setuptools 27 | # for pycocotools 28 | RUN pip install Cython==0.29.1 numpy==1.15.4 29 | 30 | COPY requirements/requirements.txt /tmp/requirements.txt 31 | RUN pip install -r /tmp/requirements.txt 32 | 33 | # mount YOLOv3-in-PyTorch to /work 34 | WORKDIR /work 35 | 36 | ENTRYPOINT ["/bin/bash"] 37 | -------------------------------------------------------------------------------- /models/shap_loss_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from models.yolov3_shap import * 4 | from utils.utils import * 5 | import sys 6 | sys.path.append('/home/linuxserver01/packages') 7 | from captum.attr import GradientShap 8 | from captum.attr import visualization as viz 9 | 10 | import numpy as np 11 | 12 | def zscore(x, axis = None): 13 | xmean = x.mean(axis=axis, keepdims=True) 14 | xstd = np.std(x, axis=axis, keepdims=True) 15 | zscore = (x-xmean)/xstd 16 | return zscore 17 | 18 | def bb_mask(bboxes, attr, img_H=608, img_W=608, std=True): 19 | pixel_attr = np.sum(attr, axis=2) 20 | if std: 21 | pixel_attr = zscore(pixel_attr) 22 | 23 | pixel_mask = np.zeros((img_H,img_W)) 24 | in_area = 0 25 | for [x1,y1,x2,y2] in bboxes: 26 | for i in range(int(y1), int(y2)): 27 | for j in range(int(x1),int(x2)): 28 | pixel_mask[i,j] = 1 29 | in_area += 1 30 | out_area = img_H*img_W - in_area 31 | 32 | in_pos = 0 33 | in_neg = 0 34 | out_pos = 0 35 | out_neg = 0 36 | for i in range(img_H): 37 | for j in range(img_W): 38 | if pixel_mask[i,j]>0: 39 | if pixel_attr[i,j]>=0: 40 | in_pos += pixel_attr[i,j] 41 | else: 42 | in_neg += pixel_attr[i,j] 43 | else: 44 | if pixel_attr[i,j]>=0: 45 | out_pos += pixel_attr[i,j] 46 | else: 47 | out_neg += pixel_attr[i,j] 48 | if in_area >0: 49 | in_pos = in_pos/in_area 50 | in_neg = in_neg/in_area 51 | if out_area>0: 52 | out_pos = out_pos/out_area 53 | out_neg = out_neg/out_area 54 | 55 | return in_pos, in_neg, out_pos, out_neg 56 | 57 | 58 | def shaploss(imgs, labels, model, num_classes, confthre, nmsthre, stdevs=0.1, target_y='cls', 59 | multiply_by_inputs=True, n_samples=5, alpha=1.0, beta=1.0): 60 | """ 61 | Calculating SHAP-loss 62 | Args: 63 | imgs (torch.Tensor) : input data whose shape is :math:`(N, C, H, W)`, \ 64 | where N, C are batchsize and num. of channels. 65 | labels (torch.Tensor) : label array whose shape is :math:`(N, 100, 5)` 66 | each label consists of [class, xc, yc, w, h]: 67 | class (float): class index. 68 | xc, yc (float) : center of bbox whose values range from 0 to 1. 69 | w, h (float) : size of bbox whose values range from 0 to 1. 70 | info_imgs : 71 | info_img : tuple of h, w, nh, nw, dx, dy. 72 | h, w (int): original shape of the image 73 | nh, nw (int): shape of the resized image without padding 74 | dx, dy (int): pad size 75 | Returns: 76 | loss (torch.Tensor) 77 | """ 78 | # Find TP output 79 | model.eval() 80 | model(imgs) 81 | with torch.no_grad(): 82 | outputss = model(imgs) 83 | outputss = postprocess(outputss, num_classes, confthre, nmsthre) 84 | 85 | labels = labels.cpu().data 86 | nlabel = (labels.sum(dim=2) > 0).sum(dim=1)# numbers of gt-objects in each image 87 | 88 | gt_bboxes_tp = [None for _ in range(len(nlabel))] 89 | output_id_tp = [None for _ in range(len(nlabel))] 90 | for outputs, label, i in zip(outputss, labels, range(len(nlabel))): 91 | gt_bboxes = [] 92 | yolo_bboxes = [] 93 | # gt bbox 94 | H = imgs.size(2) 95 | W = imgs.size(3) 96 | for j in range(nlabel[i]): 97 | _, xc, yc, w, h = label[j] # [class, xc, yc, w, h] 98 | xyxy_label = [(xc-w/2)*W, (yc-h/2)*H, (xc+w/2)*W, (yc+h/2)*H] # [x1,y1,x2,y2] 99 | gt_bboxes.append([i,xyxy_label]) 100 | 101 | if outputss[0] is not None: 102 | outputs = outputss[0].cpu().data 103 | for output in outputs: 104 | x1 = output[0].data.item() 105 | y1 = float(output[1].data.item()) 106 | x2 = float(output[2].data.item()) 107 | y2 = float(output[3].data.item()) 108 | score = float(output[4].data.item() * output[5].data.item()) 109 | label = int(output[6].data.item()) 110 | yolo_bbox = [i,[x1,y1,x2,y2], score, label, False, 111 | int(output[7].data.item()),int(output[8].data.item()), int(output[9].data.item()), int(output[10].data.item())]#layer_num, anchor_num, x,y 112 | yolo_bboxes.append(yolo_bbox) 113 | 114 | # judge TP or FP 115 | # score sort 116 | yolo_bboxes = sorted(yolo_bboxes, key=lambda x: x[2]) 117 | 118 | for k in range(len(yolo_bboxes)): 119 | a = None 120 | t = 0 121 | for gt_bbox in gt_bboxes: 122 | iou = np_bboxes_iou(np.array(yolo_bboxes[k][1]), np.array(gt_bbox[1]).reshape(1,4)) 123 | if iou > max(0.5, t): 124 | a = gt_bbox 125 | t = iou 126 | if a != None: 127 | gt_bboxes_tp[i] = a[1] 128 | gt_bboxes.remove(a) 129 | yolo_bboxes[k][4] = True 130 | output_id_tp[i] = [yolo_bboxes[k][5], yolo_bboxes[k][6], yolo_bboxes[k][7], yolo_bboxes[k][8]] 131 | break 132 | 133 | def yolo_wrapper(inp, output_id): 134 | layer_num, anchor_num, x, y = output_id 135 | output = model(inp, shap=True) 136 | return output[layer_num][:,anchor_num,y,x] 137 | if target_y == 'obj': 138 | target_y = 4 139 | elif target_y == 'cls': 140 | target_y = 5 141 | 142 | num_no_tp = 0 143 | inside_bb_sum = 0 144 | outside_bb_sum = 0 145 | 146 | for gt_bbox, output_id, img in zip(gt_bboxes_tp, output_id_tp, imgs): 147 | if gt_bbox == None: 148 | num_no_tp += 1 149 | continue 150 | 151 | img = img.reshape(1,3,img.size(1),img.size(2)) 152 | 153 | baselines = img * 0 154 | with torch.no_grad(): 155 | gs = GradientShap(yolo_wrapper, multiply_by_inputs=multiply_by_inputs) 156 | attr = gs.attribute(img, additional_forward_args=output_id, 157 | n_samples=n_samples, stdevs=stdevs, 158 | baselines=baselines, target=target_y, 159 | return_convergence_delta=False) 160 | attr = np.transpose(attr.squeeze().cpu().detach().numpy(), (1,2,0)) 161 | in_pos, in_neg, out_pos, out_neg = bb_mask([gt_bbox], attr, img_H=imgs.size(2), img_W=imgs.size(3)) 162 | inside_bb_sum -= in_neg 163 | outside_bb_sum += out_pos 164 | loss = alpha * inside_bb_sum + beta * outside_bb_sum 165 | if len(imgs)>num_no_tp: 166 | loss = loss*len(imgs)/(len(imgs) - num_no_tp) 167 | return torch.tensor(loss, device="cuda:0",dtype=torch.float) -------------------------------------------------------------------------------- /models/yolo_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from utils.utils import bboxes_iou 5 | 6 | 7 | class YOLOLayer(nn.Module): 8 | """ 9 | detection layer corresponding to yolo_layer.c of darknet 10 | """ 11 | def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7): 12 | """ 13 | Args: 14 | config_model (dict) : model configuration. 15 | ANCHORS (list of tuples) : 16 | ANCH_MASK: (list of int list): index indicating the anchors to be 17 | used in YOLO layers. One of the mask group is picked from the list. 18 | N_CLASSES (int): number of classes 19 | layer_no (int): YOLO layer number - one from (0, 1, 2). 20 | in_ch (int): number of input channels. 21 | ignore_thre (float): threshold of IoU above which objectness training is ignored. 22 | """ 23 | 24 | super(YOLOLayer, self).__init__() 25 | strides = [32, 16, 8] # fixed 26 | self.anchors = config_model['ANCHORS'] 27 | self.anch_mask = config_model['ANCH_MASK'][layer_no] 28 | self.n_anchors = len(self.anch_mask) 29 | self.n_classes = config_model['N_CLASSES'] 30 | self.ignore_thre = ignore_thre 31 | self.l2_loss = nn.MSELoss(size_average=False) 32 | self.bce_loss = nn.BCELoss(size_average=False) 33 | self.stride = strides[layer_no] 34 | self.all_anchors_grid = [(w / self.stride, h / self.stride) 35 | for w, h in self.anchors] 36 | self.masked_anchors = [self.all_anchors_grid[i] 37 | for i in self.anch_mask] 38 | self.ref_anchors = np.zeros((len(self.all_anchors_grid), 4)) 39 | self.ref_anchors[:, 2:] = np.array(self.all_anchors_grid) 40 | self.ref_anchors = torch.FloatTensor(self.ref_anchors) 41 | self.conv = nn.Conv2d(in_channels=in_ch, 42 | out_channels=self.n_anchors * (self.n_classes + 5), 43 | kernel_size=1, stride=1, padding=0) 44 | self.layer_no = layer_no 45 | def forward(self, xin, labels=None, shap=False): 46 | """ 47 | In this 48 | Args: 49 | xin (torch.Tensor): input feature map whose size is :math:`(N, C, H, W)`, \ 50 | where N, C, H, W denote batchsize, channel width, height, width respectively. 51 | labels (torch.Tensor): label data whose size is :math:`(N, K, 5)`. \ 52 | N and K denote batchsize and number of labels. 53 | Each label consists of [class, xc, yc, w, h]: 54 | class (float): class index. 55 | xc, yc (float) : center of bbox whose values range from 0 to 1. 56 | w, h (float) : size of bbox whose values range from 0 to 1. 57 | Returns: 58 | loss (torch.Tensor): total loss - the target of backprop. 59 | loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \ 60 | with boxsize-dependent weights. 61 | loss_wh (torch.Tensor): w, h loss - calculated by l2 without size averaging and \ 62 | with boxsize-dependent weights. 63 | loss_obj (torch.Tensor): objectness loss - calculated by BCE. 64 | loss_cls (torch.Tensor): classification loss - calculated by BCE for each class. 65 | loss_l2 (torch.Tensor): total l2 loss - only for logging. 66 | """ 67 | output = self.conv(xin) 68 | 69 | batchsize = output.shape[0] 70 | fsize = output.shape[2] 71 | n_ch = 5 + self.n_classes 72 | dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor 73 | 74 | output = output.reshape(batchsize, self.n_anchors, n_ch, fsize, fsize) 75 | output = output.permute(0, 1, 3, 4, 2) # .contiguous() 76 | 77 | # logistic activation for xy, obj, cls 78 | output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid( 79 | output[..., np.r_[:2, 4:n_ch]]) 80 | 81 | if shap: 82 | return output 83 | 84 | # calculate pred - xywh obj cls 85 | 86 | to_xshift = np.broadcast_to( 87 | np.arange(fsize, dtype=np.float32), output.shape[:4]) 88 | to_yshift = np.broadcast_to( 89 | np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4]) 90 | to_xshift.flags.writeable = True 91 | to_yshift.flags.writeable = True 92 | 93 | x_shift = dtype(to_xshift) 94 | y_shift = dtype(to_yshift) 95 | 96 | 97 | masked_anchors = np.array(self.masked_anchors) 98 | 99 | to_wanchors = np.broadcast_to(np.reshape( 100 | masked_anchors[:, 0], (1, self.n_anchors, 1, 1)), output.shape[:4]) 101 | to_hanchors = np.broadcast_to(np.reshape( 102 | masked_anchors[:, 1], (1, self.n_anchors, 1, 1)), output.shape[:4]) 103 | to_wanchors.flags.writeable = True 104 | to_hanchors.flags.writeable = True 105 | 106 | w_anchors = dtype(to_wanchors) 107 | h_anchors = dtype(to_hanchors) 108 | 109 | if labels is None: # inference mode 110 | anchor_num = np.broadcast_to( 111 | np.arange(self.n_anchors, dtype=np.float32).reshape(1, self.n_anchors, 1, 1), output.shape[:4]) 112 | anchor_num.flags.writeable = True 113 | anchor_num = dtype(anchor_num) 114 | 115 | output = torch.cat([output,torch.full((batchsize, self.n_anchors, fsize, fsize,1), 116 | fill_value=float(self.layer_no), device=torch.device('cuda'))],dim=4) 117 | output = torch.cat([output,anchor_num.reshape(batchsize, self.n_anchors, fsize, fsize,1)],dim=4) 118 | output = torch.cat([output,x_shift.reshape(batchsize, self.n_anchors, fsize, fsize,1)],dim=4) 119 | output = torch.cat([output,y_shift.reshape(batchsize, self.n_anchors, fsize, fsize,1)],dim=4) 120 | 121 | pred = output.clone() 122 | pred[..., 0] += x_shift 123 | pred[..., 1] += y_shift 124 | pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors 125 | pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors 126 | 127 | if labels is None: # inference mode 128 | pred[..., :4] *= self.stride 129 | return pred.reshape(batchsize, -1, n_ch+4).data 130 | 131 | pred = pred[..., :4].data 132 | 133 | # target assignment 134 | 135 | tgt_mask = torch.zeros(batchsize, self.n_anchors, 136 | fsize, fsize, 4 + self.n_classes).type(dtype) 137 | obj_mask = torch.ones(batchsize, self.n_anchors, 138 | fsize, fsize).type(dtype) 139 | tgt_scale = torch.zeros(batchsize, self.n_anchors, 140 | fsize, fsize, 2).type(dtype) 141 | 142 | target = torch.zeros(batchsize, self.n_anchors, 143 | fsize, fsize, n_ch).type(dtype) 144 | 145 | labels = labels.cpu().data 146 | nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects 147 | 148 | truth_x_all = labels[:, :, 1] * fsize 149 | truth_y_all = labels[:, :, 2] * fsize 150 | truth_w_all = labels[:, :, 3] * fsize 151 | truth_h_all = labels[:, :, 4] * fsize 152 | truth_i_all = truth_x_all.to(torch.int16).numpy() 153 | truth_j_all = truth_y_all.to(torch.int16).numpy() 154 | 155 | for b in range(batchsize): 156 | n = int(nlabel[b]) 157 | if n == 0: 158 | continue 159 | truth_box = dtype(np.zeros((n, 4))) 160 | truth_box[:n, 2] = truth_w_all[b, :n] 161 | truth_box[:n, 3] = truth_h_all[b, :n] 162 | truth_i = truth_i_all[b, :n] 163 | truth_j = truth_j_all[b, :n] 164 | 165 | # calculate iou between truth and reference anchors 166 | anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors) 167 | best_n_all = np.argmax(anchor_ious_all, axis=1) 168 | best_n = best_n_all % 3 169 | best_n_mask = ((best_n_all == self.anch_mask[0]) | ( 170 | best_n_all == self.anch_mask[1]) | (best_n_all == self.anch_mask[2])) 171 | 172 | truth_box[:n, 0] = truth_x_all[b, :n] 173 | truth_box[:n, 1] = truth_y_all[b, :n] 174 | 175 | pred_ious = bboxes_iou( 176 | pred[b].reshape(-1, 4), truth_box, xyxy=False) 177 | pred_best_iou, _ = pred_ious.max(dim=1) 178 | pred_best_iou = (pred_best_iou > self.ignore_thre) 179 | pred_best_iou = pred_best_iou.reshape(pred[b].shape[:3]) 180 | # set mask to zero (ignore) if pred matches truth 181 | obj_mask[b] = 1 - pred_best_iou.int() 182 | 183 | if sum(best_n_mask) == 0: 184 | continue 185 | 186 | for ti in range(best_n.shape[0]): 187 | if best_n_mask[ti] == 1: 188 | i, j = truth_i[ti], truth_j[ti] 189 | a = best_n[ti] 190 | obj_mask[b, a, j, i] = 1 191 | tgt_mask[b, a, j, i, :] = 1 192 | target[b, a, j, i, 0] = truth_x_all[b, ti] - \ 193 | truth_x_all[b, ti].to(torch.int16).to(torch.float) 194 | target[b, a, j, i, 1] = truth_y_all[b, ti] - \ 195 | truth_y_all[b, ti].to(torch.int16).to(torch.float) 196 | target[b, a, j, i, 2] = torch.log( 197 | truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 0] + 1e-16) 198 | target[b, a, j, i, 3] = torch.log( 199 | truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 1] + 1e-16) 200 | target[b, a, j, i, 4] = 1 201 | target[b, a, j, i, 5 + labels[b, ti, 202 | 0].to(torch.int16).numpy()] = 1 203 | tgt_scale[b, a, j, i, :] = torch.sqrt( 204 | 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize) 205 | 206 | # loss calculation 207 | 208 | output[..., 4] *= obj_mask 209 | output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask 210 | output[..., 2:4] *= tgt_scale 211 | 212 | target[..., 4] *= obj_mask 213 | target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask 214 | target[..., 2:4] *= tgt_scale 215 | 216 | bceloss = nn.BCELoss(weight=tgt_scale*tgt_scale, 217 | size_average=False) # weighted BCEloss 218 | loss_xy = bceloss(output[..., :2], target[..., :2]) 219 | loss_wh = self.l2_loss(output[..., 2:4], target[..., 2:4]) / 2 220 | loss_obj = self.bce_loss(output[..., 4], target[..., 4]) 221 | loss_cls = self.bce_loss(output[..., 5:], target[..., 5:]) 222 | loss_l2 = self.l2_loss(output, target) 223 | 224 | loss = loss_xy + loss_wh + loss_obj + loss_cls 225 | 226 | return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2 227 | -------------------------------------------------------------------------------- /models/yolov3_shap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | def add_conv(in_ch, out_ch, ksize, stride): 6 | """ 7 | Add a conv2d / batchnorm / leaky ReLU block. 8 | Args: 9 | in_ch (int): number of input channels of the convolution layer. 10 | out_ch (int): number of output channels of the convolution layer. 11 | ksize (int): kernel size of the convolution layer. 12 | stride (int): stride of the convolution layer. 13 | Returns: 14 | stage (Sequential) : Sequential layers composing a convolution block. 15 | """ 16 | stage = nn.Sequential() 17 | pad = (ksize - 1) // 2 18 | stage.add_module('conv', nn.Conv2d(in_channels=in_ch, 19 | out_channels=out_ch, kernel_size=ksize, stride=stride, 20 | padding=pad, bias=False)) 21 | stage.add_module('batch_norm', nn.BatchNorm2d(out_ch)) 22 | stage.add_module('leaky', nn.LeakyReLU(0.1)) 23 | return stage 24 | 25 | 26 | class resblock(nn.Module): 27 | """ 28 | Sequential residual blocks each of which consists of \ 29 | two convolution layers. 30 | Args: 31 | ch (int): number of input and output channels. 32 | nblocks (int): number of residual blocks. 33 | shortcut (bool): if True, residual tensor addition is enabled. 34 | """ 35 | def __init__(self, ch, nblocks=1, shortcut=True): 36 | super().__init__() 37 | self.shortcut = shortcut 38 | self.module_list = nn.ModuleList() 39 | for i in range(nblocks): 40 | resblock_one = nn.ModuleList() 41 | resblock_one.append(add_conv(ch, ch//2, 1, 1)) 42 | resblock_one.append(add_conv(ch//2, ch, 3, 1)) 43 | self.module_list.append(resblock_one) 44 | 45 | def forward(self, x): 46 | for module in self.module_list: 47 | h = x 48 | for res in module: 49 | h = res(h) 50 | x = x + h if self.shortcut else h 51 | return x 52 | 53 | 54 | 55 | class YOLOLayer(nn.Module): 56 | """ 57 | Detection Layer 58 | """ 59 | def __init__(self, in_ch, n_anchors, n_classes): 60 | super(YOLOLayer, self).__init__() 61 | self.n_anchors = n_anchors 62 | self.n_classes = n_classes 63 | self.conv = nn.Conv2d(in_channels=in_ch, 64 | out_channels=self.n_anchors * (self.n_classes + 5), 65 | kernel_size=1, stride=1, padding=0) 66 | 67 | def forward(self, x, targets=None): 68 | output = self.conv(x) 69 | batchsize = output.shape[0] 70 | fsize = output.shape[2] 71 | n_ch = 5 + self.n_classes 72 | dtype = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor 73 | 74 | output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize) 75 | output = output.permute(0, 1, 3, 4, 2) # .contiguous() 76 | 77 | # logistic activation for xy, obj, cls 78 | output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid( 79 | output[..., np.r_[:2, 4:n_ch]]) 80 | 81 | return output 82 | 83 | 84 | 85 | 86 | class YOLOv3SHAP(nn.Module): 87 | """ 88 | YOLOv3 model module for calculating SHAP 89 | """ 90 | def __init__(self, n_classes): 91 | """ 92 | Initialization of YOLOv3 class. 93 | """ 94 | super(YOLOv3SHAP, self).__init__() 95 | self.n_classes = n_classes 96 | self.module_list = nn.ModuleList() 97 | # DarkNet53 98 | self.module_list.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1)) 99 | self.module_list.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2)) 100 | self.module_list.append(resblock(ch=64)) 101 | self.module_list.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2)) 102 | self.module_list.append(resblock(ch=128, nblocks=2)) 103 | self.module_list.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=2)) 104 | self.module_list.append(resblock(ch=256, nblocks=8)) # shortcut 1 from here 105 | self.module_list.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=2)) 106 | self.module_list.append(resblock(ch=512, nblocks=8)) # shortcut 2 from here 107 | self.module_list.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=2)) 108 | self.module_list.append(resblock(ch=1024, nblocks=4)) 109 | 110 | # YOLOv3 111 | self.module_list.append(resblock(ch=1024, nblocks=2, shortcut=False)) 112 | self.module_list.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1)) 113 | # 1st yolo branch 114 | self.module_list.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=1)) 115 | self.module_list.append( 116 | YOLOLayer(in_ch=1024, n_anchors=3, n_classes=self.n_classes)) 117 | 118 | self.module_list.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) 119 | self.module_list.append(nn.Upsample(scale_factor=2, mode='nearest')) 120 | self.module_list.append(add_conv(in_ch=768, out_ch=256, ksize=1, stride=1)) 121 | self.module_list.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1)) 122 | self.module_list.append(resblock(ch=512, nblocks=1, shortcut=False)) 123 | self.module_list.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1)) 124 | # 2nd yolo branch 125 | self.module_list.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1)) 126 | self.module_list.append( 127 | YOLOLayer(in_ch=512, n_anchors=3, n_classes=self.n_classes)) 128 | self.module_list.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1)) 129 | self.module_list.append(nn.Upsample(scale_factor=2, mode='nearest')) 130 | self.module_list.append(add_conv(in_ch=384, out_ch=128, ksize=1, stride=1)) 131 | self.module_list.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1)) 132 | self.module_list.append(resblock(ch=256, nblocks=2, shortcut=False)) 133 | self.module_list.append( 134 | YOLOLayer(in_ch=256, n_anchors=3, n_classes=self.n_classes)) 135 | 136 | def forward(self, x): 137 | """ 138 | Forward path of YOLOv3. 139 | Args: 140 | x (torch.Tensor) : input data whose shape is :math:`(N, C, H, W)`, \ 141 | where N, C are batchsize and num. of channels. 142 | targets (torch.Tensor) : label array whose shape is :math:`(N, 50, 5)` 143 | 144 | Returns: 145 | training: 146 | output (torch.Tensor): loss tensor for backpropagation. 147 | test: 148 | output (torch.Tensor): concatenated detection results. 149 | """ 150 | output = [] 151 | route_layers = [] 152 | for i, module in enumerate(self.module_list): 153 | # yolo layers 154 | if i in [14, 22, 28]: 155 | x = module(x) 156 | output.append(x) 157 | else: 158 | x = module(x) 159 | 160 | # route layers = shortcut 161 | if i in [6, 8, 12, 20]: 162 | route_layers.append(x) 163 | if i == 14: 164 | x = route_layers[2] 165 | if i == 22: # yolo 2nd 166 | x = route_layers[3] 167 | if i == 16: 168 | x = torch.cat((x, route_layers[1]), 1) 169 | if i == 24: 170 | x = torch.cat((x, route_layers[0]), 1) 171 | 172 | return output -------------------------------------------------------------------------------- /requirements/download_weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://pjreddie.com/media/files/yolov3.weights 4 | wget https://pjreddie.com/media/files/darknet53.conv.74 -------------------------------------------------------------------------------- /requirements/getcoco.sh: -------------------------------------------------------------------------------- 1 | mkdir COCO 2 | cd COCO 3 | 4 | wget http://images.cocodataset.org/zips/train2017.zip 5 | wget http://images.cocodataset.org/zips/val2017.zip 6 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip 7 | 8 | unzip train2017.zip 9 | unzip val2017.zip 10 | unzip annotations_trainval2017.zip 11 | 12 | rm -f train2017.zip 13 | rm -f val2017.zip 14 | rm -f annotations_trainval2017.zip 15 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.0 2 | numpy==1.15.2 3 | matplotlib==3.0.2 4 | opencv_python==3.4.4.19 5 | tensorboardX==1.4 6 | PyYAML>=4.2b1 7 | pycocotools==2.0.0 8 | 9 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"test.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyMXZB5s0UXMxzx0femqD6Fc"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["import yaml\n","\n","import cv2\n","import torch\n","from torch.autograd import Variable\n","\n","from models.yolov3 import *\n","from utils.utils import *\n","from utils.parse_yolo_weights import parse_yolo_weights\n","\n","from captum.attr import GradientShap\n","\n","import os\n","import glob\n","\n","from utils.vis_bbox import vis_bbox\n","import matplotlib.pyplot as plt\n","\n","import json"],"metadata":{"id":"fmuZNzVD-U2-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Visualization\n","You need to input the image, the information about your model, and the output index."],"metadata":{"id":"Yk1vsi8Bz7Sl"}},{"cell_type":"code","source":["cfg = '' #config file path\n","gpu = 0\n","image = '' #image file path\n","baselines = None\n","target_y == 'cls'\n","output_id = [0,0,0,0] #output index (layer anchor, x, y)\n","weights_path = ''\n","ckpt = ''\n","\n","with open(cfg, 'r') as f:\n"," cfg = yaml.load(f)\n","imgsize = cfg['TEST']['IMGSIZE']\n","model = YOLOv3(cfg['MODEL'])\n","num_classes = cfg['MODEL']['N_CLASSES']\n","\n","confthre = cfg['TEST']['CONFTHRE']\n","nmsthre = cfg['TEST']['NMSTHRE']\n","if gpu >= 0:\n"," model.cuda(gpu) \n","\n","\n","assert weights_path or ckpt, 'One of --weights_path and --ckpt must be specified'\n","\n","if weights_path:\n"," print(\"loading yolo weights %s\" % (weights_path))\n"," parse_yolo_weights(model, weights_path)\n","elif ckpt:\n"," print(\"loading checkpoint %s\" % (ckpt))\n"," state = torch.load(ckpt)\n"," if 'model_state_dict' in state.keys():\n"," model.load_state_dict(state['model_state_dict'])\n"," else:\n"," model.load_state_dict(state)\n","\n","model.eval()\n","\n","img = cv2.imread(image_path)\n","img_raw = img.copy()[:, :, ::-1].transpose((2, 0, 1))\n","img, info_img = preprocess(img, imgsize, jitter=0) # info = (h, w, nh, nw, dx, dy)\n","img = np.transpose(img / 255., (2, 0, 1))\n","img = torch.from_numpy(img).float().unsqueeze(0)\n","\n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n","\n","if baselines==None:\n"," baselines = img * 0 \n"," \n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n"," baselines = Variable(baselines.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n"," baselines = Variable(baselines.type(torch.FloatTensor))\n","\n","if target_y == 'obj':\n"," target_y = 4\n","elif target_y == 'cls':\n"," target_y = 5\n","\n","\n","# setting wrapper\n","def yolo_wrapper(inp, output_id):\n"," layer_num, anchor_num, x, y = output_id\n"," output = model(inp, shap=True)\n"," return output[layer_num][:,anchor_num,y,x]\n"," \n","with torch.no_grad():\n"," gs = GradientShap(yolo_wrapper,multiply_by_inputs=multiply_by_inputs)\n"," \n","with torch.no_grad():\n"," attr, delta = gs.attribute(img, additional_forward_args=output_id, n_samples=n_samples, stdevs=stdevs, baselines=baselines, target=target_y, return_convergence_delta=True)\n","# postprocessing of attribution\n","attr = np.transpose(attr.squeeze().cpu().detach().numpy(), (1,2,0))\n","original_image = np.transpose(img.squeeze().cpu().detach().numpy(), (1,2,0))\n","# visualization of attribution\n","pos_fig, pos_axis = viz.visualize_image_attr(attr,\n"," original_image,\n"," \"heat_map\",\n"," \"positive\",\n"," cmap=\"Reds\",\n"," show_colorbar=True,\n"," fig_size=(8,6))\n","neg_fig, neg_axis = viz.visualize_image_attr(attr,\n"," original_image,\n"," \"heat_map\",\n"," \"positive\",\n"," cmap=\"Blues\",\n"," show_colorbar=True,\n"," fig_size=(8,6))\n"],"metadata":{"id":"wcXejiMB0BS5"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Evaluation/Data Selection with SHAP"],"metadata":{"id":"Wyti2QWG4Tdv"}},{"cell_type":"code","source":["cfg = '' #config file path\n","gpu = 0\n","image = '' #image file path\n","baselines = None\n","target_y == 'cls'\n","output_id = [0,0,0,0] #output index (layer anchor, x, y)\n","weights_path = ''\n","ckpt = ''\n","bboxes = [] #bounding box index list\n","\n","with open(cfg, 'r') as f:\n"," cfg = yaml.load(f)\n","imgsize = cfg['TEST']['IMGSIZE']\n","model = YOLOv3(cfg['MODEL'])\n","num_classes = cfg['MODEL']['N_CLASSES']\n","\n","confthre = cfg['TEST']['CONFTHRE']\n","nmsthre = cfg['TEST']['NMSTHRE']\n","if gpu >= 0:\n"," model.cuda(gpu) \n","\n","\n","assert weights_path or ckpt, 'One of --weights_path and --ckpt must be specified'\n","\n","if weights_path:\n"," print(\"loading yolo weights %s\" % (weights_path))\n"," parse_yolo_weights(model, weights_path)\n","elif ckpt:\n"," print(\"loading checkpoint %s\" % (ckpt))\n"," state = torch.load(ckpt)\n"," if 'model_state_dict' in state.keys():\n"," model.load_state_dict(state['model_state_dict'])\n"," else:\n"," model.load_state_dict(state)\n","\n","model.eval()\n","\n","img = cv2.imread(image_path)\n","img_raw = img.copy()[:, :, ::-1].transpose((2, 0, 1))\n","img, info_img = preprocess(img, imgsize, jitter=0) # info = (h, w, nh, nw, dx, dy)\n","img = np.transpose(img / 255., (2, 0, 1))\n","img = torch.from_numpy(img).float().unsqueeze(0)\n","\n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n","\n","if baselines==None:\n"," baselines = img * 0 \n"," \n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n"," baselines = Variable(baselines.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n"," baselines = Variable(baselines.type(torch.FloatTensor))\n","\n","if target_y == 'obj':\n"," target_y = 4\n","elif target_y == 'cls':\n"," target_y = 5\n","\n","\n","# setting wrapper\n","def yolo_wrapper(inp, output_id):\n"," layer_num, anchor_num, x, y = output_id\n"," output = model(inp, shap=True)\n"," return output[layer_num][:,anchor_num,y,x]\n"," \n","with torch.no_grad():\n"," gs = GradientShap(yolo_wrapper,multiply_by_inputs=multiply_by_inputs)\n"," \n","with torch.no_grad():\n"," attr, delta = gs.attribute(img, additional_forward_args=output_id, n_samples=n_samples, stdevs=stdevs, baselines=baselines, target=target_y, return_convergence_delta=True)\n","\n","img_H = 416\n","img_W = 416\n","\n","def zscore(x, axis = None):\n"," xmean = x.mean(axis=axis, keepdims=True)\n"," xstd = np.std(x, axis=axis, keepdims=True)\n"," zscore = (x-xmean)/xstd\n"," return zscore\n","\n","pixel_attr = np.sum(attr, axis=2)\n","pixel_attr = zscore(pixel_attr)\n","pixel_mask = np.zeros((img_H,img_W))\n","in_area = 0\n","for [x1,y1,x2,y2] in bboxes:\n"," for i in range(int(y1), int(y2)):\n"," for j in range(int(x1),int(x2)):\n"," pixel_mask[i,j] = 1\n"," in_area += 1\n","out_area = img_H*img_W - in_area\n","\n","in_pos = 0\n","in_neg = 0\n","out_pos = 0\n","out_neg = 0\n","l_in = []\n","for i in range(img_H):\n"," for j in range(img_W):\n"," if pixel_mask[i,j]>0:\n"," if pixel_attr[i,j]>=0:\n"," in_pos += pixel_attr[i,j]\n"," l_in.append(pixel_attr[i,j])\n"," else:\n"," in_neg += pixel_attr[i,j]\n"," \n"," else:\n"," if pixel_attr[i,j]>=0:\n"," out_pos += pixel_attr[i,j]\n"," else:\n"," out_neg += pixel_attr[i,j]\n"," \n","if in_area >0:\n"," in_pos = in_pos/in_area\n"," in_neg = in_neg/in_area\n","if out_area>0:\n"," out_pos = out_pos/out_area\n"," out_neg = out_neg/out_area\n"],"metadata":{"id":"uQiX1hTj4SYU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Training with SHAP-Regularization\n","We used annotation json files which contain dict format data like\n","* {\"image_file_name\":\n"," * {\"regions\":\n"," * [ {\"class_id\": 0, \"bb\":[0,0,0,0]},..."],"metadata":{"id":"u8Sblu4L4icq"}},{"cell_type":"code","source":["!python train_vd_reg.py --cfg config/config.cfg --weights_path weights/darknet53.conv.74 --checkpoint_interval 100 --checkpoint_dir checkpoints --anno_file anno_data.json --shap_interval 10 --eval_interval 10"],"metadata":{"id":"QE-XT6a-4kkb"},"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /train_vd_reg.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | from utils.utils import * 4 | from utils.vd_evaluator import VDEvaluator 5 | from utils.parse_yolo_weights import parse_yolo_weights 6 | from models.yolov3 import * 7 | from models.shap_loss_1 import * 8 | from dataset.dataset_vd import * 9 | 10 | import os 11 | import argparse 12 | import yaml 13 | import random 14 | 15 | import torch 16 | from torch.autograd import Variable 17 | import torch.optim as optim 18 | 19 | import numpy as np 20 | import pandas as pd 21 | 22 | import time 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--cfg', type=str, default='config/yolov3_vd.cfg', 28 | help='config file. see readme') 29 | parser.add_argument('--weights_path', type=str, 30 | default=None, help='darknet weights file') 31 | parser.add_argument('--n_cpu', type=int, default=0, 32 | help='number of workers') 33 | parser.add_argument('--checkpoint_interval', type=int, 34 | default=1000, help='interval between saving checkpoints') 35 | parser.add_argument('--eval_interval', type=int, 36 | default=4000, help='interval between evaluations') 37 | parser.add_argument('--checkpoint', type=str, 38 | help='pytorch checkpoint file path') 39 | parser.add_argument('--checkpoint_dir', type=str, 40 | default='checkpoints', 41 | help='directory where checkpoint files are saved') 42 | parser.add_argument('--use_cuda', type=bool, default=True) 43 | parser.add_argument('--debug', action='store_true', default=False, 44 | help='debug mode where only one image is trained') 45 | parser.add_argument( 46 | '--tfboard', help='tensorboard path for logging', type=str, default=None) 47 | parser.add_argument('--anno_file', type=str, 48 | default='anno_data.json', help='annotation data json file name') 49 | parser.add_argument('--shap_interval', type=int, 50 | default=None, help='interval between updating shaploss') 51 | return parser.parse_args() 52 | 53 | 54 | def main(): 55 | """ 56 | SHAP-regularized YOLOv3 trainer. 57 | 58 | """ 59 | args = parse_args() 60 | print("Setting Arguments.. : ", args) 61 | 62 | cuda = torch.cuda.is_available() and args.use_cuda 63 | os.makedirs(args.checkpoint_dir, exist_ok=True) 64 | 65 | # Parse config settings 66 | with open(args.cfg, 'r') as f: 67 | cfg = yaml.load(f) 68 | 69 | print("successfully loaded config file: ", cfg) 70 | 71 | momentum = cfg['TRAIN']['MOMENTUM'] 72 | decay = cfg['TRAIN']['DECAY'] 73 | burn_in = cfg['TRAIN']['BURN_IN'] 74 | iter_size = cfg['TRAIN']['MAXITER'] 75 | steps = eval(cfg['TRAIN']['STEPS']) 76 | batch_size = cfg['TRAIN']['BATCHSIZE'] 77 | subdivision = cfg['TRAIN']['SUBDIVISION'] 78 | ignore_thre = cfg['TRAIN']['IGNORETHRE'] 79 | random_resize = cfg['AUGMENTATION']['RANDRESIZE'] 80 | base_lr = cfg['TRAIN']['LR'] / batch_size / subdivision 81 | at_alpha = cfg['TRAIN']['ATTENTION_ALPHA'] 82 | at_beta = cfg['TRAIN']['ATTENTION_BETA'] 83 | 84 | print('effective_batch_size = batch_size * iter_size = %d * %d' % 85 | (batch_size, subdivision)) 86 | 87 | # Learning rate setup 88 | def burnin_schedule(i): 89 | if i < burn_in: 90 | factor = pow(i / burn_in, 4)# pow(x, y):x^y 91 | elif i < steps[0]: 92 | factor = 1.0 93 | elif i < steps[1]: 94 | factor = 0.1 95 | else: 96 | factor = 0.01 97 | return factor 98 | 99 | # Initiate model 100 | model = YOLOv3(cfg['MODEL'], ignore_thre=ignore_thre) 101 | 102 | if args.weights_path: 103 | print("loading darknet weights....", args.weights_path) 104 | parse_yolo_weights(model, args.weights_path) 105 | elif args.checkpoint: 106 | print("loading pytorch ckpt...", args.checkpoint) 107 | state = torch.load(args.checkpoint) 108 | if 'model_state_dict' in state.keys(): 109 | model.load_state_dict(state['model_state_dict']) 110 | else: 111 | model.load_state_dict(state) 112 | 113 | if cuda: 114 | print("using cuda") 115 | model = model.cuda() 116 | 117 | if args.tfboard: 118 | print("using tfboard") 119 | from tensorboardX import SummaryWriter 120 | tblogger = SummaryWriter(args.tfboard) 121 | 122 | model.train() 123 | 124 | imgsize = cfg['TRAIN']['IMGSIZE'] 125 | 126 | dataset = ListDataset(model_type=cfg['MODEL']['TYPE'], 127 | data_dir=cfg['TRAIN']['TRAIN_DIR'], 128 | json_file=args.anno_file, 129 | img_size=imgsize, 130 | augmentation=cfg['AUGMENTATION']) 131 | 132 | dataloader = torch.utils.data.DataLoader( 133 | dataset, 134 | batch_size=batch_size, 135 | shuffle=True, 136 | num_workers=args.n_cpu) 137 | 138 | dataiterator = iter(dataloader) 139 | 140 | evaluator = VDEvaluator(data_dir=cfg['TRAIN']['VAL_DIR'], 141 | json_file=args.anno_file, 142 | img_size=cfg['TEST']['IMGSIZE'], 143 | confthre=cfg['TEST']['CONFTHRE'], 144 | nmsthre=cfg['TEST']['NMSTHRE']) 145 | 146 | dtype = torch.cuda.FloatTensor if cuda else torch.FloatTensor 147 | 148 | # optimizer setup 149 | # set weight decay only on conv.weight 150 | params_dict = dict(model.named_parameters()) 151 | params = [] 152 | for key, value in params_dict.items(): 153 | if 'conv.weight' in key: 154 | params += [{'params':value, 'weight_decay':decay * batch_size * subdivision}] 155 | else: 156 | params += [{'params':value, 'weight_decay':0.0}] 157 | optimizer = optim.SGD(params, lr=base_lr, momentum=momentum, 158 | dampening=0, weight_decay=decay * batch_size * subdivision) 159 | 160 | iter_state = 0 161 | 162 | if args.checkpoint: 163 | if 'optimizer_state_dict' in state.keys(): 164 | optimizer.load_state_dict(state['optimizer_state_dict']) 165 | iter_state = state['iter'] + 1 166 | 167 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule) 168 | 169 | # start training loop 170 | log_col = ['time(min)', 'iter','lr','xy', 'wh', 171 | 'conf', 'cls', 'shap', 'l2', 'imgsize', 172 | 'ap50', 'precision50', 'recall50', 'F_measure'] 173 | log = [] 174 | ap50 = np.nan 175 | precision50 = np.nan 176 | recall50 = np.nan 177 | F_measure = np.nan 178 | shap_loss = torch.tensor(float('nan'),dtype=torch.float32) 179 | t_0 = time.time() 180 | for iter_i in range(iter_state, iter_size + 1): 181 | 182 | # VD evaluation 183 | 184 | if iter_i % args.eval_interval == 0 and iter_i > 0: 185 | ap50, precision50, recall50, F_measure = evaluator.evaluate(model) 186 | model.train() 187 | if args.tfboard: 188 | tblogger.add_scalar('val/COCOAP50', ap50, iter_i) 189 | 190 | print('[Iter {}/{}]:AP50:{}'.format(iter_i, iter_size,ap50)) 191 | 192 | # subdivision loop 193 | optimizer.zero_grad() 194 | for inner_iter_i in range(subdivision): 195 | try: 196 | imgs, targets, _, _ = next(dataiterator) # load a batch 197 | except StopIteration: 198 | dataiterator = iter(dataloader) 199 | imgs, targets, _, _ = next(dataiterator) # load a batch 200 | imgs = Variable(imgs.type(dtype)) 201 | targets = Variable(targets.type(dtype), requires_grad=False) 202 | loss = model(imgs, targets) 203 | loss_dict = model.loss_dict 204 | # adding SHAP-based loss 205 | if args.shap_interval is not None: 206 | if inner_iter_i % args.shap_interval == 0: 207 | shap_loss_ = shaploss(imgs, targets, model, 208 | num_classes=cfg['MODEL']['N_CLASSES'], 209 | confthre=cfg['TEST']['CONFTHRE'], 210 | nmsthre=cfg['TEST']['NMSTHRE'], 211 | n_samples=cfg['TRAIN']['N_SAMPLES'], 212 | alpha=at_alpha, beta=at_beta) 213 | if shap_loss_ != 0 and shap_loss != torch.tensor(float('nan'),dtype=torch.float32): 214 | shap_loss = shap_loss_ 215 | 216 | model.train() 217 | loss += shap_loss 218 | loss.backward() 219 | 220 | optimizer.step() 221 | scheduler.step() 222 | 223 | if iter_i % 10 == 0: 224 | # logging 225 | current_lr = scheduler.get_lr()[0] * batch_size * subdivision 226 | t = (time.time() - t_0)//60 227 | print('[Time %d] [Iter %d/%d] [lr %f] ' 228 | '[Losses: xy %f, wh %f, conf %f, cls %f, att %f, total %f, imgsize %d, ap %f, precision %f, recall %f, F %f]' 229 | % (t, iter_i, iter_size, current_lr, 230 | loss_dict['xy'], loss_dict['wh'], 231 | loss_dict['conf'], loss_dict['cls'], shap_loss, 232 | loss_dict['l2'], imgsize, ap50, precision50, recall50, F_measure), 233 | flush=True) 234 | log.append([t, iter_i, current_lr, 235 | np.atleast_1d(loss_dict['xy'].to('cpu').detach().numpy().copy())[0], 236 | np.atleast_1d(loss_dict['wh'].to('cpu').detach().numpy().copy())[0], 237 | np.atleast_1d(loss_dict['conf'].to('cpu').detach().numpy().copy())[0], 238 | np.atleast_1d(loss_dict['cls'].to('cpu').detach().numpy().copy())[0], 239 | np.atleast_1d(shap_loss.to('cpu').detach().numpy().copy())[0], 240 | np.atleast_1d(loss_dict['l2'].to('cpu').detach().numpy().copy())[0], 241 | imgsize, ap50, precision50, recall50, F_measure]) 242 | ap50 = np.nan 243 | precision50 = np.nan 244 | recall50 = np.nan 245 | F_measure = np.nan 246 | 247 | if args.tfboard: 248 | tblogger.add_scalar('train/total_loss', model.loss_dict['l2'], iter_i) 249 | 250 | # random resizing 251 | if random_resize: 252 | imgsize = (random.randint(0, 9) % 10 + 10) * 32 253 | dataset.img_shape = (imgsize, imgsize) 254 | dataset.img_size = imgsize 255 | dataloader = torch.utils.data.DataLoader( 256 | dataset, batch_size=batch_size, shuffle=True, num_workers=args.n_cpu) 257 | dataiterator = iter(dataloader) 258 | 259 | # save checkpoint 260 | #if iter_i > 0 and (iter_i % args.checkpoint_interval == 0): 261 | if (0 0: 96 | cocoGt = self.dataset.coco 97 | # workaround: temporarily write data to json file because pycocotools can't process dict in py36. 98 | _, tmp = tempfile.mkstemp() 99 | json.dump(data_dict, open(tmp, 'w')) 100 | cocoDt = cocoGt.loadRes(tmp) 101 | cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1]) 102 | cocoEval.params.imgIds = ids 103 | cocoEval.evaluate() 104 | cocoEval.accumulate() 105 | cocoEval.summarize() 106 | return cocoEval.stats[0], cocoEval.stats[1] 107 | else: 108 | return 0, 0 109 | 110 | -------------------------------------------------------------------------------- /utils/parse_yolo_weights.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import division 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def parse_conv_block(m, weights, offset, initflag): 8 | """ 9 | Initialization of conv layers with batchnorm 10 | Args: 11 | m (Sequential): sequence of layers 12 | weights (numpy.ndarray): pretrained weights data 13 | offset (int): current position in the weights file 14 | initflag (bool): if True, the layers are not covered by the weights file. \ 15 | They are initialized using darknet-style initialization. 16 | Returns: 17 | offset (int): current position in the weights file 18 | weights (numpy.ndarray): pretrained weights data 19 | """ 20 | conv_model = m[0] 21 | bn_model = m[1] 22 | param_length = m[1].bias.numel() 23 | 24 | # batchnorm 25 | for pname in ['bias', 'weight', 'running_mean', 'running_var']: 26 | layerparam = getattr(bn_model, pname) 27 | 28 | if initflag: # yolo initialization - scale to one, bias to zero 29 | if pname == 'weight': 30 | weights = np.append(weights, np.ones(param_length)) 31 | else: 32 | weights = np.append(weights, np.zeros(param_length)) 33 | 34 | param = torch.from_numpy(weights[offset:offset + param_length]).view_as(layerparam) 35 | layerparam.data.copy_(param) 36 | offset += param_length 37 | 38 | param_length = conv_model.weight.numel() 39 | 40 | # conv 41 | if initflag: # yolo initialization 42 | n, c, k, _ = conv_model.weight.shape 43 | scale = np.sqrt(2 / (k * k * c)) 44 | weights = np.append(weights, scale * np.random.normal(size=param_length)) 45 | 46 | param = torch.from_numpy( 47 | weights[offset:offset + param_length]).view_as(conv_model.weight) 48 | conv_model.weight.data.copy_(param) 49 | offset += param_length 50 | 51 | return offset, weights 52 | 53 | def parse_yolo_block(m, weights, offset, initflag): 54 | """ 55 | YOLO Layer (one conv with bias) Initialization 56 | Args: 57 | m (Sequential): sequence of layers 58 | weights (numpy.ndarray): pretrained weights data 59 | offset (int): current position in the weights file 60 | initflag (bool): if True, the layers are not covered by the weights file. \ 61 | They are initialized using darknet-style initialization. 62 | Returns: 63 | offset (int): current position in the weights file 64 | weights (numpy.ndarray): pretrained weights data 65 | """ 66 | conv_model = m._modules['conv'] 67 | param_length = conv_model.bias.numel() 68 | 69 | if initflag: # yolo initialization - bias to zero 70 | weights = np.append(weights, np.zeros(param_length)) 71 | 72 | param = torch.from_numpy( 73 | weights[offset:offset + param_length]).view_as(conv_model.bias) 74 | conv_model.bias.data.copy_(param) 75 | offset += param_length 76 | 77 | param_length = conv_model.weight.numel() 78 | 79 | if initflag: # yolo initialization 80 | n, c, k, _ = conv_model.weight.shape 81 | scale = np.sqrt(2 / (k * k * c)) 82 | weights = np.append(weights, scale * np.random.normal(size=param_length)) 83 | 84 | param = torch.from_numpy( 85 | weights[offset:offset + param_length]).view_as(conv_model.weight) 86 | conv_model.weight.data.copy_(param) 87 | offset += param_length 88 | 89 | return offset, weights 90 | 91 | def parse_yolo_weights(model, weights_path): 92 | """ 93 | Parse YOLO (darknet) pre-trained weights data onto the pytorch model 94 | Args: 95 | model : pytorch model object 96 | weights_path (str): path to the YOLO (darknet) pre-trained weights file 97 | """ 98 | fp = open(weights_path, "rb") 99 | 100 | # skip the header 101 | header = np.fromfile(fp, dtype=np.int32, count=5) # not used 102 | # read weights 103 | weights = np.fromfile(fp, dtype=np.float32) 104 | fp.close() 105 | 106 | offset = 0 107 | initflag = False #whole yolo weights : False, darknet weights : True 108 | 109 | for m in model.module_list: 110 | 111 | if m._get_name() == 'Sequential': 112 | # normal conv block 113 | offset, weights = parse_conv_block(m, weights, offset, initflag) 114 | 115 | elif m._get_name() == 'resblock': 116 | # residual block 117 | for modu in m._modules['module_list']: 118 | for blk in modu: 119 | offset, weights = parse_conv_block(blk, weights, offset, initflag) 120 | 121 | elif m._get_name() == 'YOLOLayer': 122 | # YOLO Layer (one conv with bias) Initialization 123 | offset, weights = parse_yolo_block(m, weights, offset, initflag) 124 | 125 | initflag = (offset >= len(weights)) # the end of the weights file. turn the flag on 126 | -------------------------------------------------------------------------------- /utils/resize_gt.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import json 4 | import os 5 | import shutil 6 | 7 | def area(a, b): # returns None if rectangles don't intersect 8 | """ 9 | calculating intersection area 10 | Args: 11 | a, b(list) : [y1,x1,y2,x2] 12 | Returns: 13 | is_box: 重複したbox 14 | """ 15 | is_box = [max(a[0], b[0]), 16 | max(a[1], b[1]), 17 | min(a[2], b[2]), 18 | min(a[3], b[3])] # [y1,x1,y2,x2] 19 | dy = is_box[2] - is_box[0] 20 | dx = is_box[3] - is_box[1] 21 | 22 | if dx>0 and dy>0: 23 | return dx*dy, is_box 24 | else: 25 | return 0, [] 26 | 27 | def resize_gt(input_dir, imgsize): 28 | """ 29 | Resizing input image and anno_data.json. 30 | Resized images do not have padding. 31 | Args: 32 | input_dir(str): path of GT-image dir 33 | imgsize (int): target image size after resizing(length of longer side) 34 | Returns: 35 | These returns are put to new dir named('input_dir name'_resized_'resized size') 36 | imgs(jpg): 37 | anno_data.json(json): 38 | """ 39 | img_list = glob.glob(os.path.join(input_dir, '*.jpg')) 40 | img_list.extend(glob.glob(os.path.join(input_dir,'*.png'))) 41 | 42 | json_open = open(os.path.join(input_dir, 'anno_data.json'), 'r') 43 | json_load = json.load(json_open) 44 | class_id = 0 45 | output_list = [] 46 | 47 | output_dir = input_dir+'_'+'resized'+'_'+str(imgsize) 48 | os.makedirs(output_dir, exist_ok=True) 49 | os.chdir(output_dir) 50 | 51 | for image in img_list: 52 | img = cv2.imread(image) 53 | h, w, _ = img.shape 54 | annotations = json_load[os.path.basename(image)]['regions'] 55 | bboxes = [] 56 | if len(annotations) > 0: 57 | for anno in annotations: 58 | box = [anno['bb'][1], anno['bb'][0], 59 | anno['bb'][1] + anno['bb'][3], anno['bb'][0] + anno['bb'][2]] 60 | # [y1,x1,y2,x2] 61 | bboxes.append(box) 62 | 63 | div_h = h//imgsize + min(h%imgsize,1) # 高さ方向の分割回数 64 | div_w = w//imgsize + min(w%imgsize,1) 65 | 66 | for i in range(div_h): 67 | for j in range(div_w): 68 | div_y1 = imgsize*i 69 | div_y2 = min(imgsize*(i+1)-1,h-1) 70 | div_x1 = imgsize*j 71 | div_x2 = min(imgsize*(j+1)-1,w-1) 72 | 73 | if (div_y2-div_y1)>=100 and (div_x2-div_x1)>=100: 74 | 75 | div_yxyx = [div_y1,div_x1,div_y2,div_x2] 76 | 77 | div_img = img[div_y1:div_y2+1, div_x1:div_x2+1] 78 | div_imagename = os.path.splitext(os.path.basename(image))[0] + '_'+str(i)+ '_'+str(j)+'.jpg' 79 | cv2.imwrite(div_imagename,div_img) 80 | 81 | l_regions = [] 82 | regions_dict = {} 83 | for box in bboxes: 84 | iou, is_box = area(div_yxyx,box) 85 | 86 | if iou/((box[3]-box[1])*(box[2]-box[0]))>=0.5: 87 | l_regions.append(dict((['class_id', class_id], 88 | ['bb', [is_box[1]-div_x1, is_box[0]-div_y1, is_box[3] - is_box[1], is_box[2] - is_box[0]]]))) 89 | 90 | 91 | regions_dict['regions']=l_regions 92 | output_list.append([div_imagename, regions_dict]) 93 | 94 | with open('anno_data.json', 'w') as f: 95 | json.dump(dict(output_list), f, ensure_ascii=False) 96 | 97 | return 98 | 99 | def remove_nogt(input_dir): 100 | """ 101 | Removing images without gt object. 102 | Outputs are in the directory of 'out_dir' 103 | Args: 104 | input_dir(str): path of GT-image dir (inculding no-object images) 105 | """ 106 | img_list = glob.glob(os.path.join(input_dir, '*.jpg')) 107 | img_list.extend(glob.glob(os.path.join(input_dir,'*.png'))) 108 | 109 | json_open = open(os.path.join(input_dir, 'anno_data.json'), 'r') 110 | json_load = json.load(json_open) 111 | 112 | output_dir = input_dir+'_'+'rm' 113 | os.makedirs(output_dir, exist_ok=True) 114 | 115 | for image in img_list: 116 | annotations = json_load[os.path.basename(image)]['regions'] 117 | if len(annotations) > 0: 118 | shutil.copyfile(image, os.path.join(output_dir, os.path.basename(image))) 119 | 120 | shutil.copyfile(os.path.join(input_dir, 'anno_data.json'),os.path.join(output_dir, 'anno_data.json')) 121 | 122 | return -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | def nms(bbox, thresh, score=None, limit=None): 8 | """Suppress bounding boxes according to their IoUs and confidence scores. 9 | Args: 10 | bbox (array): Bounding boxes to be transformed. The shape is 11 | :math:`(R, 4)`. :math:`R` is the number of bounding boxes. 12 | thresh (float): Threshold of IoUs. 13 | score (array): An array of confidences whose shape is :math:`(R,)`. 14 | limit (int): The upper bound of the number of the output bounding 15 | boxes. If it is not specified, this method selects as many 16 | bounding boxes as possible. 17 | Returns: 18 | array: 19 | An array with indices of bounding boxes that are selected. \ 20 | They are sorted by the scores of bounding boxes in descending \ 21 | order. \ 22 | The shape of this array is :math:`(K,)` and its dtype is\ 23 | :obj:`numpy.int32`. Note that :math:`K \\leq R`. 24 | 25 | from: https://github.com/chainer/chainercv 26 | """ 27 | 28 | if len(bbox) == 0: 29 | return np.zeros((0,), dtype=np.int32) 30 | 31 | if score is not None: 32 | order = score.argsort()[::-1] 33 | bbox = bbox[order] 34 | bbox_area = np.prod(bbox[:, 2:] - bbox[:, :2], axis=1) + 1 35 | selec = np.zeros(bbox.shape[0], dtype=bool) 36 | for i, b in enumerate(bbox): 37 | tl = np.maximum(b[:2], bbox[selec, :2]) 38 | br = np.minimum(b[2:], bbox[selec, 2:]) 39 | area = np.prod(br - tl, axis=1) * (tl < br).all(axis=1) 40 | 41 | sum_area = bbox_area[i] + bbox_area[selec] - area + 1e-16 42 | iou = area / sum_area 43 | ''' 44 | iou = np.zeros_like(area) 45 | if sum_area>0: 46 | iou = area / sum_area 47 | ''' 48 | if (iou >= thresh).any(): 49 | continue 50 | 51 | selec[i] = True 52 | if limit is not None and np.count_nonzero(selec) >= limit: 53 | break 54 | 55 | selec = np.where(selec)[0] 56 | if score is not None: 57 | selec = order[selec] 58 | return selec.astype(np.int32) 59 | 60 | def postfilter(a): 61 | for i in range(len(a)): 62 | for j in range(len(a[i])): 63 | for k in range(len(a[i,j])): 64 | if a[i,j,k]>1 or a[i,j,k]<0: 65 | a[i,j,k] = torch.tensor([1.]) 66 | else: 67 | a[i,j,k] = a[i,j,k] 68 | return a 69 | 70 | 71 | def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45): 72 | """ 73 | Postprocess for the output of YOLO model 74 | perform box transformation, specify the class for each detection, 75 | and perform class-wise non-maximum suppression. 76 | Args: 77 | prediction (torch tensor): The shape is :math:`(N, B, 8)`. 78 | :math:`N` is the number of predictions, 79 | :math:`B` the number of boxes. The last axis consists of 80 | :math:`xc, yc, w, h` where `xc` and `yc` represent a center 81 | of a bounding box. 82 | num_classes (int): 83 | number of dataset classes. 84 | conf_thre (float): 85 | confidence threshold ranging from 0 to 1, 86 | which is defined in the config file. 87 | nms_thre (float): 88 | IoU threshold of non-max suppression ranging from 0 to 1. 89 | 90 | Returns: 91 | output (list of torch tensor): 92 | 93 | """ 94 | box_corner = prediction.new(prediction.shape) 95 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 96 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 97 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 98 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 99 | prediction[:, :, :4] = box_corner[:, :, :4] 100 | 101 | output = [None for _ in range(len(prediction))] 102 | for i, image_pred in enumerate(prediction): 103 | # Filter out confidence scores below threshold 104 | class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1) 105 | class_pred = class_pred[0] 106 | conf_mask = (image_pred[:, 4] * class_pred >= conf_thre).squeeze() 107 | image_pred = image_pred[conf_mask] 108 | 109 | # If none are remaining => process next image 110 | if not image_pred.size(0): 111 | continue 112 | # Get detections with higher confidence scores than the threshold 113 | ind = (image_pred[:, 5:5 + num_classes] * image_pred[:, 4][:, None] >= conf_thre).nonzero(as_tuple=False) 114 | # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) 115 | detections = torch.cat(( 116 | image_pred[ind[:, 0], :5], 117 | image_pred[ind[:, 0], 5 + ind[:, 1]].unsqueeze(1), 118 | ind[:, 1].float().unsqueeze(1), 119 | image_pred[ind[:, 0], 5 + num_classes:] 120 | ), 1) 121 | # Iterate through all predicted classes 122 | unique_labels = detections[:, -5].cpu().unique() 123 | if prediction.is_cuda: 124 | unique_labels = unique_labels.cuda() 125 | for c in unique_labels: 126 | # Get the detections with the particular class 127 | detections_class = detections[detections[:, -5] == c] 128 | nms_in = detections_class.cpu().numpy() 129 | nms_out_index = nms( 130 | nms_in[:, :4], nms_thre, score=nms_in[:, 4]*nms_in[:, 5]) 131 | detections_class = detections_class[nms_out_index] 132 | if output[i] is None: 133 | output[i] = detections_class 134 | else: 135 | output[i] = torch.cat((output[i], detections_class)) 136 | 137 | return output 138 | 139 | 140 | def bboxes_iou(bboxes_a, bboxes_b, xyxy=True): 141 | """Calculate the Intersection of Unions (IoUs) between bounding boxes. 142 | IoU is calculated as a ratio of area of the intersection 143 | and area of the union. 144 | 145 | Args: 146 | bbox_a (array): An array whose shape is :math:`(N, 4)`. 147 | :math:`N` is the number of bounding boxes. 148 | The dtype should be :obj:`numpy.float32`. 149 | bbox_b (array): An array similar to :obj:`bbox_a`, 150 | whose shape is :math:`(K, 4)`. 151 | The dtype should be :obj:`numpy.float32`. 152 | Returns: 153 | array: 154 | An array whose shape is :math:`(N, K)`. \ 155 | An element at index :math:`(n, k)` contains IoUs between \ 156 | :math:`n` th bounding box in :obj:`bbox_a` and :math:`k` th bounding \ 157 | box in :obj:`bbox_b`. 158 | 159 | from: https://github.com/chainer/chainercv 160 | """ 161 | if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: 162 | raise IndexError 163 | 164 | # top left 165 | if xyxy: 166 | tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) 167 | # bottom right 168 | br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) 169 | area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) 170 | area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) 171 | else: 172 | tl = torch.max((bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), 173 | (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2)) 174 | # bottom right 175 | br = torch.min((bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), 176 | (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2)) 177 | 178 | area_a = torch.prod(bboxes_a[:, 2:], 1) 179 | area_b = torch.prod(bboxes_b[:, 2:], 1) 180 | en = (tl < br).type(tl.type()).prod(dim=2) 181 | area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) 182 | return area_i / (area_a[:, None] + area_b - area_i) 183 | 184 | 185 | def np_bboxes_iou(a_bbox, b_bboxes): 186 | """Calculate intersection over union (IOU). 187 | 188 | Args: 189 | a (array-like): 1-D Array with shape (4,) representing bounding box. 190 | b (array-like): 2-D Array with shape (NumBoxes, 4) representing bounding boxes. 191 | 192 | Returns: 193 | [type]: [description] 194 | 195 | from 196 | https://github.com/nekobean/pascalvoc_metrics/blob/master/pascalvoc_metrics.py 197 | """ 198 | # 短形 a_bbox と短形 b_bboxes の共通部分を計算する。 199 | xmin = np.maximum(a_bbox[0], b_bboxes[:, 0]) 200 | ymin = np.maximum(a_bbox[1], b_bboxes[:, 1]) 201 | xmax = np.minimum(a_bbox[2], b_bboxes[:, 2]) 202 | ymax = np.minimum(a_bbox[3], b_bboxes[:, 3]) 203 | i_bboxes = np.column_stack([xmin, ymin, xmax, ymax]) 204 | 205 | # 矩形の面積を計算する。 206 | a_area = calc_area(a_bbox) 207 | b_area = np.apply_along_axis(calc_area, 1, b_bboxes) 208 | i_area = np.apply_along_axis(calc_area, 1, i_bboxes) 209 | 210 | # IOU を計算する。 211 | iou = i_area / (a_area + b_area - i_area) 212 | 213 | return iou 214 | 215 | 216 | def calc_area(bbox): 217 | """Calculate area of boudning box. 218 | 219 | Args: 220 | bboxes (array-like): 1-D Array with shape (4,) representing bounding box. 221 | 222 | Returns: 223 | float: Areea 224 | """ 225 | # 矩形の面積を計算する。 226 | # 共通部分がない場合は、幅や高さは負の値になるので、その場合、幅や高さは 0 とする。 227 | width = max(0, bbox[2] - bbox[0] + 1) 228 | height = max(0, bbox[3] - bbox[1] + 1) 229 | 230 | return width * height 231 | 232 | 233 | def label2yolobox(labels, info_img, maxsize, lrflip): 234 | """ 235 | Transform coco labels to yolo box labels 236 | Args: 237 | labels (numpy.ndarray): label data whose shape is :math:`(N, 5)`. 238 | Each label consists of [class, x, y, w, h] where \ 239 | class (float): class index. 240 | x, y, w, h (float) : coordinates of \ 241 | left-top points, width, and height of a bounding box. 242 | Values range from 0 to width or height of the image. 243 | info_img : tuple of h, w, nh, nw, dx, dy. 244 | h, w (int): original shape of the image 245 | nh, nw (int): shape of the resized image without padding 246 | dx, dy (int): pad size 247 | maxsize (int): target image size after pre-processing 248 | lrflip (bool): horizontal flip flag 249 | 250 | Returns: 251 | labels:label data whose size is :math:`(N, 5)`. 252 | Each label consists of [class, xc, yc, w, h] where 253 | class (float): class index. 254 | xc, yc (float) : center of bbox whose values range from 0 to 1. 255 | w, h (float) : size of bbox whose values range from 0 to 1. 256 | """ 257 | h, w, nh, nw, dx, dy = info_img 258 | x1 = labels[:, 1] / w 259 | y1 = labels[:, 2] / h 260 | x2 = (labels[:, 1] + labels[:, 3]) / w 261 | y2 = (labels[:, 2] + labels[:, 4]) / h 262 | labels[:, 1] = (((x1 + x2) / 2) * nw + dx) / maxsize 263 | labels[:, 2] = (((y1 + y2) / 2) * nh + dy) / maxsize 264 | labels[:, 3] *= nw / w / maxsize 265 | labels[:, 4] *= nh / h / maxsize 266 | if lrflip: 267 | labels[:, 1] = 1 - labels[:, 1] 268 | return labels 269 | 270 | 271 | def yolobox2label(box, info_img): 272 | """ 273 | Transform yolo box labels to yxyx box labels. 274 | Args: 275 | box (list): box data with the format of [yc, xc, w, h] 276 | in the coordinate system after pre-processing. 277 | info_img : tuple of h, w, nh, nw, dx, dy. 278 | h, w (int): original shape of the image 279 | nh, nw (int): shape of the resized image without padding 280 | dx, dy (int): pad size 281 | Returns: 282 | label (list): box data with the format of [y1, x1, y2, x2] 283 | in the coordinate system of the input image. 284 | """ 285 | h, w, nh, nw, dx, dy = info_img 286 | y1, x1, y2, x2 = box 287 | box_h = ((y2 - y1) / nh) * h 288 | box_w = ((x2 - x1) / nw) * w 289 | y1 = ((y1 - dy) / nh) * h 290 | x1 = ((x1 - dx) / nw) * w 291 | label = [y1, x1, y1 + box_h, x1 + box_w] 292 | return label 293 | 294 | 295 | def preprocess(img, imgsize, jitter, random_placing=False): 296 | """ 297 | Image preprocess for yolo input 298 | Pad the shorter side of the image and resize to (imgsize, imgsize) 299 | Args: 300 | img (numpy.ndarray): input image whose shape is :math:`(H, W, C)`. 301 | Values range from 0 to 255. 302 | imgsize (int): target image size after pre-processing 303 | jitter (float): amplitude of jitter for resizing 304 | random_placing (bool): if True, place the image at random position 305 | 306 | Returns: 307 | img (numpy.ndarray): input image whose shape is :math:`(C, imgsize, imgsize)`. 308 | Values range from 0 to 1. 309 | info_img : tuple of h, w, nh, nw, dx, dy. 310 | h, w (int): original shape of the image 311 | nh, nw (int): shape of the resized image without padding 312 | dx, dy (int): pad size 313 | """ 314 | h, w, _ = img.shape 315 | img = img[:, :, ::-1] 316 | assert img is not None 317 | 318 | if jitter > 0: 319 | # add jitter 320 | dw = jitter * w 321 | dh = jitter * h 322 | new_ar = (w + np.random.uniform(low=-dw, high=dw))\ 323 | / (h + np.random.uniform(low=-dh, high=dh)) 324 | else: 325 | new_ar = w / h 326 | 327 | if new_ar < 1: 328 | nh = imgsize 329 | nw = nh * new_ar 330 | else: 331 | nw = imgsize 332 | nh = nw / new_ar 333 | nw, nh = int(nw), int(nh) 334 | 335 | if random_placing: 336 | dx = int(np.random.uniform(imgsize - nw)) 337 | dy = int(np.random.uniform(imgsize - nh)) 338 | else: 339 | dx = (imgsize - nw) // 2 340 | dy = (imgsize - nh) // 2 341 | 342 | img = cv2.resize(img, (nw, nh)) 343 | sized = np.ones((imgsize, imgsize, 3), dtype=np.uint8) * 127 344 | sized[dy:dy+nh, dx:dx+nw, :] = img 345 | 346 | info_img = (h, w, nh, nw, dx, dy) 347 | return sized, info_img 348 | 349 | def rand_scale(s): 350 | """ 351 | calculate random scaling factor 352 | Args: 353 | s (float): range of the random scale. 354 | Returns: 355 | random scaling factor (float) whose range is 356 | from 1 / s to s . 357 | """ 358 | scale = np.random.uniform(low=1, high=s) 359 | if np.random.rand() > 0.5: 360 | return scale 361 | return 1 / scale 362 | 363 | def random_distort(img, hue, saturation, exposure): 364 | """ 365 | perform random distortion in the HSV color space. 366 | Args: 367 | img (numpy.ndarray): input image whose shape is :math:`(H, W, C)`. 368 | Values range from 0 to 255. 369 | hue (float): random distortion parameter. 370 | saturation (float): random distortion parameter. 371 | exposure (float): random distortion parameter. 372 | Returns: 373 | img (numpy.ndarray) 374 | """ 375 | dhue = np.random.uniform(low=-hue, high=hue) 376 | dsat = rand_scale(saturation) 377 | dexp = rand_scale(exposure) 378 | 379 | img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) 380 | img = np.asarray(img, dtype=np.float32) / 255. 381 | img[:, :, 1] *= dsat 382 | img[:, :, 2] *= dexp 383 | H = img[:, :, 0] + dhue 384 | 385 | if dhue > 0: 386 | H[H > 1.0] -= 1.0 387 | else: 388 | H[H < 0.0] += 1.0 389 | 390 | img[:, :, 0] = H 391 | img = (img * 255).clip(0, 255).astype(np.uint8) 392 | img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) 393 | img = np.asarray(img, dtype=np.float32) 394 | 395 | return img 396 | 397 | 398 | def get_coco_label_names(): 399 | """ 400 | COCO label names and correspondence between the model's class index and COCO class index. 401 | Returns: 402 | coco_label_names (tuple of str) : all the COCO label names including background class. 403 | coco_class_ids (list of int) : index of 80 classes that are used in 'instance' annotations 404 | coco_cls_colors (np.ndarray) : randomly generated color vectors used for box visualization 405 | 406 | """ 407 | coco_label_names = ('background', # class zero 408 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 409 | 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 410 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 411 | 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 412 | 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 413 | 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 414 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 415 | 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 416 | 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 417 | 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 418 | 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 419 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 420 | 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' 421 | ) 422 | coco_class_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 423 | 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 424 | 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 425 | 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] 426 | 427 | coco_cls_colors = np.random.randint(128, 255, size=(80, 3)) 428 | 429 | return coco_label_names, coco_class_ids, coco_cls_colors 430 | 431 | def get_vd_label_names(): 432 | """ 433 | COCO label names and correspondence between the model's class index and COCO class index. 434 | Returns: 435 | coco_label_names (tuple of str) : all the COCO label names including background class. 436 | coco_class_ids (list of int) : index of 80 classes that are used in 'instance' annotations 437 | coco_cls_colors (np.ndarray) : randomly generated color vectors used for box visualization 438 | 439 | """ 440 | label_names = ('vehicle') 441 | class_ids = [0] 442 | 443 | cls_colors = np.random.randint(128, 255, size=(1, 3)) 444 | 445 | return label_names, class_ids, cls_colors 446 | 447 | def get_batch_statistics(outputs, targets, iou_threshold): 448 | """ Compute true positives, predicted scores and predicted labels per sample """ 449 | batch_metrics = [] 450 | for sample_i in range(len(outputs)): 451 | 452 | if outputs[sample_i] is None: 453 | continue 454 | 455 | output = outputs[sample_i] 456 | pred_boxes = output[:, :4] 457 | pred_scores = output[:, 4] 458 | pred_labels = output[:, -1] 459 | 460 | true_positives = np.zeros(pred_boxes.shape[0]) 461 | 462 | annotations = targets[targets[:, 0] == sample_i][:, 1:] 463 | target_labels = annotations[:, 0] if len(annotations) else [] 464 | if len(annotations): 465 | detected_boxes = [] 466 | target_boxes = annotations[:, 1:] 467 | 468 | for pred_i, (pred_box, pred_label) in enumerate(zip(pred_boxes, pred_labels)): 469 | 470 | # If targets are found break 471 | if len(detected_boxes) == len(annotations): 472 | break 473 | 474 | # Ignore if label is not one of the target labels 475 | if pred_label not in target_labels: 476 | continue 477 | 478 | iou, box_index = bboxes_iou(pred_box.unsqueeze(0), target_boxes).max(0) 479 | if iou >= iou_threshold and box_index not in detected_boxes: 480 | true_positives[pred_i] = 1 481 | detected_boxes += [box_index] 482 | batch_metrics.append([true_positives, pred_scores, pred_labels]) 483 | return batch_metrics 484 | 485 | 486 | 487 | def resized_bbox(box, info_img): 488 | """ 489 | Transform gt-size bbox to resized bbox 490 | Args: 491 | box (list): [x1,y1,x2,y2] 492 | info_img : tuple of h, w, nh, nw, dx, dy. 493 | h, w (int): original shape of the image 494 | nh, nw (int): shape of the resized image without padding 495 | dx, dy (int): pad size 496 | maxsize (int): target image size after pre-processing 497 | 498 | Returns: 499 | r_box (list): [x1,y1,x2,y2](resized) 500 | """ 501 | h, w, nh, nw, dx, dy = info_img 502 | 503 | x_1 = (box[0] / w) * nw + dx 504 | y_1 = (box[1] / h) * nh + dy 505 | x_2 = (box[2] / w) * nw + dx 506 | y_2 = (box[3] / h) * nh + dy 507 | 508 | return [x_1,y_1,x_2,y_2] 509 | 510 | 511 | 512 | from torch.autograd import Variable 513 | import os 514 | import glob 515 | import json 516 | import torch 517 | import cv2 518 | 519 | def make_data_dict(model, imgsize, input_dir, json_file='anno_data.json', 520 | min_size=1,confthre=0.005, nmsthre=0.45, gpu=0): 521 | """ 522 | returning lists of bounding box [x_1,y_1,x_2,y_2] of ground truth and yolo outputs 523 | no-skipping no object image 524 | 525 | Returns: 526 | pred_dict: 527 | list of dict below 528 | {"image_id": id_, "image_name":jpg/png, "category_id": label, "pred_bbox": bbox, 529 | "obj_score":, "class_score":, "score": score} 530 | gt_dict: 531 | list of dict below 532 | {"image_id": id_, "image_name":jpg/png, "gt_bbox":, "category_id": label} 533 | """ 534 | img_list = glob.glob(os.path.join(input_dir,'*.jpg')) 535 | img_list.extend(glob.glob(os.path.join(input_dir,'*.png'))) 536 | 537 | json_open = open(os.path.join(input_dir, json_file), 'r') 538 | json_load = json.load(json_open) 539 | 540 | pred_dict = [] 541 | gt_dict = [] 542 | 543 | for id_ in range(len(img_list)): 544 | image = img_list[id_] 545 | img = cv2.imread(image) 546 | img, info_img = preprocess(img, imgsize, jitter=0) # info = (h, w, nh, nw, dx, dy) 547 | img = np.transpose(img / 255., (2, 0, 1)) 548 | img = torch.from_numpy(img).float().unsqueeze(0) 549 | 550 | if gpu >= 0: 551 | img = Variable(img.type(torch.cuda.FloatTensor)) 552 | else: 553 | img = Variable(img.type(torch.FloatTensor)) 554 | 555 | # gt bbox 556 | annotations = json_load[os.path.basename(image)]['regions'] 557 | 558 | if len(annotations) > 0: 559 | for anno in annotations: 560 | if anno['bb'][2] > min_size and anno['bb'][3] > min_size: 561 | box = [anno['bb'][1], anno['bb'][0], 562 | anno['bb'][1] + anno['bb'][3], anno['bb'][0] + anno['bb'][2]] 563 | box = resized_bbox(box, info_img) 564 | gt_dict.append({"image_id": id_, "image_name":os.path.basename(image), 565 | "gt_bbox":box}) 566 | classes.append(anno['class_id']) 567 | 568 | # yolo bbox 569 | with torch.no_grad(): 570 | outputs = model(img) 571 | outputs = postprocess(outputs, 1, confthre, nmsthre) 572 | 573 | bboxes = list() 574 | 575 | if outputs[0] is None: 576 | print("No Objects Deteted!!") 577 | 578 | else: 579 | 580 | 581 | #classes = list() 582 | #colors = list() 583 | 584 | for x1, y1, x2, y2, conf, cls_conf, cls_pred in outputs[0]: 585 | 586 | #cls_id = class_ids[int(cls_pred)] 587 | #print(int(x1), int(y1), int(x2), int(y2), float(conf), int(cls_pred)) 588 | #print('\t+ Label: %s, Conf: %.5f' % 589 | # (class_names[cls_id], cls_conf.item())) 590 | # box = yolobox2label([y1, x1, y2, x2], info_img) 591 | bboxes.append([x1,y1,x2,y2]) 592 | #classes.append(cls_id) 593 | #colors.append(class_colors[int(cls_pred)]) 594 | 595 | 596 | #vis_bbox( 597 | # img_raw, bboxes, #label=classes, label_names=class_names, 598 | # instance_colors=colors, linewidth=2) 599 | #plt.show() 600 | 601 | #plt.savefig(os.path.join(folder_path, 'yolo_' + os.path.basename(image))) 602 | 603 | yolo_bboxes.append([os.path.basename(image), bboxes]) 604 | 605 | 606 | 607 | def bb_list(input_dir, cfg, 608 | min_size = 1, img_size=416, gpu=0, 609 | weights_path=None, ckpt=None): 610 | """ 611 | returning lists of bounding box [x_1,y_1,x_2,y_2] of ground truth and yolo outputs 612 | no-skipping no object image 613 | 614 | Returns: 615 | gt_bboxes: 616 | [...[image_name,[list of bounding box]]...] 617 | yolo_bboxes: 618 | """ 619 | with open(cfg, 'r') as f: 620 | cfg = yaml.load(f) 621 | 622 | img_list = glob.glob(os.path.join(input_dir,'*.jpg')) 623 | img_list.extend(glob.glob(os.path.join(input_dir,'*.png'))) 624 | 625 | imgsize = cfg['TEST']['IMGSIZE'] 626 | model = YOLOv3(cfg['MODEL']) 627 | num_classes = cfg['MODEL']['N_CLASSES'] 628 | 629 | confthre = cfg['TEST']['CONFTHRE'] 630 | nmsthre = cfg['TEST']['NMSTHRE'] 631 | 632 | if gpu >= 0: 633 | model.cuda(gpu) 634 | 635 | assert weights_path or ckpt, 'One of --weights_path and --ckpt must be specified' 636 | 637 | if weights_path: 638 | print("loading yolo weights %s" % (weights_path)) 639 | parse_yolo_weights(model, weights_path) 640 | elif ckpt: 641 | print("loading checkpoint %s" % (ckpt)) 642 | state = torch.load(ckpt) 643 | if 'model_state_dict' in state.keys(): 644 | model.load_state_dict(state['model_state_dict']) 645 | else: 646 | model.load_state_dict(state) 647 | 648 | model.eval() 649 | 650 | json_open = open(os.path.join(input_dir, 'anno_data.json'), 'r') 651 | json_load = json.load(json_open) 652 | 653 | gt_bboxes = [] 654 | yolo_bboxes = [] 655 | 656 | for image in img_list: 657 | 658 | #folder_name = os.path.splitext(os.path.basename(image))[0] 659 | #folder_path = os.path.join(output_dir, folder_name) 660 | #os.makedirs(folder_path, exist_ok=True) 661 | 662 | img = cv2.imread(image) 663 | # output raw image 664 | #cv2.imwrite(os.path.join(folder_path, os.path.basename(image)), img) 665 | 666 | 667 | img_raw = img.copy()[:, :, ::-1].transpose((2, 0, 1)) 668 | 669 | img, info_img = preprocess(img, imgsize, jitter=0) # info = (h, w, nh, nw, dx, dy) 670 | img = np.transpose(img / 255., (2, 0, 1)) 671 | img = torch.from_numpy(img).float().unsqueeze(0) 672 | 673 | if gpu >= 0: 674 | img = Variable(img.type(torch.cuda.FloatTensor)) 675 | else: 676 | img = Variable(img.type(torch.FloatTensor)) 677 | 678 | 679 | # gt bbox 680 | annotations = json_load[os.path.basename(image)]['regions'] 681 | 682 | bboxes = list() 683 | 684 | if len(annotations) == 0: 685 | print("No Objects Exist!!") 686 | 687 | else: 688 | 689 | #classes = list() 690 | #colors = list() 691 | 692 | for anno in annotations: 693 | if anno['bb'][2] > min_size and anno['bb'][3] > min_size: 694 | box = [anno['bb'][1], anno['bb'][0], 695 | anno['bb'][1] + anno['bb'][3], anno['bb'][0] + anno['bb'][2]] 696 | box = resized_bbox(box, info_img) 697 | bboxes.append(box) 698 | #classes.append(anno['class_id']) 699 | #colors.append(class_colors[anno['class_id']]) 700 | 701 | #vis_bbox( 702 | # img_raw, bboxes, #label=classes, label_names=class_names, 703 | # instance_colors=colors, linewidth=2) 704 | #plt.show() 705 | 706 | 707 | #plt.savefig(os.path.join(folder_path, 'gt_' + os.path.basename(image))) 708 | 709 | 710 | gt_bboxes.append([os.path.basename(image), bboxes]) 711 | 712 | 713 | # yolo bbox 714 | 715 | with torch.no_grad(): 716 | outputs = model(img) 717 | outputs = postprocess(outputs, num_classes, confthre, nmsthre) 718 | 719 | bboxes = list() 720 | 721 | if outputs[0] is None: 722 | print("No Objects Deteted!!") 723 | 724 | else: 725 | 726 | 727 | #classes = list() 728 | #colors = list() 729 | 730 | for x1, y1, x2, y2, conf, cls_conf, cls_pred in outputs[0]: 731 | 732 | #cls_id = class_ids[int(cls_pred)] 733 | #print(int(x1), int(y1), int(x2), int(y2), float(conf), int(cls_pred)) 734 | #print('\t+ Label: %s, Conf: %.5f' % 735 | # (class_names[cls_id], cls_conf.item())) 736 | # box = yolobox2label([y1, x1, y2, x2], info_img) 737 | bboxes.append([x1,y1,x2,y2]) 738 | #classes.append(cls_id) 739 | #colors.append(class_colors[int(cls_pred)]) 740 | 741 | 742 | #vis_bbox( 743 | # img_raw, bboxes, #label=classes, label_names=class_names, 744 | # instance_colors=colors, linewidth=2) 745 | #plt.show() 746 | 747 | #plt.savefig(os.path.join(folder_path, 'yolo_' + os.path.basename(image))) 748 | 749 | yolo_bboxes.append([os.path.basename(image), bboxes]) 750 | 751 | return gt_bboxes, yolo_bboxes 752 | -------------------------------------------------------------------------------- /utils/vd_evaluator.py: -------------------------------------------------------------------------------- 1 | def vd_eval(l_annotations, l_outputs, fp_out=False): 2 | """ 3 | calculating the number of TP,FP,FN, precision and recall 4 | Args: 5 | l_annotations (list): 6 | l_outputs (list): 7 | fp_out (bool): 8 | default:False 9 | Returns: 10 | l_evals (list): 11 | [TP,FP,FN,precision, recall] 12 | l_fpbb (list): 13 | If fp_out is True 14 | """ 15 | if fp_out: 16 | return l_evals, l_fpbb 17 | else: 18 | return l_evals 19 | 20 | 21 | import glob 22 | import json 23 | import os 24 | 25 | import cv2 26 | import numpy as np 27 | #from pycocotools.cocoeval import COCOeval 28 | from torch.autograd import Variable 29 | 30 | #from dataset.cocodataset import * 31 | from utils.utils import * 32 | 33 | 34 | class VDEvaluator(): 35 | """ 36 | Vehicle Detection AP Evaluation class. 37 | All the data in the validation dataset are processed \ 38 | and evaluated. 39 | """ 40 | def __init__(self, data_dir, json_file, img_size, confthre, nmsthre, min_size=1): 41 | """ 42 | Args: 43 | data_dir (str): dataset root directory 44 | img_size (int): image size after preprocess. images are resized \ 45 | to squares whose shape is (img_size, img_size). 46 | confthre (float): 47 | confidence threshold ranging from 0 to 1, \ 48 | which is defined in the config file. 49 | nmsthre (float): 50 | IoU threshold of non-max supression ranging from 0 to 1. 51 | """ 52 | self.data_dir = data_dir 53 | self.json_file = json_file 54 | self.img_size = img_size 55 | self.confthre = confthre # 0.005 from darknet 56 | self.nmsthre = nmsthre # 0.45 (darknet) 57 | self.min_size = min_size 58 | 59 | def evaluate(self, model): 60 | """ 61 | VD average precision (AP) Evaluation. Iterate inference on the val dataset 62 | and the results are evaluated. 63 | Args: 64 | model : model object 65 | Returns: 66 | ap50 (float) : calculated AP for IoU=50 67 | """ 68 | 69 | model.eval() 70 | cuda = torch.cuda.is_available() 71 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 72 | 73 | img_list = glob.glob(os.path.join(self.data_dir, '*.jpg')) 74 | img_list.extend(glob.glob(os.path.join(self.data_dir,'*.png'))) 75 | 76 | json_open = open(os.path.join(self.data_dir, self.json_file), 'r') 77 | json_load = json.load(json_open) 78 | 79 | all_gt_bboxes = [] 80 | all_yolo_bboxes = [] 81 | 82 | for image in img_list: 83 | gt_bboxes = [] 84 | yolo_bboxes = [] 85 | 86 | #predict bbox 87 | img = cv2.imread(image) 88 | if img is None: 89 | print('read image error') 90 | img, info_img = preprocess(img, self.img_size, jitter=0) # info = (h, w, nh, nw, dx, dy) 91 | img = np.transpose(img / 255., (2, 0, 1)) 92 | img = torch.from_numpy(img).float().unsqueeze(0) 93 | 94 | # gt bbox 95 | annotations = json_load[os.path.basename(image)]['regions'] 96 | if len(annotations) > 0: 97 | for anno in annotations: 98 | if anno['bb'][2] > self.min_size and anno['bb'][3] > self.min_size: 99 | box = [anno['bb'][0], anno['bb'][1], 100 | anno['bb'][0] + anno['bb'][2], anno['bb'][1] + anno['bb'][3]] 101 | box = resized_bbox(box, info_img) 102 | gt_bboxes.append([os.path.basename(image),box]) 103 | #gt_bboxes.append({"image":os.path.basename(image), "bbox":box}) 104 | all_gt_bboxes.extend(gt_bboxes) 105 | 106 | 107 | 108 | with torch.no_grad(): 109 | img = Variable(img.type(Tensor)) 110 | outputs = model(img) 111 | # delete outputs with inf 112 | 113 | #outputs= postfilter(outputs) 114 | outputs = postprocess(outputs, 1, self.confthre, self.nmsthre) 115 | if outputs[0] is not None: 116 | outputs = outputs[0].cpu().data 117 | for output in outputs: 118 | x1 = float(output[0]) 119 | y1 = float(output[1]) 120 | x2 = float(output[2]) 121 | y2 = float(output[3]) 122 | score = float(output[4].data.item() * output[5].data.item()) 123 | label = int(output[6]) 124 | yolo_bboxes.append([os.path.basename(image),[x1,y1,x2,y2], score, label, False]) 125 | #yolo_bboxes.append({"image":os.path.basename(image), "category_id": label, 126 | # "bbox": [x1,y1,x2,y2],"score": score}) 127 | 128 | # judge TP or FP 129 | # score sort 130 | yolo_bboxes = sorted(yolo_bboxes, key=lambda x: x[2]) 131 | 132 | for j in range(len(yolo_bboxes)): 133 | a = None 134 | t = 0 135 | for gt_bbox in gt_bboxes: 136 | iou = np_bboxes_iou(np.array(yolo_bboxes[j][1]), np.array(gt_bbox[1]).reshape(1,4)) 137 | if iou > max(0.5, t): 138 | a = gt_bbox 139 | t = iou 140 | if a != None: 141 | gt_bboxes.remove(a) 142 | yolo_bboxes[j][4] = True 143 | 144 | all_yolo_bboxes.extend(yolo_bboxes) 145 | # calculating AP 146 | ap50 = 0 147 | precision50 = 0 148 | recall50 = 0 149 | F_measure = 0 150 | if len(all_yolo_bboxes)==0: 151 | print('no pred') 152 | if len(all_gt_bboxes)==0: 153 | print('no gt') 154 | if len(all_yolo_bboxes) > 0 and len(all_gt_bboxes) > 0: 155 | tp_or_fp = [yolo_bbox[4] for yolo_bbox in all_yolo_bboxes] 156 | acc_tp = np.cumsum(tp_or_fp) 157 | acc_fp = np.cumsum(np.logical_not(tp_or_fp)) 158 | 159 | precision = acc_tp /(acc_tp + acc_fp) 160 | recall = acc_tp / len(all_gt_bboxes) 161 | 162 | modified_recall = np.concatenate([[0], recall, [1]]) 163 | modified_precision = np.concatenate([[0], precision, [0]]) 164 | 165 | # 末尾から累積最大値を計算する。 166 | modified_precision = np.maximum.accumulate(modified_precision[::-1])[::-1] 167 | 168 | # AP50 を計算する。 169 | ap50 = (np.diff(modified_recall) * modified_precision[1:]).sum() 170 | 171 | precision50 = precision[-1] 172 | recall50 = recall[-1] 173 | if (precision50+recall50)>0: 174 | F_measure = 2*(precision50*recall50)/(precision50+recall50) 175 | 176 | outputs = 0 177 | 178 | return ap50, precision50, recall50, F_measure 179 | -------------------------------------------------------------------------------- /utils/vis_bbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def vis_bbox(img, bbox, label=None, score=None, label_names=None, 5 | instance_colors=None, alpha=1., linewidth=3., ax=None): 6 | """Visualize bounding boxes inside the image. 7 | Args: 8 | img (~numpy.ndarray): An array of shape :math:`(3, height, width)`. 9 | This is in RGB format and the range of its value is 10 | :math:`[0, 255]`. If this is :obj:`None`, no image is displayed. 11 | bbox (~numpy.ndarray): An array of shape :math:`(R, 4)`, where 12 | :math:`R` is the number of bounding boxes in the image. 13 | Each element is organized 14 | by :math:`(y_{min}, x_{min}, y_{max}, x_{max})` in the second axis. 15 | label (~numpy.ndarray): An integer array of shape :math:`(R,)`. 16 | The values correspond to id for label names stored in 17 | :obj:`label_names`. This is optional. 18 | score (~numpy.ndarray): A float array of shape :math:`(R,)`. 19 | Each value indicates how confident the prediction is. 20 | This is optional. 21 | label_names (iterable of strings): Name of labels ordered according 22 | to label ids. If this is :obj:`None`, labels will be skipped. 23 | instance_colors (iterable of tuples): List of colors. 24 | Each color is RGB format and the range of its values is 25 | :math:`[0, 255]`. The :obj:`i`-th element is the color used 26 | to visualize the :obj:`i`-th instance. 27 | If :obj:`instance_colors` is :obj:`None`, the red is used for 28 | all boxes. 29 | alpha (float): The value which determines transparency of the 30 | bounding boxes. The range of this value is :math:`[0, 1]`. 31 | linewidth (float): The thickness of the edges of the bounding boxes. 32 | ax (matplotlib.axes.Axis): The visualization is displayed on this 33 | axis. If this is :obj:`None` (default), a new axis is created. 34 | Returns: 35 | ~matploblib.axes.Axes: 36 | Returns the Axes object with the plot for further tweaking. 37 | 38 | from: https://github.com/chainer/chainercv 39 | """ 40 | 41 | if label is not None and not len(bbox) == len(label): 42 | raise ValueError('The length of label must be same as that of bbox') 43 | if score is not None and not len(bbox) == len(score): 44 | raise ValueError('The length of score must be same as that of bbox') 45 | 46 | # Returns newly instantiated matplotlib.axes.Axes object if ax is None 47 | if ax is None: 48 | fig = plt.figure() 49 | ax = fig.add_subplot(1, 1, 1) 50 | ax.imshow(img.transpose((1, 2, 0)).astype(np.uint8)) 51 | # If there is no bounding box to display, visualize the image and exit. 52 | if len(bbox) == 0: 53 | return ax 54 | 55 | if instance_colors is None: 56 | # Red 57 | instance_colors = np.zeros((len(bbox), 3), dtype=np.float32) 58 | instance_colors[:, 0] = 255 59 | instance_colors = np.array(instance_colors) 60 | 61 | for i, bb in enumerate(bbox): 62 | xy = (bb[1], bb[0]) 63 | height = bb[2] - bb[0] 64 | width = bb[3] - bb[1] 65 | color = instance_colors[i % len(instance_colors)] / 255 66 | ax.add_patch(plt.Rectangle( 67 | xy, width, height, fill=False, 68 | edgecolor=color, linewidth=linewidth, alpha=alpha)) 69 | 70 | caption = [] 71 | 72 | if label is not None and label_names is not None: 73 | lb = label[i] 74 | if not (0 <= lb < len(label_names)): 75 | raise ValueError('No corresponding name is given') 76 | caption.append(label_names[lb]) 77 | if score is not None: 78 | sc = score[i] 79 | caption.append('{:.2f}'.format(sc)) 80 | 81 | if len(caption) > 0: 82 | ax.text(bb[1], bb[0], 83 | ': '.join(caption), 84 | style='italic', 85 | bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 10}) 86 | return ax 87 | --------------------------------------------------------------------------------