├── README.md ├── datasets ├── __init__.py ├── __init__.pyc ├── dataset_utils.py ├── dataset_utils.pyc ├── synthtext_to_tfrecords_self.py ├── sythtextprovider.py ├── sythtextprovider.pyc ├── xml_to_tfrecords.py └── xml_to_tfrecords.pyc ├── demo.py ├── demo ├── example │ ├── image0.txt │ ├── image0.xml │ └── standard.xml └── img_1.jpg ├── deployment ├── __init__.py ├── __init__.pyc ├── model_deploy.py ├── model_deploy.pyc └── model_deploy_test.py ├── eval.py ├── eval_result.py ├── gene_tfrecords.py ├── logs └── train_xml.txt ├── nets ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── custom_layers.cpython-35.pyc │ ├── np_methods.cpython-35.pyc │ ├── textbox_common.cpython-35.pyc │ ├── txtbox_384.cpython-35.pyc │ └── txtbox_768.cpython-35.pyc ├── custom_layers.py ├── custom_layers.pyc ├── np_methods.py ├── np_methods.pyc ├── textbox_common.py ├── textbox_common.pyc ├── txtbox_384.py ├── txtbox_384.pyc ├── txtbox_768.py └── txtbox_768.pyc ├── processing ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── ssd_vgg_preprocessing.cpython-35.pyc │ └── tf_image.cpython-35.pyc ├── ssd_vgg_preprocessing.py ├── ssd_vgg_preprocessing.pyc ├── tf_image.py └── tf_image.pyc ├── test.py ├── tf_extended ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── bboxes.cpython-35.pyc │ ├── image.cpython-35.pyc │ ├── math.cpython-35.pyc │ ├── metrics.cpython-35.pyc │ └── tensors.cpython-35.pyc ├── bboxes.py ├── bboxes.pyc ├── image.py ├── image.pyc ├── math.py ├── math.pyc ├── metrics.py ├── metrics.pyc ├── tensors.py ├── tensors.pyc ├── tf_utils.py └── tf_utils.pyc ├── tools ├── _init_paths.py ├── convert_xml_format.py ├── gen_xml.py └── test_dataset.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # TextBoxes++-TensorFlow 2 | TextBoxes++ re-implementation using tensorflow. 3 | This project is greatly inspired by [slim project](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) 4 | And many functions are modified based on [SSD-tensorflow project](https://github.com/balancap/SSD-Tensorflow) 5 | 6 | Author: 7 | Zhisheng Zou zzsshun13@gmail.com 8 | 9 | # pretrained model 10 | 1. [Google drive](https://drive.google.com/open?id=1kkRyVrx9iFtwEar6OJBKWNVyTLSYsF28) 11 | 12 | # environment 13 | ` python2.7/python3.5 ` 14 | 15 | `tensorflow-gpu 1.8.0` 16 | 17 | `at least one gpu` 18 | 19 | # how to use 20 | 21 | 1. Getting the xml file like this [example xml](./demo/example/image0.xml) and put the image together because we need the format like this [standard xml](./demo/example/standard.xml) 22 | 1. picture format: *.png or *.PNG 23 | 2. Getting the xml and flags 24 | ensure the XML file is under the same directory as the corresponding image.execute the code: [convert_xml_format.py](./tools/convert_xml_format.py) 25 | 1. `python tools/convert_xml_format.py -i in_dir -s split_flag -l save_logs -o output_dir` 26 | 2. in_dir means the absolute directory which contains the pic and xml 27 | 3. split_flag means whether or not to split the datasets 28 | 4. save_logs means whether to save train_xml.txt 29 | 5. output_dir means where to save xmls 30 | 3. Getting the tfrecords 31 | 1. `python gene_tfrecords.py --xml_img_txt_path=./logs/train_xml.txt --output_dir=tfrecords` 32 | 2. xml_img_txt_path like this [train xml](./logs/train_xml.txt) 33 | 3. output_dir means where to save tfrecords 34 | 4. Training 35 | 1. `python train.py --train_dir =some_path --dataset_dir=some_path --checkpoint_path=some_path` 36 | 2. train_dir store the checkpoints when training 37 | 3. dataset_dir store the tfrecords for training 38 | 4. checkpoint_path store the model which needs to be fine tuned 39 | 5. Testing 40 | 1. `python test.py -m /home/model.ckpt-858 -o test` 41 | 2. -m which means the model 42 | 3. -o which means output_result_dir 43 | 4. -i which means the test img dir 44 | 5. -c which means use which device to run the test 45 | 6. -n which means the nms threshold 46 | 7. -s which means the score threshold 47 | 48 | 49 | 50 | # Note: 51 | 52 | 1. when you are training the model, you can run the eval_result.py to eval your model and save the result 53 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/__init__.pyc -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | def norm(x): 29 | if x < 0: 30 | x = 0 31 | else: 32 | if x > 1: 33 | x = 1 34 | return x 35 | 36 | def int64_feature(value): 37 | """Wrapper for inserting int64 features into Example proto. 38 | """ 39 | if not isinstance(value, list): 40 | value = [value] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 42 | 43 | 44 | def float_feature(value): 45 | """Wrapper for inserting float features into Example proto. 46 | """ 47 | if not isinstance(value, list): 48 | value = [value] 49 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 50 | 51 | 52 | def bytes_feature(value): 53 | """Wrapper for inserting bytes features into Example proto. 54 | """ 55 | if not isinstance(value, list): 56 | value = [value] 57 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 58 | 59 | 60 | def image_to_tfexample(image_data, image_format, height, width, class_id): 61 | return tf.train.Example(features=tf.train.Features(feature={ 62 | 'image/encoded': bytes_feature(image_data), 63 | 'image/format': bytes_feature(image_format), 64 | 'image/class/label': int64_feature(class_id), 65 | 'image/height': int64_feature(height), 66 | 'image/width': int64_feature(width), 67 | })) 68 | 69 | 70 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 71 | """Downloads the `tarball_url` and uncompresses it locally. 72 | 73 | Args: 74 | tarball_url: The URL of a tarball file. 75 | dataset_dir: The directory where the temporary files are stored. 76 | """ 77 | filename = tarball_url.split('/')[-1] 78 | filepath = os.path.join(dataset_dir, filename) 79 | 80 | def _progress(count, block_size, total_size): 81 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 82 | filename, float(count * block_size) / float(total_size) * 100.0)) 83 | sys.stdout.flush() 84 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 85 | statinfo = os.stat(filepath) 86 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 87 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 88 | 89 | 90 | def write_label_file(labels_to_class_names, dataset_dir, 91 | filename=LABELS_FILENAME): 92 | """Writes a file with the list of class names. 93 | 94 | Args: 95 | labels_to_class_names: A map of (integer) labels to class names. 96 | dataset_dir: The directory in which the labels file should be written. 97 | filename: The filename where the class names are written. 98 | """ 99 | labels_filename = os.path.join(dataset_dir, filename) 100 | with tf.gfile.Open(labels_filename, 'w') as f: 101 | for label in labels_to_class_names: 102 | class_name = labels_to_class_names[label] 103 | f.write('%d:%s\n' % (label, class_name)) 104 | 105 | 106 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 107 | """Specifies whether or not the dataset directory contains a label map file. 108 | 109 | Args: 110 | dataset_dir: The directory in which the labels file is found. 111 | filename: The filename where the class names are written. 112 | 113 | Returns: 114 | `True` if the labels file exists and `False` otherwise. 115 | """ 116 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 117 | 118 | 119 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 120 | """Reads the labels file and returns a mapping from ID to class name. 121 | 122 | Args: 123 | dataset_dir: The directory in which the labels file is found. 124 | filename: The filename where the class names are written. 125 | 126 | Returns: 127 | A map from a label (integer) to class name. 128 | """ 129 | labels_filename = os.path.join(dataset_dir, filename) 130 | with tf.gfile.Open(labels_filename, 'rb') as f: 131 | lines = f.read() 132 | lines = lines.split(b'\n') 133 | lines = filter(None, lines) 134 | 135 | labels_to_class_names = {} 136 | for line in lines: 137 | index = line.index(b':') 138 | labels_to_class_names[int(line[:index])] = line[index+1:] 139 | return labels_to_class_names 140 | 141 | 142 | class ImageCoder(object): 143 | """Helper class that provides TensorFlow image coding utilities.""" 144 | 145 | def __init__(self): 146 | # Create a single Session to run all image coding calls. 147 | self._sess = tf.Session() 148 | 149 | # Initializes function that converts PNG to JPEG data. 150 | self._png_data = tf.placeholder(dtype=tf.string) 151 | image = tf.image.decode_png(self._png_data, channels=3) 152 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) 153 | 154 | # Initializes function that converts CMYK JPEG data to RGB JPEG data. 155 | self._cmyk_data = tf.placeholder(dtype=tf.string) 156 | image = tf.image.decode_jpeg(self._cmyk_data, channels=0) 157 | self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) 158 | 159 | # Initializes function that decodes RGB JPEG data. 160 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 161 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 162 | 163 | def png_to_jpeg(self, image_data): 164 | return self._sess.run(self._png_to_jpeg, 165 | feed_dict={self._png_data: image_data}) 166 | 167 | def cmyk_to_rgb(self, image_data): 168 | return self._sess.run(self._cmyk_to_rgb, 169 | feed_dict={self._cmyk_data: image_data}) 170 | 171 | def decode_jpeg(self, image_data): 172 | image = self._sess.run(self._decode_jpeg, 173 | feed_dict={self._decode_jpeg_data: image_data}) 174 | assert len(image.shape) == 3 175 | assert image.shape[2] == 3 176 | return image 177 | -------------------------------------------------------------------------------- /datasets/dataset_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/dataset_utils.pyc -------------------------------------------------------------------------------- /datasets/synthtext_to_tfrecords_self.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import numpy as np 3 | import tensorflow as tf 4 | import time 5 | import tensorflow.contrib.slim as slim 6 | import util 7 | 8 | 9 | 10 | def int64_feature(value): 11 | """Wrapper for inserting int64 features into Example proto. 12 | """ 13 | if not isinstance(value, list): 14 | value = [value] 15 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 16 | 17 | 18 | def float_feature(value): 19 | """Wrapper for inserting float features into Example proto. 20 | """ 21 | if not isinstance(value, list): 22 | value = [value] 23 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 24 | 25 | 26 | def bytes_feature(value): 27 | """Wrapper for inserting bytes features into Example proto. 28 | """ 29 | if not isinstance(value, list): 30 | value = [value] 31 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 32 | 33 | 34 | def image_to_tfexample(image_data, image_format, height, width, class_id): 35 | return tf.train.Example(features=tf.train.Features(feature={ 36 | 'image/encoded': bytes_feature(image_data), 37 | 'image/format': bytes_feature(image_format), 38 | 'image/class/label': int64_feature(class_id), 39 | 'image/height': int64_feature(height), 40 | 'image/width': int64_feature(width), 41 | })) 42 | 43 | 44 | def convert_to_example(image_data, filename, labels, ignored, labels_text, bboxes, oriented_bboxes, shape): 45 | """Build an Example proto for an image example. 46 | Args: 47 | image_data: string, JPEG encoding of RGB image 48 | labels: list of integers, identifier for the ground truth 49 | labels_text: list of strings, human-readable labels 50 | oriented_bboxes: list of bounding oriented boxes each box is a list of floats in [0, 1] 51 | specifying [x1, y1, x2, y2, x3, y3, x4, y4] 52 | bboxes: list of bbox in rectangle, [xmin, ymin, xmax, ymax] 53 | Returns: 54 | Example proto 55 | """ 56 | 57 | image_format = b'JPEG' 58 | oriented_bboxes = np.asarray(oriented_bboxes) 59 | bboxes = np.asarray(bboxes) 60 | example = tf.train.Example(features=tf.train.Features(feature={ 61 | 'image/shape': int64_feature(list(shape)), 62 | 'image/object/bbox/xmin': float_feature(list(bboxes[:, 0])), 63 | 'image/object/bbox/ymin': float_feature(list(bboxes[:, 1])), 64 | 'image/object/bbox/xmax': float_feature(list(bboxes[:, 2])), 65 | 'image/object/bbox/ymax': float_feature(list(bboxes[:, 3])), 66 | 'image/object/bbox/x1': float_feature(list(oriented_bboxes[:, 0])), 67 | 'image/object/bbox/y1': float_feature(list(oriented_bboxes[:, 1])), 68 | 'image/object/bbox/x2': float_feature(list(oriented_bboxes[:, 2])), 69 | 'image/object/bbox/y2': float_feature(list(oriented_bboxes[:, 3])), 70 | 'image/object/bbox/x3': float_feature(list(oriented_bboxes[:, 4])), 71 | 'image/object/bbox/y3': float_feature(list(oriented_bboxes[:, 5])), 72 | 'image/object/bbox/x4': float_feature(list(oriented_bboxes[:, 6])), 73 | 'image/object/bbox/y4': float_feature(list(oriented_bboxes[:, 7])), 74 | 'image/object/bbox/label': int64_feature(labels), 75 | 'image/object/bbox/label_text': bytes_feature(labels_text), 76 | 'image/object/bbox/ignored': int64_feature(ignored), 77 | 'image/format': bytes_feature(image_format), 78 | 'image/filename': bytes_feature(filename), 79 | 'image/encoded': bytes_feature(image_data)})) 80 | return example 81 | 82 | 83 | 84 | def get_split(split_name, dataset_dir, file_pattern, num_samples, reader=None): 85 | dataset_dir = util.io.get_absolute_path(dataset_dir) 86 | 87 | if util.str.contains(file_pattern, '%'): 88 | file_pattern = util.io.join_path(dataset_dir, file_pattern % split_name) 89 | else: 90 | file_pattern = util.io.join_path(dataset_dir, file_pattern) 91 | # Allowing None in the signature so that dataset_factory can use the default. 92 | if reader is None: 93 | reader = tf.TFRecordReader 94 | keys_to_features = { 95 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 96 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 97 | 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''), 98 | 'image/shape': tf.FixedLenFeature([3], tf.int64), 99 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 100 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 101 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 102 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 103 | 'image/object/bbox/x1': tf.VarLenFeature(dtype=tf.float32), 104 | 'image/object/bbox/x2': tf.VarLenFeature(dtype=tf.float32), 105 | 'image/object/bbox/x3': tf.VarLenFeature(dtype=tf.float32), 106 | 'image/object/bbox/x4': tf.VarLenFeature(dtype=tf.float32), 107 | 'image/object/bbox/y1': tf.VarLenFeature(dtype=tf.float32), 108 | 'image/object/bbox/y2': tf.VarLenFeature(dtype=tf.float32), 109 | 'image/object/bbox/y3': tf.VarLenFeature(dtype=tf.float32), 110 | 'image/object/bbox/y4': tf.VarLenFeature(dtype=tf.float32), 111 | 'image/object/bbox/ignored': tf.VarLenFeature(dtype=tf.int64), 112 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 113 | } 114 | items_to_handlers = { 115 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 116 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 117 | 'filename': slim.tfexample_decoder.Tensor('image/filename'), 118 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 119 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 120 | 'object/oriented_bbox/x1': slim.tfexample_decoder.Tensor('image/object/bbox/x1'), 121 | 'object/oriented_bbox/x2': slim.tfexample_decoder.Tensor('image/object/bbox/x2'), 122 | 'object/oriented_bbox/x3': slim.tfexample_decoder.Tensor('image/object/bbox/x3'), 123 | 'object/oriented_bbox/x4': slim.tfexample_decoder.Tensor('image/object/bbox/x4'), 124 | 'object/oriented_bbox/y1': slim.tfexample_decoder.Tensor('image/object/bbox/y1'), 125 | 'object/oriented_bbox/y2': slim.tfexample_decoder.Tensor('image/object/bbox/y2'), 126 | 'object/oriented_bbox/y3': slim.tfexample_decoder.Tensor('image/object/bbox/y3'), 127 | 'object/oriented_bbox/y4': slim.tfexample_decoder.Tensor('image/object/bbox/y4'), 128 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), 129 | 'object/ignored': slim.tfexample_decoder.Tensor('image/object/bbox/ignored') 130 | } 131 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 132 | 133 | labels_to_names = {0:'background', 1:'text'} 134 | items_to_descriptions = { 135 | 'image': 'A color image of varying height and width.', 136 | 'shape': 'Shape of the image', 137 | 'object/bbox': 'A list of bounding boxes, one per each object.', 138 | 'object/label': 'A list of labels, one per each object.', 139 | } 140 | 141 | return slim.dataset.Dataset( 142 | data_sources=file_pattern, 143 | reader=reader, 144 | decoder=decoder, 145 | num_samples=num_samples, 146 | items_to_descriptions=items_to_descriptions, 147 | num_classes=2, 148 | labels_to_names=labels_to_names) 149 | 150 | 151 | 152 | 153 | class SynthTextDataFetcher(): 154 | def __init__(self, mat_path, root_path): 155 | self.mat_path = mat_path 156 | self.root_path = root_path 157 | self._load_mat() 158 | 159 | # @util.dec.print_calling 160 | def _load_mat(self): 161 | data = util.io.load_mat(self.mat_path) 162 | self.image_paths = data['imnames'][0] 163 | self.image_bbox = data['wordBB'][0] 164 | self.txts = data['txt'][0] 165 | self.num_images = len(self.image_paths) 166 | 167 | def get_image_path(self, idx): 168 | image_path = util.io.join_path(self.root_path, self.image_paths[idx][0]) 169 | return image_path 170 | 171 | def get_num_words(self, idx): 172 | try: 173 | return np.shape(self.image_bbox[idx])[2] 174 | except: # error caused by dataset 175 | return 1 176 | 177 | 178 | def get_word_bbox(self, img_idx, word_idx): 179 | boxes = self.image_bbox[img_idx] 180 | if len(np.shape(boxes)) ==2: # error caused by dataset 181 | boxes = np.reshape(boxes, (2, 4, 1)) 182 | 183 | xys = boxes[:,:, word_idx] 184 | assert(np.shape(xys) ==(2, 4)) 185 | return np.float32(xys) 186 | 187 | def normalize_bbox(self, xys, width, height): 188 | xs = xys[0, :] 189 | ys = xys[1, :] 190 | 191 | min_x = min(xs) 192 | min_y = min(ys) 193 | max_x = max(xs) 194 | max_y = max(ys) 195 | 196 | # bound them in the valid range 197 | min_x = max(0, min_x) 198 | min_y = max(0, min_y) 199 | max_x = min(width, max_x) 200 | max_y = min(height, max_y) 201 | 202 | # check the w, h and area of the rect 203 | w = max_x - min_x 204 | h = max_y - min_y 205 | is_valid = True 206 | 207 | if w < 10 or h < 10: 208 | is_valid = False 209 | 210 | if w * h < 100: 211 | is_valid = False 212 | 213 | xys[0, :] = xys[0, :] / width 214 | xys[1, :] = xys[1, :] / height 215 | 216 | return is_valid, min_x / width, min_y /height, max_x / width, max_y / height, xys 217 | 218 | def get_txt(self, image_idx, word_idx): 219 | txts = self.txts[image_idx] 220 | clean_txts = [] 221 | for txt in txts: 222 | clean_txts += txt.split() 223 | return str(clean_txts[word_idx]) 224 | 225 | 226 | def fetch_record(self, image_idx): 227 | image_path = self.get_image_path(image_idx) 228 | if not (util.io.exists(image_path)): 229 | return None 230 | img = util.img.imread(image_path) 231 | h, w = img.shape[0:-1] 232 | num_words = self.get_num_words(image_idx) 233 | rect_bboxes = [] 234 | full_bboxes = [] 235 | txts = [] 236 | for word_idx in range(num_words): 237 | xys = self.get_word_bbox(image_idx, word_idx) 238 | is_valid, min_x, min_y, max_x, max_y, xys = self.normalize_bbox(xys, width = w, height = h) 239 | if not is_valid: 240 | continue 241 | rect_bboxes.append([min_x, min_y, max_x, max_y]) 242 | xys = np.reshape(np.transpose(xys), -1) 243 | full_bboxes.append(xys) 244 | txt = self.get_txt(image_idx, word_idx) 245 | txts.append(txt) 246 | if len(rect_bboxes) == 0: 247 | return None 248 | 249 | return image_path, img, txts, rect_bboxes, full_bboxes 250 | 251 | 252 | 253 | def cvt_to_tfrecords(output_path , data_path, gt_path, records_per_file = 30000): 254 | 255 | fetcher = SynthTextDataFetcher(root_path = data_path, mat_path = gt_path) 256 | fid = 0 257 | image_idx = -1 258 | while image_idx < fetcher.num_images: 259 | with tf.python_io.TFRecordWriter(output_path%(fid)) as tfrecord_writer: 260 | record_count = 0 261 | while record_count != records_per_file: 262 | image_idx += 1 263 | if image_idx >= fetcher.num_images: 264 | break 265 | print("loading image %d/%d"%(image_idx + 1, fetcher.num_images)) 266 | record = fetcher.fetch_record(image_idx) 267 | if record is None: 268 | print('\nimage %d does not exist'%(image_idx + 1)) 269 | continue 270 | 271 | image_path, image, txts, rect_bboxes, oriented_bboxes = record 272 | labels = len(rect_bboxes) * [1] 273 | ignored = len(rect_bboxes) * [0] 274 | image_data = tf.gfile.FastGFile(image_path, 'r').read() 275 | 276 | shape = image.shape 277 | image_name = str(util.io.get_filename(image_path).split('.')[0]) 278 | example = convert_to_example(image_data, image_name, labels, ignored, txts, rect_bboxes, oriented_bboxes, shape) 279 | tfrecord_writer.write(example.SerializeToString()) 280 | record_count += 1 281 | 282 | fid += 1 283 | 284 | if __name__ == "__main__": 285 | mat_path = util.io.get_absolute_path('/share/SynthText/gt.mat') 286 | root_path = util.io.get_absolute_path('/share/SynthText/') 287 | output_dir = util.io.get_absolute_path('/home/zsz/datasets/synth-tf/') 288 | util.io.mkdir(output_dir) 289 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'SynthText_%d.tfrecord'), data_path = root_path, gt_path = mat_path) 290 | -------------------------------------------------------------------------------- /datasets/sythtextprovider.py: -------------------------------------------------------------------------------- 1 | ## an initial version 2 | ## Transform the tfrecord to slim data provider format 3 | 4 | import numpy 5 | import tensorflow as tf 6 | import os 7 | import tensorflow.contrib.slim as slim 8 | import glob 9 | 10 | 11 | 12 | ITEMS_TO_DESCRIPTIONS = { 13 | 'image': 'A color image of varying height and width.', 14 | 'shape': 'Shape of the image', 15 | 'object/bbox': 'A list of bounding boxes, one per each object.', 16 | 'object/label': 'A list of labels, one per each object.', 17 | } 18 | SPLITS_TO_SIZES = { 19 | #'train': 2518, for ppt datasets 20 | 'train': 858750 # for synth text datasets 21 | } 22 | def get_datasets(data_dir,file_pattern = '*.tfrecord'): 23 | file_patterns = os.path.join(data_dir, file_pattern) 24 | print('file_path: {}'.format(file_patterns)) 25 | file_path_list = glob.glob(file_patterns) 26 | num_samples = 0 27 | #num_samples = 288688 28 | #num_samples = 858750 only for synth datasets 29 | for file_path in file_path_list: 30 | for record in tf.python_io.tf_record_iterator(file_path): 31 | num_samples += 1 32 | print('num_samples:', num_samples) 33 | reader = tf.TFRecordReader 34 | keys_to_features = { 35 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 36 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 37 | 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''), 38 | 'image/shape': tf.FixedLenFeature([3], tf.int64), 39 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 40 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 41 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 42 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 43 | 'image/object/bbox/x1': tf.VarLenFeature(dtype=tf.float32), 44 | 'image/object/bbox/x2': tf.VarLenFeature(dtype=tf.float32), 45 | 'image/object/bbox/x3': tf.VarLenFeature(dtype=tf.float32), 46 | 'image/object/bbox/x4': tf.VarLenFeature(dtype=tf.float32), 47 | 'image/object/bbox/y1': tf.VarLenFeature(dtype=tf.float32), 48 | 'image/object/bbox/y2': tf.VarLenFeature(dtype=tf.float32), 49 | 'image/object/bbox/y3': tf.VarLenFeature(dtype=tf.float32), 50 | 'image/object/bbox/y4': tf.VarLenFeature(dtype=tf.float32), 51 | 'image/object/bbox/ignored': tf.VarLenFeature(dtype=tf.int64), 52 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 53 | } 54 | 55 | items_to_handlers = { 56 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 57 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 58 | 'filename': slim.tfexample_decoder.Tensor('image/filename'), 59 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 60 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 61 | 'object/oriented_bbox/x1': slim.tfexample_decoder.Tensor('image/object/bbox/x1'), 62 | 'object/oriented_bbox/x2': slim.tfexample_decoder.Tensor('image/object/bbox/x2'), 63 | 'object/oriented_bbox/x3': slim.tfexample_decoder.Tensor('image/object/bbox/x3'), 64 | 'object/oriented_bbox/x4': slim.tfexample_decoder.Tensor('image/object/bbox/x4'), 65 | 'object/oriented_bbox/y1': slim.tfexample_decoder.Tensor('image/object/bbox/y1'), 66 | 'object/oriented_bbox/y2': slim.tfexample_decoder.Tensor('image/object/bbox/y2'), 67 | 'object/oriented_bbox/y3': slim.tfexample_decoder.Tensor('image/object/bbox/y3'), 68 | 'object/oriented_bbox/y4': slim.tfexample_decoder.Tensor('image/object/bbox/y4'), 69 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), 70 | 'object/ignored': slim.tfexample_decoder.Tensor('image/object/bbox/ignored') 71 | } 72 | 73 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 74 | 75 | labels_to_names = {0:'background', 1:'text'} 76 | return slim.dataset.Dataset( 77 | data_sources=file_patterns, 78 | reader=reader, 79 | decoder=decoder, 80 | num_samples=num_samples, 81 | items_to_descriptions=ITEMS_TO_DESCRIPTIONS, 82 | num_classes=NUM_CLASSES, 83 | labels_to_names=labels_to_names) 84 | 85 | NUM_CLASSES = 2 86 | -------------------------------------------------------------------------------- /datasets/sythtextprovider.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/sythtextprovider.pyc -------------------------------------------------------------------------------- /datasets/xml_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import random 5 | import numpy as np 6 | import tensorflow as tf 7 | import xml.etree.ElementTree as ET 8 | from datasets.dataset_utils import int64_feature, float_feature, bytes_feature 9 | import tensorflow.contrib.slim as slim 10 | 11 | # TFRecords convertion parameters. 12 | 13 | TXT_LABELS = { 14 | 'none': (0, 'Background'), 15 | 'text': (1, 'Text') 16 | } 17 | 18 | def _process_image(train_img_path, train_xml_path, name): 19 | """Process a image and annotation file. 20 | 21 | Args: 22 | filename: string, path to an image file e.g., '/path/to/example.JPG'. 23 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 24 | Returns: 25 | image_buffer: string, JPEG encoding of RGB image. 26 | height: integer, image height in pixels. 27 | width: integer, image width in pixels. 28 | """ 29 | # Read the image file. 30 | 31 | image_data = tf.gfile.FastGFile(train_img_path, 'rb').read() 32 | 33 | tree = ET.parse(train_xml_path) 34 | root = tree.getroot() 35 | # Image shape. 36 | size = root.find('size') 37 | height = int(size.find('height').text) 38 | width = int(size.find('width').text) 39 | depth = int(size.find('depth').text) 40 | if height <= 0 or width <= 0 or depth <= 0: 41 | print('height or width depth error',height, width, depth) 42 | return 43 | 44 | shape = [int(size.find('height').text), 45 | int(size.find('width').text), 46 | int(size.find('depth').text)] 47 | # Find annotations. 48 | bboxes = [] 49 | labels = [] 50 | labels_text = [] 51 | difficult = [] 52 | truncated = [] 53 | oriented_bbox = [] 54 | ignored = 0 55 | filename = root.find('filename').text 56 | 57 | for obj in root.findall('object'): 58 | label = obj.find('name').text 59 | if label == 'none': 60 | label = 'none' 61 | else: 62 | label = 'text' 63 | labels.append(int(TXT_LABELS[label][0])) 64 | labels_text.append(label.encode('ascii')) 65 | 66 | if obj.find('difficult') is not None: 67 | #print('append difficult') 68 | difficult.append(int(obj.find('difficult').text)) 69 | else: 70 | difficult.append(0) 71 | if obj.find('truncated'): 72 | truncated.append(int(obj.find('truncated').text)) 73 | else: 74 | truncated.append(0) 75 | 76 | bbox = obj.find('bndbox') 77 | ymin = float(bbox.find('ymin').text) 78 | ymax = float(bbox.find('ymax').text) 79 | xmin = float(bbox.find('xmin').text) 80 | xmax = float(bbox.find('xmax').text) 81 | 82 | x1 = float(bbox.find('x1').text) 83 | x2 = float(bbox.find('x2').text) 84 | x3 = float(bbox.find('x3').text) 85 | x4 = float(bbox.find('x4').text) 86 | 87 | y1 = float(bbox.find('y1').text) 88 | y2 = float(bbox.find('y2').text) 89 | y3 = float(bbox.find('y3').text) 90 | y4 = float(bbox.find('y4').text) 91 | 92 | 93 | ymin, ymax = np.clip([ymin, ymax], 0 , height) 94 | xmin, xmax = np.clip([xmin, xmax] ,0 , width) 95 | 96 | x1, x2, x3, x4 = np.clip([x1, x2, x3, x4], 0, width) 97 | y1, y2, y3, y4 = np.clip([y1, y2, y3, y4], 0, height) 98 | 99 | bboxes.append(( ymin / shape[0], 100 | xmin / shape[1], 101 | ymax / shape[0], 102 | xmax / shape[1] 103 | )) 104 | 105 | oriented_bbox.append((x1 / width, x2 / width, x3 / width, x4 /width, y1 /height, y2 / height, y3 / height, y4 / height)) 106 | 107 | return image_data, shape, bboxes, labels, labels_text, difficult, truncated, oriented_bbox, ignored, filename 108 | 109 | def _convert_to_example(image_data, labels, labels_text, bboxes, shape, 110 | difficult, truncated, oriented_bbox, ignored, filename): 111 | """Build an Example proto for an image example. 112 | 113 | Args: 114 | image_data: string, JPEG encoding of RGB image; 115 | labels: list of integers, identifier for the ground truth; 116 | labels_text: list of strings, human-readable labels; 117 | bboxes: list of bounding boxes; each box is a list of integers; 118 | specifying [ymin, xmin, ymax, xmax]. All boxes are assumed to belong 119 | to the same label as the image label. 120 | shape: 3 integers, image shapes in pixels. 121 | Returns: 122 | Example proto 123 | """ 124 | xmin = [] 125 | ymin = [] 126 | xmax = [] 127 | ymax = [] 128 | for b in bboxes: 129 | assert len(b) == 4 130 | # pylint: disable=expression-not-assigned 131 | [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)] 132 | # pylint: enable=expression-not-assigned 133 | 134 | x1 = [] 135 | x2 = [] 136 | x3 = [] 137 | x4 = [] 138 | 139 | y1 = [] 140 | y2 = [] 141 | y3 = [] 142 | y4 = [] 143 | 144 | for orgin in oriented_bbox: 145 | assert len(orgin) == 8 146 | [l.append(point) for l, point in zip([x1, x2, x3, x4, y1, y2, y3, y4], orgin)] 147 | 148 | image_format = b'JPEG' 149 | example = tf.train.Example(features=tf.train.Features(feature={ 150 | 'image/height': int64_feature(shape[0]), 151 | 'image/width': int64_feature(shape[1]), 152 | 'image/channels': int64_feature(shape[2]), 153 | 'image/shape': int64_feature(shape), 154 | 'image/filename': bytes_feature(filename.encode('utf-8')), 155 | 'image/object/bbox/xmin': float_feature(xmin), 156 | 'image/object/bbox/xmax': float_feature(xmax), 157 | 'image/object/bbox/ymin': float_feature(ymin), 158 | 'image/object/bbox/ymax': float_feature(ymax), 159 | 'image/object/bbox/x1': float_feature(x1), 160 | 'image/object/bbox/y1': float_feature(y1), 161 | 'image/object/bbox/x2': float_feature(x2), 162 | 'image/object/bbox/y2': float_feature(y2), 163 | 'image/object/bbox/x3': float_feature(x3), 164 | 'image/object/bbox/y3': float_feature(y3), 165 | 'image/object/bbox/x4': float_feature(x4), 166 | 'image/object/bbox/y4': float_feature(y4), 167 | 'image/object/bbox/label': int64_feature(labels), 168 | 'image/object/bbox/label_text': bytes_feature(labels_text), 169 | 'image/object/bbox/difficult': int64_feature(difficult), 170 | 'image/object/bbox/truncated': int64_feature(truncated), 171 | 'image/object/bbox/ignored': int64_feature(ignored), 172 | 'image/format': bytes_feature(image_format), 173 | 'image/encoded': bytes_feature(image_data)})) 174 | return example 175 | 176 | 177 | def _get_output_filename(output_dir, name, idx): 178 | return '%s/%s_%03d.tfrecord' % (output_dir, name, idx) 179 | 180 | 181 | def _add_to_tfrecord(train_img_path, train_xml_path , name, tfrecord_writer): 182 | """Loads data from image and annotations files and add them to a TFRecord. 183 | 184 | Args: 185 | train_img_path: img path; 186 | train_xml_path: xml path 187 | name: Image name to add to the TFRecord; 188 | tfrecord_writer: The TFRecord writer to use for writing. 189 | """ 190 | image_data, shape, bboxes, labels, labels_text, difficult, truncated, oriented_bbox, ignored, filename = \ 191 | _process_image(train_img_path, train_xml_path , name) 192 | example = _convert_to_example(image_data, labels, labels_text, 193 | bboxes, shape, difficult, truncated, oriented_bbox, ignored, filename) 194 | tfrecord_writer.write(example.SerializeToString()) 195 | 196 | 197 | def run(xml_img_txt_path, output_dir, name='icdar15_annotated_data', samples_per_files=200): 198 | """Runs the conversion operation. 199 | Args: 200 | xml_img_txt_path: The txt stored where the dataset is stored. 201 | output_dir: Output directory. 202 | """ 203 | 204 | if not tf.gfile.Exists(output_dir): 205 | tf.gfile.MakeDirs(output_dir) 206 | 207 | train_txt = open(xml_img_txt_path, 'r') 208 | lines = train_txt.readlines() 209 | train_img_path = [] 210 | train_xml_path = [] 211 | error_list = [] 212 | count = 0 213 | for line in lines: 214 | line = line.strip() 215 | if len(line.split(',')) == 2: 216 | count += 1 217 | train_img_path.append(line.split(',')[0]) 218 | train_xml_path.append(line.split(',')[1]) 219 | else: 220 | error_list.append(line) 221 | # print('line:',line) 222 | with open(os.path.join(output_dir, 'create_tfrecord_error_list.txt'), 'w') as f: 223 | f.writelines(error_list) 224 | filenames = train_img_path 225 | 226 | # Process dataset files 227 | i = 0 228 | fidx = 0 229 | while i < len(filenames): 230 | #Open new TFRecord file. 231 | tf_filename = _get_output_filename(output_dir, name, fidx) 232 | 233 | with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer: 234 | j = 0 235 | while i < len(filenames) and j < samples_per_files: 236 | sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames))) 237 | sys.stdout.flush() 238 | 239 | filename = filenames[i] 240 | img_name = filename.split('/')[-1][:-4] 241 | _add_to_tfrecord(filename, train_xml_path[i], img_name, tfrecord_writer) 242 | i += 1 243 | j += 1 244 | fidx += 1 245 | print('\nFinished converting the charts dataset!') 246 | 247 | -------------------------------------------------------------------------------- /datasets/xml_to_tfrecords.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/datasets/xml_to_tfrecords.pyc -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | import cv2 5 | import tensorflow.contrib.slim as slim 6 | import codecs 7 | import sys 8 | import time 9 | import random 10 | sys.path.append('./') 11 | 12 | from nets import txtbox_384, np_methods, txtbox_768 13 | from processing import ssd_vgg_preprocessing 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' #using GPU 0 16 | 17 | def plt_bboxes(img, classes, scores, bboxes, figsize=(10,10), linewidth=1.5): 18 | """Visualize bounding boxes. Largely inspired by SSD-MXNET! 19 | """ 20 | height = img.shape[0] 21 | width = img.shape[1] 22 | colors = dict() 23 | for i in range(classes.shape[0]): 24 | cls_id = int(classes[i]) 25 | if cls_id >= 0: 26 | score = scores[i] 27 | if cls_id not in colors: 28 | colors[cls_id] = (random.random(), random.random(), random.random()) 29 | 30 | xmin = int(bboxes[i, 0] * width) 31 | ymin = int(bboxes[i, 1] * height) 32 | xmax = int(bboxes[i, 2] * width) 33 | ymax = int(bboxes[i, 3] * height) 34 | img = cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 0)) 35 | return img 36 | 37 | gpu_options = tf.GPUOptions(allow_growth=False, per_process_gpu_memory_fraction=0.3) 38 | 39 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options) 40 | isess = tf.Session(config=config) 41 | 42 | # Input placeholder. 43 | net_shape = (384, 384) 44 | #net_shape = (768, 768) 45 | data_format = 'NHWC' 46 | img_input = tf.placeholder(tf.float32, shape=(None, None, 3)) 47 | # Evaluation pre-processing: resize to SSD net shape. 48 | image_pre, labels_pre, bboxes_pre, bbox_img, xs, ys = ssd_vgg_preprocessing.preprocess_for_eval( 49 | img_input, None, None,None, None, net_shape, data_format, resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE) 50 | image_4d = tf.expand_dims(image_pre, 0) 51 | image_4d = tf.cast(image_4d, tf.float32) 52 | # Define the txt_box model. 53 | reuse = True if 'txt_net' in locals() else None 54 | 55 | txt_net = txtbox_384.TextboxNet() 56 | print(txt_net.params.img_shape) 57 | print('reuse:',reuse) 58 | 59 | with slim.arg_scope(txt_net.arg_scope(data_format=data_format)): 60 | predictions,localisations, logits, end_points = txt_net.net(image_4d, is_training=False, reuse=reuse) 61 | 62 | ckpt_dir = 'model' 63 | 64 | isess.run(tf.global_variables_initializer()) 65 | 66 | saver = tf.train.Saver() 67 | 68 | ckpt_filename = tf.train.latest_checkpoint(ckpt_dir) 69 | if ckpt_dir and ckpt_filename: 70 | print('checkpoint:',ckpt_dir, os.getcwd(), ckpt_filename) 71 | saver.restore(isess, ckpt_filename) 72 | txt_anchors = txt_net.anchors(net_shape) 73 | 74 | def process_image(img, select_threshold=0.01, nms_threshold=.45, net_shape=net_shape): 75 | # Run txt network. 76 | startTime = time.time() 77 | rimg, rpredictions,rlogits,rlocalisations, rbbox_img = isess.run([image_4d, predictions, logits, localisations, bbox_img], 78 | feed_dict={img_input: img}) 79 | 80 | end_time = time.time() 81 | print(end_time - startTime) 82 | # Get classes and bboxes from the net outputs 83 | 84 | rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select( 85 | rpredictions, rlocalisations, txt_anchors, 86 | select_threshold=select_threshold, img_shape=net_shape, num_classes=2, decode=True) 87 | 88 | rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes) 89 | # print(rscores) 90 | rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k=400) 91 | rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold) 92 | # Resize bboxes to original image shape. Note: useless for Resize.WARP! 93 | rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes) 94 | return rclasses, rscores, rbboxes 95 | 96 | 97 | # Test on some demo image and visualize output. 98 | path = './demo/' 99 | 100 | img = cv2.imread(path + 'img_1.jpg') 101 | img_cp = img.copy() 102 | rclasses, rscores, rbboxes = process_image(img_cp) 103 | 104 | with codecs.open('demo/detections.txt', 'w', encoding='utf-8') as fout: 105 | for i in range(len(rclasses)): 106 | fout.write('{},{}\n'.format(rbboxes[i], rscores[i])) 107 | img_with_bbox = plt_bboxes(img_cp, rclasses, rscores, rbboxes) 108 | cv2.imwrite(os.path.join(path,'demo_res.png'), img_with_bbox) 109 | print('detection finished') 110 | else: 111 | raise ('no ckpt') 112 | 113 | -------------------------------------------------------------------------------- /demo/example/image0.txt: -------------------------------------------------------------------------------- 1 | text 0.597695052624 3 39 80 51 2 | -------------------------------------------------------------------------------- /demo/example/image0.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | test\image0.png 5 | 6 | 384 7 | 384 8 | 3 9 | 10 | 11 | 0 12 | 1 13 | 14 | 3 15 | 39 16 | 80 17 | 51 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /demo/example/standard.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | train_images 4 | img_10.jpg 5 | 6 | 1280 7 | 720 8 | 3 9 | 10 | 11 | 1 12 | ### 13 | none 14 | 15 | 261 16 | 138 17 | 284 18 | 140 19 | 279 20 | 158 21 | 260 22 | 158 23 | 260 24 | 138 25 | 284 26 | 158 27 | 28 | 29 | 30 | 0 31 | HarbourFront 32 | text 33 | 34 | 288 35 | 138 36 | 417 37 | 140 38 | 416 39 | 161 40 | 290 41 | 157 42 | 288 43 | 138 44 | 417 45 | 161 46 | 47 | 48 | 49 | 0 50 | CC22 51 | text 52 | 53 | 743 54 | 145 55 | 779 56 | 146 57 | 780 58 | 163 59 | 746 60 | 163 61 | 743 62 | 145 63 | 780 64 | 163 65 | 66 | 67 | 68 | 0 69 | bua 70 | text 71 | 72 | 783 73 | 129 74 | 831 75 | 132 76 | 833 77 | 155 78 | 785 79 | 153 80 | 783 81 | 129 82 | 833 83 | 155 84 | 85 | 86 | 87 | 1 88 | ### 89 | none 90 | 91 | 831 92 | 133 93 | 870 94 | 135 95 | 874 96 | 156 97 | 835 98 | 155 99 | 831 100 | 133 101 | 874 102 | 156 103 | 104 | 105 | 106 | 1 107 | ### 108 | none 109 | 110 | 159 111 | 205 112 | 230 113 | 204 114 | 231 115 | 218 116 | 159 117 | 219 118 | 159 119 | 204 120 | 231 121 | 219 122 | 123 | 124 | 125 | 1 126 | ### 127 | none 128 | 129 | 785 130 | 158 131 | 856 132 | 158 133 | 860 134 | 178 135 | 787 136 | 179 137 | 785 138 | 158 139 | 860 140 | 179 141 | 142 | 143 | 144 | 1 145 | ### 146 | none 147 | 148 | 1011 149 | 157 150 | 1079 151 | 160 152 | 1076 153 | 173 154 | 1011 155 | 170 156 | 1011 157 | 157 158 | 1079 159 | 173 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /demo/img_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/demo/img_1.jpg -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deployment/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/deployment/__init__.pyc -------------------------------------------------------------------------------- /deployment/model_deploy.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/deployment/model_deploy.pyc -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #! -*- encoding: utf-8 -*- 2 | import numpy as np 3 | import cv2 4 | import os 5 | from argparse import ArgumentParser 6 | 7 | import xml.etree.ElementTree as ET 8 | import shutil 9 | import sys 10 | import matplotlib.pyplot as plt 11 | info = sys.version_info 12 | if int(info[0]) == 2: 13 | reload(sys) 14 | sys.setdefaultencoding('utf-8') # 设置 'utf-8' 15 | 16 | 17 | def mat_inter(box1, box2): 18 | xmin_1, ymin_1, xmax_1, ymax_1 = box1 19 | xmin_2, ymin_2, xmax_2, ymax_2 = box2 20 | distance_between_box_x = abs((xmax_1 + xmin_1) / 2 - (xmax_2 + xmin_2) / 2) 21 | distance_between_box_y = abs((ymax_2 + ymin_2) / 2 - (ymin_1 + ymax_2) / 2) 22 | 23 | distance_box_1_x = abs(xmin_1 - xmax_1) 24 | distance_box_1_y = abs(ymax_1 - ymin_1) 25 | distance_box_2_x = abs(xmax_2 - xmin_2) 26 | distance_box_2_y = abs(ymax_2 - ymin_2) 27 | 28 | if distance_between_box_x < (distance_box_1_x + distance_box_2_x 29 | ) / 2 and distance_between_box_y < ( 30 | distance_box_2_y + distance_box_1_y) / 2: 31 | return True 32 | else: 33 | return False 34 | 35 | 36 | class EVAL_MODEL(object): 37 | def __init__(self, 38 | eval_data_dir, 39 | pre_data_dir, 40 | data_type, 41 | save_result_path, 42 | iou_th=0.5, 43 | save_err_path='err_pic'): 44 | # print(eval_data_dir , pre_data_dir ,data_type, save_result_path) 45 | if eval_data_dir is None or pre_data_dir is None or data_type is None or save_result_path is None: 46 | raise ValueError( 47 | 'please input eval_data_dir or pre_data_dir or data_type or save_result_path' 48 | ) 49 | self.eval_data_dir = eval_data_dir 50 | self.pre_data_dir = pre_data_dir 51 | self.data_type = data_type 52 | self.save_result_path = save_result_path 53 | 54 | self.allow_post_processing = False 55 | self.draw_err_pic_flag = True 56 | self.xml_path_list = [] 57 | self.pre_img_name_list = [] 58 | self.eval_data_dict = {} 59 | self.pre_data_dict = {} 60 | 61 | self.pre_data_num = 0 62 | self.gt_data_num = 0 63 | self.hit_precision = 0 64 | self.hit_recall = 0 65 | self.err_gt_dict = { 66 | 'ratio': [], 67 | 'h_scale': [], 68 | 'w_scale': [], 69 | 'height': [], 70 | 'width': [] 71 | } #save box ratio and size 72 | self.iou_thresh = float(iou_th) 73 | self.save_err_path = os.path.join(save_err_path, str(self.iou_thresh)) 74 | if not os.path.exists(self.save_err_path): 75 | os.makedirs(self.save_err_path) 76 | 77 | def list_from_str(self, st, dtype='float32'): 78 | line = st.split(' ')[2:6] 79 | if dtype == 'float32': 80 | line = [float(a) for a in line] 81 | else: 82 | line = [int(a) for a in line] 83 | return line 84 | 85 | def get_xml_path(self): 86 | for i in os.listdir(self.eval_data_dir): 87 | if i.split('.')[-1] == 'xml': 88 | self.xml_path_list.append(os.path.join(self.eval_data_dir, i)) 89 | return self.xml_path_list 90 | 91 | def read_gts(self): 92 | if self.data_type == '1': 93 | if self.eval_data_dir is None or self.eval_data_dir is '': 94 | raise ValueError('---eval data dir not exists!!!!-----') 95 | for i in self.xml_path_list: 96 | img_name = os.path.splitext(os.path.basename(i))[0] 97 | # img_name = os.path.basename(i) 98 | xml_info = ET.parse(i) 99 | root_node = xml_info.getroot() 100 | bbox_list = [] 101 | for obj_node in root_node.findall('object'): 102 | name_node = obj_node.find('name') 103 | # print(name_node) 104 | name = name_node.text 105 | if name == '&*@HUST_special' or name == '&*@HUST_shelter': 106 | continue 107 | difficult = int(obj_node.find('difficult').text) 108 | if difficult == 1: 109 | continue 110 | 111 | bndbox_node = obj_node.find('bndbox') 112 | xmin_filter = int(bndbox_node.find('xmin').text) 113 | ymin_filter = int(bndbox_node.find('ymin').text) 114 | xmax_filter = int(bndbox_node.find('xmax').text) 115 | ymax_filter = int(bndbox_node.find('ymax').text) 116 | 117 | bbox_list.append( 118 | [xmin_filter, ymin_filter, xmax_filter, ymax_filter]) 119 | 120 | self.eval_data_dict[img_name] = bbox_list 121 | elif self.data_type == '2': 122 | pass 123 | else: 124 | raise ValueError(' data type error !!! ') 125 | 126 | #list format: xmin ,ymin, xmax, ymax 127 | def bbox_iou_eval(self, list1, list2): 128 | xx1 = np.maximum(list1[0], list2[0]) 129 | yy1 = np.maximum(list1[1], list2[1]) 130 | xx2 = np.minimum(list1[2], list2[2]) 131 | yy2 = np.minimum(list1[3], list2[3]) 132 | 133 | w = np.maximum(0.0, xx2 - xx1 + 1) 134 | h = np.maximum(0.0, yy2 - yy1 + 1) 135 | inter = w * h 136 | area1 = (list1[2] - list1[0] + 1) * (list1[3] - list1[1] + 1) 137 | area2 = (list2[2] - list2[0] + 1) * (list2[3] - list2[1] + 1) 138 | iou = float(inter) / (area1 + area2 - inter) 139 | return iou 140 | 141 | def read_pres(self): 142 | for i in os.listdir(self.pre_data_dir): 143 | if i.split('.')[-1] == 'txt': 144 | img_name = os.path.splitext(os.path.basename(i))[0] 145 | # img_name = os.path.basename(i) 146 | self.pre_img_name_list.append(img_name) 147 | nms_outputs = open(os.path.join(self.pre_data_dir, i), 148 | 'r').readlines() 149 | dt_lines = [l.strip() for l in nms_outputs] 150 | dt_lines = [ 151 | self.list_from_str(dt, dtype='int32') for dt in dt_lines 152 | ] # score xmin ymin xmax ymax 153 | boxes = dt_lines 154 | bbox_without_same = [] 155 | for box in boxes: 156 | if box[1] != box[3]: 157 | bbox_without_same.append(box) 158 | boxes = bbox_without_same 159 | final_bboxes = [] 160 | if self.allow_post_processing is True: 161 | #box format : score xmin ymin xmax ymax 162 | del_index = [] 163 | for i, box in enumerate(boxes): 164 | if i in del_index: 165 | continue 166 | if len(boxes[i + 1:]) == 0: 167 | if box not in final_bboxes and i not in del_index: 168 | final_bboxes.append(box) 169 | break 170 | for j, box_2rd in enumerate(boxes[i + 1:]): 171 | if j in del_index: 172 | continue 173 | ymin_second = int(box_2rd[1]) 174 | ymax_second = int(box_2rd[3]) 175 | xmin_second = int(box_2rd[0]) 176 | xmax_second = int(box_2rd[2]) 177 | ymin_first = int(box[1]) 178 | ymax_first = int(box[3]) 179 | xmin_first = int(box[0]) 180 | xmax_first = int(box[2]) 181 | if abs(ymin_second - ymin_first) <= 5 and abs( 182 | ymax_first - 183 | ymax_second) <= 5 and mat_inter( 184 | box, box_2rd): 185 | 186 | xmin_final = min(xmin_first, xmin_second) 187 | ymin_final = min(ymin_first, ymin_second) 188 | xmax_final = max(xmax_first, xmax_second) 189 | ymax_final = max(ymax_first, ymax_second) 190 | temp_box = [ 191 | xmin_final, ymin_final, xmax_final, 192 | ymax_final 193 | ] 194 | del_index.append(i) 195 | del_index.append(j + i + 1) 196 | box = temp_box 197 | final_bboxes.append(box) 198 | dt_lines = final_bboxes 199 | 200 | self.pre_data_dict[img_name] = dt_lines 201 | 202 | def contrast_pre_gt(self): 203 | for img_name in self.pre_img_name_list: 204 | pre_gts = self.pre_data_dict[img_name] 205 | eval_gts = self.eval_data_dict[img_name] 206 | error_pre = [] 207 | error_eval = [] 208 | 209 | for i, eval_gt in enumerate(eval_gts): 210 | flag_strick = 0 211 | err_flag = False 212 | for j, pre_gt in enumerate(pre_gts): 213 | iou = self.bbox_iou_eval(eval_gt, pre_gt) 214 | # print(iou) 215 | if iou >= self.iou_thresh: 216 | flag_strick += 1 217 | err_flag = False 218 | break 219 | else: 220 | err_flag = True 221 | if flag_strick >= 1: 222 | self.hit_precision += 1 223 | self.hit_recall += 1 224 | self.gt_data_num += 1 225 | if err_flag is True: 226 | error_eval.append(eval_gt) 227 | 228 | self.pre_data_num += len(pre_gts) 229 | 230 | for i, pre_gt in enumerate(pre_gts): 231 | err_flag = False 232 | for j, eval_gt in enumerate(eval_gts): 233 | iou = self.bbox_iou_eval(pre_gt, eval_gt) 234 | if iou >= self.iou_thresh: 235 | err_flag = False 236 | break 237 | else: 238 | err_flag = True 239 | if err_flag is True: 240 | error_pre.append(pre_gt) 241 | if len(error_eval) != 0 or len(error_pre) != 0: 242 | self.save_error_eval(img_name, error_eval, error_pre) 243 | 244 | def save_error_eval(self, img_name, error_eval_list, error_pre_list): 245 | # img_path = os.path.join(self.eval_data_dir, img_name) 246 | img_path = os.path.join(self.eval_data_dir, img_name + '.jpg') 247 | img = cv2.imread(img_path) 248 | 249 | if img is None: 250 | img_path = os.path.join(self.eval_data_dir, img_name + '.PNG') 251 | img = cv2.imread(img_path) 252 | img_h, img_w, _ = img.shape 253 | with open( 254 | os.path.join(self.save_err_path, img_name + '_err_gt.txt'), 255 | 'w') as f_target: 256 | for err_eval in error_eval_list: 257 | # print(err_eval) 258 | width = int(err_eval[2]) - int(err_eval[0]) 259 | height = int(err_eval[3]) - int(err_eval[1]) 260 | if width != 0: 261 | self.err_gt_dict['ratio'].append(float(height) / width) 262 | else: 263 | self.err_gt_dict['ratio'].append(0.) 264 | self.err_gt_dict['h_scale'].append(float(height) / img_h) 265 | self.err_gt_dict['w_scale'].append(float(width) / img_w) 266 | self.err_gt_dict['height'].append(height) 267 | self.err_gt_dict['width'].append(width) 268 | img = self.draw_polygon( 269 | img, 270 | err_eval[0], 271 | err_eval[1], 272 | err_eval[2], 273 | err_eval[3], 274 | is_gt=True) 275 | 276 | f_target.write(','.join([str(i) for i in err_eval]) + '\n') 277 | with open( 278 | os.path.join(self.save_err_path, img_name + '_err_pre.txt'), 279 | 'w') as f_target: 280 | for err_pre in error_pre_list: 281 | f_target.write(','.join([str(j) for j in err_pre]) + '\n') 282 | img = self.draw_polygon( 283 | img, 284 | err_pre[0], 285 | err_pre[1], 286 | err_pre[2], 287 | err_pre[3], 288 | is_gt=False) 289 | cv2.imwrite(os.path.join(self.save_err_path, img_name + '.png'), img) 290 | 291 | def draw_polygon(self, img, xmin, ymin, xmax, ymax, is_gt=False): 292 | xmin = int(xmin) 293 | ymin = int(ymin) 294 | xmax = int(xmax) 295 | ymax = int(ymax) 296 | if is_gt is True: 297 | color = (255, 0, 0) 298 | else: 299 | color = (0, 255, 0) 300 | cv2.line(img, (xmin, ymin), (xmax, ymin), color, 2) 301 | cv2.line(img, (xmax, ymin), (xmax, ymax), color, 2) 302 | cv2.line(img, (xmax, ymax), (xmin, ymax), color, 2) 303 | cv2.line(img, (xmin, ymax), (xmin, ymin), color, 2) 304 | return img 305 | 306 | def cal_precision_recall(self): 307 | recall = float(self.hit_recall) / self.gt_data_num 308 | precision = float(self.hit_precision) / self.pre_data_num 309 | if recall != 0 and precision != 0: 310 | f_measure = 2 * recall * precision / (recall + precision) 311 | else: 312 | f_measure = 0 313 | return [precision, recall, f_measure] 314 | 315 | def draw_err_gt(self): 316 | err_gt_dict = self.err_gt_dict 317 | for key in err_gt_dict.keys(): 318 | values = err_gt_dict[key] 319 | fig = plt.figure() 320 | ax = fig.add_subplot(111) 321 | numBins = 50 322 | (counts, bins, patch) = ax.hist( 323 | values, numBins, color='blue', alpha=0.4, rwidth=0.5) 324 | #print('*****', key, '******') 325 | #print(counts) 326 | #print(bins) 327 | for i in range(len(counts)): 328 | plt.text( 329 | bins[i], 330 | counts[i], 331 | str(int(counts[i])), 332 | fontdict={ 333 | 'size': 6, 334 | 'color': 'r' 335 | }) 336 | if key in ['h_scale', 'w_scale', 'ratio']: 337 | mid = round((float(bins[i]) + float(bins[i + 1])) / 2, 2) 338 | else: 339 | mid = int(bins[i] + bins[i + 1] / 2) 340 | #if i % 5 == 0: 341 | plt.text( 342 | bins[i], 343 | counts[i] + 20, 344 | str(mid), 345 | fontdict={ 346 | 'size': 10, 347 | 'color': 'b' 348 | }) 349 | #print(patch) 350 | plt.grid(True) 351 | plt.title(u'{}'.format(key)) 352 | plt.savefig('{}/{}.png'.format(self.save_err_path, key)) 353 | with open('{}/{}.txt'.format(self.save_err_path, key), 'w') as f: 354 | for value in values: 355 | f.write('{}\n'.format(value)) 356 | 357 | def start_eval(self): 358 | print('----start eval----') 359 | print('---get xml path---') 360 | self.get_xml_path() 361 | print('---reading gts----') 362 | self.read_gts() 363 | print('---reading predictions---') 364 | self.read_pres() 365 | print('---contrast pre gt-----') 366 | self.contrast_pre_gt() 367 | pre, recall, f_measure = self.cal_precision_recall() 368 | print('pre:{} recall:{} f_measure:{}'.format(pre, recall, f_measure)) 369 | with open(self.save_result_path, 'a+') as f: 370 | f.write('iou:{} pre:{} recall:{} f_measure:{}\n'.format( 371 | self.iou_thresh, pre, recall, f_measure)) 372 | 373 | if self.draw_err_pic_flag == True: 374 | self.draw_err_gt() 375 | print('-----end eval------') 376 | 377 | 378 | if __name__ == '__main__': 379 | 380 | parser = ArgumentParser(description='icdar15 eval model') 381 | parser.add_argument( 382 | '--eval_data_dir', 383 | '-d', 384 | default= 385 | '/home/zsz/datasets/icdar15/test_gts/', 386 | type=str) 387 | #xml and img in same dir 388 | parser.add_argument('--pre_data_dir', '-p', type=str) 389 | #pre_data_dir is prediction format txt: text score xmin ymin xmax ymax 390 | parser.add_argument('--eval_file_type', '-f', default='1', type=str) 391 | parser.add_argument( 392 | '--save_result_path', '-s', default='result.txt', type=str) 393 | parser.add_argument('--iou_th', '-o', default=0.5, type=float) 394 | #1:read gt from xml format:(xmin ymin xmax ymax) 395 | #2:read gt from txt format:(text score xmin ymin xmax ymax) 396 | args = parser.parse_args() 397 | # print(args, args.eval_data_dir) 398 | eval_data_dir = args.eval_data_dir 399 | pre_data_dir = args.pre_data_dir 400 | eval_file_type = args.eval_file_type 401 | save_result_path = args.save_result_path 402 | iou_th = args.iou_th 403 | 404 | eval_model = EVAL_MODEL(eval_data_dir, pre_data_dir, eval_file_type, 405 | save_result_path, iou_th) 406 | eval_model.start_eval() 407 | -------------------------------------------------------------------------------- /eval_result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | import os 5 | 6 | from apscheduler.schedulers.blocking import BlockingScheduler 7 | 8 | import time 9 | import tensorflow as tf 10 | 11 | scheduler = BlockingScheduler() 12 | model_dir = './model' 13 | def eval_net(): 14 | # ckpt_list = tf.train.get_checkpoint_state(model_dir).all_model_checkpoint_paths 15 | 16 | # for ckpt_path in ckpt_list: 17 | # if os.path.exists('{}.data-00000-of-00001'.format(ckpt_path)): 18 | print(time.asctime()) 19 | time.sleep(5) 20 | ckpt_path = tf.train.latest_checkpoint(model_dir) 21 | os.system("python test.py -c -0 -m {} -o test_result/{}".format(ckpt_path, ckpt_path.split('/')[-1])) 22 | 23 | eval_net() 24 | scheduler.add_job(eval_net, 'interval', seconds=1200) 25 | 26 | scheduler.start() 27 | -------------------------------------------------------------------------------- /gene_tfrecords.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | 4 | from datasets import xml_to_tfrecords 5 | import os 6 | FLAGS = tf.app.flags.FLAGS 7 | 8 | tf.app.flags.DEFINE_string( 9 | 'output_name', 'annotated_data', 10 | 'Basename used for TFRecords output files.') 11 | tf.app.flags.DEFINE_string( 12 | 'output_dir', 'tfrecords', 13 | 'Output directory where to store TFRecords files.') 14 | tf.app.flags.DEFINE_string( 15 | 'xml_img_txt_path', None, 16 | 'the path means the txt' 17 | ) 18 | 19 | tf.app.flags.DEFINE_integer( 20 | 'samples_per_files', 2000, 21 | 'the number means one tf_record save how many pictures' 22 | ) 23 | 24 | def main(_): 25 | if not FLAGS.xml_img_txt_path or not os.path.exists(FLAGS.xml_img_txt_path): 26 | raise ValueError('You must supply the dataset directory with --xml_img_txt_path') 27 | print('Dataset directory:', FLAGS.xml_img_txt_path) 28 | print('Output directory:', FLAGS.output_dir) 29 | 30 | xml_to_tfrecords.run(FLAGS.xml_img_txt_path, FLAGS.output_dir, FLAGS.output_name, FLAGS.samples_per_files) 31 | 32 | if __name__ == '__main__': 33 | tf.app.run() 34 | -------------------------------------------------------------------------------- /logs/train_xml.txt: -------------------------------------------------------------------------------- 1 | /home/zsz/datasets/weblmtImage 26081096.0_242.0_1377.0_456.0.png,save_xml/weblmtImage 26081096.0_242.0_1377.0_456.0.xml 2 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__init__.pyc -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/custom_layers.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/custom_layers.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/np_methods.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/np_methods.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/textbox_common.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/textbox_common.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/txtbox_384.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/txtbox_384.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/txtbox_768.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/__pycache__/txtbox_768.cpython-35.pyc -------------------------------------------------------------------------------- /nets/custom_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implement some custom layers, not provided by TensorFlow. 16 | 17 | Trying to follow as much as possible the style/standards used in 18 | tf.contrib.layers 19 | """ 20 | import tensorflow as tf 21 | 22 | from tensorflow.contrib.framework.python.ops import add_arg_scope 23 | from tensorflow.contrib.layers.python.layers import initializers 24 | from tensorflow.contrib.framework.python.ops import variables 25 | from tensorflow.contrib.layers.python.layers import utils 26 | from tensorflow.python.ops import nn 27 | from tensorflow.python.ops import init_ops 28 | from tensorflow.python.ops import variable_scope 29 | 30 | 31 | def abs_smooth(x): 32 | """Smoothed absolute function. Useful to compute an L1 smooth error. 33 | 34 | Define as: 35 | x^2 / 2 if abs(x) < 1 36 | abs(x) - 0.5 if abs(x) > 1 37 | We use here a differentiable definition using min(x) and abs(x). Clearly 38 | not optimal, but good enough for our purpose! 39 | """ 40 | absx = tf.abs(x) 41 | minx = tf.minimum(absx, 1) 42 | r = 0.5 * ((absx - 1) * minx + absx) 43 | return r 44 | 45 | 46 | @add_arg_scope 47 | def l2_normalization( 48 | inputs, 49 | scaling=False, 50 | scale_initializer=init_ops.ones_initializer(), 51 | reuse=None, 52 | variables_collections=None, 53 | outputs_collections=None, 54 | data_format='NHWC', 55 | trainable=True, 56 | scope=None): 57 | """Implement L2 normalization on every feature (i.e. spatial normalization). 58 | 59 | Should be extended in some near future to other dimensions, providing a more 60 | flexible normalization framework. 61 | 62 | Args: 63 | inputs: a 4-D tensor with dimensions [batch_size, height, width, channels]. 64 | scaling: whether or not to add a post scaling operation along the dimensions 65 | which have been normalized. 66 | scale_initializer: An initializer for the weights. 67 | reuse: whether or not the layer and its variables should be reused. To be 68 | able to reuse the layer scope must be given. 69 | variables_collections: optional list of collections for all the variables or 70 | a dictionary containing a different list of collection per variable. 71 | outputs_collections: collection to add the outputs. 72 | data_format: NHWC or NCHW data format. 73 | trainable: If `True` also add variables to the graph collection 74 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 75 | scope: Optional scope for `variable_scope`. 76 | Returns: 77 | A `Tensor` representing the output of the operation. 78 | """ 79 | 80 | with variable_scope.variable_scope( 81 | scope, 'L2Normalization', [inputs], reuse=reuse) as sc: 82 | inputs_shape = inputs.get_shape() 83 | inputs_rank = inputs_shape.ndims 84 | dtype = inputs.dtype.base_dtype 85 | if data_format == 'NHWC': 86 | # norm_dim = tf.range(1, inputs_rank-1) 87 | norm_dim = tf.range(inputs_rank-1, inputs_rank) 88 | params_shape = inputs_shape[-1:] 89 | elif data_format == 'NCHW': 90 | # norm_dim = tf.range(2, inputs_rank) 91 | norm_dim = tf.range(1, 2) 92 | params_shape = (inputs_shape[1]) 93 | 94 | # Normalize along spatial dimensions. 95 | outputs = nn.l2_normalize(inputs, norm_dim, epsilon=1e-12) 96 | # Additional scaling. 97 | if scaling: 98 | scale_collections = utils.get_variable_collections( 99 | variables_collections, 'scale') 100 | scale = variables.model_variable('gamma', 101 | shape=params_shape, 102 | dtype=dtype, 103 | initializer=scale_initializer, 104 | collections=scale_collections, 105 | trainable=trainable) 106 | if data_format == 'NHWC': 107 | outputs = tf.multiply(outputs, scale) 108 | elif data_format == 'NCHW': 109 | scale = tf.expand_dims(scale, axis=-1) 110 | scale = tf.expand_dims(scale, axis=-1) 111 | outputs = tf.multiply(outputs, scale) 112 | # outputs = tf.transpose(outputs, perm=(0, 2, 3, 1)) 113 | 114 | return utils.collect_named_outputs(outputs_collections, 115 | sc.original_name_scope, outputs) 116 | 117 | 118 | @add_arg_scope 119 | def pad2d(inputs, 120 | pad=(0, 0), 121 | mode='CONSTANT', 122 | data_format='NHWC', 123 | trainable=True, 124 | scope=None): 125 | """2D Padding layer, adding a symmetric padding to H and W dimensions. 126 | 127 | Aims to mimic padding in Caffe and MXNet, helping the port of models to 128 | TensorFlow. Tries to follow the naming convention of `tf.contrib.layers`. 129 | 130 | Args: 131 | inputs: 4D input Tensor; 132 | pad: 2-Tuple with padding values for H and W dimensions; 133 | mode: Padding mode. C.f. `tf.pad` 134 | data_format: NHWC or NCHW data format. 135 | """ 136 | with tf.name_scope(scope, 'pad2d', [inputs]): 137 | # Padding shape. 138 | if data_format == 'NHWC': 139 | paddings = [[0, 0], [pad[0], pad[0]], [pad[1], pad[1]], [0, 0]] 140 | elif data_format == 'NCHW': 141 | paddings = [[0, 0], [0, 0], [pad[0], pad[0]], [pad[1], pad[1]]] 142 | net = tf.pad(inputs, paddings, mode=mode) 143 | return net 144 | 145 | 146 | @add_arg_scope 147 | def channel_to_last(inputs, 148 | data_format='NHWC', 149 | scope=None): 150 | """Move the channel axis to the last dimension. Allows to 151 | provide a single output format whatever the input data format. 152 | 153 | Args: 154 | inputs: Input Tensor; 155 | data_format: NHWC or NCHW. 156 | Return: 157 | Input in NHWC format. 158 | """ 159 | with tf.name_scope(scope, 'channel_to_last', [inputs]): 160 | if data_format == 'NHWC': 161 | net = inputs 162 | elif data_format == 'NCHW': 163 | net = tf.transpose(inputs, perm=(0, 2, 3, 1)) 164 | return net 165 | -------------------------------------------------------------------------------- /nets/custom_layers.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/custom_layers.pyc -------------------------------------------------------------------------------- /nets/np_methods.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Additional Numpy methods. Big mess of many things! 16 | """ 17 | import numpy as np 18 | 19 | 20 | # =========================================================================== # 21 | # Numpy implementations of SSD boxes functions. 22 | # =========================================================================== # 23 | def ssd_bboxes_decode(feat_localizations, 24 | anchor_bboxes, 25 | prior_scaling=[0.1, 0.1, 0.2, 0.2]): 26 | """Compute the relative bounding boxes from the layer features and 27 | reference anchor bounding boxes. 28 | 29 | Return: 30 | numpy array Nx4: ymin, xmin, ymax, xmax 31 | """ 32 | # Reshape for easier broadcasting. 33 | l_shape = feat_localizations.shape 34 | 35 | # feat_localizations = np.reshape(feat_localizations, 36 | # (-1, l_shape[-1])) 37 | feat_localizations = np.reshape(feat_localizations, 38 | (-1, l_shape[-2], l_shape[-1])) 39 | anchor_xmin, anchor_ymin, anchor_xmax, anchor_ymax = anchor_bboxes 40 | 41 | xref = (anchor_xmin + anchor_xmax) /2. 42 | yref = (anchor_ymin + anchor_ymax) /2. 43 | 44 | href = anchor_ymax - anchor_ymin 45 | wref = anchor_xmax - anchor_xmin 46 | 47 | decode_bbox_center_x = prior_scaling[0] * feat_localizations[:, :, 0] * wref + xref 48 | decode_bbox_center_y = prior_scaling[1] * feat_localizations[:, :, 1] * href + yref 49 | decode_bbox_width = np.exp(feat_localizations[:, :, 2] * prior_scaling[2]) * wref 50 | decode_bbox_height = np.exp(feat_localizations[: , :, 3] * prior_scaling[3]) * href 51 | 52 | xmin = decode_bbox_center_x - decode_bbox_width / 2. 53 | ymin = decode_bbox_center_y - decode_bbox_height / 2. 54 | xmax = decode_bbox_center_x + decode_bbox_width / 2. 55 | ymax = decode_bbox_center_y + decode_bbox_height / 2. 56 | 57 | x1 = prior_scaling[0] * feat_localizations[:, :, 4] * wref + anchor_xmin 58 | y1 = prior_scaling[1] * feat_localizations[:, :, 5] * href + anchor_ymin 59 | x2 = prior_scaling[0] * feat_localizations[:, :, 6] * wref + anchor_xmax 60 | y2 = prior_scaling[1] * feat_localizations[:, :, 7] * href + anchor_ymin 61 | 62 | x3 = prior_scaling[0] * feat_localizations[:, :, 8] * wref + anchor_xmax 63 | y3 = prior_scaling[1] * feat_localizations[:, :, 9] * href + anchor_ymax 64 | x4 = prior_scaling[0] * feat_localizations[:, :, 10] * wref + anchor_xmin 65 | y4 = prior_scaling[1] * feat_localizations[:, :, 11] * href + anchor_ymax 66 | 67 | # bboxes: ymin, xmin, xmax, ymax. 68 | bboxes = np.zeros_like(feat_localizations) 69 | 70 | bboxes[:,:,0] = xmin 71 | bboxes[:,:,1] = ymin 72 | bboxes[:,:,2] = xmax 73 | bboxes[:,:,3] = ymax 74 | bboxes[:, :, 4] = x1 75 | bboxes[:, :, 5] = x2 76 | bboxes[:, :, 6] = x3 77 | bboxes[:, :, 7] = x4 78 | bboxes[:, :, 8] = y1 79 | bboxes[:, :, 9] = y2 80 | bboxes[:, :, 10] = y3 81 | bboxes[:, :, 11] = y4 82 | 83 | # Back to original shape. 84 | bboxes = np.reshape(bboxes, l_shape) 85 | return bboxes 86 | 87 | 88 | def ssd_bboxes_select(predictions_net, 89 | localizations_net, 90 | anchors_net, 91 | select_threshold=0.5, 92 | img_shape=(384, 384), 93 | num_classes=2, 94 | decode=True): 95 | """Extract classes, scores and bounding boxes from network output layers. 96 | 97 | Return: 98 | classes, scores, bboxes: Numpy arrays... 99 | """ 100 | l_classes = [] 101 | l_scores = [] 102 | l_bboxes = [] 103 | for i in range(len(predictions_net)): 104 | classes, scores, bboxes = ssd_bboxes_select_layer( 105 | predictions_net[i], localizations_net[i], anchors_net[i], 106 | select_threshold, img_shape, num_classes, decode) 107 | l_classes.append(classes) 108 | l_scores.append(scores) 109 | l_bboxes.append(bboxes) 110 | # Debug information. 111 | # l_layers.append(i) 112 | # l_idxes.append((i, idxes)) 113 | 114 | classes = np.concatenate(l_classes, 0) 115 | scores = np.concatenate(l_scores, 0) 116 | bboxes = np.concatenate(l_bboxes, 0) 117 | return classes, scores, bboxes 118 | 119 | 120 | def ssd_bboxes_select_layer(predictions_layer, 121 | localizations_layer, 122 | anchors_layer, 123 | select_threshold=0.2, 124 | img_shape=(384, 384), 125 | num_classes=2, 126 | decode=True): 127 | """Extract classes, scores and bounding boxes from features in one layer. 128 | 129 | Return: 130 | classes, scores, bboxes: Numpy arrays... 131 | """ 132 | # First decode localizations features if necessary. 133 | if decode: 134 | localizations_layer = ssd_bboxes_decode(localizations_layer, anchors_layer) 135 | 136 | # Reshape features to: Batches x N x N_labels | 4. 137 | p_shape = predictions_layer.shape 138 | batch_size = p_shape[0] if len(p_shape) == 5 else 1 139 | predictions_layer = np.reshape(predictions_layer, 140 | (batch_size, -1, p_shape[-1])) 141 | l_shape = localizations_layer.shape 142 | localizations_layer = np.reshape(localizations_layer, 143 | (batch_size, -1, l_shape[-1])) 144 | 145 | # Boxes selection: use threshold or score > no-label criteria. 146 | if select_threshold is None or select_threshold == 0: 147 | # Class prediction and scores: assign 0. to 0-class 148 | classes = np.argmax(predictions_layer, axis=2) 149 | scores = np.argmax(predictions_layer, axis=2) 150 | mask = (classes > 0) 151 | classes = classes[mask] 152 | scores = scores[mask] 153 | bboxes = localizations_layer[mask] 154 | else:#two preditcions 155 | sub_predictions = predictions_layer[:, :, 1:] 156 | idxes = np.where(sub_predictions > select_threshold) 157 | classes = idxes[-1]+1 158 | scores = sub_predictions[idxes] 159 | bboxes = localizations_layer[idxes[:-1]] 160 | 161 | return classes, scores, bboxes 162 | 163 | 164 | # =========================================================================== # 165 | # Common functions for bboxes handling and selection. 166 | # =========================================================================== # 167 | def bboxes_sort(classes, scores, bboxes, top_k=400): 168 | """Sort bounding boxes by decreasing order and keep only the top_k 169 | """ 170 | # if priority_inside: 171 | # inside = (bboxes[:, 0] > margin) & (bboxes[:, 1] > margin) & \ 172 | # (bboxes[:, 2] < 1-margin) & (bboxes[:, 3] < 1-margin) 173 | # idxes = np.argsort(-scores) 174 | # inside = inside[idxes] 175 | # idxes = np.concatenate([idxes[inside], idxes[~inside]]) 176 | idxes = np.argsort(-scores) 177 | classes = classes[idxes][:top_k] 178 | scores = scores[idxes][:top_k] 179 | bboxes = bboxes[idxes][:top_k] 180 | return classes, scores, bboxes 181 | 182 | 183 | def bboxes_clip(bbox_ref, bboxes): 184 | """Clip bounding boxes with respect to reference bbox. 185 | """ 186 | bboxes = np.copy(bboxes) 187 | bboxes = np.transpose(bboxes) 188 | bbox_ref = np.transpose(bbox_ref) 189 | bboxes[0] = np.maximum(bboxes[0], bbox_ref[0]) 190 | bboxes[1] = np.maximum(bboxes[1], bbox_ref[1]) 191 | bboxes[2] = np.minimum(bboxes[2], bbox_ref[2]) 192 | bboxes[3] = np.minimum(bboxes[3], bbox_ref[3]) 193 | 194 | bboxes[4] = np.maximum(bboxes[4], bbox_ref[4]) 195 | bboxes[5] = np.maximum(bboxes[5], bbox_ref[5]) 196 | bboxes[6] = np.maximum(bboxes[6], bbox_ref[6]) 197 | bboxes[7] = np.maximum(bboxes[7], bbox_ref[7]) 198 | bboxes[8] = np.maximum(bboxes[8], bbox_ref[8]) 199 | bboxes[9] = np.maximum(bboxes[9], bbox_ref[9]) 200 | bboxes[10] = np.maximum(bboxes[10], bbox_ref[10]) 201 | bboxes[11] = np.maximum(bboxes[11], bbox_ref[11]) 202 | 203 | bboxes = np.transpose(bboxes) 204 | return bboxes 205 | 206 | 207 | def bboxes_resize(bbox_ref, bboxes): 208 | """Resize bounding boxes based on a reference bounding box, 209 | assuming that the latter is [0, 0, 1, 1] after transform. 210 | """ 211 | bboxes = np.copy(bboxes) 212 | # Translate. 213 | bboxes[:, 0] -= bbox_ref[0] 214 | bboxes[:, 1] -= bbox_ref[1] 215 | bboxes[:, 2] -= bbox_ref[0] 216 | bboxes[:, 3] -= bbox_ref[1] 217 | # Resize. 218 | resize = [bbox_ref[2] - bbox_ref[0], bbox_ref[3] - bbox_ref[1]] 219 | bboxes[:, 0] /= resize[0] 220 | bboxes[:, 1] /= resize[1] 221 | bboxes[:, 2] /= resize[0] 222 | bboxes[:, 3] /= resize[1] 223 | return bboxes 224 | 225 | 226 | def bboxes_jaccard(bboxes1, bboxes2): 227 | """Computing jaccard index between bboxes1 and bboxes2. 228 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable. 229 | """ 230 | bboxes1 = np.transpose(bboxes1) 231 | bboxes2 = np.transpose(bboxes2) 232 | # Intersection bbox and volume. 233 | int_ymin = np.maximum(bboxes1[0], bboxes2[0]) 234 | int_xmin = np.maximum(bboxes1[1], bboxes2[1]) 235 | int_ymax = np.minimum(bboxes1[2], bboxes2[2]) 236 | int_xmax = np.minimum(bboxes1[3], bboxes2[3]) 237 | 238 | int_h = np.maximum(int_ymax - int_ymin, 0.) 239 | int_w = np.maximum(int_xmax - int_xmin, 0.) 240 | int_vol = int_h * int_w 241 | # Union volume. 242 | vol1 = (bboxes1[2] - bboxes1[0]) * (bboxes1[3] - bboxes1[1]) 243 | vol2 = (bboxes2[2] - bboxes2[0]) * (bboxes2[3] - bboxes2[1]) 244 | jaccard = int_vol / (vol1 + vol2 - int_vol) 245 | return jaccard 246 | 247 | 248 | def bboxes_intersection(bboxes_ref, bboxes2): 249 | """Computing jaccard index between bboxes1 and bboxes2. 250 | Note: bboxes1 and bboxes2 can be multi-dimensional, but should broacastable. 251 | """ 252 | bboxes_ref = np.transpose(bboxes_ref) 253 | bboxes2 = np.transpose(bboxes2) 254 | # Intersection bbox and volume. 255 | int_ymin = np.maximum(bboxes_ref[0], bboxes2[0]) 256 | int_xmin = np.maximum(bboxes_ref[1], bboxes2[1]) 257 | int_ymax = np.minimum(bboxes_ref[2], bboxes2[2]) 258 | int_xmax = np.minimum(bboxes_ref[3], bboxes2[3]) 259 | 260 | int_h = np.maximum(int_ymax - int_ymin, 0.) 261 | int_w = np.maximum(int_xmax - int_xmin, 0.) 262 | int_vol = int_h * int_w 263 | # Union volume. 264 | vol = (bboxes_ref[2] - bboxes_ref[0]) * (bboxes_ref[3] - bboxes_ref[1]) 265 | score = int_vol / vol 266 | return score 267 | 268 | 269 | def bboxes_nms(classes, scores, bboxes, nms_threshold=0.45): 270 | """Apply non-maximum selection to bounding boxes. 271 | """ 272 | keep_bboxes = np.ones(scores.shape, dtype=np.bool) 273 | for i in range(scores.size-1): 274 | if keep_bboxes[i]: 275 | # Computer overlap with bboxes which are following. 276 | overlap = bboxes_jaccard(bboxes[i], bboxes[(i+1):]) 277 | # Overlap threshold for keeping + checking part of the same class 278 | keep_overlap = np.logical_or(overlap < nms_threshold, classes[(i+1):] != classes[i]) 279 | keep_bboxes[(i+1):] = np.logical_and(keep_bboxes[(i+1):], keep_overlap) 280 | 281 | idxes = np.where(keep_bboxes) 282 | return classes[idxes], scores[idxes], bboxes[idxes] 283 | 284 | 285 | def bboxes_nms_fast(classes, scores, bboxes, threshold=0.45): 286 | """Apply non-maximum selection to bounding boxes. 287 | """ 288 | pass 289 | 290 | -------------------------------------------------------------------------------- /nets/np_methods.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/np_methods.pyc -------------------------------------------------------------------------------- /nets/textbox_common.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import math 5 | 6 | 7 | 8 | # =========================================================================== # 9 | # TensorFlow implementation of Text Boxes encoding / decoding. 10 | # =========================================================================== # 11 | 12 | def tf_text_bboxes_encode_layer(glabels, bboxes,gxs , gys, 13 | anchors_layer, 14 | matching_threshold=0.1, 15 | prior_scaling=[0.1, 0.1, 0.2, 0.2], 16 | dtype=tf.float32): 17 | 18 | """ 19 | Encode groundtruth labels and bounding boxes using Textbox anchors from 20 | one layer. 21 | 22 | Arguments: 23 | bboxes: Nx4 Tensor(float) with bboxes relative coordinates; 24 | gxs,gys: Nx4 Tensor 25 | anchors_layer: Numpy array with layer anchors; 26 | matching_threshold: Threshold for positive match with groundtruth bboxes; 27 | prior_scaling: Scaling of encoded coordinates. 28 | 29 | Return: 30 | (target_localizations, target_scores): Target Tensors. 31 | # this is a binary problem, so target_score and tartget_labels are same. 32 | """ 33 | # Anchors coordinates and volume. 34 | 35 | # yref, xref, href, wref = anchors_layer 36 | xmin, ymin, xmax, ymax = anchors_layer 37 | xref = (xmin + xmax) /2 38 | yref = (ymin + ymax) /2 39 | href = ymax - ymin 40 | wref = xmax - xmin 41 | # caffe 源码是对每个pred_bbox(预测的bbox)和gt 进行overlap计算 42 | print(yref.shape) 43 | print(href.shape) 44 | print(bboxes.shape) 45 | # glabels = tf.Print(glabels, [tf.shape(glabels)], message=' glabels shape is :') 46 | #,,2 47 | # ymin = yref - href / 2. 48 | # xmin = xref - wref / 2. 49 | # ymax = yref + href / 2. 50 | # xmax = xref + wref / 2. 51 | vol_anchors = (xmax - xmin) * (ymax - ymin) 52 | 53 | # bboxes = tf.Print(bboxes, [tf.shape(bboxes), bboxes], message=' bboxes in encode shape:', summarize=20) 54 | # glabels = tf.Print(glabels, [tf.shape(glabels)], message=' glabels shape:') 55 | 56 | # xs = np.asarray(gxs, dtype=np.float32) 57 | # ys = np.asarray(gys, dtype=np.float32) 58 | # num_bboxes = xs.shape[0] 59 | 60 | 61 | # Initialize tensors... 62 | shape = (yref.shape[0]) 63 | # all after the flatten 64 | feat_labels = tf.zeros(shape, dtype=tf.int64) 65 | feat_scores = tf.zeros(shape, dtype=dtype) 66 | 67 | feat_ymin = tf.zeros(shape, dtype=dtype) 68 | feat_xmin = tf.zeros(shape, dtype=dtype) 69 | feat_ymax = tf.ones(shape, dtype=dtype) 70 | feat_xmax = tf.ones(shape, dtype=dtype) 71 | 72 | feat_x1 = tf.zeros(shape, dtype=dtype) 73 | feat_x2 = tf.zeros(shape, dtype=dtype) 74 | feat_x3 = tf.zeros(shape, dtype=dtype) 75 | feat_x4 = tf.zeros(shape, dtype=dtype) 76 | feat_y1 = tf.zeros(shape, dtype=dtype) 77 | feat_y2 = tf.zeros(shape, dtype=dtype) 78 | feat_y3 = tf.zeros(shape, dtype=dtype) 79 | feat_y4 = tf.zeros(shape, dtype=dtype) 80 | 81 | 82 | # feat_x1 =tf.zeros() 83 | 84 | def jaccard_with_anchors(bbox): 85 | """ 86 | Compute jaccard score between a box and the anchors. 87 | """ 88 | int_ymin = tf.maximum(ymin, bbox[0]) 89 | int_xmin = tf.maximum(xmin, bbox[1]) 90 | int_ymax = tf.minimum(ymax, bbox[2]) 91 | int_xmax = tf.minimum(xmax, bbox[3]) 92 | h = tf.maximum(int_ymax - int_ymin, 0.) 93 | w = tf.maximum(int_xmax - int_xmin, 0.) 94 | # Volumes. 95 | inter_vol = h * w 96 | union_vol = vol_anchors - inter_vol \ 97 | + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) 98 | jaccard = tf.div(inter_vol, union_vol) 99 | return jaccard 100 | 101 | 102 | def intersection_with_anchors(bbox): 103 | ''' 104 | Compute intersection between score a box and the anchors. 105 | ''' 106 | int_ymin = tf.maximum(ymin, bbox[0]) 107 | int_xmin = tf.maximum(xmin, bbox[1]) 108 | int_ymax = tf.minimum(ymax, bbox[2]) 109 | int_xmax = tf.minimum(xmax, bbox[3]) 110 | h = tf.maximum(int_ymax - int_ymin, 0.) 111 | w = tf.maximum(int_xmax - int_xmin, 0.) 112 | inter_vol = h * w 113 | scores = tf.div(inter_vol, vol_anchors) 114 | return scores 115 | 116 | def condition(i,feat_labels, feat_scores, 117 | feat_ymin, feat_xmin, feat_ymax, feat_xmax, 118 | feat_x1, feat_x2, feat_x3, feat_x4, 119 | feat_y1, feat_y2, feat_y3, feat_y4): 120 | """Condition: check label index. 121 | """ 122 | 123 | r = tf.less(i, tf.shape(glabels)[0]) 124 | 125 | return r 126 | 127 | def body(i,feat_labels, feat_scores,feat_ymin, feat_xmin, feat_ymax, feat_xmax, feat_x1, feat_x2, feat_x3, feat_x4, feat_y1, feat_y2, feat_y3, feat_y4): 128 | """Body: update feature labels, scores and bboxes. 129 | Follow the original SSD paper for that purpose: 130 | - assign values when jaccard > 0.5; 131 | - only update if beat the score of other bboxes. 132 | """ 133 | # Jaccard score. 134 | label = glabels[i] 135 | bbox = bboxes[i] 136 | gx = gxs[i] 137 | gy = gys[i] 138 | # i = tf.Print(i, [i , tf.shape(glabels), tf.shape(bboxes), tf.shape(gxs), tf.shape(gys)], message='i is :') 139 | jaccard = jaccard_with_anchors(bbox) 140 | # jaccard = tf.Print(jaccard, [tf.shape(jaccard), tf.nn.top_k(jaccard, 100, sorted=True)[0]], message=' jaccard :', summarize=100) 141 | # feat_scores = tf.Print(feat_scores, [tf.shape(feat_scores),tf.count_nonzero(feat_scores), tf.nn.top_k(feat_scores, 100, sorted=True)[0]], message= ' feat_scores: ', summarize= 100) 142 | # Mask: check threshold + scores + no annotations + num_classes. 143 | mask = tf.greater(jaccard, feat_scores) 144 | 145 | # i = tf.Print(i, [tf.shape(i), i], message= ' i is: ') 146 | # tf.Print(mask, [mask]) 147 | mask = tf.logical_and(mask, tf.greater(jaccard, matching_threshold)) 148 | # mask = tf.logical_and(mask, feat_scores > -0.5) 149 | # mask = tf.logical_and(mask, label < 2) 150 | # mask = tf.Print(mask, [tf.shape(mask), mask[0]], message=' mask is :') 151 | imask = tf.cast(mask, tf.int64) 152 | fmask = tf.cast(mask, dtype) 153 | # Update values using mask. 154 | feat_labels = imask * label + (1 - imask) * feat_labels 155 | feat_scores = tf.where(mask, jaccard, feat_scores) 156 | # bbox ymin xmin ymax xmax gxs gys 157 | # update all box 158 | # bbox = tf.Print(bbox, [tf.shape(bbox), bbox], message= ' bbox : ', summarize=20) 159 | # gx = tf.Print(gx, [gx], message=' gx: ', summarize=20) 160 | # gy = tf.Print(gy, [gy], message= ' gy: ', summarize=20) 161 | # fmask = tf.Print(fmask, [tf.shape(fmask), tf.count_nonzero(fmask), tf.nn.top_k(fmask, 100, sorted=True)[0]], message=' fmask :', summarize=100) 162 | 163 | feat_ymin = fmask * bbox[0] + (1 - fmask) * feat_ymin 164 | feat_xmin = fmask * bbox[1] + (1 - fmask) * feat_xmin 165 | feat_ymax = fmask * bbox[2] + (1 - fmask) * feat_ymax 166 | feat_xmax = fmask * bbox[3] + (1 - fmask) * feat_xmax 167 | 168 | feat_x1 = fmask * gx[0] + (1 - fmask) * feat_x1 169 | feat_x2 = fmask * gx[1] + (1 - fmask) * feat_x2 170 | feat_x3 = fmask * gx[2] + (1 - fmask) * feat_x3 171 | feat_x4 = fmask * gx[3] + (1 - fmask) * feat_x4 172 | 173 | feat_y1 = fmask * gy[0] + (1 - fmask) * feat_y1 174 | feat_y2 = fmask * gy[1] + (1 - fmask) * feat_y2 175 | feat_y3 = fmask * gy[2] + (1 - fmask) * feat_y3 176 | feat_y4 = fmask * gy[3] + (1 - fmask) * feat_y4 177 | 178 | # feat_x1 = tf.Print(feat_x1, [tf.count_nonzero(feat_x1), tf.nn.top_k(feat_x1, 100, sorted=True)[0]], message=' feat x1 : ', summarize=100) 179 | 180 | # feat_ymax = tf.Print(feat_ymax, [tf.shape(feat_ymax), tf.count_nonzero(feat_ymax), feat_ymax,tf.nn.top_k(feat_ymax, 100, sorted=True)[0]], message= ' feat_ymax :', summarize=100) 181 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), tf.count_nonzero(feat_ymin), feat_ymin, tf.nn.top_k(feat_ymax, 100, sorted=True)[0]], message= ' feat_ymin :', summarize=100) 182 | # feat_xmax = tf.Print(feat_xmax, [tf.shape(feat_xmax), tf.count_nonzero(feat_xmax), feat_xmax, tf.nn.top_k(feat_xmax, 100, sorted=True)[0]], message= ' feat_xmax : ', summarize=100) 183 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), tf.count_nonzero(feat_xmin), feat_xmin,tf.nn.top_k(feat_xmin, 100, sorted=True)[0]], message= ' feat_xmin: ' , summarize=100) 184 | 185 | 186 | # Check no annotation label: ignore these anchors... 187 | # interscts = intersection_with_anchors(bbox) 188 | #mask = tf.logical_and(interscts > ignore_threshold, 189 | # label == no_annotation_label) 190 | # Replace scores by -1. 191 | #feat_scores = tf.where(mask, -tf.cast(mask, dtype), feat_scores) 192 | 193 | return [i+1, feat_labels, feat_scores, 194 | feat_ymin, feat_xmin, feat_ymax, feat_xmax, 195 | feat_x1, feat_x2, feat_x3, feat_x4, 196 | feat_y1, feat_y2, feat_y3, feat_y4] 197 | # Main loop definition. 198 | 199 | i = 0 200 | [i,feat_labels, feat_scores, 201 | feat_ymin, feat_xmin, 202 | feat_ymax, feat_xmax, 203 | feat_x1, feat_x2, feat_x3, feat_x4, 204 | feat_y1, feat_y2, feat_y3, feat_y4] = tf.while_loop(condition, body, 205 | [i, feat_labels, feat_scores, 206 | feat_ymin, feat_xmin, 207 | feat_ymax, feat_xmax, feat_x1, feat_x2, feat_x3, feat_x4, feat_y1, feat_y2, feat_y3, feat_y4]) 208 | # Transform to center / size. 209 | ''' 210 | 这里的逻辑是用gt的外接水平矩形框与anchor/default box做匹配,得到iou的mask之后更新anchor对应的gt 211 | 然后求取anchor对应gt的偏移 212 | ''' 213 | # 214 | # feat_ymax = tf.Print(feat_ymax, [tf.shape(feat_ymax), tf.count_nonzero(feat_ymax), feat_ymax], message= ' feat_ymax :', summarize=100) 215 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), tf.count_nonzero(feat_ymin), feat_ymin], message= ' feat_ymin :', summarize=100) 216 | # feat_xmax = tf.Print(feat_xmax, [tf.shape(feat_xmax), tf.count_nonzero(feat_xmax), feat_xmax], message= ' feat_xmax : ', summarize=100) 217 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), tf.count_nonzero(feat_xmin), feat_xmin], message= ' feat_xmin: ' , summarize=100) 218 | 219 | 220 | feat_cy = (feat_ymax + feat_ymin) / 2. 221 | feat_cx = (feat_xmax + feat_xmin) / 2. 222 | feat_h = feat_ymax - feat_ymin 223 | feat_w = feat_xmax - feat_xmin 224 | # Encode features. 225 | 226 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), feat_ymin], message= ' feat_ymin : ', summarize=20) 227 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), feat_xmin], message= ' feat_xmin : ', summarize=20) 228 | 229 | # 230 | 231 | 232 | 233 | # feat_cy = tf.Print(feat_cy, [tf.shape(feat_cy), feat_cy],message=' feat_cy : ', summarize=20) 234 | # feat_cx = tf.Print(feat_cx, [tf.shape(feat_cx), feat_cx],message=' feat_cy : ', summarize=20) 235 | # feat_h = tf.Print(feat_h, [tf.shape(feat_h), feat_h], message=' feat_h : ', summarize=20) 236 | # feat_w = tf.Print(feat_w, [tf.shape(feat_w), feat_w], message=' feat_w : ', summarize=20) 237 | # 238 | # yref = tf.Print(yref, [tf.shape(yref), yref], message=' yref : ',summarize=20) 239 | # xref = tf.Print(xref, [tf.shape(xref), xref], message=' xref : ',summarize=20) 240 | # href = tf.Print(href, [tf.shape(href), href], message=' href : ', 241 | # summarize=20) 242 | # wref = tf.Print(wref, [tf.shape(wref), wref], message=' wref : ', summarize=20) 243 | 244 | feat_ymin = (feat_cy - yref) / href / prior_scaling[1] 245 | feat_xmin = (feat_cx - xref) / wref / prior_scaling[0] 246 | 247 | 248 | feat_ymax = tf.log(feat_h / href) / prior_scaling[3] 249 | feat_xmax = tf.log(feat_w / wref) / prior_scaling[2] 250 | 251 | 252 | feat_x1 = (feat_x1 - xmin) / wref / prior_scaling[0] 253 | feat_x2 = (feat_x2 - xmax) / wref / prior_scaling[0] 254 | feat_x3 = (feat_x3 - xmax) / wref / prior_scaling[0] 255 | feat_x4 = (feat_x4 - xmin) / wref / prior_scaling[0] 256 | 257 | feat_y1 = (feat_y1 - ymin) / href / prior_scaling[1] 258 | feat_y2 = (feat_y2 - ymin) / href / prior_scaling[1] 259 | feat_y3 = (feat_y3 - ymax) / href / prior_scaling[1] 260 | feat_y4 = (feat_y4 - ymax) / href / prior_scaling[1] 261 | 262 | # Use SSD ordering: x / y / w / h instead of ours. 263 | # add xy1, 2,3,4 264 | 265 | # feat_ymin = tf.Print(feat_ymin, [tf.shape(feat_ymin), feat_ymin], message= ' feat_ymin : ', summarize=20) 266 | # feat_xmin = tf.Print(feat_xmin, [tf.shape(feat_xmin), feat_xmin], message= ' feat_xmin : ', summarize=20) 267 | 268 | feat_localizations = tf.stack([feat_xmin, feat_ymin, feat_xmax, feat_ymax ,feat_x1, feat_y1, feat_x2, feat_y2, feat_x3, feat_y3, feat_x4, feat_y4], axis=-1) 269 | # feat_localizations = tf.Print(feat_localizations, [tf.shape(feat_localizations), feat_localizations], message=' feat_localizations: ', summarize=20) 270 | return feat_labels, feat_localizations, feat_scores 271 | 272 | 273 | 274 | def tf_text_bboxes_encode(glabels, bboxes, 275 | anchors,gxs, gys, 276 | matching_threshold=0.1, 277 | prior_scaling=[0.1, 0.1, 0.2, 0.2], 278 | dtype=tf.float32, 279 | scope='text_bboxes_encode'): 280 | """Encode groundtruth labels and bounding boxes using SSD net anchors. 281 | Encoding boxes for all feature layers. 282 | 283 | Arguments: 284 | bboxes: Nx4 Tensor(float) with bboxes relative coordinates; 285 | anchors: List of Numpy array with layer anchors; 286 | gxs,gys:shape = (N,4) with x,y coordinates 287 | matching_threshold: Threshold for positive match with groundtruth bboxes; 288 | prior_scaling: Scaling of encoded coordinates. 289 | 290 | Return: 291 | (target_labels, target_localizations, target_scores): 292 | Each element is a list of target Tensors. 293 | """ 294 | 295 | with tf.name_scope('text_bboxes_encode'): 296 | target_labels = [] 297 | target_localizations = [] 298 | target_scores = [] 299 | for i, anchors_layer in enumerate(anchors): 300 | with tf.name_scope('bboxes_encode_block_%i' % i): 301 | t_label, t_loc, t_scores = \ 302 | tf_text_bboxes_encode_layer(glabels, bboxes,gxs, gys, anchors_layer, 303 | matching_threshold, 304 | prior_scaling, dtype) 305 | target_localizations.append(t_loc) 306 | target_scores.append(t_scores) 307 | target_labels.append(t_label) 308 | return target_localizations, target_scores, target_labels 309 | 310 | 311 | 312 | -------------------------------------------------------------------------------- /nets/textbox_common.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/textbox_common.pyc -------------------------------------------------------------------------------- /nets/txtbox_384.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/txtbox_384.pyc -------------------------------------------------------------------------------- /nets/txtbox_768.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/nets/txtbox_768.pyc -------------------------------------------------------------------------------- /processing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /processing/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__init__.pyc -------------------------------------------------------------------------------- /processing/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /processing/__pycache__/ssd_vgg_preprocessing.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__pycache__/ssd_vgg_preprocessing.cpython-35.pyc -------------------------------------------------------------------------------- /processing/__pycache__/tf_image.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/__pycache__/tf_image.cpython-35.pyc -------------------------------------------------------------------------------- /processing/ssd_vgg_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Pre-processing images for SSD-type networks. 16 | """ 17 | from enum import Enum, IntEnum 18 | import numpy as np 19 | 20 | import tensorflow as tf 21 | import tf_extended as tfe 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | 25 | from processing import tf_image 26 | 27 | 28 | slim = tf.contrib.slim 29 | 30 | # Resizing strategies. 31 | Resize = IntEnum('Resize', ('NONE', # Nothing! 32 | 'CENTRAL_CROP', # Crop (and pad if necessary). 33 | 'PAD_AND_RESIZE', # Pad, and resize to output shape. 34 | 'WARP_RESIZE')) # Warp resize. 35 | 36 | # VGG mean parameters. 37 | _R_MEAN = 123.68 38 | _G_MEAN = 116.78 39 | _B_MEAN = 103.94 40 | 41 | # Some training pre-processing parameters. 42 | BBOX_CROP_OVERLAP = 0.15 # Minimum overlap to keep a bbox after cropping. 43 | CROP_RATIO_RANGE = (0.5, 1.5) # Distortion ratio during cropping. 44 | EVAL_SIZE = (384, 384) 45 | AREA_RANGE = [1., 1.] 46 | MIN_OBJECT_COVERED = 0.5 47 | 48 | 49 | def _mean_image_subtraction(image, means): 50 | 51 | if image.get_shape().ndims != 3: 52 | raise ValueError('Input must be of size [height, width, C>0]') 53 | num_channels = image.get_shape().as_list()[-1] 54 | if len(means) != num_channels: 55 | raise ValueError('len(means) must match the number of channels') 56 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) 57 | for i in range(num_channels): 58 | channels[i] -= means[i] 59 | return tf.concat(axis=2, values=channels) 60 | 61 | def tf_image_whitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN]): 62 | """Subtracts the given means from each image channel. 63 | 64 | Returns: 65 | the centered image. 66 | """ 67 | if image.get_shape().ndims != 3: 68 | raise ValueError('Input must be of size [height, width, C>0]') 69 | num_channels = image.get_shape().as_list()[-1] 70 | if len(means) != num_channels: 71 | raise ValueError('len(means) must match the number of channels') 72 | 73 | mean = tf.constant(means, dtype=image.dtype) 74 | image = image - mean 75 | return image 76 | 77 | 78 | def tf_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True): 79 | """Re-convert to original image distribution, and convert to int if 80 | necessary. 81 | 82 | Returns: 83 | Centered image. 84 | """ 85 | mean = tf.constant(means, dtype=image.dtype) 86 | image = image + mean 87 | if to_int: 88 | image = tf.cast(image, tf.int32) 89 | return image 90 | 91 | 92 | def np_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True): 93 | """Re-convert to original image distribution, and convert to int if 94 | necessary. Numpy version. 95 | 96 | Returns: 97 | Centered image. 98 | """ 99 | img = np.copy(image) 100 | img += np.array(means, dtype=img.dtype) 101 | if to_int: 102 | img = img.astype(np.uint8) 103 | return img 104 | 105 | 106 | def tf_summary_image(image, bboxes, name='image', unwhitened=False): 107 | """Add image with bounding boxes to summary. 108 | """ 109 | if unwhitened: 110 | image = tf_image_unwhitened(image) 111 | image = tf.expand_dims(image, 0) 112 | bboxes = tf.expand_dims(bboxes, 0) 113 | image_with_box = tf.image.draw_bounding_boxes(image, bboxes) 114 | tf.summary.image(name, image_with_box) 115 | 116 | 117 | def apply_with_random_selector(x, func, num_cases): 118 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 119 | 120 | Args: 121 | x: input Tensor. 122 | func: Python function to apply. 123 | num_cases: Python int32, number of cases to sample sel from. 124 | 125 | Returns: 126 | The result of func(x, sel), where func receives the value of the 127 | selector as a python integer, but sel is sampled dynamically. 128 | """ 129 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 130 | # Pass the real x only to one of the func calls. 131 | return control_flow_ops.merge([ 132 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 133 | for case in range(num_cases)])[0] 134 | 135 | 136 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 137 | """Distort the color of a Tensor image. 138 | 139 | Each color distortion is non-commutative and thus ordering of the color ops 140 | matters. Ideally we would randomly permute the ordering of the color ops. 141 | Rather then adding that level of complication, we select a distinct ordering 142 | of color ops for each preprocessing thread. 143 | 144 | Args: 145 | image: 3-D Tensor containing single image in [0, 1]. 146 | color_ordering: Python int, a type of distortion (valid values: 0-3). 147 | fast_mode: Avoids slower ops (random_hue and random_contrast) 148 | scope: Optional scope for name_scope. 149 | Returns: 150 | 3-D Tensor color-distorted image on range [0, 1] 151 | Raises: 152 | ValueError: if color_ordering not in [0, 3] 153 | """ 154 | with tf.name_scope(scope, 'distort_color', [image]): 155 | if fast_mode: 156 | if color_ordering == 0: 157 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 158 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 159 | else: 160 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 161 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 162 | else: 163 | if color_ordering == 0: 164 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 165 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 166 | image = tf.image.random_hue(image, max_delta=0.2) 167 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 168 | elif color_ordering == 1: 169 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 170 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 171 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 172 | image = tf.image.random_hue(image, max_delta=0.2) 173 | elif color_ordering == 2: 174 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 175 | image = tf.image.random_hue(image, max_delta=0.2) 176 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 177 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 178 | elif color_ordering == 3: 179 | image = tf.image.random_hue(image, max_delta=0.2) 180 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 181 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 182 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 183 | else: 184 | raise ValueError('color_ordering must be in [0, 3]') 185 | # The random_* ops do not necessarily clamp. 186 | return tf.clip_by_value(image, 0.0, 1.0) 187 | 188 | 189 | def distorted_bounding_box_crop(image, 190 | labels, 191 | bboxes, 192 | xs, ys, 193 | min_object_covered=0.05, 194 | aspect_ratio_range=(0.9, 1.1), 195 | area_range=(0.1, 1.0), 196 | max_attempts=200, 197 | scope=None): 198 | """Generates cropped_image using a one of the bboxes randomly distorted. 199 | 200 | See `tf.image.sample_distorted_bounding_box` for more documentation. 201 | 202 | Args: 203 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 204 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 205 | where each coordinate is [0, 1) and the coordinates are arranged 206 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 207 | image. 208 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 209 | area of the image must contain at least this fraction of any bounding box 210 | supplied. 211 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 212 | image must have an aspect ratio = width / height within this range. 213 | area_range: An optional list of `floats`. The cropped area of the image 214 | must contain a fraction of the supplied image within in this range. 215 | max_attempts: An optional `int`. Number of attempts at generating a cropped 216 | region of the image of the specified constraints. After `max_attempts` 217 | failures, return the entire image. 218 | scope: Optional scope for name_scope. 219 | Returns: 220 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 221 | """ 222 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bboxes, xs,ys]): 223 | # Each bounding box has shape [1, num_boxes, box coords] and 224 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 225 | bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box( 226 | tf.shape(image), 227 | bounding_boxes=tf.expand_dims(bboxes, 0), 228 | min_object_covered=min_object_covered, 229 | aspect_ratio_range=aspect_ratio_range, 230 | area_range=area_range, 231 | max_attempts=max_attempts, 232 | use_image_if_no_bounding_boxes=True) 233 | distort_bbox = distort_bbox[0, 0] 234 | 235 | # Crop the image to the specified bounding box. 236 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 237 | # Restore the shape since the dynamic slice loses 3rd dimension. 238 | cropped_image.set_shape([None, None, 3]) 239 | 240 | # Update bounding boxes: resize and filter out. 241 | bboxes, xs, ys = tfe.bboxes_resize(distort_bbox, bboxes, xs, ys) 242 | labels, bboxes, xs, ys = tfe.bboxes_filter_overlap(labels, bboxes,xs, ys, 243 | BBOX_CROP_OVERLAP) 244 | return cropped_image, labels, bboxes,xs, ys, distort_bbox 245 | 246 | 247 | def preprocess_for_train(image, labels, bboxes, xs, ys, 248 | out_shape, data_format='NHWC', 249 | scope='ssd_preprocessing_train', clip=True, crop_area_range=AREA_RANGE): 250 | """Preprocesses the given image for training. 251 | 252 | Note that the actual resizing scale is sampled from 253 | [`resize_size_min`, `resize_size_max`]. 254 | 255 | Args: 256 | image: A `Tensor` representing an image of arbitrary size. 257 | output_height: The height of the image after preprocessing. 258 | output_width: The width of the image after preprocessing. 259 | resize_side_min: The lower bound for the smallest side of the image for 260 | aspect-preserving resizing. 261 | resize_side_max: The upper bound for the smallest side of the image for 262 | aspect-preserving resizing. 263 | 264 | Returns: 265 | A preprocessed image. 266 | """ 267 | fast_mode = False 268 | with tf.name_scope(scope, 'ssd_preprocessing_train', [image, labels, bboxes]): 269 | if image.get_shape().ndims != 3: 270 | raise ValueError('Input must be of size [height, width, C>0]') 271 | 272 | orig_dtype = image.dtype 273 | print('orig_dtype:', orig_dtype) 274 | # Convert to float scaled [0, 1]. 275 | if image.dtype != tf.float32: 276 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 277 | # tf_summary_image(image, bboxes, 'image_with_bboxes') 278 | 279 | # Distort image and bounding boxes. 280 | dst_image = image 281 | dst_image, labels, bboxes,xs, ys, distort_bbox = \ 282 | distorted_bounding_box_crop(image, labels, bboxes,xs, ys, 283 | aspect_ratio_range=CROP_RATIO_RANGE,min_object_covered=MIN_OBJECT_COVERED,area_range=crop_area_range) 284 | # Resize image to output size. 285 | dst_image = tf_image.resize_image(dst_image, out_shape, 286 | method=tf.image.ResizeMethod.BILINEAR, 287 | align_corners=False) 288 | #tf_summary_image(dst_image, bboxes, 'image_shape_distorted') 289 | 290 | # Randomly flip the image horizontally. 291 | #bboxes and xs ys all need to random 292 | 293 | dst_image, bboxes, xs, ys = tf_image.random_flip_left_right(dst_image, bboxes, xs, ys) 294 | 295 | # Randomly distort the colors. There are 4 ways to do it. 296 | dst_image = apply_with_random_selector( 297 | dst_image, 298 | lambda x, ordering: distort_color(x, ordering, fast_mode), 299 | num_cases=4) 300 | 301 | tf_summary_image(dst_image, bboxes, 'image_color_distorted') 302 | 303 | # Rescale to VGG input scale. 304 | image = tf.to_float(tf.image.convert_image_dtype(dst_image, orig_dtype, saturate=True)) 305 | # image = dst_image * 255. 306 | # image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 307 | image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN] ) 308 | # Image data format. 309 | if data_format == 'NCHW': 310 | image = tf.transpose(image, perm=(2, 0, 1)) 311 | 312 | if clip: 313 | xy_clip_min = tf.constant([0., 0., 0., 0.]) 314 | xy_clip_max = tf.constant([1., 1., 1., 1.]) 315 | bbox_img_max = tf.constant([1., 1., 1. , 1.]) 316 | bbox_img_min = tf.constant([0., 0., 0., 0.]) 317 | 318 | bboxes = tf.minimum(bboxes, bbox_img_max) 319 | bboxes = tf.maximum(bboxes, bbox_img_min) 320 | 321 | 322 | xs = tf.maximum(xs, xy_clip_min) 323 | ys = tf.maximum(ys, xy_clip_min) 324 | xs = tf.minimum(xs, xy_clip_max) 325 | ys = tf.minimum(ys, xy_clip_max) 326 | 327 | tf_summary_image(image, bboxes, ' image whitened') 328 | # image = tf.Print(image, [image[0]], ' image: ', summarize=20) 329 | # xs = tf.Print(xs, [xs, tf.shape(xs)], ' xs ', summarize=20) 330 | # ys = tf.Print(ys, [ys, tf.shape(ys)], ' ys ', summarize=20) 331 | # bboxes = tf.Print(bboxes, [bboxes, tf.shape(bboxes)], ' bboxes ',summarize=20) 332 | return image, labels, bboxes, xs, ys 333 | 334 | 335 | 336 | 337 | def preprocess_for_eval(image, labels, bboxes,xs, ys, 338 | out_shape=EVAL_SIZE, data_format='NHWC', 339 | difficults=None, resize=Resize.WARP_RESIZE, 340 | scope='ssd_preprocessing_train'): 341 | """Preprocess an image for evaluation. 342 | 343 | Args: 344 | image: A `Tensor` representing an image of arbitrary size. 345 | out_shape: Output shape after pre-processing (if resize != None) 346 | resize: Resize strategy. 347 | 348 | Returns: 349 | A preprocessed image. 350 | """ 351 | with tf.name_scope(scope): 352 | if image.get_shape().ndims != 3: 353 | raise ValueError('Input must be of size [height, width, C>0]') 354 | 355 | image = tf.to_float(image) 356 | image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 357 | 358 | # Add image rectangle to bboxes. 359 | bbox_img = tf.constant([[0., 0., 1., 1., 0., 0. , 0., 0., 0., 0., 0., 0.]]) 360 | if bboxes is None: 361 | bboxes = bbox_img 362 | else: 363 | bboxes = tf.concat([bbox_img, bboxes], axis=0) 364 | 365 | if resize == Resize.NONE: 366 | # No resizing... 367 | pass 368 | elif resize == Resize.CENTRAL_CROP: 369 | # Central cropping of the image. 370 | image, bboxes = tf_image.resize_image_bboxes_with_crop_or_pad( 371 | image, bboxes, out_shape[0], out_shape[1]) 372 | elif resize == Resize.PAD_AND_RESIZE: 373 | # Resize image first: find the correct factor... 374 | shape = tf.shape(image) 375 | factor = tf.minimum(tf.to_double(1.0), 376 | tf.minimum(tf.to_double(out_shape[0] / shape[0]), 377 | tf.to_double(out_shape[1] / shape[1]))) 378 | resize_shape = factor * tf.to_double(shape[0:2]) 379 | resize_shape = tf.cast(tf.floor(resize_shape), tf.int32) 380 | 381 | image = tf_image.resize_image(image, resize_shape, 382 | method=tf.image.ResizeMethod.BILINEAR, 383 | align_corners=False) 384 | # Pad to expected size. 385 | image, bboxes = tf_image.resize_image_bboxes_with_crop_or_pad( 386 | image, bboxes, out_shape[0], out_shape[1]) 387 | elif resize == Resize.WARP_RESIZE: 388 | # Warp resize of the image. 389 | image = tf_image.resize_image(image, out_shape, 390 | method=tf.image.ResizeMethod.BILINEAR, 391 | align_corners=False) 392 | 393 | # Split back bounding boxes. 394 | bbox_img = bboxes[0] 395 | bboxes = bboxes[1:] 396 | # Remove difficult boxes. 397 | if difficults is not None: 398 | mask = tf.logical_not(tf.cast(difficults, tf.bool)) 399 | labels = tf.boolean_mask(labels, mask) 400 | bboxes = tf.boolean_mask(bboxes, mask) 401 | # Image data format. 402 | if data_format == 'NCHW': 403 | image = tf.transpose(image, perm=(2, 0, 1)) 404 | return image, labels, bboxes, bbox_img, xs, ys 405 | 406 | 407 | def preprocess_image(image, 408 | labels, 409 | bboxes, 410 | xs, ys, 411 | out_shape, 412 | data_format = 'NHWC', 413 | is_training=False, 414 | **kwargs): 415 | """Pre-process an given image. 416 | 417 | Args: 418 | image: A `Tensor` representing an image of arbitrary size. 419 | output_height: The height of the image after preprocessing. 420 | output_width: The width of the image after preprocessing. 421 | is_training: `True` if we're preprocessing the image for training and 422 | `False` otherwise. 423 | resize_side_min: The lower bound for the smallest side of the image for 424 | aspect-preserving resizing. If `is_training` is `False`, then this value 425 | is used for rescaling. 426 | resize_side_max: The upper bound for the smallest side of the image for 427 | aspect-preserving resizing. If `is_training` is `False`, this value is 428 | ignored. Otherwise, the resize side is sampled from 429 | [resize_size_min, resize_size_max]. 430 | 431 | Returns: 432 | A preprocessed image. 433 | """ 434 | if is_training: 435 | return preprocess_for_train(image, labels, bboxes,xs, ys, 436 | out_shape=out_shape, 437 | data_format=data_format) 438 | else: 439 | return preprocess_for_eval(image, labels, bboxes,xs, ys, 440 | out_shape=out_shape, 441 | data_format=data_format, 442 | **kwargs) 443 | -------------------------------------------------------------------------------- /processing/ssd_vgg_preprocessing.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/ssd_vgg_preprocessing.pyc -------------------------------------------------------------------------------- /processing/tf_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors and Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Custom image operations. 16 | Most of the following methods extend TensorFlow image library, and part of 17 | the code is shameless copy-paste of the former! 18 | """ 19 | import tensorflow as tf 20 | 21 | from tensorflow.python.framework import constant_op 22 | from tensorflow.python.framework import dtypes 23 | from tensorflow.python.framework import ops 24 | from tensorflow.python.framework import tensor_shape 25 | from tensorflow.python.framework import tensor_util 26 | from tensorflow.python.ops import array_ops 27 | from tensorflow.python.ops import check_ops 28 | from tensorflow.python.ops import clip_ops 29 | from tensorflow.python.ops import control_flow_ops 30 | from tensorflow.python.ops import gen_image_ops 31 | from tensorflow.python.ops import gen_nn_ops 32 | from tensorflow.python.ops import string_ops 33 | from tensorflow.python.ops import math_ops 34 | from tensorflow.python.ops import random_ops 35 | from tensorflow.python.ops import variables 36 | 37 | 38 | # =========================================================================== # 39 | # Modification of TensorFlow image routines. 40 | # =========================================================================== # 41 | def _assert(cond, ex_type, msg): 42 | """A polymorphic assert, works with tensors and boolean expressions. 43 | If `cond` is not a tensor, behave like an ordinary assert statement, except 44 | that a empty list is returned. If `cond` is a tensor, return a list 45 | containing a single TensorFlow assert op. 46 | Args: 47 | cond: Something evaluates to a boolean value. May be a tensor. 48 | ex_type: The exception class to use. 49 | msg: The error message. 50 | Returns: 51 | A list, containing at most one assert op. 52 | """ 53 | if _is_tensor(cond): 54 | return [control_flow_ops.Assert(cond, [msg])] 55 | else: 56 | if not cond: 57 | raise ex_type(msg) 58 | else: 59 | return [] 60 | 61 | 62 | def _is_tensor(x): 63 | """Returns `True` if `x` is a symbolic tensor-like object. 64 | Args: 65 | x: A python object to check. 66 | Returns: 67 | `True` if `x` is a `tf.Tensor` or `tf.Variable`, otherwise `False`. 68 | """ 69 | return isinstance(x, (ops.Tensor, variables.Variable)) 70 | 71 | 72 | def _ImageDimensions(image): 73 | """Returns the dimensions of an image tensor. 74 | Args: 75 | image: A 3-D Tensor of shape `[height, width, channels]`. 76 | Returns: 77 | A list of `[height, width, channels]` corresponding to the dimensions of the 78 | input image. Dimensions that are statically known are python integers, 79 | otherwise they are integer scalar tensors. 80 | """ 81 | if image.get_shape().is_fully_defined(): 82 | return image.get_shape().as_list() 83 | else: 84 | static_shape = image.get_shape().with_rank(3).as_list() 85 | dynamic_shape = array_ops.unstack(array_ops.shape(image), 3) 86 | return [s if s is not None else d 87 | for s, d in zip(static_shape, dynamic_shape)] 88 | 89 | 90 | def _Check3DImage(image, require_static=True): 91 | """Assert that we are working with properly shaped image. 92 | Args: 93 | image: 3-D Tensor of shape [height, width, channels] 94 | require_static: If `True`, requires that all dimensions of `image` are 95 | known and non-zero. 96 | Raises: 97 | ValueError: if `image.shape` is not a 3-vector. 98 | Returns: 99 | An empty list, if `image` has fully defined dimensions. Otherwise, a list 100 | containing an assert op is returned. 101 | """ 102 | try: 103 | image_shape = image.get_shape().with_rank(3) 104 | except ValueError: 105 | raise ValueError("'image' must be three-dimensional.") 106 | if require_static and not image_shape.is_fully_defined(): 107 | raise ValueError("'image' must be fully defined.") 108 | if any(x == 0 for x in image_shape): 109 | raise ValueError("all dims of 'image.shape' must be > 0: %s" % 110 | image_shape) 111 | if not image_shape.is_fully_defined(): 112 | return [check_ops.assert_positive(array_ops.shape(image), 113 | ["all dims of 'image.shape' " 114 | "must be > 0."])] 115 | else: 116 | return [] 117 | 118 | 119 | def fix_image_flip_shape(image, result): 120 | """Set the shape to 3 dimensional if we don't know anything else. 121 | Args: 122 | image: original image size 123 | result: flipped or transformed image 124 | Returns: 125 | An image whose shape is at least None,None,None. 126 | """ 127 | image_shape = image.get_shape() 128 | if image_shape == tensor_shape.unknown_shape(): 129 | result.set_shape([None, None, None]) 130 | else: 131 | result.set_shape(image_shape) 132 | return result 133 | 134 | 135 | # =========================================================================== # 136 | # Image + BBoxes methods: cropping, resizing, flipping, ... 137 | # =========================================================================== # 138 | def bboxes_crop_or_pad(bboxes, 139 | height, width, 140 | offset_y, offset_x, 141 | target_height, target_width): 142 | """Adapt bounding boxes to crop or pad operations. 143 | Coordinates are always supposed to be relative to the image. 144 | 145 | Arguments: 146 | bboxes: Tensor Nx4 with bboxes coordinates [y_min, x_min, y_max, x_max]; 147 | height, width: Original image dimension; 148 | offset_y, offset_x: Offset to apply, 149 | negative if cropping, positive if padding; 150 | target_height, target_width: Target dimension after cropping / padding. 151 | """ 152 | with tf.name_scope('bboxes_crop_or_pad'): 153 | # Rescale bounding boxes in pixels. 154 | scale = tf.cast(tf.stack([height, width, height, width]), bboxes.dtype) 155 | bboxes = bboxes * scale 156 | # Add offset. 157 | offset = tf.cast(tf.stack([offset_y, offset_x, offset_y, offset_x]), bboxes.dtype) 158 | bboxes = bboxes + offset 159 | # Rescale to target dimension. 160 | scale = tf.cast(tf.stack([target_height, target_width, 161 | target_height, target_width]), bboxes.dtype) 162 | bboxes = bboxes / scale 163 | return bboxes 164 | 165 | 166 | def resize_image_bboxes_with_crop_or_pad(image, bboxes, 167 | target_height, target_width): 168 | """Crops and/or pads an image to a target width and height. 169 | Resizes an image to a target width and height by either centrally 170 | cropping the image or padding it evenly with zeros. 171 | 172 | If `width` or `height` is greater than the specified `target_width` or 173 | `target_height` respectively, this op centrally crops along that dimension. 174 | If `width` or `height` is smaller than the specified `target_width` or 175 | `target_height` respectively, this op centrally pads with 0 along that 176 | dimension. 177 | Args: 178 | image: 3-D tensor of shape `[height, width, channels]` 179 | target_height: Target height. 180 | target_width: Target width. 181 | Raises: 182 | ValueError: if `target_height` or `target_width` are zero or negative. 183 | Returns: 184 | Cropped and/or padded image of shape 185 | `[target_height, target_width, channels]` 186 | """ 187 | with tf.name_scope('resize_with_crop_or_pad'): 188 | image = ops.convert_to_tensor(image, name='image') 189 | 190 | assert_ops = [] 191 | assert_ops += _Check3DImage(image, require_static=False) 192 | assert_ops += _assert(target_width > 0, ValueError, 193 | 'target_width must be > 0.') 194 | assert_ops += _assert(target_height > 0, ValueError, 195 | 'target_height must be > 0.') 196 | 197 | image = control_flow_ops.with_dependencies(assert_ops, image) 198 | # `crop_to_bounding_box` and `pad_to_bounding_box` have their own checks. 199 | # Make sure our checks come first, so that error messages are clearer. 200 | if _is_tensor(target_height): 201 | target_height = control_flow_ops.with_dependencies( 202 | assert_ops, target_height) 203 | if _is_tensor(target_width): 204 | target_width = control_flow_ops.with_dependencies(assert_ops, target_width) 205 | 206 | def max_(x, y): 207 | if _is_tensor(x) or _is_tensor(y): 208 | return math_ops.maximum(x, y) 209 | else: 210 | return max(x, y) 211 | 212 | def min_(x, y): 213 | if _is_tensor(x) or _is_tensor(y): 214 | return math_ops.minimum(x, y) 215 | else: 216 | return min(x, y) 217 | 218 | def equal_(x, y): 219 | if _is_tensor(x) or _is_tensor(y): 220 | return math_ops.equal(x, y) 221 | else: 222 | return x == y 223 | 224 | height, width, _ = _ImageDimensions(image) 225 | width_diff = target_width - width 226 | offset_crop_width = max_(-width_diff // 2, 0) 227 | offset_pad_width = max_(width_diff // 2, 0) 228 | 229 | height_diff = target_height - height 230 | offset_crop_height = max_(-height_diff // 2, 0) 231 | offset_pad_height = max_(height_diff // 2, 0) 232 | 233 | # Maybe crop if needed. 234 | height_crop = min_(target_height, height) 235 | width_crop = min_(target_width, width) 236 | cropped = tf.image.crop_to_bounding_box(image, offset_crop_height, offset_crop_width, 237 | height_crop, width_crop) 238 | bboxes = bboxes_crop_or_pad(bboxes, 239 | height, width, 240 | -offset_crop_height, -offset_crop_width, 241 | height_crop, width_crop) 242 | # Maybe pad if needed. 243 | resized = tf.image.pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width, 244 | target_height, target_width) 245 | bboxes = bboxes_crop_or_pad(bboxes, 246 | height_crop, width_crop, 247 | offset_pad_height, offset_pad_width, 248 | target_height, target_width) 249 | 250 | # In theory all the checks below are redundant. 251 | if resized.get_shape().ndims is None: 252 | raise ValueError('resized contains no shape.') 253 | 254 | resized_height, resized_width, _ = _ImageDimensions(resized) 255 | 256 | assert_ops = [] 257 | assert_ops += _assert(equal_(resized_height, target_height), ValueError, 258 | 'resized height is not correct.') 259 | assert_ops += _assert(equal_(resized_width, target_width), ValueError, 260 | 'resized width is not correct.') 261 | 262 | resized = control_flow_ops.with_dependencies(assert_ops, resized) 263 | return resized, bboxes 264 | 265 | 266 | def resize_image(image, size, 267 | method=tf.image.ResizeMethod.BILINEAR, 268 | align_corners=False): 269 | """Resize an image and bounding boxes. 270 | """ 271 | # Resize image. 272 | with tf.name_scope('resize_image'): 273 | height, width, channels = _ImageDimensions(image) 274 | image = tf.expand_dims(image, 0) 275 | image = tf.image.resize_images(image, size, 276 | method, align_corners) 277 | image = tf.reshape(image, tf.stack([size[0], size[1], channels])) 278 | return image 279 | 280 | 281 | def random_flip_left_right(image, bboxes, xs, ys, seed=None): 282 | """Random flip left-right of an image and its bounding boxes. 283 | """ 284 | def flip_bboxes(bboxes): 285 | """Flip bounding boxes coordinates. 286 | """ 287 | bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3], 288 | bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1) 289 | return bboxes 290 | 291 | def flip_xs(xs): 292 | """Flip xs coordinates 293 | """ 294 | # xs_temp = tf.ones(xs.shape) 295 | # xs_temp[:, 0] = 1 - xs[:, 1] 296 | # xs_temp[:, 1] = 1 - xs[:, 0] 297 | # xs_temp[:, 2] = 1 - xs[:, 3] 298 | # xs_temp[:, 3] = 1 - xs[:, 2] 299 | 300 | xs = tf.stack([1 - xs[:, 1], 1 - xs[:, 0], 1 - xs[ :, 3], 1 - xs[ :, 2]], axis=-1) 301 | return xs 302 | 303 | def flip_ys(ys): 304 | """Flip ys coordinates 305 | """ 306 | # ys_temp = tf.ones(ys.shape) 307 | # ys_temp[:, 0] = ys[: ,1] 308 | # ys_temp[:, 1] = ys[:, 0] 309 | # ys_temp[:, 2] = ys[:, 3] 310 | # ys_temp[:, 3] = ys[:, 2] 311 | # return ys_temp 312 | ys = tf.stack([ys[:, 1], ys[: ,0], ys[: ,3], ys[:, 2]], axis=-1) 313 | return ys 314 | 315 | 316 | # Random flip. Tensorflow implementation. 317 | with tf.name_scope('random_flip_left_right'): 318 | image = ops.convert_to_tensor(image, name='image') 319 | _Check3DImage(image, require_static=False) 320 | uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) 321 | mirror_cond = math_ops.less(uniform_random, .5) 322 | # Flip image. 323 | result = control_flow_ops.cond(mirror_cond, 324 | lambda: array_ops.reverse_v2(image, [1]), 325 | lambda: image) 326 | # Flip bboxes. 327 | bboxes = control_flow_ops.cond(mirror_cond, 328 | lambda: flip_bboxes(bboxes), 329 | lambda: bboxes) 330 | 331 | xs = control_flow_ops.cond(mirror_cond, lambda: flip_xs(xs), lambda: xs) 332 | ys = control_flow_ops.cond(mirror_cond, lambda: flip_ys(ys), lambda: ys) 333 | return fix_image_flip_shape(image, result), bboxes, xs, ys 334 | 335 | -------------------------------------------------------------------------------- /processing/tf_image.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/processing/tf_image.pyc -------------------------------------------------------------------------------- /tf_extended/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional metrics. 16 | """ 17 | 18 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import 19 | from tf_extended.metrics import * 20 | from tf_extended.tensors import * 21 | from tf_extended.bboxes import * 22 | from tf_extended.image import * 23 | from tf_extended.math import * 24 | 25 | -------------------------------------------------------------------------------- /tf_extended/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__init__.pyc -------------------------------------------------------------------------------- /tf_extended/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /tf_extended/__pycache__/bboxes.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/bboxes.cpython-35.pyc -------------------------------------------------------------------------------- /tf_extended/__pycache__/image.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/image.cpython-35.pyc -------------------------------------------------------------------------------- /tf_extended/__pycache__/math.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/math.cpython-35.pyc -------------------------------------------------------------------------------- /tf_extended/__pycache__/metrics.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/metrics.cpython-35.pyc -------------------------------------------------------------------------------- /tf_extended/__pycache__/tensors.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/__pycache__/tensors.cpython-35.pyc -------------------------------------------------------------------------------- /tf_extended/bboxes.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/bboxes.pyc -------------------------------------------------------------------------------- /tf_extended/image.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/image.py -------------------------------------------------------------------------------- /tf_extended/image.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/image.pyc -------------------------------------------------------------------------------- /tf_extended/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional math functions. 16 | """ 17 | import tensorflow as tf 18 | 19 | from tensorflow.python.ops import array_ops 20 | from tensorflow.python.ops import math_ops 21 | from tensorflow.python.framework import dtypes 22 | from tensorflow.python.framework import ops 23 | 24 | 25 | def safe_divide(numerator, denominator, name): 26 | """Divides two values, returning 0 if the denominator is <= 0. 27 | Args: 28 | numerator: A real `Tensor`. 29 | denominator: A real `Tensor`, with dtype matching `numerator`. 30 | name: Name for the returned op. 31 | Returns: 32 | 0 if `denominator` <= 0, else `numerator` / `denominator` 33 | """ 34 | return tf.where( 35 | math_ops.greater(denominator, 0), 36 | math_ops.divide(numerator, denominator), 37 | tf.zeros_like(numerator), 38 | name=name) 39 | 40 | 41 | def cummax(x, reverse=False, name=None): 42 | """Compute the cumulative maximum of the tensor `x` along `axis`. This 43 | operation is similar to the more classic `cumsum`. Only support 1D Tensor 44 | for now. 45 | 46 | Args: 47 | x: A `Tensor`. Must be one of the following types: `float32`, `float64`, 48 | `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, `complex64`, 49 | `complex128`, `qint8`, `quint8`, `qint32`, `half`. 50 | axis: A `Tensor` of type `int32` (default: 0). 51 | reverse: A `bool` (default: False). 52 | name: A name for the operation (optional). 53 | Returns: 54 | A `Tensor`. Has the same type as `x`. 55 | """ 56 | with ops.name_scope(name, "Cummax", [x]) as name: 57 | x = ops.convert_to_tensor(x, name="x") 58 | # Not very optimal: should directly integrate reverse into tf.scan. 59 | if reverse: 60 | x = tf.reverse(x, axis=[0]) 61 | # 'Accumlating' maximum: ensure it is always increasing. 62 | cmax = tf.scan(lambda a, y: tf.maximum(a, y), x, 63 | initializer=None, parallel_iterations=1, 64 | back_prop=False, swap_memory=False) 65 | if reverse: 66 | cmax = tf.reverse(cmax, axis=[0]) 67 | return cmax 68 | -------------------------------------------------------------------------------- /tf_extended/math.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/math.pyc -------------------------------------------------------------------------------- /tf_extended/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional metrics. 16 | """ 17 | import tensorflow as tf 18 | import numpy as np 19 | 20 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables 21 | from tensorflow.python.framework import dtypes 22 | from tensorflow.python.framework import ops 23 | from tensorflow.python.ops import array_ops 24 | from tensorflow.python.ops import math_ops 25 | from tensorflow.python.ops import nn 26 | from tensorflow.python.ops import state_ops 27 | from tensorflow.python.ops import variable_scope 28 | from tensorflow.python.ops import variables 29 | 30 | from tf_extended import math as tfe_math 31 | 32 | 33 | # =========================================================================== # 34 | # TensorFlow utils 35 | # =========================================================================== # 36 | def _create_local(name, shape, collections=None, validate_shape=True, 37 | dtype=dtypes.float32): 38 | """Creates a new local variable. 39 | Args: 40 | name: The name of the new or existing variable. 41 | shape: Shape of the new or existing variable. 42 | collections: A list of collection names to which the Variable will be added. 43 | validate_shape: Whether to validate the shape of the variable. 44 | dtype: Data type of the variables. 45 | Returns: 46 | The created variable. 47 | """ 48 | # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 49 | collections = list(collections or []) 50 | collections += [ops.GraphKeys.LOCAL_VARIABLES] 51 | return variables.Variable( 52 | initial_value=array_ops.zeros(shape, dtype=dtype), 53 | name=name, 54 | trainable=False, 55 | collections=collections, 56 | validate_shape=validate_shape) 57 | 58 | 59 | def _safe_div(numerator, denominator, name): 60 | """Divides two values, returning 0 if the denominator is <= 0. 61 | Args: 62 | numerator: A real `Tensor`. 63 | denominator: A real `Tensor`, with dtype matching `numerator`. 64 | name: Name for the returned op. 65 | Returns: 66 | 0 if `denominator` <= 0, else `numerator` / `denominator` 67 | """ 68 | return tf.where( 69 | math_ops.greater(denominator, 0), 70 | math_ops.divide(numerator, denominator), 71 | tf.zeros_like(numerator), 72 | name=name) 73 | 74 | 75 | def _broadcast_weights(weights, values): 76 | """Broadcast `weights` to the same shape as `values`. 77 | This returns a version of `weights` following the same broadcast rules as 78 | `mul(weights, values)`. When computing a weighted average, use this function 79 | to broadcast `weights` before summing them; e.g., 80 | `reduce_sum(w * v) / reduce_sum(_broadcast_weights(w, v))`. 81 | Args: 82 | weights: `Tensor` whose shape is broadcastable to `values`. 83 | values: `Tensor` of any shape. 84 | Returns: 85 | `weights` broadcast to `values` shape. 86 | """ 87 | weights_shape = weights.get_shape() 88 | values_shape = values.get_shape() 89 | if(weights_shape.is_fully_defined() and 90 | values_shape.is_fully_defined() and 91 | weights_shape.is_compatible_with(values_shape)): 92 | return weights 93 | return math_ops.mul( 94 | weights, array_ops.ones_like(values), name='broadcast_weights') 95 | 96 | 97 | # =========================================================================== # 98 | # TF Extended metrics: TP and FP arrays. 99 | # =========================================================================== # 100 | def precision_recall(num_gbboxes, num_detections, tp, fp, scores, 101 | dtype=tf.float64, scope=None): 102 | """Compute precision and recall from scores, true positives and false 103 | positives booleans arrays 104 | """ 105 | # Input dictionaries: dict outputs as streaming metrics. 106 | if isinstance(scores, dict): 107 | d_precision = {} 108 | d_recall = {} 109 | for c in num_gbboxes.keys(): 110 | scope = 'precision_recall_%s' % c 111 | p, r = precision_recall(num_gbboxes[c], num_detections[c], 112 | tp[c], fp[c], scores[c], 113 | dtype, scope) 114 | d_precision[c] = p 115 | d_recall[c] = r 116 | return d_precision, d_recall 117 | 118 | # Sort by score. 119 | with tf.name_scope(scope, 'precision_recall', 120 | [num_gbboxes, num_detections, tp, fp, scores]): 121 | # Sort detections by score. 122 | scores, idxes = tf.nn.top_k(scores, k=num_detections, sorted=True) 123 | tp = tf.gather(tp, idxes) 124 | fp = tf.gather(fp, idxes) 125 | # Computer recall and precision. 126 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0) 127 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0) 128 | recall = _safe_div(tp, tf.cast(num_gbboxes, dtype), 'recall') 129 | precision = _safe_div(tp, tp + fp, 'precision') 130 | return tf.tuple([precision, recall]) 131 | 132 | 133 | def streaming_tp_fp_arrays(num_gbboxes, tp, fp, scores, 134 | remove_zero_scores=True, 135 | metrics_collections=None, 136 | updates_collections=None, 137 | name=None): 138 | """Streaming computation of True and False Positive arrays. This metrics 139 | also keeps track of scores and number of grountruth objects. 140 | """ 141 | # Input dictionaries: dict outputs as streaming metrics. 142 | if isinstance(scores, dict) or isinstance(fp, dict): 143 | d_values = {} 144 | d_update_ops = {} 145 | for c in num_gbboxes.keys(): 146 | scope = 'streaming_tp_fp_%s' % c 147 | v, up = streaming_tp_fp_arrays(num_gbboxes[c], tp[c], fp[c], scores[c], 148 | remove_zero_scores, 149 | metrics_collections, 150 | updates_collections, 151 | name=scope) 152 | d_values[c] = v 153 | d_update_ops[c] = up 154 | return d_values, d_update_ops 155 | 156 | # Input Tensors... 157 | with variable_scope.variable_scope(name, 'streaming_tp_fp', 158 | [num_gbboxes, tp, fp, scores]): 159 | num_gbboxes = math_ops.to_int64(num_gbboxes) 160 | scores = math_ops.to_float(scores) 161 | stype = tf.bool 162 | tp = tf.cast(tp, stype) 163 | fp = tf.cast(fp, stype) 164 | # Reshape TP and FP tensors and clean away 0 class values. 165 | scores = tf.reshape(scores, [-1]) 166 | tp = tf.reshape(tp, [-1]) 167 | fp = tf.reshape(fp, [-1]) 168 | # Remove TP and FP both false. 169 | mask = tf.logical_or(tp, fp) 170 | if remove_zero_scores: 171 | rm_threshold = 1e-4 172 | mask = tf.logical_and(mask, tf.greater(scores, rm_threshold)) 173 | scores = tf.boolean_mask(scores, mask) 174 | tp = tf.boolean_mask(tp, mask) 175 | fp = tf.boolean_mask(fp, mask) 176 | 177 | # Local variables accumlating information over batches. 178 | v_nobjects = _create_local('v_num_gbboxes', shape=[], dtype=tf.int64) 179 | v_ndetections = _create_local('v_num_detections', shape=[], dtype=tf.int32) 180 | v_scores = _create_local('v_scores', shape=[0, ]) 181 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype) 182 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype) 183 | 184 | # Update operations. 185 | nobjects_op = state_ops.assign_add(v_nobjects, 186 | tf.reduce_sum(num_gbboxes)) 187 | ndetections_op = state_ops.assign_add(v_ndetections, 188 | tf.size(scores, out_type=tf.int32)) 189 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, scores], axis=0), 190 | validate_shape=False) 191 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp], axis=0), 192 | validate_shape=False) 193 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp], axis=0), 194 | validate_shape=False) 195 | 196 | # Value and update ops. 197 | val = (v_nobjects, v_ndetections, v_tp, v_fp, v_scores) 198 | with ops.control_dependencies([nobjects_op, ndetections_op, 199 | scores_op, tp_op, fp_op]): 200 | update_op = (nobjects_op, ndetections_op, tp_op, fp_op, scores_op) 201 | 202 | if metrics_collections: 203 | ops.add_to_collections(metrics_collections, val) 204 | if updates_collections: 205 | ops.add_to_collections(updates_collections, update_op) 206 | return val, update_op 207 | 208 | 209 | # =========================================================================== # 210 | # Average precision computations. 211 | # =========================================================================== # 212 | def average_precision_voc12(precision, recall, name=None): 213 | """Compute (interpolated) average precision from precision and recall Tensors. 214 | 215 | The implementation follows Pascal 2012 and ILSVRC guidelines. 216 | See also: https://sanchom.wordpress.com/tag/average-precision/ 217 | """ 218 | with tf.name_scope(name, 'average_precision_voc12', [precision, recall]): 219 | # Convert to float64 to decrease error on Riemann sums. 220 | precision = tf.cast(precision, dtype=tf.float64) 221 | recall = tf.cast(recall, dtype=tf.float64) 222 | 223 | # Add bounds values to precision and recall. 224 | precision = tf.concat([[0.], precision, [0.]], axis=0) 225 | recall = tf.concat([[0.], recall, [1.]], axis=0) 226 | # Ensures precision is increasing in reverse order. 227 | precision = tfe_math.cummax(precision, reverse=True) 228 | 229 | # Riemann sums for estimating the integral. 230 | # mean_pre = (precision[1:] + precision[:-1]) / 2. 231 | mean_pre = precision[1:] 232 | diff_rec = recall[1:] - recall[:-1] 233 | ap = tf.reduce_sum(mean_pre * diff_rec) 234 | return ap 235 | 236 | 237 | def average_precision_voc07(precision, recall, name=None): 238 | """Compute (interpolated) average precision from precision and recall Tensors. 239 | 240 | The implementation follows Pascal 2007 guidelines. 241 | See also: https://sanchom.wordpress.com/tag/average-precision/ 242 | """ 243 | with tf.name_scope(name, 'average_precision_voc07', [precision, recall]): 244 | # Convert to float64 to decrease error on cumulated sums. 245 | precision = tf.cast(precision, dtype=tf.float64) 246 | recall = tf.cast(recall, dtype=tf.float64) 247 | # Add zero-limit value to avoid any boundary problem... 248 | precision = tf.concat([precision, [0.]], axis=0) 249 | recall = tf.concat([recall, [np.inf]], axis=0) 250 | 251 | # Split the integral into 10 bins. 252 | l_aps = [] 253 | for t in np.arange(0., 1.1, 0.1): 254 | mask = tf.greater_equal(recall, t) 255 | v = tf.reduce_max(tf.boolean_mask(precision, mask)) 256 | l_aps.append(v / 11.) 257 | ap = tf.add_n(l_aps) 258 | return ap 259 | 260 | 261 | def precision_recall_values(xvals, precision, recall, name=None): 262 | """Compute values on the precision/recall curve. 263 | 264 | Args: 265 | x: Python list of floats; 266 | precision: 1D Tensor decreasing. 267 | recall: 1D Tensor increasing. 268 | Return: 269 | list of precision values. 270 | """ 271 | with ops.name_scope(name, "precision_recall_values", 272 | [precision, recall]) as name: 273 | # Add bounds values to precision and recall. 274 | precision = tf.concat([[0.], precision, [0.]], axis=0) 275 | recall = tf.concat([[0.], recall, [1.]], axis=0) 276 | precision = tfe_math.cummax(precision, reverse=True) 277 | 278 | prec_values = [] 279 | for x in xvals: 280 | mask = tf.less_equal(recall, x) 281 | val = tf.reduce_min(tf.boolean_mask(precision, mask)) 282 | prec_values.append(val) 283 | return tf.tuple(prec_values) 284 | 285 | 286 | # =========================================================================== # 287 | # TF Extended metrics: old stuff! 288 | # =========================================================================== # 289 | def _precision_recall(n_gbboxes, n_detections, scores, tp, fp, scope=None): 290 | """Compute precision and recall from scores, true positives and false 291 | positives booleans arrays 292 | """ 293 | # Sort by score. 294 | with tf.name_scope(scope, 'prec_rec', [n_gbboxes, scores, tp, fp]): 295 | # Sort detections by score. 296 | scores, idxes = tf.nn.top_k(scores, k=n_detections, sorted=True) 297 | tp = tf.gather(tp, idxes) 298 | fp = tf.gather(fp, idxes) 299 | # Computer recall and precision. 300 | dtype = tf.float64 301 | tp = tf.cumsum(tf.cast(tp, dtype), axis=0) 302 | fp = tf.cumsum(tf.cast(fp, dtype), axis=0) 303 | recall = _safe_div(tp, tf.cast(n_gbboxes, dtype), 'recall') 304 | precision = _safe_div(tp, tp + fp, 'precision') 305 | 306 | return tf.tuple([precision, recall]) 307 | 308 | 309 | def streaming_precision_recall_arrays(n_gbboxes, rclasses, rscores, 310 | tp_tensor, fp_tensor, 311 | remove_zero_labels=True, 312 | metrics_collections=None, 313 | updates_collections=None, 314 | name=None): 315 | """Streaming computation of precision / recall arrays. This metrics 316 | keeps tracks of boolean True positives and False positives arrays. 317 | """ 318 | with variable_scope.variable_scope(name, 'stream_precision_recall', 319 | [n_gbboxes, rclasses, tp_tensor, fp_tensor]): 320 | n_gbboxes = math_ops.to_int64(n_gbboxes) 321 | rclasses = math_ops.to_int64(rclasses) 322 | rscores = math_ops.to_float(rscores) 323 | 324 | stype = tf.int32 325 | tp_tensor = tf.cast(tp_tensor, stype) 326 | fp_tensor = tf.cast(fp_tensor, stype) 327 | 328 | # Reshape TP and FP tensors and clean away 0 class values. 329 | rclasses = tf.reshape(rclasses, [-1]) 330 | rscores = tf.reshape(rscores, [-1]) 331 | tp_tensor = tf.reshape(tp_tensor, [-1]) 332 | fp_tensor = tf.reshape(fp_tensor, [-1]) 333 | if remove_zero_labels: 334 | mask = tf.greater(rclasses, 0) 335 | rclasses = tf.boolean_mask(rclasses, mask) 336 | rscores = tf.boolean_mask(rscores, mask) 337 | tp_tensor = tf.boolean_mask(tp_tensor, mask) 338 | fp_tensor = tf.boolean_mask(fp_tensor, mask) 339 | 340 | # Local variables accumlating information over batches. 341 | v_nobjects = _create_local('v_nobjects', shape=[], dtype=tf.int64) 342 | v_ndetections = _create_local('v_ndetections', shape=[], dtype=tf.int32) 343 | v_scores = _create_local('v_scores', shape=[0, ]) 344 | v_tp = _create_local('v_tp', shape=[0, ], dtype=stype) 345 | v_fp = _create_local('v_fp', shape=[0, ], dtype=stype) 346 | 347 | # Update operations. 348 | nobjects_op = state_ops.assign_add(v_nobjects, 349 | tf.reduce_sum(n_gbboxes)) 350 | ndetections_op = state_ops.assign_add(v_ndetections, 351 | tf.size(rscores, out_type=tf.int32)) 352 | scores_op = state_ops.assign(v_scores, tf.concat([v_scores, rscores], axis=0), 353 | validate_shape=False) 354 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp_tensor], axis=0), 355 | validate_shape=False) 356 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp_tensor], axis=0), 357 | validate_shape=False) 358 | 359 | # Precision and recall computations. 360 | # r = _precision_recall(nobjects_op, scores_op, tp_op, fp_op, 'value') 361 | r = _precision_recall(v_nobjects, v_ndetections, v_scores, 362 | v_tp, v_fp, 'value') 363 | 364 | with ops.control_dependencies([nobjects_op, ndetections_op, 365 | scores_op, tp_op, fp_op]): 366 | update_op = _precision_recall(nobjects_op, ndetections_op, 367 | scores_op, tp_op, fp_op, 'update_op') 368 | 369 | # update_op = tf.Print(update_op, 370 | # [tf.reduce_sum(tf.cast(mask, tf.int64)), 371 | # tf.reduce_sum(tf.cast(mask2, tf.int64)), 372 | # tf.reduce_min(rscores), 373 | # tf.reduce_sum(n_gbboxes)], 374 | # 'Metric: ') 375 | # Some debugging stuff! 376 | # update_op = tf.Print(update_op, 377 | # [tf.shape(tp_op), 378 | # tf.reduce_sum(tf.cast(tp_op, tf.int64), axis=0)], 379 | # 'TP and FP shape: ') 380 | # update_op[0] = tf.Print(update_op, 381 | # [nobjects_op], 382 | # '# Groundtruth bboxes: ') 383 | # update_op = tf.Print(update_op, 384 | # [update_op[0][0], 385 | # update_op[0][-1], 386 | # tf.reduce_min(update_op[0]), 387 | # tf.reduce_max(update_op[0]), 388 | # tf.reduce_min(update_op[1]), 389 | # tf.reduce_max(update_op[1])], 390 | # 'Precision and recall :') 391 | 392 | if metrics_collections: 393 | ops.add_to_collections(metrics_collections, r) 394 | if updates_collections: 395 | ops.add_to_collections(updates_collections, update_op) 396 | return r, update_op 397 | 398 | -------------------------------------------------------------------------------- /tf_extended/metrics.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/metrics.pyc -------------------------------------------------------------------------------- /tf_extended/tensors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional tensors operations. 16 | """ 17 | import tensorflow as tf 18 | 19 | from tensorflow.contrib.framework.python.ops import variables as contrib_variables 20 | from tensorflow.contrib.metrics.python.ops import set_ops 21 | from tensorflow.python.framework import dtypes 22 | from tensorflow.python.framework import ops 23 | from tensorflow.python.framework import sparse_tensor 24 | from tensorflow.python.ops import array_ops 25 | from tensorflow.python.ops import check_ops 26 | from tensorflow.python.ops import control_flow_ops 27 | from tensorflow.python.ops import math_ops 28 | from tensorflow.python.ops import nn 29 | from tensorflow.python.ops import state_ops 30 | from tensorflow.python.ops import variable_scope 31 | from tensorflow.python.ops import variables 32 | 33 | 34 | def get_shape(x, rank=None): 35 | """Returns the dimensions of a Tensor as list of integers or scale tensors. 36 | 37 | Args: 38 | x: N-d Tensor; 39 | rank: Rank of the Tensor. If None, will try to guess it. 40 | Returns: 41 | A list of `[d1, d2, ..., dN]` corresponding to the dimensions of the 42 | input tensor. Dimensions that are statically known are python integers, 43 | otherwise they are integer scalar tensors. 44 | """ 45 | if x.get_shape().is_fully_defined(): 46 | return x.get_shape().as_list() 47 | else: 48 | static_shape = x.get_shape() 49 | if rank is None: 50 | static_shape = static_shape.as_list() 51 | rank = len(static_shape) 52 | else: 53 | static_shape = x.get_shape().with_rank(rank).as_list() 54 | dynamic_shape = tf.unstack(tf.shape(x), rank) 55 | return [s if s is not None else d 56 | for s, d in zip(static_shape, dynamic_shape)] 57 | 58 | 59 | def pad_axis(x, offset, size, axis=0, name=None): 60 | """Pad a tensor on an axis, with a given offset and output size. 61 | The tensor is padded with zero (i.e. CONSTANT mode). Note that the if the 62 | `size` is smaller than existing size + `offset`, the output tensor 63 | was the latter dimension. 64 | 65 | Args: 66 | x: Tensor to pad; 67 | offset: Offset to add on the dimension chosen; 68 | size: Final size of the dimension. 69 | Return: 70 | Padded tensor whose dimension on `axis` is `size`, or greater if 71 | the input vector was larger. 72 | """ 73 | with tf.name_scope(name, 'pad_axis'): 74 | shape = get_shape(x) 75 | rank = len(shape) 76 | # Padding description. 77 | new_size = tf.maximum(size-offset-shape[axis], 0) 78 | pad1 = tf.stack([0]*axis + [offset] + [0]*(rank-axis-1)) 79 | pad2 = tf.stack([0]*axis + [new_size] + [0]*(rank-axis-1)) 80 | paddings = tf.stack([pad1, pad2], axis=1) 81 | x = tf.pad(x, paddings, mode='CONSTANT') 82 | # Reshape, to get fully defined shape if possible. 83 | # TODO: fix with tf.slice 84 | shape[axis] = size 85 | x = tf.reshape(x, tf.stack(shape)) 86 | return x 87 | 88 | 89 | # def select_at_index(idx, val, t): 90 | # """Return a tensor. 91 | # """ 92 | # idx = tf.expand_dims(tf.expand_dims(idx, 0), 0) 93 | # val = tf.expand_dims(val, 0) 94 | # t = t + tf.scatter_nd(idx, val, tf.shape(t)) 95 | # return t 96 | -------------------------------------------------------------------------------- /tf_extended/tensors.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/tensors.pyc -------------------------------------------------------------------------------- /tf_extended/tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Diverse TensorFlow utils, for training, evaluation and so on! 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import os 21 | from pprint import pprint 22 | 23 | import tensorflow as tf 24 | from tensorflow.contrib.slim.python.slim.data import parallel_reader 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | # =========================================================================== # 30 | # General tools. 31 | # =========================================================================== # 32 | def reshape_list(l, shape=None): 33 | """Reshape list of (list): 1D to 2D or the other way around. 34 | 35 | Args: 36 | l: List or List of list. 37 | shape: 1D or 2D shape. 38 | Return 39 | Reshaped list. 40 | """ 41 | r = [] 42 | if shape is None: 43 | # Flatten everything. 44 | for a in l: 45 | if isinstance(a, (list, tuple)): 46 | r = r + list(a) 47 | else: 48 | r.append(a) 49 | else: 50 | # Reshape to list of list. 51 | i = 0 52 | for s in shape: 53 | if s == 1: 54 | r.append(l[i]) 55 | else: 56 | r.append(l[i:i+s]) 57 | i += s 58 | return r 59 | 60 | 61 | # =========================================================================== # 62 | # Training utils. 63 | # =========================================================================== # 64 | def print_configuration(flags, ssd_params, data_sources, save_dir=None): 65 | """Print the training configuration. 66 | """ 67 | def print_config(stream=None): 68 | print('\n# =========================================================================== #', file=stream) 69 | print('# Training | Evaluation flags:', file=stream) 70 | print('# =========================================================================== #', file=stream) 71 | pprint(flags, stream=stream) 72 | 73 | print('\n# =========================================================================== #', file=stream) 74 | print('# SSD net parameters:', file=stream) 75 | print('# =========================================================================== #', file=stream) 76 | pprint(dict(ssd_params._asdict()), stream=stream) 77 | 78 | print('\n# =========================================================================== #', file=stream) 79 | print('# Training | Evaluation dataset files:', file=stream) 80 | print('# =========================================================================== #', file=stream) 81 | data_files = parallel_reader.get_data_files(data_sources) 82 | pprint(data_files, stream=stream) 83 | print('', file=stream) 84 | 85 | print_config(None) 86 | # Save to a text file as well. 87 | if save_dir is not None: 88 | if not os.path.exists(save_dir): 89 | os.makedirs(save_dir) 90 | path = os.path.join(save_dir, 'training_config.txt') 91 | with open(path, "w") as out: 92 | print_config(out) 93 | 94 | 95 | def configure_learning_rate(flags, num_samples_per_epoch, global_step): 96 | """Configures the learning rate. 97 | 98 | Args: 99 | num_samples_per_epoch: The number of samples in each epoch of training. 100 | global_step: The global_step tensor. 101 | Returns: 102 | A `Tensor` representing the learning rate. 103 | """ 104 | decay_steps = int(num_samples_per_epoch / flags.batch_size * 105 | flags.num_epochs_per_decay) 106 | 107 | if flags.learning_rate_decay_type == 'exponential': 108 | return tf.train.exponential_decay(flags.learning_rate, 109 | global_step, 110 | decay_steps, 111 | flags.learning_rate_decay_factor, 112 | staircase=True, 113 | name='exponential_decay_learning_rate') 114 | elif flags.learning_rate_decay_type == 'fixed': 115 | return tf.constant(flags.learning_rate, name='fixed_learning_rate') 116 | elif flags.learning_rate_decay_type == 'polynomial': 117 | return tf.train.polynomial_decay(flags.learning_rate, 118 | global_step, 119 | decay_steps, 120 | flags.end_learning_rate, 121 | power=1.0, 122 | cycle=False, 123 | name='polynomial_decay_learning_rate') 124 | else: 125 | raise ValueError('learning_rate_decay_type [%s] was not recognized', 126 | flags.learning_rate_decay_type) 127 | 128 | 129 | def configure_optimizer(flags, learning_rate): 130 | """Configures the optimizer used for training. 131 | 132 | Args: 133 | learning_rate: A scalar or `Tensor` learning rate. 134 | Returns: 135 | An instance of an optimizer. 136 | """ 137 | if flags.optimizer == 'adadelta': 138 | optimizer = tf.train.AdadeltaOptimizer( 139 | learning_rate, 140 | rho=flags.adadelta_rho, 141 | epsilon=flags.opt_epsilon) 142 | elif flags.optimizer == 'adagrad': 143 | optimizer = tf.train.AdagradOptimizer( 144 | learning_rate, 145 | initial_accumulator_value=flags.adagrad_initial_accumulator_value) 146 | elif flags.optimizer == 'adam': 147 | optimizer = tf.train.AdamOptimizer( 148 | learning_rate, 149 | beta1=flags.adam_beta1, 150 | beta2=flags.adam_beta2, 151 | epsilon=flags.opt_epsilon) 152 | elif flags.optimizer == 'ftrl': 153 | optimizer = tf.train.FtrlOptimizer( 154 | learning_rate, 155 | learning_rate_power=flags.ftrl_learning_rate_power, 156 | initial_accumulator_value=flags.ftrl_initial_accumulator_value, 157 | l1_regularization_strength=flags.ftrl_l1, 158 | l2_regularization_strength=flags.ftrl_l2) 159 | elif flags.optimizer == 'momentum': 160 | optimizer = tf.train.MomentumOptimizer( 161 | learning_rate, 162 | momentum=flags.momentum, 163 | name='Momentum') 164 | elif flags.optimizer == 'rmsprop': 165 | optimizer = tf.train.RMSPropOptimizer( 166 | learning_rate, 167 | decay=flags.rmsprop_decay, 168 | momentum=flags.rmsprop_momentum, 169 | epsilon=flags.opt_epsilon) 170 | elif flags.optimizer == 'sgd': 171 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 172 | else: 173 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer) 174 | return optimizer 175 | 176 | 177 | def add_variables_summaries(learning_rate): 178 | summaries = [] 179 | for variable in slim.get_model_variables(): 180 | summaries.append(tf.summary.histogram(variable.op.name, variable)) 181 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate)) 182 | return summaries 183 | 184 | 185 | def update_model_scope(var, ckpt_scope, new_scope): 186 | return var.op.name.replace(new_scope,'vgg_16') 187 | 188 | 189 | def get_init_fn(flags): 190 | """Returns a function run by the chief worker to warm-start the training. 191 | Note that the init_fn is only run when initializing the model during the very 192 | first global step. 193 | 194 | Returns: 195 | An init function run by the supervisor. 196 | """ 197 | if flags.checkpoint_path is None: 198 | return None 199 | # Warn the user if a checkpoint exists in the train_dir. Then ignore. 200 | if tf.train.latest_checkpoint(flags.train_dir): 201 | tf.logging.info( 202 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s' 203 | % flags.train_dir) 204 | return None 205 | 206 | exclusions = [] 207 | if flags.checkpoint_exclude_scopes: 208 | exclusions = [scope.strip() 209 | for scope in flags.checkpoint_exclude_scopes.split(',')] 210 | 211 | # TODO(sguada) variables.filter_variables() 212 | variables_to_restore = [] 213 | for var in slim.get_model_variables(): 214 | excluded = False 215 | for exclusion in exclusions: 216 | if var.op.name.startswith(exclusion): 217 | excluded = True 218 | break 219 | if not excluded: 220 | variables_to_restore.append(var) 221 | # Change model scope if necessary. 222 | if flags.checkpoint_model_scope is not None: 223 | variables_to_restore = \ 224 | {var.op.name.replace(flags.model_name, 225 | flags.checkpoint_model_scope): var 226 | for var in variables_to_restore} 227 | print(variables_to_restore) 228 | if tf.gfile.IsDirectory(flags.checkpoint_path): 229 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path) 230 | else: 231 | checkpoint_path = flags.checkpoint_path 232 | tf.logging.info('Fine-tuning from %s' % checkpoint_path) 233 | 234 | return slim.assign_from_checkpoint_fn( 235 | checkpoint_path, 236 | variables_to_restore, 237 | ignore_missing_vars=flags.ignore_missing_vars) 238 | 239 | 240 | def get_variables_to_train(flags): 241 | """Returns a list of variables to train. 242 | 243 | Returns: 244 | A list of variables to train by the optimizer. 245 | """ 246 | if flags.trainable_scopes is None: 247 | return tf.trainable_variables() 248 | else: 249 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')] 250 | 251 | variables_to_train = [] 252 | for scope in scopes: 253 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 254 | variables_to_train.extend(variables) 255 | return variables_to_train 256 | 257 | 258 | # =========================================================================== # 259 | # Evaluation utils. 260 | # =========================================================================== # 261 | -------------------------------------------------------------------------------- /tf_extended/tf_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shun14/TextBoxes_plusplus_Tensorflow/0b60e4675db6eea9e910a4c016073f9b696f3614/tf_extended/tf_utils.pyc -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = osp.dirname(__file__) 9 | 10 | # Add lib to PYTHONPATH 11 | lib_path = osp.join(this_dir, '..') 12 | add_path(lib_path) 13 | -------------------------------------------------------------------------------- /tools/convert_xml_format.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | import random 5 | import time 6 | import numpy as np 7 | import codecs 8 | import cv2 9 | import xml.etree.ElementTree as ET 10 | from xml.etree.ElementTree import SubElement 11 | 12 | 13 | def process_convert(name, DIRECTORY_ANNOTATIONS, img_path, save_xml_path): 14 | # Read the XML annotation file. 15 | filename = os.path.join(DIRECTORY_ANNOTATIONS, name) 16 | try: 17 | tree = ET.parse(filename) 18 | except: 19 | print('error:', filename, ' not exist') 20 | return False 21 | 22 | root = tree.getroot() 23 | size = root.find('size') 24 | if size is None: 25 | img = cv2.imread(img_path) 26 | print('jpg_path', img_path, img.shape) 27 | shape = [int(img.shape[0]), int(img.shape[1]), int(img.shape[2])] 28 | # size = SubElement(root, 'size') 29 | 30 | elif size.find('height').text is None or size.find('width').text is None: 31 | img = cv2.imread(img_path) 32 | print('jpg_path height', img_path, img.shape) 33 | shape = [int(img.shape[0]), int(img.shape[1]), int(img.shape[2])] 34 | elif int(size.find('height').text) == 0 or int( 35 | size.find('width').text) == 0: 36 | 37 | img = cv2.imread(img_path) 38 | print('jpg_path zero', img_path, img.shape) 39 | shape = [int(img.shape[0]), int(img.shape[1]), int(img.shape[2])] 40 | else: 41 | shape = [ 42 | int(size.find('height').text), 43 | int(size.find('width').text), 44 | int(size.find('depth').text) 45 | ] 46 | 47 | height = size.find('height') 48 | height.text = str(shape[0]) 49 | width = size.find('width') 50 | width.text = str(shape[1]) 51 | 52 | for obj in root.findall('object'): 53 | difficult = int(obj.find('difficult').text) 54 | content = obj.find('name').text 55 | content = content.replace('\t', ' ') 56 | 57 | #if int(difficult) == 1 and content == '&*@HUST_special': 58 | ''' 59 | 这里代表HUST_vertical是text 60 | ''' 61 | if difficult == 0 and content != '&*@HUST_special' and content != '&*HUST_shelter': 62 | label_name = 'text' 63 | else: 64 | label_name = 'none' 65 | 66 | bbox = obj.find('bndbox') 67 | if obj.find('content') is None: 68 | content_sub = SubElement(obj, 'content') 69 | content_sub.text = content 70 | else: 71 | obj.find('content').text = content 72 | 73 | name_ele = obj.find('name') 74 | name_ele.text = label_name 75 | 76 | xmin = bbox.find('xmin').text 77 | ymin = bbox.find('ymin').text 78 | xmax = bbox.find('xmax').text 79 | ymax = bbox.find('ymax').text 80 | 81 | x1 = xmin 82 | x2 = xmax 83 | x3 = xmax 84 | x4 = xmin 85 | 86 | y1 = ymin 87 | y2 = ymin 88 | y3 = ymax 89 | y4 = ymax 90 | 91 | if bbox.find('x1') is None: 92 | x1_sub = SubElement(bbox, 'x1') 93 | x1_sub.text = x1 94 | x2_sub = SubElement(bbox, 'x2') 95 | x2_sub.text = x2 96 | x3_sub = SubElement(bbox, 'x3') 97 | x3_sub.text = x3 98 | x4_sub = SubElement(bbox, 'x4') 99 | x4_sub.text = x4 100 | 101 | y1_sub = SubElement(bbox, 'y1') 102 | y1_sub.text = y1 103 | y2_sub = SubElement(bbox, 'y2') 104 | y2_sub.text = y2 105 | y3_sub = SubElement(bbox, 'y3') 106 | y3_sub.text = y3 107 | y4_sub = SubElement(bbox, 'y4') 108 | y4_sub.text = y4 109 | else: 110 | bbox.find('y1').text = ymin 111 | bbox.find('y2').text = ymin 112 | bbox.find('y3').text = ymax 113 | bbox.find('y4').text = ymax 114 | #print(save_xml_path) 115 | tree.write(save_xml_path) 116 | return True 117 | 118 | 119 | def process_convert_txt(name, DIRECTORY_ANNOTATIONS): 120 | # Read the XML annotation file. 121 | filename = os.path.join(DIRECTORY_ANNOTATIONS, name) 122 | try: 123 | tree = ET.parse(filename) 124 | except: 125 | print('error:', filename, ' not exist') 126 | return 127 | root = tree.getroot() 128 | all_txt_line = [] 129 | for obj in root.findall('object'): 130 | 131 | bbox = obj.find('bndbox') 132 | 133 | difficult = int(obj.find('difficult').text) 134 | content = obj.find('content') 135 | if content is not None: 136 | content = content.text 137 | else: 138 | content = 0 139 | if difficult == 1 and content == '&*@HUST_special': 140 | continue 141 | xmin = bbox.find('xmin').text 142 | ymin = bbox.find('ymin').text 143 | xmax = bbox.find('xmax').text 144 | ymax = bbox.find('ymax').text 145 | 146 | x1 = xmin 147 | x2 = xmax 148 | x3 = xmax 149 | x4 = xmin 150 | 151 | y1 = ymin 152 | y2 = ymin 153 | y3 = ymax 154 | y4 = ymax 155 | 156 | all_txt_line.append('{} {} {} {} {} {} {} {}\n'.format( 157 | x1, y1, x2, y2, x3, y3, x4, y4)) 158 | 159 | txt_name = os.path.join(DIRECTORY_ANNOTATIONS, name[:-4] + '.txt') 160 | with codecs.open(txt_name, 'w', encoding='utf-8') as f: 161 | f.writelines(all_txt_line) 162 | 163 | 164 | def get_all_img(directory, split_flag, logs_dir, output_dir): 165 | count = 0 166 | ano_path_list = [] 167 | img_path_list = [] 168 | if output_dir is not None and not os.path.exists(output_dir): 169 | os.makedirs(output_dir) 170 | 171 | start_time = time.time() 172 | for root, dirs, files in os.walk(directory): 173 | for each in files: 174 | if each.split('.')[-1] == 'xml': 175 | xml_path = os.path.join(root, each[:-4] + '.xml') 176 | img_path = os.path.join(root, each[:-4] + '.png') 177 | if os.path.exists(img_path) == False: 178 | img_path = os.path.join(root, each[:-4] + '.PNG') 179 | test_png = cv2.imread(img_path) 180 | if test_png is None or os.path.exists(xml_path) == False: 181 | continue 182 | if output_dir is not None: 183 | sub_path = root[len(directory)+1:] 184 | sub_path = os.path.join(output_dir, sub_path) 185 | if not os.path.exists(sub_path): 186 | os.makedirs(sub_path) 187 | save_xml_path = os.path.join(sub_path, each[:-4] + '.xml') 188 | else: 189 | save_xml_path = xml_path 190 | if process_convert(each, root, img_path, save_xml_path): 191 | ano_path_list.append('{},{}\n'.format( 192 | img_path, 193 | save_xml_path)) 194 | img_path_list.append('{}\n'.format( 195 | img_path)) 196 | count += 1 197 | if count % 1000 == 0: 198 | print(count, time.time() - start_time) 199 | save_to_text(img_path_list, ano_path_list, count, split_flag, logs_dir) 200 | print('all over:', count) 201 | print('time:', time.time() - start_time) 202 | 203 | 204 | def save_to_text(img_path_list, ano_path_list, count, split_flag, logs_dir): 205 | if split_flag == 'yes': 206 | train_num = int(count / 10. * 9.) 207 | else: 208 | train_num = count 209 | if not os.path.exists(logs_dir): 210 | os.makedirs(logs_dir) 211 | 212 | with codecs.open( 213 | os.path.join(logs_dir, 'train_xml.txt'), 'w', 214 | encoding='utf-8') as f_xml, codecs.open( 215 | os.path.join(logs_dir, 'train.txt'), 'w', 216 | encoding='utf-8') as f_txt: 217 | f_xml.writelines(ano_path_list[:train_num]) 218 | f_txt.writelines(img_path_list[:train_num]) 219 | 220 | with codecs.open( 221 | os.path.join(logs_dir, 'test_xml.txt'), 'w', 222 | encoding='utf-8') as f_xml, codecs.open( 223 | os.path.join(logs_dir, 'test.txt'), 'w', 224 | encoding='utf-8') as f_txt: 225 | f_xml.writelines(ano_path_list[train_num:]) 226 | f_txt.writelines(img_path_list[train_num:]) 227 | 228 | 229 | if __name__ == '__main__': 230 | import argparse 231 | parser = argparse.ArgumentParser( 232 | description='icdar15 generate xml tools for standard format') 233 | parser.add_argument( 234 | '--in_dir', 235 | '-i', 236 | default= 237 | '/home/zsz/datasets/icdar15_anno/annotated_data_3rd_8thv2_cut_resize_margin8', 238 | type=str) 239 | parser.add_argument('--split_flag', '-s', default='no', type=str) 240 | parser.add_argument('--save_logs', '-l', default='logs', type=str) 241 | parser.add_argument('--output_dir', '-o', default=None, type=str) 242 | 243 | args = parser.parse_args() 244 | directory = args.in_dir 245 | split_flag = args.split_flag 246 | logs_dir = args.save_logs 247 | output_dir = args.output_dir 248 | get_all_img(directory, split_flag, logs_dir, output_dir) 249 | -------------------------------------------------------------------------------- /tools/gen_xml.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | from lxml import etree 4 | import xml.dom.minidom 5 | import sys 6 | import random 7 | 8 | import numpy as np 9 | import codecs 10 | import cv2 11 | 12 | def process_convert(txt_name, DIRECTORY_ANNOTATIONS, new_version=True): 13 | 14 | # Read the txt annotation file. 15 | filename = os.path.join(DIRECTORY_ANNOTATIONS, txt_name) 16 | with codecs.open(filename, 'r', encoding='utf-8') as f: 17 | lines = f.readlines() 18 | 19 | annotation_xml = xml.dom.minidom.Document() 20 | 21 | root = annotation_xml.createElement('annotation') 22 | annotation_xml.appendChild(root) 23 | 24 | nodeFolder = annotation_xml.createElement('folder') 25 | root.appendChild(nodeFolder) 26 | 27 | nodeFilename = annotation_xml.createElement('filename') 28 | nodeFilename.appendChild(annotation_xml.createTextNode(filename)) 29 | root.appendChild(nodeFilename) 30 | 31 | img_name = filename[:-4] 32 | if cv2.imread(img_name) is not None: 33 | h, w, c = cv2.imread(img_name).shape 34 | else: 35 | raise KeyError('img_name error:', img_name) 36 | nodeSize = annotation_xml.createElement('size') 37 | nodeWidth = annotation_xml.createElement('width') 38 | nodeHeight = annotation_xml.createElement('height') 39 | nodeDepth = annotation_xml.createElement('depth') 40 | nodeWidth.appendChild(annotation_xml.createTextNode(str(w))) 41 | nodeHeight.appendChild(annotation_xml.createTextNode(str(h))) 42 | nodeDepth.appendChild(annotation_xml.createTextNode(str(c))) 43 | 44 | nodeSize.appendChild(nodeWidth) 45 | nodeSize.appendChild(nodeHeight) 46 | nodeSize.appendChild(nodeDepth) 47 | root.appendChild(nodeSize) 48 | for l in lines: 49 | l = l.encode('utf-8').decode('utf-8-sig') 50 | l = l.strip().split(',') 51 | difficult = 0 52 | label = 1 53 | x1_text = '' 54 | x2_text = '' 55 | x3_text = '' 56 | x4_text = '' 57 | y1_text = '' 58 | y2_text = '' 59 | y3_text = '' 60 | y4_text = '' 61 | 62 | if new_version is True: 63 | label_name = str(l[-1]) 64 | if label_name == 'none': 65 | difficult = 1 66 | else: 67 | difficult = 0 68 | label = label_name 69 | xmin_text = str(int(float(l[1]))) 70 | ymin_text = str(int(float(l[2]))) 71 | xmax_text = str(int(float(l[3]))) 72 | ymax_text = str(int(float(l[4]))) 73 | x1_text = xmin_text 74 | x2_text = xmax_text 75 | x3_text = xmax_text 76 | x4_text = xmin_text 77 | 78 | y1_text = ymin_text 79 | y2_text = ymin_text 80 | y3_text = ymax_text 81 | y4_text = ymax_text 82 | else: 83 | print(l) 84 | xs = [ int(l[i]) for i in (0, 2, 4, 6)] 85 | ys = [ int(l[i]) for i in (1, 3, 5, 7)] 86 | xmin_text = str(min(xs)) 87 | ymin_text = str(min(ys)) 88 | xmax_text = str(max(xs)) 89 | ymax_text = str(max(ys)) 90 | x1_text = str(xs[0]) 91 | x2_text = str(xs[1]) 92 | x3_text = str(xs[2]) 93 | x4_text = str(xs[3]) 94 | 95 | y1_text = str(ys[0]) 96 | y2_text = str(ys[1]) 97 | y3_text = str(ys[2]) 98 | y4_text = str(ys[3]) 99 | 100 | 101 | nodeObject = annotation_xml.createElement('object') 102 | nodeDifficult = annotation_xml.createElement('difficult') 103 | nodeDifficult.appendChild(annotation_xml.createTextNode(str(difficult))) 104 | nodeName = annotation_xml.createElement('name') 105 | nodeName.appendChild(annotation_xml.createTextNode(str(label))) 106 | 107 | nodeBndbox = annotation_xml.createElement('bndbox') 108 | nodexmin = annotation_xml.createElement('xmin') 109 | nodexmin.appendChild(annotation_xml.createTextNode(xmin_text)) 110 | nodeymin = annotation_xml.createElement('ymin') 111 | nodeymin.appendChild(annotation_xml.createTextNode(ymin_text)) 112 | nodexmax = annotation_xml.createElement('xmax') 113 | nodexmax.appendChild(annotation_xml.createTextNode(xmax_text)) 114 | nodeymax = annotation_xml.createElement('ymax') 115 | nodeymax.appendChild(annotation_xml.createTextNode(ymax_text)) 116 | 117 | nodex1 = annotation_xml.createElement('x1') 118 | nodex1.appendChild(annotation_xml.createTextNode(x1_text)) 119 | nodex2 = annotation_xml.createElement('x2') 120 | nodex2.appendChild(annotation_xml.createTextNode(x2_text)) 121 | nodex3 = annotation_xml.createElement('x3') 122 | nodex3.appendChild(annotation_xml.createTextNode(x3_text)) 123 | nodex4 = annotation_xml.createElement('x4') 124 | nodex4.appendChild(annotation_xml.createTextNode(x4_text)) 125 | 126 | 127 | nodey1 = annotation_xml.createElement('y1') 128 | nodey1.appendChild(annotation_xml.createTextNode(y1_text)) 129 | 130 | nodey2 = annotation_xml.createElement('y2') 131 | nodey2.appendChild(annotation_xml.createTextNode(y2_text)) 132 | 133 | nodey3 = annotation_xml.createElement('y3') 134 | nodey3.appendChild(annotation_xml.createTextNode(y3_text)) 135 | nodey4 = annotation_xml.createElement('y4') 136 | nodey4.appendChild(annotation_xml.createTextNode(y4_text)) 137 | 138 | nodeBndbox.appendChild(nodexmin) 139 | nodeBndbox.appendChild(nodeymin) 140 | nodeBndbox.appendChild(nodexmax) 141 | nodeBndbox.appendChild(nodeymax) 142 | nodeBndbox.appendChild(nodex1) 143 | nodeBndbox.appendChild(nodex2) 144 | nodeBndbox.appendChild(nodex3) 145 | nodeBndbox.appendChild(nodex4) 146 | 147 | nodeBndbox.appendChild(nodey1) 148 | nodeBndbox.appendChild(nodey2) 149 | nodeBndbox.appendChild(nodey3) 150 | nodeBndbox.appendChild(nodey4) 151 | 152 | nodeObject.appendChild(nodeDifficult) 153 | nodeObject.appendChild(nodeName) 154 | nodeObject.appendChild(nodeBndbox) 155 | root.appendChild(nodeObject) 156 | 157 | 158 | 159 | xml_path = os.path.join(DIRECTORY_ANNOTATIONS, txt_name[0:-4] + '.xml') 160 | fp = open(xml_path, 'w') 161 | annotation_xml.writexml(fp, indent='\t', addindent='\t', newl='\n', encoding="utf-8") 162 | 163 | 164 | 165 | def get_all_txt(directory, new_version=False): 166 | count = 0 167 | for root,dirs,files in os.walk(directory): 168 | for each in files: 169 | if each.split('.')[-1] == 'txt': 170 | count += 1 171 | print(count, each) 172 | 173 | process_convert(each, root, new_version) 174 | 175 | 176 | 177 | if __name__ == '__main__': 178 | import argparse 179 | parser = argparse.ArgumentParser(description='icdar15 generate xml tools') 180 | parser.add_argument('--in_dir','-i', default='/home/zsz/datasets/icdar15_anno/eval_img_2018.9.25/icdar15_eval_final', type=str) 181 | args = parser.parse_args() 182 | directory = args.in_dir 183 | get_all_txt(directory, False) 184 | -------------------------------------------------------------------------------- /tools/test_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import sythtextprovider 2 | import tensorflow as tf 3 | import numpy as np 4 | import cv2 5 | import os 6 | import time 7 | from nets import txtbox_384 8 | from processing import ssd_vgg_preprocessing 9 | 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 11 | slim = tf.contrib.slim 12 | 13 | tf.logging.set_verbosity(tf.logging.INFO) 14 | 15 | show_pic_sum = 10 16 | save_dir = 'pic_test_dataset' 17 | tf.app.flags.DEFINE_string( 18 | 'dataset_dir', 'tfrecord_train', 'The directory where the dataset files are stored.') 19 | 20 | tf.app.flags.DEFINE_integer( 21 | 'num_readers', 2, 22 | 'The number of parallel readers that read data from the dataset.') 23 | 24 | tf.app.flags.DEFINE_integer( 25 | 'batch_size', 2, 'The number of samples in each batch.') 26 | 27 | FLAGS = tf.app.flags.FLAGS 28 | if not os.path.exists(save_dir): 29 | os.makedirs(save_dir) 30 | 31 | def draw_polygon(img,x1,y1,x2,y2,x3,y3,x4,y4, color=(255, 0, 0)): 32 | # print(x1, x2, x3, x4, y1, y2, y3, y4) 33 | x1 = int(x1) 34 | x2 = int(x2) 35 | x3 = int(x3) 36 | x4 = int(x4) 37 | 38 | y1 = int(y1) 39 | y2 = int(y2) 40 | y3 = int(y3) 41 | y4 = int(y4) 42 | cv2.line(img,(x1,y1),(x2,y2),color,2) 43 | cv2.line(img,(x2,y2),(x3,y3),color,2) 44 | cv2.line(img,(x3,y3),(x4,y4),color,2) 45 | cv2.line(img,(x4,y4),(x1,y1),color,2) 46 | # cv2_im = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 47 | # cv2_im = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 48 | # cv2.imwrite('test.png', img) 49 | return img 50 | 51 | 52 | def run(): 53 | if not FLAGS.dataset_dir: 54 | raise ValueError('You must supply the dataset directory with --dataset_dir') 55 | 56 | print('-----start test-------') 57 | if not os.path.exists(save_dir): 58 | os.makedirs(save_dir) 59 | with tf.device('/GPU:0'): 60 | dataset = sythtextprovider.get_datasets(FLAGS.dataset_dir) 61 | print(dataset) 62 | provider = slim.dataset_data_provider.DatasetDataProvider( 63 | dataset, 64 | num_readers=FLAGS.num_readers, 65 | common_queue_capacity=20 * FLAGS.batch_size, 66 | common_queue_min=10 * FLAGS.batch_size, 67 | shuffle=True) 68 | print('provider:',provider) 69 | [image, shape, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4] = provider.get(['image', 'shape', 70 | 'object/label', 71 | 'object/bbox', 72 | 'object/oriented_bbox/x1', 73 | 'object/oriented_bbox/x2', 74 | 'object/oriented_bbox/x3', 75 | 'object/oriented_bbox/x4', 76 | 'object/oriented_bbox/y1', 77 | 'object/oriented_bbox/y2', 78 | 'object/oriented_bbox/y3', 79 | 'object/oriented_bbox/y4' 80 | ]) 81 | print('image:',image) 82 | print('shape:',shape) 83 | print('glabel:',glabels) 84 | print('gboxes:',gbboxes) 85 | 86 | 87 | gxs = tf.transpose(tf.stack([x1,x2,x3,x4])) #shape = (N,4) 88 | gys = tf.transpose(tf.stack([y1, y2, y3, y4])) 89 | 90 | image = tf.identity(image, 'input_image') 91 | text_shape = (384, 384) 92 | image, glabels, gbboxes, gxs, gys= ssd_vgg_preprocessing.preprocess_image(image, glabels,gbboxes,gxs, gys, 93 | text_shape,is_training=True, 94 | data_format='NHWC') 95 | 96 | 97 | x1, x2 , x3, x4 = tf.unstack(gxs, axis=1) 98 | y1, y2, y3, y4 = tf.unstack(gys, axis=1) 99 | 100 | text_net = txtbox_384.TextboxNet() 101 | text_anchors = text_net.anchors(text_shape) 102 | e_localisations, e_scores, e_labels = text_net.bboxes_encode( glabels, gbboxes, text_anchors, gxs, gys) 103 | 104 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7) 105 | 106 | config = tf.ConfigProto(log_device_placement=False, gpu_options=gpu_options, allow_soft_placement=True) 107 | with tf.Session(config=config) as sess: 108 | coord = tf.train.Coordinator() 109 | threads = tf.train.start_queue_runners(sess, coord) 110 | j = 0 111 | all_time = 0 112 | try: 113 | while not coord.should_stop() and j < show_pic_sum: 114 | start_time = time.time() 115 | image_sess, label_sess, gbbox_sess, x1_sess, x2_sess, x3_sess, x4_sess, y1_sess, y2_sess, y3_sess, y4_sess,p_localisations, p_scores, p_labels = sess.run([ 116 | image, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4,e_localisations , e_scores, e_labels]) 117 | end_time = time.time() - start_time 118 | all_time += end_time 119 | image_np = image_sess 120 | # print(image_np) 121 | # print('label_sess:',label_sess) 122 | 123 | p_labels_concat = np.concatenate(p_labels) 124 | p_scores_concat = np.concatenate(p_scores) 125 | debug = False 126 | if debug is True: 127 | print(p_labels) 128 | print('l_labels:', len(p_labels_concat[p_labels_concat.nonzero()]),p_labels_concat[p_labels_concat.nonzero()] ) 129 | print('p_socres:', len(p_scores_concat[p_scores_concat.nonzero()]), p_scores_concat[p_scores_concat.nonzero()]) 130 | # print(img_np.shape) 131 | 132 | print('label_sess:', np.array(list(label_sess)).shape, list(label_sess)) 133 | img_np = np.array(image_np) 134 | cv2.imwrite('{}/{}.png'.format(save_dir, j), img_np) 135 | img_np = cv2.imread('{}/{}.png'.format(save_dir, j)) 136 | 137 | h, w, d = img_np.shape 138 | 139 | label_sess = list(label_sess) 140 | # for i , label in enumerate(label_sess): 141 | i = 0 142 | num_correct = 0 143 | 144 | for label in label_sess: 145 | # print(int(label) == 1) 146 | if int(label) == 1: 147 | num_correct += 1 148 | img_np = draw_polygon(img_np,x1_sess[i] * w, y1_sess[i]*h, x2_sess[i]*w, y2_sess[i]*h, x3_sess[i]*w, y3_sess[i]*h, x4_sess[i]*w, y4_sess[i]*h) 149 | if int(label) == 0: 150 | img_np = draw_polygon(img_np,x1_sess[i] * w, y1_sess[i]*h, x2_sess[i]*w, y2_sess[i]*h, x3_sess[i]*w, y3_sess[i]*h, x4_sess[i]*w, y4_sess[i]*h, color=(0, 0, 255)) 151 | i += 1 152 | img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) 153 | cv2.imwrite('{}'.format(os.path.join(save_dir, str(j)+'.png')), img_np) 154 | j+= 1 155 | print('correct:', num_correct) 156 | except tf.errors.OutOfRangeError: 157 | print('done') 158 | finally: 159 | print('done') 160 | coord.request_stop() 161 | print('all time:', all_time, 'average:', all_time / show_pic_sum) 162 | coord.join(threads=threads) 163 | 164 | if __name__ == '__main__': 165 | run() 166 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Generic training script that trains a SSD model using a given dataset.""" 16 | 17 | import tensorflow as tf 18 | from tensorflow.python.ops import control_flow_ops 19 | 20 | from datasets import sythtextprovider 21 | from deployment import model_deploy 22 | from nets import txtbox_384, txtbox_768 23 | from processing import ssd_vgg_preprocessing 24 | from tf_extended import tf_utils 25 | import os 26 | import tensorflow.contrib.slim as slim 27 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 28 | 29 | # =========================================================================== # 30 | # Text Network flags. 31 | # =========================================================================== # 32 | tf.app.flags.DEFINE_float('loss_alpha', 0.2, 33 | 'Alpha parameter in the loss function.') 34 | tf.app.flags.DEFINE_float('negative_ratio', 3., 35 | 'Negative ratio in the loss function.') 36 | tf.app.flags.DEFINE_float('match_threshold', 0.5, 37 | 'Matching threshold in the loss function.') 38 | tf.app.flags.DEFINE_boolean('large_training', False, 'Use 768 to train') 39 | # =========================================================================== # 40 | # General Flags. 41 | # =========================================================================== # 42 | tf.app.flags.DEFINE_string( 43 | 'train_dir', 44 | 'icdar15_model/', 45 | 'Directory where checkpoints and event logs are written to.') 46 | tf.app.flags.DEFINE_integer('num_clones', 1, 47 | 'Number of model clones to deploy.') 48 | tf.app.flags.DEFINE_boolean('clone_on_cpu', False, 49 | 'Use CPUs to deploy clones.') 50 | tf.app.flags.DEFINE_integer( 51 | 'num_readers', 8, 52 | 'The number of parallel readers that read data from the dataset.') 53 | tf.app.flags.DEFINE_integer( 54 | 'num_preprocessing_threads', 8, 55 | 'The number of threads used to create the batches.') 56 | 57 | tf.app.flags.DEFINE_integer('log_every_n_steps', 10, 58 | 'The frequency with which logs are print.') 59 | tf.app.flags.DEFINE_integer( 60 | 'save_summaries_secs', 120, 61 | 'The frequency with which summaries are saved, in seconds.') 62 | tf.app.flags.DEFINE_integer( 63 | 'save_interval_secs', 1200, 64 | 'The frequency with which the model is saved, in seconds.') 65 | tf.app.flags.DEFINE_float('gpu_memory_fraction', 0.9, 66 | 'GPU memory fraction to use.') 67 | 68 | # =========================================================================== # 69 | # Optimization Flags. 70 | # =========================================================================== # 71 | tf.app.flags.DEFINE_float('weight_decay', 0.0005, 72 | 'The weight decay on the model weights.') 73 | tf.app.flags.DEFINE_string( 74 | 'optimizer', 'adam', 75 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",' 76 | '"ftrl", "momentum", "sgd" or "rmsprop".') 77 | tf.app.flags.DEFINE_float('adadelta_rho', 0.95, 'The decay rate for adadelta.') 78 | tf.app.flags.DEFINE_float('adagrad_initial_accumulator_value', 0.1, 79 | 'Starting value for the AdaGrad accumulators.') 80 | tf.app.flags.DEFINE_float( 81 | 'adam_beta1', 0.9, 82 | 'The exponential decay rate for the 1st moment estimates.') 83 | tf.app.flags.DEFINE_float( 84 | 'adam_beta2', 0.999, 85 | 'The exponential decay rate for the 2nd moment estimates.') 86 | tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 87 | 'Epsilon term for the optimizer.') 88 | tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5, 89 | 'The learning rate power.') 90 | tf.app.flags.DEFINE_float('ftrl_initial_accumulator_value', 0.1, 91 | 'Starting value for the FTRL accumulators.') 92 | tf.app.flags.DEFINE_float('ftrl_l1', 0.0, 93 | 'The FTRL l1 regularization strength.') 94 | tf.app.flags.DEFINE_float('ftrl_l2', 0.0, 95 | 'The FTRL l2 regularization strength.') 96 | tf.app.flags.DEFINE_float( 97 | 'momentum', 0.9, 98 | 'The momentum for the MomentumOptimizer and RMSPropOptimizer.') 99 | tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.') 100 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') 101 | 102 | # =========================================================================== # 103 | # Learning Rate Flags. 104 | # =========================================================================== # 105 | tf.app.flags.DEFINE_string( 106 | 'learning_rate_decay_type', 'exponential', 107 | 'Specifies how the learning rate is decayed. One of "fixed", "exponential",' 108 | ' or "polynomial"') 109 | tf.app.flags.DEFINE_float('learning_rate', 0.0001, 'Initial learning rate.') 110 | tf.app.flags.DEFINE_float( 111 | 'end_learning_rate', 0.0001, 112 | 'The minimal end learning rate used by a polynomial decay learning rate.') 113 | tf.app.flags.DEFINE_float('label_smoothing', 0.0, 114 | 'The amount of label smoothing.') 115 | tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.1, 116 | 'Learning rate decay factor.') 117 | tf.app.flags.DEFINE_float( 118 | 'num_epochs_per_decay', 40000, 119 | 'Number of epochs after which learning rate decays.') 120 | tf.app.flags.DEFINE_float( 121 | 'moving_average_decay', None, 'The decay to use for the moving average.' 122 | 'If left as None, then moving averages are not used.') 123 | 124 | # =========================================================================== # 125 | # Dataset Flags. 126 | # =========================================================================== # 127 | tf.app.flags.DEFINE_string('dataset_name', 'sythtext', 128 | 'The name of the dataset to load.') 129 | tf.app.flags.DEFINE_integer('num_classes', 2, 130 | 'Number of classes to use in the dataset.') 131 | tf.app.flags.DEFINE_string('dataset_split_name', 'train', 132 | 'The name of the train/test split.') 133 | tf.app.flags.DEFINE_string( 134 | 'dataset_dir', 'icdar15_tf', 135 | ' The directory where the dataset files are stored.') 136 | tf.app.flags.DEFINE_integer( 137 | 'labels_offset', 0, 138 | 'An offset for the labels in the dataset. This flag is primarily used to ' 139 | 'evaluate the VGG and ResNet architectures which do not use a background ' 140 | 'class for the ImageNet dataset.') 141 | tf.app.flags.DEFINE_string('model_name', 'text_box_384', 142 | 'The name of the architecture to train.') 143 | tf.app.flags.DEFINE_string( 144 | 'preprocessing_name', None, 145 | 'The name of the preprocessing to use. If left ' 146 | 'as `None`, then the model_name flag is used.') 147 | tf.app.flags.DEFINE_integer('batch_size', 16, 148 | 'The number of samples in each batch.') 149 | tf.app.flags.DEFINE_integer('train_image_size', None, 'Train image size') 150 | tf.app.flags.DEFINE_string('training_image_crop_area', '0.1, 1.0', 151 | 'the area of image process for training') 152 | tf.app.flags.DEFINE_integer('max_number_of_steps', 120000, 153 | 'The maxim number of training steps.') 154 | # =========================================================================== # 155 | # Fine-Tuning Flags. 156 | # =========================================================================== # 157 | tf.app.flags.DEFINE_string( 158 | #'checkpoint_path','/home/zsz/code/TextBoxes_plusplus_Tensorflow/model/vgg_fc_16_model/vgg_16.ckpt', 159 | 'checkpoint_path', '/home/zsz/TextBoxes_plusplus/models/ckpt/', 160 | 'The path to a checkpoint from which to fine-tune.') 161 | tf.app.flags.DEFINE_string( 162 | 'checkpoint_model_scope', None, 163 | 'Model scope in the checkpoint. None if the same as the trained model.') 164 | tf.app.flags.DEFINE_string( 165 | 'checkpoint_exclude_scopes', None, 166 | 'Comma-separated list of scopes of variables to exclude when restoring ' 167 | 'from a checkpoint.') 168 | tf.app.flags.DEFINE_string( 169 | 'trainable_scopes', None, 170 | 'Comma-separated list of scopes to filter the set of variables to train.' 171 | 'By default, None would train all the variables.') 172 | tf.app.flags.DEFINE_boolean( 173 | 'ignore_missing_vars', False, 174 | 'When restoring a checkpoint would ignore missing variables.') 175 | 176 | FLAGS = tf.app.flags.FLAGS 177 | 178 | 179 | # =========================================================================== # 180 | # Main training routine. 181 | # =========================================================================== # 182 | def main(_): 183 | if not FLAGS.dataset_dir: 184 | raise ValueError( 185 | 'You must supply the dataset directory with --dataset_dir') 186 | 187 | tf.logging.set_verbosity(tf.logging.DEBUG) 188 | with tf.Graph().as_default(): 189 | # Config model_deploy. Keep TF Slim Models structure. 190 | # Useful if want to need multiple GPUs and/or servers in the future. 191 | deploy_config = model_deploy.DeploymentConfig( 192 | num_clones=FLAGS.num_clones, 193 | clone_on_cpu=FLAGS.clone_on_cpu, 194 | replica_id=0, 195 | num_replicas=1, 196 | num_ps_tasks=0) 197 | # Create global_step. 198 | with tf.device(deploy_config.variables_device()): 199 | global_step = slim.create_global_step() 200 | 201 | # Select the dataset. 202 | 203 | dataset = sythtextprovider.get_datasets(FLAGS.dataset_dir) 204 | # Get the TextBoxes++ network and its anchors. 205 | text_net = txtbox_384.TextboxNet() 206 | if FLAGS.large_training: 207 | text_net.params = text_net.params._replace(img_shape = (768, 768)) 208 | text_net.params = text_net.params._replace(feat_shapes = [(96, 96), (48,48), (24, 24), (12, 12), (10, 10), (8, 8)]) 209 | text_shape = text_net.params.img_shape 210 | print('text_shape ' + str(text_shape)) 211 | text_anchors = text_net.anchors(text_shape) 212 | 213 | tf_utils.print_configuration(FLAGS.__flags, text_net.params, 214 | dataset.data_sources, FLAGS.train_dir) 215 | # =================================================================== # 216 | # Create a dataset provider and batches. 217 | # =================================================================== # 218 | 219 | with tf.device(deploy_config.inputs_device()): 220 | with tf.name_scope(FLAGS.dataset_name + '_data_provider'): 221 | provider = slim.dataset_data_provider.DatasetDataProvider( 222 | dataset, 223 | num_readers=FLAGS.num_readers, 224 | common_queue_capacity=1000 * FLAGS.batch_size, 225 | common_queue_min=300 * FLAGS.batch_size, 226 | shuffle=True) 227 | # Get for SSD network: image, labels, bboxes. 228 | [image, shape, glabels, gbboxes, x1, x2, x3, x4, y1, y2, y3, 229 | y4] = provider.get([ 230 | 'image', 'shape', 'object/label', 'object/bbox', 231 | 'object/oriented_bbox/x1', 'object/oriented_bbox/x2', 232 | 'object/oriented_bbox/x3', 'object/oriented_bbox/x4', 233 | 'object/oriented_bbox/y1', 'object/oriented_bbox/y2', 234 | 'object/oriented_bbox/y3', 'object/oriented_bbox/y4' 235 | ]) 236 | gxs = tf.transpose(tf.stack([x1, x2, x3, x4])) #shape = (N,4) 237 | gys = tf.transpose(tf.stack([y1, y2, y3, y4])) 238 | 239 | image = tf.identity(image, 'input_image') 240 | 241 | init_op = tf.global_variables_initializer() 242 | # tf.global_variables_initializer() 243 | # Pre-processing image, labels and bboxes. 244 | training_image_crop_area = FLAGS.training_image_crop_area 245 | area_split = training_image_crop_area.split(',') 246 | assert len(area_split) == 2 247 | training_image_crop_area = [ 248 | float(area_split[0]), 249 | float(area_split[1]) 250 | ] 251 | 252 | image, glabels, gbboxes, gxs, gys= \ 253 | ssd_vgg_preprocessing.preprocess_for_train(image, glabels,gbboxes,gxs, gys, 254 | text_shape, 255 | data_format='NHWC', crop_area_range=training_image_crop_area) 256 | 257 | # Encode groundtruth labels and bboxes. 258 | 259 | image = tf.identity(image, 'processed_image') 260 | 261 | glocalisations, gscores, glabels = \ 262 | text_net.bboxes_encode( glabels, gbboxes, text_anchors, gxs, gys) 263 | batch_shape = [1] + [len(text_anchors)] * 3 264 | 265 | # Training batches and queue. 266 | 267 | r = tf.train.batch( 268 | tf_utils.reshape_list( 269 | [image, glocalisations, gscores, glabels]), 270 | batch_size=FLAGS.batch_size, 271 | num_threads=FLAGS.num_preprocessing_threads, 272 | capacity=5 * FLAGS.batch_size) 273 | 274 | b_image, b_glocalisations, b_gscores, b_glabels= \ 275 | tf_utils.reshape_list(r, batch_shape) 276 | 277 | # Intermediate queueing: unique batch computation pipeline for all 278 | # GPUs running the training. 279 | batch_queue = slim.prefetch_queue.prefetch_queue( 280 | tf_utils.reshape_list( 281 | [b_image, b_glocalisations, b_gscores, b_glabels]), 282 | capacity=2 * deploy_config.num_clones) 283 | 284 | # =================================================================== # 285 | # Define the model running on every GPU. 286 | # =================================================================== # 287 | def clone_fn(batch_queue): 288 | 289 | #Allows data parallelism by creating multiple 290 | #clones of network_fn. 291 | # Dequeue batch. 292 | b_image, b_glocalisations, b_gscores, b_glabels = \ 293 | tf_utils.reshape_list(batch_queue.dequeue(), batch_shape) 294 | 295 | # Construct TextBoxes network. 296 | arg_scope = text_net.arg_scope(weight_decay=FLAGS.weight_decay) 297 | with slim.arg_scope(arg_scope): 298 | predictions,localisations, logits, end_points = \ 299 | text_net.net(b_image, is_training=True) 300 | # Add loss function. 301 | 302 | text_net.losses( 303 | logits, 304 | localisations, 305 | b_glabels, 306 | b_glocalisations, 307 | b_gscores, 308 | match_threshold=FLAGS.match_threshold, 309 | negative_ratio=FLAGS.negative_ratio, 310 | alpha=FLAGS.loss_alpha, 311 | label_smoothing=FLAGS.label_smoothing, 312 | batch_size=FLAGS.batch_size) 313 | return end_points 314 | 315 | # Gather initial summaries. 316 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 317 | 318 | # =================================================================== # 319 | # Add summaries from first clone. 320 | # =================================================================== # 321 | clones = model_deploy.create_clones(deploy_config, clone_fn, 322 | [batch_queue]) 323 | first_clone_scope = deploy_config.clone_scope(0) 324 | # Gather update_ops from the first clone. These contain, for example, 325 | # the updates for the batch_norm variables created by network_fn. 326 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 327 | 328 | # Add summaries for end_points. 329 | end_points = clones[0].outputs 330 | for end_point in end_points: 331 | x = end_points[end_point] 332 | summaries.add(tf.summary.histogram('activations/' + end_point, x)) 333 | summaries.add( 334 | tf.summary.scalar('sparsity/' + end_point, 335 | tf.nn.zero_fraction(x))) 336 | # Add summaries for losses and extra losses. 337 | for loss in tf.get_collection(tf.GraphKeys.LOSSES): 338 | summaries.add(tf.summary.scalar(loss.op.name, loss)) 339 | for loss in tf.get_collection('EXTRA_LOSSES'): 340 | summaries.add(tf.summary.scalar(loss.op.name, loss)) 341 | 342 | # Add summaries for variables. 343 | for variable in slim.get_model_variables(): 344 | summaries.add(tf.summary.histogram(variable.op.name, variable)) 345 | 346 | # =================================================================== # 347 | # Configure the moving averages. 348 | # =================================================================== # 349 | if FLAGS.moving_average_decay: 350 | moving_average_variables = slim.get_model_variables() 351 | variable_averages = tf.train.ExponentialMovingAverage( 352 | FLAGS.moving_average_decay, global_step) 353 | else: 354 | moving_average_variables, variable_averages = None, None 355 | 356 | # =================================================================== # 357 | # Configure the optimization procedure. 358 | # =================================================================== # 359 | with tf.device(deploy_config.optimizer_device()): 360 | learning_rate = tf_utils.configure_learning_rate( 361 | FLAGS, dataset.num_samples, global_step) 362 | optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate) 363 | summaries.add(tf.summary.scalar('learning_rate', learning_rate)) 364 | 365 | if FLAGS.moving_average_decay: 366 | # Update ops executed locally by trainer. 367 | update_ops.append( 368 | variable_averages.apply(moving_average_variables)) 369 | 370 | # Variables to train. 371 | variables_to_train = tf_utils.get_variables_to_train(FLAGS) 372 | 373 | # and returns a train_tensor and summary_op 374 | total_loss, clones_gradients = model_deploy.optimize_clones( 375 | clones, optimizer, var_list=variables_to_train) 376 | # Add total_loss to summary. 377 | 378 | summaries.add(tf.summary.scalar('total_loss', total_loss)) 379 | 380 | # Create gradient updates. 381 | grad_updates = optimizer.apply_gradients( 382 | clones_gradients, global_step=global_step) 383 | update_ops.append(grad_updates) 384 | update_op = tf.group(*update_ops) 385 | train_tensor = control_flow_ops.with_dependencies( 386 | [update_op], total_loss, name='train_op') 387 | 388 | # Add the summaries from the first clone. These contain the summaries 389 | summaries |= set( 390 | tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) 391 | # Merge all summaries together. 392 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 393 | 394 | # =================================================================== # 395 | # Kicks off the training. 396 | # =================================================================== # 397 | gpu_options = tf.GPUOptions( 398 | per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction) 399 | 400 | config = tf.ConfigProto( 401 | log_device_placement=False, 402 | allow_soft_placement=True, 403 | gpu_options=gpu_options) 404 | saver = tf.train.Saver( 405 | max_to_keep=100, 406 | keep_checkpoint_every_n_hours=1.0, 407 | write_version=2, 408 | pad_step_number=False) 409 | 410 | slim.learning.train( 411 | train_tensor, 412 | logdir=FLAGS.train_dir, 413 | master='', 414 | is_chief=True, 415 | # init_op=init_op, 416 | init_fn=tf_utils.get_init_fn(FLAGS), 417 | summary_op=summary_op, ##output variables to logdir 418 | number_of_steps=FLAGS.max_number_of_steps, 419 | log_every_n_steps=FLAGS.log_every_n_steps, 420 | save_summaries_secs=FLAGS.save_summaries_secs, 421 | saver=saver, 422 | save_interval_secs=FLAGS.save_interval_secs, 423 | session_config=config, 424 | sync_optimizer=None) 425 | 426 | 427 | if __name__ == '__main__': 428 | tf.app.run() 429 | --------------------------------------------------------------------------------