├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── caffe_to_tensorflow.py ├── config.py ├── datasets ├── __init__.py ├── dataset_factory.py ├── dataset_utils.py ├── icdar2013_to_tfrecords.py ├── icdar2015_to_tfrecords.py ├── scut_to_tfrecords.py └── synthtext_to_tfrecords.py ├── eval_seglink.py ├── img_10_pred.jpg ├── img_31_pred.jpg ├── nets ├── __init__.py ├── anchor_layer.py ├── net_factory.py ├── seglink_symbol.py └── vgg.py ├── preprocessing ├── __init__.py ├── preprocessing_factory.py ├── ssd_vgg_preprocessing.py └── tf_image.py ├── push.sh ├── scripts ├── eval.sh ├── test.sh ├── train.sh └── vis.sh ├── test ├── test_batch_and_gt.py └── test_preprocessing.py ├── test_seglink.py ├── tf_extended ├── __init__.py ├── bboxes.py ├── math.py ├── metrics.py └── seglink.py ├── train_seglink.py └── visualize_detection_result.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.settings/ 6 | *.project 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pylib"] 2 | path = pylib 3 | url = git@github.com:dengdan/pylib.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Tips: A more recent scene text detection algorithm: [PixelLink](https://arxiv.org/abs/1801.01315), has been implemented here: https://github.com/ZJULearning/pixel_link 2 | 3 | 4 | Contents: 5 | 1. [Introduction](https://github.com/dengdan/seglink#introduction) 6 | 2. [Installation&requirements](https://github.com/dengdan/seglink#installationrequirements) 7 | 3. [Datasets](https://github.com/dengdan/seglink#datasets) 8 | 3. [Problems](https://github.com/dengdan/seglink#problems) 9 | 5. [Models](https://github.com/dengdan/seglink#models) 10 | 4. [Test Your own images](https://github.com/dengdan/seglink#test-your-own-images) 11 | 5. [Models](https://github.com/dengdan/seglink#training-and-evaluation) 12 | 5. [Some Comments](https://github.com/dengdan/seglink#some-comments) 13 |
14 | 15 | # Introduction 16 | 17 | This is a re-implementation of the SegLink text detection algorithm described in the paper [Detecting Oriented Text in Natural Images by Linking Segments, Baoguang Shi, Xiang Bai, Serge Belongie](https://arxiv.org/abs/1703.06520) 18 | 19 | 20 | 21 | # Installation&requirements 22 | 23 | 1. tensorflow-gpu 1.1.0 24 | 25 | 2. cv2. I'm using 2.4.9.1, but some other versions less than 3 should be OK too. If not, try to switch to the version as mine. 26 | 27 | 3. download the project [pylib](https://github.com/dengdan/pylib) and add the `src` folder to your `PYTHONPATH` 28 | 29 | 30 | 31 | If any other requirements unmet, just install them following the error msg. 32 | 33 | 34 | 35 | # Datasets 36 | 37 | 1. [SynthText](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) 38 | 39 | 2. [ICDAR2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 40 | 41 | Convert them into tfrecords format using the scripts in `datasets` if you wanna train your own model. 42 | 43 | 44 | 45 | # Problems 46 | 47 | The convergence speed of my seglink is quite slow compared with that described in the paper. For example, the authors of SegLink paper said that a good result can be obtained by training on Synthtext for less than 10W iterations and on IC15-train for less than 1W iterations. However, using my implementation, I have to train on SynthText for about 20W iterations and another more than 10W iterations on IC15-train, to get a competitive result. 48 | 49 | Several reasons may contribute to the slow convergency of my model: 50 | 51 | 1. Batch size. I don't have 4 12G-Titans for training, as described in the paper. Instead, I trained my model on two 8G GeForce GTX 1080 or two Titans. 52 | 2. Learning Rate. In the paper, 10^-3 and 10^-4 have been used. But I adopted a fixed learning rate of 10^-4. 53 | 3. Different initialization model. I used the pretrained VGG model from [SSD-caffe on coco](https://gist.github.com/weiliu89/2ed6e13bfd5b57cf81d6) , because I thought it better than VGG trained on ImageNet. However, it seems that my point of view does not hold. 54 | 4.Some other differences exists maybe, I am not sure. 55 | 56 | 57 | 58 | # Models 59 | 60 | Two models trained on SynthText and IC15 train can be downloaded. 61 | 62 | 1. [seglink-384](https://pan.baidu.com/s/1slqaYux). Trained using image size of 384x384, the same image size as the paper. The Hmean is comparable to the result reported in the paper. 63 | 64 | ![](http://fromwiz.com/share/resources/b3a92ec9-764c-470f-89a9-958c7cdeea1f/index_files/490589735.png) 65 | 66 | The `hust_orientedText` is the result of paper. 67 | 68 | 2. [seglink-512](https://pan.baidu.com/s/1slqaYux). Trainied using image size of 512x512, and one pointer better than 384x384. 69 | 70 | ![](http://fromwiz.com/share/resources/0f0c6085-322f-46bc-8535-9fed33620997/index_files/1569377909.png) 71 | 72 | 73 | 74 | They have been trained: 75 | 76 | * on Synthtext for about 20W iterations, and on IC15-train for 10w~20W iterations. 77 | 78 | * learning_rate = 10e-4 79 | 80 | * two gpus 81 | 82 | * 384: GTX 1080, batch_size = 24; 512: Titan, batch_size = 20 83 | 84 | **Both models perform best at `seg_conf_threshold=0.8` and `link_conf_threshold=0.5`**, well, another difference from paper, which takes 0.9 and 0.7 respectively. 85 | 86 | # Test Your own images 87 | 88 | Use the script `test_seglink.py`, and a shortcut has been created in `script test.sh`: 89 | 90 | Go to the seglink root directory and execute the command: 91 | 92 | ``` 93 | 94 | ./scripts/test.sh 0 GPU_ID CKPT_PATH DATASET_DIR 95 | 96 | ``` 97 | 98 | For example: 99 | 100 | ``` 101 | 102 | ./scripts/test.sh 0 ~/models/seglink/model.ckpt-217867 ~/dataset/ICDAR2015/Challenge4/ch4_training_images 103 | 104 | ``` 105 | 106 | I have only tested my models on IC15-test, but any other images can be used for test: just put your images into a directory, and config the path in the command as `DATASET_DIR`. 107 | 108 | A bunch of txt files and a zip file is created after test. If you are using IC15-test for testing, you can upload this zip file to the [icdar evaluation server](http://rrc.cvc.uab.es/) directly. 109 | 110 | 111 | 112 | The text files and placed in a subdir of the checkpoint directory, and contain the bounding boxes as the detection results, and can visualized using the script `visualize_detection_result.py`. 113 | 114 | The command looks like: 115 | 116 | ``` 117 | 118 | python visualize_detection_result.py \ 119 | 120 | --image=where your images are put 121 | 122 | --det=the directory of the text files output by test_seglink.py 123 | 124 | --output=the output directory of detection results drawn on images. 125 | 126 | ``` 127 | 128 | For example: 129 | 130 | ``` 131 | 132 | python visualize_detection_result.py \ 133 | 134 | --image=~/dataset/ICDAR2015/Challenge4/ch4_training_images/ \ 135 | 136 | --det=~/models/seglink/seglink_icdar2015_without_ignored/eval/icdar2015_train/model.ckpt-72885/seg_link_conf_th_0.900000_0.700000/txt \ 137 | --output=~/temp/no-use/seglink_result_512_train 138 | 139 | ``` 140 | 141 | ![](https://github.com/dengdan/seglink/blob/master/img_10_pred.jpg?raw=true) 142 | ![](https://github.com/dengdan/seglink/blob/master/img_31_pred.jpg?raw=true) 143 | 144 | # Training and evaluation 145 | 146 | The training processing requires data processing, i.e. converting data into tfrecords. The converting scripts are put in the `datasets` directory. The scrips:`train_seglink.py` and `eval_seglink.py` are the training and evaluation scripts respectively. Especially, I have implemented an offline evaluation function, which calculates the Recall/Precision/Hmean as the ICDAR test server, and can be used for cross validation and grid search. However, the resulting scores may have slight differences from those of test sever, but it does not matter that much. 147 | Sorry for the imcomplete documentation here. Read and modify them if you want to train your own model. 148 | 149 | 150 | 151 | # Some Comments 152 | 153 | Thanks should be given to the authors of the Seglink paper, i.e., Baoguang Shi1 Xiang Bai1, Serge Belongie. 154 | 155 | [EAST](https://arxiv.org/abs/1704.03155) is another paper on text detection accepted by CVPR 2017, and its reported result is better than that of SegLink. But if they both use same VGG16, their performances are quite similar. 156 | 157 | Contact me if you have any problems, through github issues. 158 | 159 | # Some Notes On Implementation Detail 160 | How the groundtruth is calculated, in Chinese: http://fromwiz.com/share/s/34GeEW1RFx7x2iIM0z1ZXVvc2yLl5t2fTkEg2ZVhJR2n50xg 161 | 162 | -------------------------------------------------------------------------------- /caffe_to_tensorflow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import caffe 4 | from caffe.proto import caffe_pb2 5 | import util 6 | from nets import seglink_symbol 7 | 8 | caffemodel_path=util.io.get_absolute_path('~/models/ssd-pretrain/VGG_coco_SSD_512x512_iter_360000.caffemodel') 9 | class CaffeScope(): 10 | def __init__(self): 11 | print('Loading Caffe file:', caffemodel_path) 12 | caffemodel_params = caffe_pb2.NetParameter() 13 | caffemodel_str = open(caffemodel_path, 'rb').read() 14 | caffemodel_params.ParseFromString(caffemodel_str) 15 | caffe_layers = caffemodel_params.layer 16 | self.layers = [] 17 | self.counter = 0 18 | self.bgr_to_rgb = False 19 | for layer in caffe_layers: 20 | if layer.type == 'Convolution': 21 | self.layers.append(layer) 22 | 23 | def conv_weights_init(self): 24 | def _initializer(shape, dtype, partition_info=None): 25 | layer = self.layers[self.counter] 26 | w = np.array(layer.blobs[0].data) 27 | # Weights: reshape and transpose dimensions. 28 | w = np.reshape(w, layer.blobs[0].shape.dim) 29 | # w = np.transpose(w, (1, 0, 2, 3)) 30 | w = np.transpose(w, (2, 3, 1, 0)) 31 | if self.bgr_to_rgb and w.shape[2] == 3: 32 | print('Convert BGR to RGB in convolution layer:', layer.name) 33 | w[:, :, (0, 1, 2)] = w[:, :, (2, 1, 0)] 34 | self.bgr_to_rgb = False 35 | np.testing.assert_equal(w.shape, shape) 36 | print('Load weights from convolution layer:', layer.name, w.shape, shape) 37 | return tf.cast(w, dtype) 38 | 39 | return _initializer 40 | 41 | def conv_biases_init(self): 42 | def _initializer(shape, dtype, partition_info=None): 43 | layer = self.layers[self.counter] 44 | self.counter = self.counter + 1 45 | b = np.array(layer.blobs[1].data) 46 | 47 | print('Load biases from convolution layer:', layer.name, b.shape) 48 | return tf.cast(b, dtype) 49 | return _initializer 50 | 51 | 52 | caffe_scope = CaffeScope() 53 | 54 | # def get_seglink_model(): 55 | fake_image = tf.placeholder(dtype = tf.float32, shape = [1, 512, 1024, 3]) 56 | seglink_net = seglink_symbol.SegLinkNet(inputs = fake_image, weight_decay = 0.01, 57 | weights_initializer = caffe_scope.conv_weights_init(), biases_initializer = caffe_scope.conv_biases_init()) 58 | init_op = tf.global_variables_initializer() 59 | with tf.Session() as session: 60 | # Run the init operation. 61 | session.run(init_op) 62 | 63 | # Save model in checkpoint. 64 | saver = tf.train.Saver(write_version=2) 65 | parent_dir = util.io.get_dir(caffemodel_path) 66 | filename = util.io.get_filename(caffemodel_path) 67 | parent_dir = util.io.mkdir(util.io.join_path(parent_dir, 'seglink')) 68 | filename = filename.replace('.caffemodel', '.ckpt') 69 | ckpt_path = util.io.join_path(parent_dir, filename) 70 | saver.save(session, ckpt_path, write_meta_graph=False) 71 | 72 | vars = tf.global_variables() 73 | layers_to_convert = ['conv1_1', 'conv1_2', 74 | 'conv2_1', 'conv2_2', 75 | 'conv3_1', 'conv3_2', 'conv3_3', 76 | 'conv4_1', 'conv4_2', 'conv4_3', 77 | 'conv5_1', 'conv5_2', 'conv5_3', 78 | 'fc6', 'fc7', 79 | 'conv6_1', 'conv6_2', 80 | 'conv7_1', 'conv7_2', 81 | 'conv8_1', 'conv8_2', 82 | 'conv9_1', 'conv9_2', 83 | # 'conv10_1', 'conv10_2' 84 | ] 85 | 86 | def check_var(name): 87 | tf_weights = None 88 | tf_biases = None 89 | 90 | for var in vars: 91 | if util.str.contains(str(var.name), name) and util.str.contains(str(var.name), 'weight') and not util.str.contains(str(var.name), 'seglink'): 92 | tf_weights = var 93 | 94 | if util.str.contains(str(var.name), name) and util.str.contains(str(var.name), 'bias') and not util.str.contains(str(var.name), 'seglink'): 95 | tf_biases = var 96 | 97 | caffe_weights = None 98 | caffe_biases = None 99 | for layer in caffe_scope.layers: 100 | if name == layer.name: 101 | caffe_weights = layer.blobs[0].data 102 | caffe_biases = layer.blobs[1].data 103 | 104 | np.testing.assert_almost_equal(actual = np.mean(caffe_weights), desired = np.mean(tf_weights.eval(session))) 105 | np.testing.assert_almost_equal(actual = np.mean(caffe_biases), desired = np.mean(tf_biases.eval(session))) 106 | 107 | # check all vgg and extra layer weights/biases have been converted in a right way. 108 | for name in layers_to_convert: 109 | check_var(name) 110 | 111 | # just have peek into the values of seglink layers. The weights should not be initialized to 0. Just have a look. 112 | for var in vars: 113 | if util.str.contains(str(var.name), 'seglink'): 114 | print var.name, np.mean(var.eval(session)), np.std(var.eval(session)) 115 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from pprint import pprint 3 | import numpy as np 4 | from tensorflow.contrib.slim.python.slim.data import parallel_reader 5 | import tensorflow as tf 6 | slim = tf.contrib.slim 7 | import util 8 | 9 | 10 | global feat_shapes 11 | global image_shape 12 | 13 | 14 | global default_anchors 15 | global defalt_anchor_map 16 | global default_anchor_center_set 17 | global num_anchors 18 | global num_links 19 | 20 | 21 | global batch_size 22 | global batch_size_per_gpu 23 | global gpus 24 | global num_clones 25 | global clone_scopes 26 | 27 | 28 | global train_with_ignored 29 | global seg_loc_loss_weight 30 | global conf_cls_loss_weight 31 | 32 | global seg_conf_threshold 33 | global link_conf_threshold 34 | 35 | anchor_offset = 0.5 36 | anchor_scale_gamma = 1.5 37 | feat_layers = ['conv4_3','fc7', 'conv6_2', 'conv7_2', 'conv8_2', 'conv9_2'] 38 | # feat_norms = [20] + [-1] * len(feat_layers) 39 | max_height_ratio = 1.5 40 | # prior_scaling = [0.1, 0.2, 0.1, 0.2, 20.0] 41 | prior_scaling = [0.2, 0.5, 0.2, 0.5, 20.0] 42 | # prior_scaling = [1.0] * 5 43 | 44 | max_neg_pos_ratio = 3 45 | 46 | data_format = 'NHWC' 47 | def _set_image_shape(shape): 48 | global image_shape 49 | image_shape = shape 50 | 51 | def _set_feat_shapes(shapes): 52 | global feat_shapes 53 | feat_shapes = shapes 54 | 55 | def _set_batch_size(bz): 56 | global batch_size 57 | batch_size = bz 58 | 59 | def _set_det_th(seg_conf_th, link_conf_th): 60 | global seg_conf_threshold 61 | global link_conf_threshold 62 | 63 | seg_conf_threshold = seg_conf_th 64 | link_conf_threshold = link_conf_th 65 | 66 | def _set_loss_weight(seg_loc_loss_w, link_cls_loss_w): 67 | global seg_loc_loss_weight 68 | global link_cls_loss_weight 69 | seg_loc_loss_weight = seg_loc_loss_w 70 | link_cls_loss_weight = link_cls_loss_w 71 | 72 | def _set_train_with_ignored(train_with_ignored_): 73 | global train_with_ignored 74 | train_with_ignored = train_with_ignored_ 75 | 76 | def _build_anchor_map(): 77 | global default_anchor_map 78 | global default_anchor_center_set 79 | import collections 80 | default_anchor_map = collections.defaultdict(list) 81 | for anchor_idx, anchor in enumerate(default_anchors): 82 | default_anchor_map[(int(anchor[1]), int(anchor[0]))].append(anchor_idx) 83 | default_anchor_center_set = set(default_anchor_map.keys()) 84 | 85 | def init_config(image_shape, batch_size = 1, 86 | weight_decay = 0.0005, 87 | num_gpus = 1, 88 | train_with_ignored = False, 89 | seg_loc_loss_weight = 1.0, 90 | link_cls_loss_weight = 1.0, 91 | seg_conf_threshold = 0.5, 92 | link_conf_threshold = 0.5): 93 | 94 | _set_det_th(seg_conf_threshold, link_conf_threshold) 95 | _set_loss_weight(seg_loc_loss_weight, link_cls_loss_weight) 96 | _set_train_with_ignored(train_with_ignored) 97 | 98 | h, w = image_shape 99 | from nets import anchor_layer 100 | from nets import seglink_symbol 101 | fake_image = tf.ones((1, h, w, 3)) 102 | fake_net = seglink_symbol.SegLinkNet(inputs = fake_image, weight_decay = weight_decay) 103 | feat_shapes = fake_net.get_shapes(); 104 | 105 | # the placement of the following lines are extremely important 106 | _set_image_shape(image_shape) 107 | _set_feat_shapes(feat_shapes) 108 | 109 | anchors, _ = anchor_layer.generate_anchors() 110 | global default_anchors 111 | default_anchors = anchors 112 | 113 | global num_anchors 114 | num_anchors = len(anchors) 115 | 116 | _build_anchor_map() 117 | 118 | global num_links 119 | num_links = num_anchors * 8 + (num_anchors - np.prod(feat_shapes[feat_layers[0]])) * 4 120 | 121 | #init batch size 122 | global gpus 123 | gpus = util.tf.get_available_gpus(num_gpus) 124 | 125 | global num_clones 126 | num_clones = len(gpus) 127 | 128 | global clone_scopes 129 | clone_scopes = ['clone_%d'%(idx) for idx in xrange(num_clones)] 130 | 131 | _set_batch_size(batch_size) 132 | 133 | global batch_size_per_gpu 134 | batch_size_per_gpu = batch_size / num_clones 135 | if batch_size_per_gpu < 1: 136 | raise ValueError('Invalid batch_size [=%d], resulting in 0 images per gpu.'%(batch_size)) 137 | 138 | 139 | def print_config(flags, dataset, save_dir = None, print_to_file = True): 140 | def do_print(stream=None): 141 | print('\n# =========================================================================== #', file=stream) 142 | print('# Training flags:', file=stream) 143 | print('# =========================================================================== #', file=stream) 144 | pprint(flags.__flags, stream=stream) 145 | 146 | print('\n# =========================================================================== #', file=stream) 147 | print('# seglink net parameters:', file=stream) 148 | print('# =========================================================================== #', file=stream) 149 | vars = globals() 150 | for key in vars: 151 | var = vars[key] 152 | if util.dtype.is_number(var) or util.dtype.is_str(var) or util.dtype.is_list(var) or util.dtype.is_tuple(var): 153 | pprint('%s=%s'%(key, str(var)), stream = stream) 154 | 155 | print('\n# =========================================================================== #', file=stream) 156 | print('# Training | Evaluation dataset files:', file=stream) 157 | print('# =========================================================================== #', file=stream) 158 | data_files = parallel_reader.get_data_files(dataset.data_sources) 159 | pprint(sorted(data_files), stream=stream) 160 | print('', file=stream) 161 | do_print(None) 162 | 163 | if print_to_file: 164 | # Save to a text file as well. 165 | if save_dir is None: 166 | save_dir = flags.train_dir 167 | 168 | util.io.mkdir(save_dir) 169 | path = util.io.join_path(save_dir, 'training_config.txt') 170 | with open(path, "a") as out: 171 | do_print(out) 172 | 173 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """A factory-pattern class which returns classification image/label pairs.""" 2 | from datasets import dataset_utils 3 | 4 | class DatasetConfig(): 5 | def __init__(self, file_pattern, split_sizes): 6 | self.file_pattern = file_pattern 7 | self.split_sizes = split_sizes 8 | 9 | icdar2013 = DatasetConfig( 10 | file_pattern = 'icdar2013_%s.tfrecord', 11 | split_sizes = { 12 | 'train': 229, 13 | 'test': 233 14 | } 15 | ) 16 | icdar2015 = DatasetConfig( 17 | file_pattern = 'icdar2015_%s.tfrecord', 18 | split_sizes = { 19 | 'train': 1000, 20 | 'test': 500 21 | } 22 | ) 23 | 24 | scut = DatasetConfig( 25 | file_pattern = 'scut_%s.tfrecord', 26 | split_sizes = { 27 | 'train': 1715 28 | } 29 | ) 30 | 31 | synthtext = DatasetConfig( 32 | file_pattern = 'SynthText_*.tfrecord', 33 | split_sizes = { 34 | 'train': 858750 35 | } 36 | ) 37 | 38 | datasets_map = { 39 | 'icdar2013':icdar2013, 40 | 'icdar2015':icdar2015, 41 | 'scut':scut, 42 | 'synthtext':synthtext 43 | } 44 | 45 | 46 | def get_dataset(dataset_name, split_name, dataset_dir, reader=None): 47 | """Given a dataset dataset_name and a split_name returns a Dataset. 48 | Args: 49 | dataset_name: String, the dataset_name of the dataset. 50 | split_name: A train/test split dataset_name. 51 | dataset_dir: The directory where the dataset files are stored. 52 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 53 | reader defined by each dataset is used. 54 | Returns: 55 | A `Dataset` class. 56 | Raises: 57 | ValueError: If the dataset `dataset_name` is unknown. 58 | """ 59 | if dataset_name not in datasets_map: 60 | raise ValueError('Name of dataset unknown %s' % dataset_name) 61 | dataset_config = datasets_map[dataset_name]; 62 | file_pattern = dataset_config.file_pattern 63 | num_samples = dataset_config.split_sizes[split_name] 64 | return dataset_utils.get_split(split_name, dataset_dir,file_pattern, num_samples, reader) 65 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | 17 | import tensorflow as tf 18 | import numpy as np 19 | slim = tf.contrib.slim 20 | 21 | import util 22 | 23 | 24 | 25 | def int64_feature(value): 26 | """Wrapper for inserting int64 features into Example proto. 27 | """ 28 | if not isinstance(value, list): 29 | value = [value] 30 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 31 | 32 | 33 | def float_feature(value): 34 | """Wrapper for inserting float features into Example proto. 35 | """ 36 | if not isinstance(value, list): 37 | value = [value] 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | def bytes_feature(value): 42 | """Wrapper for inserting bytes features into Example proto. 43 | """ 44 | if not isinstance(value, list): 45 | value = [value] 46 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 47 | 48 | 49 | def image_to_tfexample(image_data, image_format, height, width, class_id): 50 | return tf.train.Example(features=tf.train.Features(feature={ 51 | 'image/encoded': bytes_feature(image_data), 52 | 'image/format': bytes_feature(image_format), 53 | 'image/class/label': int64_feature(class_id), 54 | 'image/height': int64_feature(height), 55 | 'image/width': int64_feature(width), 56 | })) 57 | 58 | 59 | def convert_to_example(image_data, filename, labels, ignored, labels_text, bboxes, oriented_bboxes, shape): 60 | """Build an Example proto for an image example. 61 | Args: 62 | image_data: string, JPEG encoding of RGB image; 63 | labels: list of integers, identifier for the ground truth; 64 | labels_text: list of strings, human-readable labels; 65 | oriented_bboxes: list of bounding oriented boxes; each box is a list of floats in [0, 1]; 66 | specifying [x1, y1, x2, y2, x3, y3, x4, y4] 67 | bboxes: list of bbox in rectangle, [xmin, ymin, xmax, ymax] 68 | Returns: 69 | Example proto 70 | """ 71 | 72 | image_format = b'JPEG' 73 | oriented_bboxes = np.asarray(oriented_bboxes) 74 | bboxes = np.asarray(bboxes) 75 | example = tf.train.Example(features=tf.train.Features(feature={ 76 | 'image/shape': int64_feature(list(shape)), 77 | 'image/object/bbox/xmin': float_feature(list(bboxes[:, 0])), 78 | 'image/object/bbox/ymin': float_feature(list(bboxes[:, 1])), 79 | 'image/object/bbox/xmax': float_feature(list(bboxes[:, 2])), 80 | 'image/object/bbox/ymax': float_feature(list(bboxes[:, 3])), 81 | 'image/object/bbox/x1': float_feature(list(oriented_bboxes[:, 0])), 82 | 'image/object/bbox/y1': float_feature(list(oriented_bboxes[:, 1])), 83 | 'image/object/bbox/x2': float_feature(list(oriented_bboxes[:, 2])), 84 | 'image/object/bbox/y2': float_feature(list(oriented_bboxes[:, 3])), 85 | 'image/object/bbox/x3': float_feature(list(oriented_bboxes[:, 4])), 86 | 'image/object/bbox/y3': float_feature(list(oriented_bboxes[:, 5])), 87 | 'image/object/bbox/x4': float_feature(list(oriented_bboxes[:, 6])), 88 | 'image/object/bbox/y4': float_feature(list(oriented_bboxes[:, 7])), 89 | 'image/object/bbox/label': int64_feature(labels), 90 | 'image/object/bbox/label_text': bytes_feature(labels_text), 91 | 'image/object/bbox/ignored': int64_feature(ignored), 92 | 'image/format': bytes_feature(image_format), 93 | 'image/filename': bytes_feature(filename), 94 | 'image/encoded': bytes_feature(image_data)})) 95 | return example 96 | 97 | 98 | 99 | def get_split(split_name, dataset_dir, file_pattern, num_samples, reader=None): 100 | dataset_dir = util.io.get_absolute_path(dataset_dir) 101 | 102 | if util.str.contains(file_pattern, '%'): 103 | file_pattern = util.io.join_path(dataset_dir, file_pattern % split_name) 104 | else: 105 | file_pattern = util.io.join_path(dataset_dir, file_pattern) 106 | # Allowing None in the signature so that dataset_factory can use the default. 107 | if reader is None: 108 | reader = tf.TFRecordReader 109 | keys_to_features = { 110 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 111 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 112 | 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''), 113 | 'image/shape': tf.FixedLenFeature([3], tf.int64), 114 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 115 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 116 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 117 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 118 | 'image/object/bbox/x1': tf.VarLenFeature(dtype=tf.float32), 119 | 'image/object/bbox/x2': tf.VarLenFeature(dtype=tf.float32), 120 | 'image/object/bbox/x3': tf.VarLenFeature(dtype=tf.float32), 121 | 'image/object/bbox/x4': tf.VarLenFeature(dtype=tf.float32), 122 | 'image/object/bbox/y1': tf.VarLenFeature(dtype=tf.float32), 123 | 'image/object/bbox/y2': tf.VarLenFeature(dtype=tf.float32), 124 | 'image/object/bbox/y3': tf.VarLenFeature(dtype=tf.float32), 125 | 'image/object/bbox/y4': tf.VarLenFeature(dtype=tf.float32), 126 | 'image/object/bbox/ignored': tf.VarLenFeature(dtype=tf.int64), 127 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 128 | } 129 | items_to_handlers = { 130 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 131 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 132 | 'filename': slim.tfexample_decoder.Tensor('image/filename'), 133 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 134 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 135 | 'object/oriented_bbox/x1': slim.tfexample_decoder.Tensor('image/object/bbox/x1'), 136 | 'object/oriented_bbox/x2': slim.tfexample_decoder.Tensor('image/object/bbox/x2'), 137 | 'object/oriented_bbox/x3': slim.tfexample_decoder.Tensor('image/object/bbox/x3'), 138 | 'object/oriented_bbox/x4': slim.tfexample_decoder.Tensor('image/object/bbox/x4'), 139 | 'object/oriented_bbox/y1': slim.tfexample_decoder.Tensor('image/object/bbox/y1'), 140 | 'object/oriented_bbox/y2': slim.tfexample_decoder.Tensor('image/object/bbox/y2'), 141 | 'object/oriented_bbox/y3': slim.tfexample_decoder.Tensor('image/object/bbox/y3'), 142 | 'object/oriented_bbox/y4': slim.tfexample_decoder.Tensor('image/object/bbox/y4'), 143 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), 144 | 'object/ignored': slim.tfexample_decoder.Tensor('image/object/bbox/ignored') 145 | } 146 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 147 | 148 | labels_to_names = {0:'background', 1:'text'} 149 | items_to_descriptions = { 150 | 'image': 'A color image of varying height and width.', 151 | 'shape': 'Shape of the image', 152 | 'object/bbox': 'A list of bounding boxes, one per each object.', 153 | 'object/label': 'A list of labels, one per each object.', 154 | } 155 | 156 | return slim.dataset.Dataset( 157 | data_sources=file_pattern, 158 | reader=reader, 159 | decoder=decoder, 160 | num_samples=num_samples, 161 | items_to_descriptions=items_to_descriptions, 162 | num_classes=2, 163 | labels_to_names=labels_to_names) 164 | -------------------------------------------------------------------------------- /datasets/icdar2013_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import numpy as np; 3 | import tensorflow as tf 4 | import util 5 | from dataset_utils import int64_feature, float_feature, bytes_feature, convert_to_example 6 | 7 | 8 | def cvt_to_tfrecords(output_path , data_path, gt_path): 9 | image_names = util.io.ls(data_path, '.jpg')#[0:10]; 10 | print "%d images found in %s"%(len(image_names), data_path); 11 | 12 | with tf.python_io.TFRecordWriter(output_path) as tfrecord_writer: 13 | for idx, image_name in enumerate(image_names): 14 | oriented_bboxes = []; 15 | bboxes = [] 16 | labels = []; 17 | labels_text = []; 18 | ignored = [] 19 | path = util.io.join_path(data_path, image_name); 20 | print "\tconverting image: %d/%d %s"%(idx, len(image_names), image_name); 21 | image_data = tf.gfile.FastGFile(path, 'r').read() 22 | 23 | image = util.img.imread(path, rgb = True); 24 | shape = image.shape 25 | h, w = shape[0:2]; 26 | h *= 1.0; 27 | w *= 1.0; 28 | image_name = util.str.split(image_name, '.')[0]; 29 | gt_name = 'gt_' + image_name + '.txt'; 30 | gt_filepath = util.io.join_path(gt_path, gt_name); 31 | lines = util.io.read_lines(gt_filepath); 32 | 33 | for line in lines: 34 | gt = util.str.remove_all(line, ',') 35 | gt = util.str.split(gt, ' '); 36 | bbox = [int(gt[i]) for i in range(4)]; 37 | xmin, ymin, xmax, ymax = np.asarray(bbox) / [w, h, w, h]; 38 | oriented_bboxes.append([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]); 39 | bboxes.append([xmin, ymin, xmax, ymax]) 40 | ignored.append(0); 41 | labels_text.append(line.split('"')[1]); 42 | labels.append(1); 43 | example = convert_to_example(image_data, image_name, labels, ignored, labels_text, bboxes, oriented_bboxes, shape) 44 | tfrecord_writer.write(example.SerializeToString()) 45 | 46 | if __name__ == "__main__": 47 | root_dir = util.io.get_absolute_path('~/dataset/ICDAR2015/Challenge2.Task123/') 48 | training_data_dir = util.io.join_path(root_dir, 'Challenge2_Training_Task12_Images') 49 | training_gt_dir = util.io.join_path(root_dir,'Challenge2_Training_Task1_GT') 50 | test_data_dir = util.io.join_path(root_dir,'Challenge2_Test_Task12_Images') 51 | test_gt_dir = util.io.join_path(root_dir,'Challenge2_Test_Task1_GT') 52 | 53 | output_dir = util.io.get_absolute_path('~/dataset/SSD-tf/ICDAR/') 54 | util.io.mkdir(output_dir); 55 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'icdar2013_train.tfrecord'), data_path = training_data_dir, gt_path = training_gt_dir) 56 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'icdar2013_test.tfrecord'), data_path = test_data_dir, gt_path = test_gt_dir) 57 | -------------------------------------------------------------------------------- /datasets/icdar2015_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import numpy as np; 3 | import tensorflow as tf 4 | import util 5 | from dataset_utils import int64_feature, float_feature, bytes_feature, convert_to_example 6 | 7 | 8 | def cvt_to_tfrecords(output_path , data_path, gt_path): 9 | image_names = util.io.ls(data_path, '.jpg')#[0:10]; 10 | print "%d images found in %s"%(len(image_names), data_path); 11 | 12 | with tf.python_io.TFRecordWriter(output_path) as tfrecord_writer: 13 | for idx, image_name in enumerate(image_names): 14 | oriented_bboxes = []; 15 | bboxes = [] 16 | labels = []; 17 | labels_text = []; 18 | ignored = [] 19 | path = util.io.join_path(data_path, image_name); 20 | print "\tconverting image: %d/%d %s"%(idx, len(image_names), image_name); 21 | image_data = tf.gfile.FastGFile(path, 'r').read() 22 | 23 | image = util.img.imread(path, rgb = True); 24 | shape = image.shape 25 | h, w = shape[0:2]; 26 | h *= 1.0; 27 | w *= 1.0; 28 | image_name = util.str.split(image_name, '.')[0]; 29 | gt_name = 'gt_' + image_name + '.txt'; 30 | gt_filepath = util.io.join_path(gt_path, gt_name); 31 | lines = util.io.read_lines(gt_filepath); 32 | 33 | for line in lines: 34 | line = util.str.remove_all(line, '\xef\xbb\xbf') 35 | gt = util.str.split(line, ','); 36 | oriented_box = [int(gt[i]) for i in range(8)]; 37 | oriented_box = np.asarray(oriented_box) / ([w, h] * 4); 38 | oriented_bboxes.append(oriented_box); 39 | 40 | xs = oriented_box.reshape(4, 2)[:, 0] 41 | ys = oriented_box.reshape(4, 2)[:, 1] 42 | xmin = xs.min() 43 | xmax = xs.max() 44 | ymin = ys.min() 45 | ymax = ys.max() 46 | bboxes.append([xmin, ymin, xmax, ymax]) 47 | ignored.append(util.str.contains(gt[-1], '###')); 48 | 49 | # might be wrong here, but it doesn't matter because the label is not going to be used in detection 50 | labels_text.append(gt[-1]); 51 | labels.append(1); 52 | example = convert_to_example(image_data, image_name, labels, ignored, labels_text, bboxes, oriented_bboxes, shape) 53 | tfrecord_writer.write(example.SerializeToString()) 54 | 55 | if __name__ == "__main__": 56 | root_dir = util.io.get_absolute_path('~/dataset/ICDAR2015/Challenge4/') 57 | output_dir = util.io.get_absolute_path('~/dataset/SSD-tf/ICDAR/') 58 | util.io.mkdir(output_dir); 59 | 60 | training_data_dir = util.io.join_path(root_dir, 'ch4_training_images') 61 | training_gt_dir = util.io.join_path(root_dir,'ch4_training_localization_transcription_gt') 62 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'icdar2015_train.tfrecord'), data_path = training_data_dir, gt_path = training_gt_dir) 63 | 64 | # test_data_dir = util.io.join_path(root_dir, 'ch4_test_images') 65 | # test_gt_dir = util.io.join_path(root_dir,'ch4_test_localization_transcription_gt') 66 | # cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'icdar2015_test.tfrecord'), data_path = test_data_dir, gt_path = test_gt_dir) 67 | -------------------------------------------------------------------------------- /datasets/scut_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import numpy as np; 3 | import tensorflow as tf 4 | import util; 5 | from dataset_utils import int64_feature, float_feature, bytes_feature, convert_to_example 6 | 7 | 8 | def cvt_to_tfrecords(output_path , data_path, gt_path): 9 | image_names = util.io.ls(data_path, '.jpg')#[0:10]; 10 | print "%d images found in %s"%(len(image_names), data_path); 11 | with tf.python_io.TFRecordWriter(output_path) as tfrecord_writer: 12 | for idx, image_name in enumerate(image_names): 13 | bboxes = []; 14 | oriented_bboxes = [] 15 | labels = []; 16 | labels_text = []; 17 | ignored = [] 18 | path = util.io.join_path(data_path, image_name); 19 | if not util.img.is_valid_jpg(path): 20 | continue 21 | image = util.img.imread(path) 22 | print "\tconverting image:%s, %d/%d"%(image_name, idx, len(image_names)); 23 | image_data = tf.gfile.FastGFile(path, 'r').read() 24 | #image = util.img.imread(path, rgb = True); 25 | shape = image.shape 26 | h, w = shape[0:2]; 27 | h *= 1.0; 28 | w *= 1.0; 29 | image_name = util.str.split(image_name, '.')[0]; 30 | gt_name = image_name + '.txt'; 31 | gt_filepath = util.io.join_path(gt_path, gt_name); 32 | lines = util.io.read_lines(gt_filepath); 33 | for line in lines: 34 | spt = line.split(',') 35 | locs = spt[0: -1] 36 | xmin, ymin, bw, bh = [int(v) for v in locs] 37 | xmax = xmin + bw - 1 38 | ymax = ymin + bh - 1 39 | xmin, ymin, xmax, ymax = xmin / w, ymin/ h, xmax / w, ymax / h 40 | 41 | bboxes.append([xmin, ymin, xmax, ymax]); 42 | oriented_bboxes.append([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]) 43 | 44 | labels_text.append(str(spt[-1])); 45 | labels.append(1); 46 | ignored.append(0) 47 | example = convert_to_example(image_data, image_name, labels, ignored, labels_text, bboxes, oriented_bboxes, shape) 48 | tfrecord_writer.write(example.SerializeToString()) 49 | 50 | if __name__ == "__main__": 51 | root_dir = util.io.get_absolute_path('~/dataset/SCUT/SCUT_FORU_DB_Release/English2k/') 52 | training_data_dir = util.io.join_path(root_dir, 'word_img') 53 | training_gt_dir = util.io.join_path(root_dir,'word_annotation') 54 | output_dir = util.io.get_absolute_path('~/dataset/SSD-tf/SCUT/') 55 | util.io.mkdir(output_dir); 56 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'scut_train.tfrecord'), data_path = training_data_dir, gt_path = training_gt_dir) 57 | 58 | -------------------------------------------------------------------------------- /datasets/synthtext_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import numpy as np; 3 | import tensorflow as tf 4 | import util; 5 | from dataset_utils import int64_feature, float_feature, bytes_feature, convert_to_example 6 | 7 | # encoding = utf-8 8 | import numpy as np 9 | import time 10 | 11 | import util 12 | 13 | 14 | class SynthTextDataFetcher(): 15 | def __init__(self, mat_path, root_path): 16 | self.mat_path = mat_path 17 | self.root_path = root_path 18 | self._load_mat() 19 | 20 | @util.dec.print_calling 21 | def _load_mat(self): 22 | data = util.io.load_mat(self.mat_path) 23 | self.image_paths = data['imnames'][0] 24 | self.image_bbox = data['wordBB'][0] 25 | self.txts = data['txt'][0] 26 | self.num_images = len(self.image_paths) 27 | 28 | def get_image_path(self, idx): 29 | image_path = util.io.join_path(self.root_path, self.image_paths[idx][0]) 30 | return image_path 31 | 32 | def get_num_words(self, idx): 33 | try: 34 | return np.shape(self.image_bbox[idx])[2] 35 | except: # error caused by dataset 36 | return 1 37 | 38 | 39 | def get_word_bbox(self, img_idx, word_idx): 40 | boxes = self.image_bbox[img_idx] 41 | if len(np.shape(boxes)) ==2: # error caused by dataset 42 | boxes = np.reshape(boxes, (2, 4, 1)) 43 | 44 | xys = boxes[:,:, word_idx] 45 | assert(np.shape(xys) ==(2, 4)) 46 | return np.float32(xys) 47 | 48 | def normalize_bbox(self, xys, width, height): 49 | xs = xys[0, :] 50 | ys = xys[1, :] 51 | 52 | min_x = min(xs) 53 | min_y = min(ys) 54 | max_x = max(xs) 55 | max_y = max(ys) 56 | 57 | # bound them in the valid range 58 | min_x = max(0, min_x) 59 | min_y = max(0, min_y) 60 | max_x = min(width, max_x) 61 | max_y = min(height, max_y) 62 | 63 | # check the w, h and area of the rect 64 | w = max_x - min_x 65 | h = max_y - min_y 66 | is_valid = True 67 | 68 | if w < 10 or h < 10: 69 | is_valid = False 70 | 71 | if w * h < 100: 72 | is_valid = False 73 | 74 | xys[0, :] = xys[0, :] / width 75 | xys[1, :] = xys[1, :] / height 76 | 77 | return is_valid, min_x / width, min_y /height, max_x / width, max_y / height, xys 78 | 79 | def get_txt(self, image_idx, word_idx): 80 | txts = self.txts[image_idx]; 81 | clean_txts = [] 82 | for txt in txts: 83 | clean_txts += txt.split() 84 | return str(clean_txts[word_idx]) 85 | 86 | 87 | def fetch_record(self, image_idx): 88 | image_path = self.get_image_path(image_idx) 89 | if not (util.io.exists(image_path)): 90 | return None; 91 | img = util.img.imread(image_path) 92 | h, w = img.shape[0:-1]; 93 | num_words = self.get_num_words(image_idx) 94 | rect_bboxes = [] 95 | full_bboxes = [] 96 | txts = [] 97 | for word_idx in xrange(num_words): 98 | xys = self.get_word_bbox(image_idx, word_idx); 99 | is_valid, min_x, min_y, max_x, max_y, xys = self.normalize_bbox(xys, width = w, height = h) 100 | if not is_valid: 101 | continue; 102 | rect_bboxes.append([min_x, min_y, max_x, max_y]) 103 | xys = np.reshape(np.transpose(xys), -1) 104 | full_bboxes.append(xys); 105 | txt = self.get_txt(image_idx, word_idx); 106 | txts.append(txt); 107 | if len(rect_bboxes) == 0: 108 | return None; 109 | 110 | return image_path, img, txts, rect_bboxes, full_bboxes 111 | 112 | 113 | 114 | def cvt_to_tfrecords(output_path , data_path, gt_path, records_per_file = 50000): 115 | 116 | fetcher = SynthTextDataFetcher(root_path = data_path, mat_path = gt_path) 117 | fid = 0 118 | image_idx = -1 119 | while image_idx < fetcher.num_images: 120 | with tf.python_io.TFRecordWriter(output_path%(fid)) as tfrecord_writer: 121 | record_count = 0; 122 | while record_count != records_per_file: 123 | image_idx += 1; 124 | if image_idx >= fetcher.num_images: 125 | break; 126 | print "loading image %d/%d"%(image_idx + 1, fetcher.num_images) 127 | record = fetcher.fetch_record(image_idx); 128 | if record is None: 129 | print '\nimage %d does not exist'%(image_idx + 1) 130 | continue; 131 | 132 | image_path, image, txts, rect_bboxes, oriented_bboxes = record; 133 | labels = len(rect_bboxes) * [1]; 134 | ignored = len(rect_bboxes) * [0]; 135 | image_data = tf.gfile.FastGFile(image_path, 'r').read() 136 | shape = image.shape 137 | image_name = str(util.io.get_filename(image_path).split('.')[0]) 138 | example = convert_to_example(image_data, image_name, labels, ignored, txts, rect_bboxes, oriented_bboxes, shape) 139 | tfrecord_writer.write(example.SerializeToString()) 140 | record_count += 1; 141 | 142 | fid += 1; 143 | 144 | if __name__ == "__main__": 145 | mat_path = util.io.get_absolute_path('~/dataset/SynthText/gt.mat') 146 | root_path = util.io.get_absolute_path('~/dataset/SynthText/') 147 | output_dir = util.io.get_absolute_path('~/dataset/SSD-tf/SynthText/') 148 | util.io.mkdir(output_dir); 149 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'SynthText_%d.tfrecord'), data_path = root_path, gt_path = mat_path) 150 | -------------------------------------------------------------------------------- /eval_seglink.py: -------------------------------------------------------------------------------- 1 | #encoding = utf-8 2 | 3 | import numpy as np 4 | import math 5 | import tensorflow as tf 6 | from tensorflow.python.ops import control_flow_ops 7 | from tensorflow.contrib.training.python.training import evaluation 8 | from datasets import dataset_factory 9 | from preprocessing import ssd_vgg_preprocessing 10 | from tf_extended import seglink, metrics as tfe_metrics, bboxes as tfe_bboxes 11 | import util 12 | import cv2 13 | from nets import seglink_symbol, anchor_layer 14 | 15 | 16 | slim = tf.contrib.slim 17 | import config 18 | # =========================================================================== # 19 | # model threshold parameters 20 | # =========================================================================== # 21 | tf.app.flags.DEFINE_string('train_with_ignored', False, 22 | 'whether to use ignored bbox (in ic15) in training.') 23 | tf.app.flags.DEFINE_boolean('do_grid_search', False, 24 | 'whether to do grid search to find a best combinations of \ 25 | seg_conf_threshold and link_conf_threshold.') 26 | tf.app.flags.DEFINE_float('seg_loc_loss_weight', 1.0, 27 | 'the loss weight of segment localization') 28 | tf.app.flags.DEFINE_float('link_cls_loss_weight', 1.0, 29 | 'the loss weight of linkage classification loss') 30 | 31 | tf.app.flags.DEFINE_float('seg_conf_threshold', 0.9, 32 | 'the threshold on the confidence of segment') 33 | tf.app.flags.DEFINE_float('link_conf_threshold', 0.7, 34 | 'the threshold on the confidence of linkage') 35 | 36 | 37 | # =========================================================================== # 38 | # Checkpoint and running Flags 39 | # =========================================================================== # 40 | tf.app.flags.DEFINE_string('checkpoint_path', None, 41 | 'the path of checkpoint to be evaluated. \ 42 | If it is a directory containing many checkpoints, \ 43 | the lastest will be evaluated.') 44 | tf.app.flags.DEFINE_float('gpu_memory_fraction', 0.1, 45 | 'the gpu memory fraction to be used. \ 46 | If less than 0, allow_growth = True is used.') 47 | tf.app.flags.DEFINE_bool('using_moving_average', False, 48 | 'Whether to use ExponentionalMovingAverage') 49 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 50 | 'The decay rate of ExponentionalMovingAverage') 51 | 52 | # =========================================================================== # 53 | # I/O and preprocessing Flags. 54 | # =========================================================================== # 55 | tf.app.flags.DEFINE_integer( 56 | 'num_readers', 4, 57 | 'The number of parallel readers that read data from the dataset.') 58 | tf.app.flags.DEFINE_integer( 59 | 'num_preprocessing_threads', 1, 60 | 'The number of threads used to create the batches.') 61 | 62 | # =========================================================================== # 63 | # Dataset Flags. 64 | # =========================================================================== # 65 | tf.app.flags.DEFINE_string( 66 | 'dataset_name', None, 'The name of the dataset to load.') 67 | tf.app.flags.DEFINE_string( 68 | 'dataset_split_name', 'train', 'The name of the train/test split.') 69 | tf.app.flags.DEFINE_string( 70 | 'dataset_dir', None, 'The directory where the dataset files are stored.') 71 | tf.app.flags.DEFINE_string( 72 | 'model_name', 'seglink_vgg', 'The name of the architecture to train.') 73 | tf.app.flags.DEFINE_integer('eval_image_width', 1280, 'Train image size') 74 | tf.app.flags.DEFINE_integer('eval_image_height', 768, 'Train image size') 75 | 76 | 77 | FLAGS = tf.app.flags.FLAGS 78 | 79 | def config_initialization(): 80 | # image shape and feature layers shape inference 81 | image_shape = (FLAGS.eval_image_height, FLAGS.eval_image_width) 82 | 83 | if not FLAGS.dataset_dir: 84 | raise ValueError('You must supply the dataset directory with --dataset_dir') 85 | tf.logging.set_verbosity(tf.logging.DEBUG) 86 | 87 | config.init_config(image_shape, 88 | batch_size = 1, 89 | seg_conf_threshold = FLAGS.seg_conf_threshold, 90 | link_conf_threshold = FLAGS.link_conf_threshold, 91 | train_with_ignored = FLAGS.train_with_ignored, 92 | seg_loc_loss_weight = FLAGS.seg_loc_loss_weight, 93 | link_cls_loss_weight = FLAGS.link_cls_loss_weight, 94 | ) 95 | 96 | 97 | util.proc.set_proc_name('eval_' + FLAGS.model_name + '_' + FLAGS.dataset_name ) 98 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 99 | config.print_config(FLAGS, dataset, print_to_file = False) 100 | 101 | return dataset 102 | 103 | def read_dataset(dataset): 104 | with tf.name_scope(FLAGS.dataset_name +'_' + FLAGS.dataset_split_name + '_data_provider'): 105 | provider = slim.dataset_data_provider.DatasetDataProvider( 106 | dataset, 107 | num_readers=FLAGS.num_readers, 108 | shuffle=False) 109 | 110 | [image, shape, filename, gignored, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4] = provider.get([ 111 | 'image', 'shape', 'filename', 112 | 'object/ignored', 113 | 'object/bbox', 114 | 'object/oriented_bbox/x1', 115 | 'object/oriented_bbox/x2', 116 | 'object/oriented_bbox/x3', 117 | 'object/oriented_bbox/x4', 118 | 'object/oriented_bbox/y1', 119 | 'object/oriented_bbox/y2', 120 | 'object/oriented_bbox/y3', 121 | 'object/oriented_bbox/y4' 122 | ]) 123 | gxs = tf.transpose(tf.stack([x1, x2, x3, x4])) #shape = (N, 4) 124 | gys = tf.transpose(tf.stack([y1, y2, y3, y4])) 125 | image = tf.identity(image, 'input_image') 126 | 127 | # Pre-processing image, labels and bboxes. 128 | image, gignored, gbboxes, gxs, gys = ssd_vgg_preprocessing.preprocess_image( 129 | image, gignored, gbboxes, gxs, gys, 130 | out_shape = config.image_shape, 131 | data_format = config.data_format, 132 | is_training = False) 133 | image = tf.identity(image, 'processed_image') 134 | 135 | # calculate ground truth 136 | seg_label, seg_loc, link_gt = seglink.tf_get_all_seglink_gt(gxs, gys, gignored) 137 | 138 | return image, seg_label, seg_loc, link_gt, filename, shape, gignored, gxs, gys 139 | 140 | def eval(dataset): 141 | dict_metrics = {} 142 | checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path) 143 | logdir = util.io.join_path(checkpoint_dir, 144 | 'eval', 145 | "%s_%s"%(FLAGS.dataset_name, FLAGS.dataset_split_name)) 146 | 147 | global_step = slim.get_or_create_global_step() 148 | with tf.name_scope('evaluation_%dx%d'%(FLAGS.eval_image_height, FLAGS.eval_image_width)): 149 | with tf.variable_scope(tf.get_variable_scope(), reuse = True):# the variables has been created in config.init_config 150 | # get input tensor 151 | image, seg_label, seg_loc, link_gt, filename, shape, gignored, gxs, gys = read_dataset(dataset) 152 | # expand dim if needed 153 | b_image = tf.expand_dims(image, axis = 0); 154 | b_seg_label = tf.expand_dims(seg_label, axis = 0) 155 | b_seg_loc = tf.expand_dims(seg_loc, axis = 0) 156 | b_link_gt = tf.expand_dims(link_gt, axis = 0) 157 | b_shape = tf.expand_dims(shape, axis = 0) 158 | 159 | # build seglink loss 160 | net = seglink_symbol.SegLinkNet(inputs = b_image, data_format = config.data_format) 161 | net.build_loss(seg_labels = b_seg_label, 162 | seg_offsets = b_seg_loc, 163 | link_labels = b_link_gt, 164 | do_summary = False) # the summary will be added in the following lines 165 | 166 | # gather seglink losses 167 | losses = tf.get_collection(tf.GraphKeys.LOSSES) 168 | assert len(losses) == 3 # 3 is the number of seglink losses: seg_cls, seg_loc, link_cls 169 | for loss in tf.get_collection(tf.GraphKeys.LOSSES): 170 | dict_metrics[loss.op.name] = slim.metrics.streaming_mean(loss) 171 | 172 | seglink_loss = tf.add_n(losses) 173 | dict_metrics['seglink_loss'] = slim.metrics.streaming_mean(seglink_loss) 174 | 175 | # Add metrics to summaries. 176 | for name, metric in dict_metrics.items(): 177 | tf.summary.scalar(name, metric[0]) 178 | 179 | # shape = (height, width, channels) when format = NHWC TODO 180 | gxs = gxs * tf.cast(shape[1], gxs.dtype) 181 | gys = gys * tf.cast(shape[0], gys.dtype) 182 | if FLAGS.do_grid_search: 183 | # grid search 184 | seg_ths = np.arange(0.5, 0.91, 0.1) 185 | link_ths = seg_ths 186 | else: 187 | seg_ths = [FLAGS.seg_conf_threshold] 188 | link_ths = [FLAGS.link_conf_threshold] 189 | 190 | eval_result_path = util.io.join_path(logdir, 'eval_on_%s_%s.log'%(FLAGS.dataset_name, FLAGS.dataset_split_name)) 191 | for seg_th in seg_ths: 192 | for link_th in link_ths: 193 | config._set_det_th(seg_th, link_th) 194 | 195 | eval_result_msg = 'seg_conf_threshold=%f, link_conf_threshold = %f, '\ 196 | %(config.seg_conf_threshold, config.link_conf_threshold) 197 | eval_result_msg += 'iter = %r, recall = %r, precision = %f, fmean = %r' 198 | 199 | with tf.name_scope('seglink_conf_th_%f_%f'\ 200 | %(config.seg_conf_threshold, config.link_conf_threshold)): 201 | # decode seglink to bbox output, with absolute length, instead of being within [0,1] 202 | bboxes_pred = seglink.tf_seglink_to_bbox(net.seg_scores, net.link_scores, net.seg_offsets, 203 | b_shape, seg_conf_threshold = seg_th, link_conf_threshold = link_th) 204 | # bboxes_pred = tf.Print(bboxes_pred, [tf.shape(bboxes_pred)], '%f_%f, shape of bboxes = '%(seg_th, link_th)) 205 | # calculate true positive and false positive 206 | # the xs and ys from tfrecord is 0~1, resize them to absolute length before matching. 207 | num_gt_bboxes, tp, fp = tfe_bboxes.bboxes_matching(bboxes_pred, gxs, gys, gignored) 208 | tp_fp_metric = tfe_metrics.streaming_tp_fp_arrays(num_gt_bboxes, tp, fp) 209 | dict_metrics['tp_fp_%f_%f'%(config.seg_conf_threshold, config.link_conf_threshold)] = (tp_fp_metric[0], tp_fp_metric[1]) 210 | 211 | # precision and recall 212 | precision, recall = tfe_metrics.precision_recall(*tp_fp_metric[0]) 213 | 214 | fmean = tfe_metrics.fmean(precision, recall) 215 | fmean = util.tf.Print(fmean, data = [global_step, recall, precision, fmean], 216 | msg = eval_result_msg, 217 | file = eval_result_path, mode = 'a') 218 | fmean = tf.Print(fmean, [recall, precision, fmean], '%f_%f, Recall, Precision, Fmean = '%(seg_th, link_th)) 219 | tf.summary.scalar('Precision', precision) 220 | tf.summary.scalar('Recall', recall) 221 | tf.summary.scalar('F-mean', fmean) 222 | 223 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(dict_metrics) 224 | 225 | 226 | sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True) 227 | if FLAGS.gpu_memory_fraction < 0: 228 | sess_config.gpu_options.allow_growth = True 229 | elif FLAGS.gpu_memory_fraction > 0: 230 | sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction; 231 | 232 | # Variables to restore: moving avg. or normal weights. 233 | if FLAGS.using_moving_average: 234 | variable_averages = tf.train.ExponentialMovingAverage( 235 | FLAGS.moving_average_decay) 236 | variables_to_restore = variable_averages.variables_to_restore( 237 | slim.get_model_variables()) 238 | variables_to_restore[global_step.op.name] = global_step 239 | else: 240 | variables_to_restore = slim.get_variables_to_restore() 241 | 242 | if util.io.is_dir(FLAGS.checkpoint_path): 243 | slim.evaluation.evaluation_loop( 244 | master = '', 245 | eval_op=list(names_to_updates.values()), 246 | num_evals=dataset.num_samples, 247 | variables_to_restore=variables_to_restore, 248 | checkpoint_dir = checkpoint_dir, 249 | logdir = logdir, 250 | session_config=sess_config) 251 | else: 252 | slim.evaluation.evaluate_once( 253 | master = '', 254 | eval_op=list(names_to_updates.values()), 255 | variables_to_restore=variables_to_restore, 256 | num_evals=2,#dataset.num_samples, 257 | checkpoint_path = FLAGS.checkpoint_path, 258 | logdir = logdir, 259 | session_config=sess_config) 260 | 261 | 262 | 263 | def main(_): 264 | eval(config_initialization()) 265 | 266 | 267 | if __name__ == '__main__': 268 | tf.app.run() 269 | -------------------------------------------------------------------------------- /img_10_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dengdan/seglink/cc36732d78a637ac10587c11befe19944ec1c1ea/img_10_pred.jpg -------------------------------------------------------------------------------- /img_31_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dengdan/seglink/cc36732d78a637ac10587c11befe19944ec1c1ea/img_31_pred.jpg -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dengdan/seglink/cc36732d78a637ac10587c11befe19944ec1c1ea/nets/__init__.py -------------------------------------------------------------------------------- /nets/anchor_layer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import config 3 | def generate_anchors(): 4 | """ 5 | 6 | """ 7 | all_anchors = [] 8 | layer_anchors = {} 9 | h_I, w_I = config.image_shape; 10 | for layer_name in config.feat_layers: 11 | feat_shape = config.feat_shapes[layer_name]; 12 | h_l, w_l = feat_shape 13 | anchors = _generate_anchors_one_layer(h_I, w_I, h_l, w_l) 14 | all_anchors.append(anchors) 15 | layer_anchors[layer_name] = anchors 16 | all_anchors = _reshape_and_concat(all_anchors) 17 | return all_anchors, layer_anchors 18 | 19 | 20 | def _reshape_and_concat(tensors): 21 | tensors = [np.reshape(t, (-1, t.shape[-1])) for t in tensors] 22 | return np.vstack(tensors) 23 | 24 | def _generate_anchors_one_layer(h_I, w_I, h_l, w_l): 25 | """ 26 | generate anchors on on layer 27 | return a ndarray with shape (h_l, w_l, 4), and the last dimmension in the order:[cx, cy, w, h] 28 | """ 29 | y, x = np.mgrid[0: h_l, 0:w_l] 30 | cy = (y + config.anchor_offset) / h_l * h_I 31 | cx = (x + config.anchor_offset) / w_l * w_I 32 | 33 | anchor_scale = _get_scale(w_I, w_l) 34 | anchor_w = np.ones_like(cx) * anchor_scale 35 | anchor_h = np.ones_like(cx) * anchor_scale # cx.shape == cy.shape 36 | 37 | anchors = np.asarray([cx, cy, anchor_w, anchor_h]) 38 | anchors = np.transpose(anchors, (1, 2, 0)) 39 | 40 | return anchors 41 | 42 | 43 | def _get_scale(w_I, w_l): 44 | return config.anchor_scale_gamma * 1.0 * w_I / w_l 45 | 46 | 47 | 48 | def _test_generate_anchors_one_layer(): 49 | """ 50 | test _generate_anchors_one_layer method by visualizing it in an image. 51 | """ 52 | import util 53 | image_shape = (512, 512) 54 | h_I, w_I = image_shape 55 | stride = 256 56 | feat_shape = (h_I/stride, w_I / stride) 57 | h_l, w_l = feat_shape 58 | anchors = _generate_anchors_one_layer(h_I, w_I, h_l, w_l, gamma = 1.5) 59 | assert(anchors.shape == (h_l, w_l, 4)) 60 | mask = util.img.black(image_shape) 61 | for x in xrange(w_l): 62 | for y in xrange(h_l): 63 | cx, cy, w, h = anchors[y, x, :] 64 | xmin = (cx - w / 2) 65 | ymin = (cy - h / 2) 66 | 67 | xmax = (cx + w / 2) 68 | ymax = (cy + h / 2) 69 | 70 | cxy = (int(cx), int(cy)) 71 | util.img.circle(mask, cxy, 3, color = 255) 72 | util.img.rectangle(mask, (xmin, ymin), (xmax, ymax), color = 255) 73 | 74 | util.sit(mask) 75 | 76 | 77 | if __name__ == "__main__": 78 | _test_generate_anchors_one_layer(); 79 | -------------------------------------------------------------------------------- /nets/net_factory.py: -------------------------------------------------------------------------------- 1 | import vgg 2 | 3 | net_dict = { 4 | "vgg": vgg 5 | } 6 | 7 | def get_basenet(name, inputs): 8 | net = net_dict[name]; 9 | return net.basenet(inputs); 10 | -------------------------------------------------------------------------------- /nets/seglink_symbol.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | import net_factory 4 | import config 5 | 6 | 7 | class SegLinkNet(object): 8 | def __init__(self, inputs, weight_decay = None, basenet_type = 'vgg', data_format = 'NHWC', 9 | weights_initializer = None, biases_initializer = None): 10 | self.inputs = inputs; 11 | self.weight_decay = weight_decay 12 | self.feat_layers = config.feat_layers 13 | self.basenet_type = basenet_type; 14 | self.data_format = data_format; 15 | if weights_initializer is None: 16 | weights_initializer = tf.contrib.layers.xavier_initializer() 17 | if biases_initializer is None: 18 | biases_initializer = tf.zeros_initializer() 19 | self.weights_initializer = weights_initializer 20 | self.biases_initializer = biases_initializer 21 | 22 | self._build_network(); 23 | self.shapes = self.get_shapes(); 24 | def get_shapes(self): 25 | shapes = {} 26 | 27 | for layer in self.end_points: 28 | shapes[layer] = tensor_shape(self.end_points[layer])[1:-1] 29 | return shapes 30 | def get_shape(self, name): 31 | return self.shapes[name] 32 | 33 | def _build_network(self): 34 | 35 | with slim.arg_scope([slim.conv2d], 36 | activation_fn=tf.nn.relu, 37 | weights_regularizer=slim.l2_regularizer(self.weight_decay), 38 | weights_initializer= self.weights_initializer, 39 | biases_initializer = self.biases_initializer): 40 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 41 | padding='SAME', 42 | data_format = self.data_format): 43 | with tf.variable_scope(self.basenet_type): 44 | basenet, end_points = net_factory.get_basenet(self.basenet_type, self.inputs); 45 | 46 | with tf.variable_scope('extra_layers'): 47 | self.net, self.end_points = self._add_extra_layers(basenet, end_points); 48 | 49 | with tf.variable_scope('seglink_layers'): 50 | self._add_seglink_layers(); 51 | 52 | def _add_extra_layers(self, inputs, end_points): 53 | # Additional SSD blocks. 54 | # conv6/7/8/9/10: 1x1 and 3x3 convolutions stride 2 (except lasts). 55 | net = slim.conv2d(inputs, 256, [1, 1], scope='conv6_1') 56 | net = slim.conv2d(net, 512, [3, 3], stride=2, scope='conv6_2', padding='SAME') 57 | end_points['conv6_2'] = net 58 | 59 | net = slim.conv2d(net, 128, [1, 1], scope='conv7_1') 60 | net = slim.conv2d(net, 256, [3, 3], stride=2, scope='conv7_2', padding='SAME') 61 | end_points['conv7_2'] = net 62 | 63 | net = slim.conv2d(net, 128, [1, 1], scope='conv8_1') 64 | net = slim.conv2d(net, 256, [3, 3], stride=2, scope='conv8_2', padding='SAME') 65 | end_points['conv8_2'] = net 66 | 67 | net = slim.conv2d(net, 128, [1, 1], scope='conv9_1') 68 | net = slim.conv2d(net, 256, [3, 3], stride=2, scope='conv9_2', padding='SAME') 69 | end_points['conv9_2'] = net 70 | 71 | 72 | # net = slim.conv2d(net, 128, [1, 1], scope='conv10_1') 73 | 74 | # Padding to use kernel of size 4, to be compatible with caffe ssd model 75 | # The minimal input dimension should be 512, resulting in 2x2. After padding, it becomes 4x4 76 | # paddings = [[0, 0], [1, 1], [1, 1], [0, 0]] 77 | # net = tf.pad(net, paddings) 78 | # net = slim.conv2d(net, 256, [4, 4], scope='conv10_2', padding='VALID') 79 | # end_points['conv10_2'] = net 80 | return net, end_points; 81 | 82 | def _build_seg_link_layer(self, layer_name): 83 | net = self.end_points[layer_name] 84 | batch_size, h, w = tensor_shape(net)[:-1] 85 | 86 | if layer_name == 'conv4_3': 87 | net = tf.nn.l2_normalize(net, -1) * 20 88 | 89 | with slim.arg_scope([slim.conv2d], 90 | activation_fn = None, 91 | weights_regularizer=slim.l2_regularizer(self.weight_decay), 92 | weights_initializer = tf.contrib.layers.xavier_initializer(), 93 | biases_initializer = tf.zeros_initializer()): 94 | 95 | # segment scores 96 | num_cls_pred = 2 97 | seg_scores = slim.conv2d(net, num_cls_pred, [3, 3], scope='seg_scores') 98 | 99 | # segment offsets 100 | num_offset_pred = 5 101 | seg_offsets = slim.conv2d(net, num_offset_pred, [3, 3], scope = 'seg_offsets') 102 | 103 | # within-layer link scores 104 | num_within_layer_link_scores_pred = 16 105 | within_layer_link_scores = slim.conv2d(net, num_within_layer_link_scores_pred, [3, 3], scope = 'within_layer_link_scores') 106 | within_layer_link_scores = tf.reshape(within_layer_link_scores, tensor_shape(within_layer_link_scores)[:-1] + [8, 2]) 107 | 108 | # cross-layer link scores 109 | num_cross_layer_link_scores_pred = 8 110 | cross_layer_link_scores = None; 111 | if layer_name != 'conv4_3': 112 | cross_layer_link_scores = slim.conv2d(net, num_cross_layer_link_scores_pred, [3, 3], scope = 'cross_layer_link_scores') 113 | cross_layer_link_scores = tf.reshape(cross_layer_link_scores, tensor_shape(cross_layer_link_scores)[:-1] + [4, 2]) 114 | 115 | return seg_scores, seg_offsets, within_layer_link_scores, cross_layer_link_scores 116 | 117 | 118 | def _add_seglink_layers(self): 119 | all_seg_scores = [] 120 | all_seg_offsets = [] 121 | all_within_layer_link_scores = [] 122 | all_cross_layer_link_scores = [] 123 | for layer_name in self.feat_layers: 124 | with tf.variable_scope(layer_name): 125 | seg_scores, seg_offsets, within_layer_link_scores, cross_layer_link_scores = self._build_seg_link_layer(layer_name) 126 | all_seg_scores.append(seg_scores) 127 | all_seg_offsets.append(seg_offsets) 128 | all_within_layer_link_scores.append(within_layer_link_scores) 129 | all_cross_layer_link_scores.append(cross_layer_link_scores) 130 | 131 | self.seg_score_logits = reshape_and_concat(all_seg_scores) # (batch_size, N, 2) 132 | self.seg_scores = slim.softmax(self.seg_score_logits) # (batch_size, N, 2) 133 | self.seg_offsets = reshape_and_concat(all_seg_offsets) # (batch_size, N, 5) 134 | self.cross_layer_link_scores = reshape_and_concat(all_cross_layer_link_scores) # (batch_size, 8N, 2) 135 | self.within_layer_link_scores = reshape_and_concat(all_within_layer_link_scores) # (batch_size, 4(N - N_conv4_3), 2) 136 | self.link_score_logits = tf.concat([self.within_layer_link_scores, self.cross_layer_link_scores], axis = 1) 137 | self.link_scores = slim.softmax(self.link_score_logits) 138 | 139 | tf.summary.histogram('link_scores', self.link_scores) 140 | tf.summary.histogram('seg_scores', self.seg_scores) 141 | 142 | def build_loss(self, seg_labels, seg_offsets, link_labels, do_summary = True): 143 | batch_size = config.batch_size_per_gpu 144 | 145 | # note that for label values in both seg_labels and link_labels: 146 | # -1 stands for negative 147 | # 1 stands for positive 148 | # 0 stands for ignored 149 | def get_pos_and_neg_masks(labels): 150 | if config.train_with_ignored: 151 | pos_mask = labels >= 0 152 | neg_mask = tf.logical_not(pos_mask) 153 | else: 154 | pos_mask = tf.equal(labels, 1) 155 | neg_mask = tf.equal(labels, -1) 156 | 157 | return pos_mask, neg_mask 158 | 159 | def OHNM_single_image(scores, n_pos, neg_mask): 160 | """Online Hard Negative Mining. 161 | scores: the scores of being predicted as negative cls 162 | n_pos: the number of positive samples 163 | neg_mask: mask of negative samples 164 | Return: 165 | the mask of selected negative samples. 166 | if n_pos == 0, no negative samples will be selected. 167 | """ 168 | def has_pos(): 169 | n_neg = n_pos * config.max_neg_pos_ratio 170 | max_neg_entries = tf.reduce_sum(tf.cast(neg_mask, tf.int32)) 171 | n_neg = tf.minimum(n_neg, max_neg_entries) 172 | n_neg = tf.cast(n_neg, tf.int32) 173 | neg_conf = tf.boolean_mask(scores, neg_mask) 174 | vals, _ = tf.nn.top_k(-neg_conf, k=n_neg) 175 | threshold = vals[-1]# a negtive value 176 | selected_neg_mask = tf.logical_and(neg_mask, scores <= -threshold) 177 | return tf.cast(selected_neg_mask, tf.float32) 178 | 179 | def no_pos(): 180 | return tf.zeros_like(neg_mask, tf.float32) 181 | 182 | return tf.cond(n_pos > 0, has_pos, no_pos) 183 | 184 | def OHNM_batch(neg_conf, pos_mask, neg_mask): 185 | selected_neg_mask = [] 186 | for image_idx in xrange(batch_size): 187 | image_neg_conf = neg_conf[image_idx, :] 188 | image_neg_mask = neg_mask[image_idx, :] 189 | image_pos_mask = pos_mask[image_idx, :] 190 | n_pos = tf.reduce_sum(tf.cast(image_pos_mask, tf.int32)) 191 | selected_neg_mask.append(OHNM_single_image(image_neg_conf, n_pos, image_neg_mask)) 192 | 193 | selected_neg_mask = tf.stack(selected_neg_mask) 194 | selected_mask = tf.cast(pos_mask, tf.float32) + selected_neg_mask 195 | return selected_mask 196 | 197 | 198 | # OHNM on segments 199 | seg_neg_scores = self.seg_scores[:, :, 0] 200 | seg_pos_mask, seg_neg_mask = get_pos_and_neg_masks(seg_labels) 201 | seg_selected_mask = OHNM_batch(seg_neg_scores, seg_pos_mask, seg_neg_mask) 202 | n_seg_pos = tf.reduce_sum(tf.cast(seg_pos_mask, tf.float32)) 203 | 204 | with tf.name_scope('seg_cls_loss'): 205 | def has_pos(): 206 | seg_cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 207 | logits = self.seg_score_logits, 208 | labels = tf.cast(seg_pos_mask, dtype = tf.int32)) 209 | return tf.reduce_sum(seg_cls_loss * seg_selected_mask) / n_seg_pos 210 | def no_pos(): 211 | return tf.constant(.0); 212 | seg_cls_loss = tf.cond(n_seg_pos > 0, has_pos, no_pos) 213 | tf.add_to_collection(tf.GraphKeys.LOSSES, seg_cls_loss) 214 | 215 | def smooth_l1_loss(pred, target, weights): 216 | diff = pred - target 217 | abs_diff = tf.abs(diff) 218 | abs_diff_lt_1 = tf.less(abs_diff, 1) 219 | if len(target.shape) != len(weights.shape): 220 | loss = tf.reduce_sum(tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5), axis = 2) 221 | return tf.reduce_sum(loss * tf.cast(weights, tf.float32)) 222 | else: 223 | loss = tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5) 224 | return tf.reduce_sum(loss * tf.cast(weights, tf.float32)) 225 | 226 | with tf.name_scope('seg_loc_loss'): 227 | def has_pos(): 228 | seg_loc_loss = smooth_l1_loss(self.seg_offsets, seg_offsets, seg_pos_mask) * config.seg_loc_loss_weight / n_seg_pos 229 | names= ['loc_cx_loss', 'loc_cy_loss', 'loc_w_loss', 'loc_h_loss', 'loc_theta_loss'] 230 | sub_loc_losses = [] 231 | from tensorflow.python.ops import control_flow_ops 232 | for idx, name in enumerate(names): 233 | name_loss = smooth_l1_loss(self.seg_offsets[:, :, idx], seg_offsets[:,:, idx], seg_pos_mask) * config.seg_loc_loss_weight / n_seg_pos 234 | name_loss = tf.identity(name_loss, name = name) 235 | if do_summary: 236 | tf.summary.scalar(name, name_loss) 237 | sub_loc_losses.append(name_loss) 238 | seg_loc_loss = control_flow_ops.with_dependencies(sub_loc_losses, seg_loc_loss) 239 | return seg_loc_loss 240 | def no_pos(): 241 | return tf.constant(.0); 242 | seg_loc_loss = tf.cond(n_seg_pos > 0, has_pos, no_pos) 243 | tf.add_to_collection(tf.GraphKeys.LOSSES, seg_loc_loss) 244 | 245 | 246 | link_neg_scores = self.link_scores[:,:,0] 247 | link_pos_mask, link_neg_mask = get_pos_and_neg_masks(link_labels) 248 | link_selected_mask = OHNM_batch(link_neg_scores, link_pos_mask, link_neg_mask) 249 | n_link_pos = tf.reduce_sum(tf.cast(link_pos_mask, dtype = tf.float32)) 250 | with tf.name_scope('link_cls_loss'): 251 | def has_pos(): 252 | link_cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 253 | logits = self.link_score_logits, 254 | labels = tf.cast(link_pos_mask, tf.int32)) 255 | return tf.reduce_sum(link_cls_loss * link_selected_mask) / n_link_pos 256 | def no_pos(): 257 | return tf.constant(.0); 258 | link_cls_loss = tf.cond(n_link_pos > 0, has_pos, no_pos) * config.link_cls_loss_weight 259 | tf.add_to_collection(tf.GraphKeys.LOSSES, link_cls_loss) 260 | 261 | if do_summary: 262 | tf.summary.scalar('seg_cls_loss', seg_cls_loss) 263 | tf.summary.scalar('seg_loc_loss', seg_loc_loss) 264 | tf.summary.scalar('link_cls_loss', link_cls_loss) 265 | 266 | 267 | def reshape_and_concat(tensors): 268 | def reshape(t): 269 | shape = tensor_shape(t) 270 | if len(shape) == 4: 271 | shape = (shape[0], -1, shape[-1]) 272 | t = tf.reshape(t, shape) 273 | elif len(shape) == 5: 274 | shape = (shape[0], -1, shape[-2], shape[-1]) 275 | t = tf.reshape(t, shape) 276 | t = tf.reshape(t, [shape[0], -1, shape[-1]]) 277 | else: 278 | raise ValueError("invalid tensor shape: %s, shape = %s"%(t.name, shape)) 279 | return t; 280 | reshaped_tensors = [reshape(t) for t in tensors if t is not None] 281 | return tf.concat(reshaped_tensors, axis = 1) 282 | 283 | def tensor_shape(t): 284 | t.get_shape().assert_is_fully_defined() 285 | return t.get_shape().as_list() 286 | -------------------------------------------------------------------------------- /nets/vgg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | slim = tf.contrib.slim 4 | 5 | 6 | def basenet(inputs): 7 | """ 8 | backbone net of vgg16 9 | """ 10 | # End_points collect relevant activations for external use. 11 | end_points = {} 12 | # Original VGG-16 blocks. 13 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME'): 14 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 15 | end_points['conv1_2'] = net 16 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 17 | # Block 2. 18 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 19 | end_points['conv2_2'] = net 20 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 21 | # Block 3. 22 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 23 | end_points['conv3_3'] = net 24 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 25 | # Block 4. 26 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 27 | end_points['conv4_3'] = net 28 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 29 | # Block 5. 30 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 31 | end_points['conv5_3'] = net 32 | net = slim.max_pool2d(net, [3, 3], 1, scope='pool5') 33 | 34 | # fc6 as conv, dilation is added 35 | net = slim.conv2d(net, 1024, [3, 3], rate=6, scope='fc6') 36 | end_points['fc6'] = net 37 | 38 | # fc7 as conv 39 | net = slim.conv2d(net, 1024, [1, 1], scope='fc7') 40 | end_points['fc7'] = net 41 | 42 | return net, end_points; 43 | 44 | 45 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import ssd_vgg_preprocessing 24 | 25 | slim = tf.contrib.slim 26 | 27 | 28 | def get_preprocessing(is_training=False): 29 | """Returns preprocessing_fn(image, height, width, **kwargs). 30 | 31 | Args: 32 | name: The name of the preprocessing function. 33 | is_training: `True` if the model is being used for training. 34 | 35 | Returns: 36 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 37 | It has the following signature: 38 | image = preprocessing_fn(image, output_height, output_width, ...). 39 | 40 | Raises: 41 | ValueError: If Preprocessing `name` is not recognized. 42 | """ 43 | 44 | 45 | def preprocessing_fn(image, labels, bboxes, xs, ys, 46 | out_shape, data_format='NHWC', **kwargs): 47 | return ssd_vgg_preprocessing.preprocess_image( 48 | image, labels, bboxes, out_shape, xs, ys, data_format=data_format, 49 | is_training=is_training, **kwargs) 50 | return preprocessing_fn 51 | -------------------------------------------------------------------------------- /preprocessing/ssd_vgg_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Pre-processing images for SSD-type networks. 16 | """ 17 | from enum import Enum, IntEnum 18 | import numpy as np 19 | 20 | import tensorflow as tf 21 | import tf_extended as tfe 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | 25 | from preprocessing import tf_image 26 | 27 | slim = tf.contrib.slim 28 | 29 | # Resizing strategies. 30 | Resize = IntEnum('Resize', ('NONE', # Nothing! 31 | 'CENTRAL_CROP', # Crop (and pad if necessary). 32 | 'PAD_AND_RESIZE', # Pad, and resize to output shape. 33 | 'WARP_RESIZE')) # Warp resize. 34 | 35 | # VGG mean parameters. 36 | _R_MEAN = 123. 37 | _G_MEAN = 117. 38 | _B_MEAN = 104. 39 | 40 | # Some training pre-processing parameters. 41 | BBOX_CROP_OVERLAP = 0.1 # Minimum overlap to keep a bbox after cropping. 42 | MIN_OBJECT_COVERED = 0.5 43 | CROP_ASPECT_RATIO_RANGE = (0.5, 2.) # Distortion ratio during cropping. 44 | EVAL_SIZE = (300, 300) 45 | AREA_RANGE = [0.1, 1] 46 | FLIP = False 47 | 48 | def tf_image_whitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN]): 49 | """Subtracts the given means from each image channel. 50 | 51 | Returns: 52 | the centered image. 53 | """ 54 | if image.get_shape().ndims != 3: 55 | raise ValueError('Input must be of size [height, width, C>0]') 56 | num_channels = image.get_shape().as_list()[-1] 57 | if len(means) != num_channels: 58 | raise ValueError('len(means) must match the number of channels') 59 | 60 | mean = tf.constant(means, dtype=image.dtype) 61 | image = image - mean 62 | return image 63 | 64 | 65 | def tf_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True): 66 | """Re-convert to original image distribution, and convert to int if 67 | necessary. 68 | 69 | Returns: 70 | Centered image. 71 | """ 72 | mean = tf.constant(means, dtype=image.dtype) 73 | image = image + mean 74 | if to_int: 75 | image = tf.cast(image, tf.int32) 76 | return image 77 | 78 | 79 | def np_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True): 80 | """Re-convert to original image distribution, and convert to int if 81 | necessary. Numpy version. 82 | 83 | Returns: 84 | Centered image. 85 | """ 86 | img = np.copy(image) 87 | img += np.array(means, dtype=img.dtype) 88 | if to_int: 89 | img = img.astype(np.uint8) 90 | return img 91 | 92 | 93 | def tf_summary_image(image, bboxes, name='image', unwhitened=False): 94 | """Add image with bounding boxes to summary. 95 | """ 96 | if unwhitened: 97 | image = tf_image_unwhitened(image) 98 | image = tf.expand_dims(image, 0) 99 | bboxes = tf.expand_dims(bboxes, 0) 100 | image_with_box = tf.image.draw_bounding_boxes(image, bboxes) 101 | tf.summary.image(name, image_with_box) 102 | 103 | 104 | def apply_with_random_selector(x, func, num_cases): 105 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 106 | 107 | Args: 108 | x: input Tensor. 109 | func: Python function to apply. 110 | num_cases: Python int32, number of cases to sample sel from. 111 | 112 | Returns: 113 | The result of func(x, sel), where func receives the value of the 114 | selector as a python integer, but sel is sampled dynamically. 115 | """ 116 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 117 | # Pass the real x only to one of the func calls. 118 | return control_flow_ops.merge([ 119 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 120 | for case in range(num_cases)])[0] 121 | 122 | 123 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 124 | """Distort the color of a Tensor image. 125 | 126 | Each color distortion is non-commutative and thus ordering of the color ops 127 | matters. Ideally we would randomly permute the ordering of the color ops. 128 | Rather then adding that level of complication, we select a distinct ordering 129 | of color ops for each preprocessing thread. 130 | 131 | Args: 132 | image: 3-D Tensor containing single image in [0, 1]. 133 | color_ordering: Python int, a type of distortion (valid values: 0-3). 134 | fast_mode: Avoids slower ops (random_hue and random_contrast) 135 | scope: Optional scope for name_scope. 136 | Returns: 137 | 3-D Tensor color-distorted image on range [0, 1] 138 | Raises: 139 | ValueError: if color_ordering not in [0, 3] 140 | """ 141 | with tf.name_scope(scope, 'distort_color', [image]): 142 | if fast_mode: 143 | if color_ordering == 0: 144 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 145 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 146 | else: 147 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 148 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 149 | else: 150 | if color_ordering == 0: 151 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 152 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 153 | image = tf.image.random_hue(image, max_delta=0.2) 154 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 155 | elif color_ordering == 1: 156 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 157 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 158 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 159 | image = tf.image.random_hue(image, max_delta=0.2) 160 | elif color_ordering == 2: 161 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 162 | image = tf.image.random_hue(image, max_delta=0.2) 163 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 164 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 165 | elif color_ordering == 3: 166 | image = tf.image.random_hue(image, max_delta=0.2) 167 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 168 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 169 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 170 | else: 171 | raise ValueError('color_ordering must be in [0, 3]') 172 | # The random_* ops do not necessarily clamp. 173 | return tf.clip_by_value(image, 0.0, 1.0) 174 | 175 | 176 | def distorted_bounding_box_crop(image, 177 | labels, 178 | bboxes, 179 | xs, ys, 180 | min_object_covered, 181 | aspect_ratio_range, 182 | area_range, 183 | max_attempts = 200, 184 | scope=None): 185 | """Generates cropped_image using a one of the bboxes randomly distorted. 186 | 187 | See `tf.image.sample_distorted_bounding_box` for more documentation. 188 | 189 | Args: 190 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 191 | bbox: 2-D float Tensor of bounding boxes arranged [num_boxes, coords] 192 | where each coordinate is [0, 1) and the coordinates are arranged 193 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 194 | image. 195 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 196 | area of the image must contain at least this fraction of any bounding box 197 | supplied. 198 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 199 | image must have an aspect ratio = width / height within this range. 200 | area_range: An optional list of `floats`. The cropped area of the image 201 | must contain a fraction of the supplied image within in this range. 202 | max_attempts: An optional `int`. Number of attempts at generating a cropped 203 | region of the image of the specified constraints. After `max_attempts` 204 | failures, return the entire image. 205 | scope: Optional scope for name_scope. 206 | Returns: 207 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 208 | """ 209 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bboxes, xs, ys]): 210 | # Each bounding box has shape [1, num_boxes, box coords] and 211 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 212 | bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box( 213 | tf.shape(image), 214 | bounding_boxes=tf.expand_dims(bboxes, 0), 215 | min_object_covered=min_object_covered, 216 | aspect_ratio_range=aspect_ratio_range, 217 | area_range=area_range, 218 | max_attempts=max_attempts, 219 | use_image_if_no_bounding_boxes=True) 220 | distort_bbox = distort_bbox[0, 0] 221 | 222 | # Crop the image to the specified bounding box. 223 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 224 | # Restore the shape since the dynamic slice loses 3rd dimension. 225 | cropped_image.set_shape([None, None, 3]) 226 | 227 | # Update bounding boxes: resize and filter out. 228 | bboxes, xs, ys = tfe.bboxes_resize(distort_bbox, bboxes, xs, ys) 229 | labels, bboxes, xs, ys = tfe.bboxes_filter_overlap(labels, bboxes, xs, ys, 230 | threshold=BBOX_CROP_OVERLAP, assign_negative = False) 231 | return cropped_image, labels, bboxes, xs, ys, distort_bbox 232 | 233 | 234 | def preprocess_for_train(image, labels, bboxes, xs, ys, 235 | out_shape, data_format='NHWC', 236 | scope='ssd_preprocessing_train'): 237 | """Preprocesses the given image for training. 238 | 239 | Note that the actual resizing scale is sampled from 240 | [`resize_size_min`, `resize_size_max`]. 241 | 242 | Args: 243 | image: A `Tensor` representing an image of arbitrary size. 244 | output_height: The height of the image after preprocessing. 245 | output_width: The width of the image after preprocessing. 246 | resize_side_min: The lower bound for the smallest side of the image for 247 | aspect-preserving resizing. 248 | resize_side_max: The upper bound for the smallest side of the image for 249 | aspect-preserving resizing. 250 | 251 | Returns: 252 | A preprocessed image. 253 | """ 254 | fast_mode = False 255 | with tf.name_scope(scope, 'ssd_preprocessing_train', [image, labels, bboxes]): 256 | if image.get_shape().ndims != 3: 257 | raise ValueError('Input must be of size [height, width, C>0]') 258 | # Convert to float scaled [0, 1]. 259 | if image.dtype != tf.float32: 260 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 261 | # tf_summary_image(image, bboxes, 'image_with_bboxes') 262 | 263 | 264 | # Distort image and bounding boxes. 265 | dst_image = image 266 | dst_image, labels, bboxes, xs, ys, distort_bbox = \ 267 | distorted_bounding_box_crop(image, labels, bboxes, xs, ys, 268 | min_object_covered = MIN_OBJECT_COVERED, 269 | aspect_ratio_range = CROP_ASPECT_RATIO_RANGE, 270 | area_range = AREA_RANGE) 271 | 272 | # Resize image to output size. 273 | dst_image = tf_image.resize_image(dst_image, out_shape, 274 | method=tf.image.ResizeMethod.BILINEAR, 275 | align_corners=False) 276 | tf_summary_image(dst_image, bboxes, 'image_shape_distorted') 277 | 278 | # Randomly flip the image horizontally. 279 | if FLIP: 280 | dst_image, bboxes = tf_image.random_flip_left_right(dst_image, bboxes) 281 | 282 | # Randomly distort the colors. There are 4 ways to do it. 283 | dst_image = apply_with_random_selector( 284 | dst_image, 285 | lambda x, ordering: distort_color(x, ordering, fast_mode), 286 | num_cases=4) 287 | tf_summary_image(dst_image, bboxes, 'image_color_distorted') 288 | 289 | # Rescale to VGG input scale. 290 | image = dst_image * 255. 291 | image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 292 | # Image data format. 293 | if data_format == 'NCHW': 294 | image = tf.transpose(image, perm=(2, 0, 1)) 295 | return image, labels, bboxes, xs, ys 296 | 297 | 298 | def preprocess_for_eval(image, labels, bboxes, xs, ys, 299 | out_shape=EVAL_SIZE, data_format='NHWC', 300 | difficults=None, resize=Resize.WARP_RESIZE, 301 | scope='ssd_preprocessing_train'): 302 | """Preprocess an image for evaluation. 303 | 304 | Args: 305 | image: A `Tensor` representing an image of arbitrary size. 306 | out_shape: Output shape after pre-processing (if resize != None) 307 | resize: Resize strategy. 308 | 309 | Returns: 310 | A preprocessed image. 311 | """ 312 | with tf.name_scope(scope): 313 | if image.get_shape().ndims != 3: 314 | raise ValueError('Input must be of size [height, width, C>0]') 315 | 316 | image = tf.to_float(image) 317 | image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 318 | 319 | if resize == Resize.NONE: 320 | pass 321 | else: 322 | image = tf_image.resize_image(image, out_shape, 323 | method=tf.image.ResizeMethod.BILINEAR, 324 | align_corners=False) 325 | 326 | # Image data format. 327 | if data_format == 'NCHW': 328 | image = tf.transpose(image, perm=(2, 0, 1)) 329 | return image, labels, bboxes, xs, ys 330 | 331 | 332 | def preprocess_image(image, 333 | labels, 334 | bboxes, 335 | xs, ys, 336 | out_shape, 337 | data_format = 'NHWC', 338 | is_training=False, 339 | **kwargs): 340 | """Pre-process an given image. 341 | 342 | Args: 343 | image: A `Tensor` representing an image of arbitrary size. 344 | output_height: The height of the image after preprocessing. 345 | output_width: The width of the image after preprocessing. 346 | is_training: `True` if we're preprocessing the image for training and 347 | `False` otherwise. 348 | resize_side_min: The lower bound for the smallest side of the image for 349 | aspect-preserving resizing. If `is_training` is `False`, then this value 350 | is used for rescaling. 351 | resize_side_max: The upper bound for the smallest side of the image for 352 | aspect-preserving resizing. If `is_training` is `False`, this value is 353 | ignored. Otherwise, the resize side is sampled from 354 | [resize_size_min, resize_size_max]. 355 | 356 | Returns: 357 | A preprocessed image. 358 | """ 359 | if is_training: 360 | return preprocess_for_train(image, labels, bboxes, xs, ys, 361 | out_shape=out_shape, 362 | data_format=data_format) 363 | else: 364 | return preprocess_for_eval(image, labels, bboxes, xs, ys, 365 | out_shape=out_shape, 366 | data_format=data_format, 367 | **kwargs) 368 | -------------------------------------------------------------------------------- /preprocessing/tf_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors and Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Custom image operations. 16 | Most of the following methods extend TensorFlow image library, and part of 17 | the code is shameless copy-paste of the former! 18 | """ 19 | import tensorflow as tf 20 | 21 | from tensorflow.python.framework import constant_op 22 | from tensorflow.python.framework import dtypes 23 | from tensorflow.python.framework import ops 24 | from tensorflow.python.framework import tensor_shape 25 | from tensorflow.python.framework import tensor_util 26 | from tensorflow.python.ops import array_ops 27 | from tensorflow.python.ops import check_ops 28 | from tensorflow.python.ops import clip_ops 29 | from tensorflow.python.ops import control_flow_ops 30 | from tensorflow.python.ops import gen_image_ops 31 | from tensorflow.python.ops import gen_nn_ops 32 | from tensorflow.python.ops import string_ops 33 | from tensorflow.python.ops import math_ops 34 | from tensorflow.python.ops import random_ops 35 | from tensorflow.python.ops import variables 36 | 37 | 38 | # =========================================================================== # 39 | # Modification of TensorFlow image routines. 40 | # =========================================================================== # 41 | def _assert(cond, ex_type, msg): 42 | """A polymorphic assert, works with tensors and boolean expressions. 43 | If `cond` is not a tensor, behave like an ordinary assert statement, except 44 | that a empty list is returned. If `cond` is a tensor, return a list 45 | containing a single TensorFlow assert op. 46 | Args: 47 | cond: Something evaluates to a boolean value. May be a tensor. 48 | ex_type: The exception class to use. 49 | msg: The error message. 50 | Returns: 51 | A list, containing at most one assert op. 52 | """ 53 | if _is_tensor(cond): 54 | return [control_flow_ops.Assert(cond, [msg])] 55 | else: 56 | if not cond: 57 | raise ex_type(msg) 58 | else: 59 | return [] 60 | 61 | 62 | def _is_tensor(x): 63 | """Returns `True` if `x` is a symbolic tensor-like object. 64 | Args: 65 | x: A python object to check. 66 | Returns: 67 | `True` if `x` is a `tf.Tensor` or `tf.Variable`, otherwise `False`. 68 | """ 69 | return isinstance(x, (ops.Tensor, variables.Variable)) 70 | 71 | 72 | def _ImageDimensions(image): 73 | """Returns the dimensions of an image tensor. 74 | Args: 75 | image: A 3-D Tensor of shape `[height, width, channels]`. 76 | Returns: 77 | A list of `[height, width, channels]` corresponding to the dimensions of the 78 | input image. Dimensions that are statically known are python integers, 79 | otherwise they are integer scalar tensors. 80 | """ 81 | if image.get_shape().is_fully_defined(): 82 | return image.get_shape().as_list() 83 | else: 84 | static_shape = image.get_shape().with_rank(3).as_list() 85 | dynamic_shape = array_ops.unstack(array_ops.shape(image), 3) 86 | return [s if s is not None else d 87 | for s, d in zip(static_shape, dynamic_shape)] 88 | 89 | 90 | def _Check3DImage(image, require_static=True): 91 | """Assert that we are working with properly shaped image. 92 | Args: 93 | image: 3-D Tensor of shape [height, width, channels] 94 | require_static: If `True`, requires that all dimensions of `image` are 95 | known and non-zero. 96 | Raises: 97 | ValueError: if `image.shape` is not a 3-vector. 98 | Returns: 99 | An empty list, if `image` has fully defined dimensions. Otherwise, a list 100 | containing an assert op is returned. 101 | """ 102 | try: 103 | image_shape = image.get_shape().with_rank(3) 104 | except ValueError: 105 | raise ValueError("'image' must be three-dimensional.") 106 | if require_static and not image_shape.is_fully_defined(): 107 | raise ValueError("'image' must be fully defined.") 108 | if any(x == 0 for x in image_shape): 109 | raise ValueError("all dims of 'image.shape' must be > 0: %s" % 110 | image_shape) 111 | if not image_shape.is_fully_defined(): 112 | return [check_ops.assert_positive(array_ops.shape(image), 113 | ["all dims of 'image.shape' " 114 | "must be > 0."])] 115 | else: 116 | return [] 117 | 118 | 119 | def fix_image_flip_shape(image, result): 120 | """Set the shape to 3 dimensional if we don't know anything else. 121 | Args: 122 | image: original image size 123 | result: flipped or transformed image 124 | Returns: 125 | An image whose shape is at least None,None,None. 126 | """ 127 | image_shape = image.get_shape() 128 | if image_shape == tensor_shape.unknown_shape(): 129 | result.set_shape([None, None, None]) 130 | else: 131 | result.set_shape(image_shape) 132 | return result 133 | 134 | 135 | # =========================================================================== # 136 | # Image + BBoxes methods: cropping, resizing, flipping, ... 137 | # =========================================================================== # 138 | def bboxes_crop_or_pad(bboxes, 139 | height, width, 140 | offset_y, offset_x, 141 | target_height, target_width): 142 | """Adapt bounding boxes to crop or pad operations. 143 | Coordinates are always supposed to be relative to the image. 144 | 145 | Arguments: 146 | bboxes: Tensor Nx4 with bboxes coordinates [y_min, x_min, y_max, x_max]; 147 | height, width: Original image dimension; 148 | offset_y, offset_x: Offset to apply, 149 | negative if cropping, positive if padding; 150 | target_height, target_width: Target dimension after cropping / padding. 151 | """ 152 | with tf.name_scope('bboxes_crop_or_pad'): 153 | # Rescale bounding boxes in pixels. 154 | scale = tf.cast(tf.stack([height, width, height, width]), bboxes.dtype) 155 | bboxes = bboxes * scale 156 | # Add offset. 157 | offset = tf.cast(tf.stack([offset_y, offset_x, offset_y, offset_x]), bboxes.dtype) 158 | bboxes = bboxes + offset 159 | # Rescale to target dimension. 160 | scale = tf.cast(tf.stack([target_height, target_width, 161 | target_height, target_width]), bboxes.dtype) 162 | bboxes = bboxes / scale 163 | return bboxes 164 | 165 | 166 | def resize_image_bboxes_with_crop_or_pad(image, bboxes, 167 | target_height, target_width): 168 | """Crops and/or pads an image to a target width and height. 169 | Resizes an image to a target width and height by either centrally 170 | cropping the image or padding it evenly with zeros. 171 | 172 | If `width` or `height` is greater than the specified `target_width` or 173 | `target_height` respectively, this op centrally crops along that dimension. 174 | If `width` or `height` is smaller than the specified `target_width` or 175 | `target_height` respectively, this op centrally pads with 0 along that 176 | dimension. 177 | Args: 178 | image: 3-D tensor of shape `[height, width, channels]` 179 | target_height: Target height. 180 | target_width: Target width. 181 | Raises: 182 | ValueError: if `target_height` or `target_width` are zero or negative. 183 | Returns: 184 | Cropped and/or padded image of shape 185 | `[target_height, target_width, channels]` 186 | """ 187 | with tf.name_scope('resize_with_crop_or_pad'): 188 | image = ops.convert_to_tensor(image, name='image') 189 | 190 | assert_ops = [] 191 | assert_ops += _Check3DImage(image, require_static=False) 192 | assert_ops += _assert(target_width > 0, ValueError, 193 | 'target_width must be > 0.') 194 | assert_ops += _assert(target_height > 0, ValueError, 195 | 'target_height must be > 0.') 196 | 197 | image = control_flow_ops.with_dependencies(assert_ops, image) 198 | # `crop_to_bounding_box` and `pad_to_bounding_box` have their own checks. 199 | # Make sure our checks come first, so that error messages are clearer. 200 | if _is_tensor(target_height): 201 | target_height = control_flow_ops.with_dependencies( 202 | assert_ops, target_height) 203 | if _is_tensor(target_width): 204 | target_width = control_flow_ops.with_dependencies(assert_ops, target_width) 205 | 206 | def max_(x, y): 207 | if _is_tensor(x) or _is_tensor(y): 208 | return math_ops.maximum(x, y) 209 | else: 210 | return max(x, y) 211 | 212 | def min_(x, y): 213 | if _is_tensor(x) or _is_tensor(y): 214 | return math_ops.minimum(x, y) 215 | else: 216 | return min(x, y) 217 | 218 | def equal_(x, y): 219 | if _is_tensor(x) or _is_tensor(y): 220 | return math_ops.equal(x, y) 221 | else: 222 | return x == y 223 | 224 | height, width, _ = _ImageDimensions(image) 225 | width_diff = target_width - width 226 | offset_crop_width = max_(-width_diff // 2, 0) 227 | offset_pad_width = max_(width_diff // 2, 0) 228 | 229 | height_diff = target_height - height 230 | offset_crop_height = max_(-height_diff // 2, 0) 231 | offset_pad_height = max_(height_diff // 2, 0) 232 | 233 | # Maybe crop if needed. 234 | height_crop = min_(target_height, height) 235 | width_crop = min_(target_width, width) 236 | cropped = tf.image.crop_to_bounding_box(image, offset_crop_height, offset_crop_width, 237 | height_crop, width_crop) 238 | bboxes = bboxes_crop_or_pad(bboxes, 239 | height, width, 240 | -offset_crop_height, -offset_crop_width, 241 | height_crop, width_crop) 242 | # Maybe pad if needed. 243 | resized = tf.image.pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width, 244 | target_height, target_width) 245 | bboxes = bboxes_crop_or_pad(bboxes, 246 | height_crop, width_crop, 247 | offset_pad_height, offset_pad_width, 248 | target_height, target_width) 249 | 250 | # In theory all the checks below are redundant. 251 | if resized.get_shape().ndims is None: 252 | raise ValueError('resized contains no shape.') 253 | 254 | resized_height, resized_width, _ = _ImageDimensions(resized) 255 | 256 | assert_ops = [] 257 | assert_ops += _assert(equal_(resized_height, target_height), ValueError, 258 | 'resized height is not correct.') 259 | assert_ops += _assert(equal_(resized_width, target_width), ValueError, 260 | 'resized width is not correct.') 261 | 262 | resized = control_flow_ops.with_dependencies(assert_ops, resized) 263 | return resized, bboxes 264 | 265 | 266 | def resize_image(image, size, 267 | method=tf.image.ResizeMethod.BILINEAR, 268 | align_corners=False): 269 | """Resize an image and bounding boxes. 270 | """ 271 | # Resize image. 272 | with tf.name_scope('resize_image'): 273 | height, width, channels = _ImageDimensions(image) 274 | image = tf.expand_dims(image, 0) 275 | image = tf.image.resize_images(image, size, 276 | method, align_corners) 277 | image = tf.reshape(image, tf.stack([size[0], size[1], channels])) 278 | return image 279 | 280 | 281 | def random_flip_left_right(image, bboxes, seed=None): 282 | """Random flip left-right of an image and its bounding boxes. 283 | """ 284 | def flip_bboxes(bboxes): 285 | """Flip bounding boxes coordinates. 286 | """ 287 | bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3], 288 | bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1) 289 | return bboxes 290 | 291 | # Random flip. Tensorflow implementation. 292 | with tf.name_scope('random_flip_left_right'): 293 | image = ops.convert_to_tensor(image, name='image') 294 | _Check3DImage(image, require_static=False) 295 | uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) 296 | mirror_cond = math_ops.less(uniform_random, .5) 297 | # Flip image. 298 | result = control_flow_ops.cond(mirror_cond, 299 | lambda: array_ops.reverse_v2(image, [1]), 300 | lambda: image) 301 | # Flip bboxes. 302 | bboxes = control_flow_ops.cond(mirror_cond, 303 | lambda: flip_bboxes(bboxes), 304 | lambda: bboxes) 305 | return fix_image_flip_shape(image, result), bboxes 306 | 307 | -------------------------------------------------------------------------------- /push.sh: -------------------------------------------------------------------------------- 1 | git add . --all 2 | git commit -m "update README.md" 3 | git push -u origin master 4 | 5 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | # ./scripts/eval.sh 1 icdar2013 train 384 384 ckpt 4 | # ./scripts/eval.sh 1 icdar2013 train 512 512 ckpt 5 | 6 | export CUDA_VISIBLE_DEVICES=$1 7 | DATASET=$2 8 | SPLIT=$3 9 | WIDTH=$4 10 | HEIGHT=$5 11 | CHECKPOINT_PATH=$6 12 | 13 | if [ $DATASET == 'synthtext' ] 14 | then 15 | DATA_PATH=SynthText 16 | elif [ $DATASET == 'scut' ] 17 | then 18 | DATA_PATH=SCUT 19 | elif [ $DATASET == 'icdar2013' ] 20 | then 21 | DATA_PATH=ICDAR 22 | elif [ $DATASET == 'icdar2015' ] 23 | then 24 | DATA_PATH=ICDAR 25 | else 26 | echo invalid dataset: $DATASET 27 | exit 28 | fi 29 | 30 | DATASET_DIR=$HOME/dataset/SSD-tf/${DATA_PATH} 31 | 32 | python eval_seglink.py \ 33 | --checkpoint_path=${CHECKPOINT_PATH} \ 34 | --dataset_dir=${DATASET_DIR} \ 35 | --dataset_name=${DATASET} \ 36 | --dataset_split_name=$SPLIT \ 37 | --eval_image_width=${WIDTH} \ 38 | --eval_image_height=${HEIGHT} \ 39 | --gpu_memory_fraction=0.4 \ 40 | --do_grid_search=$7 \ 41 | --using_moving_average=0 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | # ./scripts/test.sh 1 icdar2013 train 384 384 ckpt 4 | # ./scripts/test.sh 1 icdar2013 train 512 512 ckpt 5 | # ./scripts/test.sh 1 icdar2015 train 1280 768 ckpt 6 | 7 | export CUDA_VISIBLE_DEVICES=$1 8 | CHECKPOINT_PATH=$2 9 | DATASET_DIR=$3 10 | 11 | 12 | 13 | python test_seglink.py \ 14 | --checkpoint_path=${CHECKPOINT_PATH} \ 15 | --gpu_memory_fraction=-1 \ 16 | --seg_conf_threshold=0.8 \ 17 | --link_conf_threshold=0.5 \ 18 | --dataset_dir=${DATASET_DIR} 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | # ./scripts/train.sh 0 18 synthtext 4 | export CUDA_VISIBLE_DEVICES=$1 5 | IMG_PER_GPU=$2 6 | DATASET=$3 7 | 8 | CHKPT_PATH=${HOME}/models/seglink/seglink_synthtext 9 | TRAIN_DIR=${HOME}/models/seglink/seglink_icdar2015_384 10 | #TRAIN_DIR=${HOME}/temp/no-use/seglink/seglink_icdar2015_384 11 | #CHKPT_PATH=${HOME}/models/ssd-pretrain/seglink 12 | 13 | # get the number of gpus 14 | OLD_IFS="$IFS" 15 | IFS="," 16 | gpus=($CUDA_VISIBLE_DEVICES) 17 | IFS="$OLD_IFS" 18 | NUM_GPUS=${#gpus[@]} 19 | 20 | # batch_size = num_gpus * IMG_PER_GPU 21 | BATCH_SIZE=`expr $NUM_GPUS \* $IMG_PER_GPU` 22 | 23 | #dataset 24 | if [ $DATASET == 'synthtext' ] 25 | then 26 | DATA_PATH=SynthText 27 | elif [ $DATASET == 'scut' ] 28 | then 29 | DATA_PATH=SCUT 30 | elif [ $DATASET == 'icdar2013' ] 31 | then 32 | DATA_PATH=ICDAR 33 | elif [ $DATASET == 'icdar2015' ] 34 | then 35 | DATA_PATH=ICDAR 36 | else 37 | echo invalid dataset: $DATASET 38 | exit 39 | fi 40 | 41 | DATASET_DIR=$HOME/dataset/SSD-tf/${DATA_PATH} 42 | 43 | python train_seglink.py \ 44 | --train_dir=${TRAIN_DIR} \ 45 | --num_gpus=${NUM_GPUS} \ 46 | --learning_rate=0.0001 \ 47 | --gpu_memory_fraction=-1 \ 48 | --train_image_width=384 \ 49 | --train_image_height=384 \ 50 | --batch_size=${BATCH_SIZE}\ 51 | --dataset_dir=${DATASET_DIR} \ 52 | --dataset_name=${DATASET} \ 53 | --dataset_split_name=train \ 54 | --train_with_ignored=0 \ 55 | --checkpoint_path=${CHKPT_PATH} \ 56 | --using_moving_average=0 57 | -------------------------------------------------------------------------------- /scripts/vis.sh: -------------------------------------------------------------------------------- 1 | #python visualize_detection_result.py \ 2 | # --image=~/dataset/ICDAR2015/Challenge2.Task123/Challenge2_Test_Task12_Images/ \ 3 | # --gt=~/dataset/ICDAR2015/Challenge2.Task123/Challenge2_Test_Task1_GT/ \ 4 | # --det=~/temp/no-use/seglink_debug_icdar2013/eval/icdar2013_test/model.ckpt-48176/txt/ \ 5 | # --output=~/temp/no-use/seglink_result 6 | 7 | #python visualize_detection_result.py \ 8 | # --image=~/dataset/ICDAR2015/Challenge2.Task123/Challenge2_Training_Task12_Images/ \ 9 | # --gt=~/dataset/ICDAR2015/Challenge2.Task123/Challenge2_Training_Task1_GT/ \ 10 | # --det=~/temp/no-use/seglink_debug_icdar2013/eval/icdar2013_train/model.ckpt-48176/txt/ \ 11 | # --output=~/temp/no-use/seglink_result 12 | 13 | 14 | python visualize_detection_result.py \ 15 | --image=~/dataset/ICDAR2015/Challenge4/ch4_training_images/ \ 16 | --gt=~/dataset/ICDAR2015/Challenge4/ch4_training_localization_transcription_gt/ \ 17 | --det=~/models/seglink/seglink_icdar2015_without_ignored/eval/icdar2015_train/model.ckpt-72885/seg_link_conf_th_0.900000_0.700000/txt \ 18 | --output=~/temp/no-use/seglink_result_512_train 19 | -------------------------------------------------------------------------------- /test/test_batch_and_gt.py: -------------------------------------------------------------------------------- 1 | #test code to make sure the ground truth calculation and data batch works well. 2 | 3 | import numpy as np 4 | import tensorflow as tf # test 5 | 6 | from datasets import dataset_factory 7 | from preprocessing import ssd_vgg_preprocessing 8 | from tf_extended import seglink 9 | import util 10 | import cv2 11 | from nets import seglink_symbol 12 | from nets import anchor_layer 13 | slim = tf.contrib.slim 14 | import config 15 | DATA_FORMAT = 'NHWC' 16 | 17 | # =========================================================================== # 18 | # I/O and preprocessing Flags. 19 | # =========================================================================== # 20 | tf.app.flags.DEFINE_integer( 21 | 'num_readers', 18, 22 | 'The number of parallel readers that read data from the dataset.') 23 | tf.app.flags.DEFINE_integer( 24 | 'num_preprocessing_threads', 4, 25 | 'The number of threads used to create the batches.') 26 | 27 | # =========================================================================== # 28 | # Dataset Flags. 29 | # =========================================================================== # 30 | tf.app.flags.DEFINE_string( 31 | 'dataset_name', 'synthtext', 'The name of the dataset to load.') 32 | tf.app.flags.DEFINE_integer( 33 | 'num_classes', 2, 'Number of classes to use in the dataset.') 34 | tf.app.flags.DEFINE_string( 35 | 'dataset_split_name', 'train', 'The name of the train/test split.') 36 | tf.app.flags.DEFINE_string( 37 | 'dataset_dir', util.io.get_absolute_path('~/dataset/SSD-tf/SynthText'), 'The directory where the dataset files are stored.') 38 | tf.app.flags.DEFINE_string( 39 | 'model_name', 'seglink_vgg', 'The name of the architecture to train.') 40 | tf.app.flags.DEFINE_integer( 41 | 'batch_size', 2, 'The number of samples in each batch.') 42 | tf.app.flags.DEFINE_integer('train_image_width', 1024, 'Train image size') 43 | tf.app.flags.DEFINE_integer('train_image_height', 512, 'Train image size') 44 | 45 | 46 | FLAGS = tf.app.flags.FLAGS 47 | 48 | 49 | def draw_bbox(mask, bbox, color = util.img.COLOR_RGB_RED): 50 | bbox = np.reshape(bbox, (4, 2)) 51 | cnts = util.img.points_to_contours(bbox) 52 | util.img.draw_contours(mask, cnts, -1, color = color) 53 | 54 | def config_initialization(): 55 | if not FLAGS.dataset_dir: 56 | raise ValueError('You must supply the dataset directory with --dataset_dir') 57 | tf.logging.set_verbosity(tf.logging.DEBUG) 58 | 59 | # image shape and feature layers shape inference 60 | image_shape = (FLAGS.train_image_height, FLAGS.train_image_width) 61 | 62 | config.init_config(image_shape, batch_size = FLAGS.batch_size) 63 | 64 | util.proc.set_proc_name(FLAGS.model_name + '_' + FLAGS.dataset_name) 65 | 66 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 67 | # config.print_config(FLAGS, dataset) 68 | return dataset 69 | 70 | def create_dataset_batch_queue(dataset): 71 | batch_size = config.batch_size 72 | with tf.device('/cpu:0'): 73 | with tf.name_scope(FLAGS.dataset_name + '_data_provider'): 74 | provider = slim.dataset_data_provider.DatasetDataProvider( 75 | dataset, 76 | num_readers=FLAGS.num_readers, 77 | common_queue_capacity=20 * batch_size, 78 | common_queue_min=10 * batch_size, 79 | shuffle=True) 80 | # Get for SSD network: image, labels, bboxes. 81 | [image, shape, gignored, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4] = provider.get([ 82 | 'image', 'shape', 83 | 'object/ignored', 84 | 'object/bbox', 85 | 'object/oriented_bbox/x1', 86 | 'object/oriented_bbox/x2', 87 | 'object/oriented_bbox/x3', 88 | 'object/oriented_bbox/x4', 89 | 'object/oriented_bbox/y1', 90 | 'object/oriented_bbox/y2', 91 | 'object/oriented_bbox/y3', 92 | 'object/oriented_bbox/y4' 93 | ]) 94 | gxs = tf.transpose(tf.stack([x1, x2, x3, x4])) #shape = (N, 4) 95 | gys = tf.transpose(tf.stack([y1, y2, y3, y4])) 96 | image = tf.identity(image, 'input_image') 97 | 98 | # Pre-processing image, labels and bboxes. 99 | image, gignored, gbboxes, gxs, gys = ssd_vgg_preprocessing.preprocess_image( 100 | image, gignored, gbboxes, gxs, gys, 101 | out_shape = config.image_shape, 102 | data_format = config.data_format, 103 | is_training = True) 104 | image = tf.identity(image, 'processed_image') 105 | 106 | # calculate ground truth 107 | seg_label, seg_offsets, link_label = seglink.tf_get_all_seglink_gt(gxs, gys, gignored) 108 | 109 | # batch them 110 | b_image, b_seg_label, b_seg_offsets, b_link_label = tf.train.batch( 111 | [image, seg_label, seg_offsets, link_label], 112 | batch_size = config.batch_size_per_gpu, 113 | num_threads=FLAGS.num_preprocessing_threads, 114 | capacity = 50) 115 | 116 | batch_queue = slim.prefetch_queue.prefetch_queue( 117 | [b_image, b_seg_label, b_seg_offsets, b_link_label], 118 | capacity = 50) 119 | return batch_queue 120 | 121 | # =========================================================================== # 122 | # Main training routine. 123 | # =========================================================================== # 124 | def main(_): 125 | util.init_logger() 126 | dump_path = util.io.get_absolute_path('~/temp/no-use/seglink/') 127 | 128 | dataset = config_initialization() 129 | batch_queue = create_dataset_batch_queue(dataset) 130 | batch_size = config.batch_size 131 | summary_op = tf.summary.merge_all() 132 | with tf.Session() as sess: 133 | tf.train.start_queue_runners(sess) 134 | b_image, b_seg_label, b_seg_offsets, b_link_label = batch_queue.dequeue() 135 | batch_idx = 0; 136 | while True: #batch_idx < 50: 137 | image_data_batch, seg_label_data_batch, seg_offsets_data_batch, link_label_data_batch = \ 138 | sess.run([b_image, b_seg_label, b_seg_offsets, b_link_label]) 139 | for image_idx in xrange(batch_size): 140 | image_data = image_data_batch[image_idx, ...] 141 | seg_label_data = seg_label_data_batch[image_idx, ...] 142 | seg_offsets_data = seg_offsets_data_batch[image_idx, ...] 143 | link_label_data = link_label_data_batch[image_idx, ...] 144 | 145 | image_data = image_data + [123, 117, 104] 146 | image_data = np.asarray(image_data, dtype = np.uint8) 147 | 148 | # decode the encoded ground truth back to bboxes 149 | bboxes = seglink.seglink_to_bbox(seg_scores = seg_label_data, 150 | link_scores = link_label_data, 151 | seg_offsets_pred = seg_offsets_data) 152 | 153 | # draw bboxes on the image 154 | for bbox_idx in xrange(len(bboxes)): 155 | bbox = bboxes[bbox_idx, :] 156 | draw_bbox(image_data, bbox) 157 | 158 | image_path = util.io.join_path(dump_path, '%d_%d.jpg'%(batch_idx, image_idx)) 159 | util.plt.imwrite(image_path, image_data) 160 | print 'Make sure that the text on the image are correctly bounded\ 161 | with oriented boxes:', image_path 162 | batch_idx += 1 163 | 164 | 165 | if __name__ == '__main__': 166 | tf.app.run() 167 | -------------------------------------------------------------------------------- /test/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """test code to make sure the preprocessing works all right""" 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | from datasets import dataset_factory 20 | from preprocessing import ssd_vgg_preprocessing 21 | from tf_extended import seglink as tfe_seglink 22 | import util 23 | slim = tf.contrib.slim 24 | 25 | 26 | # =========================================================================== # 27 | # I/O and preprocessing Flags. 28 | # =========================================================================== # 29 | tf.app.flags.DEFINE_integer( 30 | 'num_readers', 8, 31 | 'The number of parallel readers that read data from the dataset.') 32 | tf.app.flags.DEFINE_integer( 33 | 'num_preprocessing_threads', 4, 34 | 'The number of threads used to create the batches.') 35 | 36 | # =========================================================================== # 37 | # Dataset Flags. 38 | # =========================================================================== # 39 | tf.app.flags.DEFINE_string( 40 | 'dataset_name', 'synthtext', 'The name of the dataset to load.') 41 | tf.app.flags.DEFINE_string( 42 | 'dataset_split_name', 'train', 'The name of the train/test split.') 43 | tf.app.flags.DEFINE_string( 44 | 'dataset_dir', '~/dataset/SSD-tf/SynthText', 'The directory where the dataset files are stored.') 45 | tf.app.flags.DEFINE_string( 46 | 'model_name', 'ssd_vgg', 'The name of the architecture to train.') 47 | tf.app.flags.DEFINE_integer( 48 | 'batch_size', 2, 'The number of samples in each batch.') 49 | tf.app.flags.DEFINE_integer( 50 | 'train_image_size', 512, 'Train image size') 51 | tf.app.flags.DEFINE_integer('max_number_of_steps', None, 52 | 'The maximum number of training steps.') 53 | 54 | 55 | FLAGS = tf.app.flags.FLAGS 56 | 57 | # =========================================================================== # 58 | # Main training routine. 59 | # =========================================================================== # 60 | def main(_): 61 | if not FLAGS.dataset_dir: 62 | raise ValueError('You must supply the dataset directory with --dataset_dir') 63 | tf.logging.set_verbosity(tf.logging.DEBUG) 64 | batch_size = FLAGS.batch_size; 65 | with tf.Graph().as_default(): 66 | # Select the dataset. 67 | dataset = dataset_factory.get_dataset( 68 | FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 69 | 70 | util.proc.set_proc_name(FLAGS.model_name + '_' + FLAGS.dataset_name) 71 | 72 | 73 | # =================================================================== # 74 | # Create a dataset provider and batches. 75 | # =================================================================== # 76 | with tf.device('/cpu:0'): 77 | with tf.name_scope(FLAGS.dataset_name + '_data_provider'): 78 | provider = slim.dataset_data_provider.DatasetDataProvider( 79 | dataset, 80 | num_readers=FLAGS.num_readers, 81 | common_queue_capacity=20 * batch_size, 82 | common_queue_min=10 * batch_size, 83 | shuffle=True) 84 | # Get for SSD network: image, labels, bboxes. 85 | [image, shape, gignored, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4] = provider.get(['image', 'shape', 86 | 'object/ignored', 87 | 'object/bbox', 88 | 'object/oriented_bbox/x1', 89 | 'object/oriented_bbox/x2', 90 | 'object/oriented_bbox/x3', 91 | 'object/oriented_bbox/x4', 92 | 'object/oriented_bbox/y1', 93 | 'object/oriented_bbox/y2', 94 | 'object/oriented_bbox/y3', 95 | 'object/oriented_bbox/y4' 96 | ]) 97 | gxs = tf.transpose(tf.stack([x1, x2, x3, x4])) #shape = (N, 4) 98 | gys = tf.transpose(tf.stack([y1, y2, y3, y4])) 99 | image = tf.identity(image, 'input_image') 100 | # Pre-processing image, labels and bboxes. 101 | image_shape = (FLAGS.train_image_size, FLAGS.train_image_size) 102 | image, gignored, gbboxes, gxs, gys = \ 103 | ssd_vgg_preprocessing.preprocess_image(image, gignored, gbboxes, gxs, gys, 104 | out_shape=image_shape, 105 | is_training = True) 106 | gxs = gxs * tf.cast(image_shape[1], gxs.dtype) 107 | gys = gys * tf.cast(image_shape[0], gys.dtype) 108 | gorbboxes = tfe_seglink.tf_min_area_rect(gxs, gys) 109 | image = tf.identity(image, 'processed_image') 110 | 111 | with tf.Session() as sess: 112 | coord = tf.train.Coordinator() 113 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 114 | i = 0 115 | while i < 2: 116 | i += 1 117 | image_data, label_data, bbox_data, xs_data, ys_data, orbboxes = \ 118 | sess.run([image, gignored, gbboxes, gxs, gys, gorbboxes]) 119 | image_data = image_data + [123., 117., 104.] 120 | image_data = np.asarray(image_data, np.uint8) 121 | h, w = image_data.shape[0:-1] 122 | bbox_data = bbox_data * [h, w, h, w] 123 | I_bbox = image_data.copy() 124 | I_xys = image_data.copy() 125 | I_orbbox = image_data.copy() 126 | 127 | for idx in range(bbox_data.shape[0]): 128 | 129 | def draw_bbox(): 130 | y1, x1, y2, x2 = bbox_data[idx, :] 131 | util.img.rectangle(I_bbox, (x1, y1), (x2, y2), color = util.img.COLOR_WHITE) 132 | 133 | def draw_xys(): 134 | points = zip(xs_data[idx, :], ys_data[idx, :]) 135 | cnts = util.img.points_to_contours(points); 136 | util.img.draw_contours(I_xys, cnts, -1, color = util.img.COLOR_GREEN) 137 | 138 | def draw_orbbox(): 139 | orbox = orbboxes[idx, :] 140 | import cv2 141 | rect = ((orbox[0], orbox[1]), (orbox[2], orbox[3]), orbox[4]) 142 | box = cv2.cv.BoxPoints(rect) 143 | box = np.int0(box) 144 | cv2.drawContours(I_orbbox, [box], 0, util.img.COLOR_RGB_RED, 1) 145 | 146 | draw_bbox() 147 | draw_xys(); 148 | draw_orbbox(); 149 | 150 | print util.sit(I_bbox) 151 | print util.sit(I_xys) 152 | print util.sit(I_orbbox) 153 | print 'check the images and make sure that bboxes in difference colors are the same.' 154 | coord.request_stop() 155 | coord.join(threads) 156 | if __name__ == '__main__': 157 | tf.app.run() 158 | -------------------------------------------------------------------------------- /test_seglink.py: -------------------------------------------------------------------------------- 1 | #encoding = utf-8 2 | """Read test images, and store the detection result as txt files and zip file. 3 | The zip file follows the rule of ICDAR2015 Challenge4 Task1 4 | """ 5 | import numpy as np 6 | import math 7 | import tensorflow as tf # test 8 | from tensorflow.python.ops import control_flow_ops 9 | from tensorflow.contrib.training.python.training import evaluation 10 | from datasets import dataset_factory 11 | from preprocessing import ssd_vgg_preprocessing 12 | from tf_extended import seglink, metrics 13 | import util 14 | import cv2 15 | from nets import seglink_symbol, anchor_layer 16 | 17 | slim = tf.contrib.slim 18 | import config 19 | # =========================================================================== # 20 | # model threshold parameters 21 | # =========================================================================== # 22 | tf.app.flags.DEFINE_float('seg_conf_threshold', 0.9, 23 | 'the threshold on the confidence of segment') 24 | tf.app.flags.DEFINE_float('link_conf_threshold', 0.7, 25 | 'the threshold on the confidence of linkage') 26 | 27 | # =========================================================================== # 28 | # Checkpoint and running Flags 29 | # =========================================================================== # 30 | tf.app.flags.DEFINE_string('checkpoint_path', None, 31 | 'the path of checkpoint to be evaluated. If it is a directory containing many checkpoints, the lastest will be evaluated.') 32 | tf.app.flags.DEFINE_float('gpu_memory_fraction', -1, 'the gpu memory fraction to be used. If less than 0, allow_growth = True is used.') 33 | 34 | 35 | # =========================================================================== # 36 | # Dataset Flags. 37 | # =========================================================================== # 38 | tf.app.flags.DEFINE_string( 39 | 'dataset_name', 'icdar2015', 'The name of the dataset to load.') 40 | tf.app.flags.DEFINE_string( 41 | 'dataset_split_name', 'test', 'The name of the train/test split.') 42 | tf.app.flags.DEFINE_string('dataset_dir', 43 | util.io.get_absolute_path('~/dataset/ICDAR2015/Challenge4/ch4_test_images'), 44 | 'The directory where the dataset files are stored.') 45 | tf.app.flags.DEFINE_string( 46 | 'model_name', 'seglink_vgg', 'The name of the architecture to train.') 47 | tf.app.flags.DEFINE_integer('eval_image_width', 1280, 'Train image size') 48 | tf.app.flags.DEFINE_integer('eval_image_height', 768, 'Train image size') 49 | 50 | 51 | FLAGS = tf.app.flags.FLAGS 52 | 53 | def config_initialization(): 54 | # image shape and feature layers shape inference 55 | image_shape = (FLAGS.eval_image_height, FLAGS.eval_image_width) 56 | 57 | if not FLAGS.dataset_dir: 58 | raise ValueError('You must supply the dataset directory with --dataset_dir') 59 | tf.logging.set_verbosity(tf.logging.DEBUG) 60 | 61 | config.init_config(image_shape, batch_size = 1, seg_conf_threshold = FLAGS.seg_conf_threshold, 62 | link_conf_threshold = FLAGS.link_conf_threshold) 63 | 64 | util.proc.set_proc_name('test' + FLAGS.model_name) 65 | 66 | 67 | def write_result(image_name, image_data, bboxes, path): 68 | filename = util.io.join_path(path, 'res_%s.txt'%(image_name)) 69 | print filename 70 | lines = [] 71 | for bbox in bboxes: 72 | line = "%d, %d, %d, %d, %d, %d, %d, %d\r\n"%(int(v) for v in bbox) 73 | lines.append(line) 74 | util.io.write_lines(filename, lines) 75 | 76 | 77 | def eval(): 78 | 79 | with tf.name_scope('test'): 80 | with tf.variable_scope(tf.get_variable_scope(), reuse = True):# the variables has been created in config.init_config 81 | image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3]) 82 | image_shape = tf.placeholder(dtype = tf.int32, shape = [3, ]) 83 | processed_image, _, _, _, _ = ssd_vgg_preprocessing.preprocess_image(image, None, None, None, None, 84 | out_shape = config.image_shape, 85 | data_format = config.data_format, 86 | is_training = False) 87 | b_image = tf.expand_dims(processed_image, axis = 0) 88 | b_shape = tf.expand_dims(image_shape, axis = 0) 89 | net = seglink_symbol.SegLinkNet(inputs = b_image, data_format = config.data_format) 90 | bboxes_pred = seglink.tf_seglink_to_bbox(net.seg_scores, net.link_scores, 91 | net.seg_offsets, 92 | image_shape = b_shape, 93 | seg_conf_threshold = config.seg_conf_threshold, 94 | link_conf_threshold = config.link_conf_threshold) 95 | 96 | image_names = util.io.ls(FLAGS.dataset_dir) 97 | 98 | sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True) 99 | if FLAGS.gpu_memory_fraction < 0: 100 | sess_config.gpu_options.allow_growth = True 101 | elif FLAGS.gpu_memory_fraction > 0: 102 | sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction; 103 | 104 | checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path) 105 | logdir = util.io.join_path(FLAGS.checkpoint_path, 'test', FLAGS.dataset_name + '_' +FLAGS.dataset_split_name) 106 | 107 | saver = tf.train.Saver() 108 | if util.io.is_dir(FLAGS.checkpoint_path): 109 | checkpoint = util.tf.get_latest_ckpt(FLAGS.checkpoint_path) 110 | else: 111 | checkpoint = FLAGS.checkpoint_path 112 | 113 | tf.logging.info('testing', checkpoint) 114 | 115 | with tf.Session(config = sess_config) as sess: 116 | saver.restore(sess, checkpoint) 117 | checkpoint_name = util.io.get_filename(str(checkpoint)); 118 | dump_path = util.io.join_path(logdir, checkpoint_name, 119 | 'seg_link_conf_th_%f_%f'%(config.seg_conf_threshold, config.link_conf_threshold)) 120 | 121 | txt_path = util.io.join_path(dump_path,'txt') 122 | zip_path = util.io.join_path(dump_path, '%s_seg_link_conf_th_%f_%f.zip'%(checkpoint_name, config.seg_conf_threshold, config.link_conf_threshold)) 123 | 124 | # write detection result as txt files 125 | def write_result_as_txt(image_name, bboxes, path): 126 | filename = util.io.join_path(path, 'res_%s.txt'%(image_name)) 127 | lines = [] 128 | for b_idx, bbox in enumerate(bboxes): 129 | values = [int(v) for v in bbox] 130 | line = "%d, %d, %d, %d, %d, %d, %d, %d\n"%tuple(values) 131 | lines.append(line) 132 | util.io.write_lines(filename, lines) 133 | print 'result has been written to:', filename 134 | 135 | for iter, image_name in enumerate(image_names): 136 | image_data = util.img.imread(util.io.join_path(FLAGS.dataset_dir, image_name), rgb = True) 137 | image_name = image_name.split('.')[0] 138 | image_bboxes = sess.run([bboxes_pred], feed_dict = {image:image_data, image_shape:image_data.shape}) 139 | print '%d/%d: %s'%(iter + 1, len(image_names), image_name) 140 | write_result_as_txt(image_name, image_bboxes[0], txt_path) 141 | 142 | # create zip file for icdar2015 143 | cmd = 'cd %s;zip -j %s %s/*'%(dump_path, zip_path, txt_path); 144 | print cmd 145 | print util.cmd.cmd(cmd); 146 | print "zip file created: ", util.io.join_path(dump_path, zip_path) 147 | 148 | 149 | def main(_): 150 | config_initialization() 151 | eval() 152 | 153 | 154 | if __name__ == '__main__': 155 | tf.app.run() 156 | -------------------------------------------------------------------------------- /tf_extended/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional metrics. 16 | """ 17 | 18 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import 19 | from tf_extended.metrics import * 20 | from tf_extended.bboxes import * 21 | from tf_extended.math import * 22 | 23 | -------------------------------------------------------------------------------- /tf_extended/bboxes.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import numpy as np 15 | import tensorflow as tf 16 | import cv2 17 | import util 18 | import config 19 | from tf_extended import math as tfe_math 20 | 21 | def bboxes_resize(bbox_ref, bboxes, xs, ys, name=None): 22 | """Resize bounding boxes based on a reference bounding box, 23 | assuming that the latter is [0, 0, 1, 1] after transform. Useful for 24 | updating a collection of boxes after cropping an image. 25 | """ 26 | # Tensors inputs. 27 | with tf.name_scope(name, 'bboxes_resize'): 28 | h_ref = bbox_ref[2] - bbox_ref[0] 29 | w_ref = bbox_ref[3] - bbox_ref[1] 30 | 31 | # Translate. 32 | v = tf.stack([bbox_ref[0], bbox_ref[1], bbox_ref[0], bbox_ref[1]]) 33 | bboxes = bboxes - v 34 | xs = xs - bbox_ref[1] 35 | ys = ys - bbox_ref[0] 36 | 37 | # Scale. 38 | s = tf.stack([h_ref, w_ref, h_ref, w_ref]) 39 | bboxes = bboxes / s 40 | xs = xs / w_ref; 41 | ys = ys / h_ref; 42 | 43 | return bboxes, xs, ys 44 | 45 | 46 | 47 | # def bboxes_filter_center(labels, bboxes, scope=None): 48 | # """Filter out bounding boxes whose center are not in 49 | # the rectangle [0, 0, 1, 1] + margins. The margin Tensor 50 | # can be used to enforce or loosen this condition. 51 | # 52 | # Return: 53 | # labels, bboxes: Filtered elements. 54 | # """ 55 | # with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]): 56 | # cy = (bboxes[:, 0] + bboxes[:, 2]) / 2. 57 | # cx = (bboxes[:, 1] + bboxes[:, 3]) / 2. 58 | # mask = tf.greater(cy, 0.) 59 | # mask = tf.logical_and(mask, tf.greater(cx, 0.)) 60 | # mask = tf.logical_and(mask, tf.less(cy, 1.)) 61 | # mask = tf.logical_and(mask, tf.less(cx, 1.)) 62 | # # Boolean masking... 63 | # labels = tf.boolean_mask(labels, mask) 64 | # bboxes = tf.boolean_mask(bboxes, mask) 65 | # return labels, bboxes 66 | 67 | 68 | def bboxes_filter_overlap(labels, bboxes,xs, ys, threshold, scope=None, assign_negative = False): 69 | """Filter out bounding boxes based on (relative )overlap with reference 70 | box [0, 0, 1, 1]. Remove completely bounding boxes, or assign negative 71 | labels to the one outside (useful for latter processing...). 72 | 73 | Return: 74 | labels, bboxes: Filtered (or newly assigned) elements. 75 | """ 76 | with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]): 77 | scores = bboxes_intersection(tf.constant([0, 0, 1, 1], bboxes.dtype),bboxes) 78 | 79 | mask = scores > threshold 80 | if assign_negative: 81 | labels = tf.where(mask, labels, -labels) 82 | else: 83 | labels = tf.boolean_mask(labels, mask) 84 | bboxes = tf.boolean_mask(bboxes, mask) 85 | scores = bboxes_intersection(tf.constant([0, 0, 1, 1], bboxes.dtype),bboxes) 86 | xs = tf.boolean_mask(xs, mask); 87 | ys = tf.boolean_mask(ys, mask); 88 | return labels, bboxes, xs, ys 89 | 90 | 91 | def bboxes_intersection(bbox_ref, bboxes, name=None): 92 | """Compute relative intersection between a reference box and a 93 | collection of bounding boxes. Namely, compute the quotient between 94 | intersection area and box area. 95 | 96 | Args: 97 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es). 98 | bboxes: (N, 4) Tensor, collection of bounding boxes. 99 | Return: 100 | (N,) Tensor with relative intersection. 101 | """ 102 | with tf.name_scope(name, 'bboxes_intersection'): 103 | # Should be more efficient to first transpose. 104 | bboxes = tf.transpose(bboxes) 105 | bbox_ref = tf.transpose(bbox_ref) 106 | # Intersection bbox and volume. 107 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0]) 108 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1]) 109 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2]) 110 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3]) 111 | h = tf.maximum(int_ymax - int_ymin, 0.) 112 | w = tf.maximum(int_xmax - int_xmin, 0.) 113 | # Volumes. 114 | inter_vol = h * w 115 | bboxes_vol = (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1]) 116 | scores = tfe_math.safe_divide(inter_vol, bboxes_vol, 'intersection') 117 | return scores 118 | 119 | 120 | # def bboxes_matching_batch(bboxes, gxs, gys, gignored, matching_threshold=0.5, scope=None): 121 | # """Matching a collection of detected boxes with groundtruth values. 122 | # Batched-inputs version. 123 | # 124 | # Args: 125 | # rbboxes: BxN(x4) Tensors. Detected objects; 126 | # gbboxes: Groundtruth bounding boxes 127 | # matching_threshold: Threshold for a positive match. 128 | # Return: Tuple or Dictionaries with: 129 | # n_gbboxes: Scalar Tensor with number of groundtruth boxes (may difer from size because of zero padding). 130 | # tp: (B, N)-shaped boolean Tensor containing with True Positives. 131 | # fp: (B, N)-shaped boolean Tensor containing with False Positives. 132 | # """ 133 | # # Dictionaries as inputs. 134 | # 135 | # with tf.name_scope(scope, 'bboxes_matching_batch', [bboxes, gxs, gys, gignored]): 136 | # r = tf.map_fn(lambda x: 137 | # bboxes_matching(x[0], x[1], x[2], x[3], matching_threshold), 138 | # (bboxes, gxs, gys, gignored), 139 | # dtype=(tf.int64, tf.bool, tf.bool), 140 | # parallel_iterations=10, 141 | # back_prop=False, 142 | # swap_memory=True, 143 | # infer_shape=True) 144 | # return r[0], r[1], r[2] 145 | 146 | 147 | def bboxes_matching(bboxes, gxs, gys, gignored, matching_threshold = 0.5, scope=None): 148 | """Matching a collection of detected boxes with groundtruth values. 149 | Does not accept batched-inputs. 150 | The algorithm goes as follows: for every detected box, check 151 | if one grountruth box is matching. If none, then considered as False Positive. 152 | If the grountruth box is already matched with another one, it also counts 153 | as a False Positive. We refer the Pascal VOC documentation for the details. 154 | 155 | Args: 156 | rbboxes: Nx4 Tensors. Detected objects, sorted by score; 157 | gbboxes: Groundtruth bounding boxes. May be zero padded, hence 158 | zero-class objects are ignored. 159 | matching_threshold: Threshold for a positive match. 160 | Return: Tuple of: 161 | n_gbboxes: Scalar Tensor with number of groundtruth boxes (may difer from 162 | size because of zero padding). 163 | tp_match: (N,)-shaped boolean Tensor containing with True Positives. 164 | fp_match: (N,)-shaped boolean Tensor containing with False Positives. 165 | """ 166 | with tf.name_scope(scope, 'bboxes_matching_single',[bboxes, gxs, gys, gignored]): 167 | # Number of groundtruth boxes. 168 | gignored = tf.cast(gignored, dtype = tf.bool) 169 | n_gbboxes = tf.count_nonzero(tf.logical_not(gignored)) 170 | # Grountruth matching arrays. 171 | gmatch = tf.zeros(tf.shape(gignored), dtype=tf.bool) 172 | grange = tf.range(tf.size(gignored), dtype=tf.int32) 173 | 174 | # Number of detected boxes 175 | n_bboxes = tf.shape(bboxes)[0] 176 | rshape = (n_bboxes, ) 177 | # True/False positive matching TensorArrays. 178 | # ta is short for TensorArray 179 | ta_tp_bool = tf.TensorArray(tf.bool, size=n_bboxes, dynamic_size=False, infer_shape=True) 180 | ta_fp_bool = tf.TensorArray(tf.bool, size=n_bboxes, dynamic_size=False, infer_shape=True) 181 | 182 | n_ignored_det = 0 183 | # Loop over returned objects. 184 | def m_condition(i, ta_tp, ta_fp, gmatch, n_ignored_det): 185 | r = tf.less(i, tf.shape(bboxes)[0]) 186 | return r 187 | 188 | def m_body(i, ta_tp, ta_fp, gmatch, n_ignored_det): 189 | # Jaccard score with groundtruth bboxes. 190 | rbbox = bboxes[i, :] 191 | # rbbox = tf.Print(rbbox, [rbbox]) 192 | jaccard = bboxes_jaccard(rbbox, gxs, gys) 193 | 194 | # Best fit, checking it's above threshold. 195 | idxmax = tf.cast(tf.argmax(jaccard, axis=0), dtype = tf.int32) 196 | 197 | jcdmax = jaccard[idxmax] 198 | match = jcdmax > matching_threshold 199 | existing_match = gmatch[idxmax] 200 | not_ignored = tf.logical_not(gignored[idxmax]) 201 | 202 | n_ignored_det = n_ignored_det + tf.cast(gignored[idxmax], tf.int32) 203 | # TP: match & no previous match and FP: previous match | no match. 204 | # If ignored: no record, i.e FP=False and TP=False. 205 | tp = tf.logical_and(not_ignored, tf.logical_and(match, tf.logical_not(existing_match))) 206 | ta_tp = ta_tp.write(i, tp) 207 | 208 | fp = tf.logical_and(not_ignored, tf.logical_or(existing_match, tf.logical_not(match))) 209 | ta_fp = ta_fp.write(i, fp) 210 | 211 | # Update grountruth match. 212 | mask = tf.logical_and(tf.equal(grange, idxmax), tf.logical_and(not_ignored, match)) 213 | gmatch = tf.logical_or(gmatch, mask) 214 | return [i+1, ta_tp, ta_fp, gmatch,n_ignored_det] 215 | # Main loop definition. 216 | i = 0 217 | [i, ta_tp_bool, ta_fp_bool, gmatch, n_ignored_det] = \ 218 | tf.while_loop(m_condition, m_body, 219 | [i, ta_tp_bool, ta_fp_bool, gmatch, n_ignored_det], 220 | parallel_iterations=1, 221 | back_prop=False) 222 | # TensorArrays to Tensors and reshape. 223 | tp_match = tf.reshape(ta_tp_bool.stack(), rshape) 224 | fp_match = tf.reshape(ta_fp_bool.stack(), rshape) 225 | 226 | # Some debugging information... 227 | # tp_match = tf.Print(tp_match, 228 | # [n_gbboxes, n_bboxes, 229 | # tf.reduce_sum(tf.cast(tp_match, tf.int64)), 230 | # tf.reduce_sum(tf.cast(fp_match, tf.int64)), 231 | # n_ignored_det, 232 | # tf.reduce_sum(tf.cast(gmatch, tf.int64))], 233 | # 'Matching (NG, ND, TP, FP, n_ignored_det,GM): ') 234 | return n_gbboxes, tp_match, fp_match 235 | 236 | def bboxes_jaccard(bbox, gxs, gys): 237 | jaccard = tf.py_func(np_bboxes_jaccard, [bbox, gxs, gys], tf.float32) 238 | jaccard.set_shape([None, ]) 239 | return jaccard 240 | 241 | def np_bboxes_jaccard(bbox, gxs, gys): 242 | # assert np.shape(bbox) == (8,) 243 | bbox_points = np.reshape(bbox, (4, 2)) 244 | cnts = util.img.points_to_contours(bbox_points) 245 | 246 | # contruct a 0-1 mask to draw contours on 247 | xmax = np.max(bbox_points[:, 0]) 248 | xmax = max(xmax, np.max(gxs)) + 10 249 | ymax = np.max(bbox_points[:, 1]) 250 | ymax = max(ymax, np.max(gys)) + 10 251 | mask = util.img.black((ymax, xmax)) 252 | 253 | # draw bbox on the mask 254 | bbox_mask = mask.copy() 255 | util.img.draw_contours(bbox_mask, cnts, idx = -1, color = 1, border_width = -1) 256 | jaccard = np.zeros((len(gxs),), dtype = np.float32) 257 | # draw ground truth 258 | for gt_idx, gt_bbox in enumerate(zip(gxs, gys)): 259 | gt_mask = mask.copy() 260 | gt_bbox = np.transpose(gt_bbox) 261 | # assert gt_bbox.shape == (4, 2) 262 | gt_cnts = util.img.points_to_contours(gt_bbox) 263 | util.img.draw_contours(gt_mask, gt_cnts, idx = -1, color = 1, border_width = -1) 264 | 265 | intersect = np.sum(bbox_mask * gt_mask) 266 | union = np.sum(bbox_mask + gt_mask >= 1) 267 | # assert intersect == np.sum(bbox_mask * gt_mask) 268 | # assert union == np.sum((bbox_mask + gt_mask) > 0) 269 | iou = intersect * 1.0 / union 270 | jaccard[gt_idx] = iou 271 | return jaccard 272 | -------------------------------------------------------------------------------- /tf_extended/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional math functions. 16 | """ 17 | import tensorflow as tf 18 | 19 | from tensorflow.python.ops import array_ops 20 | from tensorflow.python.ops import math_ops 21 | from tensorflow.python.framework import dtypes 22 | from tensorflow.python.framework import ops 23 | 24 | 25 | def safe_divide(numerator, denominator, name): 26 | """Divides two values, returning 0 if the denominator is <= 0. 27 | Args: 28 | numerator: A real `Tensor`. 29 | denominator: A real `Tensor`, with dtype matching `numerator`. 30 | name: Name for the returned op. 31 | Returns: 32 | 0 if `denominator` <= 0, else `numerator` / `denominator` 33 | """ 34 | return tf.where( 35 | math_ops.greater(denominator, 0), 36 | math_ops.divide(numerator, denominator), 37 | tf.zeros_like(numerator), 38 | name=name) 39 | 40 | 41 | -------------------------------------------------------------------------------- /tf_extended/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import variables 3 | from tensorflow.python.ops import array_ops 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.ops import state_ops 6 | from tensorflow.python.ops import variable_scope 7 | from tf_extended import math as tfe_math 8 | import util 9 | 10 | def _create_local(name, shape, collections=None, validate_shape=True, 11 | dtype=tf.float32): 12 | """Creates a new local variable. 13 | Args: 14 | name: The name of the new or existing variable. 15 | shape: Shape of the new or existing variable. 16 | collections: A list of collection names to which the Variable will be added. 17 | validate_shape: Whether to validate the shape of the variable. 18 | dtype: Data type of the variables. 19 | Returns: 20 | The created variable. 21 | """ 22 | # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 23 | collections = list(collections or []) 24 | collections += [ops.GraphKeys.LOCAL_VARIABLES] 25 | return variables.Variable( 26 | initial_value=array_ops.zeros(shape, dtype=dtype), 27 | name=name, 28 | trainable=False, 29 | collections=collections, 30 | validate_shape=validate_shape) 31 | 32 | def streaming_tp_fp_arrays(num_gbboxes, tp, fp, 33 | metrics_collections=None, 34 | updates_collections=None, 35 | name=None): 36 | """Streaming computation of True and False Positive arrays. 37 | """ 38 | with variable_scope.variable_scope(name, 'streaming_tp_fp', 39 | [num_gbboxes, tp, fp]): 40 | num_gbboxes = tf.cast(num_gbboxes, tf.int32) 41 | tp = tf.cast(tp, tf.bool) 42 | fp = tf.cast(fp, tf.bool) 43 | # Reshape TP and FP tensors and clean away 0 class values. 44 | tp = tf.reshape(tp, [-1]) 45 | fp = tf.reshape(fp, [-1]) 46 | 47 | # Local variables accumlating information over batches. 48 | v_num_objects = _create_local('v_num_gbboxes', shape=[], dtype=tf.int32) 49 | v_tp = _create_local('v_tp', shape=[0, ], dtype=tf.bool) 50 | v_fp = _create_local('v_fp', shape=[0, ], dtype=tf.bool) 51 | 52 | 53 | # Update operations. 54 | num_objects_op = state_ops.assign_add(v_num_objects, 55 | tf.reduce_sum(num_gbboxes)) 56 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp], axis=0), 57 | validate_shape=False) 58 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp], axis=0), 59 | validate_shape=False) 60 | 61 | # Value and update ops. 62 | val = (v_num_objects, v_tp, v_fp) 63 | with ops.control_dependencies([num_objects_op, tp_op, fp_op]): 64 | update_op = (num_objects_op, tp_op, fp_op) 65 | 66 | return val, update_op 67 | 68 | 69 | def precision_recall(num_gbboxes, tp, fp, scope=None): 70 | """Compute precision and recall from true positives and false 71 | positives booleans arrays 72 | """ 73 | 74 | # Sort by score. 75 | with tf.name_scope(scope, 'precision_recall'): 76 | # Computer recall and precision. 77 | tp = tf.reduce_sum(tf.cast(tp, tf.float32), axis=0) 78 | fp = tf.reduce_sum(tf.cast(fp, tf.float32), axis=0) 79 | recall = tfe_math.safe_divide(tp, tf.cast(num_gbboxes, tf.float32), 'recall') 80 | precision = tfe_math.safe_divide(tp, tp + fp, 'precision') 81 | return tf.tuple([precision, recall]) 82 | 83 | def fmean(pre, rec): 84 | """Compute f-mean with precision and recall 85 | """ 86 | return 2 * pre * rec / (pre + rec) -------------------------------------------------------------------------------- /tf_extended/seglink.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import config 6 | import util 7 | 8 | ############################################################################################################ 9 | # seg_gt calculation # 10 | ############################################################################################################ 11 | 12 | def anchor_rect_height_ratio(anchor, rect): 13 | """calculate the height ratio between anchor and rect 14 | """ 15 | rect_height = min(rect[2], rect[3]) 16 | anchor_height = anchor[2] * 1.0 17 | ratio = anchor_height / rect_height 18 | return max(ratio, 1.0 / ratio) 19 | 20 | def is_anchor_center_in_rect(anchor, xs, ys, bbox_idx): 21 | """tell if the center of the anchor is in the rect represented using xs and ys and bbox_idx 22 | """ 23 | bbox_points = zip(xs[bbox_idx, :], ys[bbox_idx, :]) 24 | cnt = util.img.points_to_contour(bbox_points); 25 | acx, acy, aw, ah = anchor 26 | return util.img.is_in_contour((acx, acy), cnt) 27 | 28 | def min_area_rect(xs, ys): 29 | """ 30 | Args: 31 | xs: numpy ndarray with shape=(N,4). N is the number of oriented bboxes. 4 contains [x1, x2, x3, x4] 32 | ys: numpy ndarray with shape=(N,4), [y1, y2, y3, y4] 33 | Note that [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] can represent an oriented bbox. 34 | Return: 35 | the oriented rects sorrounding the box, in the format:[cx, cy, w, h, theta]. 36 | """ 37 | xs = np.asarray(xs, dtype = np.float32) 38 | ys = np.asarray(ys, dtype = np.float32) 39 | 40 | num_rects = xs.shape[0] 41 | box = np.empty((num_rects, 5))#cx, cy, w, h, theta 42 | for idx in xrange(num_rects): 43 | points = zip(xs[idx, :], ys[idx, :]) 44 | cnt = util.img.points_to_contour(points) 45 | rect = cv2.minAreaRect(cnt) 46 | cx, cy = rect[0] 47 | w, h = rect[1] 48 | theta = rect[2] 49 | box[idx, :] = [cx, cy, w, h, theta] 50 | 51 | box = np.asarray(box, dtype = xs.dtype) 52 | return box 53 | 54 | def tf_min_area_rect(xs, ys): 55 | return tf.py_func(min_area_rect, [xs, ys], xs.dtype) 56 | 57 | def transform_cv_rect(rects): 58 | """Transform the rects from opencv method minAreaRect to our rects. 59 | Step 1 of Figure 5 in seglink paper 60 | 61 | In cv2.minAreaRect, the w, h and theta values in the returned rect are not convenient to use (at least for me), so 62 | the Oriented (or rotated) Rectangle object in seglink algorithm is defined different from cv2. 63 | 64 | Rect definition in Seglink: 65 | 1. The angle value between a side and x-axis is: 66 | positive: if it rotates clockwisely, with y-axis increasing downwards. 67 | negative: if it rotates counter-clockwisely. 68 | This is opposite to cv2, and it is only a personal preference. 69 | 70 | 2. The width is the length of side taking a smaller absolute angle with the x-axis. 71 | 3. The theta value of a rect is the signed angle value between width-side and x-axis 72 | 4. To rotate a rect to horizontal direction, just rotate its width-side horizontally, 73 | i.e., rotate it by a angle of theta using cv2 method. 74 | (see the method rotate_oriented_bbox_to_horizontal for rotation detail) 75 | 76 | 77 | Args: 78 | rects: ndarray with shape = (5, ) or (N, 5). 79 | Return: 80 | transformed rects. 81 | """ 82 | only_one = False 83 | if len(np.shape(rects)) == 1: 84 | rects = np.expand_dims(rects, axis = 0) 85 | only_one = True 86 | assert np.shape(rects)[1] == 5, 'The shape of rects must be (N, 5), but meet %s'%(str(np.shape(rects))) 87 | 88 | rects = np.asarray(rects, dtype = np.float32).copy() 89 | num_rects = np.shape(rects)[0] 90 | for idx in xrange(num_rects): 91 | cx, cy, w, h, theta = rects[idx, ...]; 92 | #assert theta < 0 and theta >= -90, "invalid theta: %f"%(theta) 93 | if abs(theta) > 45 or (abs(theta) == 45 and w < h): 94 | w, h = [h, w] 95 | theta = 90 + theta 96 | rects[idx, ...] = [cx, cy, w, h, theta] 97 | if only_one: 98 | return rects[0, ...] 99 | return rects 100 | 101 | 102 | def rotate_oriented_bbox_to_horizontal(center, bbox): 103 | """ 104 | Step 2 of Figure 5 in seglink paper 105 | 106 | Rotate bbox horizontally along a `center` point 107 | Args: 108 | center: the center of rotation 109 | bbox: [cx, cy, w, h, theta] 110 | """ 111 | assert np.shape(center) == (2, ), "center must be a vector of length 2" 112 | assert np.shape(bbox) == (5, ) or np.shape(bbox) == (4, ), "bbox must be a vector of length 4 or 5" 113 | bbox = np.asarray(bbox.copy(), dtype = np.float32) 114 | 115 | cx, cy, w, h, theta = bbox; 116 | M = cv2.getRotationMatrix2D(center, theta, scale = 1) # 2x3 117 | 118 | cx, cy = np.dot(M, np.transpose([cx, cy, 1])) 119 | 120 | bbox[0:2] = [cx, cy] 121 | return bbox 122 | 123 | def crop_horizontal_bbox_using_anchor(bbox, anchor): 124 | """Step 3 in Figure 5 in seglink paper 125 | The crop operation is operated only on the x direction. 126 | Args: 127 | bbox: a horizontal bbox with shape = (5, ) or (4, ). 128 | """ 129 | assert np.shape(anchor) == (4, ), "anchor must be a vector of length 4" 130 | assert np.shape(bbox) == (5, ) or np.shape(bbox) == (4, ), "bbox must be a vector of length 4 or 5" 131 | 132 | # xmin and xmax of the anchor 133 | acx, acy, aw, ah = anchor 134 | axmin = acx - aw / 2.0; 135 | axmax = acx + aw / 2.0; 136 | 137 | # xmin and xmax of the bbox 138 | cx, cy, w, h = bbox[0:4] 139 | xmin = cx - w / 2.0 140 | xmax = cx + w / 2.0 141 | 142 | # clip operation 143 | xmin = max(xmin, axmin) 144 | xmax = min(xmax, axmax) 145 | 146 | # transform xmin, xmax to cx and w 147 | cx = (xmin + xmax) / 2.0; 148 | w = xmax - xmin 149 | bbox = bbox.copy() 150 | bbox[0:4] = [cx, cy, w, h] 151 | return bbox 152 | 153 | def rotate_horizontal_bbox_to_oriented(center, bbox): 154 | """ 155 | Step 4 of Figure 5 in seglink paper: 156 | Rotate the cropped horizontal bbox back to its original direction 157 | Args: 158 | center: the center of rotation 159 | bbox: [cx, cy, w, h, theta] 160 | Return: the oriented bbox 161 | """ 162 | assert np.shape(center) == (2, ), "center must be a vector of length 2" 163 | assert np.shape(bbox) == (5, ) , "bbox must be a vector of length 4 or 5" 164 | bbox = np.asarray(bbox.copy(), dtype = np.float32) 165 | 166 | cx, cy, w, h, theta = bbox; 167 | M = cv2.getRotationMatrix2D(center, -theta, scale = 1) # 2x3 168 | cx, cy = np.dot(M, np.transpose([cx, cy, 1])) 169 | bbox[0:2] = [cx, cy] 170 | return bbox 171 | 172 | 173 | def cal_seg_loc_for_single_anchor(anchor, rect): 174 | """ 175 | Step 2 to 4 176 | """ 177 | # rotate text box along the center of anchor to horizontal direction 178 | center = (anchor[0], anchor[1]) 179 | rect = rotate_oriented_bbox_to_horizontal(center, rect) 180 | 181 | # crop horizontal text box to anchor 182 | rect = crop_horizontal_bbox_using_anchor(rect, anchor) 183 | 184 | # rotate the box to original direction 185 | rect = rotate_horizontal_bbox_to_oriented(center, rect) 186 | 187 | return rect 188 | 189 | 190 | @util.dec.print_calling_in_short_for_tf 191 | def match_anchor_to_text_boxes(anchors, xs, ys): 192 | """Match anchors to text boxes. 193 | Return: 194 | seg_labels: shape = (N,), the seg_labels of segments. each value is the index of matched box if >=0. 195 | seg_locations: shape = (N, 5), the absolute location of segments. Only the match segments are correctly calculated. 196 | 197 | """ 198 | 199 | assert len(np.shape(anchors)) == 2 and np.shape(anchors)[1] == 4, "the anchors must be a tensor with shape = (num_anchors, 4)" 200 | assert len(np.shape(xs)) == 2 and np.shape(xs) == np.shape(ys) and np.shape(ys)[1] == 4, "the xs, ys must be a tensor with shape = (num_bboxes, 4)" 201 | anchors = np.asarray(anchors, dtype = np.float32) 202 | xs = np.asarray(xs, dtype = np.float32) 203 | ys = np.asarray(ys, dtype = np.float32) 204 | 205 | num_anchors = anchors.shape[0] 206 | seg_labels = np.ones((num_anchors, ), dtype = np.int32) * -1; 207 | seg_locations = np.zeros((num_anchors, 5), dtype = np.float32) 208 | 209 | # to avoid ln(0) in the ending process later. 210 | # because the height and width will be encoded using ln(w_seg / w_anchor) 211 | seg_locations[:, 2] = anchors[:, 2] 212 | seg_locations[:, 3] = anchors[:, 3] 213 | 214 | num_bboxes = xs.shape[0] 215 | 216 | 217 | #represent bboxes using min area rects 218 | rects = min_area_rect(xs, ys) # shape = (num_bboxes, 5) 219 | rects = transform_cv_rect(rects) 220 | assert rects.shape == (num_bboxes, 5) 221 | 222 | #represent bboxes using contours 223 | cnts = [] 224 | for bbox_idx in xrange(num_bboxes): 225 | bbox_points = zip(xs[bbox_idx, :], ys[bbox_idx, :]) 226 | cnt = util.img.points_to_contour(bbox_points); 227 | cnts.append(cnt) 228 | 229 | import time 230 | start_time = time.time() 231 | # match anchor to bbox 232 | for anchor_idx in xrange(num_anchors): 233 | anchor = anchors[anchor_idx, :] 234 | acx, acy, aw, ah = anchor 235 | center_point_matched = False 236 | height_matched = False 237 | for bbox_idx in xrange(num_bboxes): 238 | # center point check 239 | center_point_matched = util.img.is_in_contour((acx, acy), cnts[bbox_idx]) 240 | if not center_point_matched: 241 | continue 242 | 243 | # height height_ratio check 244 | rect = rects[bbox_idx, :] 245 | height_ratio = anchor_rect_height_ratio(anchor, rect) 246 | height_matched = height_ratio <= config.max_height_ratio 247 | if height_matched and center_point_matched: 248 | # an anchor can only be matched to at most one bbox 249 | seg_labels[anchor_idx] = bbox_idx 250 | seg_locations[anchor_idx, :] = cal_seg_loc_for_single_anchor(anchor, rect) 251 | 252 | end_time = time.time() 253 | tf.logging.info('Time in For Loop: %f'%(end_time - start_time)) 254 | return seg_labels, seg_locations 255 | 256 | # @util.dec.print_calling_in_short_for_tf 257 | def match_anchor_to_text_boxes_fast(anchors, xs, ys): 258 | """Match anchors to text boxes. 259 | Return: 260 | seg_labels: shape = (N,), the seg_labels of segments. each value is the index of matched box if >=0. 261 | seg_locations: shape = (N, 5), the absolute location of segments. Only the match segments are correctly calculated. 262 | 263 | """ 264 | 265 | assert len(np.shape(anchors)) == 2 and np.shape(anchors)[1] == 4, "the anchors must be a tensor with shape = (num_anchors, 4)" 266 | assert len(np.shape(xs)) == 2 and np.shape(xs) == np.shape(ys) and np.shape(ys)[1] == 4, "the xs, ys must be a tensor with shape = (num_bboxes, 4)" 267 | anchors = np.asarray(anchors, dtype = np.float32) 268 | xs = np.asarray(xs, dtype = np.float32) 269 | ys = np.asarray(ys, dtype = np.float32) 270 | 271 | num_anchors = anchors.shape[0] 272 | seg_labels = np.ones((num_anchors, ), dtype = np.int32) * -1; 273 | seg_locations = np.zeros((num_anchors, 5), dtype = np.float32) 274 | 275 | # to avoid ln(0) in the ending process later. 276 | # because the height and width will be encoded using ln(w_seg / w_anchor) 277 | seg_locations[:, 2] = anchors[:, 2] 278 | seg_locations[:, 3] = anchors[:, 3] 279 | 280 | num_bboxes = xs.shape[0] 281 | 282 | 283 | #represent bboxes using min area rects 284 | rects = min_area_rect(xs, ys) # shape = (num_bboxes, 5) 285 | rects = transform_cv_rect(rects) 286 | assert rects.shape == (num_bboxes, 5) 287 | 288 | # construct a bbox point map: keys are the poistion of all points in bbox contours, and 289 | # value being the bbox index 290 | bbox_mask = np.ones(config.image_shape, dtype = np.int32) * (-1) 291 | for bbox_idx in xrange(num_bboxes): 292 | bbox_points = zip(xs[bbox_idx, :], ys[bbox_idx, :]) 293 | bbox_cnts = util.img.points_to_contours(bbox_points) 294 | util.img.draw_contours(bbox_mask, bbox_cnts, -1, color = bbox_idx, border_width = - 1) 295 | 296 | points_in_bbox_mask = np.where(bbox_mask >= 0) 297 | points_in_bbox_mask = set(zip(*points_in_bbox_mask)) 298 | points_in_bbox_mask = points_in_bbox_mask.intersection(config.default_anchor_center_set) 299 | 300 | for point in points_in_bbox_mask: 301 | anchors_here = config.default_anchor_map[point] 302 | for anchor_idx in anchors_here: 303 | anchor = anchors[anchor_idx, :] 304 | bbox_idx = bbox_mask[point] 305 | acx, acy, aw, ah = anchor 306 | height_matched = False 307 | 308 | # height height_ratio check 309 | rect = rects[bbox_idx, :] 310 | height_ratio = anchor_rect_height_ratio(anchor, rect) 311 | height_matched = height_ratio <= config.max_height_ratio 312 | if height_matched: 313 | # an anchor can only be matched to at most one bbox 314 | seg_labels[anchor_idx] = bbox_idx 315 | seg_locations[anchor_idx, :] = cal_seg_loc_for_single_anchor(anchor, rect) 316 | return seg_labels, seg_locations 317 | 318 | 319 | ############################################################################################################ 320 | # link_gt calculation # 321 | ############################################################################################################ 322 | def reshape_link_gt_by_layer(link_gt): 323 | inter_layer_link_gts = {} 324 | cross_layer_link_gts = {} 325 | 326 | idx = 0; 327 | for layer_idx, layer_name in enumerate(config.feat_layers): 328 | layer_shape = config.feat_shapes[layer_name] 329 | lh, lw = layer_shape 330 | 331 | length = lh * lw * 8; 332 | layer_link_gt = link_gt[idx: idx + length] 333 | idx = idx + length; 334 | layer_link_gt = np.reshape(layer_link_gt, (lh, lw, 8)) 335 | inter_layer_link_gts[layer_name] = layer_link_gt 336 | 337 | for layer_idx in xrange(1, len(config.feat_layers)): 338 | layer_name = config.feat_layers[layer_idx] 339 | layer_shape = config.feat_shapes[layer_name] 340 | lh, lw = layer_shape 341 | length = lh * lw * 4; 342 | layer_link_gt = link_gt[idx: idx + length] 343 | idx = idx + length; 344 | layer_link_gt = np.reshape(layer_link_gt, (lh, lw, 4)) 345 | cross_layer_link_gts[layer_name] = layer_link_gt 346 | 347 | assert idx == len(link_gt) 348 | return inter_layer_link_gts, cross_layer_link_gts 349 | 350 | def reshape_labels_by_layer(labels): 351 | layer_labels = {} 352 | idx = 0; 353 | for layer_name in config.feat_layers: 354 | layer_shape = config.feat_shapes[layer_name] 355 | label_length = np.prod(layer_shape) 356 | 357 | layer_match_result = labels[idx: idx + label_length] 358 | idx = idx + label_length; 359 | 360 | layer_match_result = np.reshape(layer_match_result, layer_shape) 361 | 362 | layer_labels[layer_name] = layer_match_result; 363 | assert idx == len(labels) 364 | return layer_labels; 365 | 366 | def get_inter_layer_neighbours(x, y): 367 | return [(x - 1, y - 1), (x, y - 1), (x + 1, y - 1), \ 368 | (x - 1, y), (x + 1, y), \ 369 | (x - 1, y + 1), (x, y + 1), (x + 1, y + 1)] 370 | 371 | def get_cross_layer_neighbours(x, y): 372 | return [(2 * x, 2 * y), (2 * x + 1, 2 * y), (2 * x, 2 * y + 1), (2 * x + 1, 2 * y + 1)] 373 | 374 | def is_valid_cord(x, y, w, h): 375 | """ 376 | Tell whether the 2D coordinate (x, y) is valid or not. 377 | If valid, it should be on an h x w image 378 | """ 379 | return x >=0 and x < w and y >= 0 and y < h; 380 | 381 | def cal_link_labels(labels): 382 | layer_labels = reshape_labels_by_layer(labels) 383 | inter_layer_link_gts = [] 384 | cross_layer_link_gts = [] 385 | for layer_idx, layer_name in enumerate(config.feat_layers): 386 | layer_match_result = layer_labels[layer_name] 387 | h, w = config.feat_shapes[layer_name] 388 | 389 | # initalize link groundtruth for the current layer 390 | inter_layer_link_gt = np.ones((h, w, 8), dtype = np.int32) * (-1) 391 | 392 | if layer_idx > 0: # no cross-layer link for the first layer. 393 | cross_layer_link_gt = np.ones((h, w, 4), dtype = np.int32) * (-1) 394 | 395 | for x in xrange(w): 396 | for y in xrange(h): 397 | # the value in layer_match_result stands for the bbox idx a segments matches 398 | # if less than 0, not matched. 399 | # only matched segments are considered in link_gt calculation 400 | if layer_match_result[y, x] >= 0: 401 | matched_idx = layer_match_result[y, x] 402 | 403 | 404 | # inter-layer link_gt calculation 405 | # calculate inter-layer link_gt using the bbox matching result of inter-layer neighbours 406 | neighbours = get_inter_layer_neighbours(x, y) 407 | for nidx, nxy in enumerate(neighbours): # n here is short for neighbour 408 | nx, ny = nxy 409 | if is_valid_cord(nx, ny, w, h): 410 | n_matched_idx = layer_match_result[ny, nx] 411 | # if the current default box has matched the same bbox with this neighbour, \ 412 | # the linkage connecting them is labeled as positive. 413 | if matched_idx == n_matched_idx: 414 | inter_layer_link_gt[y, x, nidx] = n_matched_idx; 415 | 416 | # cross layer link_gt calculation 417 | if layer_idx > 0: 418 | previous_layer_name = config.feat_layers[layer_idx - 1]; 419 | ph, pw = config.feat_shapes[previous_layer_name] 420 | previous_layer_match_result = layer_labels[previous_layer_name] 421 | neighbours = get_cross_layer_neighbours(x, y) 422 | for nidx, nxy in enumerate(neighbours): 423 | nx, ny = nxy 424 | if is_valid_cord(nx, ny, pw, ph): 425 | n_matched_idx = previous_layer_match_result[ny, nx] 426 | if matched_idx == n_matched_idx: 427 | cross_layer_link_gt[y, x, nidx] = n_matched_idx; 428 | 429 | inter_layer_link_gts.append(inter_layer_link_gt) 430 | 431 | if layer_idx > 0: 432 | cross_layer_link_gts.append(cross_layer_link_gt) 433 | 434 | # construct the final link_gt from layer-wise data. 435 | # note that this reshape and concat order is the same with that of predicted linkages, which\ 436 | # has been done in the construction of SegLinkNet. 437 | inter_layer_link_gts = np.hstack([np.reshape(t, -1) for t in inter_layer_link_gts]); 438 | cross_layer_link_gts = np.hstack([np.reshape(t, -1) for t in cross_layer_link_gts]); 439 | link_gt = np.hstack([inter_layer_link_gts, cross_layer_link_gts]) 440 | return link_gt 441 | 442 | # @util.dec.print_calling_in_short_for_tf 443 | def encode_seg_offsets(seg_locs): 444 | """ 445 | Args: 446 | seg_locs: a ndarray with shape = (N, 5). It contains the abolute values of segment locations 447 | Return: 448 | seg_offsets, i.e., the offsets from default boxes. It is used as the final segment location ground truth. 449 | """ 450 | anchors = config.default_anchors 451 | anchor_cx, anchor_cy, anchor_w, anchor_h = (anchors[:, idx] for idx in range(4)) 452 | seg_cx, seg_cy, seg_w, seg_h = (seg_locs[:, idx] for idx in range(4)) 453 | 454 | #encoding using the formulations from Euqation (2) to (6) of seglink paper 455 | # seg_cx = anchor_cx + anchor_w * offset_cx 456 | offset_cx = (seg_cx - anchor_cx) * 1.0 / anchor_w 457 | 458 | # seg_cy = anchor_cy + anchor_w * offset_cy 459 | offset_cy = (seg_cy - anchor_cy) * 1.0 / anchor_h 460 | 461 | # seg_w = anchor_w * e^(offset_w) 462 | offset_w = np.log(seg_w * 1.0 / anchor_w) 463 | # seg_h = anchor_w * e^(offset_h) 464 | offset_h = np.log(seg_h * 1.0 / anchor_h) 465 | 466 | # prior scaling can be used to adjust the loss weight of loss on offset x, y, w, h, theta 467 | seg_offsets = np.zeros_like(seg_locs) 468 | seg_offsets[:, 0] = offset_cx / config.prior_scaling[0] 469 | seg_offsets[:, 1] = offset_cy / config.prior_scaling[1] 470 | seg_offsets[:, 2] = offset_w / config.prior_scaling[2] 471 | seg_offsets[:, 3] = offset_h / config.prior_scaling[3] 472 | seg_offsets[:, 4] = seg_locs[:, 4] / config.prior_scaling[4] 473 | return seg_offsets 474 | 475 | def decode_seg_offsets_pred(seg_offsets_pred): 476 | anchors = config.default_anchors 477 | anchor_cx, anchor_cy, anchor_w, anchor_h = (anchors[:, idx] for idx in range(4)) 478 | 479 | offset_cx = seg_offsets_pred[:, 0] * config.prior_scaling[0] 480 | offset_cy = seg_offsets_pred[:, 1] * config.prior_scaling[1] 481 | offset_w = seg_offsets_pred[:, 2] * config.prior_scaling[2] 482 | offset_h = seg_offsets_pred[:, 3] * config.prior_scaling[3] 483 | offset_theta = seg_offsets_pred[:, 4] * config.prior_scaling[4] 484 | 485 | seg_cx = anchor_cx + anchor_w * offset_cx 486 | seg_cy = anchor_cy + anchor_h * offset_cy # anchor_h == anchor_w 487 | seg_w = anchor_w * np.exp(offset_w) 488 | seg_h = anchor_h * np.exp(offset_h) 489 | seg_theta = offset_theta 490 | 491 | seg_loc = np.transpose(np.vstack([seg_cx, seg_cy, seg_w, seg_h, seg_theta])) 492 | return seg_loc 493 | 494 | # @util.dec.print_calling_in_short_for_tf 495 | def get_all_seglink_gt(xs, ys, ignored): 496 | 497 | # calculate ground truths. 498 | # for matching results, i.e., seg_labels and link_labels, the values stands for the 499 | # index of matched bbox 500 | assert len(np.shape(xs)) == 2 and \ 501 | np.shape(xs)[-1] == 4 and \ 502 | np.shape(ys) == np.shape(xs), \ 503 | 'the shape of xs and ys must be (N, 4), but got %s and %s'%(np.shape(xs), np.shape(ys)) 504 | 505 | assert len(xs) == len(ignored), 'the length of xs and `ignored` must be the same, \ 506 | but got %s and %s'%(len(xs), len(ignored)) 507 | 508 | anchors = config.default_anchors 509 | seg_labels, seg_locations = match_anchor_to_text_boxes_fast(anchors, xs, ys); 510 | link_labels = cal_link_labels(seg_labels) 511 | seg_offsets = encode_seg_offsets(seg_locations) 512 | 513 | 514 | # deal with ignored: use -2 to denotes ignored matchings temporarily 515 | def set_ignored_labels(labels, idx): 516 | cords = np.where(labels == idx) 517 | labels[cords] = -2 518 | 519 | ignored_bbox_idxes = np.where(ignored == 1)[0] 520 | for ignored_bbox_idx in ignored_bbox_idxes: 521 | set_ignored_labels(link_labels, ignored_bbox_idx) 522 | set_ignored_labels(seg_labels, ignored_bbox_idx) 523 | 524 | 525 | # deal with bbox idxes: use 1 to replace all matched label 526 | def set_positive_labels_to_one(labels): 527 | cords = np.where(labels >= 0) 528 | labels[cords] = 1 529 | 530 | set_positive_labels_to_one(seg_labels) 531 | set_positive_labels_to_one(link_labels) 532 | 533 | # deal with ignored: use 0 to replace all -2 534 | def set_ignored_labels_to_zero(labels): 535 | cords = np.where(labels == -2) 536 | labels[cords] = 0 537 | 538 | set_ignored_labels_to_zero(seg_labels) 539 | set_ignored_labels_to_zero(link_labels) 540 | 541 | # set dtypes 542 | seg_labels = np.asarray(seg_labels, dtype = np.int32) 543 | seg_offsets = np.asarray(seg_offsets, dtype = np.float32) 544 | link_labels = np.asarray(link_labels, dtype = np.int32) 545 | 546 | return seg_labels, seg_offsets, link_labels 547 | 548 | 549 | def tf_get_all_seglink_gt(xs, ys, ignored): 550 | """ 551 | xs, ys: tensors reprensenting ground truth bbox, both with shape=(N, 4), values in 0~1 552 | """ 553 | h_I, w_I = config.image_shape 554 | 555 | xs = xs * w_I 556 | ys = ys * h_I 557 | seg_labels, seg_offsets, link_labels = tf.py_func(get_all_seglink_gt, [xs, ys, ignored], [tf.int32, tf.float32, tf.int32]); 558 | seg_labels.set_shape([config.num_anchors]) 559 | seg_offsets.set_shape([config.num_anchors, 5]) 560 | link_labels.set_shape([config.num_links]) 561 | return seg_labels, seg_offsets, link_labels; 562 | 563 | ############################################################################################################ 564 | # linking segments together # 565 | ############################################################################################################ 566 | def group_segs(seg_scores, link_scores, seg_conf_threshold, link_conf_threshold): 567 | """ 568 | group segments based on their scores and links. 569 | Return: segment groups as a list, consisting of list of segment indexes, reprensting a group of segments belonging to a same bbox. 570 | """ 571 | 572 | assert len(np.shape(seg_scores)) == 1 573 | assert len(np.shape(link_scores)) == 1 574 | 575 | valid_segs = np.where(seg_scores >= seg_conf_threshold)[0];# `np.where` returns a tuple 576 | assert valid_segs.ndim == 1 577 | mask = {} 578 | for s in valid_segs: 579 | mask[s] = -1; 580 | 581 | def get_root(idx): 582 | parent = mask[idx] 583 | while parent != -1: 584 | idx = parent 585 | parent = mask[parent] 586 | return idx 587 | 588 | def union(idx1, idx2): 589 | root1 = get_root(idx1) 590 | root2 = get_root(idx2) 591 | 592 | if root1 != root2: 593 | mask[root1] = root2 594 | 595 | def to_list(): 596 | result = {} 597 | for idx in mask: 598 | root = get_root(idx) 599 | if root not in result: 600 | result[root] = [] 601 | 602 | result[root].append(idx) 603 | 604 | return [result[root] for root in result] 605 | 606 | 607 | seg_indexes = np.arange(len(seg_scores)) 608 | layer_seg_indexes = reshape_labels_by_layer(seg_indexes) 609 | 610 | layer_inter_link_scores, layer_cross_link_scores = reshape_link_gt_by_layer(link_scores) 611 | 612 | for layer_index, layer_name in enumerate(config.feat_layers): 613 | layer_shape = config.feat_shapes[layer_name] 614 | lh, lw = layer_shape 615 | layer_seg_index = layer_seg_indexes[layer_name] 616 | layer_inter_link_score = layer_inter_link_scores[layer_name] 617 | if layer_index > 0: 618 | previous_layer_name = config.feat_layers[layer_index - 1] 619 | previous_layer_seg_index = layer_seg_indexes[previous_layer_name] 620 | previous_layer_shape = config.feat_shapes[previous_layer_name] 621 | plh, plw = previous_layer_shape 622 | layer_cross_link_score = layer_cross_link_scores[layer_name] 623 | 624 | 625 | for y in xrange(lh): 626 | for x in xrange(lw): 627 | seg_index = layer_seg_index[y, x] 628 | _seg_score = seg_scores[seg_index] 629 | if _seg_score >= seg_conf_threshold: 630 | 631 | # find inter layer linked neighbours 632 | inter_layer_neighbours = get_inter_layer_neighbours(x, y) 633 | for nidx, nxy in enumerate(inter_layer_neighbours): 634 | nx, ny = nxy 635 | 636 | # the condition of connecting neighbour segment: valid coordinate, 637 | # valid segment confidence and valid link confidence. 638 | if is_valid_cord(nx, ny, lw, lh) and \ 639 | seg_scores[layer_seg_index[ny, nx]] >= seg_conf_threshold and \ 640 | layer_inter_link_score[y, x, nidx] >= link_conf_threshold: 641 | n_seg_index = layer_seg_index[ny, nx] 642 | union(seg_index, n_seg_index) 643 | 644 | # find cross layer linked neighbours 645 | if layer_index > 0: 646 | cross_layer_neighbours = get_cross_layer_neighbours(x, y) 647 | for nidx, nxy in enumerate(cross_layer_neighbours): 648 | nx, ny = nxy 649 | if is_valid_cord(nx, ny, plw, plh) and \ 650 | seg_scores[previous_layer_seg_index[ny, nx]] >= seg_conf_threshold and \ 651 | layer_cross_link_score[y, x, nidx] >= link_conf_threshold: 652 | 653 | n_seg_index = previous_layer_seg_index[ny, nx] 654 | union(seg_index, n_seg_index) 655 | 656 | return to_list() 657 | 658 | 659 | 660 | ############################################################################################################ 661 | # combining segments to bboxes # 662 | ############################################################################################################ 663 | def tf_seglink_to_bbox(seg_cls_pred, link_cls_pred, seg_offsets_pred, image_shape, 664 | seg_conf_threshold = None, link_conf_threshold = None): 665 | if len(seg_cls_pred.shape) == 3: 666 | assert seg_cls_pred.shape[0] == 1 # only batch_size == 1 supported now TODO 667 | seg_cls_pred = seg_cls_pred[0, ...] 668 | link_cls_pred = link_cls_pred[0, ...] 669 | seg_offsets_pred = seg_offsets_pred[0, ...] 670 | image_shape = image_shape[0, :] 671 | 672 | assert seg_cls_pred.shape[-1] == 2 673 | assert link_cls_pred.shape[-1] == 2 674 | assert seg_offsets_pred.shape[-1] == 5 675 | 676 | seg_scores = seg_cls_pred[:, 1] 677 | link_scores = link_cls_pred[:, 1] 678 | image_bboxes = tf.py_func(seglink_to_bbox, 679 | [seg_scores, link_scores, seg_offsets_pred, image_shape, seg_conf_threshold, link_conf_threshold], 680 | tf.float32); 681 | return image_bboxes 682 | 683 | 684 | def seglink_to_bbox(seg_scores, link_scores, seg_offsets_pred, 685 | image_shape = None, seg_conf_threshold = None, link_conf_threshold = None): 686 | """ 687 | Args: 688 | seg_scores: the scores of segments being positive 689 | link_scores: the scores of linkage being positive 690 | seg_offsets_pred 691 | Return: 692 | bboxes, with shape = (N, 5), and N is the number of predicted bboxes 693 | """ 694 | seg_conf_threshold = seg_conf_threshold or config.seg_conf_threshold 695 | link_conf_threshold = link_conf_threshold or config.link_conf_threshold 696 | if image_shape is None: 697 | image_shape = config.image_shape 698 | 699 | seg_groups = group_segs(seg_scores, link_scores, seg_conf_threshold, link_conf_threshold); 700 | seg_locs = decode_seg_offsets_pred(seg_offsets_pred) 701 | 702 | bboxes = [] 703 | ref_h, ref_w = config.image_shape 704 | for group in seg_groups: 705 | group = [seg_locs[idx, :] for idx in group] 706 | bbox = combine_segs(group) 707 | image_h, image_w = image_shape[0:2] 708 | scale = [image_w * 1.0 / ref_w, image_h * 1.0 / ref_h, image_w * 1.0 / ref_w, image_h * 1.0 / ref_h, 1] 709 | bbox = np.asarray(bbox) * scale 710 | bboxes.append(bbox) 711 | 712 | bboxes = bboxes_to_xys(bboxes, image_shape) 713 | return np.asarray(bboxes, dtype = np.float32) 714 | 715 | def sin(theta): 716 | return np.sin(theta / 180.0 * np.pi) 717 | def cos(theta): 718 | return np.cos(theta / 180.0 * np.pi) 719 | def tan(theta): 720 | return np.tan(theta / 180.0 * np.pi) 721 | 722 | def combine_segs(segs, return_bias = False): 723 | segs = np.asarray(segs) 724 | assert segs.ndim == 2 725 | assert segs.shape[-1] == 5 726 | 727 | if len(segs) == 1: 728 | return segs[0, :] 729 | 730 | # find the best straight line fitting all center points: y = kx + b 731 | cxs = segs[:, 0] 732 | cys = segs[:, 1] 733 | 734 | ## the slope 735 | bar_theta = np.mean(segs[:, 4])# average theta 736 | k = tan(bar_theta); 737 | 738 | ## the bias: minimize sum (k*x_i + b - y_i)^2 739 | ### let c_i = k*x_i - y_i 740 | ### sum (k*x_i + b - y_i)^2 = sum(c_i + b)^2 741 | ### = sum(c_i^2 + b^2 + 2 * c_i * b) 742 | ### = n * b^2 + 2* sum(c_i) * b + sum(c_i^2) 743 | ### the target b = - sum(c_i) / n = - mean(c_i) = mean(y_i - k * x_i) 744 | b = np.mean(cys - k * cxs) 745 | 746 | # find the projections of all centers on the straight line 747 | ## firstly, move both the line and centers upward by distance b, so as to make the straight line crossing the point(0, 0): y = kx 748 | ## reprensent the line as a vector (1, k), and the projection of vector(x, y) on (1, k) is: proj = (x + k * y) / sqrt(1 + k^2) 749 | ## the projection point of (x, y) on (1, k) is (proj * cos(theta), proj * sin(theta)) 750 | t_cys = cys - b 751 | projs = (cxs + k * t_cys) / np.sqrt(1 + k**2) 752 | proj_points = np.transpose([projs * cos(bar_theta), projs * sin(bar_theta)]) 753 | 754 | # find the max distance 755 | max_dist = -1; 756 | idx1 = -1; 757 | idx2 = -1; 758 | 759 | for i in xrange(len(proj_points)): 760 | point1 = proj_points[i, :] 761 | for j in xrange(i + 1, len(proj_points)): 762 | point2 = proj_points[j, :] 763 | dist = np.sqrt(np.sum((point1 - point2) ** 2)) 764 | if dist > max_dist: 765 | idx1 = i 766 | idx2 = j 767 | max_dist = dist 768 | assert idx1 >= 0 and idx2 >= 0 769 | # the bbox: bcx, bcy, bw, bh, average_theta 770 | seg1 = segs[idx1, :] 771 | seg2 = segs[idx2, :] 772 | bcx, bcy = (seg1[:2] + seg2[:2]) / 2.0 773 | bh = np.mean(segs[:, 3]) 774 | bw = max_dist + (seg1[2] + seg2[2]) / 2.0 775 | 776 | if return_bias: 777 | return bcx, bcy, bw, bh, bar_theta, b# bias is useful for debugging. 778 | else: 779 | return bcx, bcy, bw, bh, bar_theta 780 | 781 | def bboxes_to_xys(bboxes, image_shape): 782 | """Convert Seglink bboxes to xys, i.e., eight points 783 | The `image_shape` is used to to make sure all points return are valid, i.e., within image area 784 | """ 785 | if len(bboxes) == 0: 786 | return [] 787 | 788 | assert np.ndim(bboxes) == 2 and np.shape(bboxes)[-1] == 5, 'invalid `bboxes` param with shape = ' + str(np.shape(bboxes)) 789 | 790 | h, w = image_shape[0:2] 791 | def get_valid_x(x): 792 | if x < 0: 793 | return 0 794 | if x >= w: 795 | return w - 1 796 | return x 797 | 798 | def get_valid_y(y): 799 | if y < 0: 800 | return 0 801 | if y >= h: 802 | return h - 1 803 | return y 804 | 805 | xys = np.zeros((len(bboxes), 8)) 806 | for bbox_idx, bbox in enumerate(bboxes): 807 | bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]), bbox[4]) 808 | points = cv2.cv.BoxPoints(bbox) 809 | points = np.int0(points) 810 | for i_xy, (x, y) in enumerate(points): 811 | x = get_valid_x(x) 812 | y = get_valid_y(y) 813 | points[i_xy, :] = [x, y] 814 | points = np.reshape(points, -1) 815 | xys[bbox_idx, :] = points 816 | return xys -------------------------------------------------------------------------------- /train_seglink.py: -------------------------------------------------------------------------------- 1 | #test code to make sure the ground truth calculation and data batch works well. 2 | 3 | import numpy as np 4 | import tensorflow as tf # test 5 | from tensorflow.python.ops import control_flow_ops 6 | 7 | from datasets import dataset_factory 8 | from preprocessing import ssd_vgg_preprocessing 9 | from tf_extended import seglink 10 | import util 11 | import cv2 12 | from nets import seglink_symbol, anchor_layer 13 | 14 | 15 | slim = tf.contrib.slim 16 | import config 17 | # =========================================================================== # 18 | # Checkpoint and running Flags 19 | # =========================================================================== # 20 | tf.app.flags.DEFINE_bool('train_with_ignored', False, 21 | 'whether to use ignored bbox (in ic15) in training.') 22 | tf.app.flags.DEFINE_float('seg_loc_loss_weight', 1.0, 'the loss weight of segment localization') 23 | tf.app.flags.DEFINE_float('link_cls_loss_weight', 1.0, 'the loss weight of linkage classification loss') 24 | 25 | tf.app.flags.DEFINE_string('train_dir', None, 26 | 'the path to store checkpoints and eventfiles for summaries') 27 | 28 | tf.app.flags.DEFINE_string('checkpoint_path', None, 29 | 'the path of pretrained model to be used. If there are checkpoints in train_dir, this config will be ignored.') 30 | 31 | tf.app.flags.DEFINE_float('gpu_memory_fraction', -1, 32 | 'the gpu memory fraction to be used. If less than 0, allow_growth = True is used.') 33 | 34 | tf.app.flags.DEFINE_integer('batch_size', None, 'The number of samples in each batch.') 35 | tf.app.flags.DEFINE_integer('num_gpus', 1, 'The number of gpus can be used.') 36 | tf.app.flags.DEFINE_integer('max_number_of_steps', 1000000, 'The maximum number of training steps.') 37 | tf.app.flags.DEFINE_integer('log_every_n_steps', 1, 'log frequency') 38 | tf.app.flags.DEFINE_bool("ignore_missing_vars", True, '') 39 | tf.app.flags.DEFINE_string('checkpoint_exclude_scopes', None, 'checkpoint_exclude_scopes') 40 | 41 | # =========================================================================== # 42 | # Optimizer configs. 43 | # =========================================================================== # 44 | tf.app.flags.DEFINE_float('learning_rate', 0.001, 'learning rate.') 45 | tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum for the MomentumOptimizer') 46 | tf.app.flags.DEFINE_float('weight_decay', 0.0005, 'The weight decay on the model weights.') 47 | tf.app.flags.DEFINE_bool('using_moving_average', False, 'Whether to use ExponentionalMovingAverage') 48 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 'The decay rate of ExponentionalMovingAverage') 49 | 50 | # =========================================================================== # 51 | # I/O and preprocessing Flags. 52 | # =========================================================================== # 53 | tf.app.flags.DEFINE_integer( 54 | 'num_readers', 1, 55 | 'The number of parallel readers that read data from the dataset.') 56 | tf.app.flags.DEFINE_integer( 57 | 'num_preprocessing_threads', 1, 58 | 'The number of threads used to create the batches.') 59 | 60 | # =========================================================================== # 61 | # Dataset Flags. 62 | # =========================================================================== # 63 | tf.app.flags.DEFINE_string( 64 | 'dataset_name', None, 'The name of the dataset to load.') 65 | tf.app.flags.DEFINE_string( 66 | 'dataset_split_name', 'train', 'The name of the train/test split.') 67 | tf.app.flags.DEFINE_string( 68 | 'dataset_dir', None, 'The directory where the dataset files are stored.') 69 | tf.app.flags.DEFINE_string( 70 | 'model_name', 'seglink_vgg', 'The name of the architecture to train.') 71 | tf.app.flags.DEFINE_integer('train_image_width', 512, 'Train image size') 72 | tf.app.flags.DEFINE_integer('train_image_height', 512, 'Train image size') 73 | 74 | 75 | FLAGS = tf.app.flags.FLAGS 76 | 77 | def config_initialization(): 78 | # image shape and feature layers shape inference 79 | image_shape = (FLAGS.train_image_height, FLAGS.train_image_width) 80 | 81 | if not FLAGS.dataset_dir: 82 | raise ValueError('You must supply the dataset directory with --dataset_dir') 83 | tf.logging.set_verbosity(tf.logging.DEBUG) 84 | util.init_logger(log_file = 'log_train_seglink_%d_%d.log'%image_shape, log_path = FLAGS.train_dir, stdout = False, mode = 'a') 85 | 86 | 87 | config.init_config(image_shape, 88 | batch_size = FLAGS.batch_size, 89 | weight_decay = FLAGS.weight_decay, 90 | num_gpus = FLAGS.num_gpus, 91 | train_with_ignored = FLAGS.train_with_ignored, 92 | seg_loc_loss_weight = FLAGS.seg_loc_loss_weight, 93 | link_cls_loss_weight = FLAGS.link_cls_loss_weight, 94 | ) 95 | 96 | batch_size = config.batch_size 97 | batch_size_per_gpu = config.batch_size_per_gpu 98 | 99 | tf.summary.scalar('batch_size', batch_size) 100 | tf.summary.scalar('batch_size_per_gpu', batch_size_per_gpu) 101 | 102 | util.proc.set_proc_name(FLAGS.model_name + '_' + FLAGS.dataset_name) 103 | 104 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 105 | config.print_config(FLAGS, dataset) 106 | return dataset 107 | 108 | def create_dataset_batch_queue(dataset): 109 | with tf.device('/cpu:0'): 110 | with tf.name_scope(FLAGS.dataset_name + '_data_provider'): 111 | provider = slim.dataset_data_provider.DatasetDataProvider( 112 | dataset, 113 | num_readers=FLAGS.num_readers, 114 | common_queue_capacity=50 * config.batch_size, 115 | common_queue_min=30 * config.batch_size, 116 | shuffle=True) 117 | # Get for SSD network: image, labels, bboxes. 118 | [image, gignored, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4] = provider.get([ 119 | 'image', 120 | 'object/ignored', 121 | 'object/bbox', 122 | 'object/oriented_bbox/x1', 123 | 'object/oriented_bbox/x2', 124 | 'object/oriented_bbox/x3', 125 | 'object/oriented_bbox/x4', 126 | 'object/oriented_bbox/y1', 127 | 'object/oriented_bbox/y2', 128 | 'object/oriented_bbox/y3', 129 | 'object/oriented_bbox/y4' 130 | ]) 131 | gxs = tf.transpose(tf.stack([x1, x2, x3, x4])) #shape = (N, 4) 132 | gys = tf.transpose(tf.stack([y1, y2, y3, y4])) 133 | image = tf.identity(image, 'input_image') 134 | 135 | # Pre-processing image, labels and bboxes. 136 | image, gignored, gbboxes, gxs, gys = ssd_vgg_preprocessing.preprocess_image(image, gignored, gbboxes, gxs, gys, 137 | out_shape = config.image_shape, 138 | data_format = config.data_format, 139 | is_training = True) 140 | image = tf.identity(image, 'processed_image') 141 | 142 | # calculate ground truth 143 | seg_label, seg_loc, link_label = seglink.tf_get_all_seglink_gt(gxs, gys, gignored) 144 | 145 | # batch them 146 | b_image, b_seg_label, b_seg_loc, b_link_label = tf.train.batch( 147 | [image, seg_label, seg_loc, link_label], 148 | batch_size = config.batch_size_per_gpu, 149 | num_threads= FLAGS.num_preprocessing_threads, 150 | capacity = 50) 151 | 152 | batch_queue = slim.prefetch_queue.prefetch_queue( 153 | [b_image, b_seg_label, b_seg_loc, b_link_label], 154 | capacity = 50) 155 | return batch_queue 156 | 157 | def sum_gradients(clone_grads): 158 | averaged_grads = [] 159 | for grad_and_vars in zip(*clone_grads): 160 | grads = [] 161 | var = grad_and_vars[0][1] 162 | for g, v in grad_and_vars: 163 | assert v == var 164 | grads.append(g) 165 | grad = tf.add_n(grads, name = v.op.name + '_summed_gradients') 166 | averaged_grads.append((grad, v)) 167 | 168 | tf.summary.histogram("variables_and_gradients_" + grad.op.name, grad) 169 | tf.summary.histogram("variables_and_gradients_" + v.op.name, v) 170 | tf.summary.scalar("variables_and_gradients_" + grad.op.name+'_mean/var_mean', tf.reduce_mean(grad)/tf.reduce_mean(var)) 171 | tf.summary.scalar("variables_and_gradients_" + v.op.name+'_mean', tf.reduce_mean(var)) 172 | return averaged_grads 173 | 174 | 175 | def create_clones(batch_queue): 176 | with tf.device('/cpu:0'): 177 | global_step = slim.create_global_step() 178 | learning_rate = tf.constant(FLAGS.learning_rate, name='learning_rate') 179 | tf.summary.scalar('learning_rate', learning_rate) 180 | optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=FLAGS.momentum, name='Momentum') 181 | 182 | # place clones 183 | seglink_loss = 0; # for summary only 184 | gradients = [] 185 | for clone_idx, gpu in enumerate(config.gpus): 186 | do_summary = clone_idx == 0 # only summary on the first clone 187 | with tf.variable_scope(tf.get_variable_scope(), reuse = True):# the variables has been created in config.init_config 188 | with tf.name_scope(config.clone_scopes[clone_idx]) as clone_scope: 189 | with tf.device(gpu) as clone_device: 190 | b_image, b_seg_label, b_seg_loc, b_link_label = batch_queue.dequeue() 191 | net = seglink_symbol.SegLinkNet(inputs = b_image, data_format = config.data_format) 192 | 193 | # build seglink loss 194 | net.build_loss(seg_labels = b_seg_label, 195 | seg_offsets = b_seg_loc, 196 | link_labels = b_link_label, 197 | do_summary = do_summary) 198 | 199 | 200 | # gather seglink losses 201 | losses = tf.get_collection(tf.GraphKeys.LOSSES, clone_scope) 202 | assert len(losses) == 3 # 3 is the number of seglink losses: seg_cls, seg_loc, link_cls 203 | total_clone_loss = tf.add_n(losses) / config.num_clones 204 | seglink_loss = seglink_loss + total_clone_loss 205 | 206 | # gather regularization loss and add to clone_0 only 207 | if clone_idx == 0: 208 | regularization_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 209 | total_clone_loss = total_clone_loss + regularization_loss 210 | 211 | # compute clone gradients 212 | clone_gradients = optimizer.compute_gradients(total_clone_loss)# all variables will be updated. 213 | gradients.append(clone_gradients) 214 | 215 | tf.summary.scalar('seglink_loss', seglink_loss) 216 | tf.summary.scalar('regularization_loss', regularization_loss) 217 | 218 | # add all gradients together 219 | # note that the gradients do not need to be averaged, because the average operation has been done on loss. 220 | averaged_gradients = sum_gradients(gradients) 221 | 222 | update_op = optimizer.apply_gradients(averaged_gradients, global_step=global_step) 223 | 224 | train_ops = [update_op] 225 | 226 | # moving average 227 | if FLAGS.using_moving_average: 228 | tf.logging.info('using moving average in training, \ 229 | with decay = %f'%(FLAGS.moving_average_decay)) 230 | ema = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay) 231 | ema_op = ema.apply(tf.trainable_variables()) 232 | with tf.control_dependencies([update_op]):# ema after updating 233 | train_ops.append(tf.group(ema_op)) 234 | 235 | train_op = control_flow_ops.with_dependencies(train_ops, seglink_loss, name='train_op') 236 | return train_op 237 | 238 | 239 | 240 | def train(train_op): 241 | summary_op = tf.summary.merge_all() 242 | sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True) 243 | if FLAGS.gpu_memory_fraction < 0: 244 | sess_config.gpu_options.allow_growth = True 245 | elif FLAGS.gpu_memory_fraction > 0: 246 | sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction; 247 | 248 | init_fn = util.tf.get_init_fn(checkpoint_path = FLAGS.checkpoint_path, train_dir = FLAGS.train_dir, 249 | ignore_missing_vars = FLAGS.ignore_missing_vars, checkpoint_exclude_scopes = FLAGS.checkpoint_exclude_scopes) 250 | saver = tf.train.Saver(max_to_keep = 500, write_version = 2) 251 | slim.learning.train( 252 | train_op, 253 | logdir = FLAGS.train_dir, 254 | init_fn = init_fn, 255 | summary_op = summary_op, 256 | number_of_steps = FLAGS.max_number_of_steps, 257 | log_every_n_steps = FLAGS.log_every_n_steps, 258 | save_summaries_secs = 60, 259 | saver = saver, 260 | save_interval_secs = 1200, 261 | session_config = sess_config 262 | ) 263 | 264 | 265 | def main(_): 266 | # The choice of return dataset object via initialization method maybe confusing, 267 | # but I need to print all configurations in this method, including dataset information. 268 | dataset = config_initialization() 269 | 270 | batch_queue = create_dataset_batch_queue(dataset) 271 | train_op = create_clones(batch_queue) 272 | train(train_op) 273 | 274 | 275 | if __name__ == '__main__': 276 | tf.app.run() 277 | -------------------------------------------------------------------------------- /visualize_detection_result.py: -------------------------------------------------------------------------------- 1 | #encoding utf-8 2 | 3 | import numpy as np 4 | import util 5 | 6 | 7 | def draw_bbox(image_data, line, color): 8 | line = util.str.remove_all(line, '\xef\xbb\xbf') 9 | data = line.split(','); 10 | 11 | def draw_rectangle(): 12 | x1, y1, x2, y2 = [int(v) for v in data[0 : 4]] 13 | util.img.rectangle( 14 | img = image_data, 15 | left_up = (x1, y1), 16 | right_bottom = (x2, y2), 17 | color = color, 18 | border_width = 1) 19 | 20 | def draw_text(): 21 | text = data[-1] 22 | pos = [int(v) for v in data[0:2]] 23 | util.img.put_text( 24 | img = image_data, 25 | text = text, 26 | pos = pos, 27 | scale = 1, 28 | color = color) 29 | def draw_oriented_bbox(): 30 | points = [int(v) for v in data[0:8]] 31 | points = np.reshape(points, (4, 2)) 32 | cnts = util.img.points_to_contours(points) 33 | util.img.draw_contours(image_data, cnts, -1, color = color, border_width = 1) 34 | 35 | draw_oriented_bbox() 36 | # if len(data) == 5: # ic13 gt 37 | # draw_rectangle() 38 | # draw_text() 39 | # elif len(data) == 8:# all det 40 | # draw_oriented_bbox() 41 | # elif len(data) == 9: # ic15 gt 42 | # draw_oriented_bbox() 43 | # draw_text() 44 | # else: 45 | # import pdb 46 | # pdb.set_trace() 47 | # print data 48 | # raise ValueError 49 | 50 | def visualize(image_root, det_root, output_root, gt_root = None): 51 | def read_gt_file(image_name): 52 | gt_file = util.io.join_path(gt_root, 'gt_%s.txt'%(image_name)) 53 | return util.io.read_lines(gt_file) 54 | 55 | def read_det_file(image_name): 56 | det_file = util.io.join_path(det_root, 'res_%s.txt'%(image_name)) 57 | return util.io.read_lines(det_file) 58 | 59 | def read_image_file(image_name): 60 | return util.img.imread(util.io.join_path(image_root, image_name)) 61 | 62 | image_names = util.io.ls(image_root, '.jpg') 63 | for image_idx, image_name in enumerate(image_names): 64 | 65 | print '%d / %d: %s'%(image_idx + 1, len(image_names), image_name) 66 | image_data = read_image_file(image_name) # in BGR 67 | image_name = image_name.split('.')[0] 68 | 69 | 70 | det_image = image_data.copy() 71 | det_lines = read_det_file(image_name) 72 | for line in det_lines: 73 | draw_bbox(det_image, line, color = util.img.COLOR_BGR_RED) 74 | util.img.imwrite(util.io.join_path(output_root, '%s_pred.jpg'%(image_name)), det_image) 75 | 76 | if gt_root is not None: 77 | gt_lines = read_gt_file(image_name) 78 | for line in gt_lines: 79 | draw_bbox(image_data, line, color = util.img.COLOR_GREEN) 80 | util.img.imwrite(util.io.join_path(output_root, '%s_gt.jpg'%(image_name)), image_data) 81 | 82 | if __name__ == '__main__': 83 | import argparse 84 | parser = argparse.ArgumentParser(description='visualize detection result of seglink') 85 | parser.add_argument('--image', type=str, required = True,help='the directory of test image') 86 | parser.add_argument('--gt', type=str, default=None,help='the directory of ground truth txt files') 87 | parser.add_argument('--det', type=str, required = True, help='the directory of detection result') 88 | parser.add_argument('--output', type=str, required = True, help='the directory to store images with bboxes') 89 | 90 | args = parser.parse_args() 91 | print('**************Arguments*****************') 92 | print(args) 93 | print('****************************************') 94 | visualize(image_root = args.image, gt_root = args.gt, det_root = args.det, output_root = args.output) 95 | --------------------------------------------------------------------------------