├── README.md
├── datasets
├── __init__.py
├── __init__.pyc
├── dataset_utils.py
├── dataset_utils.pyc
├── synthtext_to_tfrecords_self.py
├── sythtextprovider.py
├── sythtextprovider.pyc
├── xml_to_tfrecords.py
└── xml_to_tfrecords.pyc
├── demo.py
├── demo
├── example
│ ├── image0.txt
│ ├── image0.xml
│ └── standard.xml
└── img_1.jpg
├── deployment
├── __init__.py
├── __init__.pyc
├── model_deploy.py
├── model_deploy.pyc
└── model_deploy_test.py
├── eval.py
├── eval_result.py
├── gene_tfrecords.py
├── logs
└── train_xml.txt
├── nets
├── __init__.py
├── __init__.pyc
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── custom_layers.cpython-35.pyc
│ ├── np_methods.cpython-35.pyc
│ ├── textbox_common.cpython-35.pyc
│ ├── txtbox_384.cpython-35.pyc
│ └── txtbox_768.cpython-35.pyc
├── custom_layers.py
├── custom_layers.pyc
├── np_methods.py
├── np_methods.pyc
├── textbox_common.py
├── textbox_common.pyc
├── txtbox_384.py
├── txtbox_384.pyc
├── txtbox_768.py
└── txtbox_768.pyc
├── processing
├── __init__.py
├── __init__.pyc
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── ssd_vgg_preprocessing.cpython-35.pyc
│ └── tf_image.cpython-35.pyc
├── ssd_vgg_preprocessing.py
├── ssd_vgg_preprocessing.pyc
├── tf_image.py
└── tf_image.pyc
├── test.py
├── tf_extended
├── __init__.py
├── __init__.pyc
├── __pycache__
│ ├── __init__.cpython-35.pyc
│ ├── bboxes.cpython-35.pyc
│ ├── image.cpython-35.pyc
│ ├── math.cpython-35.pyc
│ ├── metrics.cpython-35.pyc
│ └── tensors.cpython-35.pyc
├── bboxes.py
├── bboxes.pyc
├── image.py
├── image.pyc
├── math.py
├── math.pyc
├── metrics.py
├── metrics.pyc
├── tensors.py
├── tensors.pyc
├── tf_utils.py
└── tf_utils.pyc
├── tools
├── _init_paths.py
├── convert_xml_format.py
├── gen_xml.py
└── test_dataset.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # TextBoxes++-TensorFlow
2 | TextBoxes++ re-implementation using tensorflow.
3 | This project is greatly inspired by [slim project](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim)
4 | And many functions are modified based on [SSD-tensorflow project](https://github.com/balancap/SSD-Tensorflow)
5 |
6 | Author:
7 | Zhisheng Zou zzsshun13@gmail.com
8 |
9 | # pretrained model
10 | 1. [Google drive](https://drive.google.com/open?id=1kkRyVrx9iFtwEar6OJBKWNVyTLSYsF28)
11 |
12 | # environment
13 | ` python2.7/python3.5 `
14 |
15 | `tensorflow-gpu 1.8.0`
16 |
17 | `at least one gpu`
18 |
19 | # how to use
20 |
21 | 1. Getting the xml file like this [example xml](./demo/example/image0.xml) and put the image together because we need the format like this [standard xml](./demo/example/standard.xml)
22 | 1. picture format: *.png or *.PNG
23 | 2. Getting the xml and flags
24 | ensure the XML file is under the same directory as the corresponding image.execute the code: [convert_xml_format.py](./tools/convert_xml_format.py)
25 | 1. `python tools/convert_xml_format.py -i in_dir -s split_flag -l save_logs -o output_dir`
26 | 2. in_dir means the absolute directory which contains the pic and xml
27 | 3. split_flag means whether or not to split the datasets
28 | 4. save_logs means whether to save train_xml.txt
29 | 5. output_dir means where to save xmls
30 | 3. Getting the tfrecords
31 | 1. `python gene_tfrecords.py --xml_img_txt_path=./logs/train_xml.txt --output_dir=tfrecords`
32 | 2. xml_img_txt_path like this [train xml](./logs/train_xml.txt)
33 | 3. output_dir means where to save tfrecords
34 | 4. Training
35 | 1. `python train.py --train_dir =some_path --dataset_dir=some_path --checkpoint_path=some_path`
36 | 2. train_dir store the checkpoints when training
37 | 3. dataset_dir store the tfrecords for training
38 | 4. checkpoint_path store the model which needs to be fine tuned
39 | 5. Testing
40 | 1. `python test.py -m /home/model.ckpt-858 -o test`
41 | 2. -m which means the model
42 | 3. -o which means output_result_dir
43 | 4. -i which means the test img dir
44 | 5. -c which means use which device to run the test
45 | 6. -n which means the nms threshold
46 | 7. -s which means the score threshold
47 |
48 |
49 |
50 | # Note:
51 |
52 | 1. when you are training the model, you can run the eval_result.py to eval your model and save the result
53 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/datasets/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/__init__.pyc
--------------------------------------------------------------------------------
/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains utilities for downloading and converting datasets."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import os
21 | import sys
22 | import tarfile
23 |
24 | from six.moves import urllib
25 | import tensorflow as tf
26 |
27 | LABELS_FILENAME = 'labels.txt'
28 | def norm(x):
29 | if x < 0:
30 | x = 0
31 | else:
32 | if x > 1:
33 | x = 1
34 | return x
35 |
36 | def int64_feature(value):
37 | """Wrapper for inserting int64 features into Example proto.
38 | """
39 | if not isinstance(value, list):
40 | value = [value]
41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
42 |
43 |
44 | def float_feature(value):
45 | """Wrapper for inserting float features into Example proto.
46 | """
47 | if not isinstance(value, list):
48 | value = [value]
49 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
50 |
51 |
52 | def bytes_feature(value):
53 | """Wrapper for inserting bytes features into Example proto.
54 | """
55 | if not isinstance(value, list):
56 | value = [value]
57 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
58 |
59 |
60 | def image_to_tfexample(image_data, image_format, height, width, class_id):
61 | return tf.train.Example(features=tf.train.Features(feature={
62 | 'image/encoded': bytes_feature(image_data),
63 | 'image/format': bytes_feature(image_format),
64 | 'image/class/label': int64_feature(class_id),
65 | 'image/height': int64_feature(height),
66 | 'image/width': int64_feature(width),
67 | }))
68 |
69 |
70 | def download_and_uncompress_tarball(tarball_url, dataset_dir):
71 | """Downloads the `tarball_url` and uncompresses it locally.
72 |
73 | Args:
74 | tarball_url: The URL of a tarball file.
75 | dataset_dir: The directory where the temporary files are stored.
76 | """
77 | filename = tarball_url.split('/')[-1]
78 | filepath = os.path.join(dataset_dir, filename)
79 |
80 | def _progress(count, block_size, total_size):
81 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
82 | filename, float(count * block_size) / float(total_size) * 100.0))
83 | sys.stdout.flush()
84 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
85 | statinfo = os.stat(filepath)
86 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
87 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
88 |
89 |
90 | def write_label_file(labels_to_class_names, dataset_dir,
91 | filename=LABELS_FILENAME):
92 | """Writes a file with the list of class names.
93 |
94 | Args:
95 | labels_to_class_names: A map of (integer) labels to class names.
96 | dataset_dir: The directory in which the labels file should be written.
97 | filename: The filename where the class names are written.
98 | """
99 | labels_filename = os.path.join(dataset_dir, filename)
100 | with tf.gfile.Open(labels_filename, 'w') as f:
101 | for label in labels_to_class_names:
102 | class_name = labels_to_class_names[label]
103 | f.write('%d:%s\n' % (label, class_name))
104 |
105 |
106 | def has_labels(dataset_dir, filename=LABELS_FILENAME):
107 | """Specifies whether or not the dataset directory contains a label map file.
108 |
109 | Args:
110 | dataset_dir: The directory in which the labels file is found.
111 | filename: The filename where the class names are written.
112 |
113 | Returns:
114 | `True` if the labels file exists and `False` otherwise.
115 | """
116 | return tf.gfile.Exists(os.path.join(dataset_dir, filename))
117 |
118 |
119 | def read_label_file(dataset_dir, filename=LABELS_FILENAME):
120 | """Reads the labels file and returns a mapping from ID to class name.
121 |
122 | Args:
123 | dataset_dir: The directory in which the labels file is found.
124 | filename: The filename where the class names are written.
125 |
126 | Returns:
127 | A map from a label (integer) to class name.
128 | """
129 | labels_filename = os.path.join(dataset_dir, filename)
130 | with tf.gfile.Open(labels_filename, 'rb') as f:
131 | lines = f.read()
132 | lines = lines.split(b'\n')
133 | lines = filter(None, lines)
134 |
135 | labels_to_class_names = {}
136 | for line in lines:
137 | index = line.index(b':')
138 | labels_to_class_names[int(line[:index])] = line[index+1:]
139 | return labels_to_class_names
140 |
141 |
142 | class ImageCoder(object):
143 | """Helper class that provides TensorFlow image coding utilities."""
144 |
145 | def __init__(self):
146 | # Create a single Session to run all image coding calls.
147 | self._sess = tf.Session()
148 |
149 | # Initializes function that converts PNG to JPEG data.
150 | self._png_data = tf.placeholder(dtype=tf.string)
151 | image = tf.image.decode_png(self._png_data, channels=3)
152 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100)
153 |
154 | # Initializes function that converts CMYK JPEG data to RGB JPEG data.
155 | self._cmyk_data = tf.placeholder(dtype=tf.string)
156 | image = tf.image.decode_jpeg(self._cmyk_data, channels=0)
157 | self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100)
158 |
159 | # Initializes function that decodes RGB JPEG data.
160 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
161 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
162 |
163 | def png_to_jpeg(self, image_data):
164 | return self._sess.run(self._png_to_jpeg,
165 | feed_dict={self._png_data: image_data})
166 |
167 | def cmyk_to_rgb(self, image_data):
168 | return self._sess.run(self._cmyk_to_rgb,
169 | feed_dict={self._cmyk_data: image_data})
170 |
171 | def decode_jpeg(self, image_data):
172 | image = self._sess.run(self._decode_jpeg,
173 | feed_dict={self._decode_jpeg_data: image_data})
174 | assert len(image.shape) == 3
175 | assert image.shape[2] == 3
176 | return image
177 |
--------------------------------------------------------------------------------
/datasets/dataset_utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/dataset_utils.pyc
--------------------------------------------------------------------------------
/datasets/synthtext_to_tfrecords_self.py:
--------------------------------------------------------------------------------
1 | #encoding=utf-8
2 | import numpy as np
3 | import tensorflow as tf
4 | import time
5 | import tensorflow.contrib.slim as slim
6 | import util
7 |
8 |
9 |
10 | def int64_feature(value):
11 | """Wrapper for inserting int64 features into Example proto.
12 | """
13 | if not isinstance(value, list):
14 | value = [value]
15 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
16 |
17 |
18 | def float_feature(value):
19 | """Wrapper for inserting float features into Example proto.
20 | """
21 | if not isinstance(value, list):
22 | value = [value]
23 | return tf.train.Feature(float_list=tf.train.FloatList(value=value))
24 |
25 |
26 | def bytes_feature(value):
27 | """Wrapper for inserting bytes features into Example proto.
28 | """
29 | if not isinstance(value, list):
30 | value = [value]
31 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
32 |
33 |
34 | def image_to_tfexample(image_data, image_format, height, width, class_id):
35 | return tf.train.Example(features=tf.train.Features(feature={
36 | 'image/encoded': bytes_feature(image_data),
37 | 'image/format': bytes_feature(image_format),
38 | 'image/class/label': int64_feature(class_id),
39 | 'image/height': int64_feature(height),
40 | 'image/width': int64_feature(width),
41 | }))
42 |
43 |
44 | def convert_to_example(image_data, filename, labels, ignored, labels_text, bboxes, oriented_bboxes, shape):
45 | """Build an Example proto for an image example.
46 | Args:
47 | image_data: string, JPEG encoding of RGB image
48 | labels: list of integers, identifier for the ground truth
49 | labels_text: list of strings, human-readable labels
50 | oriented_bboxes: list of bounding oriented boxes each box is a list of floats in [0, 1]
51 | specifying [x1, y1, x2, y2, x3, y3, x4, y4]
52 | bboxes: list of bbox in rectangle, [xmin, ymin, xmax, ymax]
53 | Returns:
54 | Example proto
55 | """
56 |
57 | image_format = b'JPEG'
58 | oriented_bboxes = np.asarray(oriented_bboxes)
59 | bboxes = np.asarray(bboxes)
60 | example = tf.train.Example(features=tf.train.Features(feature={
61 | 'image/shape': int64_feature(list(shape)),
62 | 'image/object/bbox/xmin': float_feature(list(bboxes[:, 0])),
63 | 'image/object/bbox/ymin': float_feature(list(bboxes[:, 1])),
64 | 'image/object/bbox/xmax': float_feature(list(bboxes[:, 2])),
65 | 'image/object/bbox/ymax': float_feature(list(bboxes[:, 3])),
66 | 'image/object/bbox/x1': float_feature(list(oriented_bboxes[:, 0])),
67 | 'image/object/bbox/y1': float_feature(list(oriented_bboxes[:, 1])),
68 | 'image/object/bbox/x2': float_feature(list(oriented_bboxes[:, 2])),
69 | 'image/object/bbox/y2': float_feature(list(oriented_bboxes[:, 3])),
70 | 'image/object/bbox/x3': float_feature(list(oriented_bboxes[:, 4])),
71 | 'image/object/bbox/y3': float_feature(list(oriented_bboxes[:, 5])),
72 | 'image/object/bbox/x4': float_feature(list(oriented_bboxes[:, 6])),
73 | 'image/object/bbox/y4': float_feature(list(oriented_bboxes[:, 7])),
74 | 'image/object/bbox/label': int64_feature(labels),
75 | 'image/object/bbox/label_text': bytes_feature(labels_text),
76 | 'image/object/bbox/ignored': int64_feature(ignored),
77 | 'image/format': bytes_feature(image_format),
78 | 'image/filename': bytes_feature(filename),
79 | 'image/encoded': bytes_feature(image_data)}))
80 | return example
81 |
82 |
83 |
84 | def get_split(split_name, dataset_dir, file_pattern, num_samples, reader=None):
85 | dataset_dir = util.io.get_absolute_path(dataset_dir)
86 |
87 | if util.str.contains(file_pattern, '%'):
88 | file_pattern = util.io.join_path(dataset_dir, file_pattern % split_name)
89 | else:
90 | file_pattern = util.io.join_path(dataset_dir, file_pattern)
91 | # Allowing None in the signature so that dataset_factory can use the default.
92 | if reader is None:
93 | reader = tf.TFRecordReader
94 | keys_to_features = {
95 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
96 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
97 | 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
98 | 'image/shape': tf.FixedLenFeature([3], tf.int64),
99 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
100 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
101 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
102 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
103 | 'image/object/bbox/x1': tf.VarLenFeature(dtype=tf.float32),
104 | 'image/object/bbox/x2': tf.VarLenFeature(dtype=tf.float32),
105 | 'image/object/bbox/x3': tf.VarLenFeature(dtype=tf.float32),
106 | 'image/object/bbox/x4': tf.VarLenFeature(dtype=tf.float32),
107 | 'image/object/bbox/y1': tf.VarLenFeature(dtype=tf.float32),
108 | 'image/object/bbox/y2': tf.VarLenFeature(dtype=tf.float32),
109 | 'image/object/bbox/y3': tf.VarLenFeature(dtype=tf.float32),
110 | 'image/object/bbox/y4': tf.VarLenFeature(dtype=tf.float32),
111 | 'image/object/bbox/ignored': tf.VarLenFeature(dtype=tf.int64),
112 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
113 | }
114 | items_to_handlers = {
115 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
116 | 'shape': slim.tfexample_decoder.Tensor('image/shape'),
117 | 'filename': slim.tfexample_decoder.Tensor('image/filename'),
118 | 'object/bbox': slim.tfexample_decoder.BoundingBox(
119 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
120 | 'object/oriented_bbox/x1': slim.tfexample_decoder.Tensor('image/object/bbox/x1'),
121 | 'object/oriented_bbox/x2': slim.tfexample_decoder.Tensor('image/object/bbox/x2'),
122 | 'object/oriented_bbox/x3': slim.tfexample_decoder.Tensor('image/object/bbox/x3'),
123 | 'object/oriented_bbox/x4': slim.tfexample_decoder.Tensor('image/object/bbox/x4'),
124 | 'object/oriented_bbox/y1': slim.tfexample_decoder.Tensor('image/object/bbox/y1'),
125 | 'object/oriented_bbox/y2': slim.tfexample_decoder.Tensor('image/object/bbox/y2'),
126 | 'object/oriented_bbox/y3': slim.tfexample_decoder.Tensor('image/object/bbox/y3'),
127 | 'object/oriented_bbox/y4': slim.tfexample_decoder.Tensor('image/object/bbox/y4'),
128 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
129 | 'object/ignored': slim.tfexample_decoder.Tensor('image/object/bbox/ignored')
130 | }
131 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
132 |
133 | labels_to_names = {0:'background', 1:'text'}
134 | items_to_descriptions = {
135 | 'image': 'A color image of varying height and width.',
136 | 'shape': 'Shape of the image',
137 | 'object/bbox': 'A list of bounding boxes, one per each object.',
138 | 'object/label': 'A list of labels, one per each object.',
139 | }
140 |
141 | return slim.dataset.Dataset(
142 | data_sources=file_pattern,
143 | reader=reader,
144 | decoder=decoder,
145 | num_samples=num_samples,
146 | items_to_descriptions=items_to_descriptions,
147 | num_classes=2,
148 | labels_to_names=labels_to_names)
149 |
150 |
151 |
152 |
153 | class SynthTextDataFetcher():
154 | def __init__(self, mat_path, root_path):
155 | self.mat_path = mat_path
156 | self.root_path = root_path
157 | self._load_mat()
158 |
159 | # @util.dec.print_calling
160 | def _load_mat(self):
161 | data = util.io.load_mat(self.mat_path)
162 | self.image_paths = data['imnames'][0]
163 | self.image_bbox = data['wordBB'][0]
164 | self.txts = data['txt'][0]
165 | self.num_images = len(self.image_paths)
166 |
167 | def get_image_path(self, idx):
168 | image_path = util.io.join_path(self.root_path, self.image_paths[idx][0])
169 | return image_path
170 |
171 | def get_num_words(self, idx):
172 | try:
173 | return np.shape(self.image_bbox[idx])[2]
174 | except: # error caused by dataset
175 | return 1
176 |
177 |
178 | def get_word_bbox(self, img_idx, word_idx):
179 | boxes = self.image_bbox[img_idx]
180 | if len(np.shape(boxes)) ==2: # error caused by dataset
181 | boxes = np.reshape(boxes, (2, 4, 1))
182 |
183 | xys = boxes[:,:, word_idx]
184 | assert(np.shape(xys) ==(2, 4))
185 | return np.float32(xys)
186 |
187 | def normalize_bbox(self, xys, width, height):
188 | xs = xys[0, :]
189 | ys = xys[1, :]
190 |
191 | min_x = min(xs)
192 | min_y = min(ys)
193 | max_x = max(xs)
194 | max_y = max(ys)
195 |
196 | # bound them in the valid range
197 | min_x = max(0, min_x)
198 | min_y = max(0, min_y)
199 | max_x = min(width, max_x)
200 | max_y = min(height, max_y)
201 |
202 | # check the w, h and area of the rect
203 | w = max_x - min_x
204 | h = max_y - min_y
205 | is_valid = True
206 |
207 | if w < 10 or h < 10:
208 | is_valid = False
209 |
210 | if w * h < 100:
211 | is_valid = False
212 |
213 | xys[0, :] = xys[0, :] / width
214 | xys[1, :] = xys[1, :] / height
215 |
216 | return is_valid, min_x / width, min_y /height, max_x / width, max_y / height, xys
217 |
218 | def get_txt(self, image_idx, word_idx):
219 | txts = self.txts[image_idx]
220 | clean_txts = []
221 | for txt in txts:
222 | clean_txts += txt.split()
223 | return str(clean_txts[word_idx])
224 |
225 |
226 | def fetch_record(self, image_idx):
227 | image_path = self.get_image_path(image_idx)
228 | if not (util.io.exists(image_path)):
229 | return None
230 | img = util.img.imread(image_path)
231 | h, w = img.shape[0:-1]
232 | num_words = self.get_num_words(image_idx)
233 | rect_bboxes = []
234 | full_bboxes = []
235 | txts = []
236 | for word_idx in range(num_words):
237 | xys = self.get_word_bbox(image_idx, word_idx)
238 | is_valid, min_x, min_y, max_x, max_y, xys = self.normalize_bbox(xys, width = w, height = h)
239 | if not is_valid:
240 | continue
241 | rect_bboxes.append([min_x, min_y, max_x, max_y])
242 | xys = np.reshape(np.transpose(xys), -1)
243 | full_bboxes.append(xys)
244 | txt = self.get_txt(image_idx, word_idx)
245 | txts.append(txt)
246 | if len(rect_bboxes) == 0:
247 | return None
248 |
249 | return image_path, img, txts, rect_bboxes, full_bboxes
250 |
251 |
252 |
253 | def cvt_to_tfrecords(output_path , data_path, gt_path, records_per_file = 30000):
254 |
255 | fetcher = SynthTextDataFetcher(root_path = data_path, mat_path = gt_path)
256 | fid = 0
257 | image_idx = -1
258 | while image_idx < fetcher.num_images:
259 | with tf.python_io.TFRecordWriter(output_path%(fid)) as tfrecord_writer:
260 | record_count = 0
261 | while record_count != records_per_file:
262 | image_idx += 1
263 | if image_idx >= fetcher.num_images:
264 | break
265 | print("loading image %d/%d"%(image_idx + 1, fetcher.num_images))
266 | record = fetcher.fetch_record(image_idx)
267 | if record is None:
268 | print('\nimage %d does not exist'%(image_idx + 1))
269 | continue
270 |
271 | image_path, image, txts, rect_bboxes, oriented_bboxes = record
272 | labels = len(rect_bboxes) * [1]
273 | ignored = len(rect_bboxes) * [0]
274 | image_data = tf.gfile.FastGFile(image_path, 'r').read()
275 |
276 | shape = image.shape
277 | image_name = str(util.io.get_filename(image_path).split('.')[0])
278 | example = convert_to_example(image_data, image_name, labels, ignored, txts, rect_bboxes, oriented_bboxes, shape)
279 | tfrecord_writer.write(example.SerializeToString())
280 | record_count += 1
281 |
282 | fid += 1
283 |
284 | if __name__ == "__main__":
285 | mat_path = util.io.get_absolute_path('/share/SynthText/gt.mat')
286 | root_path = util.io.get_absolute_path('/share/SynthText/')
287 | output_dir = util.io.get_absolute_path('/home/zsz/datasets/synth-tf/')
288 | util.io.mkdir(output_dir)
289 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'SynthText_%d.tfrecord'), data_path = root_path, gt_path = mat_path)
290 |
--------------------------------------------------------------------------------
/datasets/sythtextprovider.py:
--------------------------------------------------------------------------------
1 | ## an initial version
2 | ## Transform the tfrecord to slim data provider format
3 |
4 | import numpy
5 | import tensorflow as tf
6 | import os
7 | import tensorflow.contrib.slim as slim
8 | import glob
9 |
10 |
11 |
12 | ITEMS_TO_DESCRIPTIONS = {
13 | 'image': 'A color image of varying height and width.',
14 | 'shape': 'Shape of the image',
15 | 'object/bbox': 'A list of bounding boxes, one per each object.',
16 | 'object/label': 'A list of labels, one per each object.',
17 | }
18 | SPLITS_TO_SIZES = {
19 | #'train': 2518, for ppt datasets
20 | 'train': 858750 # for synth text datasets
21 | }
22 | def get_datasets(data_dir,file_pattern = '*.tfrecord'):
23 | file_patterns = os.path.join(data_dir, file_pattern)
24 | print('file_path: {}'.format(file_patterns))
25 | file_path_list = glob.glob(file_patterns)
26 | num_samples = 0
27 | #num_samples = 288688
28 | #num_samples = 858750 only for synth datasets
29 | for file_path in file_path_list:
30 | for record in tf.python_io.tf_record_iterator(file_path):
31 | num_samples += 1
32 | print('num_samples:', num_samples)
33 | reader = tf.TFRecordReader
34 | keys_to_features = {
35 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
36 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'),
37 | 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''),
38 | 'image/shape': tf.FixedLenFeature([3], tf.int64),
39 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32),
40 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32),
41 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32),
42 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32),
43 | 'image/object/bbox/x1': tf.VarLenFeature(dtype=tf.float32),
44 | 'image/object/bbox/x2': tf.VarLenFeature(dtype=tf.float32),
45 | 'image/object/bbox/x3': tf.VarLenFeature(dtype=tf.float32),
46 | 'image/object/bbox/x4': tf.VarLenFeature(dtype=tf.float32),
47 | 'image/object/bbox/y1': tf.VarLenFeature(dtype=tf.float32),
48 | 'image/object/bbox/y2': tf.VarLenFeature(dtype=tf.float32),
49 | 'image/object/bbox/y3': tf.VarLenFeature(dtype=tf.float32),
50 | 'image/object/bbox/y4': tf.VarLenFeature(dtype=tf.float32),
51 | 'image/object/bbox/ignored': tf.VarLenFeature(dtype=tf.int64),
52 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64),
53 | }
54 |
55 | items_to_handlers = {
56 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
57 | 'shape': slim.tfexample_decoder.Tensor('image/shape'),
58 | 'filename': slim.tfexample_decoder.Tensor('image/filename'),
59 | 'object/bbox': slim.tfexample_decoder.BoundingBox(
60 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
61 | 'object/oriented_bbox/x1': slim.tfexample_decoder.Tensor('image/object/bbox/x1'),
62 | 'object/oriented_bbox/x2': slim.tfexample_decoder.Tensor('image/object/bbox/x2'),
63 | 'object/oriented_bbox/x3': slim.tfexample_decoder.Tensor('image/object/bbox/x3'),
64 | 'object/oriented_bbox/x4': slim.tfexample_decoder.Tensor('image/object/bbox/x4'),
65 | 'object/oriented_bbox/y1': slim.tfexample_decoder.Tensor('image/object/bbox/y1'),
66 | 'object/oriented_bbox/y2': slim.tfexample_decoder.Tensor('image/object/bbox/y2'),
67 | 'object/oriented_bbox/y3': slim.tfexample_decoder.Tensor('image/object/bbox/y3'),
68 | 'object/oriented_bbox/y4': slim.tfexample_decoder.Tensor('image/object/bbox/y4'),
69 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'),
70 | 'object/ignored': slim.tfexample_decoder.Tensor('image/object/bbox/ignored')
71 | }
72 |
73 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
74 |
75 | labels_to_names = {0:'background', 1:'text'}
76 | return slim.dataset.Dataset(
77 | data_sources=file_patterns,
78 | reader=reader,
79 | decoder=decoder,
80 | num_samples=num_samples,
81 | items_to_descriptions=ITEMS_TO_DESCRIPTIONS,
82 | num_classes=NUM_CLASSES,
83 | labels_to_names=labels_to_names)
84 |
85 | NUM_CLASSES = 2
86 |
--------------------------------------------------------------------------------
/datasets/sythtextprovider.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/sythtextprovider.pyc
--------------------------------------------------------------------------------
/datasets/xml_to_tfrecords.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import sys
4 | import random
5 | import numpy as np
6 | import tensorflow as tf
7 | import xml.etree.ElementTree as ET
8 | from datasets.dataset_utils import int64_feature, float_feature, bytes_feature
9 | import tensorflow.contrib.slim as slim
10 |
11 | # TFRecords convertion parameters.
12 |
13 | TXT_LABELS = {
14 | 'none': (0, 'Background'),
15 | 'text': (1, 'Text')
16 | }
17 |
18 | def _process_image(train_img_path, train_xml_path, name):
19 | """Process a image and annotation file.
20 |
21 | Args:
22 | filename: string, path to an image file e.g., '/path/to/example.JPG'.
23 | coder: instance of ImageCoder to provide TensorFlow image coding utils.
24 | Returns:
25 | image_buffer: string, JPEG encoding of RGB image.
26 | height: integer, image height in pixels.
27 | width: integer, image width in pixels.
28 | """
29 | # Read the image file.
30 |
31 | image_data = tf.gfile.FastGFile(train_img_path, 'rb').read()
32 |
33 | tree = ET.parse(train_xml_path)
34 | root = tree.getroot()
35 | # Image shape.
36 | size = root.find('size')
37 | height = int(size.find('height').text)
38 | width = int(size.find('width').text)
39 | depth = int(size.find('depth').text)
40 | if height <= 0 or width <= 0 or depth <= 0:
41 | print('height or width depth error',height, width, depth)
42 | return
43 |
44 | shape = [int(size.find('height').text),
45 | int(size.find('width').text),
46 | int(size.find('depth').text)]
47 | # Find annotations.
48 | bboxes = []
49 | labels = []
50 | labels_text = []
51 | difficult = []
52 | truncated = []
53 | oriented_bbox = []
54 | ignored = 0
55 | filename = root.find('filename').text
56 |
57 | for obj in root.findall('object'):
58 | label = obj.find('name').text
59 | if label == 'none':
60 | label = 'none'
61 | else:
62 | label = 'text'
63 | labels.append(int(TXT_LABELS[label][0]))
64 | labels_text.append(label.encode('ascii'))
65 |
66 | if obj.find('difficult') is not None:
67 | #print('append difficult')
68 | difficult.append(int(obj.find('difficult').text))
69 | else:
70 | difficult.append(0)
71 | if obj.find('truncated'):
72 | truncated.append(int(obj.find('truncated').text))
73 | else:
74 | truncated.append(0)
75 |
76 | bbox = obj.find('bndbox')
77 | ymin = float(bbox.find('ymin').text)
78 | ymax = float(bbox.find('ymax').text)
79 | xmin = float(bbox.find('xmin').text)
80 | xmax = float(bbox.find('xmax').text)
81 |
82 | x1 = float(bbox.find('x1').text)
83 | x2 = float(bbox.find('x2').text)
84 | x3 = float(bbox.find('x3').text)
85 | x4 = float(bbox.find('x4').text)
86 |
87 | y1 = float(bbox.find('y1').text)
88 | y2 = float(bbox.find('y2').text)
89 | y3 = float(bbox.find('y3').text)
90 | y4 = float(bbox.find('y4').text)
91 |
92 |
93 | ymin, ymax = np.clip([ymin, ymax], 0 , height)
94 | xmin, xmax = np.clip([xmin, xmax] ,0 , width)
95 |
96 | x1, x2, x3, x4 = np.clip([x1, x2, x3, x4], 0, width)
97 | y1, y2, y3, y4 = np.clip([y1, y2, y3, y4], 0, height)
98 |
99 | bboxes.append(( ymin / shape[0],
100 | xmin / shape[1],
101 | ymax / shape[0],
102 | xmax / shape[1]
103 | ))
104 |
105 | oriented_bbox.append((x1 / width, x2 / width, x3 / width, x4 /width, y1 /height, y2 / height, y3 / height, y4 / height))
106 |
107 | return image_data, shape, bboxes, labels, labels_text, difficult, truncated, oriented_bbox, ignored, filename
108 |
109 | def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
110 | difficult, truncated, oriented_bbox, ignored, filename):
111 | """Build an Example proto for an image example.
112 |
113 | Args:
114 | image_data: string, JPEG encoding of RGB image;
115 | labels: list of integers, identifier for the ground truth;
116 | labels_text: list of strings, human-readable labels;
117 | bboxes: list of bounding boxes; each box is a list of integers;
118 | specifying [ymin, xmin, ymax, xmax]. All boxes are assumed to belong
119 | to the same label as the image label.
120 | shape: 3 integers, image shapes in pixels.
121 | Returns:
122 | Example proto
123 | """
124 | xmin = []
125 | ymin = []
126 | xmax = []
127 | ymax = []
128 | for b in bboxes:
129 | assert len(b) == 4
130 | # pylint: disable=expression-not-assigned
131 | [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
132 | # pylint: enable=expression-not-assigned
133 |
134 | x1 = []
135 | x2 = []
136 | x3 = []
137 | x4 = []
138 |
139 | y1 = []
140 | y2 = []
141 | y3 = []
142 | y4 = []
143 |
144 | for orgin in oriented_bbox:
145 | assert len(orgin) == 8
146 | [l.append(point) for l, point in zip([x1, x2, x3, x4, y1, y2, y3, y4], orgin)]
147 |
148 | image_format = b'JPEG'
149 | example = tf.train.Example(features=tf.train.Features(feature={
150 | 'image/height': int64_feature(shape[0]),
151 | 'image/width': int64_feature(shape[1]),
152 | 'image/channels': int64_feature(shape[2]),
153 | 'image/shape': int64_feature(shape),
154 | 'image/filename': bytes_feature(filename.encode('utf-8')),
155 | 'image/object/bbox/xmin': float_feature(xmin),
156 | 'image/object/bbox/xmax': float_feature(xmax),
157 | 'image/object/bbox/ymin': float_feature(ymin),
158 | 'image/object/bbox/ymax': float_feature(ymax),
159 | 'image/object/bbox/x1': float_feature(x1),
160 | 'image/object/bbox/y1': float_feature(y1),
161 | 'image/object/bbox/x2': float_feature(x2),
162 | 'image/object/bbox/y2': float_feature(y2),
163 | 'image/object/bbox/x3': float_feature(x3),
164 | 'image/object/bbox/y3': float_feature(y3),
165 | 'image/object/bbox/x4': float_feature(x4),
166 | 'image/object/bbox/y4': float_feature(y4),
167 | 'image/object/bbox/label': int64_feature(labels),
168 | 'image/object/bbox/label_text': bytes_feature(labels_text),
169 | 'image/object/bbox/difficult': int64_feature(difficult),
170 | 'image/object/bbox/truncated': int64_feature(truncated),
171 | 'image/object/bbox/ignored': int64_feature(ignored),
172 | 'image/format': bytes_feature(image_format),
173 | 'image/encoded': bytes_feature(image_data)}))
174 | return example
175 |
176 |
177 | def _get_output_filename(output_dir, name, idx):
178 | return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)
179 |
180 |
181 | def _add_to_tfrecord(train_img_path, train_xml_path , name, tfrecord_writer):
182 | """Loads data from image and annotations files and add them to a TFRecord.
183 |
184 | Args:
185 | train_img_path: img path;
186 | train_xml_path: xml path
187 | name: Image name to add to the TFRecord;
188 | tfrecord_writer: The TFRecord writer to use for writing.
189 | """
190 | image_data, shape, bboxes, labels, labels_text, difficult, truncated, oriented_bbox, ignored, filename = \
191 | _process_image(train_img_path, train_xml_path , name)
192 | example = _convert_to_example(image_data, labels, labels_text,
193 | bboxes, shape, difficult, truncated, oriented_bbox, ignored, filename)
194 | tfrecord_writer.write(example.SerializeToString())
195 |
196 |
197 | def run(xml_img_txt_path, output_dir, name='icdar15_annotated_data', samples_per_files=200):
198 | """Runs the conversion operation.
199 | Args:
200 | xml_img_txt_path: The txt stored where the dataset is stored.
201 | output_dir: Output directory.
202 | """
203 |
204 | if not tf.gfile.Exists(output_dir):
205 | tf.gfile.MakeDirs(output_dir)
206 |
207 | train_txt = open(xml_img_txt_path, 'r')
208 | lines = train_txt.readlines()
209 | train_img_path = []
210 | train_xml_path = []
211 | error_list = []
212 | count = 0
213 | for line in lines:
214 | line = line.strip()
215 | if len(line.split(',')) == 2:
216 | count += 1
217 | train_img_path.append(line.split(',')[0])
218 | train_xml_path.append(line.split(',')[1])
219 | else:
220 | error_list.append(line)
221 | # print('line:',line)
222 | with open(os.path.join(output_dir, 'create_tfrecord_error_list.txt'), 'w') as f:
223 | f.writelines(error_list)
224 | filenames = train_img_path
225 |
226 | # Process dataset files
227 | i = 0
228 | fidx = 0
229 | while i < len(filenames):
230 | #Open new TFRecord file.
231 | tf_filename = _get_output_filename(output_dir, name, fidx)
232 |
233 | with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
234 | j = 0
235 | while i < len(filenames) and j < samples_per_files:
236 | sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))
237 | sys.stdout.flush()
238 |
239 | filename = filenames[i]
240 | img_name = filename.split('/')[-1][:-4]
241 | _add_to_tfrecord(filename, train_xml_path[i], img_name, tfrecord_writer)
242 | i += 1
243 | j += 1
244 | fidx += 1
245 | print('\nFinished converting the charts dataset!')
246 |
247 |
--------------------------------------------------------------------------------
/datasets/xml_to_tfrecords.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/xml_to_tfrecords.pyc
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import tensorflow as tf
4 | import cv2
5 | import tensorflow.contrib.slim as slim
6 | import codecs
7 | import sys
8 | import time
9 | import random
10 | sys.path.append('./')
11 |
12 | from nets import txtbox_384, np_methods, txtbox_768
13 | from processing import ssd_vgg_preprocessing
14 |
15 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' #using GPU 0
16 |
17 | def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5):
18 | """Visualize bounding boxes. Largely inspired by SSD-MXNET!
19 | """
20 | height = img.shape[0]
21 | width = img.shape[1]
22 | colors = dict()
23 | for i in range(classes.shape[0]):
24 | cls_id = int(classes[i])
25 | if cls_id >= 0:
26 | score = scores[i]
27 | if cls_id not in colors:
28 | colors[cls_id] = (random.random(), random.random(), random.random())
29 |
30 | xmin = int(bboxes[i, 0] * width)
31 | ymin = int(bboxes[i, 1] * height)
32 | xmax = int(bboxes[i, 2] * width)
33 | ymax = int(bboxes[i, 3] * height)
34 | img = cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 0))
35 | return img
36 |
37 | gpu_options = tf.GPUOptions(allow_growth=False, per_process_gpu_memory_fraction=0.3)
38 |
39 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options)
40 | isess = tf.Session(config=config)
41 |
42 | # Input placeholder.
43 | net_shape = (384, 384)
44 | #net_shape = (768, 768)
45 | data_format = 'NHWC'
46 | img_input = tf.placeholder(tf.float32, shape=(None, None, 3))
47 | # Evaluation pre-processing: resize to SSD net shape.
48 | image_pre, labels_pre, bboxes_pre, bbox_img, xs, ys = ssd_vgg_preprocessing.preprocess_for_eval(
49 | img_input, None, None,None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
50 | image_4d = tf.expand_dims(image_pre, 0)
51 | image_4d = tf.cast(image_4d, tf.float32)
52 | # Define the txt_box model.
53 | reuse = True if 'txt_net' in locals() else None
54 |
55 | txt_net = txtbox_384.TextboxNet()
56 | print(txt_net.params.img_shape)
57 | print('reuse:',reuse)
58 |
59 | with slim.arg_scope(txt_net.arg_scope(data_format=data_format)):
60 | predictions,localisations, logits, end_points = txt_net.net(image_4d, is_training=False, reuse=reuse)
61 |
62 | ckpt_dir = 'model'
63 |
64 | isess.run(tf.global_variables_initializer())
65 |
66 | saver = tf.train.Saver()
67 |
68 | ckpt_filename = tf.train.latest_checkpoint(ckpt_dir)
69 | if ckpt_dir and ckpt_filename:
70 | print('checkpoint:',ckpt_dir, os.getcwd(), ckpt_filename)
71 | saver.restore(isess, ckpt_filename)
72 | txt_anchors = txt_net.anchors(net_shape)
73 |
74 | def process_image(img, select_threshold=0.01, nms_threshold=.45, net_shape=net_shape):
75 | # Run txt network.
76 | startTime = time.time()
77 | rimg, rpredictions,rlogits,rlocalisations, rbbox_img = isess.run([image_4d, predictions, logits, localisations, bbox_img],
78 | feed_dict={img_input: img})
79 |
80 | end_time = time.time()
81 | print(end_time - startTime)
82 | # Get classes and bboxes from the net outputs
83 |
84 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(
85 | rpredictions, rlocalisations, txt_anchors,
86 | select_threshold=select_threshold, img_shape=net_shape, num_classes=2, decode=True)
87 |
88 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
89 | # print(rscores)
90 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400)
91 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
92 | # Resize bboxes to original image shape. Note: useless for Resize.WARP!
93 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
94 | return rclasses, rscores, rbboxes
95 |
96 |
97 | # Test on some demo image and visualize output.
98 | path = './demo/'
99 |
100 | img = cv2.imread(path + 'img_1.jpg')
101 | img_cp = img.copy()
102 | rclasses, rscores, rbboxes = process_image(img_cp)
103 |
104 | with codecs.open('demo/detections.txt', 'w', encoding='utf-8') as fout:
105 | for i in range(len(rclasses)):
106 | fout.write('{},{}\n'.format(rbboxes[i], rscores[i]))
107 | img_with_bbox = plt_bboxes(img_cp, rclasses, rscores, rbboxes)
108 | cv2.imwrite(os.path.join(path,'demo_res.png'), img_with_bbox)
109 | print('detection finished')
110 | else:
111 | raise ('no ckpt')
112 |
113 |
--------------------------------------------------------------------------------
/demo/example/image0.txt:
--------------------------------------------------------------------------------
1 | text 0.597695052624 3 39 80 51
2 |
--------------------------------------------------------------------------------
/demo/example/image0.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | test\image0.png
5 |
6 | 384
7 | 384
8 | 3
9 |
10 |
20 |
21 |
--------------------------------------------------------------------------------
/demo/example/standard.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | train_images
4 | img_10.jpg
5 |
6 | 1280
7 | 720
8 | 3
9 |
10 |
29 |
48 |
67 |
86 |
105 |
124 |
143 |
162 |
163 |
--------------------------------------------------------------------------------
/demo/img_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/demo/img_1.jpg
--------------------------------------------------------------------------------
/deployment/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/deployment/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/deployment/__init__.pyc
--------------------------------------------------------------------------------
/deployment/model_deploy.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/deployment/model_deploy.pyc
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | #! -*- encoding: utf-8 -*-
2 | import numpy as np
3 | import cv2
4 | import os
5 | from argparse import ArgumentParser
6 |
7 | import xml.etree.ElementTree as ET
8 | import shutil
9 | import sys
10 | import matplotlib.pyplot as plt
11 | info = sys.version_info
12 | if int(info[0]) == 2:
13 | reload(sys)
14 | sys.setdefaultencoding('utf-8') # 设置 'utf-8'
15 |
16 |
17 | def mat_inter(box1, box2):
18 | xmin_1, ymin_1, xmax_1, ymax_1 = box1
19 | xmin_2, ymin_2, xmax_2, ymax_2 = box2
20 | distance_between_box_x = abs((xmax_1 + xmin_1) / 2 - (xmax_2 + xmin_2) / 2)
21 | distance_between_box_y = abs((ymax_2 + ymin_2) / 2 - (ymin_1 + ymax_2) / 2)
22 |
23 | distance_box_1_x = abs(xmin_1 - xmax_1)
24 | distance_box_1_y = abs(ymax_1 - ymin_1)
25 | distance_box_2_x = abs(xmax_2 - xmin_2)
26 | distance_box_2_y = abs(ymax_2 - ymin_2)
27 |
28 | if distance_between_box_x < (distance_box_1_x + distance_box_2_x
29 | ) / 2 and distance_between_box_y < (
30 | distance_box_2_y + distance_box_1_y) / 2:
31 | return True
32 | else:
33 | return False
34 |
35 |
36 | class EVAL_MODEL(object):
37 | def __init__(self,
38 | eval_data_dir,
39 | pre_data_dir,
40 | data_type,
41 | save_result_path,
42 | iou_th=0.5,
43 | save_err_path='err_pic'):
44 | # print(eval_data_dir , pre_data_dir ,data_type, save_result_path)
45 | if eval_data_dir is None or pre_data_dir is None or data_type is None or save_result_path is None:
46 | raise ValueError(
47 | 'please input eval_data_dir or pre_data_dir or data_type or save_result_path'
48 | )
49 | self.eval_data_dir = eval_data_dir
50 | self.pre_data_dir = pre_data_dir
51 | self.data_type = data_type
52 | self.save_result_path = save_result_path
53 |
54 | self.allow_post_processing = False
55 | self.draw_err_pic_flag = True
56 | self.xml_path_list = []
57 | self.pre_img_name_list = []
58 | self.eval_data_dict = {}
59 | self.pre_data_dict = {}
60 |
61 | self.pre_data_num = 0
62 | self.gt_data_num = 0
63 | self.hit_precision = 0
64 | self.hit_recall = 0
65 | self.err_gt_dict = {
66 | 'ratio': [],
67 | 'h_scale': [],
68 | 'w_scale': [],
69 | 'height': [],
70 | 'width': []
71 | } #save box ratio and size
72 | self.iou_thresh = float(iou_th)
73 | self.save_err_path = os.path.join(save_err_path, str(self.iou_thresh))
74 | if not os.path.exists(self.save_err_path):
75 | os.makedirs(self.save_err_path)
76 |
77 | def list_from_str(self, st, dtype='float32'):
78 | line = st.split(' ')[2:6]
79 | if dtype == 'float32':
80 | line = [float(a) for a in line]
81 | else:
82 | line = [int(a) for a in line]
83 | return line
84 |
85 | def get_xml_path(self):
86 | for i in os.listdir(self.eval_data_dir):
87 | if i.split('.')[-1] == 'xml':
88 | self.xml_path_list.append(os.path.join(self.eval_data_dir, i))
89 | return self.xml_path_list
90 |
91 | def read_gts(self):
92 | if self.data_type == '1':
93 | if self.eval_data_dir is None or self.eval_data_dir is '':
94 | raise ValueError('---eval data dir not exists!!!!-----')
95 | for i in self.xml_path_list:
96 | img_name = os.path.splitext(os.path.basename(i))[0]
97 | # img_name = os.path.basename(i)
98 | xml_info = ET.parse(i)
99 | root_node = xml_info.getroot()
100 | bbox_list = []
101 | for obj_node in root_node.findall('object'):
102 | name_node = obj_node.find('name')
103 | # print(name_node)
104 | name = name_node.text
105 | if name == '&*@HUST_special' or name == '&*@HUST_shelter':
106 | continue
107 | difficult = int(obj_node.find('difficult').text)
108 | if difficult == 1:
109 | continue
110 |
111 | bndbox_node = obj_node.find('bndbox')
112 | xmin_filter = int(bndbox_node.find('xmin').text)
113 | ymin_filter = int(bndbox_node.find('ymin').text)
114 | xmax_filter = int(bndbox_node.find('xmax').text)
115 | ymax_filter = int(bndbox_node.find('ymax').text)
116 |
117 | bbox_list.append(
118 | [xmin_filter, ymin_filter, xmax_filter, ymax_filter])
119 |
120 | self.eval_data_dict[img_name] = bbox_list
121 | elif self.data_type == '2':
122 | pass
123 | else:
124 | raise ValueError(' data type error !!! ')
125 |
126 | #list format: xmin ,ymin, xmax, ymax
127 | def bbox_iou_eval(self, list1, list2):
128 | xx1 = np.maximum(list1[0], list2[0])
129 | yy1 = np.maximum(list1[1], list2[1])
130 | xx2 = np.minimum(list1[2], list2[2])
131 | yy2 = np.minimum(list1[3], list2[3])
132 |
133 | w = np.maximum(0.0, xx2 - xx1 + 1)
134 | h = np.maximum(0.0, yy2 - yy1 + 1)
135 | inter = w * h
136 | area1 = (list1[2] - list1[0] + 1) * (list1[3] - list1[1] + 1)
137 | area2 = (list2[2] - list2[0] + 1) * (list2[3] - list2[1] + 1)
138 | iou = float(inter) / (area1 + area2 - inter)
139 | return iou
140 |
141 | def read_pres(self):
142 | for i in os.listdir(self.pre_data_dir):
143 | if i.split('.')[-1] == 'txt':
144 | img_name = os.path.splitext(os.path.basename(i))[0]
145 | # img_name = os.path.basename(i)
146 | self.pre_img_name_list.append(img_name)
147 | nms_outputs = open(os.path.join(self.pre_data_dir, i),
148 | 'r').readlines()
149 | dt_lines = [l.strip() for l in nms_outputs]
150 | dt_lines = [
151 | self.list_from_str(dt, dtype='int32') for dt in dt_lines
152 | ] # score xmin ymin xmax ymax
153 | boxes = dt_lines
154 | bbox_without_same = []
155 | for box in boxes:
156 | if box[1] != box[3]:
157 | bbox_without_same.append(box)
158 | boxes = bbox_without_same
159 | final_bboxes = []
160 | if self.allow_post_processing is True:
161 | #box format : score xmin ymin xmax ymax
162 | del_index = []
163 | for i, box in enumerate(boxes):
164 | if i in del_index:
165 | continue
166 | if len(boxes[i + 1:]) == 0:
167 | if box not in final_bboxes and i not in del_index:
168 | final_bboxes.append(box)
169 | break
170 | for j, box_2rd in enumerate(boxes[i + 1:]):
171 | if j in del_index:
172 | continue
173 | ymin_second = int(box_2rd[1])
174 | ymax_second = int(box_2rd[3])
175 | xmin_second = int(box_2rd[0])
176 | xmax_second = int(box_2rd[2])
177 | ymin_first = int(box[1])
178 | ymax_first = int(box[3])
179 | xmin_first = int(box[0])
180 | xmax_first = int(box[2])
181 | if abs(ymin_second - ymin_first) <= 5 and abs(
182 | ymax_first -
183 | ymax_second) <= 5 and mat_inter(
184 | box, box_2rd):
185 |
186 | xmin_final = min(xmin_first, xmin_second)
187 | ymin_final = min(ymin_first, ymin_second)
188 | xmax_final = max(xmax_first, xmax_second)
189 | ymax_final = max(ymax_first, ymax_second)
190 | temp_box = [
191 | xmin_final, ymin_final, xmax_final,
192 | ymax_final
193 | ]
194 | del_index.append(i)
195 | del_index.append(j + i + 1)
196 | box = temp_box
197 | final_bboxes.append(box)
198 | dt_lines = final_bboxes
199 |
200 | self.pre_data_dict[img_name] = dt_lines
201 |
202 | def contrast_pre_gt(self):
203 | for img_name in self.pre_img_name_list:
204 | pre_gts = self.pre_data_dict[img_name]
205 | eval_gts = self.eval_data_dict[img_name]
206 | error_pre = []
207 | error_eval = []
208 |
209 | for i, eval_gt in enumerate(eval_gts):
210 | flag_strick = 0
211 | err_flag = False
212 | for j, pre_gt in enumerate(pre_gts):
213 | iou = self.bbox_iou_eval(eval_gt, pre_gt)
214 | # print(iou)
215 | if iou >= self.iou_thresh:
216 | flag_strick += 1
217 | err_flag = False
218 | break
219 | else:
220 | err_flag = True
221 | if flag_strick >= 1:
222 | self.hit_precision += 1
223 | self.hit_recall += 1
224 | self.gt_data_num += 1
225 | if err_flag is True:
226 | error_eval.append(eval_gt)
227 |
228 | self.pre_data_num += len(pre_gts)
229 |
230 | for i, pre_gt in enumerate(pre_gts):
231 | err_flag = False
232 | for j, eval_gt in enumerate(eval_gts):
233 | iou = self.bbox_iou_eval(pre_gt, eval_gt)
234 | if iou >= self.iou_thresh:
235 | err_flag = False
236 | break
237 | else:
238 | err_flag = True
239 | if err_flag is True:
240 | error_pre.append(pre_gt)
241 | if len(error_eval) != 0 or len(error_pre) != 0:
242 | self.save_error_eval(img_name, error_eval, error_pre)
243 |
244 | def save_error_eval(self, img_name, error_eval_list, error_pre_list):
245 | # img_path = os.path.join(self.eval_data_dir, img_name)
246 | img_path = os.path.join(self.eval_data_dir, img_name + '.jpg')
247 | img = cv2.imread(img_path)
248 |
249 | if img is None:
250 | img_path = os.path.join(self.eval_data_dir, img_name + '.PNG')
251 | img = cv2.imread(img_path)
252 | img_h, img_w, _ = img.shape
253 | with open(
254 | os.path.join(self.save_err_path, img_name + '_err_gt.txt'),
255 | 'w') as f_target:
256 | for err_eval in error_eval_list:
257 | # print(err_eval)
258 | width = int(err_eval[2]) - int(err_eval[0])
259 | height = int(err_eval[3]) - int(err_eval[1])
260 | if width != 0:
261 | self.err_gt_dict['ratio'].append(float(height) / width)
262 | else:
263 | self.err_gt_dict['ratio'].append(0.)
264 | self.err_gt_dict['h_scale'].append(float(height) / img_h)
265 | self.err_gt_dict['w_scale'].append(float(width) / img_w)
266 | self.err_gt_dict['height'].append(height)
267 | self.err_gt_dict['width'].append(width)
268 | img = self.draw_polygon(
269 | img,
270 | err_eval[0],
271 | err_eval[1],
272 | err_eval[2],
273 | err_eval[3],
274 | is_gt=True)
275 |
276 | f_target.write(','.join([str(i) for i in err_eval]) + '\n')
277 | with open(
278 | os.path.join(self.save_err_path, img_name + '_err_pre.txt'),
279 | 'w') as f_target:
280 | for err_pre in error_pre_list:
281 | f_target.write(','.join([str(j) for j in err_pre]) + '\n')
282 | img = self.draw_polygon(
283 | img,
284 | err_pre[0],
285 | err_pre[1],
286 | err_pre[2],
287 | err_pre[3],
288 | is_gt=False)
289 | cv2.imwrite(os.path.join(self.save_err_path, img_name + '.png'), img)
290 |
291 | def draw_polygon(self, img, xmin, ymin, xmax, ymax, is_gt=False):
292 | xmin = int(xmin)
293 | ymin = int(ymin)
294 | xmax = int(xmax)
295 | ymax = int(ymax)
296 | if is_gt is True:
297 | color = (255, 0, 0)
298 | else:
299 | color = (0, 255, 0)
300 | cv2.line(img, (xmin, ymin), (xmax, ymin), color, 2)
301 | cv2.line(img, (xmax, ymin), (xmax, ymax), color, 2)
302 | cv2.line(img, (xmax, ymax), (xmin, ymax), color, 2)
303 | cv2.line(img, (xmin, ymax), (xmin, ymin), color, 2)
304 | return img
305 |
306 | def cal_precision_recall(self):
307 | recall = float(self.hit_recall) / self.gt_data_num
308 | precision = float(self.hit_precision) / self.pre_data_num
309 | if recall != 0 and precision != 0:
310 | f_measure = 2 * recall * precision / (recall + precision)
311 | else:
312 | f_measure = 0
313 | return [precision, recall, f_measure]
314 |
315 | def draw_err_gt(self):
316 | err_gt_dict = self.err_gt_dict
317 | for key in err_gt_dict.keys():
318 | values = err_gt_dict[key]
319 | fig = plt.figure()
320 | ax = fig.add_subplot(111)
321 | numBins = 50
322 | (counts, bins, patch) = ax.hist(
323 | values, numBins, color='blue', alpha=0.4, rwidth=0.5)
324 | #print('*****', key, '******')
325 | #print(counts)
326 | #print(bins)
327 | for i in range(len(counts)):
328 | plt.text(
329 | bins[i],
330 | counts[i],
331 | str(int(counts[i])),
332 | fontdict={
333 | 'size': 6,
334 | 'color': 'r'
335 | })
336 | if key in ['h_scale', 'w_scale', 'ratio']:
337 | mid = round((float(bins[i]) + float(bins[i + 1])) / 2, 2)
338 | else:
339 | mid = int(bins[i] + bins[i + 1] / 2)
340 | #if i % 5 == 0:
341 | plt.text(
342 | bins[i],
343 | counts[i] + 20,
344 | str(mid),
345 | fontdict={
346 | 'size': 10,
347 | 'color': 'b'
348 | })
349 | #print(patch)
350 | plt.grid(True)
351 | plt.title(u'{}'.format(key))
352 | plt.savefig('{}/{}.png'.format(self.save_err_path, key))
353 | with open('{}/{}.txt'.format(self.save_err_path, key), 'w') as f:
354 | for value in values:
355 | f.write('{}\n'.format(value))
356 |
357 | def start_eval(self):
358 | print('----start eval----')
359 | print('---get xml path---')
360 | self.get_xml_path()
361 | print('---reading gts----')
362 | self.read_gts()
363 | print('---reading predictions---')
364 | self.read_pres()
365 | print('---contrast pre gt-----')
366 | self.contrast_pre_gt()
367 | pre, recall, f_measure = self.cal_precision_recall()
368 | print('pre:{} recall:{} f_measure:{}'.format(pre, recall, f_measure))
369 | with open(self.save_result_path, 'a+') as f:
370 | f.write('iou:{} pre:{} recall:{} f_measure:{}\n'.format(
371 | self.iou_thresh, pre, recall, f_measure))
372 |
373 | if self.draw_err_pic_flag == True:
374 | self.draw_err_gt()
375 | print('-----end eval------')
376 |
377 |
378 | if __name__ == '__main__':
379 |
380 | parser = ArgumentParser(description='icdar15 eval model')
381 | parser.add_argument(
382 | '--eval_data_dir',
383 | '-d',
384 | default=
385 | '/home/zsz/datasets/icdar15/test_gts/',
386 | type=str)
387 | #xml and img in same dir
388 | parser.add_argument('--pre_data_dir', '-p', type=str)
389 | #pre_data_dir is prediction format txt: text score xmin ymin xmax ymax
390 | parser.add_argument('--eval_file_type', '-f', default='1', type=str)
391 | parser.add_argument(
392 | '--save_result_path', '-s', default='result.txt', type=str)
393 | parser.add_argument('--iou_th', '-o', default=0.5, type=float)
394 | #1:read gt from xml format:(xmin ymin xmax ymax)
395 | #2:read gt from txt format:(text score xmin ymin xmax ymax)
396 | args = parser.parse_args()
397 | # print(args, args.eval_data_dir)
398 | eval_data_dir = args.eval_data_dir
399 | pre_data_dir = args.pre_data_dir
400 | eval_file_type = args.eval_file_type
401 | save_result_path = args.save_result_path
402 | iou_th = args.iou_th
403 |
404 | eval_model = EVAL_MODEL(eval_data_dir, pre_data_dir, eval_file_type,
405 | save_result_path, iou_th)
406 | eval_model.start_eval()
407 |
--------------------------------------------------------------------------------
/eval_result.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | import os
5 |
6 | from apscheduler.schedulers.blocking import BlockingScheduler
7 |
8 | import time
9 | import tensorflow as tf
10 |
11 | scheduler = BlockingScheduler()
12 | model_dir = './model'
13 | def eval_net():
14 | # ckpt_list = tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths
15 |
16 | # for ckpt_path in ckpt_list:
17 | # if os.path.exists('{}.data-00000-of-00001'.format(ckpt_path)):
18 | print(time.asctime())
19 | time.sleep(5)
20 | ckpt_path = tf.train.latest_checkpoint(model_dir)
21 | os.system("python test.py -c -0 -m {} -o test_result/{}".format(ckpt_path, ckpt_path.split('/')[-1]))
22 |
23 | eval_net()
24 | scheduler.add_job(eval_net, 'interval', seconds=1200)
25 |
26 | scheduler.start()
27 |
--------------------------------------------------------------------------------
/gene_tfrecords.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import tensorflow as tf
3 |
4 | from datasets import xml_to_tfrecords
5 | import os
6 | FLAGS = tf.app.flags.FLAGS
7 |
8 | tf.app.flags.DEFINE_string(
9 | 'output_name', 'annotated_data',
10 | 'Basename used for TFRecords output files.')
11 | tf.app.flags.DEFINE_string(
12 | 'output_dir', 'tfrecords',
13 | 'Output directory where to store TFRecords files.')
14 | tf.app.flags.DEFINE_string(
15 | 'xml_img_txt_path', None,
16 | 'the path means the txt'
17 | )
18 |
19 | tf.app.flags.DEFINE_integer(
20 | 'samples_per_files', 2000,
21 | 'the number means one tf_record save how many pictures'
22 | )
23 |
24 | def main(_):
25 | if not FLAGS.xml_img_txt_path or not os.path.exists(FLAGS.xml_img_txt_path):
26 | raise ValueError('You must supply the dataset directory with --xml_img_txt_path')
27 | print('Dataset directory:', FLAGS.xml_img_txt_path)
28 | print('Output directory:', FLAGS.output_dir)
29 |
30 | xml_to_tfrecords.run(FLAGS.xml_img_txt_path, FLAGS.output_dir, FLAGS.output_name, FLAGS.samples_per_files)
31 |
32 | if __name__ == '__main__':
33 | tf.app.run()
34 |
--------------------------------------------------------------------------------
/logs/train_xml.txt:
--------------------------------------------------------------------------------
1 | /home/zsz/datasets/weblmtImage 26081096.0_242.0_1377.0_456.0.png,save_xml/weblmtImage 26081096.0_242.0_1377.0_456.0.xml
2 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/nets/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__init__.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/custom_layers.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/custom_layers.cpython-35.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/np_methods.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/np_methods.cpython-35.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/textbox_common.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/textbox_common.cpython-35.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/txtbox_384.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/txtbox_384.cpython-35.pyc
--------------------------------------------------------------------------------
/nets/__pycache__/txtbox_768.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/txtbox_768.cpython-35.pyc
--------------------------------------------------------------------------------
/nets/custom_layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Implement some custom layers, not provided by TensorFlow.
16 |
17 | Trying to follow as much as possible the style/standards used in
18 | tf.contrib.layers
19 | """
20 | import tensorflow as tf
21 |
22 | from tensorflow.contrib.framework.python.ops import add_arg_scope
23 | from tensorflow.contrib.layers.python.layers import initializers
24 | from tensorflow.contrib.framework.python.ops import variables
25 | from tensorflow.contrib.layers.python.layers import utils
26 | from tensorflow.python.ops import nn
27 | from tensorflow.python.ops import init_ops
28 | from tensorflow.python.ops import variable_scope
29 |
30 |
31 | def abs_smooth(x):
32 | """Smoothed absolute function. Useful to compute an L1 smooth error.
33 |
34 | Define as:
35 | x^2 / 2 if abs(x) < 1
36 | abs(x) - 0.5 if abs(x) > 1
37 | We use here a differentiable definition using min(x) and abs(x). Clearly
38 | not optimal, but good enough for our purpose!
39 | """
40 | absx = tf.abs(x)
41 | minx = tf.minimum(absx, 1)
42 | r = 0.5 * ((absx - 1) * minx + absx)
43 | return r
44 |
45 |
46 | @add_arg_scope
47 | def l2_normalization(
48 | inputs,
49 | scaling=False,
50 | scale_initializer=init_ops.ones_initializer(),
51 | reuse=None,
52 | variables_collections=None,
53 | outputs_collections=None,
54 | data_format='NHWC',
55 | trainable=True,
56 | scope=None):
57 | """Implement L2 normalization on every feature (i.e. spatial normalization).
58 |
59 | Should be extended in some near future to other dimensions, providing a more
60 | flexible normalization framework.
61 |
62 | Args:
63 | inputs: a 4-D tensor with dimensions [batch_size, height, width, channels].
64 | scaling: whether or not to add a post scaling operation along the dimensions
65 | which have been normalized.
66 | scale_initializer: An initializer for the weights.
67 | reuse: whether or not the layer and its variables should be reused. To be
68 | able to reuse the layer scope must be given.
69 | variables_collections: optional list of collections for all the variables or
70 | a dictionary containing a different list of collection per variable.
71 | outputs_collections: collection to add the outputs.
72 | data_format: NHWC or NCHW data format.
73 | trainable: If `True` also add variables to the graph collection
74 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
75 | scope: Optional scope for `variable_scope`.
76 | Returns:
77 | A `Tensor` representing the output of the operation.
78 | """
79 |
80 | with variable_scope.variable_scope(
81 | scope, 'L2Normalization', [inputs], reuse=reuse) as sc:
82 | inputs_shape = inputs.get_shape()
83 | inputs_rank = inputs_shape.ndims
84 | dtype = inputs.dtype.base_dtype
85 | if data_format == 'NHWC':
86 | # norm_dim = tf.range(1, inputs_rank-1)
87 | norm_dim = tf.range(inputs_rank-1, inputs_rank)
88 | params_shape = inputs_shape[-1:]
89 | elif data_format == 'NCHW':
90 | # norm_dim = tf.range(2, inputs_rank)
91 | norm_dim = tf.range(1, 2)
92 | params_shape = (inputs_shape[1])
93 |
94 | # Normalize along spatial dimensions.
95 | outputs = nn.l2_normalize(inputs, norm_dim, epsilon=1e-12)
96 | # Additional scaling.
97 | if scaling:
98 | scale_collections = utils.get_variable_collections(
99 | variables_collections, 'scale')
100 | scale = variables.model_variable('gamma',
101 | shape=params_shape,
102 | dtype=dtype,
103 | initializer=scale_initializer,
104 | collections=scale_collections,
105 | trainable=trainable)
106 | if data_format == 'NHWC':
107 | outputs = tf.multiply(outputs, scale)
108 | elif data_format == 'NCHW':
109 | scale = tf.expand_dims(scale, axis=-1)
110 | scale = tf.expand_dims(scale, axis=-1)
111 | outputs = tf.multiply(outputs, scale)
112 | # outputs = tf.transpose(outputs, perm=(0, 2, 3, 1))
113 |
114 | return utils.collect_named_outputs(outputs_collections,
115 | sc.original_name_scope, outputs)
116 |
117 |
118 | @add_arg_scope
119 | def pad2d(inputs,
120 | pad=(0, 0),
121 | mode='CONSTANT',
122 | data_format='NHWC',
123 | trainable=True,
124 | scope=None):
125 | """2D Padding layer, adding a symmetric padding to H and W dimensions.
126 |
127 | Aims to mimic padding in Caffe and MXNet, helping the port of models to
128 | TensorFlow. Tries to follow the naming convention of `tf.contrib.layers`.
129 |
130 | Args:
131 | inputs: 4D input Tensor;
132 | pad: 2-Tuple with padding values for H and W dimensions;
133 | mode: Padding mode. C.f. `tf.pad`
134 | data_format: NHWC or NCHW data format.
135 | """
136 | with tf.name_scope(scope, 'pad2d', [inputs]):
137 | # Padding shape.
138 | if data_format == 'NHWC':
139 | paddings = [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]]
140 | elif data_format == 'NCHW':
141 | paddings = [[0, 0], [0, 0], [pad[0], pad[0]], [pad[1], pad[1]]]
142 | net = tf.pad(inputs, paddings, mode=mode)
143 | return net
144 |
145 |
146 | @add_arg_scope
147 | def channel_to_last(inputs,
148 | data_format='NHWC',
149 | scope=None):
150 | """Move the channel axis to the last dimension. Allows to
151 | provide a single output format whatever the input data format.
152 |
153 | Args:
154 | inputs: Input Tensor;
155 | data_format: NHWC or NCHW.
156 | Return:
157 | Input in NHWC format.
158 | """
159 | with tf.name_scope(scope, 'channel_to_last', [inputs]):
160 | if data_format == 'NHWC':
161 | net = inputs
162 | elif data_format == 'NCHW':
163 | net = tf.transpose(inputs, perm=(0, 2, 3, 1))
164 | return net
165 |
--------------------------------------------------------------------------------
/nets/custom_layers.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/custom_layers.pyc
--------------------------------------------------------------------------------
/nets/np_methods.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Additional Numpy methods. Big mess of many things!
16 | """
17 | import numpy as np
18 |
19 |
20 | # =========================================================================== #
21 | # Numpy implementations of SSD boxes functions.
22 | # =========================================================================== #
23 | def ssd_bboxes_decode(feat_localizations,
24 | anchor_bboxes,
25 | prior_scaling=[0.1, 0.1, 0.2, 0.2]):
26 | """Compute the relative bounding boxes from the layer features and
27 | reference anchor bounding boxes.
28 |
29 | Return:
30 | numpy array Nx4: ymin, xmin, ymax, xmax
31 | """
32 | # Reshape for easier broadcasting.
33 | l_shape = feat_localizations.shape
34 |
35 | # feat_localizations = np.reshape(feat_localizations,
36 | # (-1, l_shape[-1]))
37 | feat_localizations = np.reshape(feat_localizations,
38 | (-1, l_shape[-2], l_shape[-1]))
39 | anchor_xmin, anchor_ymin, anchor_xmax, anchor_ymax = anchor_bboxes
40 |
41 | xref = (anchor_xmin + anchor_xmax) /2.
42 | yref = (anchor_ymin + anchor_ymax) /2.
43 |
44 | href = anchor_ymax - anchor_ymin
45 | wref = anchor_xmax - anchor_xmin
46 |
47 | decode_bbox_center_x = prior_scaling[0] * feat_localizations[:, :, 0] * wref + xref
48 | decode_bbox_center_y = prior_scaling[1] * feat_localizations[:, :, 1] * href + yref
49 | decode_bbox_width = np.exp(feat_localizations[:, :, 2] * prior_scaling[2]) * wref
50 | decode_bbox_height = np.exp(feat_localizations[: , :, 3] * prior_scaling[3]) * href
51 |
52 | xmin = decode_bbox_center_x - decode_bbox_width / 2.
53 | ymin = decode_bbox_center_y - decode_bbox_height / 2.
54 | xmax = decode_bbox_center_x + decode_bbox_width / 2.
55 | ymax = decode_bbox_center_y + decode_bbox_height / 2.
56 |
57 | x1 = prior_scaling[0] * feat_localizations[:, :, 4] * wref + anchor_xmin
58 | y1 = prior_scaling[1] * feat_localizations[:, :, 5] * href + anchor_ymin
59 | x2 = prior_scaling[0] * feat_localizations[:, :, 6] * wref + anchor_xmax
60 | y2 = prior_scaling[1] * feat_localizations[:, :, 7] * href + anchor_ymin
61 |
62 | x3 = prior_scaling[0] * feat_localizations[:, :, 8] * wref + anchor_xmax
63 | y3 = prior_scaling[1] * feat_localizations[:, :, 9] * href + anchor_ymax
64 | x4 = prior_scaling[0] * feat_localizations[:, :, 10] * wref + anchor_xmin
65 | y4 = prior_scaling[1] * feat_localizations[:, :, 11] * href + anchor_ymax
66 |
67 | # bboxes: ymin, xmin, xmax, ymax.
68 | bboxes = np.zeros_like(feat_localizations)
69 |
70 | bboxes[:,:,0] = xmin
71 | bboxes[:,:,1] = ymin
72 | bboxes[:,:,2] = xmax
73 | bboxes[:,:,3] = ymax
74 | bboxes[:, :, 4] = x1
75 | bboxes[:, :, 5] = x2
76 | bboxes[:, :, 6] = x3
77 | bboxes[:, :, 7] = x4
78 | bboxes[:, :, 8] = y1
79 | bboxes[:, :, 9] = y2
80 | bboxes[:, :, 10] = y3
81 | bboxes[:, :, 11] = y4
82 |
83 | # Back to original shape.
84 | bboxes = np.reshape(bboxes, l_shape)
85 | return bboxes
86 |
87 |
88 | def ssd_bboxes_select(predictions_net,
89 | localizations_net,
90 | anchors_net,
91 | select_threshold=0.5,
92 | img_shape=(384, 384),
93 | num_classes=2,
94 | decode=True):
95 | """Extract classes, scores and bounding boxes from network output layers.
96 |
97 | Return:
98 | classes, scores, bboxes: Numpy arrays...
99 | """
100 | l_classes = []
101 | l_scores = []
102 | l_bboxes = []
103 | for i in range(len(predictions_net)):
104 | classes, scores, bboxes = ssd_bboxes_select_layer(
105 | predictions_net[i], localizations_net[i], anchors_net[i],
106 | select_threshold, img_shape, num_classes, decode)
107 | l_classes.append(classes)
108 | l_scores.append(scores)
109 | l_bboxes.append(bboxes)
110 | # Debug information.
111 | # l_layers.append(i)
112 | # l_idxes.append((i, idxes))
113 |
114 | classes = np.concatenate(l_classes, 0)
115 | scores = np.concatenate(l_scores, 0)
116 | bboxes = np.concatenate(l_bboxes, 0)
117 | return classes, scores, bboxes
118 |
119 |
120 | def ssd_bboxes_select_layer(predictions_layer,
121 | localizations_layer,
122 | anchors_layer,
123 | select_threshold=0.2,
124 | img_shape=(384, 384),
125 | num_classes=2,
126 | decode=True):
127 | """Extract classes, scores and bounding boxes from features in one layer.
128 |
129 | Return:
130 | classes, scores, bboxes: Numpy arrays...
131 | """
132 | # First decode localizations features if necessary.
133 | if decode:
134 | localizations_layer = ssd_bboxes_decode(localizations_layer, anchors_layer)
135 |
136 | # Reshape features to: Batches x N x N_labels | 4.
137 | p_shape = predictions_layer.shape
138 | batch_size = p_shape[0] if len(p_shape) == 5 else 1
139 | predictions_layer = np.reshape(predictions_layer,
140 | (batch_size, -1, p_shape[-1]))
141 | l_shape = localizations_layer.shape
142 | localizations_layer = np.reshape(localizations_layer,
143 | (batch_size, -1, l_shape[-1]))
144 |
145 | # Boxes selection: use threshold or score > no-label criteria.
146 | if select_threshold is None or select_threshold == 0:
147 | # Class prediction and scores: assign 0. to 0-class
148 | classes = np.argmax(predictions_layer, axis=2)
149 | scores = np.argmax(predictions_layer, axis=2)
150 | mask = (classes > 0)
151 | classes = classes[mask]
152 | scores = scores[mask]
153 | bboxes = localizations_layer[mask]
154 | else:#two preditcions
155 | sub_predictions = predictions_layer[:, :, 1:]
156 | idxes = np.where(sub_predictions > select_threshold)
157 | classes = idxes[-1]+1
158 | scores = sub_predictions[idxes]
159 | bboxes = localizations_layer[idxes[:-1]]
160 |
161 | return classes, scores, bboxes
162 |
163 |
164 | # =========================================================================== #
165 | # Common functions for bboxes handling and selection.
166 | # =========================================================================== #
167 | def bboxes_sort(classes, scores, bboxes, top_k=400):
168 | """Sort bounding boxes by decreasing order and keep only the top_k
169 | """
170 | # if priority_inside:
171 | # inside = (bboxes[:, 0] > margin) & (bboxes[:, 1] > margin) & \
172 | # (bboxes[:, 2] < 1-margin) & (bboxes[:, 3] < 1-margin)
173 | # idxes = np.argsort(-scores)
174 | # inside = inside[idxes]
175 | # idxes = np.concatenate([idxes[inside], idxes[~inside]])
176 | idxes = np.argsort(-scores)
177 | classes = classes[idxes][:top_k]
178 | scores = scores[idxes][:top_k]
179 | bboxes = bboxes[idxes][:top_k]
180 | return classes, scores, bboxes
181 |
182 |
183 | def bboxes_clip(bbox_ref, bboxes):
184 | """Clip bounding boxes with respect to reference bbox.
185 | """
186 | bboxes = np.copy(bboxes)
187 | bboxes = np.transpose(bboxes)
188 | bbox_ref = np.transpose(bbox_ref)
189 | bboxes[0] = np.maximum(bboxes[0], bbox_ref[0])
190 | bboxes[1] = np.maximum(bboxes[1], bbox_ref[1])
191 | bboxes[2] = np.minimum(bboxes[2], bbox_ref[2])
192 | bboxes[3] = np.minimum(bboxes[3], bbox_ref[3])
193 |
194 | bboxes[4] = np.maximum(bboxes[4], bbox_ref[4])
195 | bboxes[5] = np.maximum(bboxes[5], bbox_ref[5])
196 | bboxes[6] = np.maximum(bboxes[6], bbox_ref[6])
197 | bboxes[7] = np.maximum(bboxes[7], bbox_ref[7])
198 | bboxes[8] = np.maximum(bboxes[8], bbox_ref[8])
199 | bboxes[9] = np.maximum(bboxes[9], bbox_ref[9])
200 | bboxes[10] = np.maximum(bboxes[10], bbox_ref[10])
201 | bboxes[11] = np.maximum(bboxes[11], bbox_ref[11])
202 |
203 | bboxes = np.transpose(bboxes)
204 | return bboxes
205 |
206 |
207 | def bboxes_resize(bbox_ref, bboxes):
208 | """Resize bounding boxes based on a reference bounding box,
209 | assuming that the latter is [0, 0, 1, 1] after transform.
210 | """
211 | bboxes = np.copy(bboxes)
212 | # Translate.
213 | bboxes[:, 0] -= bbox_ref[0]
214 | bboxes[:, 1] -= bbox_ref[1]
215 | bboxes[:, 2] -= bbox_ref[0]
216 | bboxes[:, 3] -= bbox_ref[1]
217 | # Resize.
218 | resize = [bbox_ref[2] - bbox_ref[0], bbox_ref[3] - bbox_ref[1]]
219 | bboxes[:, 0] /= resize[0]
220 | bboxes[:, 1] /= resize[1]
221 | bboxes[:, 2] /= resize[0]
222 | bboxes[:, 3] /= resize[1]
223 | return bboxes
224 |
225 |
226 | def bboxes_jaccard(bboxes1, bboxes2):
227 | """Computing jaccard index between bboxes1 and bboxes2.
228 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable.
229 | """
230 | bboxes1 = np.transpose(bboxes1)
231 | bboxes2 = np.transpose(bboxes2)
232 | # Intersection bbox and volume.
233 | int_ymin = np.maximum(bboxes1[0], bboxes2[0])
234 | int_xmin = np.maximum(bboxes1[1], bboxes2[1])
235 | int_ymax = np.minimum(bboxes1[2], bboxes2[2])
236 | int_xmax = np.minimum(bboxes1[3], bboxes2[3])
237 |
238 | int_h = np.maximum(int_ymax - int_ymin, 0.)
239 | int_w = np.maximum(int_xmax - int_xmin, 0.)
240 | int_vol = int_h * int_w
241 | # Union volume.
242 | vol1 = (bboxes1[2] - bboxes1[0]) * (bboxes1[3] - bboxes1[1])
243 | vol2 = (bboxes2[2] - bboxes2[0]) * (bboxes2[3] - bboxes2[1])
244 | jaccard = int_vol / (vol1 + vol2 - int_vol)
245 | return jaccard
246 |
247 |
248 | def bboxes_intersection(bboxes_ref, bboxes2):
249 | """Computing jaccard index between bboxes1 and bboxes2.
250 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable.
251 | """
252 | bboxes_ref = np.transpose(bboxes_ref)
253 | bboxes2 = np.transpose(bboxes2)
254 | # Intersection bbox and volume.
255 | int_ymin = np.maximum(bboxes_ref[0], bboxes2[0])
256 | int_xmin = np.maximum(bboxes_ref[1], bboxes2[1])
257 | int_ymax = np.minimum(bboxes_ref[2], bboxes2[2])
258 | int_xmax = np.minimum(bboxes_ref[3], bboxes2[3])
259 |
260 | int_h = np.maximum(int_ymax - int_ymin, 0.)
261 | int_w = np.maximum(int_xmax - int_xmin, 0.)
262 | int_vol = int_h * int_w
263 | # Union volume.
264 | vol = (bboxes_ref[2] - bboxes_ref[0]) * (bboxes_ref[3] - bboxes_ref[1])
265 | score = int_vol / vol
266 | return score
267 |
268 |
269 | def bboxes_nms(classes, scores, bboxes, nms_threshold=0.45):
270 | """Apply non-maximum selection to bounding boxes.
271 | """
272 | keep_bboxes = np.ones(scores.shape, dtype=np.bool)
273 | for i in range(scores.size-1):
274 | if keep_bboxes[i]:
275 | # Computer overlap with bboxes which are following.
276 | overlap = bboxes_jaccard(bboxes[i], bboxes[(i+1):])
277 | # Overlap threshold for keeping + checking part of the same class
278 | keep_overlap = np.logical_or(overlap < nms_threshold, classes[(i+1):] != classes[i])
279 | keep_bboxes[(i+1):] = np.logical_and(keep_bboxes[(i+1):], keep_overlap)
280 |
281 | idxes = np.where(keep_bboxes)
282 | return classes[idxes], scores[idxes], bboxes[idxes]
283 |
284 |
285 | def bboxes_nms_fast(classes, scores, bboxes, threshold=0.45):
286 | """Apply non-maximum selection to bounding boxes.
287 | """
288 | pass
289 |
290 |
--------------------------------------------------------------------------------
/nets/np_methods.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/np_methods.pyc
--------------------------------------------------------------------------------
/nets/textbox_common.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import tensorflow as tf
3 | import numpy as np
4 | import math
5 |
6 |
7 |
8 | # =========================================================================== #
9 | # TensorFlow implementation of Text Boxes encoding / decoding.
10 | # =========================================================================== #
11 |
12 | def tf_text_bboxes_encode_layer(glabels, bboxes,gxs , gys,
13 | anchors_layer,
14 | matching_threshold=0.1,
15 | prior_scaling=[0.1, 0.1, 0.2, 0.2],
16 | dtype=tf.float32):
17 |
18 | """
19 | Encode groundtruth labels and bounding boxes using Textbox anchors from
20 | one layer.
21 |
22 | Arguments:
23 | bboxes: Nx4 Tensor(float) with bboxes relative coordinates;
24 | gxs,gys: Nx4 Tensor
25 | anchors_layer: Numpy array with layer anchors;
26 | matching_threshold: Threshold for positive match with groundtruth bboxes;
27 | prior_scaling: Scaling of encoded coordinates.
28 |
29 | Return:
30 | (target_localizations, target_scores): Target Tensors.
31 | # this is a binary problem, so target_score and tartget_labels are same.
32 | """
33 | # Anchors coordinates and volume.
34 |
35 | # yref, xref, href, wref = anchors_layer
36 | xmin, ymin, xmax, ymax = anchors_layer
37 | xref = (xmin + xmax) /2
38 | yref = (ymin + ymax) /2
39 | href = ymax - ymin
40 | wref = xmax - xmin
41 | # caffe 源码是对每个pred_bbox(预测的bbox)和gt 进行overlap计算
42 | print(yref.shape)
43 | print(href.shape)
44 | print(bboxes.shape)
45 | # glabels = tf.Print(glabels, [tf.shape(glabels)], message=' glabels shape is :')
46 | #,,2
47 | # ymin = yref - href / 2.
48 | # xmin = xref - wref / 2.
49 | # ymax = yref + href / 2.
50 | # xmax = xref + wref / 2.
51 | vol_anchors = (xmax - xmin) * (ymax - ymin)
52 |
53 | # bboxes = tf.Print(bboxes, [tf.shape(bboxes), bboxes], message=' bboxes in encode shape:', summarize=20)
54 | # glabels = tf.Print(glabels, [tf.shape(glabels)], message=' glabels shape:')
55 |
56 | # xs = np.asarray(gxs, dtype=np.float32)
57 | # ys = np.asarray(gys, dtype=np.float32)
58 | # num_bboxes = xs.shape[0]
59 |
60 |
61 | # Initialize tensors...
62 | shape = (yref.shape[0])
63 | # all after the flatten
64 | feat_labels = tf.zeros(shape, dtype=tf.int64)
65 | feat_scores = tf.zeros(shape, dtype=dtype)
66 |
67 | feat_ymin = tf.zeros(shape, dtype=dtype)
68 | feat_xmin = tf.zeros(shape, dtype=dtype)
69 | feat_ymax = tf.ones(shape, dtype=dtype)
70 | feat_xmax = tf.ones(shape, dtype=dtype)
71 |
72 | feat_x1 = tf.zeros(shape, dtype=dtype)
73 | feat_x2 = tf.zeros(shape, dtype=dtype)
74 | feat_x3 = tf.zeros(shape, dtype=dtype)
75 | feat_x4 = tf.zeros(shape, dtype=dtype)
76 | feat_y1 = tf.zeros(shape, dtype=dtype)
77 | feat_y2 = tf.zeros(shape, dtype=dtype)
78 | feat_y3 = tf.zeros(shape, dtype=dtype)
79 | feat_y4 = tf.zeros(shape, dtype=dtype)
80 |
81 |
82 | # feat_x1 =tf.zeros()
83 |
84 | def jaccard_with_anchors(bbox):
85 | """
86 | Compute jaccard score between a box and the anchors.
87 | """
88 | int_ymin = tf.maximum(ymin, bbox[0])
89 | int_xmin = tf.maximum(xmin, bbox[1])
90 | int_ymax = tf.minimum(ymax, bbox[2])
91 | int_xmax = tf.minimum(xmax, bbox[3])
92 | h = tf.maximum(int_ymax - int_ymin, 0.)
93 | w = tf.maximum(int_xmax - int_xmin, 0.)
94 | # Volumes.
95 | inter_vol = h * w
96 | union_vol = vol_anchors - inter_vol \
97 | + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
98 | jaccard = tf.div(inter_vol, union_vol)
99 | return jaccard
100 |
101 |
102 | def intersection_with_anchors(bbox):
103 | '''
104 | Compute intersection between score a box and the anchors.
105 | '''
106 | int_ymin = tf.maximum(ymin, bbox[0])
107 | int_xmin = tf.maximum(xmin, bbox[1])
108 | int_ymax = tf.minimum(ymax, bbox[2])
109 | int_xmax = tf.minimum(xmax, bbox[3])
110 | h = tf.maximum(int_ymax - int_ymin, 0.)
111 | w = tf.maximum(int_xmax - int_xmin, 0.)
112 | inter_vol = h * w
113 | scores = tf.div(inter_vol, vol_anchors)
114 | return scores
115 |
116 | def condition(i,feat_labels, feat_scores,
117 | feat_ymin, feat_xmin, feat_ymax, feat_xmax,
118 | feat_x1, feat_x2, feat_x3, feat_x4,
119 | feat_y1, feat_y2, feat_y3, feat_y4):
120 | """Condition: check label index.
121 | """
122 |
123 | r = tf.less(i, tf.shape(glabels)[0])
124 |
125 | return r
126 |
127 | def body(i,feat_labels, feat_scores,feat_ymin, feat_xmin, feat_ymax, feat_xmax, feat_x1, feat_x2, feat_x3, feat_x4, feat_y1, feat_y2, feat_y3, feat_y4):
128 | """Body: update feature labels, scores and bboxes.
129 | Follow the original SSD paper for that purpose:
130 | - assign values when jaccard > 0.5;
131 | - only update if beat the score of other bboxes.
132 | """
133 | # Jaccard score.
134 | label = glabels[i]
135 | bbox = bboxes[i]
136 | gx = gxs[i]
137 | gy = gys[i]
138 | # i = tf.Print(i, [i , tf.shape(glabels), tf.shape(bboxes), tf.shape(gxs), tf.shape(gys)], message='i is :')
139 | jaccard = jaccard_with_anchors(bbox)
140 | # jaccard = tf.Print(jaccard, [tf.shape(jaccard), tf.nn.top_k(jaccard, 100, sorted=True)[0]], message=' jaccard :', summarize=100)
141 | # feat_scores = tf.Print(feat_scores, [tf.shape(feat_scores),tf.count_nonzero(feat_scores), tf.nn.top_k(feat_scores, 100, sorted=True)[0]], message= ' feat_scores: ', summarize= 100)
142 | # Mask: check threshold + scores + no annotations + num_classes.
143 | mask = tf.greater(jaccard, feat_scores)
144 |
145 | # i = tf.Print(i, [tf.shape(i), i], message= ' i is: ')
146 | # tf.Print(mask, [mask])
147 | mask = tf.logical_and(mask, tf.greater(jaccard, matching_threshold))
148 | # mask = tf.logical_and(mask, feat_scores > -0.5)
149 | # mask = tf.logical_and(mask, label < 2)
150 | # mask = tf.Print(mask, [tf.shape(mask), mask[0]], message=' mask is :')
151 | imask = tf.cast(mask, tf.int64)
152 | fmask = tf.cast(mask, dtype)
153 | # Update values using mask.
154 | feat_labels = imask * label + (1 - imask) * feat_labels
155 | feat_scores = tf.where(mask, jaccard, feat_scores)
156 | # bbox ymin xmin ymax xmax gxs gys
157 | # update all box
158 | # bbox = tf.Print(bbox, [tf.shape(bbox), bbox], message= ' bbox : ', summarize=20)
159 | # gx = tf.Print(gx, [gx], message=' gx: ', summarize=20)
160 | # gy = tf.Print(gy, [gy], message= ' gy: ', summarize=20)
161 | # fmask = tf.Print(fmask, [tf.shape(fmask), tf.count_nonzero(fmask), tf.nn.top_k(fmask, 100, sorted=True)[0]], message=' fmask :', summarize=100)
162 |
163 | feat_ymin = fmask * bbox[0] + (1 - fmask) * feat_ymin
164 | feat_xmin = fmask * bbox[1] + (1 - fmask) * feat_xmin
165 | feat_ymax = fmask * bbox[2] + (1 - fmask) * feat_ymax
166 | feat_xmax = fmask * bbox[3] + (1 - fmask) * feat_xmax
167 |
168 | feat_x1 = fmask * gx[0] + (1 - fmask) * feat_x1
169 | feat_x2 = fmask * gx[1] + (1 - fmask) * feat_x2
170 | feat_x3 = fmask * gx[2] + (1 - fmask) * feat_x3
171 | feat_x4 = fmask * gx[3] + (1 - fmask) * feat_x4
172 |
173 | feat_y1 = fmask * gy[0] + (1 - fmask) * feat_y1
174 | feat_y2 = fmask * gy[1] + (1 - fmask) * feat_y2
175 | feat_y3 = fmask * gy[2] + (1 - fmask) * feat_y3
176 | feat_y4 = fmask * gy[3] + (1 - fmask) * feat_y4
177 |
178 | # feat_x1 = tf.Print(feat_x1, [tf.count_nonzero(feat_x1), tf.nn.top_k(feat_x1, 100, sorted=True)[0]], message=' feat x1 : ', summarize=100)
179 |
180 | # feat_ymax = tf.Print(feat_ymax, [tf.shape(feat_ymax), tf.count_nonzero(feat_ymax), feat_ymax,tf.nn.top_k(feat_ymax, 100, sorted=True)[0]], message= ' feat_ymax :', summarize=100)
181 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), tf.count_nonzero(feat_ymin), feat_ymin, tf.nn.top_k(feat_ymax, 100, sorted=True)[0]], message= ' feat_ymin :', summarize=100)
182 | # feat_xmax = tf.Print(feat_xmax, [tf.shape(feat_xmax), tf.count_nonzero(feat_xmax), feat_xmax, tf.nn.top_k(feat_xmax, 100, sorted=True)[0]], message= ' feat_xmax : ', summarize=100)
183 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), tf.count_nonzero(feat_xmin), feat_xmin,tf.nn.top_k(feat_xmin, 100, sorted=True)[0]], message= ' feat_xmin: ' , summarize=100)
184 |
185 |
186 | # Check no annotation label: ignore these anchors...
187 | # interscts = intersection_with_anchors(bbox)
188 | #mask = tf.logical_and(interscts > ignore_threshold,
189 | # label == no_annotation_label)
190 | # Replace scores by -1.
191 | #feat_scores = tf.where(mask, -tf.cast(mask, dtype), feat_scores)
192 |
193 | return [i+1, feat_labels, feat_scores,
194 | feat_ymin, feat_xmin, feat_ymax, feat_xmax,
195 | feat_x1, feat_x2, feat_x3, feat_x4,
196 | feat_y1, feat_y2, feat_y3, feat_y4]
197 | # Main loop definition.
198 |
199 | i = 0
200 | [i,feat_labels, feat_scores,
201 | feat_ymin, feat_xmin,
202 | feat_ymax, feat_xmax,
203 | feat_x1, feat_x2, feat_x3, feat_x4,
204 | feat_y1, feat_y2, feat_y3, feat_y4] = tf.while_loop(condition, body,
205 | [i, feat_labels, feat_scores,
206 | feat_ymin, feat_xmin,
207 | feat_ymax, feat_xmax, feat_x1, feat_x2, feat_x3, feat_x4, feat_y1, feat_y2, feat_y3, feat_y4])
208 | # Transform to center / size.
209 | '''
210 | 这里的逻辑是用gt的外接水平矩形框与anchor/default box做匹配,得到iou的mask之后更新anchor对应的gt
211 | 然后求取anchor对应gt的偏移
212 | '''
213 | #
214 | # feat_ymax = tf.Print(feat_ymax, [tf.shape(feat_ymax), tf.count_nonzero(feat_ymax), feat_ymax], message= ' feat_ymax :', summarize=100)
215 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), tf.count_nonzero(feat_ymin), feat_ymin], message= ' feat_ymin :', summarize=100)
216 | # feat_xmax = tf.Print(feat_xmax, [tf.shape(feat_xmax), tf.count_nonzero(feat_xmax), feat_xmax], message= ' feat_xmax : ', summarize=100)
217 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), tf.count_nonzero(feat_xmin), feat_xmin], message= ' feat_xmin: ' , summarize=100)
218 |
219 |
220 | feat_cy = (feat_ymax + feat_ymin) / 2.
221 | feat_cx = (feat_xmax + feat_xmin) / 2.
222 | feat_h = feat_ymax - feat_ymin
223 | feat_w = feat_xmax - feat_xmin
224 | # Encode features.
225 |
226 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), feat_ymin], message= ' feat_ymin : ', summarize=20)
227 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), feat_xmin], message= ' feat_xmin : ', summarize=20)
228 |
229 | #
230 |
231 |
232 |
233 | # feat_cy = tf.Print(feat_cy, [tf.shape(feat_cy), feat_cy],message=' feat_cy : ', summarize=20)
234 | # feat_cx = tf.Print(feat_cx, [tf.shape(feat_cx), feat_cx],message=' feat_cy : ', summarize=20)
235 | # feat_h = tf.Print(feat_h, [tf.shape(feat_h), feat_h], message=' feat_h : ', summarize=20)
236 | # feat_w = tf.Print(feat_w, [tf.shape(feat_w), feat_w], message=' feat_w : ', summarize=20)
237 | #
238 | # yref = tf.Print(yref, [tf.shape(yref), yref], message=' yref : ',summarize=20)
239 | # xref = tf.Print(xref, [tf.shape(xref), xref], message=' xref : ',summarize=20)
240 | # href = tf.Print(href, [tf.shape(href), href], message=' href : ',
241 | # summarize=20)
242 | # wref = tf.Print(wref, [tf.shape(wref), wref], message=' wref : ', summarize=20)
243 |
244 | feat_ymin = (feat_cy - yref) / href / prior_scaling[1]
245 | feat_xmin = (feat_cx - xref) / wref / prior_scaling[0]
246 |
247 |
248 | feat_ymax = tf.log(feat_h / href) / prior_scaling[3]
249 | feat_xmax = tf.log(feat_w / wref) / prior_scaling[2]
250 |
251 |
252 | feat_x1 = (feat_x1 - xmin) / wref / prior_scaling[0]
253 | feat_x2 = (feat_x2 - xmax) / wref / prior_scaling[0]
254 | feat_x3 = (feat_x3 - xmax) / wref / prior_scaling[0]
255 | feat_x4 = (feat_x4 - xmin) / wref / prior_scaling[0]
256 |
257 | feat_y1 = (feat_y1 - ymin) / href / prior_scaling[1]
258 | feat_y2 = (feat_y2 - ymin) / href / prior_scaling[1]
259 | feat_y3 = (feat_y3 - ymax) / href / prior_scaling[1]
260 | feat_y4 = (feat_y4 - ymax) / href / prior_scaling[1]
261 |
262 | # Use SSD ordering: x / y / w / h instead of ours.
263 | # add xy1, 2,3,4
264 |
265 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), feat_ymin], message= ' feat_ymin : ', summarize=20)
266 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), feat_xmin], message= ' feat_xmin : ', summarize=20)
267 |
268 | feat_localizations = tf.stack([feat_xmin, feat_ymin, feat_xmax, feat_ymax ,feat_x1, feat_y1, feat_x2, feat_y2, feat_x3, feat_y3, feat_x4, feat_y4], axis=-1)
269 | # feat_localizations = tf.Print(feat_localizations, [tf.shape(feat_localizations), feat_localizations], message=' feat_localizations: ', summarize=20)
270 | return feat_labels, feat_localizations, feat_scores
271 |
272 |
273 |
274 | def tf_text_bboxes_encode(glabels, bboxes,
275 | anchors,gxs, gys,
276 | matching_threshold=0.1,
277 | prior_scaling=[0.1, 0.1, 0.2, 0.2],
278 | dtype=tf.float32,
279 | scope='text_bboxes_encode'):
280 | """Encode groundtruth labels and bounding boxes using SSD net anchors.
281 | Encoding boxes for all feature layers.
282 |
283 | Arguments:
284 | bboxes: Nx4 Tensor(float) with bboxes relative coordinates;
285 | anchors: List of Numpy array with layer anchors;
286 | gxs,gys:shape = (N,4) with x,y coordinates
287 | matching_threshold: Threshold for positive match with groundtruth bboxes;
288 | prior_scaling: Scaling of encoded coordinates.
289 |
290 | Return:
291 | (target_labels, target_localizations, target_scores):
292 | Each element is a list of target Tensors.
293 | """
294 |
295 | with tf.name_scope('text_bboxes_encode'):
296 | target_labels = []
297 | target_localizations = []
298 | target_scores = []
299 | for i, anchors_layer in enumerate(anchors):
300 | with tf.name_scope('bboxes_encode_block_%i' % i):
301 | t_label, t_loc, t_scores = \
302 | tf_text_bboxes_encode_layer(glabels, bboxes,gxs, gys, anchors_layer,
303 | matching_threshold,
304 | prior_scaling, dtype)
305 | target_localizations.append(t_loc)
306 | target_scores.append(t_scores)
307 | target_labels.append(t_label)
308 | return target_localizations, target_scores, target_labels
309 |
310 |
311 |
312 |
--------------------------------------------------------------------------------
/nets/textbox_common.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/textbox_common.pyc
--------------------------------------------------------------------------------
/nets/txtbox_384.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/txtbox_384.pyc
--------------------------------------------------------------------------------
/nets/txtbox_768.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/txtbox_768.pyc
--------------------------------------------------------------------------------
/processing/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/processing/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__init__.pyc
--------------------------------------------------------------------------------
/processing/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/processing/__pycache__/ssd_vgg_preprocessing.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__pycache__/ssd_vgg_preprocessing.cpython-35.pyc
--------------------------------------------------------------------------------
/processing/__pycache__/tf_image.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__pycache__/tf_image.cpython-35.pyc
--------------------------------------------------------------------------------
/processing/ssd_vgg_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Pre-processing images for SSD-type networks.
16 | """
17 | from enum import Enum, IntEnum
18 | import numpy as np
19 |
20 | import tensorflow as tf
21 | import tf_extended as tfe
22 |
23 | from tensorflow.python.ops import control_flow_ops
24 |
25 | from processing import tf_image
26 |
27 |
28 | slim = tf.contrib.slim
29 |
30 | # Resizing strategies.
31 | Resize = IntEnum('Resize', ('NONE', # Nothing!
32 | 'CENTRAL_CROP', # Crop (and pad if necessary).
33 | 'PAD_AND_RESIZE', # Pad, and resize to output shape.
34 | 'WARP_RESIZE')) # Warp resize.
35 |
36 | # VGG mean parameters.
37 | _R_MEAN = 123.68
38 | _G_MEAN = 116.78
39 | _B_MEAN = 103.94
40 |
41 | # Some training pre-processing parameters.
42 | BBOX_CROP_OVERLAP = 0.15 # Minimum overlap to keep a bbox after cropping.
43 | CROP_RATIO_RANGE = (0.5, 1.5) # Distortion ratio during cropping.
44 | EVAL_SIZE = (384, 384)
45 | AREA_RANGE = [1., 1.]
46 | MIN_OBJECT_COVERED = 0.5
47 |
48 |
49 | def _mean_image_subtraction(image, means):
50 |
51 | if image.get_shape().ndims != 3:
52 | raise ValueError('Input must be of size [height, width, C>0]')
53 | num_channels = image.get_shape().as_list()[-1]
54 | if len(means) != num_channels:
55 | raise ValueError('len(means) must match the number of channels')
56 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image)
57 | for i in range(num_channels):
58 | channels[i] -= means[i]
59 | return tf.concat(axis=2, values=channels)
60 |
61 | def tf_image_whitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN]):
62 | """Subtracts the given means from each image channel.
63 |
64 | Returns:
65 | the centered image.
66 | """
67 | if image.get_shape().ndims != 3:
68 | raise ValueError('Input must be of size [height, width, C>0]')
69 | num_channels = image.get_shape().as_list()[-1]
70 | if len(means) != num_channels:
71 | raise ValueError('len(means) must match the number of channels')
72 |
73 | mean = tf.constant(means, dtype=image.dtype)
74 | image = image - mean
75 | return image
76 |
77 |
78 | def tf_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True):
79 | """Re-convert to original image distribution, and convert to int if
80 | necessary.
81 |
82 | Returns:
83 | Centered image.
84 | """
85 | mean = tf.constant(means, dtype=image.dtype)
86 | image = image + mean
87 | if to_int:
88 | image = tf.cast(image, tf.int32)
89 | return image
90 |
91 |
92 | def np_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True):
93 | """Re-convert to original image distribution, and convert to int if
94 | necessary. Numpy version.
95 |
96 | Returns:
97 | Centered image.
98 | """
99 | img = np.copy(image)
100 | img += np.array(means, dtype=img.dtype)
101 | if to_int:
102 | img = img.astype(np.uint8)
103 | return img
104 |
105 |
106 | def tf_summary_image(image, bboxes, name='image', unwhitened=False):
107 | """Add image with bounding boxes to summary.
108 | """
109 | if unwhitened:
110 | image = tf_image_unwhitened(image)
111 | image = tf.expand_dims(image, 0)
112 | bboxes = tf.expand_dims(bboxes, 0)
113 | image_with_box = tf.image.draw_bounding_boxes(image, bboxes)
114 | tf.summary.image(name, image_with_box)
115 |
116 |
117 | def apply_with_random_selector(x, func, num_cases):
118 | """Computes func(x, sel), with sel sampled from [0...num_cases-1].
119 |
120 | Args:
121 | x: input Tensor.
122 | func: Python function to apply.
123 | num_cases: Python int32, number of cases to sample sel from.
124 |
125 | Returns:
126 | The result of func(x, sel), where func receives the value of the
127 | selector as a python integer, but sel is sampled dynamically.
128 | """
129 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
130 | # Pass the real x only to one of the func calls.
131 | return control_flow_ops.merge([
132 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
133 | for case in range(num_cases)])[0]
134 |
135 |
136 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
137 | """Distort the color of a Tensor image.
138 |
139 | Each color distortion is non-commutative and thus ordering of the color ops
140 | matters. Ideally we would randomly permute the ordering of the color ops.
141 | Rather then adding that level of complication, we select a distinct ordering
142 | of color ops for each preprocessing thread.
143 |
144 | Args:
145 | image: 3-D Tensor containing single image in [0, 1].
146 | color_ordering: Python int, a type of distortion (valid values: 0-3).
147 | fast_mode: Avoids slower ops (random_hue and random_contrast)
148 | scope: Optional scope for name_scope.
149 | Returns:
150 | 3-D Tensor color-distorted image on range [0, 1]
151 | Raises:
152 | ValueError: if color_ordering not in [0, 3]
153 | """
154 | with tf.name_scope(scope, 'distort_color', [image]):
155 | if fast_mode:
156 | if color_ordering == 0:
157 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
158 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
159 | else:
160 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
161 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
162 | else:
163 | if color_ordering == 0:
164 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
165 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
166 | image = tf.image.random_hue(image, max_delta=0.2)
167 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
168 | elif color_ordering == 1:
169 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
170 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
171 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
172 | image = tf.image.random_hue(image, max_delta=0.2)
173 | elif color_ordering == 2:
174 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
175 | image = tf.image.random_hue(image, max_delta=0.2)
176 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
177 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
178 | elif color_ordering == 3:
179 | image = tf.image.random_hue(image, max_delta=0.2)
180 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
181 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
182 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
183 | else:
184 | raise ValueError('color_ordering must be in [0, 3]')
185 | # The random_* ops do not necessarily clamp.
186 | return tf.clip_by_value(image, 0.0, 1.0)
187 |
188 |
189 | def distorted_bounding_box_crop(image,
190 | labels,
191 | bboxes,
192 | xs, ys,
193 | min_object_covered=0.05,
194 | aspect_ratio_range=(0.9, 1.1),
195 | area_range=(0.1, 1.0),
196 | max_attempts=200,
197 | scope=None):
198 | """Generates cropped_image using a one of the bboxes randomly distorted.
199 |
200 | See `tf.image.sample_distorted_bounding_box` for more documentation.
201 |
202 | Args:
203 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
204 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
205 | where each coordinate is [0, 1) and the coordinates are arranged
206 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
207 | image.
208 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
209 | area of the image must contain at least this fraction of any bounding box
210 | supplied.
211 | aspect_ratio_range: An optional list of `floats`. The cropped area of the
212 | image must have an aspect ratio = width / height within this range.
213 | area_range: An optional list of `floats`. The cropped area of the image
214 | must contain a fraction of the supplied image within in this range.
215 | max_attempts: An optional `int`. Number of attempts at generating a cropped
216 | region of the image of the specified constraints. After `max_attempts`
217 | failures, return the entire image.
218 | scope: Optional scope for name_scope.
219 | Returns:
220 | A tuple, a 3-D Tensor cropped_image and the distorted bbox
221 | """
222 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bboxes, xs,ys]):
223 | # Each bounding box has shape [1, num_boxes, box coords] and
224 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
225 | bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box(
226 | tf.shape(image),
227 | bounding_boxes=tf.expand_dims(bboxes, 0),
228 | min_object_covered=min_object_covered,
229 | aspect_ratio_range=aspect_ratio_range,
230 | area_range=area_range,
231 | max_attempts=max_attempts,
232 | use_image_if_no_bounding_boxes=True)
233 | distort_bbox = distort_bbox[0, 0]
234 |
235 | # Crop the image to the specified bounding box.
236 | cropped_image = tf.slice(image, bbox_begin, bbox_size)
237 | # Restore the shape since the dynamic slice loses 3rd dimension.
238 | cropped_image.set_shape([None, None, 3])
239 |
240 | # Update bounding boxes: resize and filter out.
241 | bboxes, xs, ys = tfe.bboxes_resize(distort_bbox, bboxes, xs, ys)
242 | labels, bboxes, xs, ys = tfe.bboxes_filter_overlap(labels, bboxes,xs, ys,
243 | BBOX_CROP_OVERLAP)
244 | return cropped_image, labels, bboxes,xs, ys, distort_bbox
245 |
246 |
247 | def preprocess_for_train(image, labels, bboxes, xs, ys,
248 | out_shape, data_format='NHWC',
249 | scope='ssd_preprocessing_train', clip=True, crop_area_range=AREA_RANGE):
250 | """Preprocesses the given image for training.
251 |
252 | Note that the actual resizing scale is sampled from
253 | [`resize_size_min`, `resize_size_max`].
254 |
255 | Args:
256 | image: A `Tensor` representing an image of arbitrary size.
257 | output_height: The height of the image after preprocessing.
258 | output_width: The width of the image after preprocessing.
259 | resize_side_min: The lower bound for the smallest side of the image for
260 | aspect-preserving resizing.
261 | resize_side_max: The upper bound for the smallest side of the image for
262 | aspect-preserving resizing.
263 |
264 | Returns:
265 | A preprocessed image.
266 | """
267 | fast_mode = False
268 | with tf.name_scope(scope, 'ssd_preprocessing_train', [image, labels, bboxes]):
269 | if image.get_shape().ndims != 3:
270 | raise ValueError('Input must be of size [height, width, C>0]')
271 |
272 | orig_dtype = image.dtype
273 | print('orig_dtype:', orig_dtype)
274 | # Convert to float scaled [0, 1].
275 | if image.dtype != tf.float32:
276 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
277 | # tf_summary_image(image, bboxes, 'image_with_bboxes')
278 |
279 | # Distort image and bounding boxes.
280 | dst_image = image
281 | dst_image, labels, bboxes,xs, ys, distort_bbox = \
282 | distorted_bounding_box_crop(image, labels, bboxes,xs, ys,
283 | aspect_ratio_range=CROP_RATIO_RANGE,min_object_covered=MIN_OBJECT_COVERED,area_range=crop_area_range)
284 | # Resize image to output size.
285 | dst_image = tf_image.resize_image(dst_image, out_shape,
286 | method=tf.image.ResizeMethod.BILINEAR,
287 | align_corners=False)
288 | #tf_summary_image(dst_image, bboxes, 'image_shape_distorted')
289 |
290 | # Randomly flip the image horizontally.
291 | #bboxes and xs ys all need to random
292 |
293 | dst_image, bboxes, xs, ys = tf_image.random_flip_left_right(dst_image, bboxes, xs, ys)
294 |
295 | # Randomly distort the colors. There are 4 ways to do it.
296 | dst_image = apply_with_random_selector(
297 | dst_image,
298 | lambda x, ordering: distort_color(x, ordering, fast_mode),
299 | num_cases=4)
300 |
301 | tf_summary_image(dst_image, bboxes, 'image_color_distorted')
302 |
303 | # Rescale to VGG input scale.
304 | image = tf.to_float(tf.image.convert_image_dtype(dst_image, orig_dtype, saturate=True))
305 | # image = dst_image * 255.
306 | # image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN])
307 | image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN] )
308 | # Image data format.
309 | if data_format == 'NCHW':
310 | image = tf.transpose(image, perm=(2, 0, 1))
311 |
312 | if clip:
313 | xy_clip_min = tf.constant([0., 0., 0., 0.])
314 | xy_clip_max = tf.constant([1., 1., 1., 1.])
315 | bbox_img_max = tf.constant([1., 1., 1. , 1.])
316 | bbox_img_min = tf.constant([0., 0., 0., 0.])
317 |
318 | bboxes = tf.minimum(bboxes, bbox_img_max)
319 | bboxes = tf.maximum(bboxes, bbox_img_min)
320 |
321 |
322 | xs = tf.maximum(xs, xy_clip_min)
323 | ys = tf.maximum(ys, xy_clip_min)
324 | xs = tf.minimum(xs, xy_clip_max)
325 | ys = tf.minimum(ys, xy_clip_max)
326 |
327 | tf_summary_image(image, bboxes, ' image whitened')
328 | # image = tf.Print(image, [image[0]], ' image: ', summarize=20)
329 | # xs = tf.Print(xs, [xs, tf.shape(xs)], ' xs ', summarize=20)
330 | # ys = tf.Print(ys, [ys, tf.shape(ys)], ' ys ', summarize=20)
331 | # bboxes = tf.Print(bboxes, [bboxes, tf.shape(bboxes)], ' bboxes ',summarize=20)
332 | return image, labels, bboxes, xs, ys
333 |
334 |
335 |
336 |
337 | def preprocess_for_eval(image, labels, bboxes,xs, ys,
338 | out_shape=EVAL_SIZE, data_format='NHWC',
339 | difficults=None, resize=Resize.WARP_RESIZE,
340 | scope='ssd_preprocessing_train'):
341 | """Preprocess an image for evaluation.
342 |
343 | Args:
344 | image: A `Tensor` representing an image of arbitrary size.
345 | out_shape: Output shape after pre-processing (if resize != None)
346 | resize: Resize strategy.
347 |
348 | Returns:
349 | A preprocessed image.
350 | """
351 | with tf.name_scope(scope):
352 | if image.get_shape().ndims != 3:
353 | raise ValueError('Input must be of size [height, width, C>0]')
354 |
355 | image = tf.to_float(image)
356 | image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN])
357 |
358 | # Add image rectangle to bboxes.
359 | bbox_img = tf.constant([[0., 0., 1., 1., 0., 0. , 0., 0., 0., 0., 0., 0.]])
360 | if bboxes is None:
361 | bboxes = bbox_img
362 | else:
363 | bboxes = tf.concat([bbox_img, bboxes], axis=0)
364 |
365 | if resize == Resize.NONE:
366 | # No resizing...
367 | pass
368 | elif resize == Resize.CENTRAL_CROP:
369 | # Central cropping of the image.
370 | image, bboxes = tf_image.resize_image_bboxes_with_crop_or_pad(
371 | image, bboxes, out_shape[0], out_shape[1])
372 | elif resize == Resize.PAD_AND_RESIZE:
373 | # Resize image first: find the correct factor...
374 | shape = tf.shape(image)
375 | factor = tf.minimum(tf.to_double(1.0),
376 | tf.minimum(tf.to_double(out_shape[0] / shape[0]),
377 | tf.to_double(out_shape[1] / shape[1])))
378 | resize_shape = factor * tf.to_double(shape[0:2])
379 | resize_shape = tf.cast(tf.floor(resize_shape), tf.int32)
380 |
381 | image = tf_image.resize_image(image, resize_shape,
382 | method=tf.image.ResizeMethod.BILINEAR,
383 | align_corners=False)
384 | # Pad to expected size.
385 | image, bboxes = tf_image.resize_image_bboxes_with_crop_or_pad(
386 | image, bboxes, out_shape[0], out_shape[1])
387 | elif resize == Resize.WARP_RESIZE:
388 | # Warp resize of the image.
389 | image = tf_image.resize_image(image, out_shape,
390 | method=tf.image.ResizeMethod.BILINEAR,
391 | align_corners=False)
392 |
393 | # Split back bounding boxes.
394 | bbox_img = bboxes[0]
395 | bboxes = bboxes[1:]
396 | # Remove difficult boxes.
397 | if difficults is not None:
398 | mask = tf.logical_not(tf.cast(difficults, tf.bool))
399 | labels = tf.boolean_mask(labels, mask)
400 | bboxes = tf.boolean_mask(bboxes, mask)
401 | # Image data format.
402 | if data_format == 'NCHW':
403 | image = tf.transpose(image, perm=(2, 0, 1))
404 | return image, labels, bboxes, bbox_img, xs, ys
405 |
406 |
407 | def preprocess_image(image,
408 | labels,
409 | bboxes,
410 | xs, ys,
411 | out_shape,
412 | data_format = 'NHWC',
413 | is_training=False,
414 | **kwargs):
415 | """Pre-process an given image.
416 |
417 | Args:
418 | image: A `Tensor` representing an image of arbitrary size.
419 | output_height: The height of the image after preprocessing.
420 | output_width: The width of the image after preprocessing.
421 | is_training: `True` if we're preprocessing the image for training and
422 | `False` otherwise.
423 | resize_side_min: The lower bound for the smallest side of the image for
424 | aspect-preserving resizing. If `is_training` is `False`, then this value
425 | is used for rescaling.
426 | resize_side_max: The upper bound for the smallest side of the image for
427 | aspect-preserving resizing. If `is_training` is `False`, this value is
428 | ignored. Otherwise, the resize side is sampled from
429 | [resize_size_min, resize_size_max].
430 |
431 | Returns:
432 | A preprocessed image.
433 | """
434 | if is_training:
435 | return preprocess_for_train(image, labels, bboxes,xs, ys,
436 | out_shape=out_shape,
437 | data_format=data_format)
438 | else:
439 | return preprocess_for_eval(image, labels, bboxes,xs, ys,
440 | out_shape=out_shape,
441 | data_format=data_format,
442 | **kwargs)
443 |
--------------------------------------------------------------------------------
/processing/ssd_vgg_preprocessing.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/ssd_vgg_preprocessing.pyc
--------------------------------------------------------------------------------
/processing/tf_image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors and Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Custom image operations.
16 | Most of the following methods extend TensorFlow image library, and part of
17 | the code is shameless copy-paste of the former!
18 | """
19 | import tensorflow as tf
20 |
21 | from tensorflow.python.framework import constant_op
22 | from tensorflow.python.framework import dtypes
23 | from tensorflow.python.framework import ops
24 | from tensorflow.python.framework import tensor_shape
25 | from tensorflow.python.framework import tensor_util
26 | from tensorflow.python.ops import array_ops
27 | from tensorflow.python.ops import check_ops
28 | from tensorflow.python.ops import clip_ops
29 | from tensorflow.python.ops import control_flow_ops
30 | from tensorflow.python.ops import gen_image_ops
31 | from tensorflow.python.ops import gen_nn_ops
32 | from tensorflow.python.ops import string_ops
33 | from tensorflow.python.ops import math_ops
34 | from tensorflow.python.ops import random_ops
35 | from tensorflow.python.ops import variables
36 |
37 |
38 | # =========================================================================== #
39 | # Modification of TensorFlow image routines.
40 | # =========================================================================== #
41 | def _assert(cond, ex_type, msg):
42 | """A polymorphic assert, works with tensors and boolean expressions.
43 | If `cond` is not a tensor, behave like an ordinary assert statement, except
44 | that a empty list is returned. If `cond` is a tensor, return a list
45 | containing a single TensorFlow assert op.
46 | Args:
47 | cond: Something evaluates to a boolean value. May be a tensor.
48 | ex_type: The exception class to use.
49 | msg: The error message.
50 | Returns:
51 | A list, containing at most one assert op.
52 | """
53 | if _is_tensor(cond):
54 | return [control_flow_ops.Assert(cond, [msg])]
55 | else:
56 | if not cond:
57 | raise ex_type(msg)
58 | else:
59 | return []
60 |
61 |
62 | def _is_tensor(x):
63 | """Returns `True` if `x` is a symbolic tensor-like object.
64 | Args:
65 | x: A python object to check.
66 | Returns:
67 | `True` if `x` is a `tf.Tensor` or `tf.Variable`, otherwise `False`.
68 | """
69 | return isinstance(x, (ops.Tensor, variables.Variable))
70 |
71 |
72 | def _ImageDimensions(image):
73 | """Returns the dimensions of an image tensor.
74 | Args:
75 | image: A 3-D Tensor of shape `[height, width, channels]`.
76 | Returns:
77 | A list of `[height, width, channels]` corresponding to the dimensions of the
78 | input image. Dimensions that are statically known are python integers,
79 | otherwise they are integer scalar tensors.
80 | """
81 | if image.get_shape().is_fully_defined():
82 | return image.get_shape().as_list()
83 | else:
84 | static_shape = image.get_shape().with_rank(3).as_list()
85 | dynamic_shape = array_ops.unstack(array_ops.shape(image), 3)
86 | return [s if s is not None else d
87 | for s, d in zip(static_shape, dynamic_shape)]
88 |
89 |
90 | def _Check3DImage(image, require_static=True):
91 | """Assert that we are working with properly shaped image.
92 | Args:
93 | image: 3-D Tensor of shape [height, width, channels]
94 | require_static: If `True`, requires that all dimensions of `image` are
95 | known and non-zero.
96 | Raises:
97 | ValueError: if `image.shape` is not a 3-vector.
98 | Returns:
99 | An empty list, if `image` has fully defined dimensions. Otherwise, a list
100 | containing an assert op is returned.
101 | """
102 | try:
103 | image_shape = image.get_shape().with_rank(3)
104 | except ValueError:
105 | raise ValueError("'image' must be three-dimensional.")
106 | if require_static and not image_shape.is_fully_defined():
107 | raise ValueError("'image' must be fully defined.")
108 | if any(x == 0 for x in image_shape):
109 | raise ValueError("all dims of 'image.shape' must be > 0: %s" %
110 | image_shape)
111 | if not image_shape.is_fully_defined():
112 | return [check_ops.assert_positive(array_ops.shape(image),
113 | ["all dims of 'image.shape' "
114 | "must be > 0."])]
115 | else:
116 | return []
117 |
118 |
119 | def fix_image_flip_shape(image, result):
120 | """Set the shape to 3 dimensional if we don't know anything else.
121 | Args:
122 | image: original image size
123 | result: flipped or transformed image
124 | Returns:
125 | An image whose shape is at least None,None,None.
126 | """
127 | image_shape = image.get_shape()
128 | if image_shape == tensor_shape.unknown_shape():
129 | result.set_shape([None, None, None])
130 | else:
131 | result.set_shape(image_shape)
132 | return result
133 |
134 |
135 | # =========================================================================== #
136 | # Image + BBoxes methods: cropping, resizing, flipping, ...
137 | # =========================================================================== #
138 | def bboxes_crop_or_pad(bboxes,
139 | height, width,
140 | offset_y, offset_x,
141 | target_height, target_width):
142 | """Adapt bounding boxes to crop or pad operations.
143 | Coordinates are always supposed to be relative to the image.
144 |
145 | Arguments:
146 | bboxes: Tensor Nx4 with bboxes coordinates [y_min, x_min, y_max, x_max];
147 | height, width: Original image dimension;
148 | offset_y, offset_x: Offset to apply,
149 | negative if cropping, positive if padding;
150 | target_height, target_width: Target dimension after cropping / padding.
151 | """
152 | with tf.name_scope('bboxes_crop_or_pad'):
153 | # Rescale bounding boxes in pixels.
154 | scale = tf.cast(tf.stack([height, width, height, width]), bboxes.dtype)
155 | bboxes = bboxes * scale
156 | # Add offset.
157 | offset = tf.cast(tf.stack([offset_y, offset_x, offset_y, offset_x]), bboxes.dtype)
158 | bboxes = bboxes + offset
159 | # Rescale to target dimension.
160 | scale = tf.cast(tf.stack([target_height, target_width,
161 | target_height, target_width]), bboxes.dtype)
162 | bboxes = bboxes / scale
163 | return bboxes
164 |
165 |
166 | def resize_image_bboxes_with_crop_or_pad(image, bboxes,
167 | target_height, target_width):
168 | """Crops and/or pads an image to a target width and height.
169 | Resizes an image to a target width and height by either centrally
170 | cropping the image or padding it evenly with zeros.
171 |
172 | If `width` or `height` is greater than the specified `target_width` or
173 | `target_height` respectively, this op centrally crops along that dimension.
174 | If `width` or `height` is smaller than the specified `target_width` or
175 | `target_height` respectively, this op centrally pads with 0 along that
176 | dimension.
177 | Args:
178 | image: 3-D tensor of shape `[height, width, channels]`
179 | target_height: Target height.
180 | target_width: Target width.
181 | Raises:
182 | ValueError: if `target_height` or `target_width` are zero or negative.
183 | Returns:
184 | Cropped and/or padded image of shape
185 | `[target_height, target_width, channels]`
186 | """
187 | with tf.name_scope('resize_with_crop_or_pad'):
188 | image = ops.convert_to_tensor(image, name='image')
189 |
190 | assert_ops = []
191 | assert_ops += _Check3DImage(image, require_static=False)
192 | assert_ops += _assert(target_width > 0, ValueError,
193 | 'target_width must be > 0.')
194 | assert_ops += _assert(target_height > 0, ValueError,
195 | 'target_height must be > 0.')
196 |
197 | image = control_flow_ops.with_dependencies(assert_ops, image)
198 | # `crop_to_bounding_box` and `pad_to_bounding_box` have their own checks.
199 | # Make sure our checks come first, so that error messages are clearer.
200 | if _is_tensor(target_height):
201 | target_height = control_flow_ops.with_dependencies(
202 | assert_ops, target_height)
203 | if _is_tensor(target_width):
204 | target_width = control_flow_ops.with_dependencies(assert_ops, target_width)
205 |
206 | def max_(x, y):
207 | if _is_tensor(x) or _is_tensor(y):
208 | return math_ops.maximum(x, y)
209 | else:
210 | return max(x, y)
211 |
212 | def min_(x, y):
213 | if _is_tensor(x) or _is_tensor(y):
214 | return math_ops.minimum(x, y)
215 | else:
216 | return min(x, y)
217 |
218 | def equal_(x, y):
219 | if _is_tensor(x) or _is_tensor(y):
220 | return math_ops.equal(x, y)
221 | else:
222 | return x == y
223 |
224 | height, width, _ = _ImageDimensions(image)
225 | width_diff = target_width - width
226 | offset_crop_width = max_(-width_diff // 2, 0)
227 | offset_pad_width = max_(width_diff // 2, 0)
228 |
229 | height_diff = target_height - height
230 | offset_crop_height = max_(-height_diff // 2, 0)
231 | offset_pad_height = max_(height_diff // 2, 0)
232 |
233 | # Maybe crop if needed.
234 | height_crop = min_(target_height, height)
235 | width_crop = min_(target_width, width)
236 | cropped = tf.image.crop_to_bounding_box(image, offset_crop_height, offset_crop_width,
237 | height_crop, width_crop)
238 | bboxes = bboxes_crop_or_pad(bboxes,
239 | height, width,
240 | -offset_crop_height, -offset_crop_width,
241 | height_crop, width_crop)
242 | # Maybe pad if needed.
243 | resized = tf.image.pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width,
244 | target_height, target_width)
245 | bboxes = bboxes_crop_or_pad(bboxes,
246 | height_crop, width_crop,
247 | offset_pad_height, offset_pad_width,
248 | target_height, target_width)
249 |
250 | # In theory all the checks below are redundant.
251 | if resized.get_shape().ndims is None:
252 | raise ValueError('resized contains no shape.')
253 |
254 | resized_height, resized_width, _ = _ImageDimensions(resized)
255 |
256 | assert_ops = []
257 | assert_ops += _assert(equal_(resized_height, target_height), ValueError,
258 | 'resized height is not correct.')
259 | assert_ops += _assert(equal_(resized_width, target_width), ValueError,
260 | 'resized width is not correct.')
261 |
262 | resized = control_flow_ops.with_dependencies(assert_ops, resized)
263 | return resized, bboxes
264 |
265 |
266 | def resize_image(image, size,
267 | method=tf.image.ResizeMethod.BILINEAR,
268 | align_corners=False):
269 | """Resize an image and bounding boxes.
270 | """
271 | # Resize image.
272 | with tf.name_scope('resize_image'):
273 | height, width, channels = _ImageDimensions(image)
274 | image = tf.expand_dims(image, 0)
275 | image = tf.image.resize_images(image, size,
276 | method, align_corners)
277 | image = tf.reshape(image, tf.stack([size[0], size[1], channels]))
278 | return image
279 |
280 |
281 | def random_flip_left_right(image, bboxes, xs, ys, seed=None):
282 | """Random flip left-right of an image and its bounding boxes.
283 | """
284 | def flip_bboxes(bboxes):
285 | """Flip bounding boxes coordinates.
286 | """
287 | bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3],
288 | bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1)
289 | return bboxes
290 |
291 | def flip_xs(xs):
292 | """Flip xs coordinates
293 | """
294 | # xs_temp = tf.ones(xs.shape)
295 | # xs_temp[:, 0] = 1 - xs[:, 1]
296 | # xs_temp[:, 1] = 1 - xs[:, 0]
297 | # xs_temp[:, 2] = 1 - xs[:, 3]
298 | # xs_temp[:, 3] = 1 - xs[:, 2]
299 |
300 | xs = tf.stack([1 - xs[:, 1], 1 - xs[:, 0], 1 - xs[ :, 3], 1 - xs[ :, 2]], axis=-1)
301 | return xs
302 |
303 | def flip_ys(ys):
304 | """Flip ys coordinates
305 | """
306 | # ys_temp = tf.ones(ys.shape)
307 | # ys_temp[:, 0] = ys[: ,1]
308 | # ys_temp[:, 1] = ys[:, 0]
309 | # ys_temp[:, 2] = ys[:, 3]
310 | # ys_temp[:, 3] = ys[:, 2]
311 | # return ys_temp
312 | ys = tf.stack([ys[:, 1], ys[: ,0], ys[: ,3], ys[:, 2]], axis=-1)
313 | return ys
314 |
315 |
316 | # Random flip. Tensorflow implementation.
317 | with tf.name_scope('random_flip_left_right'):
318 | image = ops.convert_to_tensor(image, name='image')
319 | _Check3DImage(image, require_static=False)
320 | uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
321 | mirror_cond = math_ops.less(uniform_random, .5)
322 | # Flip image.
323 | result = control_flow_ops.cond(mirror_cond,
324 | lambda: array_ops.reverse_v2(image, [1]),
325 | lambda: image)
326 | # Flip bboxes.
327 | bboxes = control_flow_ops.cond(mirror_cond,
328 | lambda: flip_bboxes(bboxes),
329 | lambda: bboxes)
330 |
331 | xs = control_flow_ops.cond(mirror_cond, lambda: flip_xs(xs), lambda: xs)
332 | ys = control_flow_ops.cond(mirror_cond, lambda: flip_ys(ys), lambda: ys)
333 | return fix_image_flip_shape(image, result), bboxes, xs, ys
334 |
335 |
--------------------------------------------------------------------------------
/processing/tf_image.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/tf_image.pyc
--------------------------------------------------------------------------------
/tf_extended/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF Extended: additional metrics.
16 | """
17 |
18 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import
19 | from tf_extended.metrics import *
20 | from tf_extended.tensors import *
21 | from tf_extended.bboxes import *
22 | from tf_extended.image import *
23 | from tf_extended.math import *
24 |
25 |
--------------------------------------------------------------------------------
/tf_extended/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__init__.pyc
--------------------------------------------------------------------------------
/tf_extended/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/tf_extended/__pycache__/bboxes.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/bboxes.cpython-35.pyc
--------------------------------------------------------------------------------
/tf_extended/__pycache__/image.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/image.cpython-35.pyc
--------------------------------------------------------------------------------
/tf_extended/__pycache__/math.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/math.cpython-35.pyc
--------------------------------------------------------------------------------
/tf_extended/__pycache__/metrics.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/metrics.cpython-35.pyc
--------------------------------------------------------------------------------
/tf_extended/__pycache__/tensors.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/tensors.cpython-35.pyc
--------------------------------------------------------------------------------
/tf_extended/bboxes.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/bboxes.pyc
--------------------------------------------------------------------------------
/tf_extended/image.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/image.py
--------------------------------------------------------------------------------
/tf_extended/image.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/image.pyc
--------------------------------------------------------------------------------
/tf_extended/math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF Extended: additional math functions.
16 | """
17 | import tensorflow as tf
18 |
19 | from tensorflow.python.ops import array_ops
20 | from tensorflow.python.ops import math_ops
21 | from tensorflow.python.framework import dtypes
22 | from tensorflow.python.framework import ops
23 |
24 |
25 | def safe_divide(numerator, denominator, name):
26 | """Divides two values, returning 0 if the denominator is <= 0.
27 | Args:
28 | numerator: A real `Tensor`.
29 | denominator: A real `Tensor`, with dtype matching `numerator`.
30 | name: Name for the returned op.
31 | Returns:
32 | 0 if `denominator` <= 0, else `numerator` / `denominator`
33 | """
34 | return tf.where(
35 | math_ops.greater(denominator, 0),
36 | math_ops.divide(numerator, denominator),
37 | tf.zeros_like(numerator),
38 | name=name)
39 |
40 |
41 | def cummax(x, reverse=False, name=None):
42 | """Compute the cumulative maximum of the tensor `x` along `axis`. This
43 | operation is similar to the more classic `cumsum`. Only support 1D Tensor
44 | for now.
45 |
46 | Args:
47 | x: A `Tensor`. Must be one of the following types: `float32`, `float64`,
48 | `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`,
49 | `complex128`, `qint8`, `quint8`, `qint32`, `half`.
50 | axis: A `Tensor` of type `int32` (default: 0).
51 | reverse: A `bool` (default: False).
52 | name: A name for the operation (optional).
53 | Returns:
54 | A `Tensor`. Has the same type as `x`.
55 | """
56 | with ops.name_scope(name, "Cummax", [x]) as name:
57 | x = ops.convert_to_tensor(x, name="x")
58 | # Not very optimal: should directly integrate reverse into tf.scan.
59 | if reverse:
60 | x = tf.reverse(x, axis=[0])
61 | # 'Accumlating' maximum: ensure it is always increasing.
62 | cmax = tf.scan(lambda a, y: tf.maximum(a, y), x,
63 | initializer=None, parallel_iterations=1,
64 | back_prop=False, swap_memory=False)
65 | if reverse:
66 | cmax = tf.reverse(cmax, axis=[0])
67 | return cmax
68 |
--------------------------------------------------------------------------------
/tf_extended/math.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/math.pyc
--------------------------------------------------------------------------------
/tf_extended/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF Extended: additional metrics.
16 | """
17 | import tensorflow as tf
18 | import numpy as np
19 |
20 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables
21 | from tensorflow.python.framework import dtypes
22 | from tensorflow.python.framework import ops
23 | from tensorflow.python.ops import array_ops
24 | from tensorflow.python.ops import math_ops
25 | from tensorflow.python.ops import nn
26 | from tensorflow.python.ops import state_ops
27 | from tensorflow.python.ops import variable_scope
28 | from tensorflow.python.ops import variables
29 |
30 | from tf_extended import math as tfe_math
31 |
32 |
33 | # =========================================================================== #
34 | # TensorFlow utils
35 | # =========================================================================== #
36 | def _create_local(name, shape, collections=None, validate_shape=True,
37 | dtype=dtypes.float32):
38 | """Creates a new local variable.
39 | Args:
40 | name: The name of the new or existing variable.
41 | shape: Shape of the new or existing variable.
42 | collections: A list of collection names to which the Variable will be added.
43 | validate_shape: Whether to validate the shape of the variable.
44 | dtype: Data type of the variables.
45 | Returns:
46 | The created variable.
47 | """
48 | # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES
49 | collections = list(collections or [])
50 | collections += [ops.GraphKeys.LOCAL_VARIABLES]
51 | return variables.Variable(
52 | initial_value=array_ops.zeros(shape, dtype=dtype),
53 | name=name,
54 | trainable=False,
55 | collections=collections,
56 | validate_shape=validate_shape)
57 |
58 |
59 | def _safe_div(numerator, denominator, name):
60 | """Divides two values, returning 0 if the denominator is <= 0.
61 | Args:
62 | numerator: A real `Tensor`.
63 | denominator: A real `Tensor`, with dtype matching `numerator`.
64 | name: Name for the returned op.
65 | Returns:
66 | 0 if `denominator` <= 0, else `numerator` / `denominator`
67 | """
68 | return tf.where(
69 | math_ops.greater(denominator, 0),
70 | math_ops.divide(numerator, denominator),
71 | tf.zeros_like(numerator),
72 | name=name)
73 |
74 |
75 | def _broadcast_weights(weights, values):
76 | """Broadcast `weights` to the same shape as `values`.
77 | This returns a version of `weights` following the same broadcast rules as
78 | `mul(weights, values)`. When computing a weighted average, use this function
79 | to broadcast `weights` before summing them; e.g.,
80 | `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`.
81 | Args:
82 | weights: `Tensor` whose shape is broadcastable to `values`.
83 | values: `Tensor` of any shape.
84 | Returns:
85 | `weights` broadcast to `values` shape.
86 | """
87 | weights_shape = weights.get_shape()
88 | values_shape = values.get_shape()
89 | if(weights_shape.is_fully_defined() and
90 | values_shape.is_fully_defined() and
91 | weights_shape.is_compatible_with(values_shape)):
92 | return weights
93 | return math_ops.mul(
94 | weights, array_ops.ones_like(values), name='broadcast_weights')
95 |
96 |
97 | # =========================================================================== #
98 | # TF Extended metrics: TP and FP arrays.
99 | # =========================================================================== #
100 | def precision_recall(num_gbboxes, num_detections, tp, fp, scores,
101 | dtype=tf.float64, scope=None):
102 | """Compute precision and recall from scores, true positives and false
103 | positives booleans arrays
104 | """
105 | # Input dictionaries: dict outputs as streaming metrics.
106 | if isinstance(scores, dict):
107 | d_precision = {}
108 | d_recall = {}
109 | for c in num_gbboxes.keys():
110 | scope = 'precision_recall_%s' % c
111 | p, r = precision_recall(num_gbboxes[c], num_detections[c],
112 | tp[c], fp[c], scores[c],
113 | dtype, scope)
114 | d_precision[c] = p
115 | d_recall[c] = r
116 | return d_precision, d_recall
117 |
118 | # Sort by score.
119 | with tf.name_scope(scope, 'precision_recall',
120 | [num_gbboxes, num_detections, tp, fp, scores]):
121 | # Sort detections by score.
122 | scores, idxes = tf.nn.top_k(scores, k=num_detections, sorted=True)
123 | tp = tf.gather(tp, idxes)
124 | fp = tf.gather(fp, idxes)
125 | # Computer recall and precision.
126 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0)
127 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0)
128 | recall = _safe_div(tp, tf.cast(num_gbboxes, dtype), 'recall')
129 | precision = _safe_div(tp, tp + fp, 'precision')
130 | return tf.tuple([precision, recall])
131 |
132 |
133 | def streaming_tp_fp_arrays(num_gbboxes, tp, fp, scores,
134 | remove_zero_scores=True,
135 | metrics_collections=None,
136 | updates_collections=None,
137 | name=None):
138 | """Streaming computation of True and False Positive arrays. This metrics
139 | also keeps track of scores and number of grountruth objects.
140 | """
141 | # Input dictionaries: dict outputs as streaming metrics.
142 | if isinstance(scores, dict) or isinstance(fp, dict):
143 | d_values = {}
144 | d_update_ops = {}
145 | for c in num_gbboxes.keys():
146 | scope = 'streaming_tp_fp_%s' % c
147 | v, up = streaming_tp_fp_arrays(num_gbboxes[c], tp[c], fp[c], scores[c],
148 | remove_zero_scores,
149 | metrics_collections,
150 | updates_collections,
151 | name=scope)
152 | d_values[c] = v
153 | d_update_ops[c] = up
154 | return d_values, d_update_ops
155 |
156 | # Input Tensors...
157 | with variable_scope.variable_scope(name, 'streaming_tp_fp',
158 | [num_gbboxes, tp, fp, scores]):
159 | num_gbboxes = math_ops.to_int64(num_gbboxes)
160 | scores = math_ops.to_float(scores)
161 | stype = tf.bool
162 | tp = tf.cast(tp, stype)
163 | fp = tf.cast(fp, stype)
164 | # Reshape TP and FP tensors and clean away 0 class values.
165 | scores = tf.reshape(scores, [-1])
166 | tp = tf.reshape(tp, [-1])
167 | fp = tf.reshape(fp, [-1])
168 | # Remove TP and FP both false.
169 | mask = tf.logical_or(tp, fp)
170 | if remove_zero_scores:
171 | rm_threshold = 1e-4
172 | mask = tf.logical_and(mask, tf.greater(scores, rm_threshold))
173 | scores = tf.boolean_mask(scores, mask)
174 | tp = tf.boolean_mask(tp, mask)
175 | fp = tf.boolean_mask(fp, mask)
176 |
177 | # Local variables accumlating information over batches.
178 | v_nobjects = _create_local('v_num_gbboxes', shape=[], dtype=tf.int64)
179 | v_ndetections = _create_local('v_num_detections', shape=[], dtype=tf.int32)
180 | v_scores = _create_local('v_scores', shape=[0, ])
181 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype)
182 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype)
183 |
184 | # Update operations.
185 | nobjects_op = state_ops.assign_add(v_nobjects,
186 | tf.reduce_sum(num_gbboxes))
187 | ndetections_op = state_ops.assign_add(v_ndetections,
188 | tf.size(scores, out_type=tf.int32))
189 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, scores], axis=0),
190 | validate_shape=False)
191 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp], axis=0),
192 | validate_shape=False)
193 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp], axis=0),
194 | validate_shape=False)
195 |
196 | # Value and update ops.
197 | val = (v_nobjects, v_ndetections, v_tp, v_fp, v_scores)
198 | with ops.control_dependencies([nobjects_op, ndetections_op,
199 | scores_op, tp_op, fp_op]):
200 | update_op = (nobjects_op, ndetections_op, tp_op, fp_op, scores_op)
201 |
202 | if metrics_collections:
203 | ops.add_to_collections(metrics_collections, val)
204 | if updates_collections:
205 | ops.add_to_collections(updates_collections, update_op)
206 | return val, update_op
207 |
208 |
209 | # =========================================================================== #
210 | # Average precision computations.
211 | # =========================================================================== #
212 | def average_precision_voc12(precision, recall, name=None):
213 | """Compute (interpolated) average precision from precision and recall Tensors.
214 |
215 | The implementation follows Pascal 2012 and ILSVRC guidelines.
216 | See also: https://sanchom.wordpress.com/tag/average-precision/
217 | """
218 | with tf.name_scope(name, 'average_precision_voc12', [precision, recall]):
219 | # Convert to float64 to decrease error on Riemann sums.
220 | precision = tf.cast(precision, dtype=tf.float64)
221 | recall = tf.cast(recall, dtype=tf.float64)
222 |
223 | # Add bounds values to precision and recall.
224 | precision = tf.concat([[0.], precision, [0.]], axis=0)
225 | recall = tf.concat([[0.], recall, [1.]], axis=0)
226 | # Ensures precision is increasing in reverse order.
227 | precision = tfe_math.cummax(precision, reverse=True)
228 |
229 | # Riemann sums for estimating the integral.
230 | # mean_pre = (precision[1:] + precision[:-1]) / 2.
231 | mean_pre = precision[1:]
232 | diff_rec = recall[1:] - recall[:-1]
233 | ap = tf.reduce_sum(mean_pre * diff_rec)
234 | return ap
235 |
236 |
237 | def average_precision_voc07(precision, recall, name=None):
238 | """Compute (interpolated) average precision from precision and recall Tensors.
239 |
240 | The implementation follows Pascal 2007 guidelines.
241 | See also: https://sanchom.wordpress.com/tag/average-precision/
242 | """
243 | with tf.name_scope(name, 'average_precision_voc07', [precision, recall]):
244 | # Convert to float64 to decrease error on cumulated sums.
245 | precision = tf.cast(precision, dtype=tf.float64)
246 | recall = tf.cast(recall, dtype=tf.float64)
247 | # Add zero-limit value to avoid any boundary problem...
248 | precision = tf.concat([precision, [0.]], axis=0)
249 | recall = tf.concat([recall, [np.inf]], axis=0)
250 |
251 | # Split the integral into 10 bins.
252 | l_aps = []
253 | for t in np.arange(0., 1.1, 0.1):
254 | mask = tf.greater_equal(recall, t)
255 | v = tf.reduce_max(tf.boolean_mask(precision, mask))
256 | l_aps.append(v / 11.)
257 | ap = tf.add_n(l_aps)
258 | return ap
259 |
260 |
261 | def precision_recall_values(xvals, precision, recall, name=None):
262 | """Compute values on the precision/recall curve.
263 |
264 | Args:
265 | x: Python list of floats;
266 | precision: 1D Tensor decreasing.
267 | recall: 1D Tensor increasing.
268 | Return:
269 | list of precision values.
270 | """
271 | with ops.name_scope(name, "precision_recall_values",
272 | [precision, recall]) as name:
273 | # Add bounds values to precision and recall.
274 | precision = tf.concat([[0.], precision, [0.]], axis=0)
275 | recall = tf.concat([[0.], recall, [1.]], axis=0)
276 | precision = tfe_math.cummax(precision, reverse=True)
277 |
278 | prec_values = []
279 | for x in xvals:
280 | mask = tf.less_equal(recall, x)
281 | val = tf.reduce_min(tf.boolean_mask(precision, mask))
282 | prec_values.append(val)
283 | return tf.tuple(prec_values)
284 |
285 |
286 | # =========================================================================== #
287 | # TF Extended metrics: old stuff!
288 | # =========================================================================== #
289 | def _precision_recall(n_gbboxes, n_detections, scores, tp, fp, scope=None):
290 | """Compute precision and recall from scores, true positives and false
291 | positives booleans arrays
292 | """
293 | # Sort by score.
294 | with tf.name_scope(scope, 'prec_rec', [n_gbboxes, scores, tp, fp]):
295 | # Sort detections by score.
296 | scores, idxes = tf.nn.top_k(scores, k=n_detections, sorted=True)
297 | tp = tf.gather(tp, idxes)
298 | fp = tf.gather(fp, idxes)
299 | # Computer recall and precision.
300 | dtype = tf.float64
301 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0)
302 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0)
303 | recall = _safe_div(tp, tf.cast(n_gbboxes, dtype), 'recall')
304 | precision = _safe_div(tp, tp + fp, 'precision')
305 |
306 | return tf.tuple([precision, recall])
307 |
308 |
309 | def streaming_precision_recall_arrays(n_gbboxes, rclasses, rscores,
310 | tp_tensor, fp_tensor,
311 | remove_zero_labels=True,
312 | metrics_collections=None,
313 | updates_collections=None,
314 | name=None):
315 | """Streaming computation of precision / recall arrays. This metrics
316 | keeps tracks of boolean True positives and False positives arrays.
317 | """
318 | with variable_scope.variable_scope(name, 'stream_precision_recall',
319 | [n_gbboxes, rclasses, tp_tensor, fp_tensor]):
320 | n_gbboxes = math_ops.to_int64(n_gbboxes)
321 | rclasses = math_ops.to_int64(rclasses)
322 | rscores = math_ops.to_float(rscores)
323 |
324 | stype = tf.int32
325 | tp_tensor = tf.cast(tp_tensor, stype)
326 | fp_tensor = tf.cast(fp_tensor, stype)
327 |
328 | # Reshape TP and FP tensors and clean away 0 class values.
329 | rclasses = tf.reshape(rclasses, [-1])
330 | rscores = tf.reshape(rscores, [-1])
331 | tp_tensor = tf.reshape(tp_tensor, [-1])
332 | fp_tensor = tf.reshape(fp_tensor, [-1])
333 | if remove_zero_labels:
334 | mask = tf.greater(rclasses, 0)
335 | rclasses = tf.boolean_mask(rclasses, mask)
336 | rscores = tf.boolean_mask(rscores, mask)
337 | tp_tensor = tf.boolean_mask(tp_tensor, mask)
338 | fp_tensor = tf.boolean_mask(fp_tensor, mask)
339 |
340 | # Local variables accumlating information over batches.
341 | v_nobjects = _create_local('v_nobjects', shape=[], dtype=tf.int64)
342 | v_ndetections = _create_local('v_ndetections', shape=[], dtype=tf.int32)
343 | v_scores = _create_local('v_scores', shape=[0, ])
344 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype)
345 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype)
346 |
347 | # Update operations.
348 | nobjects_op = state_ops.assign_add(v_nobjects,
349 | tf.reduce_sum(n_gbboxes))
350 | ndetections_op = state_ops.assign_add(v_ndetections,
351 | tf.size(rscores, out_type=tf.int32))
352 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, rscores], axis=0),
353 | validate_shape=False)
354 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp_tensor], axis=0),
355 | validate_shape=False)
356 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp_tensor], axis=0),
357 | validate_shape=False)
358 |
359 | # Precision and recall computations.
360 | # r = _precision_recall(nobjects_op, scores_op, tp_op, fp_op, 'value')
361 | r = _precision_recall(v_nobjects, v_ndetections, v_scores,
362 | v_tp, v_fp, 'value')
363 |
364 | with ops.control_dependencies([nobjects_op, ndetections_op,
365 | scores_op, tp_op, fp_op]):
366 | update_op = _precision_recall(nobjects_op, ndetections_op,
367 | scores_op, tp_op, fp_op, 'update_op')
368 |
369 | # update_op = tf.Print(update_op,
370 | # [tf.reduce_sum(tf.cast(mask, tf.int64)),
371 | # tf.reduce_sum(tf.cast(mask2, tf.int64)),
372 | # tf.reduce_min(rscores),
373 | # tf.reduce_sum(n_gbboxes)],
374 | # 'Metric: ')
375 | # Some debugging stuff!
376 | # update_op = tf.Print(update_op,
377 | # [tf.shape(tp_op),
378 | # tf.reduce_sum(tf.cast(tp_op, tf.int64), axis=0)],
379 | # 'TP and FP shape: ')
380 | # update_op[0] = tf.Print(update_op,
381 | # [nobjects_op],
382 | # '# Groundtruth bboxes: ')
383 | # update_op = tf.Print(update_op,
384 | # [update_op[0][0],
385 | # update_op[0][-1],
386 | # tf.reduce_min(update_op[0]),
387 | # tf.reduce_max(update_op[0]),
388 | # tf.reduce_min(update_op[1]),
389 | # tf.reduce_max(update_op[1])],
390 | # 'Precision and recall :')
391 |
392 | if metrics_collections:
393 | ops.add_to_collections(metrics_collections, r)
394 | if updates_collections:
395 | ops.add_to_collections(updates_collections, update_op)
396 | return r, update_op
397 |
398 |
--------------------------------------------------------------------------------
/tf_extended/metrics.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/metrics.pyc
--------------------------------------------------------------------------------
/tf_extended/tensors.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """TF Extended: additional tensors operations.
16 | """
17 | import tensorflow as tf
18 |
19 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables
20 | from tensorflow.contrib.metrics.python.ops import set_ops
21 | from tensorflow.python.framework import dtypes
22 | from tensorflow.python.framework import ops
23 | from tensorflow.python.framework import sparse_tensor
24 | from tensorflow.python.ops import array_ops
25 | from tensorflow.python.ops import check_ops
26 | from tensorflow.python.ops import control_flow_ops
27 | from tensorflow.python.ops import math_ops
28 | from tensorflow.python.ops import nn
29 | from tensorflow.python.ops import state_ops
30 | from tensorflow.python.ops import variable_scope
31 | from tensorflow.python.ops import variables
32 |
33 |
34 | def get_shape(x, rank=None):
35 | """Returns the dimensions of a Tensor as list of integers or scale tensors.
36 |
37 | Args:
38 | x: N-d Tensor;
39 | rank: Rank of the Tensor. If None, will try to guess it.
40 | Returns:
41 | A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the
42 | input tensor. Dimensions that are statically known are python integers,
43 | otherwise they are integer scalar tensors.
44 | """
45 | if x.get_shape().is_fully_defined():
46 | return x.get_shape().as_list()
47 | else:
48 | static_shape = x.get_shape()
49 | if rank is None:
50 | static_shape = static_shape.as_list()
51 | rank = len(static_shape)
52 | else:
53 | static_shape = x.get_shape().with_rank(rank).as_list()
54 | dynamic_shape = tf.unstack(tf.shape(x), rank)
55 | return [s if s is not None else d
56 | for s, d in zip(static_shape, dynamic_shape)]
57 |
58 |
59 | def pad_axis(x, offset, size, axis=0, name=None):
60 | """Pad a tensor on an axis, with a given offset and output size.
61 | The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the
62 | `size` is smaller than existing size + `offset`, the output tensor
63 | was the latter dimension.
64 |
65 | Args:
66 | x: Tensor to pad;
67 | offset: Offset to add on the dimension chosen;
68 | size: Final size of the dimension.
69 | Return:
70 | Padded tensor whose dimension on `axis` is `size`, or greater if
71 | the input vector was larger.
72 | """
73 | with tf.name_scope(name, 'pad_axis'):
74 | shape = get_shape(x)
75 | rank = len(shape)
76 | # Padding description.
77 | new_size = tf.maximum(size-offset-shape[axis], 0)
78 | pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1))
79 | pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1))
80 | paddings = tf.stack([pad1, pad2], axis=1)
81 | x = tf.pad(x, paddings, mode='CONSTANT')
82 | # Reshape, to get fully defined shape if possible.
83 | # TODO: fix with tf.slice
84 | shape[axis] = size
85 | x = tf.reshape(x, tf.stack(shape))
86 | return x
87 |
88 |
89 | # def select_at_index(idx, val, t):
90 | # """Return a tensor.
91 | # """
92 | # idx = tf.expand_dims(tf.expand_dims(idx, 0), 0)
93 | # val = tf.expand_dims(val, 0)
94 | # t = t + tf.scatter_nd(idx, val, tf.shape(t))
95 | # return t
96 |
--------------------------------------------------------------------------------
/tf_extended/tensors.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/tensors.pyc
--------------------------------------------------------------------------------
/tf_extended/tf_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Diverse TensorFlow utils, for training, evaluation and so on!
16 | """
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import os
21 | from pprint import pprint
22 |
23 | import tensorflow as tf
24 | from tensorflow.contrib.slim.python.slim.data import parallel_reader
25 |
26 | slim = tf.contrib.slim
27 |
28 |
29 | # =========================================================================== #
30 | # General tools.
31 | # =========================================================================== #
32 | def reshape_list(l, shape=None):
33 | """Reshape list of (list): 1D to 2D or the other way around.
34 |
35 | Args:
36 | l: List or List of list.
37 | shape: 1D or 2D shape.
38 | Return
39 | Reshaped list.
40 | """
41 | r = []
42 | if shape is None:
43 | # Flatten everything.
44 | for a in l:
45 | if isinstance(a, (list, tuple)):
46 | r = r + list(a)
47 | else:
48 | r.append(a)
49 | else:
50 | # Reshape to list of list.
51 | i = 0
52 | for s in shape:
53 | if s == 1:
54 | r.append(l[i])
55 | else:
56 | r.append(l[i:i+s])
57 | i += s
58 | return r
59 |
60 |
61 | # =========================================================================== #
62 | # Training utils.
63 | # =========================================================================== #
64 | def print_configuration(flags, ssd_params, data_sources, save_dir=None):
65 | """Print the training configuration.
66 | """
67 | def print_config(stream=None):
68 | print('\n# =========================================================================== #', file=stream)
69 | print('# Training | Evaluation flags:', file=stream)
70 | print('# =========================================================================== #', file=stream)
71 | pprint(flags, stream=stream)
72 |
73 | print('\n# =========================================================================== #', file=stream)
74 | print('# SSD net parameters:', file=stream)
75 | print('# =========================================================================== #', file=stream)
76 | pprint(dict(ssd_params._asdict()), stream=stream)
77 |
78 | print('\n# =========================================================================== #', file=stream)
79 | print('# Training | Evaluation dataset files:', file=stream)
80 | print('# =========================================================================== #', file=stream)
81 | data_files = parallel_reader.get_data_files(data_sources)
82 | pprint(data_files, stream=stream)
83 | print('', file=stream)
84 |
85 | print_config(None)
86 | # Save to a text file as well.
87 | if save_dir is not None:
88 | if not os.path.exists(save_dir):
89 | os.makedirs(save_dir)
90 | path = os.path.join(save_dir, 'training_config.txt')
91 | with open(path, "w") as out:
92 | print_config(out)
93 |
94 |
95 | def configure_learning_rate(flags, num_samples_per_epoch, global_step):
96 | """Configures the learning rate.
97 |
98 | Args:
99 | num_samples_per_epoch: The number of samples in each epoch of training.
100 | global_step: The global_step tensor.
101 | Returns:
102 | A `Tensor` representing the learning rate.
103 | """
104 | decay_steps = int(num_samples_per_epoch / flags.batch_size *
105 | flags.num_epochs_per_decay)
106 |
107 | if flags.learning_rate_decay_type == 'exponential':
108 | return tf.train.exponential_decay(flags.learning_rate,
109 | global_step,
110 | decay_steps,
111 | flags.learning_rate_decay_factor,
112 | staircase=True,
113 | name='exponential_decay_learning_rate')
114 | elif flags.learning_rate_decay_type == 'fixed':
115 | return tf.constant(flags.learning_rate, name='fixed_learning_rate')
116 | elif flags.learning_rate_decay_type == 'polynomial':
117 | return tf.train.polynomial_decay(flags.learning_rate,
118 | global_step,
119 | decay_steps,
120 | flags.end_learning_rate,
121 | power=1.0,
122 | cycle=False,
123 | name='polynomial_decay_learning_rate')
124 | else:
125 | raise ValueError('learning_rate_decay_type [%s] was not recognized',
126 | flags.learning_rate_decay_type)
127 |
128 |
129 | def configure_optimizer(flags, learning_rate):
130 | """Configures the optimizer used for training.
131 |
132 | Args:
133 | learning_rate: A scalar or `Tensor` learning rate.
134 | Returns:
135 | An instance of an optimizer.
136 | """
137 | if flags.optimizer == 'adadelta':
138 | optimizer = tf.train.AdadeltaOptimizer(
139 | learning_rate,
140 | rho=flags.adadelta_rho,
141 | epsilon=flags.opt_epsilon)
142 | elif flags.optimizer == 'adagrad':
143 | optimizer = tf.train.AdagradOptimizer(
144 | learning_rate,
145 | initial_accumulator_value=flags.adagrad_initial_accumulator_value)
146 | elif flags.optimizer == 'adam':
147 | optimizer = tf.train.AdamOptimizer(
148 | learning_rate,
149 | beta1=flags.adam_beta1,
150 | beta2=flags.adam_beta2,
151 | epsilon=flags.opt_epsilon)
152 | elif flags.optimizer == 'ftrl':
153 | optimizer = tf.train.FtrlOptimizer(
154 | learning_rate,
155 | learning_rate_power=flags.ftrl_learning_rate_power,
156 | initial_accumulator_value=flags.ftrl_initial_accumulator_value,
157 | l1_regularization_strength=flags.ftrl_l1,
158 | l2_regularization_strength=flags.ftrl_l2)
159 | elif flags.optimizer == 'momentum':
160 | optimizer = tf.train.MomentumOptimizer(
161 | learning_rate,
162 | momentum=flags.momentum,
163 | name='Momentum')
164 | elif flags.optimizer == 'rmsprop':
165 | optimizer = tf.train.RMSPropOptimizer(
166 | learning_rate,
167 | decay=flags.rmsprop_decay,
168 | momentum=flags.rmsprop_momentum,
169 | epsilon=flags.opt_epsilon)
170 | elif flags.optimizer == 'sgd':
171 | optimizer = tf.train.GradientDescentOptimizer(learning_rate)
172 | else:
173 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer)
174 | return optimizer
175 |
176 |
177 | def add_variables_summaries(learning_rate):
178 | summaries = []
179 | for variable in slim.get_model_variables():
180 | summaries.append(tf.summary.histogram(variable.op.name, variable))
181 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate))
182 | return summaries
183 |
184 |
185 | def update_model_scope(var, ckpt_scope, new_scope):
186 | return var.op.name.replace(new_scope,'vgg_16')
187 |
188 |
189 | def get_init_fn(flags):
190 | """Returns a function run by the chief worker to warm-start the training.
191 | Note that the init_fn is only run when initializing the model during the very
192 | first global step.
193 |
194 | Returns:
195 | An init function run by the supervisor.
196 | """
197 | if flags.checkpoint_path is None:
198 | return None
199 | # Warn the user if a checkpoint exists in the train_dir. Then ignore.
200 | if tf.train.latest_checkpoint(flags.train_dir):
201 | tf.logging.info(
202 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s'
203 | % flags.train_dir)
204 | return None
205 |
206 | exclusions = []
207 | if flags.checkpoint_exclude_scopes:
208 | exclusions = [scope.strip()
209 | for scope in flags.checkpoint_exclude_scopes.split(',')]
210 |
211 | # TODO(sguada) variables.filter_variables()
212 | variables_to_restore = []
213 | for var in slim.get_model_variables():
214 | excluded = False
215 | for exclusion in exclusions:
216 | if var.op.name.startswith(exclusion):
217 | excluded = True
218 | break
219 | if not excluded:
220 | variables_to_restore.append(var)
221 | # Change model scope if necessary.
222 | if flags.checkpoint_model_scope is not None:
223 | variables_to_restore = \
224 | {var.op.name.replace(flags.model_name,
225 | flags.checkpoint_model_scope): var
226 | for var in variables_to_restore}
227 | print(variables_to_restore)
228 | if tf.gfile.IsDirectory(flags.checkpoint_path):
229 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path)
230 | else:
231 | checkpoint_path = flags.checkpoint_path
232 | tf.logging.info('Fine-tuning from %s' % checkpoint_path)
233 |
234 | return slim.assign_from_checkpoint_fn(
235 | checkpoint_path,
236 | variables_to_restore,
237 | ignore_missing_vars=flags.ignore_missing_vars)
238 |
239 |
240 | def get_variables_to_train(flags):
241 | """Returns a list of variables to train.
242 |
243 | Returns:
244 | A list of variables to train by the optimizer.
245 | """
246 | if flags.trainable_scopes is None:
247 | return tf.trainable_variables()
248 | else:
249 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')]
250 |
251 | variables_to_train = []
252 | for scope in scopes:
253 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
254 | variables_to_train.extend(variables)
255 | return variables_to_train
256 |
257 |
258 | # =========================================================================== #
259 | # Evaluation utils.
260 | # =========================================================================== #
261 |
--------------------------------------------------------------------------------
/tf_extended/tf_utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/tf_utils.pyc
--------------------------------------------------------------------------------
/tools/_init_paths.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 |
4 | def add_path(path):
5 | if path not in sys.path:
6 | sys.path.insert(0, path)
7 |
8 | this_dir = osp.dirname(__file__)
9 |
10 | # Add lib to PYTHONPATH
11 | lib_path = osp.join(this_dir, '..')
12 | add_path(lib_path)
13 |
--------------------------------------------------------------------------------
/tools/convert_xml_format.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import sys
4 | import random
5 | import time
6 | import numpy as np
7 | import codecs
8 | import cv2
9 | import xml.etree.ElementTree as ET
10 | from xml.etree.ElementTree import SubElement
11 |
12 |
13 | def process_convert(name, DIRECTORY_ANNOTATIONS, img_path, save_xml_path):
14 | # Read the XML annotation file.
15 | filename = os.path.join(DIRECTORY_ANNOTATIONS, name)
16 | try:
17 | tree = ET.parse(filename)
18 | except:
19 | print('error:', filename, ' not exist')
20 | return False
21 |
22 | root = tree.getroot()
23 | size = root.find('size')
24 | if size is None:
25 | img = cv2.imread(img_path)
26 | print('jpg_path', img_path, img.shape)
27 | shape = [int(img.shape[0]), int(img.shape[1]), int(img.shape[2])]
28 | # size = SubElement(root, 'size')
29 |
30 | elif size.find('height').text is None or size.find('width').text is None:
31 | img = cv2.imread(img_path)
32 | print('jpg_path height', img_path, img.shape)
33 | shape = [int(img.shape[0]), int(img.shape[1]), int(img.shape[2])]
34 | elif int(size.find('height').text) == 0 or int(
35 | size.find('width').text) == 0:
36 |
37 | img = cv2.imread(img_path)
38 | print('jpg_path zero', img_path, img.shape)
39 | shape = [int(img.shape[0]), int(img.shape[1]), int(img.shape[2])]
40 | else:
41 | shape = [
42 | int(size.find('height').text),
43 | int(size.find('width').text),
44 | int(size.find('depth').text)
45 | ]
46 |
47 | height = size.find('height')
48 | height.text = str(shape[0])
49 | width = size.find('width')
50 | width.text = str(shape[1])
51 |
52 | for obj in root.findall('object'):
53 | difficult = int(obj.find('difficult').text)
54 | content = obj.find('name').text
55 | content = content.replace('\t', ' ')
56 |
57 | #if int(difficult) == 1 and content == '&*@HUST_special':
58 | '''
59 | 这里代表HUST_vertical是text
60 | '''
61 | if difficult == 0 and content != '&*@HUST_special' and content != '&*HUST_shelter':
62 | label_name = 'text'
63 | else:
64 | label_name = 'none'
65 |
66 | bbox = obj.find('bndbox')
67 | if obj.find('content') is None:
68 | content_sub = SubElement(obj, 'content')
69 | content_sub.text = content
70 | else:
71 | obj.find('content').text = content
72 |
73 | name_ele = obj.find('name')
74 | name_ele.text = label_name
75 |
76 | xmin = bbox.find('xmin').text
77 | ymin = bbox.find('ymin').text
78 | xmax = bbox.find('xmax').text
79 | ymax = bbox.find('ymax').text
80 |
81 | x1 = xmin
82 | x2 = xmax
83 | x3 = xmax
84 | x4 = xmin
85 |
86 | y1 = ymin
87 | y2 = ymin
88 | y3 = ymax
89 | y4 = ymax
90 |
91 | if bbox.find('x1') is None:
92 | x1_sub = SubElement(bbox, 'x1')
93 | x1_sub.text = x1
94 | x2_sub = SubElement(bbox, 'x2')
95 | x2_sub.text = x2
96 | x3_sub = SubElement(bbox, 'x3')
97 | x3_sub.text = x3
98 | x4_sub = SubElement(bbox, 'x4')
99 | x4_sub.text = x4
100 |
101 | y1_sub = SubElement(bbox, 'y1')
102 | y1_sub.text = y1
103 | y2_sub = SubElement(bbox, 'y2')
104 | y2_sub.text = y2
105 | y3_sub = SubElement(bbox, 'y3')
106 | y3_sub.text = y3
107 | y4_sub = SubElement(bbox, 'y4')
108 | y4_sub.text = y4
109 | else:
110 | bbox.find('y1').text = ymin
111 | bbox.find('y2').text = ymin
112 | bbox.find('y3').text = ymax
113 | bbox.find('y4').text = ymax
114 | #print(save_xml_path)
115 | tree.write(save_xml_path)
116 | return True
117 |
118 |
119 | def process_convert_txt(name, DIRECTORY_ANNOTATIONS):
120 | # Read the XML annotation file.
121 | filename = os.path.join(DIRECTORY_ANNOTATIONS, name)
122 | try:
123 | tree = ET.parse(filename)
124 | except:
125 | print('error:', filename, ' not exist')
126 | return
127 | root = tree.getroot()
128 | all_txt_line = []
129 | for obj in root.findall('object'):
130 |
131 | bbox = obj.find('bndbox')
132 |
133 | difficult = int(obj.find('difficult').text)
134 | content = obj.find('content')
135 | if content is not None:
136 | content = content.text
137 | else:
138 | content = 0
139 | if difficult == 1 and content == '&*@HUST_special':
140 | continue
141 | xmin = bbox.find('xmin').text
142 | ymin = bbox.find('ymin').text
143 | xmax = bbox.find('xmax').text
144 | ymax = bbox.find('ymax').text
145 |
146 | x1 = xmin
147 | x2 = xmax
148 | x3 = xmax
149 | x4 = xmin
150 |
151 | y1 = ymin
152 | y2 = ymin
153 | y3 = ymax
154 | y4 = ymax
155 |
156 | all_txt_line.append('{} {} {} {} {} {} {} {}\n'.format(
157 | x1, y1, x2, y2, x3, y3, x4, y4))
158 |
159 | txt_name = os.path.join(DIRECTORY_ANNOTATIONS, name[:-4] + '.txt')
160 | with codecs.open(txt_name, 'w', encoding='utf-8') as f:
161 | f.writelines(all_txt_line)
162 |
163 |
164 | def get_all_img(directory, split_flag, logs_dir, output_dir):
165 | count = 0
166 | ano_path_list = []
167 | img_path_list = []
168 | if output_dir is not None and not os.path.exists(output_dir):
169 | os.makedirs(output_dir)
170 |
171 | start_time = time.time()
172 | for root, dirs, files in os.walk(directory):
173 | for each in files:
174 | if each.split('.')[-1] == 'xml':
175 | xml_path = os.path.join(root, each[:-4] + '.xml')
176 | img_path = os.path.join(root, each[:-4] + '.png')
177 | if os.path.exists(img_path) == False:
178 | img_path = os.path.join(root, each[:-4] + '.PNG')
179 | test_png = cv2.imread(img_path)
180 | if test_png is None or os.path.exists(xml_path) == False:
181 | continue
182 | if output_dir is not None:
183 | sub_path = root[len(directory)+1:]
184 | sub_path = os.path.join(output_dir, sub_path)
185 | if not os.path.exists(sub_path):
186 | os.makedirs(sub_path)
187 | save_xml_path = os.path.join(sub_path, each[:-4] + '.xml')
188 | else:
189 | save_xml_path = xml_path
190 | if process_convert(each, root, img_path, save_xml_path):
191 | ano_path_list.append('{},{}\n'.format(
192 | img_path,
193 | save_xml_path))
194 | img_path_list.append('{}\n'.format(
195 | img_path))
196 | count += 1
197 | if count % 1000 == 0:
198 | print(count, time.time() - start_time)
199 | save_to_text(img_path_list, ano_path_list, count, split_flag, logs_dir)
200 | print('all over:', count)
201 | print('time:', time.time() - start_time)
202 |
203 |
204 | def save_to_text(img_path_list, ano_path_list, count, split_flag, logs_dir):
205 | if split_flag == 'yes':
206 | train_num = int(count / 10. * 9.)
207 | else:
208 | train_num = count
209 | if not os.path.exists(logs_dir):
210 | os.makedirs(logs_dir)
211 |
212 | with codecs.open(
213 | os.path.join(logs_dir, 'train_xml.txt'), 'w',
214 | encoding='utf-8') as f_xml, codecs.open(
215 | os.path.join(logs_dir, 'train.txt'), 'w',
216 | encoding='utf-8') as f_txt:
217 | f_xml.writelines(ano_path_list[:train_num])
218 | f_txt.writelines(img_path_list[:train_num])
219 |
220 | with codecs.open(
221 | os.path.join(logs_dir, 'test_xml.txt'), 'w',
222 | encoding='utf-8') as f_xml, codecs.open(
223 | os.path.join(logs_dir, 'test.txt'), 'w',
224 | encoding='utf-8') as f_txt:
225 | f_xml.writelines(ano_path_list[train_num:])
226 | f_txt.writelines(img_path_list[train_num:])
227 |
228 |
229 | if __name__ == '__main__':
230 | import argparse
231 | parser = argparse.ArgumentParser(
232 | description='icdar15 generate xml tools for standard format')
233 | parser.add_argument(
234 | '--in_dir',
235 | '-i',
236 | default=
237 | '/home/zsz/datasets/icdar15_anno/annotated_data_3rd_8thv2_cut_resize_margin8',
238 | type=str)
239 | parser.add_argument('--split_flag', '-s', default='no', type=str)
240 | parser.add_argument('--save_logs', '-l', default='logs', type=str)
241 | parser.add_argument('--output_dir', '-o', default=None, type=str)
242 |
243 | args = parser.parse_args()
244 | directory = args.in_dir
245 | split_flag = args.split_flag
246 | logs_dir = args.save_logs
247 | output_dir = args.output_dir
248 | get_all_img(directory, split_flag, logs_dir, output_dir)
249 |
--------------------------------------------------------------------------------
/tools/gen_xml.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | from lxml import etree
4 | import xml.dom.minidom
5 | import sys
6 | import random
7 |
8 | import numpy as np
9 | import codecs
10 | import cv2
11 |
12 | def process_convert(txt_name, DIRECTORY_ANNOTATIONS, new_version=True):
13 |
14 | # Read the txt annotation file.
15 | filename = os.path.join(DIRECTORY_ANNOTATIONS, txt_name)
16 | with codecs.open(filename, 'r', encoding='utf-8') as f:
17 | lines = f.readlines()
18 |
19 | annotation_xml = xml.dom.minidom.Document()
20 |
21 | root = annotation_xml.createElement('annotation')
22 | annotation_xml.appendChild(root)
23 |
24 | nodeFolder = annotation_xml.createElement('folder')
25 | root.appendChild(nodeFolder)
26 |
27 | nodeFilename = annotation_xml.createElement('filename')
28 | nodeFilename.appendChild(annotation_xml.createTextNode(filename))
29 | root.appendChild(nodeFilename)
30 |
31 | img_name = filename[:-4]
32 | if cv2.imread(img_name) is not None:
33 | h, w, c = cv2.imread(img_name).shape
34 | else:
35 | raise KeyError('img_name error:', img_name)
36 | nodeSize = annotation_xml.createElement('size')
37 | nodeWidth = annotation_xml.createElement('width')
38 | nodeHeight = annotation_xml.createElement('height')
39 | nodeDepth = annotation_xml.createElement('depth')
40 | nodeWidth.appendChild(annotation_xml.createTextNode(str(w)))
41 | nodeHeight.appendChild(annotation_xml.createTextNode(str(h)))
42 | nodeDepth.appendChild(annotation_xml.createTextNode(str(c)))
43 |
44 | nodeSize.appendChild(nodeWidth)
45 | nodeSize.appendChild(nodeHeight)
46 | nodeSize.appendChild(nodeDepth)
47 | root.appendChild(nodeSize)
48 | for l in lines:
49 | l = l.encode('utf-8').decode('utf-8-sig')
50 | l = l.strip().split(',')
51 | difficult = 0
52 | label = 1
53 | x1_text = ''
54 | x2_text = ''
55 | x3_text = ''
56 | x4_text = ''
57 | y1_text = ''
58 | y2_text = ''
59 | y3_text = ''
60 | y4_text = ''
61 |
62 | if new_version is True:
63 | label_name = str(l[-1])
64 | if label_name == 'none':
65 | difficult = 1
66 | else:
67 | difficult = 0
68 | label = label_name
69 | xmin_text = str(int(float(l[1])))
70 | ymin_text = str(int(float(l[2])))
71 | xmax_text = str(int(float(l[3])))
72 | ymax_text = str(int(float(l[4])))
73 | x1_text = xmin_text
74 | x2_text = xmax_text
75 | x3_text = xmax_text
76 | x4_text = xmin_text
77 |
78 | y1_text = ymin_text
79 | y2_text = ymin_text
80 | y3_text = ymax_text
81 | y4_text = ymax_text
82 | else:
83 | print(l)
84 | xs = [ int(l[i]) for i in (0, 2, 4, 6)]
85 | ys = [ int(l[i]) for i in (1, 3, 5, 7)]
86 | xmin_text = str(min(xs))
87 | ymin_text = str(min(ys))
88 | xmax_text = str(max(xs))
89 | ymax_text = str(max(ys))
90 | x1_text = str(xs[0])
91 | x2_text = str(xs[1])
92 | x3_text = str(xs[2])
93 | x4_text = str(xs[3])
94 |
95 | y1_text = str(ys[0])
96 | y2_text = str(ys[1])
97 | y3_text = str(ys[2])
98 | y4_text = str(ys[3])
99 |
100 |
101 | nodeObject = annotation_xml.createElement('object')
102 | nodeDifficult = annotation_xml.createElement('difficult')
103 | nodeDifficult.appendChild(annotation_xml.createTextNode(str(difficult)))
104 | nodeName = annotation_xml.createElement('name')
105 | nodeName.appendChild(annotation_xml.createTextNode(str(label)))
106 |
107 | nodeBndbox = annotation_xml.createElement('bndbox')
108 | nodexmin = annotation_xml.createElement('xmin')
109 | nodexmin.appendChild(annotation_xml.createTextNode(xmin_text))
110 | nodeymin = annotation_xml.createElement('ymin')
111 | nodeymin.appendChild(annotation_xml.createTextNode(ymin_text))
112 | nodexmax = annotation_xml.createElement('xmax')
113 | nodexmax.appendChild(annotation_xml.createTextNode(xmax_text))
114 | nodeymax = annotation_xml.createElement('ymax')
115 | nodeymax.appendChild(annotation_xml.createTextNode(ymax_text))
116 |
117 | nodex1 = annotation_xml.createElement('x1')
118 | nodex1.appendChild(annotation_xml.createTextNode(x1_text))
119 | nodex2 = annotation_xml.createElement('x2')
120 | nodex2.appendChild(annotation_xml.createTextNode(x2_text))
121 | nodex3 = annotation_xml.createElement('x3')
122 | nodex3.appendChild(annotation_xml.createTextNode(x3_text))
123 | nodex4 = annotation_xml.createElement('x4')
124 | nodex4.appendChild(annotation_xml.createTextNode(x4_text))
125 |
126 |
127 | nodey1 = annotation_xml.createElement('y1')
128 | nodey1.appendChild(annotation_xml.createTextNode(y1_text))
129 |
130 | nodey2 = annotation_xml.createElement('y2')
131 | nodey2.appendChild(annotation_xml.createTextNode(y2_text))
132 |
133 | nodey3 = annotation_xml.createElement('y3')
134 | nodey3.appendChild(annotation_xml.createTextNode(y3_text))
135 | nodey4 = annotation_xml.createElement('y4')
136 | nodey4.appendChild(annotation_xml.createTextNode(y4_text))
137 |
138 | nodeBndbox.appendChild(nodexmin)
139 | nodeBndbox.appendChild(nodeymin)
140 | nodeBndbox.appendChild(nodexmax)
141 | nodeBndbox.appendChild(nodeymax)
142 | nodeBndbox.appendChild(nodex1)
143 | nodeBndbox.appendChild(nodex2)
144 | nodeBndbox.appendChild(nodex3)
145 | nodeBndbox.appendChild(nodex4)
146 |
147 | nodeBndbox.appendChild(nodey1)
148 | nodeBndbox.appendChild(nodey2)
149 | nodeBndbox.appendChild(nodey3)
150 | nodeBndbox.appendChild(nodey4)
151 |
152 | nodeObject.appendChild(nodeDifficult)
153 | nodeObject.appendChild(nodeName)
154 | nodeObject.appendChild(nodeBndbox)
155 | root.appendChild(nodeObject)
156 |
157 |
158 |
159 | xml_path = os.path.join(DIRECTORY_ANNOTATIONS, txt_name[0:-4] + '.xml')
160 | fp = open(xml_path, 'w')
161 | annotation_xml.writexml(fp, indent='\t', addindent='\t', newl='\n', encoding="utf-8")
162 |
163 |
164 |
165 | def get_all_txt(directory, new_version=False):
166 | count = 0
167 | for root,dirs,files in os.walk(directory):
168 | for each in files:
169 | if each.split('.')[-1] == 'txt':
170 | count += 1
171 | print(count, each)
172 |
173 | process_convert(each, root, new_version)
174 |
175 |
176 |
177 | if __name__ == '__main__':
178 | import argparse
179 | parser = argparse.ArgumentParser(description='icdar15 generate xml tools')
180 | parser.add_argument('--in_dir','-i', default='/home/zsz/datasets/icdar15_anno/eval_img_2018.9.25/icdar15_eval_final', type=str)
181 | args = parser.parse_args()
182 | directory = args.in_dir
183 | get_all_txt(directory, False)
184 |
--------------------------------------------------------------------------------
/tools/test_dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import sythtextprovider
2 | import tensorflow as tf
3 | import numpy as np
4 | import cv2
5 | import os
6 | import time
7 | from nets import txtbox_384
8 | from processing import ssd_vgg_preprocessing
9 |
10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
11 | slim = tf.contrib.slim
12 |
13 | tf.logging.set_verbosity(tf.logging.INFO)
14 |
15 | show_pic_sum = 10
16 | save_dir = 'pic_test_dataset'
17 | tf.app.flags.DEFINE_string(
18 | 'dataset_dir', 'tfrecord_train', 'The directory where the dataset files are stored.')
19 |
20 | tf.app.flags.DEFINE_integer(
21 | 'num_readers', 2,
22 | 'The number of parallel readers that read data from the dataset.')
23 |
24 | tf.app.flags.DEFINE_integer(
25 | 'batch_size', 2, 'The number of samples in each batch.')
26 |
27 | FLAGS = tf.app.flags.FLAGS
28 | if not os.path.exists(save_dir):
29 | os.makedirs(save_dir)
30 |
31 | def draw_polygon(img,x1,y1,x2,y2,x3,y3,x4,y4, color=(255, 0, 0)):
32 | # print(x1, x2, x3, x4, y1, y2, y3, y4)
33 | x1 = int(x1)
34 | x2 = int(x2)
35 | x3 = int(x3)
36 | x4 = int(x4)
37 |
38 | y1 = int(y1)
39 | y2 = int(y2)
40 | y3 = int(y3)
41 | y4 = int(y4)
42 | cv2.line(img,(x1,y1),(x2,y2),color,2)
43 | cv2.line(img,(x2,y2),(x3,y3),color,2)
44 | cv2.line(img,(x3,y3),(x4,y4),color,2)
45 | cv2.line(img,(x4,y4),(x1,y1),color,2)
46 | # cv2_im = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
47 | # cv2_im = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
48 | # cv2.imwrite('test.png', img)
49 | return img
50 |
51 |
52 | def run():
53 | if not FLAGS.dataset_dir:
54 | raise ValueError('You must supply the dataset directory with --dataset_dir')
55 |
56 | print('-----start test-------')
57 | if not os.path.exists(save_dir):
58 | os.makedirs(save_dir)
59 | with tf.device('/GPU:0'):
60 | dataset = sythtextprovider.get_datasets(FLAGS.dataset_dir)
61 | print(dataset)
62 | provider = slim.dataset_data_provider.DatasetDataProvider(
63 | dataset,
64 | num_readers=FLAGS.num_readers,
65 | common_queue_capacity=20 * FLAGS.batch_size,
66 | common_queue_min=10 * FLAGS.batch_size,
67 | shuffle=True)
68 | print('provider:',provider)
69 | [image, shape, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4] = provider.get(['image', 'shape',
70 | 'object/label',
71 | 'object/bbox',
72 | 'object/oriented_bbox/x1',
73 | 'object/oriented_bbox/x2',
74 | 'object/oriented_bbox/x3',
75 | 'object/oriented_bbox/x4',
76 | 'object/oriented_bbox/y1',
77 | 'object/oriented_bbox/y2',
78 | 'object/oriented_bbox/y3',
79 | 'object/oriented_bbox/y4'
80 | ])
81 | print('image:',image)
82 | print('shape:',shape)
83 | print('glabel:',glabels)
84 | print('gboxes:',gbboxes)
85 |
86 |
87 | gxs = tf.transpose(tf.stack([x1,x2,x3,x4])) #shape = (N,4)
88 | gys = tf.transpose(tf.stack([y1, y2, y3, y4]))
89 |
90 | image = tf.identity(image, 'input_image')
91 | text_shape = (384, 384)
92 | image, glabels, gbboxes, gxs, gys= ssd_vgg_preprocessing.preprocess_image(image, glabels,gbboxes,gxs, gys,
93 | text_shape,is_training=True,
94 | data_format='NHWC')
95 |
96 |
97 | x1, x2 , x3, x4 = tf.unstack(gxs, axis=1)
98 | y1, y2, y3, y4 = tf.unstack(gys, axis=1)
99 |
100 | text_net = txtbox_384.TextboxNet()
101 | text_anchors = text_net.anchors(text_shape)
102 | e_localisations, e_scores, e_labels = text_net.bboxes_encode( glabels, gbboxes, text_anchors, gxs, gys)
103 |
104 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
105 |
106 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options, allow_soft_placement=True)
107 | with tf.Session(config=config) as sess:
108 | coord = tf.train.Coordinator()
109 | threads = tf.train.start_queue_runners(sess, coord)
110 | j = 0
111 | all_time = 0
112 | try:
113 | while not coord.should_stop() and j < show_pic_sum:
114 | start_time = time.time()
115 | image_sess, label_sess, gbbox_sess, x1_sess, x2_sess, x3_sess, x4_sess, y1_sess, y2_sess, y3_sess, y4_sess,p_localisations, p_scores, p_labels = sess.run([
116 | image, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4,e_localisations , e_scores, e_labels])
117 | end_time = time.time() - start_time
118 | all_time += end_time
119 | image_np = image_sess
120 | # print(image_np)
121 | # print('label_sess:',label_sess)
122 |
123 | p_labels_concat = np.concatenate(p_labels)
124 | p_scores_concat = np.concatenate(p_scores)
125 | debug = False
126 | if debug is True:
127 | print(p_labels)
128 | print('l_labels:', len(p_labels_concat[p_labels_concat.nonzero()]),p_labels_concat[p_labels_concat.nonzero()] )
129 | print('p_socres:', len(p_scores_concat[p_scores_concat.nonzero()]), p_scores_concat[p_scores_concat.nonzero()])
130 | # print(img_np.shape)
131 |
132 | print('label_sess:', np.array(list(label_sess)).shape, list(label_sess))
133 | img_np = np.array(image_np)
134 | cv2.imwrite('{}/{}.png'.format(save_dir, j), img_np)
135 | img_np = cv2.imread('{}/{}.png'.format(save_dir, j))
136 |
137 | h, w, d = img_np.shape
138 |
139 | label_sess = list(label_sess)
140 | # for i , label in enumerate(label_sess):
141 | i = 0
142 | num_correct = 0
143 |
144 | for label in label_sess:
145 | # print(int(label) == 1)
146 | if int(label) == 1:
147 | num_correct += 1
148 | img_np = draw_polygon(img_np,x1_sess[i] * w, y1_sess[i]*h, x2_sess[i]*w, y2_sess[i]*h, x3_sess[i]*w, y3_sess[i]*h, x4_sess[i]*w, y4_sess[i]*h)
149 | if int(label) == 0:
150 | img_np = draw_polygon(img_np,x1_sess[i] * w, y1_sess[i]*h, x2_sess[i]*w, y2_sess[i]*h, x3_sess[i]*w, y3_sess[i]*h, x4_sess[i]*w, y4_sess[i]*h, color=(0, 0, 255))
151 | i += 1
152 | img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
153 | cv2.imwrite('{}'.format(os.path.join(save_dir, str(j)+'.png')), img_np)
154 | j+= 1
155 | print('correct:', num_correct)
156 | except tf.errors.OutOfRangeError:
157 | print('done')
158 | finally:
159 | print('done')
160 | coord.request_stop()
161 | print('all time:', all_time, 'average:', all_time / show_pic_sum)
162 | coord.join(threads=threads)
163 |
164 | if __name__ == '__main__':
165 | run()
166 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 Paul Balanca. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Generic training script that trains a SSD model using a given dataset."""
16 |
17 | import tensorflow as tf
18 | from tensorflow.python.ops import control_flow_ops
19 |
20 | from datasets import sythtextprovider
21 | from deployment import model_deploy
22 | from nets import txtbox_384, txtbox_768
23 | from processing import ssd_vgg_preprocessing
24 | from tf_extended import tf_utils
25 | import os
26 | import tensorflow.contrib.slim as slim
27 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
28 |
29 | # =========================================================================== #
30 | # Text Network flags.
31 | # =========================================================================== #
32 | tf.app.flags.DEFINE_float('loss_alpha', 0.2,
33 | 'Alpha parameter in the loss function.')
34 | tf.app.flags.DEFINE_float('negative_ratio', 3.,
35 | 'Negative ratio in the loss function.')
36 | tf.app.flags.DEFINE_float('match_threshold', 0.5,
37 | 'Matching threshold in the loss function.')
38 | tf.app.flags.DEFINE_boolean('large_training', False, 'Use 768 to train')
39 | # =========================================================================== #
40 | # General Flags.
41 | # =========================================================================== #
42 | tf.app.flags.DEFINE_string(
43 | 'train_dir',
44 | 'icdar15_model/',
45 | 'Directory where checkpoints and event logs are written to.')
46 | tf.app.flags.DEFINE_integer('num_clones', 1,
47 | 'Number of model clones to deploy.')
48 | tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
49 | 'Use CPUs to deploy clones.')
50 | tf.app.flags.DEFINE_integer(
51 | 'num_readers', 8,
52 | 'The number of parallel readers that read data from the dataset.')
53 | tf.app.flags.DEFINE_integer(
54 | 'num_preprocessing_threads', 8,
55 | 'The number of threads used to create the batches.')
56 |
57 | tf.app.flags.DEFINE_integer('log_every_n_steps', 10,
58 | 'The frequency with which logs are print.')
59 | tf.app.flags.DEFINE_integer(
60 | 'save_summaries_secs', 120,
61 | 'The frequency with which summaries are saved, in seconds.')
62 | tf.app.flags.DEFINE_integer(
63 | 'save_interval_secs', 1200,
64 | 'The frequency with which the model is saved, in seconds.')
65 | tf.app.flags.DEFINE_float('gpu_memory_fraction', 0.9,
66 | 'GPU memory fraction to use.')
67 |
68 | # =========================================================================== #
69 | # Optimization Flags.
70 | # =========================================================================== #
71 | tf.app.flags.DEFINE_float('weight_decay', 0.0005,
72 | 'The weight decay on the model weights.')
73 | tf.app.flags.DEFINE_string(
74 | 'optimizer', 'adam',
75 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
76 | '"ftrl", "momentum", "sgd" or "rmsprop".')
77 | tf.app.flags.DEFINE_float('adadelta_rho', 0.95, 'The decay rate for adadelta.')
78 | tf.app.flags.DEFINE_float('adagrad_initial_accumulator_value', 0.1,
79 | 'Starting value for the AdaGrad accumulators.')
80 | tf.app.flags.DEFINE_float(
81 | 'adam_beta1', 0.9,
82 | 'The exponential decay rate for the 1st moment estimates.')
83 | tf.app.flags.DEFINE_float(
84 | 'adam_beta2', 0.999,
85 | 'The exponential decay rate for the 2nd moment estimates.')
86 | tf.app.flags.DEFINE_float('opt_epsilon', 1.0,
87 | 'Epsilon term for the optimizer.')
88 | tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
89 | 'The learning rate power.')
90 | tf.app.flags.DEFINE_float('ftrl_initial_accumulator_value', 0.1,
91 | 'Starting value for the FTRL accumulators.')
92 | tf.app.flags.DEFINE_float('ftrl_l1', 0.0,
93 | 'The FTRL l1 regularization strength.')
94 | tf.app.flags.DEFINE_float('ftrl_l2', 0.0,
95 | 'The FTRL l2 regularization strength.')
96 | tf.app.flags.DEFINE_float(
97 | 'momentum', 0.9,
98 | 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
99 | tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
100 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
101 |
102 | # =========================================================================== #
103 | # Learning Rate Flags.
104 | # =========================================================================== #
105 | tf.app.flags.DEFINE_string(
106 | 'learning_rate_decay_type', 'exponential',
107 | 'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
108 | ' or "polynomial"')
109 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Initial learning rate.')
110 | tf.app.flags.DEFINE_float(
111 | 'end_learning_rate', 0.0001,
112 | 'The minimal end learning rate used by a polynomial decay learning rate.')
113 | tf.app.flags.DEFINE_float('label_smoothing', 0.0,
114 | 'The amount of label smoothing.')
115 | tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.1,
116 | 'Learning rate decay factor.')
117 | tf.app.flags.DEFINE_float(
118 | 'num_epochs_per_decay', 40000,
119 | 'Number of epochs after which learning rate decays.')
120 | tf.app.flags.DEFINE_float(
121 | 'moving_average_decay', None, 'The decay to use for the moving average.'
122 | 'If left as None, then moving averages are not used.')
123 |
124 | # =========================================================================== #
125 | # Dataset Flags.
126 | # =========================================================================== #
127 | tf.app.flags.DEFINE_string('dataset_name', 'sythtext',
128 | 'The name of the dataset to load.')
129 | tf.app.flags.DEFINE_integer('num_classes', 2,
130 | 'Number of classes to use in the dataset.')
131 | tf.app.flags.DEFINE_string('dataset_split_name', 'train',
132 | 'The name of the train/test split.')
133 | tf.app.flags.DEFINE_string(
134 | 'dataset_dir', 'icdar15_tf',
135 | ' The directory where the dataset files are stored.')
136 | tf.app.flags.DEFINE_integer(
137 | 'labels_offset', 0,
138 | 'An offset for the labels in the dataset. This flag is primarily used to '
139 | 'evaluate the VGG and ResNet architectures which do not use a background '
140 | 'class for the ImageNet dataset.')
141 | tf.app.flags.DEFINE_string('model_name', 'text_box_384',
142 | 'The name of the architecture to train.')
143 | tf.app.flags.DEFINE_string(
144 | 'preprocessing_name', None,
145 | 'The name of the preprocessing to use. If left '
146 | 'as `None`, then the model_name flag is used.')
147 | tf.app.flags.DEFINE_integer('batch_size', 16,
148 | 'The number of samples in each batch.')
149 | tf.app.flags.DEFINE_integer('train_image_size', None, 'Train image size')
150 | tf.app.flags.DEFINE_string('training_image_crop_area', '0.1, 1.0',
151 | 'the area of image process for training')
152 | tf.app.flags.DEFINE_integer('max_number_of_steps', 120000,
153 | 'The maxim number of training steps.')
154 | # =========================================================================== #
155 | # Fine-Tuning Flags.
156 | # =========================================================================== #
157 | tf.app.flags.DEFINE_string(
158 | #'checkpoint_path','/home/zsz/code/TextBoxes_plusplus_Tensorflow/model/vgg_fc_16_model/vgg_16.ckpt',
159 | 'checkpoint_path', '/home/zsz/TextBoxes_plusplus/models/ckpt/',
160 | 'The path to a checkpoint from which to fine-tune.')
161 | tf.app.flags.DEFINE_string(
162 | 'checkpoint_model_scope', None,
163 | 'Model scope in the checkpoint. None if the same as the trained model.')
164 | tf.app.flags.DEFINE_string(
165 | 'checkpoint_exclude_scopes', None,
166 | 'Comma-separated list of scopes of variables to exclude when restoring '
167 | 'from a checkpoint.')
168 | tf.app.flags.DEFINE_string(
169 | 'trainable_scopes', None,
170 | 'Comma-separated list of scopes to filter the set of variables to train.'
171 | 'By default, None would train all the variables.')
172 | tf.app.flags.DEFINE_boolean(
173 | 'ignore_missing_vars', False,
174 | 'When restoring a checkpoint would ignore missing variables.')
175 |
176 | FLAGS = tf.app.flags.FLAGS
177 |
178 |
179 | # =========================================================================== #
180 | # Main training routine.
181 | # =========================================================================== #
182 | def main(_):
183 | if not FLAGS.dataset_dir:
184 | raise ValueError(
185 | 'You must supply the dataset directory with --dataset_dir')
186 |
187 | tf.logging.set_verbosity(tf.logging.DEBUG)
188 | with tf.Graph().as_default():
189 | # Config model_deploy. Keep TF Slim Models structure.
190 | # Useful if want to need multiple GPUs and/or servers in the future.
191 | deploy_config = model_deploy.DeploymentConfig(
192 | num_clones=FLAGS.num_clones,
193 | clone_on_cpu=FLAGS.clone_on_cpu,
194 | replica_id=0,
195 | num_replicas=1,
196 | num_ps_tasks=0)
197 | # Create global_step.
198 | with tf.device(deploy_config.variables_device()):
199 | global_step = slim.create_global_step()
200 |
201 | # Select the dataset.
202 |
203 | dataset = sythtextprovider.get_datasets(FLAGS.dataset_dir)
204 | # Get the TextBoxes++ network and its anchors.
205 | text_net = txtbox_384.TextboxNet()
206 | if FLAGS.large_training:
207 | text_net.params = text_net.params._replace(img_shape = (768, 768))
208 | text_net.params = text_net.params._replace(feat_shapes = [(96, 96), (48,48), (24, 24), (12, 12), (10, 10), (8, 8)])
209 | text_shape = text_net.params.img_shape
210 | print('text_shape ' + str(text_shape))
211 | text_anchors = text_net.anchors(text_shape)
212 |
213 | tf_utils.print_configuration(FLAGS.__flags, text_net.params,
214 | dataset.data_sources, FLAGS.train_dir)
215 | # =================================================================== #
216 | # Create a dataset provider and batches.
217 | # =================================================================== #
218 |
219 | with tf.device(deploy_config.inputs_device()):
220 | with tf.name_scope(FLAGS.dataset_name + '_data_provider'):
221 | provider = slim.dataset_data_provider.DatasetDataProvider(
222 | dataset,
223 | num_readers=FLAGS.num_readers,
224 | common_queue_capacity=1000 * FLAGS.batch_size,
225 | common_queue_min=300 * FLAGS.batch_size,
226 | shuffle=True)
227 | # Get for SSD network: image, labels, bboxes.
228 | [image, shape, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3,
229 | y4] = provider.get([
230 | 'image', 'shape', 'object/label', 'object/bbox',
231 | 'object/oriented_bbox/x1', 'object/oriented_bbox/x2',
232 | 'object/oriented_bbox/x3', 'object/oriented_bbox/x4',
233 | 'object/oriented_bbox/y1', 'object/oriented_bbox/y2',
234 | 'object/oriented_bbox/y3', 'object/oriented_bbox/y4'
235 | ])
236 | gxs = tf.transpose(tf.stack([x1, x2, x3, x4])) #shape = (N,4)
237 | gys = tf.transpose(tf.stack([y1, y2, y3, y4]))
238 |
239 | image = tf.identity(image, 'input_image')
240 |
241 | init_op = tf.global_variables_initializer()
242 | # tf.global_variables_initializer()
243 | # Pre-processing image, labels and bboxes.
244 | training_image_crop_area = FLAGS.training_image_crop_area
245 | area_split = training_image_crop_area.split(',')
246 | assert len(area_split) == 2
247 | training_image_crop_area = [
248 | float(area_split[0]),
249 | float(area_split[1])
250 | ]
251 |
252 | image, glabels, gbboxes, gxs, gys= \
253 | ssd_vgg_preprocessing.preprocess_for_train(image, glabels,gbboxes,gxs, gys,
254 | text_shape,
255 | data_format='NHWC', crop_area_range=training_image_crop_area)
256 |
257 | # Encode groundtruth labels and bboxes.
258 |
259 | image = tf.identity(image, 'processed_image')
260 |
261 | glocalisations, gscores, glabels = \
262 | text_net.bboxes_encode( glabels, gbboxes, text_anchors, gxs, gys)
263 | batch_shape = [1] + [len(text_anchors)] * 3
264 |
265 | # Training batches and queue.
266 |
267 | r = tf.train.batch(
268 | tf_utils.reshape_list(
269 | [image, glocalisations, gscores, glabels]),
270 | batch_size=FLAGS.batch_size,
271 | num_threads=FLAGS.num_preprocessing_threads,
272 | capacity=5 * FLAGS.batch_size)
273 |
274 | b_image, b_glocalisations, b_gscores, b_glabels= \
275 | tf_utils.reshape_list(r, batch_shape)
276 |
277 | # Intermediate queueing: unique batch computation pipeline for all
278 | # GPUs running the training.
279 | batch_queue = slim.prefetch_queue.prefetch_queue(
280 | tf_utils.reshape_list(
281 | [b_image, b_glocalisations, b_gscores, b_glabels]),
282 | capacity=2 * deploy_config.num_clones)
283 |
284 | # =================================================================== #
285 | # Define the model running on every GPU.
286 | # =================================================================== #
287 | def clone_fn(batch_queue):
288 |
289 | #Allows data parallelism by creating multiple
290 | #clones of network_fn.
291 | # Dequeue batch.
292 | b_image, b_glocalisations, b_gscores, b_glabels = \
293 | tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)
294 |
295 | # Construct TextBoxes network.
296 | arg_scope = text_net.arg_scope(weight_decay=FLAGS.weight_decay)
297 | with slim.arg_scope(arg_scope):
298 | predictions,localisations, logits, end_points = \
299 | text_net.net(b_image, is_training=True)
300 | # Add loss function.
301 |
302 | text_net.losses(
303 | logits,
304 | localisations,
305 | b_glabels,
306 | b_glocalisations,
307 | b_gscores,
308 | match_threshold=FLAGS.match_threshold,
309 | negative_ratio=FLAGS.negative_ratio,
310 | alpha=FLAGS.loss_alpha,
311 | label_smoothing=FLAGS.label_smoothing,
312 | batch_size=FLAGS.batch_size)
313 | return end_points
314 |
315 | # Gather initial summaries.
316 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
317 |
318 | # =================================================================== #
319 | # Add summaries from first clone.
320 | # =================================================================== #
321 | clones = model_deploy.create_clones(deploy_config, clone_fn,
322 | [batch_queue])
323 | first_clone_scope = deploy_config.clone_scope(0)
324 | # Gather update_ops from the first clone. These contain, for example,
325 | # the updates for the batch_norm variables created by network_fn.
326 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
327 |
328 | # Add summaries for end_points.
329 | end_points = clones[0].outputs
330 | for end_point in end_points:
331 | x = end_points[end_point]
332 | summaries.add(tf.summary.histogram('activations/' + end_point, x))
333 | summaries.add(
334 | tf.summary.scalar('sparsity/' + end_point,
335 | tf.nn.zero_fraction(x)))
336 | # Add summaries for losses and extra losses.
337 | for loss in tf.get_collection(tf.GraphKeys.LOSSES):
338 | summaries.add(tf.summary.scalar(loss.op.name, loss))
339 | for loss in tf.get_collection('EXTRA_LOSSES'):
340 | summaries.add(tf.summary.scalar(loss.op.name, loss))
341 |
342 | # Add summaries for variables.
343 | for variable in slim.get_model_variables():
344 | summaries.add(tf.summary.histogram(variable.op.name, variable))
345 |
346 | # =================================================================== #
347 | # Configure the moving averages.
348 | # =================================================================== #
349 | if FLAGS.moving_average_decay:
350 | moving_average_variables = slim.get_model_variables()
351 | variable_averages = tf.train.ExponentialMovingAverage(
352 | FLAGS.moving_average_decay, global_step)
353 | else:
354 | moving_average_variables, variable_averages = None, None
355 |
356 | # =================================================================== #
357 | # Configure the optimization procedure.
358 | # =================================================================== #
359 | with tf.device(deploy_config.optimizer_device()):
360 | learning_rate = tf_utils.configure_learning_rate(
361 | FLAGS, dataset.num_samples, global_step)
362 | optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)
363 | summaries.add(tf.summary.scalar('learning_rate', learning_rate))
364 |
365 | if FLAGS.moving_average_decay:
366 | # Update ops executed locally by trainer.
367 | update_ops.append(
368 | variable_averages.apply(moving_average_variables))
369 |
370 | # Variables to train.
371 | variables_to_train = tf_utils.get_variables_to_train(FLAGS)
372 |
373 | # and returns a train_tensor and summary_op
374 | total_loss, clones_gradients = model_deploy.optimize_clones(
375 | clones, optimizer, var_list=variables_to_train)
376 | # Add total_loss to summary.
377 |
378 | summaries.add(tf.summary.scalar('total_loss', total_loss))
379 |
380 | # Create gradient updates.
381 | grad_updates = optimizer.apply_gradients(
382 | clones_gradients, global_step=global_step)
383 | update_ops.append(grad_updates)
384 | update_op = tf.group(*update_ops)
385 | train_tensor = control_flow_ops.with_dependencies(
386 | [update_op], total_loss, name='train_op')
387 |
388 | # Add the summaries from the first clone. These contain the summaries
389 | summaries |= set(
390 | tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope))
391 | # Merge all summaries together.
392 | summary_op = tf.summary.merge(list(summaries), name='summary_op')
393 |
394 | # =================================================================== #
395 | # Kicks off the training.
396 | # =================================================================== #
397 | gpu_options = tf.GPUOptions(
398 | per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
399 |
400 | config = tf.ConfigProto(
401 | log_device_placement=False,
402 | allow_soft_placement=True,
403 | gpu_options=gpu_options)
404 | saver = tf.train.Saver(
405 | max_to_keep=100,
406 | keep_checkpoint_every_n_hours=1.0,
407 | write_version=2,
408 | pad_step_number=False)
409 |
410 | slim.learning.train(
411 | train_tensor,
412 | logdir=FLAGS.train_dir,
413 | master='',
414 | is_chief=True,
415 | # init_op=init_op,
416 | init_fn=tf_utils.get_init_fn(FLAGS),
417 | summary_op=summary_op, ##output variables to logdir
418 | number_of_steps=FLAGS.max_number_of_steps,
419 | log_every_n_steps=FLAGS.log_every_n_steps,
420 | save_summaries_secs=FLAGS.save_summaries_secs,
421 | saver=saver,
422 | save_interval_secs=FLAGS.save_interval_secs,
423 | session_config=config,
424 | sync_optimizer=None)
425 |
426 |
427 | if __name__ == '__main__':
428 | tf.app.run()
429 |
--------------------------------------------------------------------------------