├── .gitignore
├── LICENSE
├── README.md
├── images
├── 000001.jpg
├── 000025.jpg
├── 000028.jpg
├── 000064.jpg
└── 000082.jpg
├── keras_frcnn
├── FixedBatchNormalization.py
├── RoiPoolingConv.py
├── __init__.py
├── config.py
├── data_augment.py
├── data_generators.py
├── inception_resnet_v2.py
├── losses.py
├── pascal_voc_parser.py
├── resnet.py
├── roi_helpers.py
├── simple_parser.py
├── vgg.py
└── xception.py
├── measure_map.py
├── requirements.txt
├── results_imgs
├── 0.png
├── 1.png
├── 2.png
├── 3.png
└── 4.png
├── test_frcnn.py
├── train_frcnn.py
└── transfer
├── export_imagenet.py
└── inception_resnet_v2.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.jpg
3 | *.png
4 | *.txt
5 | *.h5
6 | *.hdf5
7 | *.pickle
8 | *.swp
9 | logs/
10 | .DS_Store
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 YoungJin Kim
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Keras-FasterRCNN
2 | Keras implementation of Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks.
3 | cloned from [https://github.com/yhenon/keras-frcnn/](https://github.com/yhenon/keras-frcnn/)
4 |
5 | ## UPDATE:
6 | - supporting inception_resnet_v2
7 | - for use inception_resnet_v2 in keras.application as feature extractor, create new inception_resnet_v2 model file using transfer/export_imagenet.py
8 | - if use original inception_resnet_v2 model as feature extractor, you can't load weight parameter on faster-rcnn
9 |
10 | ## USAGE:
11 | - Both theano and tensorflow backends are supported. However compile times are very high in theano, and tensorflow is highly recommended.
12 | - `train_frcnn.py` can be used to train a model. To train on Pascal VOC data, simply do:
13 | `python train_frcnn.py -p /path/to/pascalvoc/`.
14 | - the Pascal VOC data set (images and annotations for bounding boxes around the classified objects) can be obtained from: http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
15 | - simple_parser.py provides an alternative way to input data, using a text file. Simply provide a text file, with each
16 | line containing:
17 |
18 | `filepath,x1,y1,x2,y2,class_name`
19 |
20 | For example:
21 |
22 | /data/imgs/img_001.jpg,837,346,981,456,cow
23 |
24 | /data/imgs/img_002.jpg,215,312,279,391,cat
25 |
26 | The classes will be inferred from the file. To use the simple parser instead of the default pascal voc style parser,
27 | use the command line option `-o simple`. For example `python train_frcnn.py -o simple -p my_data.txt`.
28 |
29 | - Running `train_frcnn.py` will write weights to disk to an hdf5 file, as well as all the setting of the training run to a `pickle` file. These
30 | settings can then be loaded by `test_frcnn.py` for any testing.
31 |
32 | - test_frcnn.py can be used to perform inference, given pretrained weights and a config file. Specify a path to the folder containing
33 | images:
34 | `python test_frcnn.py -p /path/to/test_data/`
35 | - Data augmentation can be applied by specifying `--hf` for horizontal flips, `--vf` for vertical flips and `--rot` for 90 degree rotations
36 |
37 |
38 |
39 | ## NOTES:
40 | - config.py contains all settings for the train or test run. The default settings match those in the original Faster-RCNN
41 | paper. The anchor box sizes are [128, 256, 512] and the ratios are [1:1, 1:2, 2:1].
42 | - The theano backend by default uses a 7x7 pooling region, instead of 14x14 as in the frcnn paper. This cuts down compiling time slightly.
43 | - The tensorflow backend performs a resize on the pooling region, instead of max pooling. This is much more efficient and has little impact on results.
44 |
45 |
46 | ## Example output:
47 |
48 | 
49 | 
50 | 
51 | 
52 |
53 | ## ISSUES:
54 |
55 | - If you get this error:
56 | `ValueError: There is a negative shape in the graph!`
57 | than update keras to the newest version
58 |
59 | - Make sure to use `python2`, not `python3`. If you get this error:
60 | `TypeError: unorderable types: dict() < dict()` you are using python3
61 |
62 | - If you run out of memory, try reducing the number of ROIs that are processed simultaneously. Try passing a lower `-n` to `train_frcnn.py`. Alternatively, try reducing the image size from the default value of 600 (this setting is found in `config.py`.
63 |
64 | ## Reference
65 | [1] [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks, 2015](https://arxiv.org/pdf/1506.01497.pdf)
66 | [2] [Inception-v4, Inception-ResNet and the Impact of Residual Connections on Learning, 2016](https://arxiv.org/pdf/1602.07261.pdf)
67 | [3] [https://github.com/yhenon/keras-frcnn/](https://github.com/yhenon/keras-frcnn/)
68 |
--------------------------------------------------------------------------------
/images/000001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/images/000001.jpg
--------------------------------------------------------------------------------
/images/000025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/images/000025.jpg
--------------------------------------------------------------------------------
/images/000028.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/images/000028.jpg
--------------------------------------------------------------------------------
/images/000064.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/images/000064.jpg
--------------------------------------------------------------------------------
/images/000082.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/images/000082.jpg
--------------------------------------------------------------------------------
/keras_frcnn/FixedBatchNormalization.py:
--------------------------------------------------------------------------------
1 | from keras.engine import Layer, InputSpec
2 | from keras import initializers, regularizers
3 | from keras import backend as K
4 |
5 |
6 | class FixedBatchNormalization(Layer):
7 |
8 | def __init__(self, epsilon=1e-3, axis=-1,
9 | weights=None, beta_init='zero', gamma_init='one',
10 | gamma_regularizer=None, beta_regularizer=None, **kwargs):
11 |
12 | self.supports_masking = True
13 | self.beta_init = initializers.get(beta_init)
14 | self.gamma_init = initializers.get(gamma_init)
15 | self.epsilon = epsilon
16 | self.axis = axis
17 | self.gamma_regularizer = regularizers.get(gamma_regularizer)
18 | self.beta_regularizer = regularizers.get(beta_regularizer)
19 | self.initial_weights = weights
20 | super(FixedBatchNormalization, self).__init__(**kwargs)
21 |
22 | def build(self, input_shape):
23 | self.input_spec = [InputSpec(shape=input_shape)]
24 | shape = (input_shape[self.axis],)
25 |
26 | self.gamma = self.add_weight(shape,
27 | initializer=self.gamma_init,
28 | regularizer=self.gamma_regularizer,
29 | name='{}_gamma'.format(self.name),
30 | trainable=False)
31 | self.beta = self.add_weight(shape,
32 | initializer=self.beta_init,
33 | regularizer=self.beta_regularizer,
34 | name='{}_beta'.format(self.name),
35 | trainable=False)
36 | self.running_mean = self.add_weight(shape, initializer='zero',
37 | name='{}_running_mean'.format(self.name),
38 | trainable=False)
39 | self.running_std = self.add_weight(shape, initializer='one',
40 | name='{}_running_std'.format(self.name),
41 | trainable=False)
42 |
43 | if self.initial_weights is not None:
44 | self.set_weights(self.initial_weights)
45 | del self.initial_weights
46 |
47 | self.built = True
48 |
49 | def call(self, x, mask=None):
50 |
51 | assert self.built, 'Layer must be built before being called'
52 | input_shape = K.int_shape(x)
53 |
54 | reduction_axes = list(range(len(input_shape)))
55 | del reduction_axes[self.axis]
56 | broadcast_shape = [1] * len(input_shape)
57 | broadcast_shape[self.axis] = input_shape[self.axis]
58 |
59 | if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
60 | x_normed = K.batch_normalization(
61 | x, self.running_mean, self.running_std,
62 | self.beta, self.gamma,
63 | epsilon=self.epsilon)
64 | else:
65 | # need broadcasting
66 | broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape)
67 | broadcast_running_std = K.reshape(self.running_std, broadcast_shape)
68 | broadcast_beta = K.reshape(self.beta, broadcast_shape)
69 | broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
70 | x_normed = K.batch_normalization(
71 | x, broadcast_running_mean, broadcast_running_std,
72 | broadcast_beta, broadcast_gamma,
73 | epsilon=self.epsilon)
74 |
75 | return x_normed
76 |
77 | def get_config(self):
78 | config = {'epsilon': self.epsilon,
79 | 'axis': self.axis,
80 | 'gamma_regularizer': self.gamma_regularizer.get_config() if self.gamma_regularizer else None,
81 | 'beta_regularizer': self.beta_regularizer.get_config() if self.beta_regularizer else None}
82 | base_config = super(FixedBatchNormalization, self).get_config()
83 | return dict(list(base_config.items()) + list(config.items()))
--------------------------------------------------------------------------------
/keras_frcnn/RoiPoolingConv.py:
--------------------------------------------------------------------------------
1 | from keras.engine.topology import Layer
2 | import keras.backend as K
3 |
4 | if K.backend() == 'tensorflow':
5 | import tensorflow as tf
6 |
7 |
8 | class RoiPoolingConv(Layer):
9 | '''
10 | ROI pooling layer for 2D inputs.
11 | See Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition,
12 | K. He, X. Zhang, S. Ren, J. Sun
13 | # Arguments
14 | pool_size: int
15 | Size of pooling region to use. pool_size = 7 will result in a 7x7 region.
16 | num_rois: number of regions of interest to be used
17 | # Input shape
18 | list of two 4D tensors [X_img,X_roi] with shape:
19 | X_img:
20 | `(1, channels, rows, cols)` if dim_ordering='th'
21 | or 4D tensor with shape:
22 | `(1, rows, cols, channels)` if dim_ordering='tf'.
23 | X_roi:
24 | `(1,num_rois,4)` list of rois, with ordering (x,y,w,h)
25 | # Output shape
26 | 3D tensor with shape:
27 | `(1, num_rois, channels, pool_size, pool_size)`
28 | '''
29 |
30 | def __init__(self, pool_size, num_rois, **kwargs):
31 |
32 | self.dim_ordering = K.image_dim_ordering()
33 | assert self.dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
34 |
35 | self.pool_size = pool_size
36 | self.num_rois = num_rois
37 |
38 | super(RoiPoolingConv, self).__init__(**kwargs)
39 |
40 | def build(self, input_shape):
41 | if self.dim_ordering == 'th':
42 | self.nb_channels = input_shape[0][1]
43 | elif self.dim_ordering == 'tf':
44 | self.nb_channels = input_shape[0][3]
45 |
46 | def compute_output_shape(self, input_shape):
47 | if self.dim_ordering == 'th':
48 | return None, self.num_rois, self.nb_channels, self.pool_size, self.pool_size
49 | else:
50 | return None, self.num_rois, self.pool_size, self.pool_size, self.nb_channels
51 |
52 | def call(self, x, mask=None):
53 |
54 | assert(len(x) == 2)
55 |
56 | img = x[0]
57 | rois = x[1]
58 |
59 | input_shape = K.shape(img)
60 |
61 | outputs = []
62 |
63 | for roi_idx in range(self.num_rois):
64 |
65 | x = rois[0, roi_idx, 0]
66 | y = rois[0, roi_idx, 1]
67 | w = rois[0, roi_idx, 2]
68 | h = rois[0, roi_idx, 3]
69 |
70 | row_length = w / float(self.pool_size)
71 | col_length = h / float(self.pool_size)
72 |
73 | num_pool_regions = self.pool_size
74 |
75 | #NOTE: the RoiPooling implementation differs between theano and tensorflow due to the lack of a resize op
76 | # in theano. The theano implementation is much less efficient and leads to long compile times
77 |
78 | if self.dim_ordering == 'th':
79 | for jy in range(num_pool_regions):
80 | for ix in range(num_pool_regions):
81 | x1 = x + ix * row_length
82 | x2 = x1 + row_length
83 | y1 = y + jy * col_length
84 | y2 = y1 + col_length
85 |
86 | x1 = K.cast(x1, 'int32')
87 | x2 = K.cast(x2, 'int32')
88 | y1 = K.cast(y1, 'int32')
89 | y2 = K.cast(y2, 'int32')
90 |
91 | x2 = x1 + K.maximum(1,x2-x1)
92 | y2 = y1 + K.maximum(1,y2-y1)
93 |
94 | new_shape = [input_shape[0], input_shape[1],
95 | y2 - y1, x2 - x1]
96 |
97 | x_crop = img[:, :, y1:y2, x1:x2]
98 | xm = K.reshape(x_crop, new_shape)
99 | pooled_val = K.max(xm, axis=(2, 3))
100 | outputs.append(pooled_val)
101 |
102 | elif self.dim_ordering == 'tf':
103 | x = K.cast(x, 'int32')
104 | y = K.cast(y, 'int32')
105 | w = K.cast(w, 'int32')
106 | h = K.cast(h, 'int32')
107 |
108 | rs = tf.image.resize_images(img[:, y:y+h, x:x+w, :], (self.pool_size, self.pool_size))
109 | outputs.append(rs)
110 |
111 | final_output = K.concatenate(outputs, axis=0)
112 | final_output = K.reshape(final_output, (1, self.num_rois, self.pool_size, self.pool_size, self.nb_channels))
113 |
114 | if self.dim_ordering == 'th':
115 | final_output = K.permute_dimensions(final_output, (0, 1, 4, 2, 3))
116 | else:
117 | final_output = K.permute_dimensions(final_output, (0, 1, 2, 3, 4))
118 |
119 | return final_output
120 |
--------------------------------------------------------------------------------
/keras_frcnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/keras_frcnn/__init__.py
--------------------------------------------------------------------------------
/keras_frcnn/config.py:
--------------------------------------------------------------------------------
1 | from keras import backend as K
2 |
3 |
4 | class Config:
5 |
6 | def __init__(self):
7 |
8 | self.verbose = True
9 |
10 | # base CNN model
11 | self.network = 'inception_resnet_v2'
12 |
13 | # setting for data augmentation
14 | self.use_horizontal_flips = False
15 | self.use_vertical_flips = False
16 | self.rot_90 = False
17 |
18 | # anchor box scales
19 | self.anchor_box_scales = [128, 256, 512]
20 |
21 | # anchor box ratios
22 | self.anchor_box_ratios = [[1, 1], [1, 2], [2, 1]]
23 |
24 | # size to resize the smallest side of the image
25 | self.im_size = 600
26 |
27 | # image channel-wise mean to subtract
28 | self.img_channel_mean = [103.939, 116.779, 123.68]
29 | self.img_scaling_factor = 1.0
30 |
31 | # number of ROIs at once
32 | self.num_rois = 300
33 |
34 | # stride at the RPN (this depends on the network configuration)
35 | self.rpn_stride = 16
36 |
37 | self.balanced_classes = False
38 |
39 | # scaling the stdev
40 | self.std_scaling = 4.0
41 | self.classifier_regr_std = [8.0, 8.0, 4.0, 4.0]
42 |
43 | # overlaps for RPN
44 | self.rpn_min_overlap = 0.3
45 | self.rpn_max_overlap = 0.7
46 |
47 | # overlaps for classifier ROIs
48 | self.classifier_min_overlap = 0.1
49 | self.classifier_max_overlap = 0.5
50 |
51 | # placeholder for the class mapping, automatically generated by the parser
52 | self.class_mapping = None
53 |
54 | # location of pretrained weights for the base network
55 | # weight files can be found at:
56 | # https://github.com/fchollet/deep-learning-models/releases/download/v0.2/resnet50_weights_th_dim_ordering_th_kernels_notop.h5
57 |
58 | self.model_path = 'model_frcnn.{}.hdf5'.format(self.network)
59 |
--------------------------------------------------------------------------------
/keras_frcnn/data_augment.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import copy
4 |
5 |
6 | def augment(img_data, config, augment=True):
7 | assert 'filepath' in img_data
8 | assert 'bboxes' in img_data
9 | assert 'width' in img_data
10 | assert 'height' in img_data
11 |
12 | img_data_aug = copy.deepcopy(img_data)
13 |
14 | img = cv2.imread(img_data_aug['filepath'])
15 |
16 | if augment:
17 | rows, cols = img.shape[:2]
18 |
19 | if config.use_horizontal_flips and np.random.randint(0, 2) == 0:
20 | img = cv2.flip(img, 1)
21 | for bbox in img_data_aug['bboxes']:
22 | x1 = bbox['x1']
23 | x2 = bbox['x2']
24 | bbox['x2'] = cols - x1
25 | bbox['x1'] = cols - x2
26 |
27 | if config.use_vertical_flips and np.random.randint(0, 2) == 0:
28 | img = cv2.flip(img, 0)
29 | for bbox in img_data_aug['bboxes']:
30 | y1 = bbox['y1']
31 | y2 = bbox['y2']
32 | bbox['y2'] = rows - y1
33 | bbox['y1'] = rows - y2
34 |
35 | if config.rot_90:
36 | angle = np.random.choice([0,90,180,270],1)[0]
37 | if angle == 270:
38 | img = np.transpose(img, (1,0,2))
39 | img = cv2.flip(img, 0)
40 | elif angle == 180:
41 | img = cv2.flip(img, -1)
42 | elif angle == 90:
43 | img = np.transpose(img, (1,0,2))
44 | img = cv2.flip(img, 1)
45 | elif angle == 0:
46 | pass
47 |
48 | for bbox in img_data_aug['bboxes']:
49 | x1 = bbox['x1']
50 | x2 = bbox['x2']
51 | y1 = bbox['y1']
52 | y2 = bbox['y2']
53 | if angle == 270:
54 | bbox['x1'] = y1
55 | bbox['x2'] = y2
56 | bbox['y1'] = cols - x2
57 | bbox['y2'] = cols - x1
58 | elif angle == 180:
59 | bbox['x2'] = cols - x1
60 | bbox['x1'] = cols - x2
61 | bbox['y2'] = rows - y1
62 | bbox['y1'] = rows - y2
63 | elif angle == 90:
64 | bbox['x1'] = rows - y2
65 | bbox['x2'] = rows - y1
66 | bbox['y1'] = x1
67 | bbox['y2'] = x2
68 | elif angle == 0:
69 | pass
70 |
71 | img_data_aug['width'] = img.shape[1]
72 | img_data_aug['height'] = img.shape[0]
73 | return img_data_aug, img
74 |
--------------------------------------------------------------------------------
/keras_frcnn/data_generators.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import numpy as np
3 | import cv2
4 | import random
5 | import copy
6 | from . import data_augment
7 | import threading
8 | import itertools
9 |
10 |
11 | def union(au, bu, area_intersection):
12 | area_a = (au[2] - au[0]) * (au[3] - au[1])
13 | area_b = (bu[2] - bu[0]) * (bu[3] - bu[1])
14 | area_union = area_a + area_b - area_intersection
15 | return area_union
16 |
17 |
18 | def intersection(ai, bi):
19 | x = max(ai[0], bi[0])
20 | y = max(ai[1], bi[1])
21 | w = min(ai[2], bi[2]) - x
22 | h = min(ai[3], bi[3]) - y
23 | if w < 0 or h < 0:
24 | return 0
25 | return w*h
26 |
27 |
28 | # Intersection of Union
29 | def iou(a, b):
30 | # a and b should be (x1,y1,x2,y2)
31 |
32 | if a[0] >= a[2] or a[1] >= a[3] or b[0] >= b[2] or b[1] >= b[3]:
33 | return 0.0
34 |
35 | area_i = intersection(a, b)
36 | area_u = union(a, b, area_i)
37 |
38 | return float(area_i) / float(area_u + 1e-6)
39 |
40 |
41 | # image resize
42 | def get_new_img_size(width, height, img_min_side=600):
43 | if width <= height:
44 | f = float(img_min_side) / width
45 | resized_height = int(f * height)
46 | resized_width = int(img_min_side)
47 | else:
48 | f = float(img_min_side) / height
49 | resized_width = int(f * width)
50 | resized_height = int(img_min_side)
51 |
52 | return resized_width, resized_height
53 |
54 |
55 | # for balanced class
56 | class SampleSelector:
57 | def __init__(self, class_count):
58 | # ignore classes that have zero samples
59 | self.classes = [b for b in class_count.keys() if class_count[b] > 0]
60 | self.class_cycle = itertools.cycle(self.classes)
61 | self.curr_class = next(self.class_cycle)
62 |
63 | def skip_sample_for_balanced_class(self, img_data):
64 |
65 | class_in_img = False
66 |
67 | for bbox in img_data['bboxes']:
68 |
69 | cls_name = bbox['class']
70 |
71 | if cls_name == self.curr_class:
72 | class_in_img = True
73 | self.curr_class = next(self.class_cycle)
74 | break
75 |
76 | if class_in_img:
77 | return False
78 | else:
79 | return True
80 |
81 |
82 | def calc_rpn(C, img_data, width, height, resized_width, resized_height, img_length_calc_function):
83 |
84 | downscale = float(C.rpn_stride)
85 | anchor_sizes = C.anchor_box_scales
86 | anchor_ratios = C.anchor_box_ratios
87 | num_anchors = len(anchor_sizes) * len(anchor_ratios)
88 |
89 | # calculate the output map size based on the network architecture
90 | (output_width, output_height) = img_length_calc_function(resized_width, resized_height)
91 |
92 | n_anchratios = len(anchor_ratios)
93 |
94 | # initialise empty output objectives
95 | y_rpn_overlap = np.zeros((output_height, output_width, num_anchors))
96 | y_is_box_valid = np.zeros((output_height, output_width, num_anchors))
97 | y_rpn_regr = np.zeros((output_height, output_width, num_anchors * 4))
98 |
99 | num_bboxes = len(img_data['bboxes'])
100 |
101 | num_anchors_for_bbox = np.zeros(num_bboxes).astype(int)
102 | best_anchor_for_bbox = -1 * np.ones((num_bboxes, 4)).astype(int)
103 | best_iou_for_bbox = np.zeros(num_bboxes).astype(np.float32)
104 | best_x_for_bbox = np.zeros((num_bboxes, 4)).astype(int)
105 | best_dx_for_bbox = np.zeros((num_bboxes, 4)).astype(np.float32)
106 |
107 | # get the GT box coordinates, and resize to account for image resizing
108 | gta = np.zeros((num_bboxes, 4))
109 | for bbox_num, bbox in enumerate(img_data['bboxes']):
110 | # get the GT box coordinates, and resize to account for image resizing
111 | gta[bbox_num, 0] = bbox['x1'] * (resized_width / float(width))
112 | gta[bbox_num, 1] = bbox['x2'] * (resized_width / float(width))
113 | gta[bbox_num, 2] = bbox['y1'] * (resized_height / float(height))
114 | gta[bbox_num, 3] = bbox['y2'] * (resized_height / float(height))
115 |
116 | # rpn ground truth
117 | for anchor_size_idx in range(len(anchor_sizes)):
118 | for anchor_ratio_idx in range(n_anchratios):
119 | anchor_x = anchor_sizes[anchor_size_idx] * anchor_ratios[anchor_ratio_idx][0]
120 | anchor_y = anchor_sizes[anchor_size_idx] * anchor_ratios[anchor_ratio_idx][1]
121 |
122 | for ix in range(output_width):
123 | # x-coordinates of the current anchor box
124 | x1_anc = downscale * (ix + 0.5) - anchor_x / 2
125 | x2_anc = downscale * (ix + 0.5) + anchor_x / 2
126 |
127 | # ignore boxes that go across image boundaries
128 | if x1_anc < 0 or x2_anc > resized_width:
129 | continue
130 |
131 | for jy in range(output_height):
132 |
133 | # y-coordinates of the current anchor box
134 | y1_anc = downscale * (jy + 0.5) - anchor_y / 2
135 | y2_anc = downscale * (jy + 0.5) + anchor_y / 2
136 |
137 | # ignore boxes that go across image boundaries
138 | if y1_anc < 0 or y2_anc > resized_height:
139 | continue
140 |
141 | # bbox_type indicates whether an anchor should be a target
142 | bbox_type = 'neg'
143 |
144 | # this is the best IOU for the (x,y) coord and the current anchor
145 | # note that this is different from the best IOU for a GT bbox
146 | best_iou_for_loc = 0.0
147 |
148 | for bbox_num in range(num_bboxes):
149 |
150 | # get IOU of the current GT box and the current anchor box
151 | curr_iou = iou([gta[bbox_num, 0], gta[bbox_num, 2], gta[bbox_num, 1], gta[bbox_num, 3]], [x1_anc, y1_anc, x2_anc, y2_anc])
152 | # calculate the regression targets if they will be needed
153 | if curr_iou > best_iou_for_bbox[bbox_num] or curr_iou > C.rpn_max_overlap:
154 | cx = (gta[bbox_num, 0] + gta[bbox_num, 1]) / 2.0
155 | cy = (gta[bbox_num, 2] + gta[bbox_num, 3]) / 2.0
156 | cxa = (x1_anc + x2_anc)/2.0
157 | cya = (y1_anc + y2_anc)/2.0
158 |
159 | tx = (cx - cxa) / (x2_anc - x1_anc)
160 | ty = (cy - cya) / (y2_anc - y1_anc)
161 | tw = np.log((gta[bbox_num, 1] - gta[bbox_num, 0]) / (x2_anc - x1_anc))
162 | th = np.log((gta[bbox_num, 3] - gta[bbox_num, 2]) / (y2_anc - y1_anc))
163 |
164 | if img_data['bboxes'][bbox_num]['class'] != 'bg':
165 |
166 | # all GT boxes should be mapped to an anchor box, so we keep track of which anchor box was best
167 | if curr_iou > best_iou_for_bbox[bbox_num]:
168 | best_anchor_for_bbox[bbox_num] = [jy, ix, anchor_ratio_idx, anchor_size_idx]
169 | best_iou_for_bbox[bbox_num] = curr_iou
170 | best_x_for_bbox[bbox_num,:] = [x1_anc, x2_anc, y1_anc, y2_anc]
171 | best_dx_for_bbox[bbox_num,:] = [tx, ty, tw, th]
172 |
173 | # we set the anchor to positive if the IOU is >0.7 (it does not matter if there was another better box, it just indicates overlap)
174 | if curr_iou > C.rpn_max_overlap:
175 | bbox_type = 'pos'
176 | num_anchors_for_bbox[bbox_num] += 1
177 | # we update the regression layer target if this IOU is the best for the current (x,y) and anchor position
178 | if curr_iou > best_iou_for_loc:
179 | best_iou_for_loc = curr_iou
180 | best_regr = (tx, ty, tw, th)
181 |
182 | # if the IOU is >0.3 and <0.7, it is ambiguous and no included in the objective
183 | if C.rpn_min_overlap < curr_iou < C.rpn_max_overlap:
184 | # gray zone between neg and pos
185 | if bbox_type != 'pos':
186 | bbox_type = 'neutral'
187 |
188 | # turn on or off outputs depending on IOUs
189 | if bbox_type == 'neg':
190 | y_is_box_valid[jy, ix, anchor_ratio_idx + n_anchratios * anchor_size_idx] = 1
191 | y_rpn_overlap[jy, ix, anchor_ratio_idx + n_anchratios * anchor_size_idx] = 0
192 | elif bbox_type == 'neutral':
193 | y_is_box_valid[jy, ix, anchor_ratio_idx + n_anchratios * anchor_size_idx] = 0
194 | y_rpn_overlap[jy, ix, anchor_ratio_idx + n_anchratios * anchor_size_idx] = 0
195 | elif bbox_type == 'pos':
196 | y_is_box_valid[jy, ix, anchor_ratio_idx + n_anchratios * anchor_size_idx] = 1
197 | y_rpn_overlap[jy, ix, anchor_ratio_idx + n_anchratios * anchor_size_idx] = 1
198 | start = 4 * (anchor_ratio_idx + n_anchratios * anchor_size_idx)
199 | y_rpn_regr[jy, ix, start:start+4] = best_regr
200 |
201 | # we ensure that every bbox has at least one positive RPN region
202 |
203 | for idx in range(num_anchors_for_bbox.shape[0]):
204 | if num_anchors_for_bbox[idx] == 0:
205 | # no box with an IOU greater than zero ...
206 | if best_anchor_for_bbox[idx, 0] == -1:
207 | continue
208 | y_is_box_valid[
209 | best_anchor_for_bbox[idx,0], best_anchor_for_bbox[idx,1], best_anchor_for_bbox[idx,2] + n_anchratios *
210 | best_anchor_for_bbox[idx,3]] = 1
211 | y_rpn_overlap[
212 | best_anchor_for_bbox[idx,0], best_anchor_for_bbox[idx,1], best_anchor_for_bbox[idx,2] + n_anchratios *
213 | best_anchor_for_bbox[idx,3]] = 1
214 | start = 4 * (best_anchor_for_bbox[idx,2] + n_anchratios * best_anchor_for_bbox[idx,3])
215 | y_rpn_regr[
216 | best_anchor_for_bbox[idx,0], best_anchor_for_bbox[idx,1], start:start+4] = best_dx_for_bbox[idx, :]
217 |
218 | y_rpn_overlap = np.transpose(y_rpn_overlap, (2, 0, 1))
219 | y_rpn_overlap = np.expand_dims(y_rpn_overlap, axis=0)
220 |
221 | y_is_box_valid = np.transpose(y_is_box_valid, (2, 0, 1))
222 | y_is_box_valid = np.expand_dims(y_is_box_valid, axis=0)
223 |
224 | y_rpn_regr = np.transpose(y_rpn_regr, (2, 0, 1))
225 | y_rpn_regr = np.expand_dims(y_rpn_regr, axis=0)
226 |
227 | pos_locs = np.where(np.logical_and(y_rpn_overlap[0, :, :, :] == 1, y_is_box_valid[0, :, :, :] == 1))
228 | neg_locs = np.where(np.logical_and(y_rpn_overlap[0, :, :, :] == 0, y_is_box_valid[0, :, :, :] == 1))
229 |
230 | num_pos = len(pos_locs[0])
231 |
232 | # one issue is that the RPN has many more negative than positive regions, so we turn off some of the negative
233 | # regions. We also limit it to 256 regions.
234 | num_regions = 256
235 |
236 | if len(pos_locs[0]) > num_regions/2:
237 | val_locs = random.sample(range(len(pos_locs[0])), len(pos_locs[0]) - num_regions/2)
238 | y_is_box_valid[0, pos_locs[0][val_locs], pos_locs[1][val_locs], pos_locs[2][val_locs]] = 0
239 | num_pos = num_regions/2
240 |
241 | if len(neg_locs[0]) + num_pos > num_regions:
242 | val_locs = random.sample(range(len(neg_locs[0])), len(neg_locs[0]) - num_pos)
243 | y_is_box_valid[0, neg_locs[0][val_locs], neg_locs[1][val_locs], neg_locs[2][val_locs]] = 0
244 |
245 | y_rpn_cls = np.concatenate([y_is_box_valid, y_rpn_overlap], axis=1)
246 | y_rpn_regr = np.concatenate([np.repeat(y_rpn_overlap, 4, axis=1), y_rpn_regr], axis=1)
247 |
248 | return np.copy(y_rpn_cls), np.copy(y_rpn_regr)
249 |
250 |
251 | class threadsafe_iter:
252 | """Takes an iterator/generator and makes it thread-safe by
253 | serializing call to the `next` method of given iterator/generator.
254 | """
255 | def __init__(self, it):
256 | self.it = it
257 | self.lock = threading.Lock()
258 |
259 | def __iter__(self):
260 | return self
261 |
262 | def next(self):
263 | with self.lock:
264 | return next(self.it)
265 |
266 |
267 | def threadsafe_generator(f):
268 | """A decorator that takes a generator function and makes it thread-safe.
269 | """
270 | def g(*a, **kw):
271 | return threadsafe_iter(f(*a, **kw))
272 | return g
273 |
274 |
275 | def get_anchor_gt(all_img_data, class_count, C, img_length_calc_function, backend, mode='train'):
276 |
277 | # The following line is not useful with Python 3.5, it is kept for the legacy
278 | # all_img_data = sorted(all_img_data)
279 |
280 | sample_selector = SampleSelector(class_count)
281 |
282 | while True:
283 | if mode == 'train':
284 | random.shuffle(all_img_data)
285 |
286 | for img_data in all_img_data:
287 | try:
288 |
289 | if C.balanced_classes and sample_selector.skip_sample_for_balanced_class(img_data):
290 | continue
291 |
292 | # read in image, and optionally add augmentation
293 |
294 | if mode == 'train':
295 | img_data_aug, x_img = data_augment.augment(img_data, C, augment=True)
296 | else:
297 | img_data_aug, x_img = data_augment.augment(img_data, C, augment=False)
298 |
299 | (width, height) = (img_data_aug['width'], img_data_aug['height'])
300 | (rows, cols, _) = x_img.shape
301 |
302 | assert cols == width
303 | assert rows == height
304 |
305 | # get image dimensions for resizing
306 | (resized_width, resized_height) = get_new_img_size(width, height, C.im_size)
307 |
308 | # resize the image so that smalles side is length = 600px
309 | x_img = cv2.resize(x_img, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC)
310 |
311 | try:
312 | # rpn ground-truth cls, reg
313 | y_rpn_cls, y_rpn_regr = calc_rpn(C, img_data_aug, width, height, resized_width, resized_height, img_length_calc_function)
314 | except:
315 | continue
316 |
317 | # Zero-center by mean pixel, and preprocess image
318 |
319 | x_img = x_img[:, :, (2, 1, 0)] # BGR -> RGB
320 | x_img = x_img.astype(np.float32)
321 | x_img[:, :, 0] -= C.img_channel_mean[0]
322 | x_img[:, :, 1] -= C.img_channel_mean[1]
323 | x_img[:, :, 2] -= C.img_channel_mean[2]
324 | x_img /= C.img_scaling_factor
325 |
326 | x_img = np.transpose(x_img, (2, 0, 1))
327 | x_img = np.expand_dims(x_img, axis=0)
328 |
329 | y_rpn_regr[:, y_rpn_regr.shape[1]//2:, :, :] *= C.std_scaling
330 |
331 | if backend == 'tf':
332 | x_img = np.transpose(x_img, (0, 2, 3, 1))
333 | y_rpn_cls = np.transpose(y_rpn_cls, (0, 2, 3, 1))
334 | y_rpn_regr = np.transpose(y_rpn_regr, (0, 2, 3, 1))
335 |
336 | yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug
337 |
338 | except Exception as e:
339 | print(e)
340 | continue
341 |
--------------------------------------------------------------------------------
/keras_frcnn/inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | """Xception V1 model for Keras.
2 | On ImageNet, this model gets to a top-1 validation accuracy of 0.790
3 | and a top-5 validation accuracy of 0.945.
4 | Do note that the input image format for this model is different than for
5 | the VGG16 and ResNet models (299x299 instead of 224x224),
6 | and that the input preprocessing function
7 | is also different (same as Inception V3).
8 | # Reference
9 | - [Xception: Deep Learning with Depthwise Separable Convolutions](
10 | https://arxiv.org/abs/1610.02357)
11 | """
12 |
13 | from __future__ import absolute_import
14 | from __future__ import division
15 | from __future__ import print_function
16 | import os
17 | from keras.layers import Input, Dense, Activation, Flatten, Conv2D, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D, AveragePooling2D, TimeDistributed, Concatenate, Lambda
18 |
19 | from keras import backend as K
20 |
21 | from keras_frcnn.RoiPoolingConv import RoiPoolingConv
22 | from keras_frcnn.FixedBatchNormalization import FixedBatchNormalization
23 |
24 |
25 | def get_weight_path():
26 | return os.path.join('keras_frcnn', 'weights','inception_resnet_v2.h5')
27 |
28 |
29 | def get_img_output_length(width, height):
30 | def get_output_length(input_length):
31 | # filter_sizes = [3, 3, 3, 3, 3, 3, 3]
32 | # strides = [2, 1, 2, 1, 2, 2, 2]
33 | filter_sizes = [3, 3, 3, 3, 3, 3]
34 | strides = [2, 1, 2, 1, 2, 2]
35 |
36 | assert len(filter_sizes) == len(strides)
37 |
38 | for i in range(len(filter_sizes)):
39 | input_length = (input_length - filter_sizes[i]) // strides[i] + 1
40 |
41 | return input_length
42 |
43 | return get_output_length(width), get_output_length(height)
44 |
45 |
46 | def conv2d_bn(x,
47 | filters,
48 | kernel_size,
49 | strides=1,
50 | padding='same',
51 | activation='relu',
52 | use_bias=False,
53 | name=None):
54 | """Utility function to apply conv + BN.
55 | # Arguments
56 | x: input tensor.
57 | filters: filters in `Conv2D`.
58 | kernel_size: kernel size as in `Conv2D`.
59 | strides: strides in `Conv2D`.
60 | padding: padding mode in `Conv2D`.
61 | activation: activation in `Conv2D`.
62 | use_bias: whether to use a bias in `Conv2D`.
63 | name: name of the ops; will become `name + '_ac'` for the activation
64 | and `name + '_bn'` for the batch norm layer.
65 | # Returns
66 | Output tensor after applying `Conv2D` and `BatchNormalization`.
67 | """
68 | x = Conv2D(filters,
69 | kernel_size,
70 | strides=strides,
71 | padding=padding,
72 | use_bias=use_bias,
73 | name=name)(x)
74 | if not use_bias:
75 | bn_axis = 1 if K.image_data_format() == 'channels_first' else 3
76 | bn_name = None if name is None else name + '_bn'
77 | x = BatchNormalization(axis=bn_axis,
78 | scale=False,
79 | name=bn_name)(x)
80 | if activation is not None:
81 | ac_name = None if name is None else name + '_ac'
82 | x = Activation(activation, name=ac_name)(x)
83 | return x
84 |
85 |
86 | def conv2d_bn_td(x,
87 | filters,
88 | kernel_size,
89 | strides=1,
90 | padding='same',
91 | activation='relu',
92 | use_bias=False,
93 | name=None):
94 | """Utility function to apply conv + BN.
95 | # Arguments
96 | x: input tensor.
97 | filters: filters in `Conv2D`.
98 | kernel_size: kernel size as in `Conv2D`.
99 | strides: strides in `Conv2D`.
100 | padding: padding mode in `Conv2D`.
101 | activation: activation in `Conv2D`.
102 | use_bias: whether to use a bias in `Conv2D`.
103 | name: name of the ops; will become `name + '_ac'` for the activation
104 | and `name + '_bn'` for the batch norm layer.
105 | # Returns
106 | Output tensor after applying `Conv2D` and `BatchNormalization`.
107 | """
108 | x = TimeDistributed(Conv2D(filters,
109 | kernel_size,
110 | strides=strides,
111 | padding=padding,
112 | use_bias=use_bias),
113 | name=name)(x)
114 | if not use_bias:
115 | bn_axis = 1 if K.image_data_format() == 'channels_first' else 3
116 | bn_name = None if name is None else name + '_bn'
117 | x = TimeDistributed(BatchNormalization(axis=bn_axis,
118 | scale=False),
119 | name=bn_name)(x)
120 | if activation is not None:
121 | ac_name = None if name is None else name + '_ac'
122 | x = Activation(activation, name=ac_name)(x)
123 | return x
124 |
125 |
126 | def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
127 | """Adds a Inception-ResNet block.
128 | This function builds 3 types of Inception-ResNet blocks mentioned
129 | in the paper, controlled by the `block_type` argument (which is the
130 | block name used in the official TF-slim implementation):
131 | - Inception-ResNet-A: `block_type='block35'`
132 | - Inception-ResNet-B: `block_type='block17'`
133 | - Inception-ResNet-C: `block_type='block8'`
134 | # Arguments
135 | x: input tensor.
136 | scale: scaling factor to scale the residuals (i.e., the output of
137 | passing `x` through an inception module) before adding them
138 | to the shortcut branch.
139 | Let `r` be the output from the residual branch,
140 | the output of this block will be `x + scale * r`.
141 | block_type: `'block35'`, `'block17'` or `'block8'`, determines
142 | the network structure in the residual branch.
143 | block_idx: an `int` used for generating layer names.
144 | The Inception-ResNet blocks
145 | are repeated many times in this network.
146 | We use `block_idx` to identify
147 | each of the repetitions. For example,
148 | the first Inception-ResNet-A block
149 | will have `block_type='block35', block_idx=0`,
150 | and the layer names will have
151 | a common prefix `'block35_0'`.
152 | activation: activation function to use at the end of the block
153 | (see [activations](../activations.md)).
154 | When `activation=None`, no activation is applied
155 | (i.e., "linear" activation: `a(x) = x`).
156 | # Returns
157 | Output tensor for the block.
158 | # Raises
159 | ValueError: if `block_type` is not one of `'block35'`,
160 | `'block17'` or `'block8'`.
161 | """
162 | block_name = block_type + '_' + str(block_idx)
163 |
164 | if block_type == 'block35':
165 | branch_0 = conv2d_bn(x, 32, 1, name=block_name + '_conv1')
166 | branch_1 = conv2d_bn(x, 32, 1, name=block_name + '_conv2')
167 | branch_1 = conv2d_bn(branch_1, 32, 3, name=block_name + '_conv3')
168 | branch_2 = conv2d_bn(x, 32, 1, name=block_name + '_conv4')
169 | branch_2 = conv2d_bn(branch_2, 48, 3, name=block_name + '_conv5')
170 | branch_2 = conv2d_bn(branch_2, 64, 3, name=block_name + '_conv6')
171 | branches = [branch_0, branch_1, branch_2]
172 | elif block_type == 'block17':
173 | branch_0 = conv2d_bn(x, 192, 1, name=block_name + '_conv1')
174 | branch_1 = conv2d_bn(x, 128, 1, name=block_name + '_conv2')
175 | branch_1 = conv2d_bn(branch_1, 160, [1, 7], name=block_name + '_conv3')
176 | branch_1 = conv2d_bn(branch_1, 192, [7, 1], name=block_name + '_conv4')
177 | branches = [branch_0, branch_1]
178 | elif block_type == 'block8':
179 | branch_0 = conv2d_bn(x, 192, 1, name=block_name + '_conv1')
180 | branch_1 = conv2d_bn(x, 192, 1, name=block_name + '_conv2')
181 | branch_1 = conv2d_bn(branch_1, 224, [1, 3], name=block_name + '_conv3')
182 | branch_1 = conv2d_bn(branch_1, 256, [3, 1], name=block_name + '_conv4')
183 | branches = [branch_0, branch_1]
184 | else:
185 | raise ValueError('Unknown Inception-ResNet block type. '
186 | 'Expects "block35", "block17" or "block8", '
187 | 'but got: ' + str(block_type))
188 |
189 | channel_axis = 1 if K.image_data_format() == 'channels_first' else 3
190 | mixed = Concatenate(
191 | axis=channel_axis, name=block_name + '_mixed')(branches)
192 | up = conv2d_bn(mixed,
193 | K.int_shape(x)[channel_axis],
194 | 1,
195 | activation=None,
196 | use_bias=True,
197 | name=block_name + '_conv')
198 |
199 | x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
200 | output_shape=K.int_shape(x)[1:],
201 | arguments={'scale': scale},
202 | name=block_name)([x, up])
203 | if activation is not None:
204 | x = Activation(activation, name=block_name + '_ac')(x)
205 | return x
206 |
207 |
208 | def inception_resnet_block_td(x, scale, block_type, block_idx, activation='relu'):
209 | """Adds a Inception-ResNet block.
210 | This function builds 3 types of Inception-ResNet blocks mentioned
211 | in the paper, controlled by the `block_type` argument (which is the
212 | block name used in the official TF-slim implementation):
213 | - Inception-ResNet-A: `block_type='block35'`
214 | - Inception-ResNet-B: `block_type='block17'`
215 | - Inception-ResNet-C: `block_type='block8'`
216 | # Arguments
217 | x: input tensor.
218 | scale: scaling factor to scale the residuals (i.e., the output of
219 | passing `x` through an inception module) before adding them
220 | to the shortcut branch.
221 | Let `r` be the output from the residual branch,
222 | the output of this block will be `x + scale * r`.
223 | block_type: `'block35'`, `'block17'` or `'block8'`, determines
224 | the network structure in the residual branch.
225 | block_idx: an `int` used for generating layer names.
226 | The Inception-ResNet blocks
227 | are repeated many times in this network.
228 | We use `block_idx` to identify
229 | each of the repetitions. For example,
230 | the first Inception-ResNet-A block
231 | will have `block_type='block35', block_idx=0`,
232 | and the layer names will have
233 | a common prefix `'block35_0'`.
234 | activation: activation function to use at the end of the block
235 | (see [activations](../activations.md)).
236 | When `activation=None`, no activation is applied
237 | (i.e., "linear" activation: `a(x) = x`).
238 | # Returns
239 | Output tensor for the block.
240 | # Raises
241 | ValueError: if `block_type` is not one of `'block35'`,
242 | `'block17'` or `'block8'`.
243 | """
244 | block_name = block_type + '_' + str(block_idx)
245 |
246 | if block_type == 'block35':
247 | branch_0 = conv2d_bn_td(x, 32, 1, name=block_name + '_conv1')
248 | branch_1 = conv2d_bn_td(x, 32, 1, name=block_name + '_conv2')
249 | branch_1 = conv2d_bn_td(branch_1, 32, 3, name=block_name + '_conv3')
250 | branch_2 = conv2d_bn_td(x, 32, 1, name=block_name + '_conv4')
251 | branch_2 = conv2d_bn_td(branch_2, 48, 3, name=block_name + '_conv5')
252 | branch_2 = conv2d_bn_td(branch_2, 64, 3, name=block_name + '_conv6')
253 | branches = [branch_0, branch_1, branch_2]
254 | elif block_type == 'block17':
255 | branch_0 = conv2d_bn_td(x, 192, 1, name=block_name + '_conv1')
256 | branch_1 = conv2d_bn_td(x, 128, 1, name=block_name + '_conv2')
257 | branch_1 = conv2d_bn_td(branch_1, 160, [1, 7], name=block_name + '_conv3')
258 | branch_1 = conv2d_bn_td(branch_1, 192, [7, 1], name=block_name + '_conv4')
259 | branches = [branch_0, branch_1]
260 | elif block_type == 'block8':
261 | branch_0 = conv2d_bn_td(x, 192, 1, name=block_name + '_conv1')
262 | branch_1 = conv2d_bn_td(x, 192, 1, name=block_name + '_conv2')
263 | branch_1 = conv2d_bn_td(branch_1, 224, [1, 3], name=block_name + '_conv3')
264 | branch_1 = conv2d_bn_td(branch_1, 256, [3, 1], name=block_name + '_conv4')
265 | branches = [branch_0, branch_1]
266 | else:
267 | raise ValueError('Unknown Inception-ResNet block type. '
268 | 'Expects "block35", "block17" or "block8", '
269 | 'but got: ' + str(block_type))
270 |
271 | channel_axis = 1 if K.image_data_format() == 'channels_first' else 4
272 | mixed = Concatenate(
273 | axis=channel_axis, name=block_name + '_mixed')(branches)
274 | up = conv2d_bn_td(mixed,
275 | K.int_shape(x)[channel_axis],
276 | 1,
277 | activation=None,
278 | use_bias=True,
279 | name=block_name + '_conv')
280 |
281 | x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
282 | output_shape=K.int_shape(x)[1:],
283 | arguments={'scale': scale},
284 | name=block_name)([x, up])
285 | if activation is not None:
286 | x = Activation(activation, name=block_name + '_ac')(x)
287 | return x
288 |
289 |
290 | def nn_base(input_tensor=None, trainable=False):
291 |
292 | # Determine proper input shape
293 | if K.image_dim_ordering() == 'th':
294 | input_shape = (3, None, None)
295 | else:
296 | input_shape = (None, None, 3)
297 |
298 | if input_tensor is None:
299 | img_input = Input(shape=input_shape)
300 | else:
301 | if not K.is_keras_tensor(input_tensor):
302 | img_input = Input(tensor=input_tensor, shape=input_shape)
303 | else:
304 | img_input = input_tensor
305 |
306 | if K.image_dim_ordering() == 'tf':
307 | bn_axis = 3
308 | else:
309 | bn_axis = 1
310 |
311 | # Stem block: 35 x 35 x 192
312 | x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid', name='Stem_block' + '_conv1')
313 | x = conv2d_bn(x, 32, 3, padding='valid', name='Stem_block' + '_conv2')
314 | x = conv2d_bn(x, 64, 3, name='Stem_block' + '_conv3')
315 | x = MaxPooling2D(3, strides=2)(x)
316 | x = conv2d_bn(x, 80, 1, padding='valid', name='Stem_block' + '_conv4')
317 | x = conv2d_bn(x, 192, 3, padding='valid', name='Stem_block' + '_conv5')
318 | x = MaxPooling2D(3, strides=2)(x)
319 |
320 | # Mixed 5b (Inception-A block): 35 x 35 x 320
321 | branch_0 = conv2d_bn(x, 96, 1, name='Inception_A_block' + '_conv1')
322 | branch_1 = conv2d_bn(x, 48, 1, name='Inception_A_block' + '_conv2')
323 | branch_1 = conv2d_bn(branch_1, 64, 5, name='Inception_A_block' + '_conv3')
324 | branch_2 = conv2d_bn(x, 64, 1, name='Inception_A_block' + '_conv4')
325 | branch_2 = conv2d_bn(branch_2, 96, 3, name='Inception_A_block' + '_conv5')
326 | branch_2 = conv2d_bn(branch_2, 96, 3, name='Inception_A_block' + '_conv6')
327 | branch_pool = AveragePooling2D(3, strides=1, padding='same')(x)
328 | branch_pool = conv2d_bn(branch_pool, 64, 1, name='Inception_A_block' + '_conv7')
329 | branches = [branch_0, branch_1, branch_2, branch_pool]
330 | channel_axis = 1 if K.image_data_format() == 'channels_first' else 3
331 | x = Concatenate(axis=channel_axis, name='mixed_5b')(branches)
332 |
333 | # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
334 | for block_idx in range(1, 11):
335 | x = inception_resnet_block(x,
336 | scale=0.17,
337 | block_type='block35',
338 | block_idx=block_idx)
339 |
340 | # Mixed 6a (Reduction-A block): 17 x 17 x 1088
341 | branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid', name='Reduction_A_block' + '_conv1')
342 | branch_1 = conv2d_bn(x, 256, 1, name='Reduction_A_block' + '_conv2')
343 | branch_1 = conv2d_bn(branch_1, 256, 3, name='Reduction_A_block' + '_conv3')
344 | branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid', name='Reduction_A_block' + '_conv4')
345 | branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
346 | branches = [branch_0, branch_1, branch_pool]
347 | x = Concatenate(axis=channel_axis, name='mixed_6a')(branches)
348 |
349 | # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
350 | for block_idx in range(1, 21):
351 | x = inception_resnet_block(x,
352 | scale=0.1,
353 | block_type='block17',
354 | block_idx=block_idx)
355 |
356 | return x
357 |
358 |
359 | def classifier_layers(x, input_shape, trainable=False):
360 |
361 | # compile times on theano tend to be very high, so we use smaller ROI pooling regions to workaround
362 | # (hence a smaller stride in the region that follows the ROI pool)
363 |
364 | channel_axis = 1 if K.image_data_format() == 'channels_first' else 4
365 |
366 | # Mixed 7a (Reduction-B block): 8 x 8 x 2080
367 | branch_0 = conv2d_bn_td(x, 256, 1, name='Reduction_B_block' + '_conv1')
368 | branch_0 = conv2d_bn_td(branch_0, 384, 3, strides=2, padding='valid', name='Reduction_B_block' + '_conv2')
369 | branch_1 = conv2d_bn_td(x, 256, 1, name='Reduction_B_block' + '_conv3')
370 | branch_1 = conv2d_bn_td(branch_1, 288, 3, strides=2, padding='valid', name='Reduction_B_block' + '_conv4')
371 | branch_2 = conv2d_bn_td(x, 256, 1, name='Reduction_B_block' + '_conv5')
372 | branch_2 = conv2d_bn_td(branch_2, 288, 3, name='Reduction_B_block' + '_conv6')
373 | branch_2 = conv2d_bn_td(branch_2, 320, 3, strides=2, padding='valid', name='Reduction_B_block' + '_conv7')
374 | branch_pool = TimeDistributed(MaxPooling2D(3, strides=2, padding='valid'))(x)
375 | branches = [branch_0, branch_1, branch_2, branch_pool]
376 | x = Concatenate(axis=channel_axis, name='mixed_7a')(branches)
377 |
378 | # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
379 | for block_idx in range(1, 10):
380 | x = inception_resnet_block_td(x,
381 | scale=0.2,
382 | block_type='block8',
383 | block_idx=block_idx)
384 | x = inception_resnet_block_td(x,
385 | scale=1.,
386 | activation=None,
387 | block_type='block8',
388 | block_idx=10)
389 |
390 | # Final convolution block: 8 x 8 x 1536
391 | x = conv2d_bn_td(x, 1536, 1, name='conv_7b')
392 |
393 | TimeDistributed(GlobalAveragePooling2D(), name='avg_pool')(x)
394 |
395 | return x
396 |
397 |
398 | def rpn(base_layers, num_anchors):
399 |
400 | x = Conv2D(512, (3, 3), padding='same', activation='relu', kernel_initializer='normal', name='rpn_conv1')(base_layers)
401 |
402 | x_class = Conv2D(num_anchors, (1, 1), activation='sigmoid', kernel_initializer='uniform', name='rpn_out_class')(x)
403 | x_regr = Conv2D(num_anchors * 4, (1, 1), activation='linear', kernel_initializer='zero', name='rpn_out_regress')(x)
404 |
405 | return [x_class, x_regr, base_layers]
406 |
407 |
408 | def classifier(base_layers, input_rois, num_rois, nb_classes=21, trainable=False):
409 |
410 | # compile times on theano tend to be very high, so we use smaller ROI pooling regions to workaround
411 |
412 | if K.backend() == 'tensorflow':
413 | pooling_regions = 14
414 | # Changed the input shape to 1088 from 1024 because of nn_base's output being 1088. Not sure if this is correct
415 | input_shape = (num_rois, 14, 14, 1088)
416 | elif K.backend() == 'theano':
417 | pooling_regions = 7
418 | input_shape = (num_rois, 1024, 7, 7)
419 |
420 | out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois])
421 | out = classifier_layers(out_roi_pool, input_shape=input_shape, trainable=True)
422 |
423 | out = TimeDistributed(Flatten())(out)
424 |
425 | out_class = TimeDistributed(Dense(nb_classes, activation='softmax', kernel_initializer='zero'), name='dense_class_{}'.format(nb_classes))(out)
426 | # note: no regression target for bg class
427 | out_regr = TimeDistributed(Dense(4 * (nb_classes-1), activation='linear', kernel_initializer='zero'), name='dense_regress_{}'.format(nb_classes))(out)
428 | return [out_class, out_regr]
429 |
430 |
--------------------------------------------------------------------------------
/keras_frcnn/losses.py:
--------------------------------------------------------------------------------
1 | from keras import backend as K
2 | from keras.objectives import categorical_crossentropy
3 |
4 | if K.image_dim_ordering() == 'tf':
5 | import tensorflow as tf
6 |
7 | lambda_rpn_regr = 1.0
8 | lambda_rpn_class = 1.0
9 |
10 | lambda_cls_regr = 1.0
11 | lambda_cls_class = 1.0
12 |
13 | epsilon = 1e-4
14 |
15 |
16 | def rpn_loss_regr(num_anchors):
17 | def rpn_loss_regr_fixed_num(y_true, y_pred):
18 | if K.image_dim_ordering() == 'th':
19 | x = y_true[:, 4 * num_anchors:, :, :] - y_pred
20 | x_abs = K.abs(x)
21 | x_bool = K.less_equal(x_abs, 1.0)
22 | return lambda_rpn_regr * K.sum(
23 | y_true[:, :4 * num_anchors, :, :] * (x_bool * (0.5 * x * x) + (1 - x_bool) * (x_abs - 0.5))) / K.sum(epsilon + y_true[:, :4 * num_anchors, :, :])
24 | else:
25 | x = y_true[:, :, :, 4 * num_anchors:] - y_pred
26 | x_abs = K.abs(x)
27 | x_bool = K.cast(K.less_equal(x_abs, 1.0), tf.float32)
28 |
29 | return lambda_rpn_regr * K.sum(
30 | y_true[:, :, :, :4 * num_anchors] * (x_bool * (0.5 * x * x) + (1 - x_bool) * (x_abs - 0.5))) / K.sum(epsilon + y_true[:, :, :, :4 * num_anchors])
31 |
32 | return rpn_loss_regr_fixed_num
33 |
34 |
35 | def rpn_loss_cls(num_anchors):
36 | def rpn_loss_cls_fixed_num(y_true, y_pred):
37 | if K.image_dim_ordering() == 'tf':
38 | return lambda_rpn_class * K.sum(y_true[:, :, :, :num_anchors] * K.binary_crossentropy(y_pred[:, :, :, :], y_true[:, :, :, num_anchors:])) / K.sum(epsilon + y_true[:, :, :, :num_anchors])
39 | else:
40 | return lambda_rpn_class * K.sum(y_true[:, :num_anchors, :, :] * K.binary_crossentropy(y_pred[:, :, :, :], y_true[:, num_anchors:, :, :])) / K.sum(epsilon + y_true[:, :num_anchors, :, :])
41 |
42 | return rpn_loss_cls_fixed_num
43 |
44 |
45 | def class_loss_regr(num_classes):
46 | def class_loss_regr_fixed_num(y_true, y_pred):
47 | x = y_true[:, :, 4*num_classes:] - y_pred
48 | x_abs = K.abs(x)
49 | x_bool = K.cast(K.less_equal(x_abs, 1.0), 'float32')
50 | return lambda_cls_regr * K.sum(y_true[:, :, :4*num_classes] * (x_bool * (0.5 * x * x) + (1 - x_bool) * (x_abs - 0.5))) / K.sum(epsilon + y_true[:, :, :4*num_classes])
51 | return class_loss_regr_fixed_num
52 |
53 |
54 | def class_loss_cls(y_true, y_pred):
55 | return lambda_cls_class * K.mean(categorical_crossentropy(y_true[0, :, :], y_pred[0, :, :]))
56 |
--------------------------------------------------------------------------------
/keras_frcnn/pascal_voc_parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import xml.etree.ElementTree as ET
4 | from tqdm import tqdm
5 |
6 |
7 | def get_data(input_path):
8 | all_imgs = []
9 | classes_count = {}
10 | class_mapping = {}
11 |
12 | # parsing 정보 확인 Flag
13 | visualise = False
14 |
15 | # pascal voc directory + 2012
16 | data_paths = [os.path.join(input_path, 'VOC2012')]
17 |
18 | print('Parsing annotation files')
19 | for data_path in data_paths:
20 |
21 | annot_path = os.path.join(data_path, 'Annotations')
22 | imgs_path = os.path.join(data_path, 'JPEGImages')
23 |
24 | #ImageSets/Main directory의 4개 파일(train, val, trainval, test)
25 | imgsets_path_trainval = os.path.join(data_path, 'ImageSets', 'Main', 'trainval.txt')
26 | imgsets_path_train = os.path.join(data_path, 'ImageSets', 'Main', 'train.txt')
27 | imgsets_path_val = os.path.join(data_path, 'ImageSets', 'Main', 'val.txt')
28 | imgsets_path_test = os.path.join(data_path, 'ImageSets', 'Main', 'test.txt')
29 |
30 | trainval_files = []
31 | train_files = []
32 | val_files = []
33 | test_files = []
34 |
35 | with open(imgsets_path_trainval) as f:
36 | for line in f:
37 | trainval_files.append(line.strip() + '.jpg')
38 |
39 | with open(imgsets_path_train) as f:
40 | for line in f:
41 | train_files.append(line.strip() + '.jpg')
42 |
43 | with open(imgsets_path_val) as f:
44 | for line in f:
45 | val_files.append(line.strip() + '.jpg')
46 |
47 | # test-set not included in pascal VOC 2012
48 | if os.path.isfile(imgsets_path_test):
49 | with open(imgsets_path_test) as f:
50 | for line in f:
51 | test_files.append(line.strip() + '.jpg')
52 |
53 | # 이미지셋 txt 파일 read 예외처리
54 | # try:
55 | # with open(imgsets_path_trainval) as f:
56 | # for line in f:
57 | # trainval_files.append(line.strip() + '.jpg')
58 | # except Exception as e:
59 | # print(e)
60 | #
61 | # try:
62 | # with open(imgsets_path_test) as f:
63 | # for line in f:
64 | # test_files.append(line.strip() + '.jpg')
65 | # except Exception as e:
66 | # if data_path[-7:] == 'VOC2012':
67 | # # this is expected, most pascal voc distibutions dont have the test.txt file
68 | # pass
69 | # else:
70 | # print(e)
71 |
72 | # annotation 파일 read
73 | annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]
74 | idx = 0
75 |
76 | annots = tqdm(annots)
77 | for annot in annots:
78 | # try:
79 | exist_flag = False
80 | idx += 1
81 | annots.set_description("Processing %s" % annot.split(os.sep)[-1])
82 |
83 | et = ET.parse(annot)
84 | element = et.getroot()
85 |
86 | element_objs = element.findall('object')
87 | # element_filename = element.find('filename').text + '.jpg'
88 | element_filename = element.find('filename').text
89 | element_width = int(element.find('size').find('width').text)
90 | element_height = int(element.find('size').find('height').text)
91 |
92 | if len(element_objs) > 0:
93 | annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
94 | 'height': element_height, 'bboxes': []}
95 |
96 | annotation_data['image_id'] = idx
97 |
98 | if element_filename in trainval_files:
99 | annotation_data['imageset'] = 'trainval'
100 | exist_flag = True
101 |
102 | if element_filename in train_files:
103 | annotation_data['imageset'] = 'train'
104 | exist_flag = True
105 |
106 | if element_filename in val_files:
107 | annotation_data['imageset'] = 'val'
108 | exist_flag = True
109 |
110 | if len(test_files) > 0:
111 | if element_filename in test_files:
112 | annotation_data['imageset'] = 'test'
113 | exist_flag = True
114 |
115 | # if element_filename in trainval_files:
116 | # annotation_data['imageset'] = 'trainval'
117 | # elif element_filename in test_files:
118 | # annotation_data['imageset'] = 'test'
119 | # else:
120 | # annotation_data['imageset'] = 'trainval'
121 |
122 | # annotation file not exist in ImageSet
123 | if not exist_flag:
124 | continue
125 |
126 | for element_obj in element_objs:
127 | class_name = element_obj.find('name').text
128 | if class_name not in classes_count:
129 | classes_count[class_name] = 1
130 | else:
131 | classes_count[class_name] += 1
132 |
133 | # class mapping 정보 추가
134 | if class_name not in class_mapping:
135 | class_mapping[class_name] = len(class_mapping) # 마지막 번호로 추가
136 |
137 | obj_bbox = element_obj.find('bndbox')
138 | x1 = int(round(float(obj_bbox.find('xmin').text)))
139 | y1 = int(round(float(obj_bbox.find('ymin').text)))
140 | x2 = int(round(float(obj_bbox.find('xmax').text)))
141 | y2 = int(round(float(obj_bbox.find('ymax').text)))
142 | difficulty = int(element_obj.find('difficult').text) == 1
143 | annotation_data['bboxes'].append(
144 | {'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
145 | all_imgs.append(annotation_data)
146 |
147 | if visualise:
148 | img = cv2.imread(annotation_data['filepath'])
149 | for bbox in annotation_data['bboxes']:
150 | cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox['x2'], bbox['y2']), (0, 0, 255))
151 | cv2.imshow('img', img)
152 | print(annotation_data['imageset'])
153 | cv2.waitKey(0)
154 |
155 | # except Exception as e:
156 | # print(e)
157 | # continue
158 | return all_imgs, classes_count, class_mapping
159 |
--------------------------------------------------------------------------------
/keras_frcnn/resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | '''ResNet50 model for Keras.
3 | # Reference:
4 | - [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
5 | Adapted from code contributed by BigMoyan.
6 | '''
7 |
8 | from __future__ import print_function
9 | from __future__ import absolute_import
10 |
11 | from keras.layers import Input, Add, Dense, Activation, Flatten, Convolution2D, MaxPooling2D, ZeroPadding2D, \
12 | AveragePooling2D, TimeDistributed
13 |
14 | from keras import backend as K
15 |
16 | from keras_frcnn.RoiPoolingConv import RoiPoolingConv
17 | from keras_frcnn.FixedBatchNormalization import FixedBatchNormalization
18 |
19 |
20 | def get_weight_path():
21 | if K.image_dim_ordering() == 'th':
22 | return 'resnet50_weights_th_dim_ordering_th_kernels_notop.h5'
23 | else:
24 | return 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'
25 |
26 |
27 | def get_img_output_length(width, height):
28 | def get_output_length(input_length):
29 | # zero_pad
30 | input_length += 6
31 | # apply 4 strided convolutions
32 | filter_sizes = [7, 3, 1, 1]
33 | stride = 2
34 | for filter_size in filter_sizes:
35 | input_length = (input_length - filter_size + stride) // stride
36 | return input_length
37 |
38 | return get_output_length(width), get_output_length(height)
39 |
40 |
41 | def identity_block(input_tensor, kernel_size, filters, stage, block, trainable=True):
42 |
43 | nb_filter1, nb_filter2, nb_filter3 = filters
44 |
45 | if K.image_dim_ordering() == 'tf':
46 | bn_axis = 3
47 | else:
48 | bn_axis = 1
49 |
50 | conv_name_base = 'res' + str(stage) + block + '_branch'
51 | bn_name_base = 'bn' + str(stage) + block + '_branch'
52 |
53 | x = Convolution2D(nb_filter1, (1, 1), name=conv_name_base + '2a', trainable=trainable)(input_tensor)
54 | x = FixedBatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
55 | x = Activation('relu')(x)
56 |
57 | x = Convolution2D(nb_filter2, (kernel_size, kernel_size), padding='same', name=conv_name_base + '2b', trainable=trainable)(x)
58 | x = FixedBatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
59 | x = Activation('relu')(x)
60 |
61 | x = Convolution2D(nb_filter3, (1, 1), name=conv_name_base + '2c', trainable=trainable)(x)
62 | x = FixedBatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
63 |
64 | x = Add()([x, input_tensor])
65 | x = Activation('relu')(x)
66 | return x
67 |
68 |
69 | def identity_block_td(input_tensor, kernel_size, filters, stage, block, trainable=True):
70 |
71 | # identity block time distributed
72 |
73 | nb_filter1, nb_filter2, nb_filter3 = filters
74 | if K.image_dim_ordering() == 'tf':
75 | bn_axis = 3
76 | else:
77 | bn_axis = 1
78 |
79 | conv_name_base = 'res' + str(stage) + block + '_branch'
80 | bn_name_base = 'bn' + str(stage) + block + '_branch'
81 |
82 | x = TimeDistributed(Convolution2D(nb_filter1, (1, 1), trainable=trainable, kernel_initializer='normal'), name=conv_name_base + '2a')(input_tensor)
83 | x = TimeDistributed(FixedBatchNormalization(axis=bn_axis), name=bn_name_base + '2a')(x)
84 | x = Activation('relu')(x)
85 |
86 | x = TimeDistributed(Convolution2D(nb_filter2, (kernel_size, kernel_size), trainable=trainable, kernel_initializer='normal',padding='same'), name=conv_name_base + '2b')(x)
87 | x = TimeDistributed(FixedBatchNormalization(axis=bn_axis), name=bn_name_base + '2b')(x)
88 | x = Activation('relu')(x)
89 |
90 | x = TimeDistributed(Convolution2D(nb_filter3, (1, 1), trainable=trainable, kernel_initializer='normal'), name=conv_name_base + '2c')(x)
91 | x = TimeDistributed(FixedBatchNormalization(axis=bn_axis), name=bn_name_base + '2c')(x)
92 |
93 | x = Add()([x, input_tensor])
94 | x = Activation('relu')(x)
95 |
96 | return x
97 |
98 |
99 | def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2), trainable=True):
100 |
101 | nb_filter1, nb_filter2, nb_filter3 = filters
102 | if K.image_dim_ordering() == 'tf':
103 | bn_axis = 3
104 | else:
105 | bn_axis = 1
106 |
107 | conv_name_base = 'res' + str(stage) + block + '_branch'
108 | bn_name_base = 'bn' + str(stage) + block + '_branch'
109 |
110 | x = Convolution2D(nb_filter1, (1, 1), strides=strides, name=conv_name_base + '2a', trainable=trainable)(input_tensor)
111 | x = FixedBatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
112 | x = Activation('relu')(x)
113 |
114 | x = Convolution2D(nb_filter2, (kernel_size, kernel_size), padding='same', name=conv_name_base + '2b', trainable=trainable)(x)
115 | x = FixedBatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
116 | x = Activation('relu')(x)
117 |
118 | x = Convolution2D(nb_filter3, (1, 1), name=conv_name_base + '2c', trainable=trainable)(x)
119 | x = FixedBatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
120 |
121 | shortcut = Convolution2D(nb_filter3, (1, 1), strides=strides, name=conv_name_base + '1', trainable=trainable)(input_tensor)
122 | shortcut = FixedBatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut)
123 |
124 | x = Add()([x, shortcut])
125 | x = Activation('relu')(x)
126 | return x
127 |
128 |
129 | def conv_block_td(input_tensor, kernel_size, filters, stage, block, input_shape, strides=(2, 2), trainable=True):
130 |
131 | # conv block time distributed
132 |
133 | nb_filter1, nb_filter2, nb_filter3 = filters
134 | if K.image_dim_ordering() == 'tf':
135 | bn_axis = 3
136 | else:
137 | bn_axis = 1
138 |
139 | conv_name_base = 'res' + str(stage) + block + '_branch'
140 | bn_name_base = 'bn' + str(stage) + block + '_branch'
141 |
142 | x = TimeDistributed(Convolution2D(nb_filter1, (1, 1), strides=strides, trainable=trainable, kernel_initializer='normal'), input_shape=input_shape, name=conv_name_base + '2a')(input_tensor)
143 | x = TimeDistributed(FixedBatchNormalization(axis=bn_axis), name=bn_name_base + '2a')(x)
144 | x = Activation('relu')(x)
145 |
146 | x = TimeDistributed(Convolution2D(nb_filter2, (kernel_size, kernel_size), padding='same', trainable=trainable, kernel_initializer='normal'), name=conv_name_base + '2b')(x)
147 | x = TimeDistributed(FixedBatchNormalization(axis=bn_axis), name=bn_name_base + '2b')(x)
148 | x = Activation('relu')(x)
149 |
150 | x = TimeDistributed(Convolution2D(nb_filter3, (1, 1), kernel_initializer='normal'), name=conv_name_base + '2c', trainable=trainable)(x)
151 | x = TimeDistributed(FixedBatchNormalization(axis=bn_axis), name=bn_name_base + '2c')(x)
152 |
153 | shortcut = TimeDistributed(Convolution2D(nb_filter3, (1, 1), strides=strides, trainable=trainable, kernel_initializer='normal'), name=conv_name_base + '1')(input_tensor)
154 | shortcut = TimeDistributed(FixedBatchNormalization(axis=bn_axis), name=bn_name_base + '1')(shortcut)
155 |
156 | x = Add()([x, shortcut])
157 | x = Activation('relu')(x)
158 | return x
159 |
160 |
161 | def nn_base(input_tensor=None, trainable=False):
162 |
163 | # Determine proper input shape
164 | if K.image_dim_ordering() == 'th':
165 | input_shape = (3, None, None)
166 | else:
167 | input_shape = (None, None, 3)
168 |
169 | if input_tensor is None:
170 | img_input = Input(shape=input_shape)
171 | else:
172 | if not K.is_keras_tensor(input_tensor):
173 | img_input = Input(tensor=input_tensor, shape=input_shape)
174 | else:
175 | img_input = input_tensor
176 |
177 | if K.image_dim_ordering() == 'tf':
178 | bn_axis = 3
179 | else:
180 | bn_axis = 1
181 |
182 | x = ZeroPadding2D((3, 3))(img_input)
183 |
184 | x = Convolution2D(64, (7, 7), strides=(2, 2), name='conv1', trainable = trainable)(x)
185 | x = FixedBatchNormalization(axis=bn_axis, name='bn_conv1')(x)
186 | x = Activation('relu')(x)
187 | x = MaxPooling2D((3, 3), strides=(2, 2))(x)
188 |
189 | x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), trainable = trainable)
190 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', trainable = trainable)
191 | x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', trainable = trainable)
192 |
193 | x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', trainable = trainable)
194 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', trainable = trainable)
195 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', trainable = trainable)
196 | x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', trainable = trainable)
197 |
198 | x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', trainable = trainable)
199 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', trainable = trainable)
200 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', trainable = trainable)
201 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', trainable = trainable)
202 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', trainable = trainable)
203 | x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', trainable = trainable)
204 |
205 | return x
206 |
207 |
208 | def classifier_layers(x, input_shape, trainable=False):
209 |
210 | # compile times on theano tend to be very high, so we use smaller ROI pooling regions to workaround
211 | # (hence a smaller stride in the region that follows the ROI pool)
212 | if K.backend() == 'tensorflow':
213 | x = conv_block_td(x, 3, [512, 512, 2048], stage=5, block='a', input_shape=input_shape, strides=(2, 2), trainable=trainable)
214 | elif K.backend() == 'theano':
215 | x = conv_block_td(x, 3, [512, 512, 2048], stage=5, block='a', input_shape=input_shape, strides=(1, 1), trainable=trainable)
216 |
217 | x = identity_block_td(x, 3, [512, 512, 2048], stage=5, block='b', trainable=trainable)
218 | x = identity_block_td(x, 3, [512, 512, 2048], stage=5, block='c', trainable=trainable)
219 | x = TimeDistributed(AveragePooling2D((7, 7)), name='avg_pool')(x)
220 |
221 | return x
222 |
223 |
224 | def rpn(base_layers, num_anchors):
225 |
226 | x = Convolution2D(512, (3, 3), padding='same', activation='relu', kernel_initializer='normal', name='rpn_conv1')(base_layers)
227 |
228 | x_class = Convolution2D(num_anchors, (1, 1), activation='sigmoid', kernel_initializer='uniform', name='rpn_out_class')(x)
229 | x_regr = Convolution2D(num_anchors * 4, (1, 1), activation='linear', kernel_initializer='zero', name='rpn_out_regress')(x)
230 |
231 | return [x_class, x_regr, base_layers]
232 |
233 |
234 | def classifier(base_layers, input_rois, num_rois, nb_classes=21, trainable=False):
235 |
236 | # compile times on theano tend to be very high, so we use smaller ROI pooling regions to workaround
237 |
238 | if K.backend() == 'tensorflow':
239 | pooling_regions = 14
240 | input_shape = (num_rois, 14, 14, 1024)
241 | elif K.backend() == 'theano':
242 | pooling_regions = 7
243 | input_shape = (num_rois, 1024, 7, 7)
244 |
245 | out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois])
246 | out = classifier_layers(out_roi_pool, input_shape=input_shape, trainable=True)
247 |
248 | out = TimeDistributed(Flatten())(out)
249 |
250 | out_class = TimeDistributed(Dense(nb_classes, activation='softmax', kernel_initializer='zero'), name='dense_class_{}'.format(nb_classes))(out)
251 | # note: no regression target for bg class
252 | out_regr = TimeDistributed(Dense(4 * (nb_classes-1), activation='linear', kernel_initializer='zero'), name='dense_regress_{}'.format(nb_classes))(out)
253 | return [out_class, out_regr]
254 |
255 |
--------------------------------------------------------------------------------
/keras_frcnn/roi_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pdb
3 | import math
4 | from . import data_generators
5 | import copy
6 | import time
7 |
8 |
9 | def calc_iou(R, img_data, C, class_mapping):
10 |
11 | bboxes = img_data['bboxes']
12 | (width, height) = (img_data['width'], img_data['height'])
13 | # get image dimensions for resizing
14 | (resized_width, resized_height) = data_generators.get_new_img_size(width, height, C.im_size)
15 |
16 | gta = np.zeros((len(bboxes), 4))
17 |
18 | for bbox_num, bbox in enumerate(bboxes):
19 | # get the GT box coordinates, and resize to account for image resizing
20 | gta[bbox_num, 0] = int(round(bbox['x1'] * (resized_width / float(width))/C.rpn_stride))
21 | gta[bbox_num, 1] = int(round(bbox['x2'] * (resized_width / float(width))/C.rpn_stride))
22 | gta[bbox_num, 2] = int(round(bbox['y1'] * (resized_height / float(height))/C.rpn_stride))
23 | gta[bbox_num, 3] = int(round(bbox['y2'] * (resized_height / float(height))/C.rpn_stride))
24 |
25 | x_roi = []
26 | y_class_num = []
27 | y_class_regr_coords = []
28 | y_class_regr_label = []
29 | IoUs = [] # for debugging only
30 |
31 | for ix in range(R.shape[0]):
32 | (x1, y1, x2, y2) = R[ix, :]
33 | x1 = int(round(x1))
34 | y1 = int(round(y1))
35 | x2 = int(round(x2))
36 | y2 = int(round(y2))
37 |
38 | best_iou = 0.0
39 | best_bbox = -1
40 | for bbox_num in range(len(bboxes)):
41 | curr_iou = data_generators.iou([gta[bbox_num, 0], gta[bbox_num, 2], gta[bbox_num, 1], gta[bbox_num, 3]], [x1, y1, x2, y2])
42 | if curr_iou > best_iou:
43 | best_iou = curr_iou
44 | best_bbox = bbox_num
45 |
46 | if best_iou < C.classifier_min_overlap:
47 | continue
48 | else:
49 | w = x2 - x1
50 | h = y2 - y1
51 | x_roi.append([x1, y1, w, h])
52 | IoUs.append(best_iou)
53 |
54 | if C.classifier_min_overlap <= best_iou < C.classifier_max_overlap:
55 | # hard negative example
56 | cls_name = 'bg'
57 | elif C.classifier_max_overlap <= best_iou:
58 | cls_name = bboxes[best_bbox]['class']
59 | cxg = (gta[best_bbox, 0] + gta[best_bbox, 1]) / 2.0
60 | cyg = (gta[best_bbox, 2] + gta[best_bbox, 3]) / 2.0
61 |
62 | cx = x1 + w / 2.0
63 | cy = y1 + h / 2.0
64 |
65 | tx = (cxg - cx) / float(w)
66 | ty = (cyg - cy) / float(h)
67 | tw = np.log((gta[best_bbox, 1] - gta[best_bbox, 0]) / float(w))
68 | th = np.log((gta[best_bbox, 3] - gta[best_bbox, 2]) / float(h))
69 | else:
70 | print('roi = {}'.format(best_iou))
71 | raise RuntimeError
72 |
73 | class_num = class_mapping[cls_name]
74 | class_label = len(class_mapping) * [0]
75 | class_label[class_num] = 1
76 | y_class_num.append(copy.deepcopy(class_label))
77 | coords = [0] * 4 * (len(class_mapping) - 1)
78 | labels = [0] * 4 * (len(class_mapping) - 1)
79 | if cls_name != 'bg':
80 | label_pos = 4 * class_num
81 | sx, sy, sw, sh = C.classifier_regr_std
82 | coords[label_pos:4+label_pos] = [sx*tx, sy*ty, sw*tw, sh*th]
83 | labels[label_pos:4+label_pos] = [1, 1, 1, 1]
84 | y_class_regr_coords.append(copy.deepcopy(coords))
85 | y_class_regr_label.append(copy.deepcopy(labels))
86 | else:
87 | y_class_regr_coords.append(copy.deepcopy(coords))
88 | y_class_regr_label.append(copy.deepcopy(labels))
89 |
90 | if len(x_roi) == 0:
91 | return None, None, None, None
92 |
93 | X = np.array(x_roi)
94 | Y1 = np.array(y_class_num)
95 | Y2 = np.concatenate([np.array(y_class_regr_label),np.array(y_class_regr_coords)],axis=1)
96 |
97 | return np.expand_dims(X, axis=0), np.expand_dims(Y1, axis=0), np.expand_dims(Y2, axis=0), IoUs
98 |
99 |
100 | def apply_regr(x, y, w, h, tx, ty, tw, th):
101 | try:
102 | cx = x + w/2.
103 | cy = y + h/2.
104 | cx1 = tx * w + cx
105 | cy1 = ty * h + cy
106 | w1 = math.exp(tw) * w
107 | h1 = math.exp(th) * h
108 | x1 = cx1 - w1/2.
109 | y1 = cy1 - h1/2.
110 | x1 = int(round(x1))
111 | y1 = int(round(y1))
112 | w1 = int(round(w1))
113 | h1 = int(round(h1))
114 |
115 | return x1, y1, w1, h1
116 |
117 | except ValueError:
118 | return x, y, w, h
119 | except OverflowError:
120 | return x, y, w, h
121 | except Exception as e:
122 | print(e)
123 | return x, y, w, h
124 |
125 |
126 | def apply_regr_np(X, T):
127 | try:
128 | x = X[0, :, :]
129 | y = X[1, :, :]
130 | w = X[2, :, :]
131 | h = X[3, :, :]
132 |
133 | tx = T[0, :, :]
134 | ty = T[1, :, :]
135 | tw = T[2, :, :]
136 | th = T[3, :, :]
137 |
138 | cx = x + w/2.
139 | cy = y + h/2.
140 | cx1 = tx * w + cx
141 | cy1 = ty * h + cy
142 |
143 | w1 = np.exp(tw.astype(np.float64)) * w
144 | h1 = np.exp(th.astype(np.float64)) * h
145 | x1 = cx1 - w1/2.
146 | y1 = cy1 - h1/2.
147 |
148 | x1 = np.round(x1)
149 | y1 = np.round(y1)
150 | w1 = np.round(w1)
151 | h1 = np.round(h1)
152 | return np.stack([x1, y1, w1, h1])
153 | except Exception as e:
154 | print(e)
155 | return X
156 |
157 |
158 | def non_max_suppression_fast(boxes, probs, overlap_thresh=0.9, max_boxes=300):
159 | # code used from here: http://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
160 | # if there are no boxes, return an empty list
161 | if len(boxes) == 0:
162 | return []
163 |
164 | # grab the coordinates of the bounding boxes
165 | x1 = boxes[:, 0]
166 | y1 = boxes[:, 1]
167 | x2 = boxes[:, 2]
168 | y2 = boxes[:, 3]
169 |
170 | np.testing.assert_array_less(x1, x2)
171 | np.testing.assert_array_less(y1, y2)
172 |
173 | # if the bounding boxes integers, convert them to floats --
174 | # this is important since we'll be doing a bunch of divisions
175 | if boxes.dtype.kind == "i":
176 | boxes = boxes.astype("float")
177 |
178 | # initialize the list of picked indexes
179 | pick = []
180 |
181 | # calculate the areas
182 | area = (x2 - x1) * (y2 - y1)
183 |
184 | # sort the bounding boxes
185 | idxs = np.argsort(probs)
186 |
187 | # keep looping while some indexes still remain in the indexes
188 | # list
189 | while len(idxs) > 0:
190 | # grab the last index in the indexes list and add the
191 | # index value to the list of picked indexes
192 | last = len(idxs) - 1
193 | i = idxs[last]
194 | pick.append(i)
195 |
196 | # find the intersection
197 |
198 | xx1_int = np.maximum(x1[i], x1[idxs[:last]])
199 | yy1_int = np.maximum(y1[i], y1[idxs[:last]])
200 | xx2_int = np.minimum(x2[i], x2[idxs[:last]])
201 | yy2_int = np.minimum(y2[i], y2[idxs[:last]])
202 |
203 | ww_int = np.maximum(0, xx2_int - xx1_int)
204 | hh_int = np.maximum(0, yy2_int - yy1_int)
205 |
206 | area_int = ww_int * hh_int
207 |
208 | # find the union
209 | area_union = area[i] + area[idxs[:last]] - area_int
210 |
211 | # compute the ratio of overlap
212 | overlap = area_int/(area_union + 1e-6)
213 |
214 | # delete all indexes from the index list that have
215 | idxs = np.delete(idxs, np.concatenate(([last],
216 | np.where(overlap > overlap_thresh)[0])))
217 |
218 | if len(pick) >= max_boxes:
219 | break
220 |
221 | # return only the bounding boxes that were picked using the integer data type
222 | boxes = boxes[pick].astype("int")
223 | probs = probs[pick]
224 | return boxes, probs
225 |
226 |
227 | def rpn_to_roi(rpn_layer, regr_layer, C, dim_ordering, use_regr=True, max_boxes=300,overlap_thresh=0.9):
228 |
229 | regr_layer = regr_layer / C.std_scaling
230 |
231 | anchor_sizes = C.anchor_box_scales
232 | anchor_ratios = C.anchor_box_ratios
233 |
234 | assert rpn_layer.shape[0] == 1
235 |
236 | if dim_ordering == 'th':
237 | (rows,cols) = rpn_layer.shape[2:]
238 |
239 | elif dim_ordering == 'tf':
240 | (rows, cols) = rpn_layer.shape[1:3]
241 |
242 | curr_layer = 0
243 | if dim_ordering == 'tf':
244 | A = np.zeros((4, rpn_layer.shape[1], rpn_layer.shape[2], rpn_layer.shape[3]))
245 | elif dim_ordering == 'th':
246 | A = np.zeros((4, rpn_layer.shape[2], rpn_layer.shape[3], rpn_layer.shape[1]))
247 |
248 | for anchor_size in anchor_sizes:
249 | for anchor_ratio in anchor_ratios:
250 |
251 | anchor_x = (anchor_size * anchor_ratio[0])/C.rpn_stride
252 | anchor_y = (anchor_size * anchor_ratio[1])/C.rpn_stride
253 | if dim_ordering == 'th':
254 | regr = regr_layer[0, 4 * curr_layer:4 * curr_layer + 4, :, :]
255 | else:
256 | regr = regr_layer[0, :, :, 4 * curr_layer:4 * curr_layer + 4]
257 | regr = np.transpose(regr, (2, 0, 1))
258 |
259 | X, Y = np.meshgrid(np.arange(cols),np. arange(rows))
260 |
261 | A[0, :, :, curr_layer] = X - anchor_x/2
262 | A[1, :, :, curr_layer] = Y - anchor_y/2
263 | A[2, :, :, curr_layer] = anchor_x
264 | A[3, :, :, curr_layer] = anchor_y
265 |
266 | if use_regr:
267 | A[:, :, :, curr_layer] = apply_regr_np(A[:, :, :, curr_layer], regr)
268 |
269 | A[2, :, :, curr_layer] = np.maximum(1, A[2, :, :, curr_layer])
270 | A[3, :, :, curr_layer] = np.maximum(1, A[3, :, :, curr_layer])
271 | A[2, :, :, curr_layer] += A[0, :, :, curr_layer]
272 | A[3, :, :, curr_layer] += A[1, :, :, curr_layer]
273 |
274 | A[0, :, :, curr_layer] = np.maximum(0, A[0, :, :, curr_layer])
275 | A[1, :, :, curr_layer] = np.maximum(0, A[1, :, :, curr_layer])
276 | A[2, :, :, curr_layer] = np.minimum(cols-1, A[2, :, :, curr_layer])
277 | A[3, :, :, curr_layer] = np.minimum(rows-1, A[3, :, :, curr_layer])
278 |
279 | curr_layer += 1
280 |
281 | all_boxes = np.reshape(A.transpose((0, 3, 1,2)), (4, -1)).transpose((1, 0))
282 | all_probs = rpn_layer.transpose((0, 3, 1, 2)).reshape((-1))
283 |
284 | x1 = all_boxes[:, 0]
285 | y1 = all_boxes[:, 1]
286 | x2 = all_boxes[:, 2]
287 | y2 = all_boxes[:, 3]
288 |
289 | idxs = np.where((x1 - x2 >= 0) | (y1 - y2 >= 0))
290 |
291 | all_boxes = np.delete(all_boxes, idxs, 0)
292 | all_probs = np.delete(all_probs, idxs, 0)
293 |
294 | result = non_max_suppression_fast(all_boxes, all_probs, overlap_thresh=overlap_thresh, max_boxes=max_boxes)[0]
295 |
296 | return result
297 |
--------------------------------------------------------------------------------
/keras_frcnn/simple_parser.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | def get_data(input_path):
5 | found_bg = False
6 | all_imgs = {}
7 |
8 | classes_count = {}
9 |
10 | class_mapping = {}
11 |
12 | visualise = True
13 |
14 | with open(input_path,'r') as f:
15 |
16 | print('Parsing annotation files')
17 |
18 | for line in f:
19 | line_split = line.strip().split(',')
20 | (filename,x1,y1,x2,y2,class_name) = line_split
21 |
22 | if class_name not in classes_count:
23 | classes_count[class_name] = 1
24 | else:
25 | classes_count[class_name] += 1
26 |
27 | if class_name not in class_mapping:
28 | if class_name == 'bg' and found_bg == False:
29 | print('Found class name with special name bg. Will be treated as a background region (this is usually for hard negative mining).')
30 | found_bg = True
31 | class_mapping[class_name] = len(class_mapping)
32 |
33 | if filename not in all_imgs:
34 | all_imgs[filename] = {}
35 |
36 | img = cv2.imread(filename)
37 | (rows,cols) = img.shape[:2]
38 | all_imgs[filename]['filepath'] = filename
39 | all_imgs[filename]['width'] = cols
40 | all_imgs[filename]['height'] = rows
41 | all_imgs[filename]['bboxes'] = []
42 | if np.random.randint(0,6) > 0:
43 | all_imgs[filename]['imageset'] = 'trainval'
44 | else:
45 | all_imgs[filename]['imageset'] = 'test'
46 |
47 | all_imgs[filename]['bboxes'].append({'class': class_name, 'x1': int(x1), 'x2': int(x2), 'y1': int(y1), 'y2': int(y2)})
48 |
49 |
50 | all_data = []
51 | for key in all_imgs:
52 | all_data.append(all_imgs[key])
53 |
54 | # make sure the bg class is last in the list
55 | if found_bg:
56 | if class_mapping['bg'] != len(class_mapping) - 1:
57 | key_to_switch = [key for key in class_mapping.keys() if class_mapping[key] == len(class_mapping)-1][0]
58 | val_to_switch = class_mapping['bg']
59 | class_mapping['bg'] = len(class_mapping) - 1
60 | class_mapping[key_to_switch] = val_to_switch
61 |
62 | return all_data, classes_count, class_mapping
63 |
64 |
65 |
--------------------------------------------------------------------------------
/keras_frcnn/vgg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """VGG16 model for Keras.
3 | # Reference
4 | - [Very Deep Convolutional Networks for Large-Scale Image Recognition](https://arxiv.org/abs/1409.1556)
5 | """
6 | from __future__ import print_function
7 | from __future__ import absolute_import
8 |
9 | import warnings
10 |
11 | from keras.models import Model
12 | from keras.layers import Flatten, Dense, Input, Conv2D, MaxPooling2D
13 | from keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, TimeDistributed
14 | from keras.engine.topology import get_source_inputs
15 | from keras.utils import layer_utils
16 | from keras.utils.data_utils import get_file
17 | from keras import backend as K
18 | from keras_frcnn.RoiPoolingConv import RoiPoolingConv
19 |
20 |
21 | def get_weight_path():
22 | if K.image_dim_ordering() == 'th':
23 | print('pretrained weights not available for VGG with theano backend')
24 | return
25 | else:
26 | return 'vgg16_weights_tf_dim_ordering_tf_kernels.h5'
27 |
28 |
29 | def get_img_output_length(width, height):
30 | def get_output_length(input_length):
31 | return input_length/16
32 |
33 | return get_output_length(width), get_output_length(height)
34 |
35 |
36 | def nn_base(input_tensor=None, trainable=False):
37 |
38 |
39 | # Determine proper input shape
40 | if K.image_dim_ordering() == 'th':
41 | input_shape = (3, None, None)
42 | else:
43 | input_shape = (None, None, 3)
44 |
45 | if input_tensor is None:
46 | img_input = Input(shape=input_shape)
47 | else:
48 | if not K.is_keras_tensor(input_tensor):
49 | img_input = Input(tensor=input_tensor, shape=input_shape)
50 | else:
51 | img_input = input_tensor
52 |
53 | if K.image_dim_ordering() == 'tf':
54 | bn_axis = 3
55 | else:
56 | bn_axis = 1
57 |
58 | # Block 1
59 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(img_input)
60 | x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
61 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
62 |
63 | # Block 2
64 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
65 | x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
66 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
67 |
68 | # Block 3
69 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
70 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
71 | x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
72 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
73 |
74 | # Block 4
75 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
76 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
77 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
78 | x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
79 |
80 | # Block 5
81 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
82 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
83 | x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
84 | # x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
85 |
86 | return x
87 |
88 |
89 | def rpn(base_layers, num_anchors):
90 |
91 | x = Conv2D(256, (3, 3), padding='same', activation='relu', kernel_initializer='normal', name='rpn_conv1')(base_layers)
92 |
93 | x_class = Conv2D(num_anchors, (1, 1), activation='sigmoid', kernel_initializer='uniform', name='rpn_out_class')(x)
94 | x_regr = Conv2D(num_anchors * 4, (1, 1), activation='linear', kernel_initializer='zero', name='rpn_out_regress')(x)
95 |
96 | return [x_class, x_regr, base_layers]
97 |
98 |
99 | def classifier(base_layers, input_rois, num_rois, nb_classes = 21, trainable=False):
100 |
101 | # compile times on theano tend to be very high, so we use smaller ROI pooling regions to workaround
102 |
103 | if K.backend() == 'tensorflow':
104 | pooling_regions = 7
105 | input_shape = (num_rois, 7, 7, 512)
106 | elif K.backend() == 'theano':
107 | pooling_regions = 7
108 | input_shape = (num_rois, 512, 7, 7)
109 |
110 | out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois])
111 |
112 | out = TimeDistributed(Flatten(name='flatten'))(out_roi_pool)
113 | out = TimeDistributed(Dense(4096, activation='relu', name='fc1'))(out)
114 | out = TimeDistributed(Dense(4096, activation='relu', name='fc2'))(out)
115 |
116 | out_class = TimeDistributed(Dense(nb_classes, activation='softmax', kernel_initializer='zero'), name='dense_class_{}'.format(nb_classes))(out)
117 | # note: no regression target for bg class
118 | out_regr = TimeDistributed(Dense(4 * (nb_classes-1), activation='linear', kernel_initializer='zero'), name='dense_regress_{}'.format(nb_classes))(out)
119 |
120 | return [out_class, out_regr]
121 |
122 |
123 |
--------------------------------------------------------------------------------
/keras_frcnn/xception.py:
--------------------------------------------------------------------------------
1 | """Xception V1 model for Keras.
2 | On ImageNet, this model gets to a top-1 validation accuracy of 0.790
3 | and a top-5 validation accuracy of 0.945.
4 | Do note that the input image format for this model is different than for
5 | the VGG16 and ResNet models (299x299 instead of 224x224),
6 | and that the input preprocessing function
7 | is also different (same as Inception V3).
8 | # Reference
9 | - [Xception: Deep Learning with Depthwise Separable Convolutions](
10 | https://arxiv.org/abs/1610.02357)
11 | """
12 |
13 | from __future__ import absolute_import
14 | from __future__ import division
15 | from __future__ import print_function
16 |
17 | from keras.layers import Input, add, Dense, Activation, Flatten, Conv2D, MaxPooling2D, SeparableConv2D, BatchNormalization, GlobalAveragePooling2D, AveragePooling2D, TimeDistributed
18 |
19 | from keras import backend as K
20 |
21 | from keras_frcnn.RoiPoolingConv import RoiPoolingConv
22 | from keras_frcnn.FixedBatchNormalization import FixedBatchNormalization
23 |
24 |
25 | def get_weight_path():
26 | if K.image_dim_ordering() == 'th':
27 | return 'xception_weights_tf_dim_ordering_tf_kernels_notop.h5'
28 | else:
29 | return 'xception_weights_tf_dim_ordering_tf_kernels.h5'
30 |
31 |
32 | def get_img_output_length(width, height):
33 | def get_output_length(input_length):
34 | filter_sizes = [3, 3, 1, 1, 1, 1]
35 | strides = [2, 1, 2, 2, 2, 2]
36 |
37 | assert len(filter_sizes) == len(strides)
38 |
39 | for i in range(len(filter_sizes)):
40 | input_length = (input_length - filter_sizes[i]) // strides[i] + 1
41 |
42 | return input_length
43 |
44 | return get_output_length(width), get_output_length(height)
45 |
46 |
47 | def nn_base(input_tensor=None, trainable=False):
48 |
49 | # Determine proper input shape
50 | if K.image_dim_ordering() == 'th':
51 | input_shape = (3, None, None)
52 | else:
53 | input_shape = (None, None, 3)
54 |
55 | if input_tensor is None:
56 | img_input = Input(shape=input_shape)
57 | else:
58 | if not K.is_keras_tensor(input_tensor):
59 | img_input = Input(tensor=input_tensor, shape=input_shape)
60 | else:
61 | img_input = input_tensor
62 |
63 | if K.image_dim_ordering() == 'tf':
64 | bn_axis = 3
65 | else:
66 | bn_axis = 1
67 |
68 | x = Conv2D(32, (3, 3),
69 | strides=(2, 2),
70 | use_bias=False,
71 | name='block1_conv1')(img_input)
72 | x = BatchNormalization(name='block1_conv1_bn')(x)
73 | x = Activation('relu', name='block1_conv1_act')(x)
74 | x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
75 | x = BatchNormalization(name='block1_conv2_bn')(x)
76 | x = Activation('relu', name='block1_conv2_act')(x)
77 |
78 | residual = Conv2D(128, (1, 1),
79 | strides=(2, 2),
80 | padding='same',
81 | use_bias=False)(x)
82 | residual = BatchNormalization()(residual)
83 |
84 | x = SeparableConv2D(128, (3, 3),
85 | padding='same',
86 | use_bias=False,
87 | name='block2_sepconv1')(x)
88 | x = BatchNormalization(name='block2_sepconv1_bn')(x)
89 | x = Activation('relu', name='block2_sepconv2_act')(x)
90 | x = SeparableConv2D(128, (3, 3),
91 | padding='same',
92 | use_bias=False,
93 | name='block2_sepconv2')(x)
94 | x = BatchNormalization(name='block2_sepconv2_bn')(x)
95 |
96 | x = MaxPooling2D((3, 3),
97 | strides=(2, 2),
98 | padding='same',
99 | name='block2_pool')(x)
100 | x = add([x, residual])
101 |
102 | residual = Conv2D(256, (1, 1), strides=(2, 2),
103 | padding='same', use_bias=False)(x)
104 | residual = BatchNormalization()(residual)
105 |
106 | x = Activation('relu', name='block3_sepconv1_act')(x)
107 | x = SeparableConv2D(256, (3, 3),
108 | padding='same',
109 | use_bias=False,
110 | name='block3_sepconv1')(x)
111 | x = BatchNormalization(name='block3_sepconv1_bn')(x)
112 | x = Activation('relu', name='block3_sepconv2_act')(x)
113 | x = SeparableConv2D(256, (3, 3),
114 | padding='same',
115 | use_bias=False,
116 | name='block3_sepconv2')(x)
117 | x = BatchNormalization(name='block3_sepconv2_bn')(x)
118 |
119 | x = MaxPooling2D((3, 3), strides=(2, 2),
120 | padding='same',
121 | name='block3_pool')(x)
122 | x = add([x, residual])
123 |
124 | residual = Conv2D(728, (1, 1),
125 | strides=(2, 2),
126 | padding='same',
127 | use_bias=False)(x)
128 | residual = BatchNormalization()(residual)
129 |
130 | x = Activation('relu', name='block4_sepconv1_act')(x)
131 | x = SeparableConv2D(728, (3, 3),
132 | padding='same',
133 | use_bias=False,
134 | name='block4_sepconv1')(x)
135 | x = BatchNormalization(name='block4_sepconv1_bn')(x)
136 | x = Activation('relu', name='block4_sepconv2_act')(x)
137 | x = SeparableConv2D(728, (3, 3),
138 | padding='same',
139 | use_bias=False,
140 | name='block4_sepconv2')(x)
141 | x = BatchNormalization(name='block4_sepconv2_bn')(x)
142 |
143 | x = MaxPooling2D((3, 3), strides=(2, 2),
144 | padding='same',
145 | name='block4_pool')(x)
146 | x = add([x, residual])
147 |
148 | for i in range(8):
149 | residual = x
150 | prefix = 'block' + str(i + 5)
151 |
152 | x = Activation('relu', name=prefix + '_sepconv1_act')(x)
153 | x = SeparableConv2D(728, (3, 3),
154 | padding='same',
155 | use_bias=False,
156 | name=prefix + '_sepconv1')(x)
157 | x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
158 | x = Activation('relu', name=prefix + '_sepconv2_act')(x)
159 | x = SeparableConv2D(728, (3, 3),
160 | padding='same',
161 | use_bias=False,
162 | name=prefix + '_sepconv2')(x)
163 | x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
164 | x = Activation('relu', name=prefix + '_sepconv3_act')(x)
165 | x = SeparableConv2D(728, (3, 3),
166 | padding='same',
167 | use_bias=False,
168 | name=prefix + '_sepconv3')(x)
169 | x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
170 |
171 | x = add([x, residual])
172 |
173 | residual = Conv2D(1024, (1, 1), strides=(2, 2),
174 | padding='same', use_bias=False)(x)
175 | residual = BatchNormalization()(residual)
176 |
177 | x = Activation('relu', name='block13_sepconv1_act')(x)
178 | x = SeparableConv2D(728, (3, 3),
179 | padding='same',
180 | use_bias=False,
181 | name='block13_sepconv1')(x)
182 | x = BatchNormalization(name='block13_sepconv1_bn')(x)
183 | x = Activation('relu', name='block13_sepconv2_act')(x)
184 | x = SeparableConv2D(1024, (3, 3),
185 | padding='same',
186 | use_bias=False,
187 | name='block13_sepconv2')(x)
188 | x = BatchNormalization(name='block13_sepconv2_bn')(x)
189 |
190 | x = MaxPooling2D((3, 3),
191 | strides=(2, 2),
192 | padding='same',
193 | name='block13_pool')(x)
194 | x = add([x, residual])
195 |
196 | return x
197 |
198 |
199 | def classifier_layers(x, input_shape, trainable=False):
200 |
201 | # compile times on theano tend to be very high, so we use smaller ROI pooling regions to workaround
202 | # (hence a smaller stride in the region that follows the ROI pool)
203 | x = TimeDistributed(SeparableConv2D(1536, (3, 3),
204 | padding='same',
205 | use_bias=False),
206 | name='block14_sepconv1')(x)
207 | x = TimeDistributed(BatchNormalization(), name='block14_sepconv1_bn')(x)
208 | x = Activation('relu', name='block14_sepconv1_act')(x)
209 |
210 | x = TimeDistributed(SeparableConv2D(2048, (3, 3),
211 | padding='same',
212 | use_bias=False),
213 | name='block14_sepconv2')(x)
214 | x = TimeDistributed(BatchNormalization(), name='block14_sepconv2_bn')(x)
215 | x = Activation('relu', name='block14_sepconv2_act')(x)
216 |
217 | TimeDistributed(GlobalAveragePooling2D(), name='avg_pool')(x)
218 |
219 | return x
220 |
221 |
222 | def rpn(base_layers, num_anchors):
223 |
224 | x = Conv2D(1024, (3, 3), padding='same', activation='relu', kernel_initializer='normal', name='rpn_conv1')(base_layers)
225 |
226 | x_class = Conv2D(num_anchors, (1, 1), activation='sigmoid', kernel_initializer='uniform', name='rpn_out_class')(x)
227 | x_regr = Conv2D(num_anchors * 4, (1, 1), activation='linear', kernel_initializer='zero', name='rpn_out_regress')(x)
228 |
229 | return [x_class, x_regr, base_layers]
230 |
231 |
232 | def classifier(base_layers, input_rois, num_rois, nb_classes=21, trainable=False):
233 |
234 | # compile times on theano tend to be very high, so we use smaller ROI pooling regions to workaround
235 |
236 | if K.backend() == 'tensorflow':
237 | pooling_regions = 14
238 | input_shape = (num_rois, 14, 14, 1024)
239 | elif K.backend() == 'theano':
240 | pooling_regions = 7
241 | input_shape = (num_rois, 1024, 7, 7)
242 |
243 | out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois])
244 | out = classifier_layers(out_roi_pool, input_shape=input_shape, trainable=True)
245 |
246 | out = TimeDistributed(Flatten())(out)
247 |
248 | out_class = TimeDistributed(Dense(nb_classes, activation='softmax', kernel_initializer='zero'), name='dense_class_{}'.format(nb_classes))(out)
249 | # note: no regression target for bg class
250 | out_regr = TimeDistributed(Dense(4 * (nb_classes-1), activation='linear', kernel_initializer='zero'), name='dense_regress_{}'.format(nb_classes))(out)
251 | return [out_class, out_regr]
252 |
253 |
--------------------------------------------------------------------------------
/measure_map.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import sys
5 | import pickle
6 | from optparse import OptionParser
7 | import time
8 | from keras_frcnn import config
9 | import keras_frcnn.resnet as nn
10 | from keras import backend as K
11 | from keras.layers import Input
12 | from keras.models import Model
13 | from keras_frcnn import roi_helpers
14 | from keras_frcnn import data_generators
15 | from sklearn.metrics import average_precision_score
16 |
17 |
18 | def get_map(pred, gt, f):
19 | T = {}
20 | P = {}
21 | fx, fy = f
22 |
23 | for bbox in gt:
24 | bbox['bbox_matched'] = False
25 |
26 | pred_probs = np.array([s['prob'] for s in pred])
27 | box_idx_sorted_by_prob = np.argsort(pred_probs)[::-1]
28 |
29 | for box_idx in box_idx_sorted_by_prob:
30 | pred_box = pred[box_idx]
31 | pred_class = pred_box['class']
32 | pred_x1 = pred_box['x1']
33 | pred_x2 = pred_box['x2']
34 | pred_y1 = pred_box['y1']
35 | pred_y2 = pred_box['y2']
36 | pred_prob = pred_box['prob']
37 | if pred_class not in P:
38 | P[pred_class] = []
39 | T[pred_class] = []
40 | P[pred_class].append(pred_prob)
41 | found_match = False
42 |
43 | for gt_box in gt:
44 | gt_class = gt_box['class']
45 | gt_x1 = gt_box['x1']/fx
46 | gt_x2 = gt_box['x2']/fx
47 | gt_y1 = gt_box['y1']/fy
48 | gt_y2 = gt_box['y2']/fy
49 | gt_seen = gt_box['bbox_matched']
50 | if gt_class != pred_class:
51 | continue
52 | if gt_seen:
53 | continue
54 | iou = data_generators.iou((pred_x1, pred_y1, pred_x2, pred_y2), (gt_x1, gt_y1, gt_x2, gt_y2))
55 | if iou >= 0.5:
56 | found_match = True
57 | gt_box['bbox_matched'] = True
58 | break
59 | else:
60 | continue
61 |
62 | T[pred_class].append(int(found_match))
63 |
64 | for gt_box in gt:
65 | if not gt_box['bbox_matched'] and not gt_box['difficult']:
66 | if gt_box['class'] not in P:
67 | P[gt_box['class']] = []
68 | T[gt_box['class']] = []
69 |
70 | T[gt_box['class']].append(1)
71 | P[gt_box['class']].append(0)
72 |
73 | #import pdb
74 | #pdb.set_trace()
75 | return T, P
76 |
77 | sys.setrecursionlimit(40000)
78 |
79 | parser = OptionParser()
80 |
81 | parser.add_option("-p", "--path", dest="test_path", help="Path to test data.")
82 | parser.add_option("-n", "--num_rois", dest="num_rois",
83 | help="Number of ROIs per iteration. Higher means more memory use.", default=32)
84 | parser.add_option("--config_filename", dest="config_filename", help=
85 | "Location to read the metadata related to the training (generated when training).",
86 | default="config.pickle")
87 | parser.add_option("-o", "--parser", dest="parser", help="Parser to use. One of simple or pascal_voc",
88 | default="pascal_voc"),
89 |
90 | (options, args) = parser.parse_args()
91 |
92 | if not options.test_path: # if filename is not given
93 | parser.error('Error: path to test data must be specified. Pass --path to command line')
94 |
95 |
96 | if options.parser == 'pascal_voc':
97 | from keras_frcnn.pascal_voc_parser import get_data
98 | elif options.parser == 'simple':
99 | from keras_frcnn.simple_parser import get_data
100 | else:
101 | raise ValueError("Command line option parser must be one of 'pascal_voc' or 'simple'")
102 |
103 | config_output_filename = options.config_filename
104 |
105 | with open(config_output_filename, 'rb') as f_in:
106 | C = pickle.load(f_in)
107 |
108 | # turn off any data augmentation at test time
109 | C.use_horizontal_flips = False
110 | C.use_vertical_flips = False
111 | C.rot_90 = False
112 |
113 | img_path = options.test_path
114 |
115 |
116 | def format_img(img, C):
117 | img_min_side = float(C.im_size)
118 | (height, width, _) = img.shape
119 |
120 | if width <= height:
121 | f = float(img_min_side) / width
122 | new_height = int(f * height)
123 | new_width = int(img_min_side)
124 | else:
125 | f = float(img_min_side) / height
126 | new_width = int(f * width)
127 | new_height = int(img_min_side)
128 |
129 | fx = width / float(new_width)
130 | fy = height / float(new_height)
131 |
132 | img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
133 | img = img[:, :, (2, 1, 0)]
134 | img = img.astype(np.float32)
135 | img[:, :, 0] -= C.img_channel_mean[0]
136 | img[:, :, 1] -= C.img_channel_mean[1]
137 | img[:, :, 2] -= C.img_channel_mean[2]
138 | img /= C.img_scaling_factor
139 | img = np.transpose(img, (2, 0, 1))
140 | img = np.expand_dims(img, axis=0)
141 | return img, fx, fy
142 |
143 |
144 | class_mapping = C.class_mapping
145 |
146 | if 'bg' not in class_mapping:
147 | class_mapping['bg'] = len(class_mapping)
148 |
149 | class_mapping = {v: k for k, v in class_mapping.items()}
150 | print(class_mapping)
151 |
152 | class_to_color = {class_mapping[v]: np.random.randint(0, 255, 3) for v in class_mapping}
153 | C.num_rois = int(options.num_rois)
154 |
155 | if K.image_dim_ordering() == 'th':
156 | input_shape_img = (3, None, None)
157 | input_shape_features = (1024, None, None)
158 | else:
159 | input_shape_img = (None, None, 3)
160 | input_shape_features = (None, None, 1024)
161 |
162 |
163 | # input placeholder 정의
164 | img_input = Input(shape=input_shape_img)
165 | roi_input = Input(shape=(C.num_rois, 4))
166 | feature_map_input = Input(shape=input_shape_features) #??
167 |
168 | # define the base network (resnet here, can be VGG, Inception, etc)
169 | shared_layers = nn.nn_base(img_input, trainable=True)
170 |
171 | # define the RPN, built on the base layers
172 | num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
173 | rpn_layers = nn.rpn(shared_layers, num_anchors)
174 |
175 | classifier = nn.classifier(feature_map_input, roi_input, C.num_rois, nb_classes=len(class_mapping), trainable=True)
176 |
177 | model_rpn = Model(img_input, rpn_layers)
178 | model_classifier_only = Model([feature_map_input, roi_input], classifier)
179 |
180 | model_classifier = Model([feature_map_input, roi_input], classifier)
181 |
182 | model_rpn.load_weights(C.model_path, by_name=True)
183 | model_classifier.load_weights(C.model_path, by_name=True)
184 |
185 | model_rpn.compile(optimizer='sgd', loss='mse')
186 | model_classifier.compile(optimizer='sgd', loss='mse')
187 |
188 | all_imgs, _, _ = get_data(options.test_path)
189 | test_imgs = [s for s in all_imgs if s['imageset'] == 'test']
190 |
191 |
192 | T = {}
193 | P = {}
194 | for idx, img_data in enumerate(test_imgs):
195 | print('{}/{}'.format(idx, len(test_imgs)))
196 | st = time.time()
197 | filepath = img_data['filepath']
198 |
199 | # read image
200 | img = cv2.imread(filepath)
201 | X, fx, fy = format_img(img, C)
202 |
203 | if K.image_dim_ordering() == 'tf':
204 | X = np.transpose(X, (0, 2, 3, 1))
205 |
206 | # get the feature maps and output from the RPN
207 | [Y1, Y2, F] = model_rpn.predict(X)
208 |
209 | R = roi_helpers.rpn_to_roi(Y1, Y2, C, K.image_dim_ordering(), overlap_thresh=0.7)
210 |
211 | # convert from (x1,y1,x2,y2) to (x,y,w,h)
212 | R[:, 2] -= R[:, 0]
213 | R[:, 3] -= R[:, 1]
214 |
215 | # apply the spatial pyramid pooling to the proposed regions
216 | bboxes = {}
217 | probs = {}
218 |
219 | for jk in range(R.shape[0] // C.num_rois + 1):
220 | ROIs = np.expand_dims(R[C.num_rois * jk:C.num_rois * (jk + 1), :], axis=0)
221 | if ROIs.shape[1] == 0:
222 | break
223 |
224 | if jk == R.shape[0] // C.num_rois:
225 | # pad R
226 | curr_shape = ROIs.shape
227 | target_shape = (curr_shape[0], C.num_rois, curr_shape[2])
228 | ROIs_padded = np.zeros(target_shape).astype(ROIs.dtype)
229 | ROIs_padded[:, :curr_shape[1], :] = ROIs
230 | ROIs_padded[0, curr_shape[1]:, :] = ROIs[0, 0, :]
231 | ROIs = ROIs_padded
232 |
233 | [P_cls, P_regr] = model_classifier_only.predict([F, ROIs])
234 |
235 | for ii in range(P_cls.shape[1]):
236 |
237 | if np.argmax(P_cls[0, ii, :]) == (P_cls.shape[2] - 1):
238 | continue
239 |
240 | cls_name = class_mapping[np.argmax(P_cls[0, ii, :])]
241 |
242 | if cls_name not in bboxes:
243 | bboxes[cls_name] = []
244 | probs[cls_name] = []
245 |
246 | (x, y, w, h) = ROIs[0, ii, :]
247 |
248 | cls_num = np.argmax(P_cls[0, ii, :])
249 | try:
250 | (tx, ty, tw, th) = P_regr[0, ii, 4 * cls_num:4 * (cls_num + 1)]
251 | tx /= C.classifier_regr_std[0]
252 | ty /= C.classifier_regr_std[1]
253 | tw /= C.classifier_regr_std[2]
254 | th /= C.classifier_regr_std[3]
255 | x, y, w, h = roi_helpers.apply_regr(x, y, w, h, tx, ty, tw, th)
256 | except:
257 | pass
258 | bboxes[cls_name].append([16 * x, 16 * y, 16 * (x + w), 16 * (y + h)])
259 | probs[cls_name].append(np.max(P_cls[0, ii, :]))
260 |
261 | all_dets = []
262 |
263 | for key in bboxes:
264 | bbox = np.array(bboxes[key])
265 |
266 | new_boxes, new_probs = roi_helpers.non_max_suppression_fast(bbox, np.array(probs[key]), overlap_thresh=0.5)
267 | for jk in range(new_boxes.shape[0]):
268 | (x1, y1, x2, y2) = new_boxes[jk, :]
269 | det = {'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': key, 'prob': new_probs[jk]}
270 | all_dets.append(det)
271 |
272 |
273 | print('Elapsed time = {}'.format(time.time() - st))
274 | t, p = get_map(all_dets, img_data['bboxes'], (fx, fy))
275 | for key in t.keys():
276 | if key not in T:
277 | T[key] = []
278 | P[key] = []
279 | T[key].extend(t[key])
280 | P[key].extend(p[key])
281 | all_aps = []
282 | for key in T.keys():
283 | ap = average_precision_score(T[key], P[key])
284 | print('{} AP: {}'.format(key, ap))
285 | all_aps.append(ap)
286 | print('mAP = {}'.format(np.mean(np.array(all_aps))))
287 | #print(T)
288 | #print(P)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.8.0
2 | astor==0.8.0
3 | backports.weakref==1.0.post1
4 | enum34==1.1.6
5 | funcsigs==1.0.2
6 | futures==3.3.0
7 | gast==0.2.2
8 | google-pasta==0.1.7
9 | grpcio==1.23.0
10 | h5py==2.9.0
11 | Keras==2.0.3
12 | Keras-Applications==1.0.8
13 | Keras-Preprocessing==1.1.0
14 | Markdown==3.1.1
15 | mock==3.0.5
16 | numpy==1.16.5
17 | opencv-python==4.1.0.25
18 | protobuf==3.9.1
19 | PyYAML==5.1.2
20 | scipy==1.2.2
21 | six==1.12.0
22 | tensorboard==1.14.0
23 | tensorflow==1.14.0
24 | tensorflow-estimator==1.14.0
25 | termcolor==1.1.0
26 | Theano==1.0.4
27 | Werkzeug==0.15.5
28 | wrapt==1.11.2
29 |
--------------------------------------------------------------------------------
/results_imgs/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/results_imgs/0.png
--------------------------------------------------------------------------------
/results_imgs/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/results_imgs/1.png
--------------------------------------------------------------------------------
/results_imgs/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/results_imgs/2.png
--------------------------------------------------------------------------------
/results_imgs/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/results_imgs/3.png
--------------------------------------------------------------------------------
/results_imgs/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/you359/Keras-FasterRCNN/eb67ad5d946581344f614faa1e3ee7902f429ce3/results_imgs/4.png
--------------------------------------------------------------------------------
/test_frcnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import cv2
4 | import numpy as np
5 | import sys
6 | import pickle
7 | from optparse import OptionParser
8 | import time
9 | from keras_frcnn import config
10 | from keras import backend as K
11 | from keras.layers import Input
12 | from keras.models import Model
13 | from keras_frcnn import roi_helpers
14 |
15 | # Set learning phase to 0 for model.predict. Set to 1 for training
16 | K.set_learning_phase(0)
17 |
18 |
19 | sys.setrecursionlimit(40000)
20 |
21 | parser = OptionParser()
22 |
23 | parser.add_option("-p", "--path", dest="test_path", help="Path to test data.")
24 | parser.add_option("-n", "--num_rois", dest="num_rois",
25 | help="Number of ROIs per iteration. Higher means more memory use.", default=32)
26 | parser.add_option("--config_filename", dest="config_filename", help="Location to read the metadata related to the training (generated when training).",
27 | default="config.pickle")
28 | parser.add_option("--network", dest="network", help="Base network to use. Supports vgg or resnet50.", default='resnet50')
29 |
30 | (options, args) = parser.parse_args()
31 |
32 | if not options.test_path: # if filename is not given
33 | parser.error('Error: path to test data must be specified. Pass --path to command line')
34 |
35 |
36 | config_output_filename = options.config_filename
37 |
38 | with open(config_output_filename, 'rb') as f_in:
39 | C = pickle.load(f_in)
40 |
41 | if C.network == 'resnet50':
42 | import keras_frcnn.resnet as nn
43 | elif C.network == 'xception':
44 | import keras_frcnn.xception as nn
45 | elif C.network == 'inception_resnet_v2':
46 | import keras_frcnn.inception_resnet_v2 as nn
47 | elif C.network == 'vgg':
48 | import keras_frcnn.vgg as nn
49 |
50 | # turn off any data augmentation at test time
51 | C.use_horizontal_flips = False
52 | C.use_vertical_flips = False
53 | C.rot_90 = False
54 |
55 | img_path = options.test_path
56 |
57 | def format_img_size(img, C):
58 | """ formats the image size based on config """
59 | img_min_side = float(C.im_size)
60 | (height, width ,_) = img.shape
61 |
62 | if width <= height:
63 | ratio = img_min_side/width
64 | new_height = int(ratio * height)
65 | new_width = int(img_min_side)
66 | else:
67 | ratio = img_min_side/height
68 | new_width = int(ratio * width)
69 | new_height = int(img_min_side)
70 | img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_CUBIC)
71 | return img, ratio
72 |
73 | def format_img_channels(img, C):
74 | """ formats the image channels based on config """
75 | img = img[:, :, (2, 1, 0)]
76 | img = img.astype(np.float32)
77 | img[:, :, 0] -= C.img_channel_mean[0]
78 | img[:, :, 1] -= C.img_channel_mean[1]
79 | img[:, :, 2] -= C.img_channel_mean[2]
80 | img /= C.img_scaling_factor
81 | img = np.transpose(img, (2, 0, 1))
82 | img = np.expand_dims(img, axis=0)
83 | return img
84 |
85 | def format_img(img, C):
86 | """ formats an image for model prediction based on config """
87 | img, ratio = format_img_size(img, C)
88 | img = format_img_channels(img, C)
89 | return img, ratio
90 |
91 | # Method to transform the coordinates of the bounding box to its original size
92 | def get_real_coordinates(ratio, x1, y1, x2, y2):
93 |
94 | real_x1 = int(round(x1 // ratio))
95 | real_y1 = int(round(y1 // ratio))
96 | real_x2 = int(round(x2 // ratio))
97 | real_y2 = int(round(y2 // ratio))
98 |
99 | return (real_x1, real_y1, real_x2 ,real_y2)
100 |
101 | class_mapping = C.class_mapping
102 |
103 | if 'bg' not in class_mapping:
104 | class_mapping['bg'] = len(class_mapping)
105 |
106 | class_mapping = {v: k for k, v in class_mapping.items()}
107 | print(class_mapping)
108 | class_to_color = {class_mapping[v]: np.random.randint(0, 255, 3) for v in class_mapping}
109 | C.num_rois = int(options.num_rois)
110 |
111 | if C.network == 'resnet50':
112 | num_features = 1024
113 | elif C.network == 'xception':
114 | num_features = 1024
115 | elif C.network == 'inception_resnet_v2':
116 | num_features = 1088
117 | elif C.network == 'vgg':
118 | num_features = 512
119 |
120 | if K.image_dim_ordering() == 'th':
121 | input_shape_img = (3, None, None)
122 | input_shape_features = (num_features, None, None)
123 | else:
124 | input_shape_img = (None, None, 3)
125 | input_shape_features = (None, None, num_features)
126 |
127 |
128 | img_input = Input(shape=input_shape_img)
129 | roi_input = Input(shape=(C.num_rois, 4))
130 | feature_map_input = Input(shape=input_shape_features)
131 |
132 | # define the base network (resnet here, can be VGG, Inception, etc)
133 | shared_layers = nn.nn_base(img_input, trainable=True)
134 |
135 | # define the RPN, built on the base layers
136 | num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
137 | rpn_layers = nn.rpn(shared_layers, num_anchors)
138 |
139 | classifier = nn.classifier(feature_map_input, roi_input, C.num_rois, nb_classes=len(class_mapping), trainable=True)
140 |
141 | model_rpn = Model(img_input, rpn_layers)
142 |
143 | model_classifier = Model([feature_map_input, roi_input], classifier)
144 |
145 | print('Loading weights from {}'.format(C.model_path))
146 | model_rpn.load_weights(C.model_path, by_name=True)
147 | model_classifier.load_weights(C.model_path, by_name=True)
148 |
149 | model_rpn.compile(optimizer='sgd', loss='mse')
150 | model_classifier.compile(optimizer='sgd', loss='mse')
151 |
152 | all_imgs = []
153 |
154 | classes = {}
155 |
156 | bbox_threshold = 0.8
157 |
158 | visualise = True
159 |
160 | for idx, img_name in enumerate(sorted(os.listdir(img_path))):
161 | if not img_name.lower().endswith(('.bmp', '.jpeg', '.jpg', '.png', '.tif', '.tiff')):
162 | continue
163 | print(img_name)
164 | st = time.time()
165 | filepath = os.path.join(img_path,img_name)
166 |
167 | img = cv2.imread(filepath)
168 |
169 | X, ratio = format_img(img, C)
170 |
171 | if K.image_dim_ordering() == 'tf':
172 | X = np.transpose(X, (0, 2, 3, 1))
173 |
174 | # get the feature maps and output from the RPN
175 | [Y1, Y2, F] = model_rpn.predict(X)
176 |
177 |
178 | R = roi_helpers.rpn_to_roi(Y1, Y2, C, K.image_dim_ordering(), overlap_thresh=0.7)
179 |
180 | # convert from (x1,y1,x2,y2) to (x,y,w,h)
181 | R[:, 2] -= R[:, 0]
182 | R[:, 3] -= R[:, 1]
183 |
184 | # apply the spatial pyramid pooling to the proposed regions
185 | bboxes = {}
186 | probs = {}
187 |
188 | for jk in range(R.shape[0]//C.num_rois + 1):
189 | ROIs = np.expand_dims(R[C.num_rois*jk:C.num_rois*(jk+1), :], axis=0)
190 | if ROIs.shape[1] == 0:
191 | break
192 |
193 | if jk == R.shape[0]//C.num_rois:
194 | #pad R
195 | curr_shape = ROIs.shape
196 | target_shape = (curr_shape[0],C.num_rois,curr_shape[2])
197 | ROIs_padded = np.zeros(target_shape).astype(ROIs.dtype)
198 | ROIs_padded[:, :curr_shape[1], :] = ROIs
199 | ROIs_padded[0, curr_shape[1]:, :] = ROIs[0, 0, :]
200 | ROIs = ROIs_padded
201 |
202 | [P_cls, P_regr] = model_classifier.predict([F, ROIs])
203 |
204 | for ii in range(P_cls.shape[1]):
205 |
206 | if np.max(P_cls[0, ii, :]) < bbox_threshold or np.argmax(P_cls[0, ii, :]) == (P_cls.shape[2] - 1):
207 | continue
208 |
209 | cls_name = class_mapping[np.argmax(P_cls[0, ii, :])]
210 |
211 | if cls_name not in bboxes:
212 | bboxes[cls_name] = []
213 | probs[cls_name] = []
214 |
215 | (x, y, w, h) = ROIs[0, ii, :]
216 |
217 | cls_num = np.argmax(P_cls[0, ii, :])
218 | try:
219 | (tx, ty, tw, th) = P_regr[0, ii, 4*cls_num:4*(cls_num+1)]
220 | tx /= C.classifier_regr_std[0]
221 | ty /= C.classifier_regr_std[1]
222 | tw /= C.classifier_regr_std[2]
223 | th /= C.classifier_regr_std[3]
224 | x, y, w, h = roi_helpers.apply_regr(x, y, w, h, tx, ty, tw, th)
225 | except:
226 | pass
227 | bboxes[cls_name].append([C.rpn_stride*x, C.rpn_stride*y, C.rpn_stride*(x+w), C.rpn_stride*(y+h)])
228 | probs[cls_name].append(np.max(P_cls[0, ii, :]))
229 |
230 | all_dets = []
231 |
232 | for key in bboxes:
233 | bbox = np.array(bboxes[key])
234 |
235 | new_boxes, new_probs = roi_helpers.non_max_suppression_fast(bbox, np.array(probs[key]), overlap_thresh=0.5)
236 | for jk in range(new_boxes.shape[0]):
237 | (x1, y1, x2, y2) = new_boxes[jk,:]
238 |
239 | (real_x1, real_y1, real_x2, real_y2) = get_real_coordinates(ratio, x1, y1, x2, y2)
240 |
241 | cv2.rectangle(img,(real_x1, real_y1), (real_x2, real_y2), (int(class_to_color[key][0]), int(class_to_color[key][1]), int(class_to_color[key][2])),2)
242 |
243 | textLabel = '{}: {}'.format(key,int(100*new_probs[jk]))
244 | all_dets.append((key,100*new_probs[jk]))
245 |
246 | (retval,baseLine) = cv2.getTextSize(textLabel,cv2.FONT_HERSHEY_COMPLEX,1,1)
247 | textOrg = (real_x1, real_y1-0)
248 |
249 | cv2.rectangle(img, (textOrg[0] - 5, textOrg[1]+baseLine - 5), (textOrg[0]+retval[0] + 5, textOrg[1]-retval[1] - 5), (0, 0, 0), 2)
250 | cv2.rectangle(img, (textOrg[0] - 5,textOrg[1]+baseLine - 5), (textOrg[0]+retval[0] + 5, textOrg[1]-retval[1] - 5), (255, 255, 255), -1)
251 | cv2.putText(img, textLabel, textOrg, cv2.FONT_HERSHEY_DUPLEX, 1, (0, 0, 0), 1)
252 |
253 | print('Elapsed time = {}'.format(time.time() - st))
254 | print(all_dets)
255 | #cv2.imshow('img', img)
256 | #cv2.waitKey(0)
257 | cv2.imwrite('./results_imgs/{}.png'.format(idx),img)
258 |
--------------------------------------------------------------------------------
/train_frcnn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import division
3 | import random
4 | import pprint
5 | import sys
6 | import time
7 | import numpy as np
8 | from optparse import OptionParser
9 | import pickle
10 | import os
11 |
12 | import tensorflow as tf
13 | from keras import backend as K
14 | from keras.optimizers import Adam, SGD, RMSprop
15 | from keras.layers import Input
16 | from keras.models import Model
17 | from keras_frcnn import config, data_generators
18 | from keras_frcnn import losses as losses
19 | import keras_frcnn.roi_helpers as roi_helpers
20 | from keras.utils import generic_utils
21 | from keras.callbacks import TensorBoard
22 |
23 |
24 | # tensorboard 로그 작성 함수
25 | def write_log(callback, names, logs, batch_no):
26 | for name, value in zip(names, logs):
27 | summary = tf.Summary()
28 | summary_value = summary.value.add()
29 | summary_value.simple_value = value
30 | summary_value.tag = name
31 | callback.writer.add_summary(summary, batch_no)
32 | callback.writer.flush()
33 |
34 | sys.setrecursionlimit(40000)
35 |
36 | parser = OptionParser()
37 |
38 | parser.add_option("-p", "--path", dest="train_path", help="Path to training data.")
39 | parser.add_option("-o", "--parser", dest="parser", help="Parser to use. One of simple or pascal_voc",
40 | default="pascal_voc")
41 | parser.add_option("-n", "--num_rois", dest="num_rois", help="Number of RoIs to process at once.", default=32)
42 | parser.add_option("--network", dest="network", help="Base network to use. Supports vgg, xception, inception_resnet_v2 or resnet50.", default='resnet50')
43 | parser.add_option("--hf", dest="horizontal_flips", help="Augment with horizontal flips in training. (Default=false).", action="store_true", default=False)
44 | parser.add_option("--vf", dest="vertical_flips", help="Augment with vertical flips in training. (Default=false).", action="store_true", default=False)
45 | parser.add_option("--rot", "--rot_90", dest="rot_90", help="Augment with 90 degree rotations in training. (Default=false).",
46 | action="store_true", default=False)
47 | parser.add_option("--num_epochs", dest="num_epochs", help="Number of epochs.", default=2000)
48 | parser.add_option("--config_filename", dest="config_filename",
49 | help="Location to store all the metadata related to the training (to be used when testing).",
50 | default="config.pickle")
51 | parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5')
52 | parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.")
53 |
54 | (options, args) = parser.parse_args()
55 |
56 | if not options.train_path: # if filename is not given
57 | parser.error('Error: path to training data must be specified. Pass --path to command line')
58 |
59 | if options.parser == 'pascal_voc':
60 | from keras_frcnn.pascal_voc_parser import get_data
61 | elif options.parser == 'simple':
62 | from keras_frcnn.simple_parser import get_data
63 | else:
64 | raise ValueError("Command line option parser must be one of 'pascal_voc' or 'simple'")
65 |
66 | # pass the settings from the command line, and persist them in the config object
67 | C = config.Config()
68 |
69 | C.use_horizontal_flips = bool(options.horizontal_flips)
70 | C.use_vertical_flips = bool(options.vertical_flips)
71 | C.rot_90 = bool(options.rot_90)
72 |
73 | C.model_path = options.output_weight_path
74 | C.num_rois = int(options.num_rois)
75 |
76 | if options.network == 'vgg':
77 | C.network = 'vgg'
78 | from keras_frcnn import vgg as nn
79 | elif options.network == 'resnet50':
80 | from keras_frcnn import resnet as nn
81 | C.network = 'resnet50'
82 | elif options.network == 'xception':
83 | from keras_frcnn import xception as nn
84 | C.network = 'xception'
85 | elif options.network == 'inception_resnet_v2':
86 | from keras_frcnn import inception_resnet_v2 as nn
87 | C.network = 'inception_resnet_v2'
88 | else:
89 | print('Not a valid model')
90 | raise ValueError
91 |
92 | # check if weight path was passed via command line
93 | if options.input_weight_path:
94 | C.base_net_weights = options.input_weight_path
95 | else:
96 | # set the path to weights based on backend and model
97 | C.base_net_weights = nn.get_weight_path()
98 |
99 | # parser에서 이미지, 클래스, 클래스 맵핑 정보 가져오기
100 | all_imgs, classes_count, class_mapping = get_data(options.train_path)
101 |
102 | # bg 클래스 추가
103 | if 'bg' not in classes_count:
104 | classes_count['bg'] = 0
105 | class_mapping['bg'] = len(class_mapping)
106 |
107 | C.class_mapping = class_mapping
108 |
109 | inv_map = {v: k for k, v in class_mapping.items()}
110 |
111 | print('Training images per class:')
112 | pprint.pprint(classes_count)
113 | print('Num classes (including bg) = {}'.format(len(classes_count)))
114 |
115 | config_output_filename = options.config_filename
116 |
117 | with open(config_output_filename, 'wb') as config_f:
118 | pickle.dump(C, config_f)
119 | print('Config has been written to {}, and can be loaded when testing to ensure correct results'.format(config_output_filename))
120 |
121 | random.shuffle(all_imgs)
122 |
123 | num_imgs = len(all_imgs)
124 |
125 | train_imgs = [s for s in all_imgs if s['imageset'] == 'trainval']
126 | test_imgs = [s for s in all_imgs if s['imageset'] == 'test']
127 |
128 | print('Num train samples {}'.format(len(train_imgs)))
129 | print('Num test samples {}'.format(len(test_imgs)))
130 |
131 | # groundtruth anchor 데이터 가져오기
132 | data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='train')
133 | data_gen_test = data_generators.get_anchor_gt(test_imgs, classes_count, C, nn.get_img_output_length, K.image_dim_ordering(), mode='test')
134 |
135 | if K.image_dim_ordering() == 'th':
136 | input_shape_img = (3, None, None)
137 | else:
138 | input_shape_img = (None, None, 3)
139 |
140 | # input placeholder 정의
141 | img_input = Input(shape=input_shape_img)
142 | roi_input = Input(shape=(None, 4))
143 |
144 | # base network(feature extractor) 정의 (resnet, VGG, Inception, Inception Resnet V2, etc)
145 | shared_layers = nn.nn_base(img_input, trainable=True)
146 |
147 | # define the RPN, built on the base layers
148 | # RPN 정의
149 | num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
150 | rpn = nn.rpn(shared_layers, num_anchors)
151 |
152 | # detection network 정의
153 | classifier = nn.classifier(shared_layers, roi_input, C.num_rois, nb_classes=len(classes_count), trainable=True)
154 |
155 | model_rpn = Model(img_input, rpn[:2])
156 | model_classifier = Model([img_input, roi_input], classifier)
157 |
158 | # this is a model that holds both the RPN and the classifier, used to load/save weights for the models
159 | model_all = Model([img_input, roi_input], rpn[:2] + classifier)
160 |
161 | try:
162 | # load_weights by name
163 | # some keras application model does not containing name
164 | # for this kinds of model, we need to re-construct model with naming
165 | print('loading weights from {}'.format(C.base_net_weights))
166 | model_rpn.load_weights(C.base_net_weights, by_name=True)
167 | model_classifier.load_weights(C.base_net_weights, by_name=True)
168 | except:
169 | print('Could not load pretrained model weights. Weights can be found in the keras application folder \
170 | https://github.com/fchollet/keras/tree/master/keras/applications')
171 |
172 | optimizer = Adam(lr=1e-5)
173 | optimizer_classifier = Adam(lr=1e-5)
174 | model_rpn.compile(optimizer=optimizer, loss=[losses.rpn_loss_cls(num_anchors), losses.rpn_loss_regr(num_anchors)])
175 | model_classifier.compile(optimizer=optimizer_classifier, loss=[losses.class_loss_cls, losses.class_loss_regr(len(classes_count)-1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
176 | model_all.compile(optimizer='sgd', loss='mae')
177 |
178 | # Tensorboard log폴더 생성
179 | log_path = './logs'
180 | if not os.path.isdir(log_path):
181 | os.mkdir(log_path)
182 |
183 | # Tensorboard log모델 연결
184 | callback = TensorBoard(log_path)
185 | callback.set_model(model_all)
186 |
187 | epoch_length = 1000
188 | num_epochs = int(options.num_epochs)
189 | iter_num = 0
190 | train_step = 0
191 |
192 | losses = np.zeros((epoch_length, 5))
193 | rpn_accuracy_rpn_monitor = []
194 | rpn_accuracy_for_epoch = []
195 | start_time = time.time()
196 |
197 | best_loss = np.Inf
198 |
199 | class_mapping_inv = {v: k for k, v in class_mapping.items()}
200 | print('Starting training')
201 |
202 | # vis = True
203 |
204 | for epoch_num in range(num_epochs):
205 |
206 | progbar = generic_utils.Progbar(epoch_length) # keras progress bar 사용
207 | print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))
208 |
209 | while True:
210 | # try:
211 | # mean overlapping bboxes 출력
212 | if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
213 | mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
214 | rpn_accuracy_rpn_monitor = []
215 | print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, epoch_length))
216 | if mean_overlapping_bboxes == 0:
217 | print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
218 |
219 | # data generator에서 X, Y, image 가져오기
220 | X, Y, img_data = next(data_gen_train)
221 |
222 | loss_rpn = model_rpn.train_on_batch(X, Y)
223 | write_log(callback, ['rpn_cls_loss', 'rpn_reg_loss'], loss_rpn, train_step)
224 |
225 | P_rpn = model_rpn.predict_on_batch(X)
226 |
227 | R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
228 | # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
229 | X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
230 |
231 | if X2 is None:
232 | rpn_accuracy_rpn_monitor.append(0)
233 | rpn_accuracy_for_epoch.append(0)
234 | continue
235 |
236 | # sampling positive/negative samples
237 | neg_samples = np.where(Y1[0, :, -1] == 1)
238 | pos_samples = np.where(Y1[0, :, -1] == 0)
239 |
240 | if len(neg_samples) > 0:
241 | neg_samples = neg_samples[0]
242 | else:
243 | neg_samples = []
244 |
245 | if len(pos_samples) > 0:
246 | pos_samples = pos_samples[0]
247 | else:
248 | pos_samples = []
249 |
250 | rpn_accuracy_rpn_monitor.append(len(pos_samples))
251 | rpn_accuracy_for_epoch.append((len(pos_samples)))
252 |
253 | if C.num_rois > 1:
254 | if len(pos_samples) < C.num_rois//2:
255 | selected_pos_samples = pos_samples.tolist()
256 | else:
257 | if len(pos_samples) > 0:
258 | selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
259 | else:
260 | selected_pos_samples = []
261 | try:
262 | if len(neg_samples) > 0:
263 | selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
264 | else:
265 | selected_neg_samples = []
266 | except:
267 | if len(neg_samples) > 0:
268 | selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()
269 | else:
270 | selected_neg_samples = []
271 |
272 | sel_samples = selected_pos_samples + selected_neg_samples
273 | else:
274 | # in the extreme case where num_rois = 1, we pick a random pos or neg sample
275 | selected_pos_samples = pos_samples.tolist()
276 | selected_neg_samples = neg_samples.tolist()
277 | if np.random.randint(0, 2):
278 | sel_samples = random.choice(neg_samples)
279 | else:
280 | sel_samples = random.choice(pos_samples)
281 |
282 | loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
283 | write_log(callback, ['detection_cls_loss', 'detection_reg_loss', 'detection_acc'], loss_class, train_step)
284 | train_step += 1
285 |
286 | losses[iter_num, 0] = loss_rpn[1]
287 | losses[iter_num, 1] = loss_rpn[2]
288 |
289 | losses[iter_num, 2] = loss_class[1]
290 | losses[iter_num, 3] = loss_class[2]
291 | losses[iter_num, 4] = loss_class[3]
292 |
293 | iter_num += 1
294 |
295 | progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
296 | ('detector_cls', np.mean(losses[:iter_num, 2])), ('detector_regr', np.mean(losses[:iter_num, 3]))])
297 |
298 | if iter_num == epoch_length:
299 | loss_rpn_cls = np.mean(losses[:, 0])
300 | loss_rpn_regr = np.mean(losses[:, 1])
301 | loss_class_cls = np.mean(losses[:, 2])
302 | loss_class_regr = np.mean(losses[:, 3])
303 | class_acc = np.mean(losses[:, 4])
304 |
305 | mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
306 | rpn_accuracy_for_epoch = []
307 |
308 | if C.verbose:
309 | print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
310 | print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
311 | print('Loss RPN classifier: {}'.format(loss_rpn_cls))
312 | print('Loss RPN regression: {}'.format(loss_rpn_regr))
313 | print('Loss Detector classifier: {}'.format(loss_class_cls))
314 | print('Loss Detector regression: {}'.format(loss_class_regr))
315 | print('Elapsed time: {}'.format(time.time() - start_time))
316 |
317 | curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
318 | iter_num = 0
319 | start_time = time.time()
320 |
321 | write_log(callback,
322 | ['Elapsed_time', 'mean_overlapping_bboxes', 'mean_rpn_cls_loss', 'mean_rpn_reg_loss',
323 | 'mean_detection_cls_loss', 'mean_detection_reg_loss', 'mean_detection_acc', 'total_loss'],
324 | [time.time() - start_time, mean_overlapping_bboxes, loss_rpn_cls, loss_rpn_regr,
325 | loss_class_cls, loss_class_regr, class_acc, curr_loss],
326 | epoch_num)
327 |
328 | if curr_loss < best_loss:
329 | if C.verbose:
330 | print('Total loss decreased from {} to {}, saving weights'.format(best_loss,curr_loss))
331 | best_loss = curr_loss
332 | model_all.save_weights(C.model_path)
333 |
334 | break
335 |
336 | # except Exception as e:
337 | # print('Exception: {}'.format(e))
338 | # continue
339 |
340 | print('Training complete, exiting.')
341 |
--------------------------------------------------------------------------------
/transfer/export_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # from keras.applications import *
5 | from inception_resnet_v2 import InceptionResnetV2_model
6 | from keras import backend as k
7 | from keras.layers import *
8 | from keras.models import Model
9 | from keras.optimizers import *
10 |
11 | # hyper parameters for model
12 | nb_classes = 1 # number of classes
13 | # change based on the shape/structure of your images
14 | img_width, img_height = 299, 299
15 |
16 |
17 | def export(model_path):
18 | # Pre-Trained CNN Model using imagenet dataset for pre-trained weights
19 | # base_model = InceptionResnetV2_model(input_shape=(img_width, img_height, 3), weights='imagenet', include_top=False)
20 | base_model = InceptionResnetV2_model(input_shape=(
21 | img_width, img_height, 3), weights='imagenet', include_top=False)
22 |
23 | x = base_model.output
24 | x = GlobalAveragePooling2D()(x)
25 | predictions = Dense(nb_classes, activation='softmax')(x)
26 |
27 | # add your top layer block to your base model
28 | model = Model(base_model.input, predictions)
29 | print(model.summary())
30 |
31 | model.save(model_path + '/inception_resnet_v2.h5')
32 | # save model
33 | # model_json = model.to_json()
34 | # with open(os.path.join(os.path.abspath(model_path), 'model.json'), 'w') as json_file:
35 | # json_file.write(model_json)
36 |
37 |
38 | if __name__ == '__main__':
39 | export(os.getcwd()) # train model
40 |
41 | # release memory
42 | k.clear_session()
43 |
--------------------------------------------------------------------------------
/transfer/inception_resnet_v2.py:
--------------------------------------------------------------------------------
1 | from keras import utils
2 | from keras import *
3 |
4 | def conv2d_bn(x,
5 | filters,
6 | kernel_size,
7 | strides=1,
8 | padding='same',
9 | activation='relu',
10 | use_bias=False,
11 | name=None):
12 | """Utility function to apply conv + BN.
13 | # Arguments
14 | x: input tensor.
15 | filters: filters in `Conv2D`.
16 | kernel_size: kernel size as in `Conv2D`.
17 | strides: strides in `Conv2D`.
18 | padding: padding mode in `Conv2D`.
19 | activation: activation in `Conv2D`.
20 | use_bias: whether to use a bias in `Conv2D`.
21 | name: name of the ops; will become `name + '_ac'` for the activation
22 | and `name + '_bn'` for the batch norm layer.
23 | # Returns
24 | Output tensor after applying `Conv2D` and `BatchNormalization`.
25 | """
26 | if name is None:
27 | print('None!!')
28 |
29 | x = layers.Conv2D(filters,
30 | kernel_size,
31 | strides=strides,
32 | padding=padding,
33 | use_bias=use_bias,
34 | name=name)(x)
35 | if not use_bias:
36 | bn_axis = 1 if backend.image_data_format() == 'channels_first' else 3
37 | bn_name = None if name is None else name + '_bn'
38 | x = layers.BatchNormalization(axis=bn_axis,
39 | scale=False,
40 | name=bn_name)(x)
41 | if activation is not None:
42 | ac_name = None if name is None else name + '_ac'
43 | x = layers.Activation(activation, name=ac_name)(x)
44 | return x
45 |
46 |
47 | def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
48 | """Adds a Inception-ResNet block.
49 | This function builds 3 types of Inception-ResNet blocks mentioned
50 | in the paper, controlled by the `block_type` argument (which is the
51 | block name used in the official TF-slim implementation):
52 | - Inception-ResNet-A: `block_type='block35'`
53 | - Inception-ResNet-B: `block_type='block17'`
54 | - Inception-ResNet-C: `block_type='block8'`
55 | # Arguments
56 | x: input tensor.
57 | scale: scaling factor to scale the residuals (i.e., the output of
58 | passing `x` through an inception module) before adding them
59 | to the shortcut branch.
60 | Let `r` be the output from the residual branch,
61 | the output of this block will be `x + scale * r`.
62 | block_type: `'block35'`, `'block17'` or `'block8'`, determines
63 | the network structure in the residual branch.
64 | block_idx: an `int` used for generating layer names.
65 | The Inception-ResNet blocks
66 | are repeated many times in this network.
67 | We use `block_idx` to identify
68 | each of the repetitions. For example,
69 | the first Inception-ResNet-A block
70 | will have `block_type='block35', block_idx=0`,
71 | and the layer names will have
72 | a common prefix `'block35_0'`.
73 | activation: activation function to use at the end of the block
74 | (see [activations](../activations.md)).
75 | When `activation=None`, no activation is applied
76 | (i.e., "linear" activation: `a(x) = x`).
77 | # Returns
78 | Output tensor for the block.
79 | # Raises
80 | ValueError: if `block_type` is not one of `'block35'`,
81 | `'block17'` or `'block8'`.
82 | """
83 | block_name = block_type + '_' + str(block_idx)
84 |
85 | if block_type == 'block35':
86 | branch_0 = conv2d_bn(x, 32, 1, name=block_name + '_conv1')
87 | branch_1 = conv2d_bn(x, 32, 1, name=block_name + '_conv2')
88 | branch_1 = conv2d_bn(branch_1, 32, 3, name=block_name + '_conv3')
89 | branch_2 = conv2d_bn(x, 32, 1, name=block_name + '_conv4')
90 | branch_2 = conv2d_bn(branch_2, 48, 3, name=block_name + '_conv5')
91 | branch_2 = conv2d_bn(branch_2, 64, 3, name=block_name + '_conv6')
92 | branches = [branch_0, branch_1, branch_2]
93 | elif block_type == 'block17':
94 | branch_0 = conv2d_bn(x, 192, 1, name=block_name + '_conv1')
95 | branch_1 = conv2d_bn(x, 128, 1, name=block_name + '_conv2')
96 | branch_1 = conv2d_bn(branch_1, 160, [1, 7], name=block_name + '_conv3')
97 | branch_1 = conv2d_bn(branch_1, 192, [7, 1], name=block_name + '_conv4')
98 | branches = [branch_0, branch_1]
99 | elif block_type == 'block8':
100 | branch_0 = conv2d_bn(x, 192, 1, name=block_name + '_conv1')
101 | branch_1 = conv2d_bn(x, 192, 1, name=block_name + '_conv2')
102 | branch_1 = conv2d_bn(branch_1, 224, [1, 3], name=block_name + '_conv3')
103 | branch_1 = conv2d_bn(branch_1, 256, [3, 1], name=block_name + '_conv4')
104 | branches = [branch_0, branch_1]
105 | else:
106 | raise ValueError('Unknown Inception-ResNet block type. '
107 | 'Expects "block35", "block17" or "block8", '
108 | 'but got: ' + str(block_type))
109 |
110 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
111 | mixed = layers.Concatenate(
112 | axis=channel_axis, name=block_name + '_mixed')(branches)
113 | up = conv2d_bn(mixed,
114 | backend.int_shape(x)[channel_axis],
115 | 1,
116 | activation=None,
117 | use_bias=True,
118 | name=block_name + '_conv')
119 |
120 | x = layers.Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
121 | output_shape=backend.int_shape(x)[1:],
122 | arguments={'scale': scale},
123 | name=block_name)([x, up])
124 | if activation is not None:
125 | x = layers.Activation(activation, name=block_name + '_ac')(x)
126 | return x
127 |
128 |
129 | def InceptionResnetV2_model(input_shape, include_top=False, input_tensor=None, weights='imagenet', pooling=None):
130 | if input_tensor is None:
131 | img_input = layers.Input(shape=input_shape)
132 | else:
133 | if not backend.is_keras_tensor(input_tensor):
134 | img_input = layers.Input(tensor=input_tensor, shape=input_shape)
135 | else:
136 | img_input = input_tensor
137 |
138 | # Stem block: 35 x 35 x 192
139 | x = conv2d_bn(img_input, 32, 3, strides=2, padding='valid', name='Stem_block' + '_conv1')
140 | x = conv2d_bn(x, 32, 3, padding='valid', name='Stem_block' + '_conv2')
141 | x = conv2d_bn(x, 64, 3, name='Stem_block' + '_conv3')
142 | x = layers.MaxPooling2D(3, strides=2)(x)
143 | x = conv2d_bn(x, 80, 1, padding='valid', name='Stem_block' + '_conv4')
144 | x = conv2d_bn(x, 192, 3, padding='valid', name='Stem_block' + '_conv5')
145 | x = layers.MaxPooling2D(3, strides=2)(x)
146 |
147 | # Mixed 5b (Inception-A block): 35 x 35 x 320
148 | branch_0 = conv2d_bn(x, 96, 1, name='Inception_A_block' + '_conv1')
149 | branch_1 = conv2d_bn(x, 48, 1, name='Inception_A_block' + '_conv2')
150 | branch_1 = conv2d_bn(branch_1, 64, 5, name='Inception_A_block' + '_conv3')
151 | branch_2 = conv2d_bn(x, 64, 1, name='Inception_A_block' + '_conv4')
152 | branch_2 = conv2d_bn(branch_2, 96, 3, name='Inception_A_block' + '_conv5')
153 | branch_2 = conv2d_bn(branch_2, 96, 3, name='Inception_A_block' + '_conv6')
154 | branch_pool = layers.AveragePooling2D(3, strides=1, padding='same')(x)
155 | branch_pool = conv2d_bn(branch_pool, 64, 1, name='Inception_A_block' + '_conv7')
156 | branches = [branch_0, branch_1, branch_2, branch_pool]
157 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else 3
158 | x = layers.Concatenate(axis=channel_axis, name='mixed_5b')(branches)
159 |
160 | # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320
161 | for block_idx in range(1, 11):
162 | x = inception_resnet_block(x,
163 | scale=0.17,
164 | block_type='block35',
165 | block_idx=block_idx)
166 |
167 | # Mixed 6a (Reduction-A block): 17 x 17 x 1088
168 | branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid', name='Reduction_A_block' + '_conv1')
169 | branch_1 = conv2d_bn(x, 256, 1, name='Reduction_A_block' + '_conv2')
170 | branch_1 = conv2d_bn(branch_1, 256, 3, name='Reduction_A_block' + '_conv3')
171 | branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid', name='Reduction_A_block' + '_conv4')
172 | branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
173 | branches = [branch_0, branch_1, branch_pool]
174 | x = layers.Concatenate(axis=channel_axis, name='mixed_6a')(branches)
175 |
176 | # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088
177 | for block_idx in range(1, 21):
178 | x = inception_resnet_block(x,
179 | scale=0.1,
180 | block_type='block17',
181 | block_idx=block_idx)
182 |
183 | # Mixed 7a (Reduction-B block): 8 x 8 x 2080
184 | branch_0 = conv2d_bn(x, 256, 1, name='Reduction_B_block' + '_conv1')
185 | branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid', name='Reduction_B_block' + '_conv2')
186 | branch_1 = conv2d_bn(x, 256, 1, name='Reduction_B_block' + '_conv3')
187 | branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid', name='Reduction_B_block' + '_conv4')
188 | branch_2 = conv2d_bn(x, 256, 1, name='Reduction_B_block' + '_conv5')
189 | branch_2 = conv2d_bn(branch_2, 288, 3, name='Reduction_B_block' + '_conv6')
190 | branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid', name='Reduction_B_block' + '_conv7')
191 | branch_pool = layers.MaxPooling2D(3, strides=2, padding='valid')(x)
192 | branches = [branch_0, branch_1, branch_2, branch_pool]
193 | x = layers.Concatenate(axis=channel_axis, name='mixed_7a')(branches)
194 |
195 | # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080
196 | for block_idx in range(1, 10):
197 | x = inception_resnet_block(x,
198 | scale=0.2,
199 | block_type='block8',
200 | block_idx=block_idx)
201 | x = inception_resnet_block(x,
202 | scale=1.,
203 | activation=None,
204 | block_type='block8',
205 | block_idx=10)
206 |
207 | # Final convolution block: 8 x 8 x 1536
208 | x = conv2d_bn(x, 1536, 1, name='conv_7b')
209 |
210 |
211 | if include_top:
212 | # Classification block
213 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
214 | x = layers.Dense(1000, activation='softmax', name='predictions')(x)
215 | else:
216 | if pooling == 'avg':
217 | x = layers.GlobalAveragePooling2D()(x)
218 | elif pooling == 'max':
219 | x = layers.GlobalMaxPooling2D()(x)
220 |
221 | inputs = img_input
222 |
223 | # Create model.
224 | model = models.Model(inputs, x, name='inception_resnet_v2')
225 |
226 | BASE_WEIGHT_URL = ('https://github.com/fchollet/deep-learning-models/'
227 | 'releases/download/v0.7/')
228 |
229 | # Load weights.
230 | if weights == 'imagenet':
231 | if include_top:
232 | fname = 'inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5'
233 | weights_path = utils.get_file(
234 | fname,
235 | BASE_WEIGHT_URL + fname,
236 | cache_subdir='models',
237 | file_hash='e693bd0210a403b3192acc6073ad2e96')
238 | else:
239 | fname = ('inception_resnet_v2_weights_'
240 | 'tf_dim_ordering_tf_kernels_notop.h5')
241 | weights_path = utils.get_file(
242 | fname,
243 | BASE_WEIGHT_URL + fname,
244 | cache_subdir='models',
245 | file_hash='d19885ff4a710c122648d3b5c3b684e4')
246 | model.load_weights(weights_path)
247 | elif weights is not None:
248 | model.load_weights(weights)
249 |
250 | return model
--------------------------------------------------------------------------------