├── LICENSE ├── README.md ├── config └── config.cfg ├── data ├── vis_tp.png └── whole_figure.png ├── dataset ├── __pycache__ │ └── dataset_vd.cpython-36.pyc ├── cocodataset.py └── dataset_vd.py ├── docker └── Dockerfile ├── models ├── shap_loss_1.py ├── yolo_layer.py └── yolov3_shap.py ├── requirements ├── download_weights.sh ├── getcoco.sh └── requirements.txt ├── test.ipynb ├── train_vd_reg.py └── utils ├── cocoapi_evaluator.py ├── parse_yolo_weights.py ├── resize_gt.py ├── utils.py ├── vd_evaluator.py └── vis_bbox.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 Hiroki Kawauchi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, and/or sublicense 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software; and 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 16 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 17 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 18 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 19 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 20 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 21 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 22 | 23 | 24 | This software uses some portions from the following software under its license: 25 | 26 | PyTorch_YOLOv3 27 | 28 | The MIT License 29 | 30 | Copyright (c) 2018 DeNA Co., Ltd. 31 | 32 | Permission is hereby granted, free of charge, to any person obtaining a copy 33 | of this software and associated documentation files (the "Software"), to deal 34 | in the Software without restriction, including without limitation the rights 35 | to use, copy, modify, merge, publish, distribute, and/or sublicense 36 | copies of the Software, and to permit persons to whom the Software is 37 | furnished to do so, subject to the following conditions: 38 | 39 | The above copyright notice and this permission notice shall be included in all 40 | copies or substantial portions of the Software; and 41 | 42 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 43 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 44 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 45 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 46 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 47 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 48 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 49 | 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SHAP-Based Interpretable Object Detection Method for Satellite Imagery 2 | This is the author implementation of [SHAP-Based Interpretable Object Detection Method for Satellite Imagery](https://www.mdpi.com/2072-4292/14/9/1970). The implementation of the object detection model (YOLOv3) is based on [Pytorch_YOLOv3](https://github.com/DeNA/PyTorch_YOLOv3). The framework of the proposed method can be applied to any differentiable object detection model. 3 | 4 |
5 |
6 | ## Performance
7 |
8 | #### Visualization
9 |
10 |
11 | Please see the paper for details on the results of the evaluation, regularization, and data selection methods.
12 |
13 | ## Installation
14 | #### Requirements
15 |
16 | - Python 3.6.3+
17 | - Numpy
18 | - OpenCV
19 | - Matplotlib
20 | - Pytorch 1.2+
21 | - Cython
22 | - Cuda (verified as operable: v10.2)
23 | - Captum (verified as operable: v0.4.1)
24 |
25 | optional:
26 | - tensorboard
27 | - [tensorboardX](https://github.com/lanpa/tensorboardX)
28 | - CuDNN
29 |
30 | #### Download the original YOLOv3 weights
31 | download the pretrained file from the author's project page:
32 |
33 | ```bash
34 | $ mkdir weights
35 | $ cd weights/
36 | $ bash ../requirements/download_weights.sh
37 | ```
38 |
39 | ## Usage
40 |
41 | Please see the test.ipynb
42 |
43 |
44 | ## Paper
45 | ### SHAP-based Methods for Interpretable Object Detection in Satellite Imagery
46 | _Hiroki Kawauchi, Takashi Fuse_
47 |
48 | [[Paper]](https://www.mdpi.com/2072-4292/14/9/1970) [[Original Implementation]](https://github.com/hiroki-kawauchi/SHAPObjectDetection.git)
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/config.cfg:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: YOLOv3
3 | BACKBONE: darknet53
4 | ANCHORS: [[10, 13], [16, 30], [33, 23],
5 | [30, 61], [62, 45], [59, 119],
6 | [116, 90], [156, 198], [373, 326]]
7 | ANCH_MASK: [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
8 | N_CLASSES: 1
9 | TRAIN:
10 | TRAIN_DIR: ../train
11 | VAL_DIR: ../val
12 | LR: 0.001
13 | MOMENTUM: 0.9
14 | DECAY: 0.0005
15 | BURN_IN: 100
16 | MAXITER: 3000
17 | STEPS: (1200, 1500)
18 | BATCHSIZE: 4
19 | SUBDIVISION: 16
20 | IMGSIZE: 416
21 | LOSSTYPE: l2
22 | IGNORETHRE: 0.7
23 | ATTENTION_ALPHA: 0.0
24 | ATTENTION_BETA: 1.0
25 | N_SAMPLES: 1
26 | AUGMENTATION:
27 | RANDRESIZE: False
28 | JITTER: 0
29 | RANDOM_PLACING: True
30 | HUE: 0.1
31 | SATURATION: 1.5
32 | EXPOSURE: 1.5
33 | LRFLIP: True
34 | RANDOM_DISTORT: True
35 | TEST:
36 | CONFTHRE: 0.50
37 | NMSTHRE: 0.50
38 | IMGSIZE: 416
39 | NUM_GPUS: 1
40 |
--------------------------------------------------------------------------------
/data/vis_tp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiroki-kawauchi/SHAPObjectDetection/5ea90f743227ac22fa07fbeb9b0e1ea821e5defa/data/vis_tp.png
--------------------------------------------------------------------------------
/data/whole_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiroki-kawauchi/SHAPObjectDetection/5ea90f743227ac22fa07fbeb9b0e1ea821e5defa/data/whole_figure.png
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_vd.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiroki-kawauchi/SHAPObjectDetection/5ea90f743227ac22fa07fbeb9b0e1ea821e5defa/dataset/__pycache__/dataset_vd.cpython-36.pyc
--------------------------------------------------------------------------------
/dataset/cocodataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 | import cv2
7 | from pycocotools.coco import COCO
8 |
9 | from utils.utils import *
10 |
11 |
12 | class COCODataset(Dataset):
13 | """
14 | COCO dataset class.
15 | """
16 | def __init__(self, model_type, data_dir='COCO', json_file='instances_train2017.json',
17 | name='train2017', img_size=416,
18 | augmentation=None, min_size=1, debug=False):
19 | """
20 | COCO dataset initialization. Annotation data are read into memory by COCO API.
21 | Args:
22 | model_type (str): model name specified in config file
23 | data_dir (str): dataset root directory
24 | json_file (str): COCO json file name
25 | name (str): COCO data name (e.g. 'train2017' or 'val2017')
26 | img_size (int): target image size after pre-processing
27 | min_size (int): bounding boxes smaller than this are ignored
28 | debug (bool): if True, only one data id is selected from the dataset
29 | """
30 | self.data_dir = data_dir
31 | self.json_file = json_file
32 | self.model_type = model_type
33 | self.coco = COCO(self.data_dir+'annotations/'+self.json_file)
34 | self.ids = self.coco.getImgIds()
35 | if debug:
36 | self.ids = self.ids[1:2]
37 | print("debug mode...", self.ids)
38 | self.class_ids = sorted(self.coco.getCatIds())
39 | self.name = name
40 | self.max_labels = 50
41 | self.img_size = img_size
42 | self.min_size = min_size
43 | self.lrflip = augmentation['LRFLIP']
44 | self.jitter = augmentation['JITTER']
45 | self.random_placing = augmentation['RANDOM_PLACING']
46 | self.hue = augmentation['HUE']
47 | self.saturation = augmentation['SATURATION']
48 | self.exposure = augmentation['EXPOSURE']
49 | self.random_distort = augmentation['RANDOM_DISTORT']
50 |
51 |
52 | def __len__(self):
53 | return len(self.ids)
54 |
55 | def __getitem__(self, index):
56 | """
57 | One image / label pair for the given index is picked up \
58 | and pre-processed.
59 | Args:
60 | index (int): data index
61 | Returns:
62 | img (numpy.ndarray): pre-processed image
63 | padded_labels (torch.Tensor): pre-processed label data. \
64 | The shape is :math:`[self.max_labels, 5]`. \
65 | each label consists of [class, xc, yc, w, h]:
66 | class (float): class index.
67 | xc, yc (float) : center of bbox whose values range from 0 to 1.
68 | w, h (float) : size of bbox whose values range from 0 to 1.
69 | info_img : tuple of h, w, nh, nw, dx, dy.
70 | h, w (int): original shape of the image
71 | nh, nw (int): shape of the resized image without padding
72 | dx, dy (int): pad size
73 | id_ (int): same as the input index. Used for evaluation.
74 | """
75 | id_ = self.ids[index]
76 |
77 | anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=None)
78 | annotations = self.coco.loadAnns(anno_ids)
79 |
80 | lrflip = False
81 | if np.random.rand() > 0.5 and self.lrflip == True:
82 | lrflip = True
83 |
84 | # load image and preprocess
85 | img_file = os.path.join(self.data_dir, self.name,
86 | '{:012}'.format(id_) + '.jpg')
87 | img = cv2.imread(img_file)
88 |
89 | if self.json_file == 'instances_val5k.json' and img is None:
90 | img_file = os.path.join(self.data_dir, 'train2017',
91 | '{:012}'.format(id_) + '.jpg')
92 | img = cv2.imread(img_file)
93 | assert img is not None
94 |
95 | img, info_img = preprocess(img, self.img_size, jitter=self.jitter,
96 | random_placing=self.random_placing)
97 |
98 | if self.random_distort:
99 | img = random_distort(img, self.hue, self.saturation, self.exposure)
100 |
101 | img = np.transpose(img / 255., (2, 0, 1))
102 |
103 | if lrflip:
104 | img = np.flip(img, axis=2).copy()
105 |
106 | # load labels
107 | labels = []
108 | for anno in annotations:
109 | if anno['bbox'][2] > self.min_size and anno['bbox'][3] > self.min_size:
110 | labels.append([])
111 | labels[-1].append(self.class_ids.index(anno['category_id']))
112 | labels[-1].extend(anno['bbox'])
113 |
114 | padded_labels = np.zeros((self.max_labels, 5))
115 | if len(labels) > 0:
116 | labels = np.stack(labels)
117 | if 'YOLO' in self.model_type:
118 | labels = label2yolobox(labels, info_img, self.img_size, lrflip)
119 | padded_labels[range(len(labels))[:self.max_labels]
120 | ] = labels[:self.max_labels]
121 | padded_labels = torch.from_numpy(padded_labels)
122 |
123 | return img, padded_labels, info_img, id_
124 |
--------------------------------------------------------------------------------
/dataset/dataset_vd.py:
--------------------------------------------------------------------------------
1 | # reference:
2 | # https://github.com/eriklindernoren/PyTorch-YOLOv3/blob/master/utils/datasets.py
3 | import json
4 | import glob
5 | import random
6 | import os
7 | import sys
8 | import numpy as np
9 | import cv2
10 | import torch
11 | import torch.nn.functional as F
12 |
13 |
14 | from torch.utils.data import Dataset
15 | import torchvision.transforms as transforms
16 |
17 | from utils.utils import *
18 |
19 |
20 | def load_classes(path):
21 | """
22 | Loads class labels at 'path'
23 | """
24 | fp = open(path, "r")
25 | names = fp.read().split("\n")[:-1]
26 | return names
27 | '''
28 | def pad_to_square(img, pad_value):
29 | c, h, w = img.shape
30 | dim_diff = np.abs(h - w)
31 | # (upper / left) padding and (lower / right) padding
32 | pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
33 | # Determine padding
34 | pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
35 | # Add padding
36 | img = F.pad(img, pad, "constant", value=pad_value)
37 |
38 | return img, pad
39 |
40 |
41 | def resize(image, size):
42 | image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
43 | return image
44 |
45 |
46 | def random_resize(images, min_size=288, max_size=448):
47 | new_size = random.sample(list(range(min_size, max_size + 1, 32)), 1)[0]
48 | images = F.interpolate(images, size=new_size, mode="nearest")
49 | return images
50 |
51 |
52 | class ImageFolder(Dataset):
53 | def __init__(self, folder_path, img_size=416):
54 | self.files = sorted(glob.glob("%s/*.*" % folder_path))
55 | self.img_size = img_size
56 |
57 | def __getitem__(self, index):
58 | img_path = self.files[index % len(self.files)]
59 | # Extract image as PyTorch tensor
60 | img = transforms.ToTensor()(Image.open(img_path))
61 | # Pad to square resolution
62 | img, _ = pad_to_square(img, 0)
63 | # Resize
64 | img = resize(img, self.img_size)
65 |
66 | return img_path, img
67 |
68 | def __len__(self):
69 | return len(self.files)
70 | '''
71 |
72 | class ListDataset(Dataset):
73 | def __init__(self, model_type, data_dir, json_file='anno_data.json',
74 | img_size=416,
75 | augmentation=None,min_size=1):
76 | """
77 | Vehicle detection dataset initialization. Annotation data are read into memory by COCO API.
78 | Args:
79 | model_type (str): model name specified in config file
80 | list_path (str): dataset list textfile path
81 | img_size (int): target image size after pre-processing
82 | min_size (int): bounding boxes smaller than this are ignored
83 | """
84 | self.model_type = model_type
85 | self.img_size = img_size
86 | self.max_labels = 100
87 | self.min_size = min_size
88 | self.data_dir = data_dir
89 | self.json_file = json_file
90 | self.lrflip = augmentation['LRFLIP']
91 | self.jitter = augmentation['JITTER']
92 | self.random_placing = augmentation['RANDOM_PLACING']
93 | self.hue = augmentation['HUE']
94 | self.saturation = augmentation['SATURATION']
95 | self.exposure = augmentation['EXPOSURE']
96 | self.random_distort = augmentation['RANDOM_DISTORT']
97 |
98 | self.img_list = glob.glob(os.path.join(self.data_dir, '*.jpg'))
99 | self.img_list.extend(glob.glob(os.path.join(self.data_dir,'*.png')))
100 |
101 | def __getitem__(self, index):
102 | """
103 | One image / label pair for the given index is picked up \
104 | and pre-processed.
105 | Args:
106 | index (int): data index
107 | Returns:
108 | img (numpy.ndarray): pre-processed image
109 | padded_labels (torch.Tensor): pre-processed label data. \
110 | The shape is :math:`[self.max_labels, 5]`. \
111 | each label consists of [class, xc, yc, w, h]:
112 | class (float): class index.
113 | xc, yc (float) : center of bbox whose values range from 0 to 1.
114 | w, h (float) : size of bbox whose values range from 0 to 1.
115 | info_img : tuple of h, w, nh, nw, dx, dy.
116 | h, w (int): original shape of the image
117 | nh, nw (int): shape of the resized image without padding
118 | dx, dy (int): pad size
119 | id_ (int): same as the input index. Used for evaluation.
120 | """
121 |
122 | # load image and preprocess
123 | img_path = self.img_list[index % len(self.img_list)]
124 | img = cv2.imread(img_path)
125 | assert img is not None
126 | img, info_img = preprocess(img, self.img_size, jitter=self.jitter,
127 | random_placing=self.random_placing)
128 | if self.random_distort:
129 | img = random_distort(img, self.hue, self.saturation, self.exposure)
130 |
131 | img = np.transpose(img / 255., (2, 0, 1))
132 |
133 | lrflip = False
134 | if np.random.rand() > 0.5 and self.lrflip == True:
135 | lrflip = True
136 |
137 | if lrflip:
138 | img = np.flip(img, axis=2).copy()
139 |
140 | # load labels
141 | json_open = open(os.path.join(self.data_dir, self.json_file), 'r')
142 | json_load = json.load(json_open)
143 | annotations = json_load[os.path.basename(img_path)]['regions']
144 |
145 | labels = []
146 | for anno in annotations:
147 | if anno['bb'][2] > self.min_size and anno['bb'][3] > self.min_size:
148 | labels.append([])
149 | labels[-1].append(anno['class_id'])
150 | labels[-1].extend(anno['bb'])
151 |
152 | padded_labels = np.zeros((self.max_labels, 5))
153 | if len(labels) > 0:
154 | labels = np.stack(labels).astype(np.float64)
155 | if 'YOLO' in self.model_type:
156 | labels = label2yolobox(labels, info_img, self.img_size, lrflip)
157 | padded_labels[range(len(labels))[:self.max_labels]
158 | ] = labels[:self.max_labels]
159 | padded_labels = torch.from_numpy(padded_labels)
160 |
161 | return img, padded_labels, info_img, index
162 |
163 |
164 | def __len__(self):
165 | return len(self.img_list)
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:9.0-cudnn7-runtime-ubuntu16.04
2 |
3 | RUN apt-get update && apt-get install -y --no-install-recommends \
4 | sudo \
5 | git \
6 | zip \
7 | libopencv-dev \
8 | build-essential libssl-dev libbz2-dev libreadline-dev libsqlite3-dev curl \
9 | wget && \
10 | rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
11 |
12 | ARG UID
13 | RUN useradd docker -l -u $UID -G sudo -s /bin/bash -m
14 | RUN echo 'Defaults visiblepw' >> /etc/sudoers
15 | RUN echo 'docker ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers
16 |
17 | USER docker
18 |
19 | ENV PYENV_ROOT /home/docker/.pyenv
20 | ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH
21 | RUN curl -L https://raw.githubusercontent.com/yyuu/pyenv-installer/master/bin/pyenv-installer | bash
22 |
23 | ENV PYTHON_VERSION 3.6.8
24 | RUN pyenv install ${PYTHON_VERSION} && pyenv global ${PYTHON_VERSION}
25 |
26 | RUN pip install -U pip setuptools
27 | # for pycocotools
28 | RUN pip install Cython==0.29.1 numpy==1.15.4
29 |
30 | COPY requirements/requirements.txt /tmp/requirements.txt
31 | RUN pip install -r /tmp/requirements.txt
32 |
33 | # mount YOLOv3-in-PyTorch to /work
34 | WORKDIR /work
35 |
36 | ENTRYPOINT ["/bin/bash"]
37 |
--------------------------------------------------------------------------------
/models/shap_loss_1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from models.yolov3_shap import *
4 | from utils.utils import *
5 | import sys
6 | sys.path.append('/home/linuxserver01/packages')
7 | from captum.attr import GradientShap
8 | from captum.attr import visualization as viz
9 |
10 | import numpy as np
11 |
12 | def zscore(x, axis = None):
13 | xmean = x.mean(axis=axis, keepdims=True)
14 | xstd = np.std(x, axis=axis, keepdims=True)
15 | zscore = (x-xmean)/xstd
16 | return zscore
17 |
18 | def bb_mask(bboxes, attr, img_H=608, img_W=608, std=True):
19 | pixel_attr = np.sum(attr, axis=2)
20 | if std:
21 | pixel_attr = zscore(pixel_attr)
22 |
23 | pixel_mask = np.zeros((img_H,img_W))
24 | in_area = 0
25 | for [x1,y1,x2,y2] in bboxes:
26 | for i in range(int(y1), int(y2)):
27 | for j in range(int(x1),int(x2)):
28 | pixel_mask[i,j] = 1
29 | in_area += 1
30 | out_area = img_H*img_W - in_area
31 |
32 | in_pos = 0
33 | in_neg = 0
34 | out_pos = 0
35 | out_neg = 0
36 | for i in range(img_H):
37 | for j in range(img_W):
38 | if pixel_mask[i,j]>0:
39 | if pixel_attr[i,j]>=0:
40 | in_pos += pixel_attr[i,j]
41 | else:
42 | in_neg += pixel_attr[i,j]
43 | else:
44 | if pixel_attr[i,j]>=0:
45 | out_pos += pixel_attr[i,j]
46 | else:
47 | out_neg += pixel_attr[i,j]
48 | if in_area >0:
49 | in_pos = in_pos/in_area
50 | in_neg = in_neg/in_area
51 | if out_area>0:
52 | out_pos = out_pos/out_area
53 | out_neg = out_neg/out_area
54 |
55 | return in_pos, in_neg, out_pos, out_neg
56 |
57 |
58 | def shaploss(imgs, labels, model, num_classes, confthre, nmsthre, stdevs=0.1, target_y='cls',
59 | multiply_by_inputs=True, n_samples=5, alpha=1.0, beta=1.0):
60 | """
61 | Calculating SHAP-loss
62 | Args:
63 | imgs (torch.Tensor) : input data whose shape is :math:`(N, C, H, W)`, \
64 | where N, C are batchsize and num. of channels.
65 | labels (torch.Tensor) : label array whose shape is :math:`(N, 100, 5)`
66 | each label consists of [class, xc, yc, w, h]:
67 | class (float): class index.
68 | xc, yc (float) : center of bbox whose values range from 0 to 1.
69 | w, h (float) : size of bbox whose values range from 0 to 1.
70 | info_imgs :
71 | info_img : tuple of h, w, nh, nw, dx, dy.
72 | h, w (int): original shape of the image
73 | nh, nw (int): shape of the resized image without padding
74 | dx, dy (int): pad size
75 | Returns:
76 | loss (torch.Tensor)
77 | """
78 | # Find TP output
79 | model.eval()
80 | model(imgs)
81 | with torch.no_grad():
82 | outputss = model(imgs)
83 | outputss = postprocess(outputss, num_classes, confthre, nmsthre)
84 |
85 | labels = labels.cpu().data
86 | nlabel = (labels.sum(dim=2) > 0).sum(dim=1)# numbers of gt-objects in each image
87 |
88 | gt_bboxes_tp = [None for _ in range(len(nlabel))]
89 | output_id_tp = [None for _ in range(len(nlabel))]
90 | for outputs, label, i in zip(outputss, labels, range(len(nlabel))):
91 | gt_bboxes = []
92 | yolo_bboxes = []
93 | # gt bbox
94 | H = imgs.size(2)
95 | W = imgs.size(3)
96 | for j in range(nlabel[i]):
97 | _, xc, yc, w, h = label[j] # [class, xc, yc, w, h]
98 | xyxy_label = [(xc-w/2)*W, (yc-h/2)*H, (xc+w/2)*W, (yc+h/2)*H] # [x1,y1,x2,y2]
99 | gt_bboxes.append([i,xyxy_label])
100 |
101 | if outputss[0] is not None:
102 | outputs = outputss[0].cpu().data
103 | for output in outputs:
104 | x1 = output[0].data.item()
105 | y1 = float(output[1].data.item())
106 | x2 = float(output[2].data.item())
107 | y2 = float(output[3].data.item())
108 | score = float(output[4].data.item() * output[5].data.item())
109 | label = int(output[6].data.item())
110 | yolo_bbox = [i,[x1,y1,x2,y2], score, label, False,
111 | int(output[7].data.item()),int(output[8].data.item()), int(output[9].data.item()), int(output[10].data.item())]#layer_num, anchor_num, x,y
112 | yolo_bboxes.append(yolo_bbox)
113 |
114 | # judge TP or FP
115 | # score sort
116 | yolo_bboxes = sorted(yolo_bboxes, key=lambda x: x[2])
117 |
118 | for k in range(len(yolo_bboxes)):
119 | a = None
120 | t = 0
121 | for gt_bbox in gt_bboxes:
122 | iou = np_bboxes_iou(np.array(yolo_bboxes[k][1]), np.array(gt_bbox[1]).reshape(1,4))
123 | if iou > max(0.5, t):
124 | a = gt_bbox
125 | t = iou
126 | if a != None:
127 | gt_bboxes_tp[i] = a[1]
128 | gt_bboxes.remove(a)
129 | yolo_bboxes[k][4] = True
130 | output_id_tp[i] = [yolo_bboxes[k][5], yolo_bboxes[k][6], yolo_bboxes[k][7], yolo_bboxes[k][8]]
131 | break
132 |
133 | def yolo_wrapper(inp, output_id):
134 | layer_num, anchor_num, x, y = output_id
135 | output = model(inp, shap=True)
136 | return output[layer_num][:,anchor_num,y,x]
137 | if target_y == 'obj':
138 | target_y = 4
139 | elif target_y == 'cls':
140 | target_y = 5
141 |
142 | num_no_tp = 0
143 | inside_bb_sum = 0
144 | outside_bb_sum = 0
145 |
146 | for gt_bbox, output_id, img in zip(gt_bboxes_tp, output_id_tp, imgs):
147 | if gt_bbox == None:
148 | num_no_tp += 1
149 | continue
150 |
151 | img = img.reshape(1,3,img.size(1),img.size(2))
152 |
153 | baselines = img * 0
154 | with torch.no_grad():
155 | gs = GradientShap(yolo_wrapper, multiply_by_inputs=multiply_by_inputs)
156 | attr = gs.attribute(img, additional_forward_args=output_id,
157 | n_samples=n_samples, stdevs=stdevs,
158 | baselines=baselines, target=target_y,
159 | return_convergence_delta=False)
160 | attr = np.transpose(attr.squeeze().cpu().detach().numpy(), (1,2,0))
161 | in_pos, in_neg, out_pos, out_neg = bb_mask([gt_bbox], attr, img_H=imgs.size(2), img_W=imgs.size(3))
162 | inside_bb_sum -= in_neg
163 | outside_bb_sum += out_pos
164 | loss = alpha * inside_bb_sum + beta * outside_bb_sum
165 | if len(imgs)>num_no_tp:
166 | loss = loss*len(imgs)/(len(imgs) - num_no_tp)
167 | return torch.tensor(loss, device="cuda:0",dtype=torch.float)
--------------------------------------------------------------------------------
/models/yolo_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from utils.utils import bboxes_iou
5 |
6 |
7 | class YOLOLayer(nn.Module):
8 | """
9 | detection layer corresponding to yolo_layer.c of darknet
10 | """
11 | def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7):
12 | """
13 | Args:
14 | config_model (dict) : model configuration.
15 | ANCHORS (list of tuples) :
16 | ANCH_MASK: (list of int list): index indicating the anchors to be
17 | used in YOLO layers. One of the mask group is picked from the list.
18 | N_CLASSES (int): number of classes
19 | layer_no (int): YOLO layer number - one from (0, 1, 2).
20 | in_ch (int): number of input channels.
21 | ignore_thre (float): threshold of IoU above which objectness training is ignored.
22 | """
23 |
24 | super(YOLOLayer, self).__init__()
25 | strides = [32, 16, 8] # fixed
26 | self.anchors = config_model['ANCHORS']
27 | self.anch_mask = config_model['ANCH_MASK'][layer_no]
28 | self.n_anchors = len(self.anch_mask)
29 | self.n_classes = config_model['N_CLASSES']
30 | self.ignore_thre = ignore_thre
31 | self.l2_loss = nn.MSELoss(size_average=False)
32 | self.bce_loss = nn.BCELoss(size_average=False)
33 | self.stride = strides[layer_no]
34 | self.all_anchors_grid = [(w / self.stride, h / self.stride)
35 | for w, h in self.anchors]
36 | self.masked_anchors = [self.all_anchors_grid[i]
37 | for i in self.anch_mask]
38 | self.ref_anchors = np.zeros((len(self.all_anchors_grid), 4))
39 | self.ref_anchors[:, 2:] = np.array(self.all_anchors_grid)
40 | self.ref_anchors = torch.FloatTensor(self.ref_anchors)
41 | self.conv = nn.Conv2d(in_channels=in_ch,
42 | out_channels=self.n_anchors * (self.n_classes + 5),
43 | kernel_size=1, stride=1, padding=0)
44 | self.layer_no = layer_no
45 | def forward(self, xin, labels=None, shap=False):
46 | """
47 | In this
48 | Args:
49 | xin (torch.Tensor): input feature map whose size is :math:`(N, C, H, W)`, \
50 | where N, C, H, W denote batchsize, channel width, height, width respectively.
51 | labels (torch.Tensor): label data whose size is :math:`(N, K, 5)`. \
52 | N and K denote batchsize and number of labels.
53 | Each label consists of [class, xc, yc, w, h]:
54 | class (float): class index.
55 | xc, yc (float) : center of bbox whose values range from 0 to 1.
56 | w, h (float) : size of bbox whose values range from 0 to 1.
57 | Returns:
58 | loss (torch.Tensor): total loss - the target of backprop.
59 | loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \
60 | with boxsize-dependent weights.
61 | loss_wh (torch.Tensor): w, h loss - calculated by l2 without size averaging and \
62 | with boxsize-dependent weights.
63 | loss_obj (torch.Tensor): objectness loss - calculated by BCE.
64 | loss_cls (torch.Tensor): classification loss - calculated by BCE for each class.
65 | loss_l2 (torch.Tensor): total l2 loss - only for logging.
66 | """
67 | output = self.conv(xin)
68 |
69 | batchsize = output.shape[0]
70 | fsize = output.shape[2]
71 | n_ch = 5 + self.n_classes
72 | dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor
73 |
74 | output = output.reshape(batchsize, self.n_anchors, n_ch, fsize, fsize)
75 | output = output.permute(0, 1, 3, 4, 2) # .contiguous()
76 |
77 | # logistic activation for xy, obj, cls
78 | output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(
79 | output[..., np.r_[:2, 4:n_ch]])
80 |
81 | if shap:
82 | return output
83 |
84 | # calculate pred - xywh obj cls
85 |
86 | to_xshift = np.broadcast_to(
87 | np.arange(fsize, dtype=np.float32), output.shape[:4])
88 | to_yshift = np.broadcast_to(
89 | np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4])
90 | to_xshift.flags.writeable = True
91 | to_yshift.flags.writeable = True
92 |
93 | x_shift = dtype(to_xshift)
94 | y_shift = dtype(to_yshift)
95 |
96 |
97 | masked_anchors = np.array(self.masked_anchors)
98 |
99 | to_wanchors = np.broadcast_to(np.reshape(
100 | masked_anchors[:, 0], (1, self.n_anchors, 1, 1)), output.shape[:4])
101 | to_hanchors = np.broadcast_to(np.reshape(
102 | masked_anchors[:, 1], (1, self.n_anchors, 1, 1)), output.shape[:4])
103 | to_wanchors.flags.writeable = True
104 | to_hanchors.flags.writeable = True
105 |
106 | w_anchors = dtype(to_wanchors)
107 | h_anchors = dtype(to_hanchors)
108 |
109 | if labels is None: # inference mode
110 | anchor_num = np.broadcast_to(
111 | np.arange(self.n_anchors, dtype=np.float32).reshape(1, self.n_anchors, 1, 1), output.shape[:4])
112 | anchor_num.flags.writeable = True
113 | anchor_num = dtype(anchor_num)
114 |
115 | output = torch.cat([output,torch.full((batchsize, self.n_anchors, fsize, fsize,1),
116 | fill_value=float(self.layer_no), device=torch.device('cuda'))],dim=4)
117 | output = torch.cat([output,anchor_num.reshape(batchsize, self.n_anchors, fsize, fsize,1)],dim=4)
118 | output = torch.cat([output,x_shift.reshape(batchsize, self.n_anchors, fsize, fsize,1)],dim=4)
119 | output = torch.cat([output,y_shift.reshape(batchsize, self.n_anchors, fsize, fsize,1)],dim=4)
120 |
121 | pred = output.clone()
122 | pred[..., 0] += x_shift
123 | pred[..., 1] += y_shift
124 | pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors
125 | pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors
126 |
127 | if labels is None: # inference mode
128 | pred[..., :4] *= self.stride
129 | return pred.reshape(batchsize, -1, n_ch+4).data
130 |
131 | pred = pred[..., :4].data
132 |
133 | # target assignment
134 |
135 | tgt_mask = torch.zeros(batchsize, self.n_anchors,
136 | fsize, fsize, 4 + self.n_classes).type(dtype)
137 | obj_mask = torch.ones(batchsize, self.n_anchors,
138 | fsize, fsize).type(dtype)
139 | tgt_scale = torch.zeros(batchsize, self.n_anchors,
140 | fsize, fsize, 2).type(dtype)
141 |
142 | target = torch.zeros(batchsize, self.n_anchors,
143 | fsize, fsize, n_ch).type(dtype)
144 |
145 | labels = labels.cpu().data
146 | nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects
147 |
148 | truth_x_all = labels[:, :, 1] * fsize
149 | truth_y_all = labels[:, :, 2] * fsize
150 | truth_w_all = labels[:, :, 3] * fsize
151 | truth_h_all = labels[:, :, 4] * fsize
152 | truth_i_all = truth_x_all.to(torch.int16).numpy()
153 | truth_j_all = truth_y_all.to(torch.int16).numpy()
154 |
155 | for b in range(batchsize):
156 | n = int(nlabel[b])
157 | if n == 0:
158 | continue
159 | truth_box = dtype(np.zeros((n, 4)))
160 | truth_box[:n, 2] = truth_w_all[b, :n]
161 | truth_box[:n, 3] = truth_h_all[b, :n]
162 | truth_i = truth_i_all[b, :n]
163 | truth_j = truth_j_all[b, :n]
164 |
165 | # calculate iou between truth and reference anchors
166 | anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors)
167 | best_n_all = np.argmax(anchor_ious_all, axis=1)
168 | best_n = best_n_all % 3
169 | best_n_mask = ((best_n_all == self.anch_mask[0]) | (
170 | best_n_all == self.anch_mask[1]) | (best_n_all == self.anch_mask[2]))
171 |
172 | truth_box[:n, 0] = truth_x_all[b, :n]
173 | truth_box[:n, 1] = truth_y_all[b, :n]
174 |
175 | pred_ious = bboxes_iou(
176 | pred[b].reshape(-1, 4), truth_box, xyxy=False)
177 | pred_best_iou, _ = pred_ious.max(dim=1)
178 | pred_best_iou = (pred_best_iou > self.ignore_thre)
179 | pred_best_iou = pred_best_iou.reshape(pred[b].shape[:3])
180 | # set mask to zero (ignore) if pred matches truth
181 | obj_mask[b] = 1 - pred_best_iou.int()
182 |
183 | if sum(best_n_mask) == 0:
184 | continue
185 |
186 | for ti in range(best_n.shape[0]):
187 | if best_n_mask[ti] == 1:
188 | i, j = truth_i[ti], truth_j[ti]
189 | a = best_n[ti]
190 | obj_mask[b, a, j, i] = 1
191 | tgt_mask[b, a, j, i, :] = 1
192 | target[b, a, j, i, 0] = truth_x_all[b, ti] - \
193 | truth_x_all[b, ti].to(torch.int16).to(torch.float)
194 | target[b, a, j, i, 1] = truth_y_all[b, ti] - \
195 | truth_y_all[b, ti].to(torch.int16).to(torch.float)
196 | target[b, a, j, i, 2] = torch.log(
197 | truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 0] + 1e-16)
198 | target[b, a, j, i, 3] = torch.log(
199 | truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 1] + 1e-16)
200 | target[b, a, j, i, 4] = 1
201 | target[b, a, j, i, 5 + labels[b, ti,
202 | 0].to(torch.int16).numpy()] = 1
203 | tgt_scale[b, a, j, i, :] = torch.sqrt(
204 | 2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize)
205 |
206 | # loss calculation
207 |
208 | output[..., 4] *= obj_mask
209 | output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
210 | output[..., 2:4] *= tgt_scale
211 |
212 | target[..., 4] *= obj_mask
213 | target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
214 | target[..., 2:4] *= tgt_scale
215 |
216 | bceloss = nn.BCELoss(weight=tgt_scale*tgt_scale,
217 | size_average=False) # weighted BCEloss
218 | loss_xy = bceloss(output[..., :2], target[..., :2])
219 | loss_wh = self.l2_loss(output[..., 2:4], target[..., 2:4]) / 2
220 | loss_obj = self.bce_loss(output[..., 4], target[..., 4])
221 | loss_cls = self.bce_loss(output[..., 5:], target[..., 5:])
222 | loss_l2 = self.l2_loss(output, target)
223 |
224 | loss = loss_xy + loss_wh + loss_obj + loss_cls
225 |
226 | return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2
227 |
--------------------------------------------------------------------------------
/models/yolov3_shap.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | def add_conv(in_ch, out_ch, ksize, stride):
6 | """
7 | Add a conv2d / batchnorm / leaky ReLU block.
8 | Args:
9 | in_ch (int): number of input channels of the convolution layer.
10 | out_ch (int): number of output channels of the convolution layer.
11 | ksize (int): kernel size of the convolution layer.
12 | stride (int): stride of the convolution layer.
13 | Returns:
14 | stage (Sequential) : Sequential layers composing a convolution block.
15 | """
16 | stage = nn.Sequential()
17 | pad = (ksize - 1) // 2
18 | stage.add_module('conv', nn.Conv2d(in_channels=in_ch,
19 | out_channels=out_ch, kernel_size=ksize, stride=stride,
20 | padding=pad, bias=False))
21 | stage.add_module('batch_norm', nn.BatchNorm2d(out_ch))
22 | stage.add_module('leaky', nn.LeakyReLU(0.1))
23 | return stage
24 |
25 |
26 | class resblock(nn.Module):
27 | """
28 | Sequential residual blocks each of which consists of \
29 | two convolution layers.
30 | Args:
31 | ch (int): number of input and output channels.
32 | nblocks (int): number of residual blocks.
33 | shortcut (bool): if True, residual tensor addition is enabled.
34 | """
35 | def __init__(self, ch, nblocks=1, shortcut=True):
36 | super().__init__()
37 | self.shortcut = shortcut
38 | self.module_list = nn.ModuleList()
39 | for i in range(nblocks):
40 | resblock_one = nn.ModuleList()
41 | resblock_one.append(add_conv(ch, ch//2, 1, 1))
42 | resblock_one.append(add_conv(ch//2, ch, 3, 1))
43 | self.module_list.append(resblock_one)
44 |
45 | def forward(self, x):
46 | for module in self.module_list:
47 | h = x
48 | for res in module:
49 | h = res(h)
50 | x = x + h if self.shortcut else h
51 | return x
52 |
53 |
54 |
55 | class YOLOLayer(nn.Module):
56 | """
57 | Detection Layer
58 | """
59 | def __init__(self, in_ch, n_anchors, n_classes):
60 | super(YOLOLayer, self).__init__()
61 | self.n_anchors = n_anchors
62 | self.n_classes = n_classes
63 | self.conv = nn.Conv2d(in_channels=in_ch,
64 | out_channels=self.n_anchors * (self.n_classes + 5),
65 | kernel_size=1, stride=1, padding=0)
66 |
67 | def forward(self, x, targets=None):
68 | output = self.conv(x)
69 | batchsize = output.shape[0]
70 | fsize = output.shape[2]
71 | n_ch = 5 + self.n_classes
72 | dtype = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
73 |
74 | output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize)
75 | output = output.permute(0, 1, 3, 4, 2) # .contiguous()
76 |
77 | # logistic activation for xy, obj, cls
78 | output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(
79 | output[..., np.r_[:2, 4:n_ch]])
80 |
81 | return output
82 |
83 |
84 |
85 |
86 | class YOLOv3SHAP(nn.Module):
87 | """
88 | YOLOv3 model module for calculating SHAP
89 | """
90 | def __init__(self, n_classes):
91 | """
92 | Initialization of YOLOv3 class.
93 | """
94 | super(YOLOv3SHAP, self).__init__()
95 | self.n_classes = n_classes
96 | self.module_list = nn.ModuleList()
97 | # DarkNet53
98 | self.module_list.append(add_conv(in_ch=3, out_ch=32, ksize=3, stride=1))
99 | self.module_list.append(add_conv(in_ch=32, out_ch=64, ksize=3, stride=2))
100 | self.module_list.append(resblock(ch=64))
101 | self.module_list.append(add_conv(in_ch=64, out_ch=128, ksize=3, stride=2))
102 | self.module_list.append(resblock(ch=128, nblocks=2))
103 | self.module_list.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=2))
104 | self.module_list.append(resblock(ch=256, nblocks=8)) # shortcut 1 from here
105 | self.module_list.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=2))
106 | self.module_list.append(resblock(ch=512, nblocks=8)) # shortcut 2 from here
107 | self.module_list.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=2))
108 | self.module_list.append(resblock(ch=1024, nblocks=4))
109 |
110 | # YOLOv3
111 | self.module_list.append(resblock(ch=1024, nblocks=2, shortcut=False))
112 | self.module_list.append(add_conv(in_ch=1024, out_ch=512, ksize=1, stride=1))
113 | # 1st yolo branch
114 | self.module_list.append(add_conv(in_ch=512, out_ch=1024, ksize=3, stride=1))
115 | self.module_list.append(
116 | YOLOLayer(in_ch=1024, n_anchors=3, n_classes=self.n_classes))
117 |
118 | self.module_list.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1))
119 | self.module_list.append(nn.Upsample(scale_factor=2, mode='nearest'))
120 | self.module_list.append(add_conv(in_ch=768, out_ch=256, ksize=1, stride=1))
121 | self.module_list.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1))
122 | self.module_list.append(resblock(ch=512, nblocks=1, shortcut=False))
123 | self.module_list.append(add_conv(in_ch=512, out_ch=256, ksize=1, stride=1))
124 | # 2nd yolo branch
125 | self.module_list.append(add_conv(in_ch=256, out_ch=512, ksize=3, stride=1))
126 | self.module_list.append(
127 | YOLOLayer(in_ch=512, n_anchors=3, n_classes=self.n_classes))
128 | self.module_list.append(add_conv(in_ch=256, out_ch=128, ksize=1, stride=1))
129 | self.module_list.append(nn.Upsample(scale_factor=2, mode='nearest'))
130 | self.module_list.append(add_conv(in_ch=384, out_ch=128, ksize=1, stride=1))
131 | self.module_list.append(add_conv(in_ch=128, out_ch=256, ksize=3, stride=1))
132 | self.module_list.append(resblock(ch=256, nblocks=2, shortcut=False))
133 | self.module_list.append(
134 | YOLOLayer(in_ch=256, n_anchors=3, n_classes=self.n_classes))
135 |
136 | def forward(self, x):
137 | """
138 | Forward path of YOLOv3.
139 | Args:
140 | x (torch.Tensor) : input data whose shape is :math:`(N, C, H, W)`, \
141 | where N, C are batchsize and num. of channels.
142 | targets (torch.Tensor) : label array whose shape is :math:`(N, 50, 5)`
143 |
144 | Returns:
145 | training:
146 | output (torch.Tensor): loss tensor for backpropagation.
147 | test:
148 | output (torch.Tensor): concatenated detection results.
149 | """
150 | output = []
151 | route_layers = []
152 | for i, module in enumerate(self.module_list):
153 | # yolo layers
154 | if i in [14, 22, 28]:
155 | x = module(x)
156 | output.append(x)
157 | else:
158 | x = module(x)
159 |
160 | # route layers = shortcut
161 | if i in [6, 8, 12, 20]:
162 | route_layers.append(x)
163 | if i == 14:
164 | x = route_layers[2]
165 | if i == 22: # yolo 2nd
166 | x = route_layers[3]
167 | if i == 16:
168 | x = torch.cat((x, route_layers[1]), 1)
169 | if i == 24:
170 | x = torch.cat((x, route_layers[0]), 1)
171 |
172 | return output
--------------------------------------------------------------------------------
/requirements/download_weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | wget https://pjreddie.com/media/files/yolov3.weights
4 | wget https://pjreddie.com/media/files/darknet53.conv.74
--------------------------------------------------------------------------------
/requirements/getcoco.sh:
--------------------------------------------------------------------------------
1 | mkdir COCO
2 | cd COCO
3 |
4 | wget http://images.cocodataset.org/zips/train2017.zip
5 | wget http://images.cocodataset.org/zips/val2017.zip
6 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
7 |
8 | unzip train2017.zip
9 | unzip val2017.zip
10 | unzip annotations_trainval2017.zip
11 |
12 | rm -f train2017.zip
13 | rm -f val2017.zip
14 | rm -f annotations_trainval2017.zip
15 |
--------------------------------------------------------------------------------
/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.0.0
2 | numpy==1.15.2
3 | matplotlib==3.0.2
4 | opencv_python==3.4.4.19
5 | tensorboardX==1.4
6 | PyYAML>=4.2b1
7 | pycocotools==2.0.0
8 |
9 |
--------------------------------------------------------------------------------
/test.ipynb:
--------------------------------------------------------------------------------
1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"test.ipynb","provenance":[],"collapsed_sections":[],"authorship_tag":"ABX9TyMXZB5s0UXMxzx0femqD6Fc"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["import yaml\n","\n","import cv2\n","import torch\n","from torch.autograd import Variable\n","\n","from models.yolov3 import *\n","from utils.utils import *\n","from utils.parse_yolo_weights import parse_yolo_weights\n","\n","from captum.attr import GradientShap\n","\n","import os\n","import glob\n","\n","from utils.vis_bbox import vis_bbox\n","import matplotlib.pyplot as plt\n","\n","import json"],"metadata":{"id":"fmuZNzVD-U2-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Visualization\n","You need to input the image, the information about your model, and the output index."],"metadata":{"id":"Yk1vsi8Bz7Sl"}},{"cell_type":"code","source":["cfg = '' #config file path\n","gpu = 0\n","image = '' #image file path\n","baselines = None\n","target_y == 'cls'\n","output_id = [0,0,0,0] #output index (layer anchor, x, y)\n","weights_path = ''\n","ckpt = ''\n","\n","with open(cfg, 'r') as f:\n"," cfg = yaml.load(f)\n","imgsize = cfg['TEST']['IMGSIZE']\n","model = YOLOv3(cfg['MODEL'])\n","num_classes = cfg['MODEL']['N_CLASSES']\n","\n","confthre = cfg['TEST']['CONFTHRE']\n","nmsthre = cfg['TEST']['NMSTHRE']\n","if gpu >= 0:\n"," model.cuda(gpu) \n","\n","\n","assert weights_path or ckpt, 'One of --weights_path and --ckpt must be specified'\n","\n","if weights_path:\n"," print(\"loading yolo weights %s\" % (weights_path))\n"," parse_yolo_weights(model, weights_path)\n","elif ckpt:\n"," print(\"loading checkpoint %s\" % (ckpt))\n"," state = torch.load(ckpt)\n"," if 'model_state_dict' in state.keys():\n"," model.load_state_dict(state['model_state_dict'])\n"," else:\n"," model.load_state_dict(state)\n","\n","model.eval()\n","\n","img = cv2.imread(image_path)\n","img_raw = img.copy()[:, :, ::-1].transpose((2, 0, 1))\n","img, info_img = preprocess(img, imgsize, jitter=0) # info = (h, w, nh, nw, dx, dy)\n","img = np.transpose(img / 255., (2, 0, 1))\n","img = torch.from_numpy(img).float().unsqueeze(0)\n","\n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n","\n","if baselines==None:\n"," baselines = img * 0 \n"," \n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n"," baselines = Variable(baselines.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n"," baselines = Variable(baselines.type(torch.FloatTensor))\n","\n","if target_y == 'obj':\n"," target_y = 4\n","elif target_y == 'cls':\n"," target_y = 5\n","\n","\n","# setting wrapper\n","def yolo_wrapper(inp, output_id):\n"," layer_num, anchor_num, x, y = output_id\n"," output = model(inp, shap=True)\n"," return output[layer_num][:,anchor_num,y,x]\n"," \n","with torch.no_grad():\n"," gs = GradientShap(yolo_wrapper,multiply_by_inputs=multiply_by_inputs)\n"," \n","with torch.no_grad():\n"," attr, delta = gs.attribute(img, additional_forward_args=output_id, n_samples=n_samples, stdevs=stdevs, baselines=baselines, target=target_y, return_convergence_delta=True)\n","# postprocessing of attribution\n","attr = np.transpose(attr.squeeze().cpu().detach().numpy(), (1,2,0))\n","original_image = np.transpose(img.squeeze().cpu().detach().numpy(), (1,2,0))\n","# visualization of attribution\n","pos_fig, pos_axis = viz.visualize_image_attr(attr,\n"," original_image,\n"," \"heat_map\",\n"," \"positive\",\n"," cmap=\"Reds\",\n"," show_colorbar=True,\n"," fig_size=(8,6))\n","neg_fig, neg_axis = viz.visualize_image_attr(attr,\n"," original_image,\n"," \"heat_map\",\n"," \"positive\",\n"," cmap=\"Blues\",\n"," show_colorbar=True,\n"," fig_size=(8,6))\n"],"metadata":{"id":"wcXejiMB0BS5"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Evaluation/Data Selection with SHAP"],"metadata":{"id":"Wyti2QWG4Tdv"}},{"cell_type":"code","source":["cfg = '' #config file path\n","gpu = 0\n","image = '' #image file path\n","baselines = None\n","target_y == 'cls'\n","output_id = [0,0,0,0] #output index (layer anchor, x, y)\n","weights_path = ''\n","ckpt = ''\n","bboxes = [] #bounding box index list\n","\n","with open(cfg, 'r') as f:\n"," cfg = yaml.load(f)\n","imgsize = cfg['TEST']['IMGSIZE']\n","model = YOLOv3(cfg['MODEL'])\n","num_classes = cfg['MODEL']['N_CLASSES']\n","\n","confthre = cfg['TEST']['CONFTHRE']\n","nmsthre = cfg['TEST']['NMSTHRE']\n","if gpu >= 0:\n"," model.cuda(gpu) \n","\n","\n","assert weights_path or ckpt, 'One of --weights_path and --ckpt must be specified'\n","\n","if weights_path:\n"," print(\"loading yolo weights %s\" % (weights_path))\n"," parse_yolo_weights(model, weights_path)\n","elif ckpt:\n"," print(\"loading checkpoint %s\" % (ckpt))\n"," state = torch.load(ckpt)\n"," if 'model_state_dict' in state.keys():\n"," model.load_state_dict(state['model_state_dict'])\n"," else:\n"," model.load_state_dict(state)\n","\n","model.eval()\n","\n","img = cv2.imread(image_path)\n","img_raw = img.copy()[:, :, ::-1].transpose((2, 0, 1))\n","img, info_img = preprocess(img, imgsize, jitter=0) # info = (h, w, nh, nw, dx, dy)\n","img = np.transpose(img / 255., (2, 0, 1))\n","img = torch.from_numpy(img).float().unsqueeze(0)\n","\n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n","\n","if baselines==None:\n"," baselines = img * 0 \n"," \n","if gpu >= 0:\n"," img = Variable(img.type(torch.cuda.FloatTensor))\n"," baselines = Variable(baselines.type(torch.cuda.FloatTensor))\n","else:\n"," img = Variable(img.type(torch.FloatTensor))\n"," baselines = Variable(baselines.type(torch.FloatTensor))\n","\n","if target_y == 'obj':\n"," target_y = 4\n","elif target_y == 'cls':\n"," target_y = 5\n","\n","\n","# setting wrapper\n","def yolo_wrapper(inp, output_id):\n"," layer_num, anchor_num, x, y = output_id\n"," output = model(inp, shap=True)\n"," return output[layer_num][:,anchor_num,y,x]\n"," \n","with torch.no_grad():\n"," gs = GradientShap(yolo_wrapper,multiply_by_inputs=multiply_by_inputs)\n"," \n","with torch.no_grad():\n"," attr, delta = gs.attribute(img, additional_forward_args=output_id, n_samples=n_samples, stdevs=stdevs, baselines=baselines, target=target_y, return_convergence_delta=True)\n","\n","img_H = 416\n","img_W = 416\n","\n","def zscore(x, axis = None):\n"," xmean = x.mean(axis=axis, keepdims=True)\n"," xstd = np.std(x, axis=axis, keepdims=True)\n"," zscore = (x-xmean)/xstd\n"," return zscore\n","\n","pixel_attr = np.sum(attr, axis=2)\n","pixel_attr = zscore(pixel_attr)\n","pixel_mask = np.zeros((img_H,img_W))\n","in_area = 0\n","for [x1,y1,x2,y2] in bboxes:\n"," for i in range(int(y1), int(y2)):\n"," for j in range(int(x1),int(x2)):\n"," pixel_mask[i,j] = 1\n"," in_area += 1\n","out_area = img_H*img_W - in_area\n","\n","in_pos = 0\n","in_neg = 0\n","out_pos = 0\n","out_neg = 0\n","l_in = []\n","for i in range(img_H):\n"," for j in range(img_W):\n"," if pixel_mask[i,j]>0:\n"," if pixel_attr[i,j]>=0:\n"," in_pos += pixel_attr[i,j]\n"," l_in.append(pixel_attr[i,j])\n"," else:\n"," in_neg += pixel_attr[i,j]\n"," \n"," else:\n"," if pixel_attr[i,j]>=0:\n"," out_pos += pixel_attr[i,j]\n"," else:\n"," out_neg += pixel_attr[i,j]\n"," \n","if in_area >0:\n"," in_pos = in_pos/in_area\n"," in_neg = in_neg/in_area\n","if out_area>0:\n"," out_pos = out_pos/out_area\n"," out_neg = out_neg/out_area\n"],"metadata":{"id":"uQiX1hTj4SYU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Training with SHAP-Regularization\n","We used annotation json files which contain dict format data like\n","* {\"image_file_name\":\n"," * {\"regions\":\n"," * [ {\"class_id\": 0, \"bb\":[0,0,0,0]},..."],"metadata":{"id":"u8Sblu4L4icq"}},{"cell_type":"code","source":["!python train_vd_reg.py --cfg config/config.cfg --weights_path weights/darknet53.conv.74 --checkpoint_interval 100 --checkpoint_dir checkpoints --anno_file anno_data.json --shap_interval 10 --eval_interval 10"],"metadata":{"id":"QE-XT6a-4kkb"},"execution_count":null,"outputs":[]}]}
--------------------------------------------------------------------------------
/train_vd_reg.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | from utils.utils import *
4 | from utils.vd_evaluator import VDEvaluator
5 | from utils.parse_yolo_weights import parse_yolo_weights
6 | from models.yolov3 import *
7 | from models.shap_loss_1 import *
8 | from dataset.dataset_vd import *
9 |
10 | import os
11 | import argparse
12 | import yaml
13 | import random
14 |
15 | import torch
16 | from torch.autograd import Variable
17 | import torch.optim as optim
18 |
19 | import numpy as np
20 | import pandas as pd
21 |
22 | import time
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--cfg', type=str, default='config/yolov3_vd.cfg',
28 | help='config file. see readme')
29 | parser.add_argument('--weights_path', type=str,
30 | default=None, help='darknet weights file')
31 | parser.add_argument('--n_cpu', type=int, default=0,
32 | help='number of workers')
33 | parser.add_argument('--checkpoint_interval', type=int,
34 | default=1000, help='interval between saving checkpoints')
35 | parser.add_argument('--eval_interval', type=int,
36 | default=4000, help='interval between evaluations')
37 | parser.add_argument('--checkpoint', type=str,
38 | help='pytorch checkpoint file path')
39 | parser.add_argument('--checkpoint_dir', type=str,
40 | default='checkpoints',
41 | help='directory where checkpoint files are saved')
42 | parser.add_argument('--use_cuda', type=bool, default=True)
43 | parser.add_argument('--debug', action='store_true', default=False,
44 | help='debug mode where only one image is trained')
45 | parser.add_argument(
46 | '--tfboard', help='tensorboard path for logging', type=str, default=None)
47 | parser.add_argument('--anno_file', type=str,
48 | default='anno_data.json', help='annotation data json file name')
49 | parser.add_argument('--shap_interval', type=int,
50 | default=None, help='interval between updating shaploss')
51 | return parser.parse_args()
52 |
53 |
54 | def main():
55 | """
56 | SHAP-regularized YOLOv3 trainer.
57 |
58 | """
59 | args = parse_args()
60 | print("Setting Arguments.. : ", args)
61 |
62 | cuda = torch.cuda.is_available() and args.use_cuda
63 | os.makedirs(args.checkpoint_dir, exist_ok=True)
64 |
65 | # Parse config settings
66 | with open(args.cfg, 'r') as f:
67 | cfg = yaml.load(f)
68 |
69 | print("successfully loaded config file: ", cfg)
70 |
71 | momentum = cfg['TRAIN']['MOMENTUM']
72 | decay = cfg['TRAIN']['DECAY']
73 | burn_in = cfg['TRAIN']['BURN_IN']
74 | iter_size = cfg['TRAIN']['MAXITER']
75 | steps = eval(cfg['TRAIN']['STEPS'])
76 | batch_size = cfg['TRAIN']['BATCHSIZE']
77 | subdivision = cfg['TRAIN']['SUBDIVISION']
78 | ignore_thre = cfg['TRAIN']['IGNORETHRE']
79 | random_resize = cfg['AUGMENTATION']['RANDRESIZE']
80 | base_lr = cfg['TRAIN']['LR'] / batch_size / subdivision
81 | at_alpha = cfg['TRAIN']['ATTENTION_ALPHA']
82 | at_beta = cfg['TRAIN']['ATTENTION_BETA']
83 |
84 | print('effective_batch_size = batch_size * iter_size = %d * %d' %
85 | (batch_size, subdivision))
86 |
87 | # Learning rate setup
88 | def burnin_schedule(i):
89 | if i < burn_in:
90 | factor = pow(i / burn_in, 4)# pow(x, y):x^y
91 | elif i < steps[0]:
92 | factor = 1.0
93 | elif i < steps[1]:
94 | factor = 0.1
95 | else:
96 | factor = 0.01
97 | return factor
98 |
99 | # Initiate model
100 | model = YOLOv3(cfg['MODEL'], ignore_thre=ignore_thre)
101 |
102 | if args.weights_path:
103 | print("loading darknet weights....", args.weights_path)
104 | parse_yolo_weights(model, args.weights_path)
105 | elif args.checkpoint:
106 | print("loading pytorch ckpt...", args.checkpoint)
107 | state = torch.load(args.checkpoint)
108 | if 'model_state_dict' in state.keys():
109 | model.load_state_dict(state['model_state_dict'])
110 | else:
111 | model.load_state_dict(state)
112 |
113 | if cuda:
114 | print("using cuda")
115 | model = model.cuda()
116 |
117 | if args.tfboard:
118 | print("using tfboard")
119 | from tensorboardX import SummaryWriter
120 | tblogger = SummaryWriter(args.tfboard)
121 |
122 | model.train()
123 |
124 | imgsize = cfg['TRAIN']['IMGSIZE']
125 |
126 | dataset = ListDataset(model_type=cfg['MODEL']['TYPE'],
127 | data_dir=cfg['TRAIN']['TRAIN_DIR'],
128 | json_file=args.anno_file,
129 | img_size=imgsize,
130 | augmentation=cfg['AUGMENTATION'])
131 |
132 | dataloader = torch.utils.data.DataLoader(
133 | dataset,
134 | batch_size=batch_size,
135 | shuffle=True,
136 | num_workers=args.n_cpu)
137 |
138 | dataiterator = iter(dataloader)
139 |
140 | evaluator = VDEvaluator(data_dir=cfg['TRAIN']['VAL_DIR'],
141 | json_file=args.anno_file,
142 | img_size=cfg['TEST']['IMGSIZE'],
143 | confthre=cfg['TEST']['CONFTHRE'],
144 | nmsthre=cfg['TEST']['NMSTHRE'])
145 |
146 | dtype = torch.cuda.FloatTensor if cuda else torch.FloatTensor
147 |
148 | # optimizer setup
149 | # set weight decay only on conv.weight
150 | params_dict = dict(model.named_parameters())
151 | params = []
152 | for key, value in params_dict.items():
153 | if 'conv.weight' in key:
154 | params += [{'params':value, 'weight_decay':decay * batch_size * subdivision}]
155 | else:
156 | params += [{'params':value, 'weight_decay':0.0}]
157 | optimizer = optim.SGD(params, lr=base_lr, momentum=momentum,
158 | dampening=0, weight_decay=decay * batch_size * subdivision)
159 |
160 | iter_state = 0
161 |
162 | if args.checkpoint:
163 | if 'optimizer_state_dict' in state.keys():
164 | optimizer.load_state_dict(state['optimizer_state_dict'])
165 | iter_state = state['iter'] + 1
166 |
167 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, burnin_schedule)
168 |
169 | # start training loop
170 | log_col = ['time(min)', 'iter','lr','xy', 'wh',
171 | 'conf', 'cls', 'shap', 'l2', 'imgsize',
172 | 'ap50', 'precision50', 'recall50', 'F_measure']
173 | log = []
174 | ap50 = np.nan
175 | precision50 = np.nan
176 | recall50 = np.nan
177 | F_measure = np.nan
178 | shap_loss = torch.tensor(float('nan'),dtype=torch.float32)
179 | t_0 = time.time()
180 | for iter_i in range(iter_state, iter_size + 1):
181 |
182 | # VD evaluation
183 |
184 | if iter_i % args.eval_interval == 0 and iter_i > 0:
185 | ap50, precision50, recall50, F_measure = evaluator.evaluate(model)
186 | model.train()
187 | if args.tfboard:
188 | tblogger.add_scalar('val/COCOAP50', ap50, iter_i)
189 |
190 | print('[Iter {}/{}]:AP50:{}'.format(iter_i, iter_size,ap50))
191 |
192 | # subdivision loop
193 | optimizer.zero_grad()
194 | for inner_iter_i in range(subdivision):
195 | try:
196 | imgs, targets, _, _ = next(dataiterator) # load a batch
197 | except StopIteration:
198 | dataiterator = iter(dataloader)
199 | imgs, targets, _, _ = next(dataiterator) # load a batch
200 | imgs = Variable(imgs.type(dtype))
201 | targets = Variable(targets.type(dtype), requires_grad=False)
202 | loss = model(imgs, targets)
203 | loss_dict = model.loss_dict
204 | # adding SHAP-based loss
205 | if args.shap_interval is not None:
206 | if inner_iter_i % args.shap_interval == 0:
207 | shap_loss_ = shaploss(imgs, targets, model,
208 | num_classes=cfg['MODEL']['N_CLASSES'],
209 | confthre=cfg['TEST']['CONFTHRE'],
210 | nmsthre=cfg['TEST']['NMSTHRE'],
211 | n_samples=cfg['TRAIN']['N_SAMPLES'],
212 | alpha=at_alpha, beta=at_beta)
213 | if shap_loss_ != 0 and shap_loss != torch.tensor(float('nan'),dtype=torch.float32):
214 | shap_loss = shap_loss_
215 |
216 | model.train()
217 | loss += shap_loss
218 | loss.backward()
219 |
220 | optimizer.step()
221 | scheduler.step()
222 |
223 | if iter_i % 10 == 0:
224 | # logging
225 | current_lr = scheduler.get_lr()[0] * batch_size * subdivision
226 | t = (time.time() - t_0)//60
227 | print('[Time %d] [Iter %d/%d] [lr %f] '
228 | '[Losses: xy %f, wh %f, conf %f, cls %f, att %f, total %f, imgsize %d, ap %f, precision %f, recall %f, F %f]'
229 | % (t, iter_i, iter_size, current_lr,
230 | loss_dict['xy'], loss_dict['wh'],
231 | loss_dict['conf'], loss_dict['cls'], shap_loss,
232 | loss_dict['l2'], imgsize, ap50, precision50, recall50, F_measure),
233 | flush=True)
234 | log.append([t, iter_i, current_lr,
235 | np.atleast_1d(loss_dict['xy'].to('cpu').detach().numpy().copy())[0],
236 | np.atleast_1d(loss_dict['wh'].to('cpu').detach().numpy().copy())[0],
237 | np.atleast_1d(loss_dict['conf'].to('cpu').detach().numpy().copy())[0],
238 | np.atleast_1d(loss_dict['cls'].to('cpu').detach().numpy().copy())[0],
239 | np.atleast_1d(shap_loss.to('cpu').detach().numpy().copy())[0],
240 | np.atleast_1d(loss_dict['l2'].to('cpu').detach().numpy().copy())[0],
241 | imgsize, ap50, precision50, recall50, F_measure])
242 | ap50 = np.nan
243 | precision50 = np.nan
244 | recall50 = np.nan
245 | F_measure = np.nan
246 |
247 | if args.tfboard:
248 | tblogger.add_scalar('train/total_loss', model.loss_dict['l2'], iter_i)
249 |
250 | # random resizing
251 | if random_resize:
252 | imgsize = (random.randint(0, 9) % 10 + 10) * 32
253 | dataset.img_shape = (imgsize, imgsize)
254 | dataset.img_size = imgsize
255 | dataloader = torch.utils.data.DataLoader(
256 | dataset, batch_size=batch_size, shuffle=True, num_workers=args.n_cpu)
257 | dataiterator = iter(dataloader)
258 |
259 | # save checkpoint
260 | #if iter_i > 0 and (iter_i % args.checkpoint_interval == 0):
261 | if (0