├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── config.py ├── datasets ├── __init__.py ├── dataset_factory.py ├── dataset_utils.py ├── icdar2015_to_tfrecords.py └── synthtext_to_tfrecords.py ├── nets ├── __init__.py ├── pixel_link_symbol.py └── vgg.py ├── pixel_link.py ├── pixel_link_decode.pyx ├── pixel_link_env.txt ├── preprocessing ├── __init__.py ├── preprocessing_factory.py ├── ssd_vgg_preprocessing.py └── tf_image.py ├── samples ├── img_249_pred.jpg └── img_333_pred.jpg ├── scripts ├── test.sh ├── test_any.sh ├── train.sh └── vis.sh ├── test_pixel_link.py ├── test_pixel_link_on_any_image.py ├── tf_extended ├── __init__.py ├── bboxes.py ├── math.py └── metrics.py ├── train_pixel_link.py └── visualize_detection_result.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pylib"] 2 | path = pylib 3 | url = git@github.com:dengdan/pylib.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 ZJULearning 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for the AAAI18 paper [PixelLink: Detecting Scene Text via Instance Segmentation](https://arxiv.org/abs/1801.01315), by Dan Deng, Haifeng Liu, Xuelong Li, and Deng Cai. 2 | 3 | Contributions to this repo are welcome, e.g., some other backbone networks (including the model definition and pretrained models). 4 | 5 | PLEASE CHECK EXSITING ISSUES BEFORE OPENNING YOUR OWN ONE. IF A SAME OR SIMILAR ISSUE HAD BEEN POSTED BEFORE, JUST REFER TO IT, AND DO NO OPEN A NEW ONE. 6 | 7 | # Installation 8 | ## Clone the repo 9 | ``` 10 | git clone --recursive git@github.com:ZJULearning/pixel_link.git 11 | ``` 12 | 13 | Denote the root directory path of pixel_link by `${pixel_link_root}`. 14 | 15 | Add the path of `${pixel_link_root}/pylib/src` to your `PYTHONPATH`: 16 | ``` 17 | export PYTHONPATH=${pixel_link_root}/pylib/src:$PYTHONPATH 18 | ``` 19 | 20 | ## Prerequisites 21 | (Only tested on) Ubuntu14.04 and 16.04 with: 22 | * Python 2.7 23 | * Tensorflow-gpu >= 1.1 24 | * opencv2 25 | * setproctitle 26 | * matplotlib 27 | 28 | Anaconda is recommended to for an easier installation: 29 | 30 | 1. Install [Anaconda](https://anaconda.org/) 31 | 2. Create and activate the required virtual environment by: 32 | ``` 33 | conda env create --file pixel_link_env.txt 34 | source activate pixel_link 35 | ``` 36 | 37 | # Testing 38 | ## Download the pretrained model 39 | * PixelLink + VGG16 4s [Baidu Netdisk](https://pan.baidu.com/s/1jsOc-cutC4GyF-wMMyj5-w) | [GoogleDrive](https://drive.google.com/file/d/19mlX5W8OBalSjhf5oTTS6qEq2eAU8Tg9/view?usp=sharing), trained on IC15 40 | * PixelLink + VGG16 2s [Baidu Netdisk](https://pan.baidu.com/s/1asSFsRSgviU2GnvGt2lAUw) | [GoogleDrive](https://drive.google.com/file/d/1QleZxu_6PSI733G7wzbqeFtc8A3-LmWW/view?usp=sharing), trained on IC15 41 | 42 | Unzip the downloaded model. It contains 4 files: 43 | 44 | * config.py 45 | * model.ckpt-xxx.data-00000-of-00001 46 | * model.ckpt-xxx.index 47 | * model.ckpt-xxx.meta 48 | 49 | Denote their parent directory as `${model_path}`. 50 | 51 | ## Test on ICDAR2015 52 | The reported results on ICDAR2015 are: 53 | 54 | |Model|Recall|Precision|F-mean| 55 | |---|---|---|---| 56 | |PixelLink+VGG16 2s|82.0|85.5|83.7| 57 | |PixelLink+VGG16 4s|81.7|82.9|82.3| 58 | 59 | Suppose you have downloaded the [ICDAR2015 dataset](http://rrc.cvc.uab.es/?ch=4&com=downloads), execute the following commands to test the model on ICDAR2015: 60 | ``` 61 | cd ${pixel_link_root} 62 | ./scripts/test.sh ${GPU_ID} ${model_path}/model.ckpt-xxx ${path_to_icdar2015}/ch4_test_images 63 | ``` 64 | For example: 65 | ``` 66 | ./scripts/test.sh 3 ~/temp/conv3_3/model.ckpt-38055 ~/dataset/ICDAR2015/Challenge4/ch4_test_images 67 | ``` 68 | 69 | The program will create a zip file of detection results, which can be submitted to the ICDAR2015 server directly. 70 | The detection results can be visualized via `scripts/vis.sh`. 71 | 72 | Here are some samples: 73 | ![./samples/img_333_pred.jpg](./samples/img_333_pred.jpg) 74 | ![./samples/img_249_pred.jpg](./samples/img_249_pred.jpg) 75 | 76 | 77 | ## Test on any images 78 | Put the images to be tested in a single directory, i.e., `${image_dir}`. Then: 79 | ``` 80 | cd ${pixel_link_root} 81 | ./scripts/test_any.sh ${GPU_ID} ${model_path}/model.ckpt-xxx ${image_dir} 82 | ``` 83 | For example: 84 | ``` 85 | ./scripts/test_any.sh 3 ~/temp/conv3_3/model.ckpt-38055 ~/dataset/ICDAR2015/Challenge4/ch4_training_images 86 | ``` 87 | 88 | The program will visualize the detection results directly on images. If the detection result is not satisfying, try to: 89 | 90 | 1. Adjust the inference parameters like `eval_image_width`, `eval_image_height`, `pixel_conf_threshold`, `link_conf_threshold`. 91 | 2. Or train your own model. 92 | 93 | # Training 94 | ## Converting the dataset to tfrecords files 95 | Scripts for converting ICDAR2015 and SynthText datasets have been provided in the `datasets` directory. 96 | It not hard to write a converting script for your own dataset. 97 | 98 | ## Train your own model 99 | 100 | * Modify `scripts/train.sh` to configure your dataset name and dataset path like: 101 | ``` 102 | DATASET=icdar2015 103 | DATASET_DIR=$HOME/dataset/pixel_link/icdar2015 104 | ``` 105 | * Start training 106 | ``` 107 | ./scripts/train.sh ${GPU_IDs} ${IMG_PER_GPU} 108 | ``` 109 | For example, `./scripts/train.sh 0,1,2 8`. 110 | 111 | The existing training strategy in `scripts/train.sh` is configured for icdar2015, modify it if necessary. A lot of training or model options are available in `config.py`, try it yourself if you are interested. 112 | 113 | # Acknowlegement 114 | ![](http://www.cad.zju.edu.cn/templets/default/imgzd/logo.jpg) 115 | ![](http://www.cvte.com/images/logo.png) 116 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from pprint import pprint 3 | import numpy as np 4 | from tensorflow.contrib.slim.python.slim.data import parallel_reader 5 | import tensorflow as tf 6 | import util 7 | from nets import pixel_link_symbol 8 | import pixel_link 9 | slim = tf.contrib.slim 10 | 11 | #===================================================================== 12 | #====================Pre-processing params START====================== 13 | # VGG mean parameters. 14 | r_mean = 123. 15 | g_mean = 117. 16 | b_mean = 104. 17 | rgb_mean = [r_mean, g_mean, b_mean] 18 | 19 | # scale, crop, filtering and resize parameters 20 | use_rotation = True 21 | rotation_prob = 0.5 22 | max_expand_scale = 1 23 | expand_prob = 0 24 | min_object_covered = 0.1 # Minimum object to be cropped in random crop. 25 | bbox_crop_overlap = 0.2 # Minimum overlap to keep a bbox after cropping. 26 | crop_aspect_ratio_range = (0.5, 2.) # Distortion ratio during cropping. 27 | area_range = [0.1, 1] 28 | flip = False 29 | using_shorter_side_filtering=True 30 | min_shorter_side = 10 31 | max_shorter_side = np.infty 32 | #====================Pre-processing params END======================== 33 | #===================================================================== 34 | 35 | 36 | 37 | 38 | #===================================================================== 39 | #====================Post-processing params START===================== 40 | decode_method = pixel_link.DECODE_METHOD_join 41 | min_area = 300 42 | min_height = 10 43 | #====================Post-processing params END======================= 44 | #===================================================================== 45 | 46 | 47 | 48 | #===================================================================== 49 | #====================Training and model params START ================= 50 | dropout_ratio = 0 51 | max_neg_pos_ratio = 3 52 | 53 | feat_fuse_type = pixel_link_symbol.FUSE_TYPE_cascade_conv1x1_upsample_sum 54 | # feat_fuse_type = pixel_link_symbol.FUSE_TYPE_cascade_conv1x1_128_upsamle_sum_conv1x1_2 55 | # feat_fuse_type = pixel_link_symbol.FUSE_TYPE_cascade_conv1x1_128_upsamle_concat_conv1x1_2 56 | 57 | pixel_neighbour_type = pixel_link.PIXEL_NEIGHBOUR_TYPE_8 58 | #pixel_neighbour_type = pixel_link.PIXEL_NEIGHBOUR_TYPE_4 59 | 60 | 61 | #model_type = pixel_link_symbol.MODEL_TYPE_vgg16 62 | #feat_layers = ['conv2_2', 'conv3_3', 'conv4_3', 'conv5_3', 'fc7'] 63 | #strides = [2] 64 | model_type = pixel_link_symbol.MODEL_TYPE_vgg16 65 | feat_layers = ['conv3_3', 'conv4_3', 'conv5_3', 'fc7'] 66 | strides = [4] 67 | 68 | pixel_cls_weight_method = pixel_link.PIXEL_CLS_WEIGHT_bbox_balanced 69 | bbox_border_width = 1 70 | pixel_cls_border_weight_lambda = 1.0 71 | pixel_cls_loss_weight_lambda = 2.0 72 | pixel_link_neg_loss_weight_lambda = 1.0 73 | pixel_link_loss_weight = 1.0 74 | #====================Training and model params END ================== 75 | #===================================================================== 76 | 77 | 78 | #===================================================================== 79 | #====================do-not-change configurations START=============== 80 | num_classes = 2 81 | ignore_label = -1 82 | background_label = 0 83 | text_label = 1 84 | data_format = 'NHWC' 85 | train_with_ignored = False 86 | #====================do-not-change configurations END================= 87 | #===================================================================== 88 | 89 | global weight_decay 90 | 91 | global train_image_shape 92 | global image_shape 93 | global score_map_shape 94 | 95 | global batch_size 96 | global batch_size_per_gpu 97 | global gpus 98 | global num_clones 99 | global clone_scopes 100 | 101 | global num_neighbours 102 | 103 | global pixel_conf_threshold 104 | global link_conf_threshold 105 | 106 | def _set_weight_decay(wd): 107 | global weight_decay 108 | weight_decay = wd 109 | 110 | def _set_image_shape(shape): 111 | h, w = shape 112 | global train_image_shape 113 | global score_map_shape 114 | global image_shape 115 | 116 | assert w % 4 == 0 117 | assert h % 4 == 0 118 | 119 | train_image_shape = [h, w] 120 | score_map_shape = (h / strides[0], w / strides[0]) 121 | image_shape = train_image_shape 122 | 123 | def _set_batch_size(bz): 124 | global batch_size 125 | batch_size = bz 126 | 127 | def _set_seg_th(pixel_conf_th, link_conf_th): 128 | global pixel_conf_threshold 129 | global link_conf_threshold 130 | 131 | pixel_conf_threshold = pixel_conf_th 132 | link_conf_threshold = link_conf_th 133 | 134 | 135 | def _set_train_with_ignored(train_with_ignored_): 136 | global train_with_ignored 137 | train_with_ignored = train_with_ignored_ 138 | 139 | 140 | def init_config(image_shape, batch_size = 1, 141 | weight_decay = 0.0005, 142 | num_gpus = 1, 143 | pixel_conf_threshold = 0.6, 144 | link_conf_threshold = 0.9): 145 | _set_seg_th(pixel_conf_threshold, link_conf_threshold) 146 | _set_weight_decay(weight_decay) 147 | _set_image_shape(image_shape) 148 | 149 | #init batch size 150 | global gpus 151 | gpus = util.tf.get_available_gpus(num_gpus) 152 | 153 | global num_clones 154 | num_clones = len(gpus) 155 | 156 | global clone_scopes 157 | clone_scopes = ['clone_%d'%(idx) for idx in xrange(num_clones)] 158 | 159 | _set_batch_size(batch_size) 160 | 161 | global batch_size_per_gpu 162 | batch_size_per_gpu = batch_size / num_clones 163 | if batch_size_per_gpu < 1: 164 | raise ValueError('Invalid batch_size [=%d], \ 165 | resulting in 0 images per gpu.'%(batch_size)) 166 | 167 | global num_neighbours 168 | num_neighbours = pixel_link.get_neighbours_fn()[1] 169 | 170 | 171 | def print_config(flags, dataset, save_dir = None, print_to_file = True): 172 | def do_print(stream=None): 173 | print(util.log.get_date_str(), file = stream) 174 | print('\n# =========================================================================== #', file=stream) 175 | print('# Training flags:', file=stream) 176 | print('# =========================================================================== #', file=stream) 177 | 178 | def print_ckpt(path): 179 | ckpt = util.tf.get_latest_ckpt(path) 180 | if ckpt is not None: 181 | print('Resume Training from : %s'%(ckpt), file = stream) 182 | return True 183 | return False 184 | 185 | if not print_ckpt(flags.train_dir): 186 | print_ckpt(flags.checkpoint_path) 187 | 188 | pprint(flags.__flags, stream=stream) 189 | 190 | print('\n# =========================================================================== #', file=stream) 191 | print('# pixel_link net parameters:', file=stream) 192 | print('# =========================================================================== #', file=stream) 193 | vars = globals() 194 | for key in vars: 195 | var = vars[key] 196 | if util.dtype.is_number(var) or util.dtype.is_str(var) or util.dtype.is_list(var) or util.dtype.is_tuple(var): 197 | pprint('%s=%s'%(key, str(var)), stream = stream) 198 | 199 | print('\n# =========================================================================== #', file=stream) 200 | print('# Training | Evaluation dataset files:', file=stream) 201 | print('# =========================================================================== #', file=stream) 202 | data_files = parallel_reader.get_data_files(dataset.data_sources) 203 | pprint(sorted(data_files), stream=stream) 204 | print('', file=stream) 205 | do_print(None) 206 | 207 | if print_to_file: 208 | # Save to a text file as well. 209 | if save_dir is None: 210 | save_dir = flags.train_dir 211 | 212 | util.io.mkdir(save_dir) 213 | path = util.io.join_path(save_dir, 'training_config.txt') 214 | with open(path, "a") as out: 215 | do_print(out) 216 | 217 | def load_config(path): 218 | if not util.io.is_dir(path): 219 | path = util.io.get_dir(path) 220 | 221 | config_file = util.io.join_path(path, 'config.py') 222 | if util.io.exists(config_file): 223 | tf.logging.info('loading config.py from %s'%(config_file)) 224 | config = util.mod.load_mod_from_path(config_file) 225 | else: 226 | util.io.copy('config.py', path) 227 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """A factory-pattern class which returns classification image/label pairs.""" 2 | from datasets import dataset_utils 3 | 4 | class DatasetConfig(): 5 | def __init__(self, file_pattern, split_sizes): 6 | self.file_pattern = file_pattern 7 | self.split_sizes = split_sizes 8 | 9 | icdar2013 = DatasetConfig( 10 | file_pattern = '*_%s.tfrecord', 11 | split_sizes = { 12 | 'train': 229, 13 | 'test': 233 14 | } 15 | ) 16 | icdar2015 = DatasetConfig( 17 | file_pattern = 'icdar2015_%s.tfrecord', 18 | split_sizes = { 19 | 'train': 1000, 20 | 'test': 500 21 | } 22 | ) 23 | td500 = DatasetConfig( 24 | file_pattern = '*_%s.tfrecord', 25 | split_sizes = { 26 | 'train': 300, 27 | 'test': 200 28 | } 29 | ) 30 | tr400 = DatasetConfig( 31 | file_pattern = 'tr400_%s.tfrecord', 32 | split_sizes = { 33 | 'train': 400 34 | } 35 | ) 36 | scut = DatasetConfig( 37 | file_pattern = 'scut_%s.tfrecord', 38 | split_sizes = { 39 | 'train': 1715 40 | } 41 | ) 42 | 43 | synthtext = DatasetConfig( 44 | file_pattern = '*.tfrecord', 45 | # file_pattern = 'SynthText_*.tfrecord', 46 | split_sizes = { 47 | 'train': 858750 48 | } 49 | ) 50 | 51 | datasets_map = { 52 | 'icdar2013':icdar2013, 53 | 'icdar2015':icdar2015, 54 | 'scut':scut, 55 | 'td500':td500, 56 | 'tr400':tr400, 57 | 'synthtext':synthtext 58 | } 59 | 60 | 61 | def get_dataset(dataset_name, split_name, dataset_dir, reader=None): 62 | """Given a dataset dataset_name and a split_name returns a Dataset. 63 | Args: 64 | dataset_name: String, the dataset_name of the dataset. 65 | split_name: A train/test split dataset_name. 66 | dataset_dir: The directory where the dataset files are stored. 67 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 68 | reader defined by each dataset is used. 69 | Returns: 70 | A `Dataset` class. 71 | Raises: 72 | ValueError: If the dataset `dataset_name` is unknown. 73 | """ 74 | if dataset_name not in datasets_map: 75 | raise ValueError('Name of dataset unknown %s' % dataset_name) 76 | dataset_config = datasets_map[dataset_name]; 77 | file_pattern = dataset_config.file_pattern 78 | num_samples = dataset_config.split_sizes[split_name] 79 | return dataset_utils.get_split(split_name, dataset_dir,file_pattern, num_samples, reader) 80 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | 17 | import tensorflow as tf 18 | import numpy as np 19 | slim = tf.contrib.slim 20 | 21 | import util 22 | 23 | 24 | 25 | def int64_feature(value): 26 | """Wrapper for inserting int64 features into Example proto. 27 | """ 28 | if not isinstance(value, list): 29 | value = [value] 30 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 31 | 32 | 33 | def float_feature(value): 34 | """Wrapper for inserting float features into Example proto. 35 | """ 36 | if not isinstance(value, list): 37 | value = [value] 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | def bytes_feature(value): 42 | """Wrapper for inserting bytes features into Example proto. 43 | """ 44 | if not isinstance(value, list): 45 | value = [value] 46 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 47 | 48 | 49 | def image_to_tfexample(image_data, image_format, height, width, class_id): 50 | return tf.train.Example(features=tf.train.Features(feature={ 51 | 'image/encoded': bytes_feature(image_data), 52 | 'image/format': bytes_feature(image_format), 53 | 'image/class/label': int64_feature(class_id), 54 | 'image/height': int64_feature(height), 55 | 'image/width': int64_feature(width), 56 | })) 57 | 58 | 59 | # def convert_to_example(image_data, filename, labels, labels_text, bboxes, oriented_bboxes, shape): 60 | # """Build an Example proto for an image example. 61 | # Args: 62 | # image_data: string, JPEG encoding of RGB image; 63 | # labels: list of integers, identifier for the ground truth; 64 | # labels_text: list of strings, human-readable labels; 65 | # oriented_bboxes: list of bounding oriented boxes; each box is a list of floats in [0, 1]; 66 | # specifying [x1, y1, x2, y2, x3, y3, x4, y4] 67 | # bboxes: list of bbox in rectangle, [xmin, ymin, xmax, ymax] 68 | # Returns: 69 | # Example proto 70 | # """ 71 | # 72 | # image_format = b'JPEG' 73 | # oriented_bboxes = np.asarray(oriented_bboxes) 74 | # bboxes = np.asarray(bboxes) 75 | # example = tf.train.Example(features=tf.train.Features(feature={ 76 | # 'image/shape': int64_feature(list(shape)), 77 | # 'image/object/bbox/xmin': float_feature(list(bboxes[:, 0])), 78 | # 'image/object/bbox/ymin': float_feature(list(bboxes[:, 1])), 79 | # 'image/object/bbox/xmax': float_feature(list(bboxes[:, 2])), 80 | # 'image/object/bbox/ymax': float_feature(list(bboxes[:, 3])), 81 | # 'image/object/bbox/x1': float_feature(list(oriented_bboxes[:, 0])), 82 | # 'image/object/bbox/y1': float_feature(list(oriented_bboxes[:, 1])), 83 | # 'image/object/bbox/x2': float_feature(list(oriented_bboxes[:, 2])), 84 | # 'image/object/bbox/y2': float_feature(list(oriented_bboxes[:, 3])), 85 | # 'image/object/bbox/x3': float_feature(list(oriented_bboxes[:, 4])), 86 | # 'image/object/bbox/y3': float_feature(list(oriented_bboxes[:, 5])), 87 | # 'image/object/bbox/x4': float_feature(list(oriented_bboxes[:, 6])), 88 | # 'image/object/bbox/y4': float_feature(list(oriented_bboxes[:, 7])), 89 | # 'image/object/bbox/label': int64_feature(labels), 90 | # 'image/object/bbox/label_text': bytes_feature(labels_text), 91 | # 'image/format': bytes_feature(image_format), 92 | # 'image/filename': bytes_feature(filename), 93 | # 'image/encoded': bytes_feature(image_data)})) 94 | # return example 95 | 96 | def convert_to_example(image_data, filename, labels, labels_text, bboxes, oriented_bboxes, shape): 97 | """Build an Example proto for an image example. 98 | Args: 99 | image_data: string, JPEG encoding of RGB image; 100 | labels: list of integers, identifier for the ground truth; 101 | labels_text: list of strings, human-readable labels; 102 | oriented_bboxes: list of bounding oriented boxes; each box is a list of floats in [0, 1]; 103 | specifying [x1, y1, x2, y2, x3, y3, x4, y4] 104 | bboxes: list of bbox in rectangle, [xmin, ymin, xmax, ymax] 105 | Returns: 106 | Example proto 107 | """ 108 | 109 | image_format = b'JPEG' 110 | oriented_bboxes = np.asarray(oriented_bboxes) 111 | if len(bboxes) == 0: 112 | print filename, 'has no bboxes' 113 | 114 | bboxes = np.asarray(bboxes) 115 | def get_list(obj, idx): 116 | if len(obj) > 0: 117 | return list(obj[:, idx]) 118 | return [] 119 | example = tf.train.Example(features=tf.train.Features(feature={ 120 | 'image/shape': int64_feature(list(shape)), 121 | 'image/object/bbox/xmin': float_feature(get_list(bboxes, 0)), 122 | 'image/object/bbox/ymin': float_feature(get_list(bboxes, 1)), 123 | 'image/object/bbox/xmax': float_feature(get_list(bboxes, 2)), 124 | 'image/object/bbox/ymax': float_feature(get_list(bboxes, 3)), 125 | 'image/object/bbox/x1': float_feature(get_list(oriented_bboxes, 0)), 126 | 'image/object/bbox/y1': float_feature(get_list(oriented_bboxes, 1)), 127 | 'image/object/bbox/x2': float_feature(get_list(oriented_bboxes, 2)), 128 | 'image/object/bbox/y2': float_feature(get_list(oriented_bboxes, 3)), 129 | 'image/object/bbox/x3': float_feature(get_list(oriented_bboxes, 4)), 130 | 'image/object/bbox/y3': float_feature(get_list(oriented_bboxes, 5)), 131 | 'image/object/bbox/x4': float_feature(get_list(oriented_bboxes, 6)), 132 | 'image/object/bbox/y4': float_feature(get_list(oriented_bboxes, 7)), 133 | 'image/object/bbox/label': int64_feature(labels), 134 | 'image/object/bbox/label_text': bytes_feature(labels_text), 135 | 'image/format': bytes_feature(image_format), 136 | 'image/filename': bytes_feature(filename), 137 | 'image/encoded': bytes_feature(image_data)})) 138 | return example 139 | 140 | def get_split(split_name, dataset_dir, file_pattern, num_samples, reader=None): 141 | dataset_dir = util.io.get_absolute_path(dataset_dir) 142 | 143 | if util.str.contains(file_pattern, '%'): 144 | file_pattern = util.io.join_path(dataset_dir, file_pattern % split_name) 145 | else: 146 | file_pattern = util.io.join_path(dataset_dir, file_pattern) 147 | # Allowing None in the signature so that dataset_factory can use the default. 148 | if reader is None: 149 | reader = tf.TFRecordReader 150 | keys_to_features = { 151 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 152 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 153 | 'image/filename': tf.FixedLenFeature((), tf.string, default_value=''), 154 | 'image/shape': tf.FixedLenFeature([3], tf.int64), 155 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 156 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 157 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 158 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 159 | 'image/object/bbox/x1': tf.VarLenFeature(dtype=tf.float32), 160 | 'image/object/bbox/x2': tf.VarLenFeature(dtype=tf.float32), 161 | 'image/object/bbox/x3': tf.VarLenFeature(dtype=tf.float32), 162 | 'image/object/bbox/x4': tf.VarLenFeature(dtype=tf.float32), 163 | 'image/object/bbox/y1': tf.VarLenFeature(dtype=tf.float32), 164 | 'image/object/bbox/y2': tf.VarLenFeature(dtype=tf.float32), 165 | 'image/object/bbox/y3': tf.VarLenFeature(dtype=tf.float32), 166 | 'image/object/bbox/y4': tf.VarLenFeature(dtype=tf.float32), 167 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 168 | } 169 | items_to_handlers = { 170 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 171 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 172 | 'filename': slim.tfexample_decoder.Tensor('image/filename'), 173 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 174 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 175 | 'object/oriented_bbox/x1': slim.tfexample_decoder.Tensor('image/object/bbox/x1'), 176 | 'object/oriented_bbox/x2': slim.tfexample_decoder.Tensor('image/object/bbox/x2'), 177 | 'object/oriented_bbox/x3': slim.tfexample_decoder.Tensor('image/object/bbox/x3'), 178 | 'object/oriented_bbox/x4': slim.tfexample_decoder.Tensor('image/object/bbox/x4'), 179 | 'object/oriented_bbox/y1': slim.tfexample_decoder.Tensor('image/object/bbox/y1'), 180 | 'object/oriented_bbox/y2': slim.tfexample_decoder.Tensor('image/object/bbox/y2'), 181 | 'object/oriented_bbox/y3': slim.tfexample_decoder.Tensor('image/object/bbox/y3'), 182 | 'object/oriented_bbox/y4': slim.tfexample_decoder.Tensor('image/object/bbox/y4'), 183 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label') 184 | } 185 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers) 186 | 187 | labels_to_names = {0:'background', 1:'text'} 188 | items_to_descriptions = { 189 | 'image': 'A color image of varying height and width.', 190 | 'shape': 'Shape of the image', 191 | 'object/bbox': 'A list of bounding boxes, one per each object.', 192 | 'object/label': 'A list of labels, one per each object.', 193 | } 194 | 195 | return slim.dataset.Dataset( 196 | data_sources=file_pattern, 197 | reader=reader, 198 | decoder=decoder, 199 | num_samples=num_samples, 200 | items_to_descriptions=items_to_descriptions, 201 | num_classes=2, 202 | labels_to_names=labels_to_names) 203 | -------------------------------------------------------------------------------- /datasets/icdar2015_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import numpy as np; 3 | import tensorflow as tf 4 | import util 5 | from dataset_utils import int64_feature, float_feature, bytes_feature, convert_to_example 6 | import config 7 | 8 | 9 | def cvt_to_tfrecords(output_path , data_path, gt_path): 10 | image_names = util.io.ls(data_path, '.jpg')#[0:10]; 11 | print "%d images found in %s"%(len(image_names), data_path); 12 | with tf.python_io.TFRecordWriter(output_path) as tfrecord_writer: 13 | for idx, image_name in enumerate(image_names): 14 | oriented_bboxes = []; 15 | bboxes = [] 16 | labels = []; 17 | labels_text = []; 18 | path = util.io.join_path(data_path, image_name); 19 | print "\tconverting image: %d/%d %s"%(idx, len(image_names), image_name); 20 | image_data = tf.gfile.FastGFile(path, 'r').read() 21 | 22 | image = util.img.imread(path, rgb = True); 23 | shape = image.shape 24 | h, w = shape[0:2]; 25 | h *= 1.0; 26 | w *= 1.0; 27 | image_name = util.str.split(image_name, '.')[0]; 28 | gt_name = 'gt_' + image_name + '.txt'; 29 | gt_filepath = util.io.join_path(gt_path, gt_name); 30 | lines = util.io.read_lines(gt_filepath); 31 | 32 | for line in lines: 33 | line = util.str.remove_all(line, '\xef\xbb\xbf') 34 | gt = util.str.split(line, ','); 35 | oriented_box = [int(gt[i]) for i in range(8)]; 36 | oriented_box = np.asarray(oriented_box) / ([w, h] * 4); 37 | oriented_bboxes.append(oriented_box); 38 | 39 | xs = oriented_box.reshape(4, 2)[:, 0] 40 | ys = oriented_box.reshape(4, 2)[:, 1] 41 | xmin = xs.min() 42 | xmax = xs.max() 43 | ymin = ys.min() 44 | ymax = ys.max() 45 | bboxes.append([xmin, ymin, xmax, ymax]) 46 | 47 | # might be wrong here, but it doesn't matter because the label is not going to be used in detection 48 | labels_text.append(gt[-1]); 49 | ignored = util.str.contains(gt[-1], '###') 50 | if ignored: 51 | labels.append(config.ignore_label); 52 | else: 53 | labels.append(config.text_label) 54 | example = convert_to_example(image_data, image_name, labels, labels_text, bboxes, oriented_bboxes, shape) 55 | tfrecord_writer.write(example.SerializeToString()) 56 | 57 | if __name__ == "__main__": 58 | root_dir = util.io.get_absolute_path('~/dataset/ICDAR2015/Challenge4/') 59 | output_dir = util.io.get_absolute_path('~/dataset/pixel_link/ICDAR/') 60 | util.io.mkdir(output_dir); 61 | 62 | training_data_dir = util.io.join_path(root_dir, 'ch4_training_images') 63 | training_gt_dir = util.io.join_path(root_dir,'ch4_training_localization_transcription_gt') 64 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'icdar2015_train.tfrecord'), data_path = training_data_dir, gt_path = training_gt_dir) 65 | 66 | test_data_dir = util.io.join_path(root_dir, 'ch4_test_images') 67 | test_gt_dir = util.io.join_path(root_dir,'ch4_test_localization_transcription_gt') 68 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'icdar2015_test.tfrecord'), data_path = test_data_dir, gt_path = test_gt_dir) 69 | -------------------------------------------------------------------------------- /datasets/synthtext_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | #encoding=utf-8 2 | import numpy as np; 3 | import tensorflow as tf 4 | import util; 5 | from dataset_utils import int64_feature, float_feature, bytes_feature, convert_to_example 6 | 7 | # encoding = utf-8 8 | import numpy as np 9 | import time 10 | import config 11 | import util 12 | 13 | 14 | class SynthTextDataFetcher(): 15 | def __init__(self, mat_path, root_path): 16 | self.mat_path = mat_path 17 | self.root_path = root_path 18 | self._load_mat() 19 | 20 | @util.dec.print_calling 21 | def _load_mat(self): 22 | data = util.io.load_mat(self.mat_path) 23 | self.image_paths = data['imnames'][0] 24 | self.image_bbox = data['wordBB'][0] 25 | self.txts = data['txt'][0] 26 | self.num_images = len(self.image_paths) 27 | 28 | def get_image_path(self, idx): 29 | image_path = util.io.join_path(self.root_path, self.image_paths[idx][0]) 30 | return image_path 31 | 32 | def get_num_words(self, idx): 33 | try: 34 | return np.shape(self.image_bbox[idx])[2] 35 | except: # error caused by dataset 36 | return 1 37 | 38 | 39 | def get_word_bbox(self, img_idx, word_idx): 40 | boxes = self.image_bbox[img_idx] 41 | if len(np.shape(boxes)) ==2: # error caused by dataset 42 | boxes = np.reshape(boxes, (2, 4, 1)) 43 | 44 | xys = boxes[:,:, word_idx] 45 | assert(np.shape(xys) ==(2, 4)) 46 | return np.float32(xys) 47 | 48 | def normalize_bbox(self, xys, width, height): 49 | xs = xys[0, :] 50 | ys = xys[1, :] 51 | 52 | min_x = min(xs) 53 | min_y = min(ys) 54 | max_x = max(xs) 55 | max_y = max(ys) 56 | 57 | # bound them in the valid range 58 | min_x = max(0, min_x) 59 | min_y = max(0, min_y) 60 | max_x = min(width, max_x) 61 | max_y = min(height, max_y) 62 | 63 | # check the w, h and area of the rect 64 | w = max_x - min_x 65 | h = max_y - min_y 66 | is_valid = True 67 | 68 | if w < 10 or h < 10: 69 | is_valid = False 70 | 71 | if w * h < 100: 72 | is_valid = False 73 | 74 | xys[0, :] = xys[0, :] / width 75 | xys[1, :] = xys[1, :] / height 76 | 77 | return is_valid, min_x / width, min_y /height, max_x / width, max_y / height, xys 78 | 79 | def get_txt(self, image_idx, word_idx): 80 | txts = self.txts[image_idx]; 81 | clean_txts = [] 82 | for txt in txts: 83 | clean_txts += txt.split() 84 | return str(clean_txts[word_idx]) 85 | 86 | 87 | def fetch_record(self, image_idx): 88 | image_path = self.get_image_path(image_idx) 89 | if not (util.io.exists(image_path)): 90 | return None; 91 | img = util.img.imread(image_path) 92 | h, w = img.shape[0:-1]; 93 | num_words = self.get_num_words(image_idx) 94 | rect_bboxes = [] 95 | full_bboxes = [] 96 | txts = [] 97 | for word_idx in xrange(num_words): 98 | xys = self.get_word_bbox(image_idx, word_idx); 99 | is_valid, min_x, min_y, max_x, max_y, xys = self.normalize_bbox(xys, width = w, height = h) 100 | if not is_valid: 101 | continue; 102 | rect_bboxes.append([min_x, min_y, max_x, max_y]) 103 | xys = np.reshape(np.transpose(xys), -1) 104 | full_bboxes.append(xys); 105 | txt = self.get_txt(image_idx, word_idx); 106 | txts.append(txt); 107 | if len(rect_bboxes) == 0: 108 | return None; 109 | 110 | return image_path, img, txts, rect_bboxes, full_bboxes 111 | 112 | 113 | 114 | def cvt_to_tfrecords(output_path , data_path, gt_path, records_per_file = 50000): 115 | 116 | fetcher = SynthTextDataFetcher(root_path = data_path, mat_path = gt_path) 117 | image_idxes = range(fetcher.num_images) 118 | np.random.shuffle(image_idxes) 119 | record_count = 0; 120 | for image_idx in image_idxes: 121 | if record_count % records_per_file == 0: 122 | fid = record_count / records_per_file 123 | tfrecord_writer = tf.python_io.TFRecordWriter(output_path%(fid)) 124 | 125 | print "converting image %d/%d"%(record_count, fetcher.num_images) 126 | record = fetcher.fetch_record(image_idx); 127 | if record is None: 128 | print '\nimage %d does not exist'%(image_idx + 1) 129 | continue; 130 | record_count += 1 131 | image_path, image, txts, rect_bboxes, oriented_bboxes = record; 132 | labels = []; 133 | for txt in txts: 134 | if len(txt) < 3: 135 | labels.append(config.ignore_label) 136 | else: 137 | labels.append(config.text_label) 138 | image_data = tf.gfile.FastGFile(image_path, 'r').read() 139 | shape = image.shape 140 | image_name = str(util.io.get_filename(image_path).split('.')[0]) 141 | example = convert_to_example(image_data, image_name, labels, txts, rect_bboxes, oriented_bboxes, shape) 142 | tfrecord_writer.write(example.SerializeToString()) 143 | 144 | 145 | if __name__ == "__main__": 146 | mat_path = util.io.get_absolute_path('~/dataset/SynthText/gt.mat') 147 | root_path = util.io.get_absolute_path('~/dataset/SynthText/') 148 | output_dir = util.io.get_absolute_path('~/dataset/pixel_link/SynthText/') 149 | util.io.mkdir(output_dir); 150 | cvt_to_tfrecords(output_path = util.io.join_path(output_dir, 'SynthText_%d.tfrecord'), data_path = root_path, gt_path = mat_path) 151 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/pixel_link/00cb9aacc80583a8aba77d6834748ab4cca03254/nets/__init__.py -------------------------------------------------------------------------------- /nets/pixel_link_symbol.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | slim = tf.contrib.slim 3 | 4 | MODEL_TYPE_vgg16 = 'vgg16' 5 | MODEL_TYPE_vgg16_no_dilation = 'vgg16_no_dilation' 6 | 7 | FUSE_TYPE_cascade_conv1x1_upsample_sum = 'cascade_conv1x1_upsample_sum' 8 | FUSE_TYPE_cascade_conv1x1_128_upsamle_sum_conv1x1_2 = \ 9 | 'cascade_conv1x1_128_upsamle_sum_conv1x1_2' 10 | FUSE_TYPE_cascade_conv1x1_128_upsamle_concat_conv1x1_2 = \ 11 | 'cascade_conv1x1_128_upsamle_concat_conv1x1_2' 12 | 13 | class PixelLinkNet(object): 14 | def __init__(self, inputs, is_training): 15 | self.inputs = inputs 16 | self.is_training = is_training 17 | self._build_network() 18 | self._fuse_feat_layers() 19 | self._logits_to_scores() 20 | 21 | def _build_network(self): 22 | import config 23 | if config.model_type == MODEL_TYPE_vgg16: 24 | from nets import vgg 25 | with slim.arg_scope([slim.conv2d], 26 | activation_fn=tf.nn.relu, 27 | weights_regularizer=slim.l2_regularizer(config.weight_decay), 28 | weights_initializer= tf.contrib.layers.xavier_initializer(), 29 | biases_initializer = tf.zeros_initializer()): 30 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 31 | padding='SAME') as sc: 32 | self.arg_scope = sc 33 | self.net, self.end_points = vgg.basenet( 34 | inputs = self.inputs) 35 | 36 | elif config.model_type == MODEL_TYPE_vgg16_no_dilation: 37 | from nets import vgg 38 | with slim.arg_scope([slim.conv2d], 39 | activation_fn=tf.nn.relu, 40 | weights_regularizer=slim.l2_regularizer(config.weight_decay), 41 | weights_initializer= tf.contrib.layers.xavier_initializer(), 42 | biases_initializer = tf.zeros_initializer()): 43 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 44 | padding='SAME') as sc: 45 | self.arg_scope = sc 46 | self.net, self.end_points = vgg.basenet( 47 | inputs = self.inputs, dilation = False) 48 | else: 49 | raise ValueError('model_type not supported:%s'%(config.model_type)) 50 | 51 | def _score_layer(self, input_layer, num_classes, scope): 52 | import config 53 | with slim.arg_scope(self.arg_scope): 54 | logits = slim.conv2d(input_layer, num_classes, [1, 1], 55 | stride=1, 56 | activation_fn=None, 57 | scope='score_from_%s'%scope, 58 | normalizer_fn=None) 59 | try: 60 | use_dropout = config.dropout_ratio > 0 61 | except: 62 | use_dropout = False 63 | 64 | if use_dropout: 65 | if self.is_training: 66 | dropout_ratio = config.dropout_ratio 67 | else: 68 | dropout_ratio = 0 69 | keep_prob = 1.0 - dropout_ratio 70 | tf.logging.info('Using Dropout, with keep_prob = %f'%(keep_prob)) 71 | logits = tf.nn.dropout(logits, keep_prob) 72 | return logits 73 | 74 | def _upscore_layer(self, layer, target_layer): 75 | # target_shape = target_layer.shape[1:-1] # NHWC 76 | target_shape = tf.shape(target_layer)[1:-1] 77 | upscored = tf.image.resize_images(layer, target_shape) 78 | return upscored 79 | def _fuse_by_cascade_conv1x1_128_upsamle_sum_conv1x1_2(self, scope): 80 | """ 81 | The feature fuse fashion of 82 | 'Deep Direct Regression for Multi-Oriented Scene Text Detection' 83 | 84 | Instead of fusion of scores, feature map from 1x1, 128 conv are fused, 85 | and the scores are predicted on it. 86 | """ 87 | base_map = self._fuse_by_cascade_conv1x1_upsample_sum(num_classes = 128, 88 | scope = 'feature_fuse') 89 | return base_map 90 | 91 | def _fuse_by_cascade_conv1x1_128_upsamle_concat_conv1x1_2(self, scope, num_classes = 32): 92 | import config 93 | num_layers = len(config.feat_layers) 94 | 95 | with tf.variable_scope(scope): 96 | smaller_score_map = None 97 | for idx in range(0, len(config.feat_layers))[::-1]: #[4, 3, 2, 1, 0] 98 | current_layer_name = config.feat_layers[idx] 99 | current_layer = self.end_points[current_layer_name] 100 | current_score_map = self._score_layer(current_layer, 101 | num_classes, current_layer_name) 102 | if smaller_score_map is None: 103 | smaller_score_map = current_score_map 104 | else: 105 | upscore_map = self._upscore_layer(smaller_score_map, current_score_map) 106 | smaller_score_map = tf.concat([current_score_map, upscore_map], axis = 0) 107 | 108 | return smaller_score_map 109 | 110 | 111 | def _fuse_by_cascade_conv1x1_upsample_sum(self, num_classes, scope): 112 | """ 113 | The feature fuse fashion of FCN for semantic segmentation: 114 | Suppose there are several feature maps with decreasing sizes , 115 | and we are going to get a single score map from them. 116 | 117 | Every feature map contributes to the final score map: 118 | predict score on all the feature maps using 1x1 conv, with 119 | depth equal to num_classes 120 | 121 | The score map is upsampled and added in a cascade way: 122 | start from the smallest score map, upsmale it to the size 123 | of the next score map with a larger size, and add them 124 | to get a fused score map. Upsample this fused score map and 125 | add it to the next sibling larger score map. The final 126 | score map is got when all score maps are fused together 127 | """ 128 | import config 129 | num_layers = len(config.feat_layers) 130 | 131 | with tf.variable_scope(scope): 132 | smaller_score_map = None 133 | for idx in range(0, len(config.feat_layers))[::-1]: #[4, 3, 2, 1, 0] 134 | current_layer_name = config.feat_layers[idx] 135 | current_layer = self.end_points[current_layer_name] 136 | current_score_map = self._score_layer(current_layer, 137 | num_classes, current_layer_name) 138 | if smaller_score_map is None: 139 | smaller_score_map = current_score_map 140 | else: 141 | upscore_map = self._upscore_layer(smaller_score_map, current_score_map) 142 | smaller_score_map = current_score_map + upscore_map 143 | 144 | return smaller_score_map 145 | 146 | def _fuse_feat_layers(self): 147 | import config 148 | if config.feat_fuse_type == FUSE_TYPE_cascade_conv1x1_upsample_sum: 149 | self.pixel_cls_logits = self._fuse_by_cascade_conv1x1_upsample_sum( 150 | config.num_classes, scope = 'pixel_cls') 151 | 152 | self.pixel_link_logits = self._fuse_by_cascade_conv1x1_upsample_sum( 153 | config.num_neighbours * 2, scope = 'pixel_link') 154 | 155 | elif config.feat_fuse_type == FUSE_TYPE_cascade_conv1x1_128_upsamle_sum_conv1x1_2: 156 | base_map = self._fuse_by_cascade_conv1x1_128_upsamle_sum_conv1x1_2( 157 | scope = 'fuse_feature') 158 | 159 | self.pixel_cls_logits = self._score_layer(base_map, 160 | config.num_classes, scope = 'pixel_cls') 161 | 162 | self.pixel_link_logits = self._score_layer(base_map, 163 | config.num_neighbours * 2, scope = 'pixel_link') 164 | elif config.feat_fuse_type == FUSE_TYPE_cascade_conv1x1_128_upsamle_concat_conv1x1_2: 165 | base_map = self._fuse_by_cascade_conv1x1_128_upsamle_concat_conv1x1_2( 166 | scope = 'fuse_feature') 167 | else: 168 | raise ValueError('feat_fuse_type not supported:%s'%(config.feat_fuse_type)) 169 | 170 | def _flat_pixel_cls_values(self, values): 171 | shape = values.shape.as_list() 172 | values = tf.reshape(values, shape = [shape[0], -1, shape[-1]]) 173 | return values 174 | 175 | 176 | def _logits_to_scores(self): 177 | self.pixel_cls_scores = tf.nn.softmax(self.pixel_cls_logits) 178 | self.pixel_cls_logits_flatten = \ 179 | self._flat_pixel_cls_values(self.pixel_cls_logits) 180 | self.pixel_cls_scores_flatten = \ 181 | self._flat_pixel_cls_values(self.pixel_cls_scores) 182 | 183 | import config 184 | # shape = self.pixel_link_logits.shape.as_list() 185 | shape = tf.shape(self.pixel_link_logits) 186 | self.pixel_link_logits = tf.reshape(self.pixel_link_logits, 187 | [shape[0], shape[1], shape[2], config.num_neighbours, 2]) 188 | 189 | self.pixel_link_scores = tf.nn.softmax(self.pixel_link_logits) 190 | 191 | self.pixel_pos_scores = self.pixel_cls_scores[:, :, :, 1] 192 | self.link_pos_scores = self.pixel_link_scores[:, :, :, :, 1] 193 | 194 | def build_loss(self, pixel_cls_labels, pixel_cls_weights, 195 | pixel_link_labels, pixel_link_weights, 196 | do_summary = True 197 | ): 198 | """ 199 | The loss consists of two parts: pixel_cls_loss + link_cls_loss, 200 | and link_cls_loss is calculated only on positive pixels 201 | """ 202 | import config 203 | count_warning = tf.get_local_variable( 204 | name = 'count_warning', initializer = tf.constant(0.0)) 205 | batch_size = config.batch_size_per_gpu 206 | ignore_label = config.ignore_label 207 | background_label = config.background_label 208 | text_label = config.text_label 209 | pixel_link_neg_loss_weight_lambda = config.pixel_link_neg_loss_weight_lambda 210 | pixel_cls_loss_weight_lambda = config.pixel_cls_loss_weight_lambda 211 | pixel_link_loss_weight = config.pixel_link_loss_weight 212 | 213 | def OHNM_single_image(scores, n_pos, neg_mask): 214 | """Online Hard Negative Mining. 215 | scores: the scores of being predicted as negative cls 216 | n_pos: the number of positive samples 217 | neg_mask: mask of negative samples 218 | Return: 219 | the mask of selected negative samples. 220 | if n_pos == 0, top 10000 negative samples will be selected. 221 | """ 222 | def has_pos(): 223 | return n_pos * config.max_neg_pos_ratio 224 | def no_pos(): 225 | return tf.constant(10000, dtype = tf.int32) 226 | 227 | n_neg = tf.cond(n_pos > 0, has_pos, no_pos) 228 | max_neg_entries = tf.reduce_sum(tf.cast(neg_mask, tf.int32)) 229 | 230 | n_neg = tf.minimum(n_neg, max_neg_entries) 231 | n_neg = tf.cast(n_neg, tf.int32) 232 | def has_neg(): 233 | neg_conf = tf.boolean_mask(scores, neg_mask) 234 | vals, _ = tf.nn.top_k(-neg_conf, k=n_neg) 235 | threshold = vals[-1]# a negtive value 236 | selected_neg_mask = tf.logical_and(neg_mask, scores <= -threshold) 237 | return selected_neg_mask 238 | def no_neg(): 239 | selected_neg_mask = tf.zeros_like(neg_mask) 240 | return selected_neg_mask 241 | 242 | selected_neg_mask = tf.cond(n_neg > 0, has_neg, no_neg) 243 | return tf.cast(selected_neg_mask, tf.int32) 244 | 245 | def OHNM_batch(neg_conf, pos_mask, neg_mask): 246 | selected_neg_mask = [] 247 | for image_idx in xrange(batch_size): 248 | image_neg_conf = neg_conf[image_idx, :] 249 | image_neg_mask = neg_mask[image_idx, :] 250 | image_pos_mask = pos_mask[image_idx, :] 251 | n_pos = tf.reduce_sum(tf.cast(image_pos_mask, tf.int32)) 252 | selected_neg_mask.append(OHNM_single_image(image_neg_conf, n_pos, image_neg_mask)) 253 | 254 | selected_neg_mask = tf.stack(selected_neg_mask) 255 | return selected_neg_mask 256 | 257 | # OHNM on pixel classification task 258 | pixel_cls_labels_flatten = tf.reshape(pixel_cls_labels, [batch_size, -1]) 259 | pos_pixel_weights_flatten = tf.reshape(pixel_cls_weights, [batch_size, -1]) 260 | 261 | pos_mask = tf.equal(pixel_cls_labels_flatten, text_label) 262 | neg_mask = tf.equal(pixel_cls_labels_flatten, background_label) 263 | 264 | n_pos = tf.reduce_sum(tf.cast(pos_mask, dtype = tf.float32)) 265 | 266 | with tf.name_scope('pixel_cls_loss'): 267 | def no_pos(): 268 | return tf.constant(.0); 269 | def has_pos(): 270 | pixel_cls_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 271 | logits = self.pixel_cls_logits_flatten, 272 | labels = tf.cast(pos_mask, dtype = tf.int32)) 273 | 274 | pixel_neg_scores = self.pixel_cls_scores_flatten[:, :, 0] 275 | selected_neg_pixel_mask = OHNM_batch(pixel_neg_scores, pos_mask, neg_mask) 276 | 277 | pixel_cls_weights = pos_pixel_weights_flatten + \ 278 | tf.cast(selected_neg_pixel_mask, tf.float32) 279 | n_neg = tf.cast(tf.reduce_sum(selected_neg_pixel_mask), tf.float32) 280 | loss = tf.reduce_sum(pixel_cls_loss * pixel_cls_weights) / (n_neg + n_pos) 281 | return loss 282 | 283 | # pixel_cls_loss = tf.cond(n_pos > 0, has_pos, no_pos) 284 | pixel_cls_loss = has_pos() 285 | tf.add_to_collection(tf.GraphKeys.LOSSES, pixel_cls_loss * pixel_cls_loss_weight_lambda) 286 | 287 | 288 | with tf.name_scope('pixel_link_loss'): 289 | def no_pos(): 290 | return tf.constant(.0), tf.constant(.0); 291 | 292 | def has_pos(): 293 | pixel_link_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( 294 | logits = self.pixel_link_logits, 295 | labels = pixel_link_labels) 296 | 297 | def get_loss(label): 298 | link_mask = tf.equal(pixel_link_labels, label) 299 | link_weights = pixel_link_weights * tf.cast(link_mask, tf.float32) 300 | n_links = tf.reduce_sum(link_weights) 301 | loss = tf.reduce_sum(pixel_link_loss * link_weights) / n_links 302 | return loss 303 | 304 | neg_loss = get_loss(0) 305 | pos_loss = get_loss(1) 306 | return neg_loss, pos_loss 307 | 308 | pixel_neg_link_loss, pixel_pos_link_loss = \ 309 | tf.cond(n_pos > 0, has_pos, no_pos) 310 | 311 | pixel_link_loss = pixel_pos_link_loss + \ 312 | pixel_neg_link_loss * pixel_link_neg_loss_weight_lambda 313 | 314 | tf.add_to_collection(tf.GraphKeys.LOSSES, 315 | pixel_link_loss_weight * pixel_link_loss) 316 | 317 | if do_summary: 318 | tf.summary.scalar('pixel_cls_loss', pixel_cls_loss) 319 | tf.summary.scalar('pixel_pos_link_loss', pixel_pos_link_loss) 320 | tf.summary.scalar('pixel_neg_link_loss', pixel_neg_link_loss) 321 | -------------------------------------------------------------------------------- /nets/vgg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | slim = tf.contrib.slim 4 | 5 | 6 | def basenet(inputs, fatness = 64, dilation = True): 7 | """ 8 | backbone net of vgg16 9 | """ 10 | # End_points collect relevant activations for external use. 11 | end_points = {} 12 | # Original VGG-16 blocks. 13 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME'): 14 | # Block1 15 | net = slim.repeat(inputs, 2, slim.conv2d, fatness, [3, 3], scope='conv1') 16 | end_points['conv1_2'] = net 17 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 18 | end_points['pool1'] = net 19 | 20 | 21 | # Block 2. 22 | net = slim.repeat(net, 2, slim.conv2d, fatness * 2, [3, 3], scope='conv2') 23 | end_points['conv2_2'] = net 24 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 25 | end_points['pool2'] = net 26 | 27 | 28 | # Block 3. 29 | net = slim.repeat(net, 3, slim.conv2d, fatness * 4, [3, 3], scope='conv3') 30 | end_points['conv3_3'] = net 31 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 32 | end_points['pool3'] = net 33 | 34 | # Block 4. 35 | net = slim.repeat(net, 3, slim.conv2d, fatness * 8, [3, 3], scope='conv4') 36 | end_points['conv4_3'] = net 37 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 38 | end_points['pool4'] = net 39 | 40 | # Block 5. 41 | net = slim.repeat(net, 3, slim.conv2d, fatness * 8, [3, 3], scope='conv5') 42 | end_points['conv5_3'] = net 43 | net = slim.max_pool2d(net, [3, 3], 1, scope='pool5') 44 | end_points['pool5'] = net 45 | 46 | # fc6 as conv, dilation is added 47 | if dilation: 48 | net = slim.conv2d(net, fatness * 16, [3, 3], rate=6, scope='fc6') 49 | else: 50 | net = slim.conv2d(net, fatness * 16, [3, 3], scope='fc6') 51 | end_points['fc6'] = net 52 | 53 | # fc7 as conv 54 | net = slim.conv2d(net, fatness * 16, [1, 1], scope='fc7') 55 | end_points['fc7'] = net 56 | 57 | return net, end_points; 58 | 59 | -------------------------------------------------------------------------------- /pixel_link.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 4 | 5 | import util 6 | 7 | PIXEL_CLS_WEIGHT_all_ones = 'PIXEL_CLS_WEIGHT_all_ones' 8 | PIXEL_CLS_WEIGHT_bbox_balanced = 'PIXEL_CLS_WEIGHT_bbox_balanced' 9 | PIXEL_NEIGHBOUR_TYPE_4 = 'PIXEL_NEIGHBOUR_TYPE_4' 10 | PIXEL_NEIGHBOUR_TYPE_8 = 'PIXEL_NEIGHBOUR_TYPE_8' 11 | 12 | DECODE_METHOD_join = 'DECODE_METHOD_join' 13 | 14 | 15 | def get_neighbours_8(x, y): 16 | """ 17 | Get 8 neighbours of point(x, y) 18 | """ 19 | return [(x - 1, y - 1), (x, y - 1), (x + 1, y - 1), \ 20 | (x - 1, y), (x + 1, y), \ 21 | (x - 1, y + 1), (x, y + 1), (x + 1, y + 1)] 22 | 23 | 24 | def get_neighbours_4(x, y): 25 | return [(x - 1, y), (x + 1, y), (x, y + 1), (x, y - 1)] 26 | 27 | 28 | def get_neighbours(x, y): 29 | import config 30 | neighbour_type = config.pixel_neighbour_type 31 | if neighbour_type == PIXEL_NEIGHBOUR_TYPE_4: 32 | return get_neighbours_4(x, y) 33 | else: 34 | return get_neighbours_8(x, y) 35 | 36 | def get_neighbours_fn(): 37 | import config 38 | neighbour_type = config.pixel_neighbour_type 39 | if neighbour_type == PIXEL_NEIGHBOUR_TYPE_4: 40 | return get_neighbours_4, 4 41 | else: 42 | return get_neighbours_8, 8 43 | 44 | 45 | 46 | def is_valid_cord(x, y, w, h): 47 | """ 48 | Tell whether the 2D coordinate (x, y) is valid or not. 49 | If valid, it should be on an h x w image 50 | """ 51 | return x >=0 and x < w and y >= 0 and y < h; 52 | 53 | #=====================Ground Truth Calculation Begin================== 54 | def tf_cal_gt_for_single_image(xs, ys, labels): 55 | pixel_cls_label, pixel_cls_weight, \ 56 | pixel_link_label, pixel_link_weight = \ 57 | tf.py_func( 58 | cal_gt_for_single_image, 59 | [xs, ys, labels], 60 | [tf.int32, tf.float32, tf.int32, tf.float32] 61 | ) 62 | import config 63 | score_map_shape = config.score_map_shape 64 | num_neighbours = config.num_neighbours 65 | h, w = score_map_shape 66 | pixel_cls_label.set_shape(score_map_shape) 67 | pixel_cls_weight.set_shape(score_map_shape) 68 | pixel_link_label.set_shape([h, w, num_neighbours]) 69 | pixel_link_weight.set_shape([h, w, num_neighbours]) 70 | return pixel_cls_label, pixel_cls_weight, \ 71 | pixel_link_label, pixel_link_weight 72 | 73 | 74 | def cal_gt_for_single_image(normed_xs, normed_ys, labels): 75 | """ 76 | Args: 77 | xs, ys: both in shape of (N, 4), 78 | and N is the number of bboxes, 79 | their values are normalized to [0,1] 80 | labels: shape = (N,), only two values are allowed: 81 | -1: ignored 82 | 1: text 83 | Return: 84 | pixel_cls_label 85 | pixel_cls_weight 86 | pixel_link_label 87 | pixel_link_weight 88 | """ 89 | import config 90 | score_map_shape = config.score_map_shape 91 | pixel_cls_weight_method = config.pixel_cls_weight_method 92 | h, w = score_map_shape 93 | text_label = config.text_label 94 | ignore_label = config.ignore_label 95 | background_label = config.background_label 96 | num_neighbours = config.num_neighbours 97 | bbox_border_width = config.bbox_border_width 98 | pixel_cls_border_weight_lambda = config.pixel_cls_border_weight_lambda 99 | 100 | # validate the args 101 | assert np.ndim(normed_xs) == 2 102 | assert np.shape(normed_xs)[-1] == 4 103 | assert np.shape(normed_xs) == np.shape(normed_ys) 104 | assert len(normed_xs) == len(labels) 105 | 106 | # assert set(labels).issubset(set([text_label, ignore_label, background_label])) 107 | 108 | num_positive_bboxes = np.sum(np.asarray(labels) == text_label) 109 | # rescale normalized xys to absolute values 110 | xs = normed_xs * w 111 | ys = normed_ys * h 112 | 113 | # initialize ground truth values 114 | mask = np.zeros(score_map_shape, dtype = np.int32) 115 | pixel_cls_label = np.ones(score_map_shape, dtype = np.int32) * background_label 116 | pixel_cls_weight = np.zeros(score_map_shape, dtype = np.float32) 117 | 118 | pixel_link_label = np.zeros((h, w, num_neighbours), dtype = np.int32) 119 | pixel_link_weight = np.ones((h, w, num_neighbours), dtype = np.float32) 120 | 121 | # find overlapped pixels, and consider them as ignored in pixel_cls_weight 122 | # and pixels in ignored bboxes are ignored as well 123 | # That is to say, only the weights of not ignored pixels are set to 1 124 | 125 | ## get the masks of all bboxes 126 | bbox_masks = [] 127 | pos_mask = mask.copy() 128 | for bbox_idx, (bbox_xs, bbox_ys) in enumerate(zip(xs, ys)): 129 | if labels[bbox_idx] == background_label: 130 | continue 131 | 132 | bbox_mask = mask.copy() 133 | 134 | bbox_points = zip(bbox_xs, bbox_ys) 135 | bbox_contours = util.img.points_to_contours(bbox_points) 136 | util.img.draw_contours(bbox_mask, bbox_contours, idx = -1, 137 | color = 1, border_width = -1) 138 | 139 | bbox_masks.append(bbox_mask) 140 | 141 | if labels[bbox_idx] == text_label: 142 | pos_mask += bbox_mask 143 | 144 | # treat overlapped in-bbox pixels as negative, 145 | # and non-overlapped ones as positive 146 | pos_mask = np.asarray(pos_mask == 1, dtype = np.int32) 147 | num_positive_pixels = np.sum(pos_mask) 148 | 149 | ## add all bbox_maskes, find non-overlapping pixels 150 | sum_mask = np.sum(bbox_masks, axis = 0) 151 | not_overlapped_mask = sum_mask == 1 152 | 153 | 154 | ## gt and weight calculation 155 | for bbox_idx, bbox_mask in enumerate(bbox_masks): 156 | bbox_label = labels[bbox_idx] 157 | if bbox_label == ignore_label: 158 | # for ignored bboxes, only non-overlapped pixels are encoded as ignored 159 | bbox_ignore_pixel_mask = bbox_mask * not_overlapped_mask 160 | pixel_cls_label += bbox_ignore_pixel_mask * ignore_label 161 | continue 162 | 163 | if labels[bbox_idx] == background_label: 164 | continue 165 | # from here on, only text boxes left. 166 | 167 | # for positive bboxes, all pixels within it and pos_mask are positive 168 | bbox_positive_pixel_mask = bbox_mask * pos_mask 169 | # background or text is encoded into cls gt 170 | pixel_cls_label += bbox_positive_pixel_mask * bbox_label 171 | 172 | # for the pixel cls weights, only positive pixels are set to ones 173 | if pixel_cls_weight_method == PIXEL_CLS_WEIGHT_all_ones: 174 | pixel_cls_weight += bbox_positive_pixel_mask 175 | elif pixel_cls_weight_method == PIXEL_CLS_WEIGHT_bbox_balanced: 176 | # let N denote num_positive_pixels 177 | # weight per pixel = N /num_positive_bboxes / n_pixels_in_bbox 178 | # so all pixel weights in this bbox sum to N/num_positive_bboxes 179 | # and all pixels weights in this image sum to N, the same 180 | # as setting all weights to 1 181 | num_bbox_pixels = np.sum(bbox_positive_pixel_mask) 182 | if num_bbox_pixels > 0: 183 | per_bbox_weight = num_positive_pixels * 1.0 / num_positive_bboxes 184 | per_pixel_weight = per_bbox_weight / num_bbox_pixels 185 | pixel_cls_weight += bbox_positive_pixel_mask * per_pixel_weight 186 | else: 187 | raise ValueError, 'pixel_cls_weight_method not supported:%s'\ 188 | %(pixel_cls_weight_method) 189 | 190 | 191 | ## calculate the labels and weights of links 192 | ### for all pixels in bboxes, all links are positive at first 193 | bbox_point_cords = np.where(bbox_positive_pixel_mask) 194 | pixel_link_label[bbox_point_cords] = 1 195 | 196 | 197 | ## the border of bboxes might be distored because of overlapping 198 | ## so recalculate it, and find the border mask 199 | new_bbox_contours = util.img.find_contours(bbox_positive_pixel_mask) 200 | bbox_border_mask = mask.copy() 201 | util.img.draw_contours(bbox_border_mask, new_bbox_contours, -1, 202 | color = 1, border_width = bbox_border_width * 2 + 1) 203 | bbox_border_mask *= bbox_positive_pixel_mask 204 | bbox_border_cords = np.where(bbox_border_mask) 205 | 206 | ## give more weight to the border pixels if configured 207 | pixel_cls_weight[bbox_border_cords] *= pixel_cls_border_weight_lambda 208 | 209 | ### change link labels according to their neighbour status 210 | border_points = zip(*bbox_border_cords) 211 | def in_bbox(nx, ny): 212 | return bbox_positive_pixel_mask[ny, nx] 213 | 214 | for y, x in border_points: 215 | neighbours = get_neighbours(x, y) 216 | for n_idx, (nx, ny) in enumerate(neighbours): 217 | if not is_valid_cord(nx, ny, w, h) or not in_bbox(nx, ny): 218 | pixel_link_label[y, x, n_idx] = 0 219 | 220 | pixel_cls_weight = np.asarray(pixel_cls_weight, dtype = np.float32) 221 | pixel_link_weight *= np.expand_dims(pixel_cls_weight, axis = -1) 222 | 223 | # try: 224 | # np.testing.assert_almost_equal(np.sum(pixel_cls_weight), num_positive_pixels, decimal = 1) 225 | # except: 226 | # print num_positive_pixels, np.sum(pixel_cls_label), np.sum(pixel_cls_weight) 227 | # import pdb 228 | # pdb.set_trace() 229 | return pixel_cls_label, pixel_cls_weight, pixel_link_label, pixel_link_weight 230 | 231 | #=====================Ground Truth Calculation End==================== 232 | 233 | 234 | #============================Decode Begin============================= 235 | 236 | def tf_decode_score_map_to_mask_in_batch(pixel_cls_scores, pixel_link_scores): 237 | masks = tf.py_func(decode_batch, 238 | [pixel_cls_scores, pixel_link_scores], tf.int32) 239 | b, h, w = pixel_cls_scores.shape.as_list() 240 | masks.set_shape([b, h, w]) 241 | return masks 242 | 243 | 244 | 245 | def decode_batch(pixel_cls_scores, pixel_link_scores, 246 | pixel_conf_threshold = None, link_conf_threshold = None): 247 | import config 248 | 249 | if pixel_conf_threshold is None: 250 | pixel_conf_threshold = config.pixel_conf_threshold 251 | 252 | if link_conf_threshold is None: 253 | link_conf_threshold = config.link_conf_threshold 254 | 255 | batch_size = pixel_cls_scores.shape[0] 256 | batch_mask = [] 257 | for image_idx in xrange(batch_size): 258 | image_pos_pixel_scores = pixel_cls_scores[image_idx, :, :] 259 | image_pos_link_scores = pixel_link_scores[image_idx, :, :, :] 260 | mask = decode_image( 261 | image_pos_pixel_scores, image_pos_link_scores, 262 | pixel_conf_threshold, link_conf_threshold 263 | ) 264 | batch_mask.append(mask) 265 | return np.asarray(batch_mask, np.int32) 266 | 267 | # @util.dec.print_calling_in_short 268 | # @util.dec.timeit 269 | def decode_image(pixel_scores, link_scores, 270 | pixel_conf_threshold, link_conf_threshold): 271 | import config 272 | if config.decode_method == DECODE_METHOD_join: 273 | mask = decode_image_by_join(pixel_scores, link_scores, 274 | pixel_conf_threshold, link_conf_threshold) 275 | return mask 276 | elif config.decode_method == DECODE_METHOD_border_split: 277 | return decode_image_by_border(pixel_scores, link_scores, 278 | pixel_conf_threshold, link_conf_threshold) 279 | else: 280 | raise ValueError('Unknow decode method:%s'%(config.decode_method)) 281 | 282 | 283 | import pyximport; pyximport.install() 284 | from pixel_link_decode import decode_image_by_join 285 | 286 | def min_area_rect(cnt): 287 | """ 288 | Args: 289 | xs: numpy ndarray with shape=(N,4). N is the number of oriented bboxes. 4 contains [x1, x2, x3, x4] 290 | ys: numpy ndarray with shape=(N,4), [y1, y2, y3, y4] 291 | Note that [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] can represent an oriented bbox. 292 | Return: 293 | the oriented rects sorrounding the box, in the format:[cx, cy, w, h, theta]. 294 | """ 295 | rect = cv2.minAreaRect(cnt) 296 | cx, cy = rect[0] 297 | w, h = rect[1] 298 | theta = rect[2] 299 | box = [cx, cy, w, h, theta] 300 | return box, w * h 301 | 302 | def rect_to_xys(rect, image_shape): 303 | """Convert rect to xys, i.e., eight points 304 | The `image_shape` is used to to make sure all points return are valid, i.e., within image area 305 | """ 306 | h, w = image_shape[0:2] 307 | def get_valid_x(x): 308 | if x < 0: 309 | return 0 310 | if x >= w: 311 | return w - 1 312 | return x 313 | 314 | def get_valid_y(y): 315 | if y < 0: 316 | return 0 317 | if y >= h: 318 | return h - 1 319 | return y 320 | 321 | rect = ((rect[0], rect[1]), (rect[2], rect[3]), rect[4]) 322 | points = cv2.cv.BoxPoints(rect) 323 | points = np.int0(points) 324 | for i_xy, (x, y) in enumerate(points): 325 | x = get_valid_x(x) 326 | y = get_valid_y(y) 327 | points[i_xy, :] = [x, y] 328 | points = np.reshape(points, -1) 329 | return points 330 | 331 | # @util.dec.print_calling_in_short 332 | # @util.dec.timeit 333 | def mask_to_bboxes(mask, image_shape = None, min_area = None, 334 | min_height = None, min_aspect_ratio = None): 335 | import config 336 | feed_shape = config.train_image_shape 337 | 338 | if image_shape is None: 339 | image_shape = feed_shape 340 | 341 | image_h, image_w = image_shape[0:2] 342 | 343 | if min_area is None: 344 | min_area = config.min_area 345 | 346 | if min_height is None: 347 | min_height = config.min_height 348 | bboxes = [] 349 | max_bbox_idx = mask.max() 350 | mask = util.img.resize(img = mask, size = (image_w, image_h), 351 | interpolation = cv2.INTER_NEAREST) 352 | 353 | for bbox_idx in xrange(1, max_bbox_idx + 1): 354 | bbox_mask = mask == bbox_idx 355 | # if bbox_mask.sum() < 10: 356 | # continue 357 | cnts = util.img.find_contours(bbox_mask) 358 | if len(cnts) == 0: 359 | continue 360 | cnt = cnts[0] 361 | rect, rect_area = min_area_rect(cnt) 362 | 363 | w, h = rect[2:-1] 364 | if min(w, h) < min_height: 365 | continue 366 | 367 | if rect_area < min_area: 368 | continue 369 | 370 | # if max(w, h) * 1.0 / min(w, h) < 2: 371 | # continue 372 | xys = rect_to_xys(rect, image_shape) 373 | bboxes.append(xys) 374 | 375 | return bboxes 376 | 377 | 378 | #============================Decode End=============================== 379 | -------------------------------------------------------------------------------- /pixel_link_decode.pyx: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import util 5 | PIXEL_NEIGHBOUR_TYPE_4 = 'PIXEL_NEIGHBOUR_TYPE_4' 6 | PIXEL_NEIGHBOUR_TYPE_8 = 'PIXEL_NEIGHBOUR_TYPE_8' 7 | 8 | 9 | def get_neighbours_8(x, y): 10 | """ 11 | Get 8 neighbours of point(x, y) 12 | """ 13 | return [(x - 1, y - 1), (x, y - 1), (x + 1, y - 1), \ 14 | (x - 1, y), (x + 1, y), \ 15 | (x - 1, y + 1), (x, y + 1), (x + 1, y + 1)] 16 | 17 | 18 | def get_neighbours_4(x, y): 19 | return [(x - 1, y), (x + 1, y), (x, y + 1), (x, y - 1)] 20 | 21 | 22 | def get_neighbours(x, y): 23 | import config 24 | neighbour_type = config.pixel_neighbour_type 25 | if neighbour_type == PIXEL_NEIGHBOUR_TYPE_4: 26 | return get_neighbours_4(x, y) 27 | else: 28 | return get_neighbours_8(x, y) 29 | 30 | def get_neighbours_fn(): 31 | import config 32 | neighbour_type = config.pixel_neighbour_type 33 | if neighbour_type == PIXEL_NEIGHBOUR_TYPE_4: 34 | return get_neighbours_4, 4 35 | else: 36 | return get_neighbours_8, 8 37 | 38 | 39 | 40 | def is_valid_cord(x, y, w, h): 41 | """ 42 | Tell whether the 2D coordinate (x, y) is valid or not. 43 | If valid, it should be on an h x w image 44 | """ 45 | return x >=0 and x < w and y >= 0 and y < h; 46 | 47 | 48 | 49 | def decode_image_by_join(pixel_scores, link_scores, 50 | pixel_conf_threshold, link_conf_threshold): 51 | pixel_mask = pixel_scores >= pixel_conf_threshold 52 | link_mask = link_scores >= link_conf_threshold 53 | points = zip(*np.where(pixel_mask)) 54 | h, w = np.shape(pixel_mask) 55 | group_mask = dict.fromkeys(points, -1) 56 | def find_parent(point): 57 | return group_mask[point] 58 | 59 | def set_parent(point, parent): 60 | group_mask[point] = parent 61 | 62 | def is_root(point): 63 | return find_parent(point) == -1 64 | 65 | def find_root(point): 66 | root = point 67 | update_parent = False 68 | while not is_root(root): 69 | root = find_parent(root) 70 | update_parent = True 71 | 72 | # for acceleration of find_root 73 | if update_parent: 74 | set_parent(point, root) 75 | 76 | return root 77 | 78 | def join(p1, p2): 79 | root1 = find_root(p1) 80 | root2 = find_root(p2) 81 | 82 | if root1 != root2: 83 | set_parent(root1, root2) 84 | 85 | def get_all(): 86 | root_map = {} 87 | def get_index(root): 88 | if root not in root_map: 89 | root_map[root] = len(root_map) + 1 90 | return root_map[root] 91 | 92 | mask = np.zeros_like(pixel_mask, dtype = np.int32) 93 | for point in points: 94 | point_root = find_root(point) 95 | bbox_idx = get_index(point_root) 96 | mask[point] = bbox_idx 97 | return mask 98 | 99 | # join by link 100 | for point in points: 101 | y, x = point 102 | neighbours = get_neighbours(x, y) 103 | for n_idx, (nx, ny) in enumerate(neighbours): 104 | if is_valid_cord(nx, ny, w, h): 105 | # reversed_neighbours = get_neighbours(nx, ny) 106 | # reversed_idx = reversed_neighbours.index((x, y)) 107 | link_value = link_mask[y, x, n_idx]# and link_mask[ny, nx, reversed_idx] 108 | pixel_cls = pixel_mask[ny, nx] 109 | if link_value and pixel_cls: 110 | join(point, (ny, nx)) 111 | 112 | mask = get_all() 113 | return mask 114 | 115 | -------------------------------------------------------------------------------- /pixel_link_env.txt: -------------------------------------------------------------------------------- 1 | name: pixel_link 2 | channels: 3 | - menpo 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 6 | - defaults 7 | dependencies: 8 | - certifi=2016.2.28=py27_0 9 | - cudatoolkit=7.5=2 10 | - cudnn=5.1=0 11 | - funcsigs=1.0.2=py27_0 12 | - libprotobuf=3.4.0=0 13 | - mkl=2017.0.3=0 14 | - mock=2.0.0=py27_0 15 | - numpy=1.12.1=py27_0 16 | - openssl=1.0.2l=0 17 | - pbr=1.10.0=py27_0 18 | - pip=9.0.1=py27_1 19 | - protobuf=3.4.0=py27_0 20 | - python=2.7.13=0 21 | - readline=6.2=2 22 | - setuptools=36.4.0=py27_1 23 | - six=1.10.0=py27_0 24 | - sqlite=3.13.0=0 25 | - tensorflow-gpu=1.1.0=np112py27_0 26 | - tk=8.5.18=0 27 | - werkzeug=0.12.2=py27_0 28 | - wheel=0.29.0=py27_0 29 | - zlib=1.2.11=0 30 | - opencv=2.4.11=nppy27_0 31 | - pip: 32 | - backports.functools-lru-cache==1.5 33 | - bottle==0.12.13 34 | - cycler==0.10.0 35 | - cython==0.28.2 36 | - enum34==1.1.6 37 | - kiwisolver==1.0.1 38 | - matplotlib==2.2.2 39 | - olefile==0.44 40 | - pillow==4.3.0 41 | - polygon2==2.0.8 42 | - pyparsing==2.2.0 43 | - python-dateutil==2.7.2 44 | - pytz==2018.4 45 | - setproctitle==1.1.10 46 | - subprocess32==3.2.7 47 | - tensorflow==1.1.0 48 | - virtualenv==15.1.0 49 | prefix: /home/dengdan/anaconda2/envs/pixel_link 50 | 51 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import ssd_vgg_preprocessing 24 | 25 | slim = tf.contrib.slim 26 | 27 | 28 | def get_preprocessing(is_training=False): 29 | """Returns preprocessing_fn(image, height, width, **kwargs). 30 | 31 | Args: 32 | name: The name of the preprocessing function. 33 | is_training: `True` if the model is being used for training. 34 | 35 | Returns: 36 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 37 | It has the following signature: 38 | image = preprocessing_fn(image, output_height, output_width, ...). 39 | 40 | Raises: 41 | ValueError: If Preprocessing `name` is not recognized. 42 | """ 43 | 44 | 45 | def preprocessing_fn(image, labels, bboxes, xs, ys, 46 | out_shape, data_format='NHWC', **kwargs): 47 | return ssd_vgg_preprocessing.preprocess_image( 48 | image, labels, bboxes, out_shape, xs, ys, data_format=data_format, 49 | is_training=is_training, **kwargs) 50 | return preprocessing_fn 51 | -------------------------------------------------------------------------------- /preprocessing/ssd_vgg_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Pre-processing images for SSD-type networks. 16 | """ 17 | from enum import Enum, IntEnum 18 | import numpy as np 19 | 20 | import tensorflow as tf 21 | import tf_extended as tfe 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | import cv2 25 | import util 26 | from preprocessing import tf_image 27 | 28 | slim = tf.contrib.slim 29 | 30 | # Resizing strategies. 31 | Resize = IntEnum('Resize', ('NONE', # Nothing! 32 | 'CENTRAL_CROP', # Crop (and pad if necessary). 33 | 'PAD_AND_RESIZE', # Pad, and resize to output shape. 34 | 'WARP_RESIZE')) # Warp resize. 35 | 36 | import config 37 | # VGG mean parameters. 38 | _R_MEAN = config.r_mean 39 | _G_MEAN = config.g_mean 40 | _B_MEAN = config.b_mean 41 | 42 | # Some training pre-processing parameters. 43 | MAX_EXPAND_SCALE = config.max_expand_scale 44 | BBOX_CROP_OVERLAP = config.bbox_crop_overlap # Minimum overlap to keep a bbox after cropping. 45 | MIN_OBJECT_COVERED = config.min_object_covered 46 | CROP_ASPECT_RATIO_RANGE = config.crop_aspect_ratio_range # Distortion ratio during cropping. 47 | AREA_RANGE = config.area_range 48 | FLIP = config.flip 49 | LABEL_IGNORE = config.ignore_label 50 | USING_SHORTER_SIDE_FILTERING = config.using_shorter_side_filtering 51 | 52 | MIN_SHORTER_SIDE = config.min_shorter_side 53 | MAX_SHORTER_SIDE = config.max_shorter_side 54 | 55 | USE_ROTATION = config.use_rotation 56 | 57 | def tf_image_whitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN]): 58 | """Subtracts the given means from each image channel. 59 | 60 | Returns: 61 | the centered image. 62 | """ 63 | if image.get_shape().ndims != 3: 64 | raise ValueError('Input must be of size [height, width, C>0]') 65 | num_channels = image.get_shape().as_list()[-1] 66 | if len(means) != num_channels: 67 | raise ValueError('len(means) must match the number of channels') 68 | 69 | mean = tf.constant(means, dtype=image.dtype) 70 | image = image - mean 71 | return image 72 | 73 | 74 | def tf_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True): 75 | """Re-convert to original image distribution, and convert to int if 76 | necessary. 77 | 78 | Returns: 79 | Centered image. 80 | """ 81 | mean = tf.constant(means, dtype=image.dtype) 82 | image = image + mean 83 | if to_int: 84 | image = tf.cast(image, tf.int32) 85 | return image 86 | 87 | 88 | def np_image_unwhitened(image, means=[_R_MEAN, _G_MEAN, _B_MEAN], to_int=True): 89 | """Re-convert to original image distribution, and convert to int if 90 | necessary. Numpy version. 91 | 92 | Returns: 93 | Centered image. 94 | """ 95 | img = np.copy(image) 96 | img += np.array(means, dtype=img.dtype) 97 | if to_int: 98 | img = img.astype(np.uint8) 99 | return img 100 | 101 | 102 | def tf_summary_image(image, bboxes, name='image', unwhitened=False): 103 | """Add image with bounding boxes to summary. 104 | """ 105 | if unwhitened: 106 | image = tf_image_unwhitened(image) 107 | image = tf.expand_dims(image, 0) 108 | bboxes = tf.expand_dims(bboxes, 0) 109 | image_with_box = tf.image.draw_bounding_boxes(image, bboxes) 110 | tf.summary.image(name, image_with_box) 111 | 112 | 113 | def apply_with_random_selector(x, func, num_cases): 114 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 115 | 116 | Args: 117 | x: input Tensor. 118 | func: Python function to apply. 119 | num_cases: Python int32, number of cases to sample sel from. 120 | 121 | Returns: 122 | The result of func(x, sel), where func receives the value of the 123 | selector as a python integer, but sel is sampled dynamically. 124 | """ 125 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 126 | # Pass the real x only to one of the func calls. 127 | return control_flow_ops.merge([ 128 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 129 | for case in range(num_cases)])[0] 130 | 131 | 132 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 133 | """Distort the color of a Tensor image. 134 | 135 | Each color distortion is non-commutative and thus ordering of the color ops 136 | matters. Ideally we would randomly permute the ordering of the color ops. 137 | Rather then adding that level of complication, we select a distinct ordering 138 | of color ops for each preprocessing thread. 139 | 140 | Args: 141 | image: 3-D Tensor containing single image in [0, 1]. 142 | color_ordering: Python int, a type of distortion (valid values: 0-3). 143 | fast_mode: Avoids slower ops (random_hue and random_contrast) 144 | scope: Optional scope for name_scope. 145 | Returns: 146 | 3-D Tensor color-distorted image on range [0, 1] 147 | Raises: 148 | ValueError: if color_ordering not in [0, 3] 149 | """ 150 | with tf.name_scope(scope, 'distort_color', [image]): 151 | if fast_mode: 152 | if color_ordering == 0: 153 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 154 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 155 | else: 156 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 157 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 158 | else: 159 | if color_ordering == 0: 160 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 161 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 162 | image = tf.image.random_hue(image, max_delta=0.2) 163 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 164 | elif color_ordering == 1: 165 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 166 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 167 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 168 | image = tf.image.random_hue(image, max_delta=0.2) 169 | elif color_ordering == 2: 170 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 171 | image = tf.image.random_hue(image, max_delta=0.2) 172 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 173 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 174 | elif color_ordering == 3: 175 | image = tf.image.random_hue(image, max_delta=0.2) 176 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 177 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 178 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 179 | else: 180 | raise ValueError('color_ordering must be in [0, 3]') 181 | # The random_* ops do not necessarily clamp. 182 | return tf.clip_by_value(image, 0.0, 1.0) 183 | 184 | 185 | def distorted_bounding_box_crop(image, 186 | labels, 187 | bboxes, 188 | xs, ys, 189 | min_object_covered, 190 | aspect_ratio_range, 191 | area_range, 192 | max_attempts = 200, 193 | scope=None): 194 | """Generates cropped_image using a one of the bboxes randomly distorted. 195 | 196 | See `tf.image.sample_distorted_bounding_box` for more documentation. 197 | 198 | Args: 199 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 200 | bbox: 2-D float Tensor of bounding boxes arranged [num_boxes, coords] 201 | where each coordinate is [0, 1) and the coordinates are arranged 202 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 203 | image. 204 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 205 | area of the image must contain at least this fraction of any bounding box 206 | supplied. 207 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 208 | image must have an aspect ratio = width / height within this range. 209 | area_range: An optional list of `floats`. The cropped area of the image 210 | must contain a fraction of the supplied image within in this range. 211 | max_attempts: An optional `int`. Number of attempts at generating a cropped 212 | region of the image of the specified constraints. After `max_attempts` 213 | failures, return the entire image. 214 | scope: Optional scope for name_scope. 215 | Returns: 216 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 217 | """ 218 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bboxes, xs, ys]): 219 | # Each bounding box has shape [1, num_boxes, box coords] and 220 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 221 | num_bboxes = tf.shape(bboxes)[0] 222 | def has_bboxes(): 223 | return bboxes, labels, xs, ys 224 | def no_bboxes(): 225 | xmin = tf.random_uniform((1,1), minval = 0, maxval = 0.9) 226 | ymin = tf.random_uniform((1,1), minval = 0, maxval = 0.9) 227 | w = tf.constant(0.1, dtype = tf.float32) 228 | h = w 229 | xmax = xmin + w 230 | ymax = ymin + h 231 | rnd_bboxes = tf.concat([ymin, xmin, ymax, xmax], axis = 1) 232 | rnd_labels = tf.constant([config.background_label], dtype = tf.int64) 233 | rnd_xs = tf.concat([xmin, xmax, xmax, xmin], axis = 1) 234 | rnd_ys = tf.concat([ymin, ymin, ymax, ymax], axis = 1) 235 | 236 | return rnd_bboxes, rnd_labels, rnd_xs, rnd_ys 237 | 238 | bboxes, labels, xs, ys = tf.cond(num_bboxes > 0, has_bboxes, no_bboxes) 239 | bbox_begin, bbox_size, distort_bbox = tf.image.sample_distorted_bounding_box( 240 | tf.shape(image), 241 | bounding_boxes=tf.expand_dims(bboxes, 0), 242 | min_object_covered=min_object_covered, 243 | aspect_ratio_range=aspect_ratio_range, 244 | area_range=area_range, 245 | max_attempts=max_attempts, 246 | use_image_if_no_bounding_boxes=True) 247 | distort_bbox = distort_bbox[0, 0] 248 | 249 | # Crop the image to the specified bounding box. 250 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 251 | # Restore the shape since the dynamic slice loses 3rd dimension. 252 | cropped_image.set_shape([None, None, 3]) 253 | 254 | # Update bounding boxes: resize and filter out. 255 | bboxes, xs, ys = tfe.bboxes_resize(distort_bbox, bboxes, xs, ys) 256 | labels, bboxes, xs, ys = tfe.bboxes_filter_overlap(labels, bboxes, xs, ys, 257 | threshold=BBOX_CROP_OVERLAP, assign_value = LABEL_IGNORE) 258 | return cropped_image, labels, bboxes, xs, ys, distort_bbox 259 | 260 | 261 | def tf_rotate_image(image, xs, ys): 262 | image, bboxes, xs, ys = tf.py_func(rotate_image, [image, xs, ys], [tf.uint8, tf.float32, tf.float32, tf.float32]) 263 | image.set_shape([None, None, 3]) 264 | bboxes.set_shape([None, 4]) 265 | xs.set_shape([None, 4]) 266 | ys.set_shape([None, 4]) 267 | return image, bboxes, xs, ys 268 | 269 | 270 | 271 | def rotate_image(image, xs, ys): 272 | rotation_angle = np.random.randint(low = -90, high = 90); 273 | scale = np.random.uniform(low = MIN_ROTATION_SCLAE, high = MAX_ROTATION_SCLAE) 274 | # scale = 1.0 275 | h, w = image.shape[0:2] 276 | # rotate image 277 | image, M = util.img.rotate_about_center(image, rotation_angle, scale = scale) 278 | 279 | nh, nw = image.shape[0:2] 280 | 281 | # rotate bboxes 282 | xs = xs * w 283 | ys = ys * h 284 | def rotate_xys(xs, ys): 285 | xs = np.reshape(xs, -1) 286 | ys = np.reshape(ys, -1) 287 | xs, ys = np.dot(M, np.transpose([xs, ys, 1])) 288 | xs = np.reshape(xs, (-1, 4)) 289 | ys = np.reshape(ys, (-1, 4)) 290 | return xs, ys 291 | xs, ys = rotate_xys(xs, ys) 292 | xs = xs * 1.0 / nw 293 | ys = ys * 1.0 / nh 294 | xmin = np.min(xs, axis = 1) 295 | xmin[np.where(xmin < 0)] = 0 296 | 297 | xmax = np.max(xs, axis = 1) 298 | xmax[np.where(xmax > 1)] = 1 299 | 300 | ymin = np.min(ys, axis = 1) 301 | ymin[np.where(ymin < 0)] = 0 302 | 303 | ymax = np.max(ys, axis = 1) 304 | ymax[np.where(ymax > 1)] = 1 305 | 306 | bboxes = np.transpose(np.asarray([ymin, xmin, ymax, xmax])) 307 | image = np.asarray(image, np.uint8) 308 | return image, bboxes, xs, ys 309 | 310 | def preprocess_for_train(image, labels, bboxes, xs, ys, 311 | out_shape, data_format='NHWC', 312 | scope='ssd_preprocessing_train'): 313 | """Preprocesses the given image for training. 314 | 315 | Note that the actual resizing scale is sampled from 316 | [`resize_size_min`, `resize_size_max`]. 317 | 318 | Args: 319 | image: A `Tensor` representing an image of arbitrary size. 320 | output_height: The height of the image after preprocessing. 321 | output_width: The width of the image after preprocessing. 322 | resize_side_min: The lower bound for the smallest side of the image for 323 | aspect-preserving resizing. 324 | resize_side_max: The upper bound for the smallest side of the image for 325 | aspect-preserving resizing. 326 | 327 | Returns: 328 | A preprocessed image. 329 | """ 330 | fast_mode = False 331 | with tf.name_scope(scope, 'ssd_preprocessing_train', [image, labels, bboxes]): 332 | if image.get_shape().ndims != 3: 333 | raise ValueError('Input must be of size [height, width, C>0]') 334 | 335 | # rotate image by 0, 0.5 * pi, pi, 1.5 * pi randomly 336 | # if USE_ROTATION: 337 | # image, bboxes, xs, ys = tf_image.random_rotate90(image, bboxes, xs, ys) 338 | # rotate image by 0, 0.5 * pi, pi, 1.5 * pi randomly 339 | if USE_ROTATION: 340 | rnd = tf.random_uniform((), minval = 0, maxval = 1) 341 | def rotate(): 342 | return tf_image.random_rotate90(image, bboxes, xs, ys) 343 | 344 | def no_rotate(): 345 | return image, bboxes, xs, ys 346 | 347 | image, bboxes, xs, ys = tf.cond(tf.less(rnd, config.rotation_prob), rotate, no_rotate) 348 | 349 | # expand image 350 | if MAX_EXPAND_SCALE > 1: 351 | rnd2 = tf.random_uniform((), minval = 0, maxval = 1) 352 | def expand(): 353 | scale = tf.random_uniform([], minval = 1.0, 354 | maxval = MAX_EXPAND_SCALE, dtype=tf.float32) 355 | image_shape = tf.cast(tf.shape(image), dtype = tf.float32) 356 | image_h, image_w = image_shape[0], image_shape[1] 357 | target_h = tf.cast(image_h * scale, dtype = tf.int32) 358 | target_w = tf.cast(image_w * scale, dtype = tf.int32) 359 | tf.logging.info('expanded') 360 | return tf_image.resize_image_bboxes_with_crop_or_pad( 361 | image, bboxes, xs, ys, target_h, target_w) 362 | 363 | def no_expand(): 364 | return image, bboxes, xs, ys 365 | 366 | image, bboxes, xs, ys = tf.cond(tf.less(rnd2, config.expand_prob), expand, no_expand) 367 | 368 | 369 | # Convert to float scaled [0, 1]. 370 | if image.dtype != tf.float32: 371 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 372 | # tf_summary_image(image, bboxes, 'image_with_bboxes') 373 | 374 | # Distort image and bounding boxes. 375 | dst_image = image 376 | dst_image, labels, bboxes, xs, ys, distort_bbox = \ 377 | distorted_bounding_box_crop(image, labels, bboxes, xs, ys, 378 | min_object_covered = MIN_OBJECT_COVERED, 379 | aspect_ratio_range = CROP_ASPECT_RATIO_RANGE, 380 | area_range = AREA_RANGE) 381 | # Resize image to output size. 382 | dst_image = tf_image.resize_image(dst_image, out_shape, 383 | method=tf.image.ResizeMethod.BILINEAR, 384 | align_corners=False) 385 | tf_summary_image(dst_image, bboxes, 'image_shape_distorted') 386 | 387 | # Filter bboxes using the length of shorter sides 388 | if USING_SHORTER_SIDE_FILTERING: 389 | xs = xs * out_shape[1] 390 | ys = ys * out_shape[0] 391 | labels, bboxes, xs, ys = tfe.bboxes_filter_by_shorter_side(labels, 392 | bboxes, xs, ys, 393 | min_height = MIN_SHORTER_SIDE, max_height = MAX_SHORTER_SIDE, 394 | assign_value = LABEL_IGNORE) 395 | xs = xs / out_shape[1] 396 | ys = ys / out_shape[0] 397 | 398 | # Randomly distort the colors. There are 4 ways to do it. 399 | dst_image = apply_with_random_selector( 400 | dst_image, 401 | lambda x, ordering: distort_color(x, ordering, fast_mode), 402 | num_cases=4) 403 | tf_summary_image(dst_image, bboxes, 'image_color_distorted') 404 | 405 | # Rescale to VGG input scale. 406 | image = dst_image * 255. 407 | image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 408 | # Image data format. 409 | if data_format == 'NCHW': 410 | image = tf.transpose(image, perm=(2, 0, 1)) 411 | return image, labels, bboxes, xs, ys 412 | 413 | 414 | def preprocess_for_eval(image, labels, bboxes, xs, ys, 415 | out_shape, data_format='NHWC', 416 | resize=Resize.WARP_RESIZE, 417 | do_resize = True, 418 | scope='ssd_preprocessing_train'): 419 | """Preprocess an image for evaluation. 420 | 421 | Args: 422 | image: A `Tensor` representing an image of arbitrary size. 423 | out_shape: Output shape after pre-processing (if resize != None) 424 | resize: Resize strategy. 425 | 426 | Returns: 427 | A preprocessed image. 428 | """ 429 | with tf.name_scope(scope): 430 | if image.get_shape().ndims != 3: 431 | raise ValueError('Input must be of size [height, width, C>0]') 432 | 433 | image = tf.to_float(image) 434 | image = tf_image_whitened(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 435 | 436 | if do_resize: 437 | if resize == Resize.NONE: 438 | pass 439 | else: 440 | image = tf_image.resize_image(image, out_shape, 441 | method=tf.image.ResizeMethod.BILINEAR, 442 | align_corners=False) 443 | 444 | # Image data format. 445 | if data_format == 'NCHW': 446 | image = tf.transpose(image, perm=(2, 0, 1)) 447 | return image, labels, bboxes, xs, ys 448 | 449 | 450 | def preprocess_image(image, 451 | labels = None, 452 | bboxes = None, 453 | xs = None, ys = None, 454 | out_shape = None, 455 | data_format = 'NHWC', 456 | is_training=False, 457 | **kwargs): 458 | """Pre-process an given image. 459 | 460 | Args: 461 | image: A `Tensor` representing an image of arbitrary size. 462 | output_height: The height of the image after preprocessing. 463 | output_width: The width of the image after preprocessing. 464 | is_training: `True` if we're preprocessing the image for training and 465 | `False` otherwise. 466 | resize_side_min: The lower bound for the smallest side of the image for 467 | aspect-preserving resizing. If `is_training` is `False`, then this value 468 | is used for rescaling. 469 | resize_side_max: The upper bound for the smallest side of the image for 470 | aspect-preserving resizing. If `is_training` is `False`, this value is 471 | ignored. Otherwise, the resize side is sampled from 472 | [resize_size_min, resize_size_max]. 473 | 474 | Returns: 475 | A preprocessed image. 476 | """ 477 | if is_training: 478 | return preprocess_for_train(image, labels, bboxes, xs, ys, 479 | out_shape=out_shape, 480 | data_format=data_format) 481 | else: 482 | return preprocess_for_eval(image, labels, bboxes, xs, ys, 483 | out_shape=out_shape, 484 | data_format=data_format, 485 | **kwargs) 486 | -------------------------------------------------------------------------------- /preprocessing/tf_image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors and Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Custom image operations. 16 | Most of the following methods extend TensorFlow image library, and part of 17 | the code is shameless copy-paste of the former! 18 | """ 19 | import tensorflow as tf 20 | 21 | from tensorflow.python.framework import constant_op 22 | from tensorflow.python.framework import dtypes 23 | from tensorflow.python.framework import ops 24 | from tensorflow.python.framework import tensor_shape 25 | from tensorflow.python.framework import tensor_util 26 | from tensorflow.python.ops import array_ops 27 | from tensorflow.python.ops import check_ops 28 | from tensorflow.python.ops import clip_ops 29 | from tensorflow.python.ops import control_flow_ops 30 | from tensorflow.python.ops import gen_image_ops 31 | from tensorflow.python.ops import gen_nn_ops 32 | from tensorflow.python.ops import string_ops 33 | from tensorflow.python.ops import math_ops 34 | from tensorflow.python.ops import random_ops 35 | from tensorflow.python.ops import variables 36 | 37 | import util 38 | 39 | # =========================================================================== # 40 | # Modification of TensorFlow image routines. 41 | # =========================================================================== # 42 | def _assert(cond, ex_type, msg): 43 | """A polymorphic assert, works with tensors and boolean expressions. 44 | If `cond` is not a tensor, behave like an ordinary assert statement, except 45 | that a empty list is returned. If `cond` is a tensor, return a list 46 | containing a single TensorFlow assert op. 47 | Args: 48 | cond: Something evaluates to a boolean value. May be a tensor. 49 | ex_type: The exception class to use. 50 | msg: The error message. 51 | Returns: 52 | A list, containing at most one assert op. 53 | """ 54 | if _is_tensor(cond): 55 | return [control_flow_ops.Assert(cond, [msg])] 56 | else: 57 | if not cond: 58 | raise ex_type(msg) 59 | else: 60 | return [] 61 | 62 | 63 | def _is_tensor(x): 64 | """Returns `True` if `x` is a symbolic tensor-like object. 65 | Args: 66 | x: A python object to check. 67 | Returns: 68 | `True` if `x` is a `tf.Tensor` or `tf.Variable`, otherwise `False`. 69 | """ 70 | return isinstance(x, (ops.Tensor, variables.Variable)) 71 | 72 | 73 | def _ImageDimensions(image): 74 | """Returns the dimensions of an image tensor. 75 | Args: 76 | image: A 3-D Tensor of shape `[height, width, channels]`. 77 | Returns: 78 | A list of `[height, width, channels]` corresponding to the dimensions of the 79 | input image. Dimensions that are statically known are python integers, 80 | otherwise they are integer scalar tensors. 81 | """ 82 | if image.get_shape().is_fully_defined(): 83 | return image.get_shape().as_list() 84 | else: 85 | static_shape = image.get_shape().with_rank(3).as_list() 86 | dynamic_shape = array_ops.unstack(array_ops.shape(image), 3) 87 | return [s if s is not None else d 88 | for s, d in zip(static_shape, dynamic_shape)] 89 | 90 | 91 | def _Check3DImage(image, require_static=True): 92 | """Assert that we are working with properly shaped image. 93 | Args: 94 | image: 3-D Tensor of shape [height, width, channels] 95 | require_static: If `True`, requires that all dimensions of `image` are 96 | known and non-zero. 97 | Raises: 98 | ValueError: if `image.shape` is not a 3-vector. 99 | Returns: 100 | An empty list, if `image` has fully defined dimensions. Otherwise, a list 101 | containing an assert op is returned. 102 | """ 103 | try: 104 | image_shape = image.get_shape().with_rank(3) 105 | except ValueError: 106 | raise ValueError("'image' must be three-dimensional.") 107 | if require_static and not image_shape.is_fully_defined(): 108 | raise ValueError("'image' must be fully defined.") 109 | if any(x == 0 for x in image_shape): 110 | raise ValueError("all dims of 'image.shape' must be > 0: %s" % 111 | image_shape) 112 | if not image_shape.is_fully_defined(): 113 | return [check_ops.assert_positive(array_ops.shape(image), 114 | ["all dims of 'image.shape' " 115 | "must be > 0."])] 116 | else: 117 | return [] 118 | 119 | 120 | def fix_image_flip_shape(image, result): 121 | """Set the shape to 3 dimensional if we don't know anything else. 122 | Args: 123 | image: original image size 124 | result: flipped or transformed image 125 | Returns: 126 | An image whose shape is at least None,None,None. 127 | """ 128 | image_shape = image.get_shape() 129 | if image_shape == tensor_shape.unknown_shape(): 130 | result.set_shape([None, None, None]) 131 | else: 132 | result.set_shape(image_shape) 133 | return result 134 | 135 | 136 | # =========================================================================== # 137 | # Image + BBoxes methods: cropping, resizing, flipping, ... 138 | # =========================================================================== # 139 | def bboxes_crop_or_pad(bboxes, xs, ys, 140 | height, width, 141 | offset_y, offset_x, 142 | target_height, target_width): 143 | """Adapt bounding boxes to crop or pad operations. 144 | Coordinates are always supposed to be relative to the image. 145 | 146 | Arguments: 147 | bboxes: Tensor Nx4 with bboxes coordinates [y_min, x_min, y_max, x_max]; 148 | height, width: Original image dimension; 149 | offset_y, offset_x: Offset to apply, 150 | negative if cropping, positive if padding; 151 | target_height, target_width: Target dimension after cropping / padding. 152 | """ 153 | with tf.name_scope('bboxes_crop_or_pad'): 154 | # Rescale bounding boxes in pixels. 155 | scale = tf.cast(tf.stack([height, width, height, width]), bboxes.dtype) 156 | bboxes = bboxes * scale 157 | xs *= tf.cast(width, bboxes.dtype) 158 | ys *= tf.cast(height, bboxes.dtype) 159 | # Add offset. 160 | offset = tf.cast(tf.stack([offset_y, offset_x, offset_y, offset_x]), bboxes.dtype) 161 | bboxes = bboxes + offset 162 | xs += tf.cast(offset_x, bboxes.dtype) 163 | ys += tf.cast(offset_y, bboxes.dtype) 164 | 165 | # Rescale to target dimension. 166 | scale = tf.cast(tf.stack([target_height, target_width, 167 | target_height, target_width]), bboxes.dtype) 168 | bboxes = bboxes / scale 169 | xs = xs / tf.cast(target_width, xs.dtype) 170 | ys = ys / tf.cast(target_height, ys.dtype) 171 | return bboxes, xs, ys 172 | 173 | 174 | def resize_image_bboxes_with_crop_or_pad(image, bboxes, xs, ys, 175 | target_height, target_width): 176 | """Crops and/or pads an image to a target width and height. 177 | Resizes an image to a target width and height by either centrally 178 | cropping the image or padding it evenly with zeros. 179 | 180 | If `width` or `height` is greater than the specified `target_width` or 181 | `target_height` respectively, this op centrally crops along that dimension. 182 | If `width` or `height` is smaller than the specified `target_width` or 183 | `target_height` respectively, this op centrally pads with 0 along that 184 | dimension. 185 | Args: 186 | image: 3-D tensor of shape `[height, width, channels]` 187 | target_height: Target height. 188 | target_width: Target width. 189 | Raises: 190 | ValueError: if `target_height` or `target_width` are zero or negative. 191 | Returns: 192 | Cropped and/or padded image of shape 193 | `[target_height, target_width, channels]` 194 | """ 195 | with tf.name_scope('resize_with_crop_or_pad'): 196 | image = ops.convert_to_tensor(image, name='image') 197 | 198 | assert_ops = [] 199 | assert_ops += _Check3DImage(image, require_static=False) 200 | assert_ops += _assert(target_width > 0, ValueError, 201 | 'target_width must be > 0.') 202 | assert_ops += _assert(target_height > 0, ValueError, 203 | 'target_height must be > 0.') 204 | 205 | image = control_flow_ops.with_dependencies(assert_ops, image) 206 | # `crop_to_bounding_box` and `pad_to_bounding_box` have their own checks. 207 | # Make sure our checks come first, so that error messages are clearer. 208 | if _is_tensor(target_height): 209 | target_height = control_flow_ops.with_dependencies( 210 | assert_ops, target_height) 211 | if _is_tensor(target_width): 212 | target_width = control_flow_ops.with_dependencies(assert_ops, target_width) 213 | 214 | def max_(x, y): 215 | if _is_tensor(x) or _is_tensor(y): 216 | return math_ops.maximum(x, y) 217 | else: 218 | return max(x, y) 219 | 220 | def min_(x, y): 221 | if _is_tensor(x) or _is_tensor(y): 222 | return math_ops.minimum(x, y) 223 | else: 224 | return min(x, y) 225 | 226 | def equal_(x, y): 227 | if _is_tensor(x) or _is_tensor(y): 228 | return math_ops.equal(x, y) 229 | else: 230 | return x == y 231 | 232 | height, width, _ = _ImageDimensions(image) 233 | width_diff = target_width - width 234 | offset_crop_width = max_(-width_diff // 2, 0) 235 | offset_pad_width = max_(width_diff // 2, 0) 236 | 237 | height_diff = target_height - height 238 | offset_crop_height = max_(-height_diff // 2, 0) 239 | offset_pad_height = max_(height_diff // 2, 0) 240 | 241 | # Maybe crop if needed. 242 | height_crop = min_(target_height, height) 243 | width_crop = min_(target_width, width) 244 | cropped = tf.image.crop_to_bounding_box(image, offset_crop_height, offset_crop_width, 245 | height_crop, width_crop) 246 | bboxes, xs, ys = bboxes_crop_or_pad(bboxes, xs, ys, 247 | height, width, 248 | -offset_crop_height, -offset_crop_width, 249 | height_crop, width_crop) 250 | # Maybe pad if needed. 251 | resized = tf.image.pad_to_bounding_box(cropped, offset_pad_height, offset_pad_width, 252 | target_height, target_width) 253 | bboxes, xs, ys = bboxes_crop_or_pad(bboxes, xs, ys, 254 | height_crop, width_crop, 255 | offset_pad_height, offset_pad_width, 256 | target_height, target_width) 257 | 258 | # In theory all the checks below are redundant. 259 | if resized.get_shape().ndims is None: 260 | raise ValueError('resized contains no shape.') 261 | 262 | resized_height, resized_width, _ = _ImageDimensions(resized) 263 | 264 | assert_ops = [] 265 | assert_ops += _assert(equal_(resized_height, target_height), ValueError, 266 | 'resized height is not correct.') 267 | assert_ops += _assert(equal_(resized_width, target_width), ValueError, 268 | 'resized width is not correct.') 269 | 270 | resized = control_flow_ops.with_dependencies(assert_ops, resized) 271 | return resized, bboxes, xs, ys 272 | 273 | 274 | def resize_image(image, size, 275 | method=tf.image.ResizeMethod.BILINEAR, 276 | align_corners=False): 277 | """Resize an image and bounding boxes. 278 | """ 279 | # Resize image. 280 | with tf.name_scope('resize_image'): 281 | height, width, channels = _ImageDimensions(image) 282 | image = tf.expand_dims(image, 0) 283 | image = tf.image.resize_images(image, size, 284 | method, align_corners) 285 | image = tf.reshape(image, tf.stack([size[0], size[1], channels])) 286 | return image 287 | 288 | 289 | def random_flip_left_right(image, bboxes, seed=None): 290 | """Random flip left-right of an image and its bounding boxes. 291 | """ 292 | def flip_bboxes(bboxes): 293 | """Flip bounding boxes coordinates. 294 | """ 295 | bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3], 296 | bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1) 297 | return bboxes 298 | 299 | # Random flip. Tensorflow implementation. 300 | with tf.name_scope('random_flip_left_right'): 301 | image = ops.convert_to_tensor(image, name='image') 302 | _Check3DImage(image, require_static=False) 303 | uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) 304 | mirror_cond = math_ops.less(uniform_random, .5) 305 | # Flip image. 306 | result = control_flow_ops.cond(mirror_cond, 307 | lambda: array_ops.reverse_v2(image, [1]), 308 | lambda: image) 309 | # Flip bboxes. 310 | bboxes = control_flow_ops.cond(mirror_cond, 311 | lambda: flip_bboxes(bboxes), 312 | lambda: bboxes) 313 | return fix_image_flip_shape(image, result), bboxes 314 | 315 | 316 | def random_rotate90(image, bboxes, xs, ys): 317 | with tf.name_scope('random_rotate90'): 318 | k = random_ops.random_uniform([], 0, 10000) 319 | k = tf.cast(k, tf.int32) 320 | 321 | image_shape = tf.shape(image) 322 | h, w = image_shape[0], image_shape[1] 323 | image = tf.image.rot90(image, k = k) 324 | bboxes, xs, ys = rotate90(bboxes, xs, ys, k) 325 | return image, bboxes, xs, ys 326 | 327 | def tf_rotate_point_by_90(x, y, k): 328 | return tf.py_func(util.img.rotate_point_by_90, [x, y, k], 329 | [tf.float32, tf.float32]) 330 | 331 | def rotate90(bboxes, xs, ys, k): 332 | # bboxes = tf.Print(bboxes, [bboxes], 'before rotate',summarize = 100) 333 | ymin, xmin, ymax, xmax = [bboxes[:, i] for i in range(4)] 334 | xmin, ymin = tf_rotate_point_by_90(xmin, ymin, k) 335 | xmax, ymax = tf_rotate_point_by_90(xmax, ymax, k) 336 | 337 | new_xmin = tf.minimum(xmin, xmax) 338 | new_xmax = tf.maximum(xmin, xmax) 339 | 340 | new_ymin = tf.minimum(ymin, ymax) 341 | new_ymax = tf.maximum(ymin, ymax) 342 | 343 | bboxes = tf.stack([new_ymin, new_xmin, new_ymax, new_xmax]) 344 | bboxes = tf.transpose(bboxes) 345 | 346 | xs, ys = tf_rotate_point_by_90(xs, ys, k) 347 | return bboxes, xs, ys 348 | 349 | if __name__ == "__main__": 350 | import util 351 | image_path = '~/Pictures/img_1.jpg' 352 | image_data = util.img.imread(image_path, rgb = True) 353 | bbox_data = [[100, 100, 300, 300], [400, 400, 500, 500]] 354 | def draw_bbox(img, bbox): 355 | xmin, ymin, xmax, ymax = bbox 356 | util.img.rectangle(img, left_up = (xmin, ymin), 357 | right_bottom = (xmax, ymax), 358 | color = util.img.COLOR_RGB_RED, 359 | border_width = 10) 360 | 361 | image = tf.placeholder(dtype = tf.uint8) 362 | bboxes = tf.placeholder(dtype = tf.int32) 363 | 364 | bboxes_float32 = tf.cast(bboxes, dtype = tf.float32) 365 | image_shape = tf.cast(tf.shape(image), dtype = tf.float32) 366 | image_h, image_w = image_shape[0], image_shape[1] 367 | xmin, ymin, xmax, ymax = [bboxes_float32[:, i] for i in range(4)] 368 | bboxes_normed = tf.stack([xmin / image_w, ymin / image_h, 369 | xmax / image_w, ymax / image_h]) 370 | bboxes_normed = tf.transpose(bboxes_normed) 371 | 372 | target_height = image_h * 2 373 | target_width = image_w * 2 374 | target_height = tf.cast(target_height, tf.int32) 375 | target_width = tf.cast(target_width, tf.int32) 376 | 377 | processed_image, processed_bboxes = resize_image_bboxes_with_crop_or_pad(image, bboxes_normed, 378 | target_height, target_width) 379 | 380 | with tf.Session() as sess: 381 | resized_image, resized_bboxes = sess.run( 382 | [processed_image, processed_bboxes], 383 | feed_dict = {image: image_data, bboxes: bbox_data}) 384 | for _bbox in bbox_data: 385 | draw_bbox(image_data, _bbox) 386 | util.plt.imshow('image_data', image_data) 387 | 388 | h, w = resized_image.shape[0:2] 389 | for _bbox in resized_bboxes: 390 | _bbox *= [w, h, w, h] 391 | draw_bbox(resized_image, _bbox) 392 | util.plt.imshow('resized_image', resized_image) 393 | -------------------------------------------------------------------------------- /samples/img_249_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/pixel_link/00cb9aacc80583a8aba77d6834748ab4cca03254/samples/img_249_pred.jpg -------------------------------------------------------------------------------- /samples/img_333_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZJULearning/pixel_link/00cb9aacc80583a8aba77d6834748ab4cca03254/samples/img_333_pred.jpg -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | export CUDA_VISIBLE_DEVICES=$1 4 | python test_pixel_link.py \ 5 | --checkpoint_path=$2 \ 6 | --dataset_dir=$3\ 7 | --gpu_memory_fraction=-1 8 | 9 | 10 | -------------------------------------------------------------------------------- /scripts/test_any.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | 4 | export CUDA_VISIBLE_DEVICES=$1 5 | 6 | python test_pixel_link_on_any_image.py \ 7 | --checkpoint_path=$2 \ 8 | --dataset_dir=$3 \ 9 | --eval_image_width=1280\ 10 | --eval_image_height=768\ 11 | --pixel_conf_threshold=0.5\ 12 | --link_conf_threshold=0.5\ 13 | --gpu_memory_fraction=-1 14 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | set -e 3 | export CUDA_VISIBLE_DEVICES=$1 4 | IMG_PER_GPU=$2 5 | 6 | TRAIN_DIR=${HOME}/models/pixel_link 7 | 8 | # get the number of gpus 9 | OLD_IFS="$IFS" 10 | IFS="," 11 | gpus=($CUDA_VISIBLE_DEVICES) 12 | IFS="$OLD_IFS" 13 | NUM_GPUS=${#gpus[@]} 14 | 15 | # batch_size = num_gpus * IMG_PER_GPU 16 | BATCH_SIZE=`expr $NUM_GPUS \* $IMG_PER_GPU` 17 | 18 | #DATASET=synthtext 19 | #DATASET_PATH=SynthText 20 | 21 | DATASET=icdar2015 22 | DATASET_DIR=$HOME/dataset/pixel_link/icdar2015 23 | 24 | python train_pixel_link.py \ 25 | --train_dir=${TRAIN_DIR} \ 26 | --num_gpus=${NUM_GPUS} \ 27 | --learning_rate=1e-3\ 28 | --gpu_memory_fraction=-1 \ 29 | --train_image_width=512 \ 30 | --train_image_height=512 \ 31 | --batch_size=${BATCH_SIZE}\ 32 | --dataset_dir=${DATASET_DIR} \ 33 | --dataset_name=${DATASET} \ 34 | --dataset_split_name=train \ 35 | --max_number_of_steps=100\ 36 | --checkpoint_path=${CKPT_PATH} \ 37 | --using_moving_average=1 38 | 39 | python train_pixel_link.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --num_gpus=${NUM_GPUS} \ 42 | --learning_rate=1e-2\ 43 | --gpu_memory_fraction=-1 \ 44 | --train_image_width=512 \ 45 | --train_image_height=512 \ 46 | --batch_size=${BATCH_SIZE}\ 47 | --dataset_dir=${DATASET_DIR} \ 48 | --dataset_name=${DATASET} \ 49 | --dataset_split_name=train \ 50 | --checkpoint_path=${CKPT_PATH} \ 51 | --using_moving_average=1\ 52 | 2>&1 | tee -a ${TRAIN_DIR}/log.log 53 | 54 | -------------------------------------------------------------------------------- /scripts/vis.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | python visualize_detection_result.py \ 4 | --image=$1 \ 5 | --det=$2 \ 6 | --output=~/temp/no-use/pixel_result 7 | -------------------------------------------------------------------------------- /test_pixel_link.py: -------------------------------------------------------------------------------- 1 | #encoding = utf-8 2 | 3 | import numpy as np 4 | import math 5 | import tensorflow as tf 6 | from tensorflow.python.ops import control_flow_ops 7 | from tensorflow.contrib.training.python.training import evaluation 8 | from datasets import dataset_factory 9 | from preprocessing import ssd_vgg_preprocessing 10 | from tf_extended import metrics as tfe_metrics 11 | import util 12 | import cv2 13 | import pixel_link 14 | from nets import pixel_link_symbol 15 | 16 | 17 | slim = tf.contrib.slim 18 | import config 19 | # =========================================================================== # 20 | # Checkpoint and running Flags 21 | # =========================================================================== # 22 | tf.app.flags.DEFINE_string('checkpoint_path', None, 23 | 'the path of pretrained model to be used. If there are checkpoints\ 24 | in train_dir, this config will be ignored.') 25 | 26 | tf.app.flags.DEFINE_float('gpu_memory_fraction', -1, 27 | 'the gpu memory fraction to be used. If less than 0, allow_growth = True is used.') 28 | 29 | 30 | # =========================================================================== # 31 | # I/O and preprocessing Flags. 32 | # =========================================================================== # 33 | tf.app.flags.DEFINE_integer( 34 | 'num_readers', 1, 35 | 'The number of parallel readers that read data from the dataset.') 36 | tf.app.flags.DEFINE_integer( 37 | 'num_preprocessing_threads', 4, 38 | 'The number of threads used to create the batches.') 39 | tf.app.flags.DEFINE_bool('preprocessing_use_rotation', False, 40 | 'Whether to use rotation for data augmentation') 41 | 42 | # =========================================================================== # 43 | # Dataset Flags. 44 | # =========================================================================== # 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_name', 'icdar2015', 'The name of the dataset to load.') 47 | tf.app.flags.DEFINE_string( 48 | 'dataset_split_name', 'test', 'The name of the train/test split.') 49 | tf.app.flags.DEFINE_string('dataset_dir', 50 | util.io.get_absolute_path('~/dataset/ICDAR2015/Challenge4/ch4_test_images'), 51 | 'The directory where the dataset files are stored.') 52 | 53 | tf.app.flags.DEFINE_integer('eval_image_width', 1280, 'Train image size') 54 | tf.app.flags.DEFINE_integer('eval_image_height', 768, 'Train image size') 55 | tf.app.flags.DEFINE_bool('using_moving_average', True, 56 | 'Whether to use ExponentionalMovingAverage') 57 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 58 | 'The decay rate of ExponentionalMovingAverage') 59 | 60 | 61 | FLAGS = tf.app.flags.FLAGS 62 | 63 | def config_initialization(): 64 | # image shape and feature layers shape inference 65 | image_shape = (FLAGS.eval_image_height, FLAGS.eval_image_width) 66 | 67 | if not FLAGS.dataset_dir: 68 | raise ValueError('You must supply the dataset directory with --dataset_dir') 69 | 70 | tf.logging.set_verbosity(tf.logging.DEBUG) 71 | config.load_config(FLAGS.checkpoint_path) 72 | config.init_config(image_shape, 73 | batch_size = 1, 74 | pixel_conf_threshold = 0.8, 75 | link_conf_threshold = 0.8, 76 | num_gpus = 1, 77 | ) 78 | 79 | util.proc.set_proc_name('test_pixel_link_on'+ '_' + FLAGS.dataset_name) 80 | 81 | 82 | 83 | def to_txt(txt_path, image_name, 84 | image_data, pixel_pos_scores, link_pos_scores): 85 | # write detection result as txt files 86 | def write_result_as_txt(image_name, bboxes, path): 87 | filename = util.io.join_path(path, 'res_%s.txt'%(image_name)) 88 | lines = [] 89 | for b_idx, bbox in enumerate(bboxes): 90 | values = [int(v) for v in bbox] 91 | line = "%d, %d, %d, %d, %d, %d, %d, %d\n"%tuple(values) 92 | lines.append(line) 93 | util.io.write_lines(filename, lines) 94 | print 'result has been written to:', filename 95 | 96 | mask = pixel_link.decode_batch(pixel_pos_scores, link_pos_scores)[0, ...] 97 | bboxes = pixel_link.mask_to_bboxes(mask, image_data.shape) 98 | write_result_as_txt(image_name, bboxes, txt_path) 99 | 100 | def test(): 101 | with tf.name_scope('test'): 102 | image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3]) 103 | image_shape = tf.placeholder(dtype = tf.int32, shape = [3, ]) 104 | processed_image, _, _, _, _ = ssd_vgg_preprocessing.preprocess_image(image, None, None, None, None, 105 | out_shape = config.image_shape, 106 | data_format = config.data_format, 107 | is_training = False) 108 | b_image = tf.expand_dims(processed_image, axis = 0) 109 | net = pixel_link_symbol.PixelLinkNet(b_image, is_training = True) 110 | global_step = slim.get_or_create_global_step() 111 | 112 | 113 | sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True) 114 | if FLAGS.gpu_memory_fraction < 0: 115 | sess_config.gpu_options.allow_growth = True 116 | elif FLAGS.gpu_memory_fraction > 0: 117 | sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction; 118 | 119 | checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path) 120 | logdir = util.io.join_path(checkpoint_dir, 'test', FLAGS.dataset_name + '_' +FLAGS.dataset_split_name) 121 | 122 | # Variables to restore: moving avg. or normal weights. 123 | if FLAGS.using_moving_average: 124 | variable_averages = tf.train.ExponentialMovingAverage( 125 | FLAGS.moving_average_decay) 126 | variables_to_restore = variable_averages.variables_to_restore() 127 | variables_to_restore[global_step.op.name] = global_step 128 | else: 129 | variables_to_restore = slim.get_variables_to_restore() 130 | 131 | saver = tf.train.Saver(var_list = variables_to_restore) 132 | 133 | 134 | image_names = util.io.ls(FLAGS.dataset_dir) 135 | image_names.sort() 136 | 137 | checkpoint = FLAGS.checkpoint_path 138 | checkpoint_name = util.io.get_filename(str(checkpoint)); 139 | dump_path = util.io.join_path(logdir, checkpoint_name) 140 | txt_path = util.io.join_path(dump_path,'txt') 141 | zip_path = util.io.join_path(dump_path, checkpoint_name + '_det.zip') 142 | 143 | with tf.Session(config = sess_config) as sess: 144 | saver.restore(sess, checkpoint) 145 | 146 | for iter, image_name in enumerate(image_names): 147 | image_data = util.img.imread( 148 | util.io.join_path(FLAGS.dataset_dir, image_name), rgb = True) 149 | image_name = image_name.split('.')[0] 150 | pixel_pos_scores, link_pos_scores = sess.run( 151 | [net.pixel_pos_scores, net.link_pos_scores], 152 | feed_dict = { 153 | image:image_data 154 | }) 155 | 156 | print '%d/%d: %s'%(iter + 1, len(image_names), image_name) 157 | to_txt(txt_path, 158 | image_name, image_data, 159 | pixel_pos_scores, link_pos_scores) 160 | 161 | 162 | # create zip file for icdar2015 163 | cmd = 'cd %s;zip -j %s %s/*'%(dump_path, zip_path, txt_path); 164 | print cmd 165 | util.cmd.cmd(cmd); 166 | print "zip file created: ", util.io.join_path(dump_path, zip_path) 167 | 168 | 169 | 170 | def main(_): 171 | config_initialization() 172 | test() 173 | 174 | 175 | if __name__ == '__main__': 176 | tf.app.run() 177 | -------------------------------------------------------------------------------- /test_pixel_link_on_any_image.py: -------------------------------------------------------------------------------- 1 | #encoding = utf-8 2 | 3 | import numpy as np 4 | import math 5 | import tensorflow as tf 6 | from tensorflow.python.ops import control_flow_ops 7 | from tensorflow.contrib.training.python.training import evaluation 8 | from datasets import dataset_factory 9 | from preprocessing import ssd_vgg_preprocessing 10 | from tf_extended import metrics as tfe_metrics 11 | import util 12 | import cv2 13 | import pixel_link 14 | from nets import pixel_link_symbol 15 | 16 | 17 | slim = tf.contrib.slim 18 | import config 19 | # =========================================================================== # 20 | # Checkpoint and running Flags 21 | # =========================================================================== # 22 | tf.app.flags.DEFINE_string('checkpoint_path', None, 23 | 'the path of pretrained model to be used. If there are checkpoints\ 24 | in train_dir, this config will be ignored.') 25 | 26 | tf.app.flags.DEFINE_float('gpu_memory_fraction', -1, 27 | 'the gpu memory fraction to be used. If less than 0, allow_growth = True is used.') 28 | 29 | 30 | # =========================================================================== # 31 | # Dataset Flags. 32 | # =========================================================================== # 33 | tf.app.flags.DEFINE_string( 34 | 'dataset_dir', 'None', 35 | 'The directory where the dataset files are stored.') 36 | 37 | tf.app.flags.DEFINE_integer('eval_image_width', None, 'resized image width for inference') 38 | tf.app.flags.DEFINE_integer('eval_image_height', None, 'resized image height for inference') 39 | tf.app.flags.DEFINE_float('pixel_conf_threshold', None, 'threshold on the pixel confidence') 40 | tf.app.flags.DEFINE_float('link_conf_threshold', None, 'threshold on the link confidence') 41 | 42 | 43 | tf.app.flags.DEFINE_bool('using_moving_average', True, 44 | 'Whether to use ExponentionalMovingAverage') 45 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 46 | 'The decay rate of ExponentionalMovingAverage') 47 | 48 | 49 | FLAGS = tf.app.flags.FLAGS 50 | 51 | def config_initialization(): 52 | # image shape and feature layers shape inference 53 | image_shape = (FLAGS.eval_image_height, FLAGS.eval_image_width) 54 | 55 | if not FLAGS.dataset_dir: 56 | raise ValueError('You must supply the dataset directory with --dataset_dir') 57 | 58 | tf.logging.set_verbosity(tf.logging.DEBUG) 59 | 60 | config.init_config(image_shape, 61 | batch_size = 1, 62 | pixel_conf_threshold = FLAGS.pixel_conf_threshold, 63 | link_conf_threshold = FLAGS.link_conf_threshold, 64 | num_gpus = 1, 65 | ) 66 | 67 | 68 | def test(): 69 | checkpoint_dir = util.io.get_dir(FLAGS.checkpoint_path) 70 | 71 | global_step = slim.get_or_create_global_step() 72 | with tf.name_scope('evaluation_%dx%d'%(FLAGS.eval_image_height, FLAGS.eval_image_width)): 73 | with tf.variable_scope(tf.get_variable_scope(), reuse = False): 74 | image = tf.placeholder(dtype=tf.int32, shape = [None, None, 3]) 75 | image_shape = tf.placeholder(dtype = tf.int32, shape = [3, ]) 76 | processed_image, _, _, _, _ = ssd_vgg_preprocessing.preprocess_image(image, None, None, None, None, 77 | out_shape = config.image_shape, 78 | data_format = config.data_format, 79 | is_training = False) 80 | b_image = tf.expand_dims(processed_image, axis = 0) 81 | 82 | # build model and loss 83 | net = pixel_link_symbol.PixelLinkNet(b_image, is_training = False) 84 | masks = pixel_link.tf_decode_score_map_to_mask_in_batch( 85 | net.pixel_pos_scores, net.link_pos_scores) 86 | 87 | sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True) 88 | if FLAGS.gpu_memory_fraction < 0: 89 | sess_config.gpu_options.allow_growth = True 90 | elif FLAGS.gpu_memory_fraction > 0: 91 | sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction; 92 | 93 | # Variables to restore: moving avg. or normal weights. 94 | if FLAGS.using_moving_average: 95 | variable_averages = tf.train.ExponentialMovingAverage( 96 | FLAGS.moving_average_decay) 97 | variables_to_restore = variable_averages.variables_to_restore( 98 | tf.trainable_variables()) 99 | variables_to_restore[global_step.op.name] = global_step 100 | else: 101 | variables_to_restore = slim.get_variables_to_restore() 102 | 103 | 104 | saver = tf.train.Saver(var_list = variables_to_restore) 105 | with tf.Session() as sess: 106 | saver.restore(sess, util.tf.get_latest_ckpt(FLAGS.checkpoint_path)) 107 | 108 | files = util.io.ls(FLAGS.dataset_dir) 109 | 110 | for image_name in files: 111 | file_path = util.io.join_path(FLAGS.dataset_dir, image_name) 112 | image_data = util.img.imread(file_path) 113 | link_scores, pixel_scores, mask_vals = sess.run( 114 | [net.link_pos_scores, net.pixel_pos_scores, masks], 115 | feed_dict = {image: image_data}) 116 | h, w, _ =image_data.shape 117 | def resize(img): 118 | return util.img.resize(img, size = (w, h), 119 | interpolation = cv2.INTER_NEAREST) 120 | 121 | def get_bboxes(mask): 122 | return pixel_link.mask_to_bboxes(mask, image_data.shape) 123 | 124 | def draw_bboxes(img, bboxes, color): 125 | for bbox in bboxes: 126 | points = np.reshape(bbox, [4, 2]) 127 | cnts = util.img.points_to_contours(points) 128 | util.img.draw_contours(img, contours = cnts, 129 | idx = -1, color = color, border_width = 1) 130 | image_idx = 0 131 | pixel_score = pixel_scores[image_idx, ...] 132 | mask = mask_vals[image_idx, ...] 133 | 134 | bboxes_det = get_bboxes(mask) 135 | 136 | mask = resize(mask) 137 | pixel_score = resize(pixel_score) 138 | 139 | draw_bboxes(image_data, bboxes_det, util.img.COLOR_RGB_RED) 140 | # print util.sit(pixel_score) 141 | # print util.sit(mask) 142 | print util.sit(image_data) 143 | 144 | 145 | def main(_): 146 | dataset = config_initialization() 147 | test() 148 | 149 | 150 | if __name__ == '__main__': 151 | tf.app.run() 152 | -------------------------------------------------------------------------------- /tf_extended/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional metrics. 16 | """ 17 | 18 | # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import 19 | from tf_extended.metrics import * 20 | from tf_extended.bboxes import * 21 | from tf_extended.math import * 22 | 23 | -------------------------------------------------------------------------------- /tf_extended/bboxes.py: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import numpy as np 15 | import tensorflow as tf 16 | import cv2 17 | import util 18 | import config 19 | from tf_extended import math as tfe_math 20 | def bboxes_resize(bbox_ref, bboxes, xs, ys, name=None): 21 | """Resize bounding boxes based on a reference bounding box, 22 | assuming that the latter is [0, 0, 1, 1] after transform. Useful for 23 | updating a collection of boxes after cropping an image. 24 | """ 25 | # Tensors inputs. 26 | with tf.name_scope(name, 'bboxes_resize'): 27 | h_ref = bbox_ref[2] - bbox_ref[0] 28 | w_ref = bbox_ref[3] - bbox_ref[1] 29 | 30 | # Translate. 31 | v = tf.stack([bbox_ref[0], bbox_ref[1], bbox_ref[0], bbox_ref[1]]) 32 | bboxes = bboxes - v 33 | xs = xs - bbox_ref[1] 34 | ys = ys - bbox_ref[0] 35 | 36 | # Scale. 37 | s = tf.stack([h_ref, w_ref, h_ref, w_ref]) 38 | bboxes = bboxes / s 39 | xs = xs / w_ref; 40 | ys = ys / h_ref; 41 | 42 | return bboxes, xs, ys 43 | 44 | 45 | 46 | # def bboxes_filter_center(labels, bboxes, scope=None): 47 | # """Filter out bounding boxes whose center are not in 48 | # the rectangle [0, 0, 1, 1] + margins. The margin Tensor 49 | # can be used to enforce or loosen this condition. 50 | # 51 | # Return: 52 | # labels, bboxes: Filtered elements. 53 | # """ 54 | # with tf.name_scope(scope, 'bboxes_filter', [labels, bboxes]): 55 | # cy = (bboxes[:, 0] + bboxes[:, 2]) / 2. 56 | # cx = (bboxes[:, 1] + bboxes[:, 3]) / 2. 57 | # mask = tf.greater(cy, 0.) 58 | # mask = tf.logical_and(mask, tf.greater(cx, 0.)) 59 | # mask = tf.logical_and(mask, tf.less(cy, 1.)) 60 | # mask = tf.logical_and(mask, tf.less(cx, 1.)) 61 | # # Boolean masking... 62 | # labels = tf.boolean_mask(labels, mask) 63 | # bboxes = tf.boolean_mask(bboxes, mask) 64 | # return labels, bboxes 65 | 66 | def bboxes_filter_overlap(labels, bboxes,xs, ys, threshold, scope=None, assign_value = None): 67 | """Filter out bounding boxes based on (relative )overlap with reference 68 | box [0, 0, 1, 1]. Remove completely bounding boxes, or assign negative 69 | labels to the one outside (useful for latter processing...). 70 | 71 | Return: 72 | labels, bboxes: Filtered (or newly assigned) elements. 73 | """ 74 | with tf.name_scope(scope, 'bboxes_filter_overlap', [labels, bboxes]): 75 | scores = bboxes_intersection(tf.constant([0, 0, 1, 1], bboxes.dtype),bboxes) 76 | 77 | if assign_value is not None: 78 | mask = scores < threshold 79 | mask = tf.logical_and(mask, tf.equal(labels, config.text_label)) 80 | labels = tf.where(mask, tf.ones_like(labels) * assign_value, labels) 81 | else: 82 | mask = scores > threshold 83 | labels = tf.boolean_mask(labels, mask) 84 | bboxes = tf.boolean_mask(bboxes, mask) 85 | scores = bboxes_intersection(tf.constant([0, 0, 1, 1], bboxes.dtype),bboxes) 86 | xs = tf.boolean_mask(xs, mask); 87 | ys = tf.boolean_mask(ys, mask); 88 | return labels, bboxes, xs, ys 89 | 90 | 91 | def bboxes_filter_by_shorter_side(labels, bboxes, xs, ys, min_height = 16, max_height = 32, assign_value = None): 92 | """ 93 | Filtering bboxes by the length of shorter side 94 | """ 95 | with tf.name_scope('bboxes_filter_by_shorter_side', [labels, bboxes]): 96 | bbox_rects = util.tf.min_area_rect(xs, ys) 97 | ws, hs = bbox_rects[:, 2], bbox_rects[:, 3] 98 | shorter_sides = tf.minimum(ws, hs) 99 | if assign_value is not None: 100 | mask = tf.logical_or(shorter_sides < min_height, shorter_sides > max_height) 101 | mask = tf.logical_and(mask, tf.equal(labels, config.text_label)) 102 | labels = tf.where(mask, tf.ones_like(labels) * assign_value, labels) 103 | else: 104 | mask = tf.logical_and(shorter_sides >= min_height, shorter_sides <= max_height) 105 | labels = tf.boolean_mask(labels, mask) 106 | bboxes = tf.boolean_mask(bboxes, mask) 107 | xs = tf.boolean_mask(xs, mask); 108 | ys = tf.boolean_mask(ys, mask); 109 | return labels, bboxes, xs, ys 110 | 111 | def bboxes_intersection(bbox_ref, bboxes, name=None): 112 | """Compute relative intersection between a reference box and a 113 | collection of bounding boxes. Namely, compute the quotient between 114 | intersection area and box area. 115 | 116 | Args: 117 | bbox_ref: (N, 4) or (4,) Tensor with reference bounding box(es). 118 | bboxes: (N, 4) Tensor, collection of bounding boxes. 119 | Return: 120 | (N,) Tensor with relative intersection. 121 | """ 122 | with tf.name_scope(name, 'bboxes_intersection'): 123 | # Should be more efficient to first transpose. 124 | bboxes = tf.transpose(bboxes) 125 | bbox_ref = tf.transpose(bbox_ref) 126 | # Intersection bbox and volume. 127 | int_ymin = tf.maximum(bboxes[0], bbox_ref[0]) 128 | int_xmin = tf.maximum(bboxes[1], bbox_ref[1]) 129 | int_ymax = tf.minimum(bboxes[2], bbox_ref[2]) 130 | int_xmax = tf.minimum(bboxes[3], bbox_ref[3]) 131 | h = tf.maximum(int_ymax - int_ymin, 0.) 132 | w = tf.maximum(int_xmax - int_xmin, 0.) 133 | # Volumes. 134 | inter_vol = h * w 135 | bboxes_vol = (bboxes[2] - bboxes[0]) * (bboxes[3] - bboxes[1]) 136 | scores = tfe_math.safe_divide(inter_vol, bboxes_vol, 'intersection') 137 | return scores 138 | 139 | 140 | def bboxes_matching(bboxes, gxs, gys, gignored, matching_threshold = 0.5, scope=None): 141 | """Matching a collection of detected boxes with groundtruth values. 142 | Does not accept batched-inputs. 143 | The algorithm goes as follows: for every detected box, check 144 | if one grountruth box is matching. If none, then considered as False Positive. 145 | If the grountruth box is already matched with another one, it also counts 146 | as a False Positive. We refer the Pascal VOC documentation for the details. 147 | 148 | Args: 149 | rbboxes: Nx4 Tensors. Detected objects, sorted by score; 150 | gbboxes: Groundtruth bounding boxes. May be zero padded, hence 151 | zero-class objects are ignored. 152 | matching_threshold: Threshold for a positive match. 153 | Return: Tuple of: 154 | n_gbboxes: Scalar Tensor with number of groundtruth boxes (may difer from 155 | size because of zero padding). 156 | tp_match: (N,)-shaped boolean Tensor containing with True Positives. 157 | fp_match: (N,)-shaped boolean Tensor containing with False Positives. 158 | """ 159 | with tf.name_scope(scope, 'bboxes_matching_single',[bboxes, gxs, gys, gignored]): 160 | # Number of groundtruth boxes. 161 | gignored = tf.cast(gignored, dtype = tf.bool) 162 | n_gbboxes = tf.count_nonzero(tf.logical_not(gignored)) 163 | # Grountruth matching arrays. 164 | gmatch = tf.zeros(tf.shape(gignored), dtype=tf.bool) 165 | grange = tf.range(tf.size(gignored), dtype=tf.int32) 166 | 167 | # Number of detected boxes 168 | n_bboxes = tf.shape(bboxes)[0] 169 | rshape = (n_bboxes, ) 170 | # True/False positive matching TensorArrays. 171 | # ta is short for TensorArray 172 | ta_tp_bool = tf.TensorArray(tf.bool, size=n_bboxes, dynamic_size=False, infer_shape=True) 173 | ta_fp_bool = tf.TensorArray(tf.bool, size=n_bboxes, dynamic_size=False, infer_shape=True) 174 | 175 | n_ignored_det = 0 176 | # Loop over returned objects. 177 | def m_condition(i, ta_tp, ta_fp, gmatch, n_ignored_det): 178 | r = tf.less(i, tf.shape(bboxes)[0]) 179 | return r 180 | 181 | def m_body(i, ta_tp, ta_fp, gmatch, n_ignored_det): 182 | # Jaccard score with groundtruth bboxes. 183 | rbbox = bboxes[i, :] 184 | # rbbox = tf.Print(rbbox, [rbbox]) 185 | jaccard = bboxes_jaccard(rbbox, gxs, gys) 186 | 187 | # Best fit, checking it's above threshold. 188 | idxmax = tf.cast(tf.argmax(jaccard, axis=0), dtype = tf.int32) 189 | 190 | jcdmax = jaccard[idxmax] 191 | match = jcdmax > matching_threshold 192 | existing_match = gmatch[idxmax] 193 | not_ignored = tf.logical_not(gignored[idxmax]) 194 | 195 | n_ignored_det = n_ignored_det + tf.cast(gignored[idxmax], tf.int32) 196 | # TP: match & no previous match and FP: previous match | no match. 197 | # If ignored: no record, i.e FP=False and TP=False. 198 | tp = tf.logical_and(not_ignored, tf.logical_and(match, tf.logical_not(existing_match))) 199 | ta_tp = ta_tp.write(i, tp) 200 | 201 | fp = tf.logical_and(not_ignored, tf.logical_or(existing_match, tf.logical_not(match))) 202 | ta_fp = ta_fp.write(i, fp) 203 | 204 | # Update grountruth match. 205 | mask = tf.logical_and(tf.equal(grange, idxmax), tf.logical_and(not_ignored, match)) 206 | gmatch = tf.logical_or(gmatch, mask) 207 | return [i+1, ta_tp, ta_fp, gmatch,n_ignored_det] 208 | # Main loop definition. 209 | i = 0 210 | [i, ta_tp_bool, ta_fp_bool, gmatch, n_ignored_det] = \ 211 | tf.while_loop(m_condition, m_body, 212 | [i, ta_tp_bool, ta_fp_bool, gmatch, n_ignored_det], 213 | parallel_iterations=1, 214 | back_prop=False) 215 | # TensorArrays to Tensors and reshape. 216 | tp_match = tf.reshape(ta_tp_bool.stack(), rshape) 217 | fp_match = tf.reshape(ta_fp_bool.stack(), rshape) 218 | 219 | # Some debugging information... 220 | # tp_match = tf.Print(tp_match, 221 | # [n_gbboxes, n_bboxes, 222 | # tf.reduce_sum(tf.cast(tp_match, tf.int64)), 223 | # tf.reduce_sum(tf.cast(fp_match, tf.int64)), 224 | # n_ignored_det, 225 | # tf.reduce_sum(tf.cast(gmatch, tf.int64))], 226 | # 'Matching (NG, ND, TP, FP, n_ignored_det,GM): ') 227 | return n_gbboxes, tp_match, fp_match 228 | 229 | def bboxes_jaccard(bbox, gxs, gys): 230 | jaccard = tf.py_func(np_bboxes_jaccard, [bbox, gxs, gys], tf.float32) 231 | jaccard.set_shape([None, ]) 232 | return jaccard 233 | 234 | def np_bboxes_jaccard(bbox, gxs, gys): 235 | # assert np.shape(bbox) == (8,) 236 | bbox_points = np.reshape(bbox, (4, 2)) 237 | cnts = util.img.points_to_contours(bbox_points) 238 | 239 | # contruct a 0-1 mask to draw contours on 240 | xmax = np.max(bbox_points[:, 0]) 241 | xmax = max(xmax, np.max(gxs)) + 10 242 | ymax = np.max(bbox_points[:, 1]) 243 | ymax = max(ymax, np.max(gys)) + 10 244 | mask = util.img.black((ymax, xmax)) 245 | 246 | # draw bbox on the mask 247 | bbox_mask = mask.copy() 248 | util.img.draw_contours(bbox_mask, cnts, idx = -1, color = 1, border_width = -1) 249 | jaccard = np.zeros((len(gxs),), dtype = np.float32) 250 | # draw ground truth 251 | for gt_idx, gt_bbox in enumerate(zip(gxs, gys)): 252 | gt_mask = mask.copy() 253 | gt_bbox = np.transpose(gt_bbox) 254 | # assert gt_bbox.shape == (4, 2) 255 | gt_cnts = util.img.points_to_contours(gt_bbox) 256 | util.img.draw_contours(gt_mask, gt_cnts, idx = -1, color = 1, border_width = -1) 257 | 258 | intersect = np.sum(bbox_mask * gt_mask) 259 | union = np.sum(bbox_mask + gt_mask >= 1) 260 | # assert intersect == np.sum(bbox_mask * gt_mask) 261 | # assert union == np.sum((bbox_mask + gt_mask) > 0) 262 | iou = intersect * 1.0 / union 263 | jaccard[gt_idx] = iou 264 | return jaccard 265 | -------------------------------------------------------------------------------- /tf_extended/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """TF Extended: additional math functions. 16 | """ 17 | import tensorflow as tf 18 | 19 | from tensorflow.python.ops import array_ops 20 | from tensorflow.python.ops import math_ops 21 | from tensorflow.python.framework import dtypes 22 | from tensorflow.python.framework import ops 23 | 24 | 25 | def safe_divide(numerator, denominator, name): 26 | """Divides two values, returning 0 if the denominator is <= 0. 27 | Args: 28 | numerator: A real `Tensor`. 29 | denominator: A real `Tensor`, with dtype matching `numerator`. 30 | name: Name for the returned op. 31 | Returns: 32 | 0 if `denominator` <= 0, else `numerator` / `denominator` 33 | """ 34 | return tf.where( 35 | math_ops.greater(denominator, 0), 36 | math_ops.divide(numerator, denominator), 37 | tf.zeros_like(numerator), 38 | name=name) 39 | 40 | 41 | -------------------------------------------------------------------------------- /tf_extended/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops import variables 3 | from tensorflow.python.ops import array_ops 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.ops import state_ops 6 | from tensorflow.python.ops import variable_scope 7 | from tf_extended import math as tfe_math 8 | import util 9 | 10 | def _create_local(name, shape, collections=None, validate_shape=True, 11 | dtype=tf.float32): 12 | """Creates a new local variable. 13 | Args: 14 | name: The name of the new or existing variable. 15 | shape: Shape of the new or existing variable. 16 | collections: A list of collection names to which the Variable will be added. 17 | validate_shape: Whether to validate the shape of the variable. 18 | dtype: Data type of the variables. 19 | Returns: 20 | The created variable. 21 | """ 22 | # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES 23 | collections = list(collections or []) 24 | collections += [ops.GraphKeys.LOCAL_VARIABLES] 25 | return variables.Variable( 26 | initial_value=array_ops.zeros(shape, dtype=dtype), 27 | name=name, 28 | trainable=False, 29 | collections=collections, 30 | validate_shape=validate_shape) 31 | 32 | def streaming_tp_fp_arrays(num_gbboxes, tp, fp, 33 | metrics_collections=None, 34 | updates_collections=None, 35 | name=None): 36 | """Streaming computation of True and False Positive arrays. 37 | """ 38 | with variable_scope.variable_scope(name, 'streaming_tp_fp', 39 | [num_gbboxes, tp, fp]): 40 | num_gbboxes = tf.cast(num_gbboxes, tf.int32) 41 | tp = tf.cast(tp, tf.bool) 42 | fp = tf.cast(fp, tf.bool) 43 | # Reshape TP and FP tensors and clean away 0 class values. 44 | tp = tf.reshape(tp, [-1]) 45 | fp = tf.reshape(fp, [-1]) 46 | 47 | # Local variables accumlating information over batches. 48 | v_num_objects = _create_local('v_num_gbboxes', shape=[], dtype=tf.int32) 49 | v_tp = _create_local('v_tp', shape=[0, ], dtype=tf.bool) 50 | v_fp = _create_local('v_fp', shape=[0, ], dtype=tf.bool) 51 | 52 | 53 | # Update operations. 54 | num_objects_op = state_ops.assign_add(v_num_objects, 55 | tf.reduce_sum(num_gbboxes)) 56 | tp_op = state_ops.assign(v_tp, tf.concat([v_tp, tp], axis=0), 57 | validate_shape=False) 58 | fp_op = state_ops.assign(v_fp, tf.concat([v_fp, fp], axis=0), 59 | validate_shape=False) 60 | 61 | # Value and update ops. 62 | val = (v_num_objects, v_tp, v_fp) 63 | with ops.control_dependencies([num_objects_op, tp_op, fp_op]): 64 | update_op = (num_objects_op, tp_op, fp_op) 65 | 66 | return val, update_op 67 | 68 | 69 | def precision_recall(num_gbboxes, tp, fp, scope=None): 70 | """Compute precision and recall from true positives and false 71 | positives booleans arrays 72 | """ 73 | 74 | # Sort by score. 75 | with tf.name_scope(scope, 'precision_recall'): 76 | # Computer recall and precision. 77 | tp = tf.reduce_sum(tf.cast(tp, tf.float32), axis=0) 78 | fp = tf.reduce_sum(tf.cast(fp, tf.float32), axis=0) 79 | recall = tfe_math.safe_divide(tp, tf.cast(num_gbboxes, tf.float32), 'recall') 80 | precision = tfe_math.safe_divide(tp, tp + fp, 'precision') 81 | return tf.tuple([precision, recall]) 82 | 83 | def fmean(pre, rec): 84 | """Compute f-mean with precision and recall 85 | """ 86 | def zero(): 87 | return tf.zeros([]) 88 | def not_zero(): 89 | return 2 * pre * rec / (pre + rec) 90 | 91 | return tf.cond(pre + rec > 0, not_zero, zero) 92 | -------------------------------------------------------------------------------- /train_pixel_link.py: -------------------------------------------------------------------------------- 1 | #test code to make sure the ground truth calculation and data batch works well. 2 | 3 | import numpy as np 4 | import tensorflow as tf # test 5 | from tensorflow.python.ops import control_flow_ops 6 | 7 | from datasets import dataset_factory 8 | 9 | from nets import pixel_link_symbol 10 | import util 11 | import pixel_link 12 | 13 | slim = tf.contrib.slim 14 | import config 15 | # =========================================================================== # 16 | # Checkpoint and running Flags 17 | # =========================================================================== # 18 | tf.app.flags.DEFINE_string('train_dir', None, 19 | 'the path to store checkpoints and eventfiles for summaries') 20 | 21 | tf.app.flags.DEFINE_string('checkpoint_path', None, 22 | 'the path of pretrained model to be used. If there are checkpoints in train_dir, this config will be ignored.') 23 | 24 | tf.app.flags.DEFINE_float('gpu_memory_fraction', -1, 25 | 'the gpu memory fraction to be used. If less than 0, allow_growth = True is used.') 26 | 27 | tf.app.flags.DEFINE_integer('batch_size', None, 'The number of samples in each batch.') 28 | tf.app.flags.DEFINE_integer('num_gpus', 1, 'The number of gpus can be used.') 29 | tf.app.flags.DEFINE_integer('max_number_of_steps', 1000000, 'The maximum number of training steps.') 30 | tf.app.flags.DEFINE_integer('log_every_n_steps', 1, 'log frequency') 31 | tf.app.flags.DEFINE_bool("ignore_missing_vars", False, '') 32 | tf.app.flags.DEFINE_string('checkpoint_exclude_scopes', None, 'checkpoint_exclude_scopes') 33 | 34 | # =========================================================================== # 35 | # Optimizer configs. 36 | # =========================================================================== # 37 | tf.app.flags.DEFINE_float('learning_rate', 0.001, 'learning rate.') 38 | tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum for the MomentumOptimizer') 39 | tf.app.flags.DEFINE_float('weight_decay', 0.0001, 'The weight decay on the model weights.') 40 | tf.app.flags.DEFINE_bool('using_moving_average', True, 'Whether to use ExponentionalMovingAverage') 41 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 'The decay rate of ExponentionalMovingAverage') 42 | 43 | # =========================================================================== # 44 | # I/O and preprocessing Flags. 45 | # =========================================================================== # 46 | tf.app.flags.DEFINE_integer( 47 | 'num_readers', 1, 48 | 'The number of parallel readers that read data from the dataset.') 49 | tf.app.flags.DEFINE_integer( 50 | 'num_preprocessing_threads', 24, 51 | 'The number of threads used to create the batches.') 52 | 53 | # =========================================================================== # 54 | # Dataset Flags. 55 | # =========================================================================== # 56 | tf.app.flags.DEFINE_string( 57 | 'dataset_name', None, 'The name of the dataset to load.') 58 | tf.app.flags.DEFINE_string( 59 | 'dataset_split_name', 'train', 'The name of the train/test split.') 60 | tf.app.flags.DEFINE_string( 61 | 'dataset_dir', None, 'The directory where the dataset files are stored.') 62 | tf.app.flags.DEFINE_integer('train_image_width', 512, 'Train image size') 63 | tf.app.flags.DEFINE_integer('train_image_height', 512, 'Train image size') 64 | 65 | 66 | FLAGS = tf.app.flags.FLAGS 67 | def config_initialization(): 68 | # image shape and feature layers shape inference 69 | image_shape = (FLAGS.train_image_height, FLAGS.train_image_width) 70 | 71 | if not FLAGS.dataset_dir: 72 | raise ValueError('You must supply the dataset directory with --dataset_dir') 73 | 74 | tf.logging.set_verbosity(tf.logging.DEBUG) 75 | util.init_logger( 76 | log_file = 'log_train_pixel_link_%d_%d.log'%image_shape, 77 | log_path = FLAGS.train_dir, stdout = False, mode = 'a') 78 | 79 | 80 | config.load_config(FLAGS.train_dir) 81 | 82 | config.init_config(image_shape, 83 | batch_size = FLAGS.batch_size, 84 | weight_decay = FLAGS.weight_decay, 85 | num_gpus = FLAGS.num_gpus 86 | ) 87 | 88 | batch_size = config.batch_size 89 | batch_size_per_gpu = config.batch_size_per_gpu 90 | 91 | tf.summary.scalar('batch_size', batch_size) 92 | tf.summary.scalar('batch_size_per_gpu', batch_size_per_gpu) 93 | 94 | util.proc.set_proc_name('train_pixel_link_on'+ '_' + FLAGS.dataset_name) 95 | 96 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 97 | config.print_config(FLAGS, dataset) 98 | return dataset 99 | 100 | def create_dataset_batch_queue(dataset): 101 | from preprocessing import ssd_vgg_preprocessing 102 | 103 | with tf.device('/cpu:0'): 104 | with tf.name_scope(FLAGS.dataset_name + '_data_provider'): 105 | provider = slim.dataset_data_provider.DatasetDataProvider( 106 | dataset, 107 | num_readers=FLAGS.num_readers, 108 | common_queue_capacity=1000 * config.batch_size, 109 | common_queue_min=700 * config.batch_size, 110 | shuffle=True) 111 | # Get for SSD network: image, labels, bboxes. 112 | [image, glabel, gbboxes, x1, x2, x3, x4, y1, y2, y3, y4] = provider.get([ 113 | 'image', 114 | 'object/label', 115 | 'object/bbox', 116 | 'object/oriented_bbox/x1', 117 | 'object/oriented_bbox/x2', 118 | 'object/oriented_bbox/x3', 119 | 'object/oriented_bbox/x4', 120 | 'object/oriented_bbox/y1', 121 | 'object/oriented_bbox/y2', 122 | 'object/oriented_bbox/y3', 123 | 'object/oriented_bbox/y4' 124 | ]) 125 | gxs = tf.transpose(tf.stack([x1, x2, x3, x4])) #shape = (N, 4) 126 | gys = tf.transpose(tf.stack([y1, y2, y3, y4])) 127 | image = tf.identity(image, 'input_image') 128 | 129 | # Pre-processing image, labels and bboxes. 130 | image, glabel, gbboxes, gxs, gys = \ 131 | ssd_vgg_preprocessing.preprocess_image( 132 | image, glabel, gbboxes, gxs, gys, 133 | out_shape = config.train_image_shape, 134 | data_format = config.data_format, 135 | use_rotation = config.use_rotation, 136 | is_training = True) 137 | image = tf.identity(image, 'processed_image') 138 | 139 | # calculate ground truth 140 | pixel_cls_label, pixel_cls_weight, \ 141 | pixel_link_label, pixel_link_weight = \ 142 | pixel_link.tf_cal_gt_for_single_image(gxs, gys, glabel) 143 | 144 | # batch them 145 | with tf.name_scope(FLAGS.dataset_name + '_batch'): 146 | b_image, b_pixel_cls_label, b_pixel_cls_weight, \ 147 | b_pixel_link_label, b_pixel_link_weight = \ 148 | tf.train.batch( 149 | [image, pixel_cls_label, pixel_cls_weight, 150 | pixel_link_label, pixel_link_weight], 151 | batch_size = config.batch_size_per_gpu, 152 | num_threads= FLAGS.num_preprocessing_threads, 153 | capacity = 500) 154 | with tf.name_scope(FLAGS.dataset_name + '_prefetch_queue'): 155 | batch_queue = slim.prefetch_queue.prefetch_queue( 156 | [b_image, b_pixel_cls_label, b_pixel_cls_weight, 157 | b_pixel_link_label, b_pixel_link_weight], 158 | capacity = 50) 159 | return batch_queue 160 | 161 | def sum_gradients(clone_grads): 162 | averaged_grads = [] 163 | for grad_and_vars in zip(*clone_grads): 164 | grads = [] 165 | var = grad_and_vars[0][1] 166 | try: 167 | for g, v in grad_and_vars: 168 | assert v == var 169 | grads.append(g) 170 | grad = tf.add_n(grads, name = v.op.name + '_summed_gradients') 171 | except: 172 | import pdb 173 | pdb.set_trace() 174 | 175 | averaged_grads.append((grad, v)) 176 | 177 | # tf.summary.histogram("variables_and_gradients_" + grad.op.name, grad) 178 | # tf.summary.histogram("variables_and_gradients_" + v.op.name, v) 179 | # tf.summary.scalar("variables_and_gradients_" + grad.op.name+\ 180 | # '_mean/var_mean', tf.reduce_mean(grad)/tf.reduce_mean(var)) 181 | # tf.summary.scalar("variables_and_gradients_" + v.op.name+'_mean',tf.reduce_mean(var)) 182 | return averaged_grads 183 | 184 | 185 | def create_clones(batch_queue): 186 | with tf.device('/cpu:0'): 187 | global_step = slim.create_global_step() 188 | learning_rate = tf.constant(FLAGS.learning_rate, name='learning_rate') 189 | optimizer = tf.train.MomentumOptimizer(learning_rate, 190 | momentum=FLAGS.momentum, name='Momentum') 191 | 192 | tf.summary.scalar('learning_rate', learning_rate) 193 | # place clones 194 | pixel_link_loss = 0; # for summary only 195 | gradients = [] 196 | for clone_idx, gpu in enumerate(config.gpus): 197 | do_summary = clone_idx == 0 # only summary on the first clone 198 | reuse = clone_idx > 0 199 | with tf.variable_scope(tf.get_variable_scope(), reuse = reuse): 200 | with tf.name_scope(config.clone_scopes[clone_idx]) as clone_scope: 201 | with tf.device(gpu) as clone_device: 202 | b_image, b_pixel_cls_label, b_pixel_cls_weight, \ 203 | b_pixel_link_label, b_pixel_link_weight = batch_queue.dequeue() 204 | # build model and loss 205 | net = pixel_link_symbol.PixelLinkNet(b_image, is_training = True) 206 | net.build_loss( 207 | pixel_cls_labels = b_pixel_cls_label, 208 | pixel_cls_weights = b_pixel_cls_weight, 209 | pixel_link_labels = b_pixel_link_label, 210 | pixel_link_weights = b_pixel_link_weight, 211 | do_summary = do_summary) 212 | 213 | # gather losses 214 | losses = tf.get_collection(tf.GraphKeys.LOSSES, clone_scope) 215 | assert len(losses) == 2 216 | total_clone_loss = tf.add_n(losses) / config.num_clones 217 | pixel_link_loss += total_clone_loss 218 | 219 | # gather regularization loss and add to clone_0 only 220 | if clone_idx == 0: 221 | regularization_loss = tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) 222 | total_clone_loss = total_clone_loss + regularization_loss 223 | 224 | # compute clone gradients 225 | clone_gradients = optimizer.compute_gradients(total_clone_loss) 226 | gradients.append(clone_gradients) 227 | 228 | tf.summary.scalar('pixel_link_loss', pixel_link_loss) 229 | tf.summary.scalar('regularization_loss', regularization_loss) 230 | 231 | # add all gradients together 232 | # note that the gradients do not need to be averaged, because the average operation has been done on loss. 233 | averaged_gradients = sum_gradients(gradients) 234 | 235 | apply_grad_op = optimizer.apply_gradients(averaged_gradients, global_step=global_step) 236 | 237 | train_ops = [apply_grad_op] 238 | 239 | bn_update_op = util.tf.get_update_op() 240 | if bn_update_op is not None: 241 | train_ops.append(bn_update_op) 242 | 243 | # moving average 244 | if FLAGS.using_moving_average: 245 | tf.logging.info('using moving average in training, \ 246 | with decay = %f'%(FLAGS.moving_average_decay)) 247 | ema = tf.train.ExponentialMovingAverage(FLAGS.moving_average_decay) 248 | ema_op = ema.apply(tf.trainable_variables()) 249 | with tf.control_dependencies([apply_grad_op]):# ema after updating 250 | train_ops.append(tf.group(ema_op)) 251 | 252 | train_op = control_flow_ops.with_dependencies(train_ops, pixel_link_loss, name='train_op') 253 | return train_op 254 | 255 | 256 | 257 | def train(train_op): 258 | summary_op = tf.summary.merge_all() 259 | sess_config = tf.ConfigProto(log_device_placement = False, allow_soft_placement = True) 260 | if FLAGS.gpu_memory_fraction < 0: 261 | sess_config.gpu_options.allow_growth = True 262 | elif FLAGS.gpu_memory_fraction > 0: 263 | sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory_fraction; 264 | 265 | init_fn = util.tf.get_init_fn(checkpoint_path = FLAGS.checkpoint_path, train_dir = FLAGS.train_dir, 266 | ignore_missing_vars = FLAGS.ignore_missing_vars, checkpoint_exclude_scopes = FLAGS.checkpoint_exclude_scopes) 267 | saver = tf.train.Saver(max_to_keep = 500, write_version = 2) 268 | slim.learning.train( 269 | train_op, 270 | logdir = FLAGS.train_dir, 271 | init_fn = init_fn, 272 | summary_op = summary_op, 273 | number_of_steps = FLAGS.max_number_of_steps, 274 | log_every_n_steps = FLAGS.log_every_n_steps, 275 | save_summaries_secs = 30, 276 | saver = saver, 277 | save_interval_secs = 1200, 278 | session_config = sess_config 279 | ) 280 | 281 | 282 | def main(_): 283 | # The choice of return dataset object via initialization method maybe confusing, 284 | # but I need to print all configurations in this method, including dataset information. 285 | dataset = config_initialization() 286 | 287 | batch_queue = create_dataset_batch_queue(dataset) 288 | train_op = create_clones(batch_queue) 289 | train(train_op) 290 | 291 | 292 | if __name__ == '__main__': 293 | tf.app.run() 294 | -------------------------------------------------------------------------------- /visualize_detection_result.py: -------------------------------------------------------------------------------- 1 | #encoding utf-8 2 | 3 | import numpy as np 4 | import util 5 | 6 | 7 | def draw_bbox(image_data, line, color): 8 | line = util.str.remove_all(line, '\xef\xbb\xbf') 9 | data = line.split(','); 10 | points = [int(v) for v in data[0:8]] 11 | points = np.reshape(points, (4, 2)) 12 | cnts = util.img.points_to_contours(points) 13 | util.img.draw_contours(image_data, cnts, -1, color = color, border_width = 3) 14 | 15 | 16 | def visualize(image_root, det_root, output_root, gt_root = None): 17 | def read_gt_file(image_name): 18 | gt_file = util.io.join_path(gt_root, 'gt_%s.txt'%(image_name)) 19 | return util.io.read_lines(gt_file) 20 | 21 | def read_det_file(image_name): 22 | det_file = util.io.join_path(det_root, 'res_%s.txt'%(image_name)) 23 | return util.io.read_lines(det_file) 24 | 25 | def read_image_file(image_name): 26 | return util.img.imread(util.io.join_path(image_root, image_name)) 27 | 28 | image_names = util.io.ls(image_root, '.jpg') 29 | for image_idx, image_name in enumerate(image_names): 30 | print '%d / %d: %s'%(image_idx + 1, len(image_names), image_name) 31 | image_data = read_image_file(image_name) # in BGR 32 | image_name = image_name.split('.')[0] 33 | det_image = image_data.copy() 34 | det_lines = read_det_file(image_name) 35 | for line in det_lines: 36 | draw_bbox(det_image, line, color = util.img.COLOR_GREEN) 37 | output_path = util.io.join_path(output_root, '%s_pred.jpg'%(image_name)) 38 | util.img.imwrite(output_path, det_image) 39 | print "Detection result has been written to ", util.io.get_absolute_path(output_path) 40 | 41 | if gt_root is not None: 42 | gt_lines = read_gt_file(image_name) 43 | for line in gt_lines: 44 | draw_bbox(image_data, line, color = util.img.COLOR_GREEN) 45 | util.img.imwrite(util.io.join_path(output_root, '%s_gt.jpg'%(image_name)), image_data) 46 | 47 | if __name__ == '__main__': 48 | import argparse 49 | parser = argparse.ArgumentParser(description='visualize detection result of pixel_link') 50 | parser.add_argument('--image', type=str, required = True,help='the directory of test image') 51 | parser.add_argument('--gt', type=str, default=None,help='the directory of ground truth txt files') 52 | parser.add_argument('--det', type=str, required = True, help='the directory of detection result') 53 | parser.add_argument('--output', type=str, required = True, help='the directory to store images with bboxes') 54 | 55 | args = parser.parse_args() 56 | print('**************Arguments*****************') 57 | print(args) 58 | print('****************************************') 59 | visualize(image_root = args.image, gt_root = args.gt, det_root = args.det, output_root = args.output) 60 | --------------------------------------------------------------------------------