├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── Test Run After Training.ipynb ├── Visualization & Train Test Split.ipynb ├── assets ├── example_input.jpg ├── example_output.jpg ├── output.gif ├── processed.mp4 ├── result1.png ├── result2.png ├── result3.png ├── result4.png ├── result5.png ├── result6.png └── result7.png ├── tests └── utils │ └── data_test.py ├── train.py └── utils ├── __init__.py ├── data.py └── image.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data Files 2 | mask/ 3 | resize/ 4 | 5 | 6 | ###Python### 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | 57 | # Sphinx documentation 58 | docs/_build/ 59 | 60 | # PyBuilder 61 | target/ 62 | 63 | 64 | ###IPythonNotebook### 65 | 66 | # Temporary data 67 | .ipynb_checkpoints/ 68 | 69 | 70 | ###PyCharm### 71 | 72 | # PyCharm 73 | # http://www.jetbrains.com/pycharm/webhelp/project.html 74 | .idea 75 | .iml 76 | 77 | 78 | ###OSX### 79 | 80 | .DS_Store 81 | .AppleDouble 82 | .LSOverride 83 | 84 | # Icon must end with two \r 85 | Icon 86 | 87 | # Thumbnails 88 | ._* 89 | 90 | # Files that might appear on external disk 91 | .Spotlight-V100 92 | .Trashes 93 | 94 | # Directories potentially created on remote AFP share 95 | .AppleDB 96 | .AppleDesktop 97 | Network Trash Folder 98 | Temporary Items 99 | .apdisk 100 | 101 | 102 | ###Linux### 103 | 104 | *~ 105 | 106 | # KDE directory preferences 107 | .directory 108 | 109 | 110 | ###Windows### 111 | 112 | # Windows image file caches 113 | Thumbs.db 114 | ehthumbs.db 115 | 116 | # Folder config file 117 | Desktop.ini 118 | 119 | # Recycle Bin used on file shares 120 | $RECYCLE.BIN/ 121 | 122 | # Windows Installer files 123 | *.cab 124 | *.msi 125 | *.msm 126 | *.msp 127 | 128 | # Windows shortcuts 129 | *.lnk 130 | .mypy_cache 131 | /.projectile 132 | data 133 | data_resuize 134 | data_resize 135 | /labels_resized.csv 136 | object-detection-crowdai/ 137 | /udacity-annoations-crowdai 138 | logdir 139 | /train.csv 140 | /test.csv 141 | models 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kyung Mo Kweon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: download generate fresh clean cleaner 2 | 3 | # Download raw datasets (by Udacity) 4 | download: SHELL:=/bin/bash 5 | download: 6 | @if [[ -f "data/object-detection-crowdai.tar.gz" ]]; then \ 7 | echo "Data exists"; \ 8 | if [[ ! -d "object-detection-crowdai" ]]; then \ 9 | tar xvf data/object-detection-crowdai.tar.gz; \ 10 | fi; \ 11 | else \ 12 | mkdir -p data; \ 13 | wget -O data/object-detection-crowdai.tar.gz "https://s3.amazonaws.com/udacity-sdc/annotations/object-detection-crowdai.tar.gz"; \ 14 | tar xvf data/object-detection-crowdai.tar.gz; \ 15 | fi; \ 16 | if [[ ! -f "data/labels_crowdai.csv" ]]; then \ 17 | wget -O data/labels_crowdai.csv "https://raw.githubusercontent.com/udacity/self-driving-car/master/annotations/labels_crowdai.csv"; \ 18 | fi 19 | 20 | # Generate training images 21 | generate: 22 | python utils/data.py 23 | 24 | # Fresh Training 25 | fresh: 26 | rm -rf logdir models 27 | 28 | # Run tensorboard 29 | tensorboard: 30 | tensorboard --logdir logdir 31 | 32 | # Remove augmented data 33 | clean: 34 | rm -rf data_resize mask 35 | 36 | # Remove raw original data 37 | cleaner: clean 38 | rm -rf data 39 | rm -rf object-detection-crowdai 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # U-Net Implementation in TensorFlow 2 | 3 | 4 | 5 | Re implementation of U-Net in Tensorflow 6 | - to check how image segmentations can be used for detection problems 7 | 8 | Original Paper 9 | - [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://arxiv.org/abs/1505.04597) 10 | 11 | ## Summary 12 | 13 | Vehicle Detection using U-Net 14 | 15 | Objective: detect vehicles 16 | Find a function f such that y = f(X) 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 |
InputShapeExplanationExample
X: 3-D Tensor(640, 960, 3)RGB image in an array
y: 3-D Tensor(640, 960, 1)Binarized image. Bacground is 0
vehicle is masked as 255
37 | 38 | Loss function: maximize IOU 39 | ``` 40 | (intersection of prediction & grount truth) 41 | ------------------------------------------- 42 | (union of prediction & ground truth) 43 | ``` 44 | 45 | ### Examples on Test Data: trained for 3 epochs 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | ## Get Started 56 | 57 | ### Download dataset 58 | 59 | - the annotated driving dataset is provided by [Udacity](https://github.com/udacity/self-driving-car/tree/master/annotations) 60 | - In total, 9,423 frames with 65,000 labels at 1920x1200 resolution. 61 | 62 | ```bash 63 | make download 64 | ``` 65 | 66 | ### Resize image and generate mask images 67 | 68 | - [utils/data.py](./utils/data.py) is used to resize images and generate masks 69 | 70 | ```bash 71 | make generate 72 | ``` 73 | 74 | ### Train Test Split 75 | 76 | Make sure masks and bounding boxes 77 | 78 | ```bash 79 | jupyter notebook "Visualization & Train Test Split.ipynb" 80 | ``` 81 | 82 | ### Train 83 | 84 | ```bash 85 | # Train for 1 epoch 86 | python train.py 87 | ``` 88 | 89 | or 90 | 91 | ```bash 92 | $ python train.py --help 93 | usage: train.py [-h] [--epochs EPOCHS] [--batch-size BATCH_SIZE] 94 | [--logdir LOGDIR] [--reg REG] [--ckdir CKDIR] 95 | 96 | optional arguments: 97 | -h, --help show this help message and exit 98 | --epochs EPOCHS Number of epochs (default: 1) 99 | --batch-size BATCH_SIZE 100 | Batch size (default: 4) 101 | --logdir LOGDIR Tensorboard log directory (default: logdir) 102 | --reg REG L2 Regularizer Term (default: 0.1) 103 | --ckdir CKDIR Checkpoint directory (default: models) 104 | ``` 105 | 106 | ### Test 107 | 108 | - Open the Jupyter notebook file to run against test data 109 | 110 | ```bash 111 | jupyter notebook "./Test Run After Training.ipynb" 112 | ``` 113 | -------------------------------------------------------------------------------- /assets/example_input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/example_input.jpg -------------------------------------------------------------------------------- /assets/example_output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/example_output.jpg -------------------------------------------------------------------------------- /assets/output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/output.gif -------------------------------------------------------------------------------- /assets/processed.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/processed.mp4 -------------------------------------------------------------------------------- /assets/result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/result1.png -------------------------------------------------------------------------------- /assets/result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/result2.png -------------------------------------------------------------------------------- /assets/result3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/result3.png -------------------------------------------------------------------------------- /assets/result4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/result4.png -------------------------------------------------------------------------------- /assets/result5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/result5.png -------------------------------------------------------------------------------- /assets/result6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/result6.png -------------------------------------------------------------------------------- /assets/result7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/assets/result7.png -------------------------------------------------------------------------------- /tests/utils/data_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | 5 | @pytest.fixture 6 | def original(): 7 | original = pd.read_csv("data/labels_crowdai.csv") 8 | original = original[original["Label"].isin(["Car", "Truck"])] 9 | return original.reset_index(drop=True) 10 | 11 | 12 | @pytest.fixture 13 | def resized(): 14 | return pd.read_csv("labels_resized.csv") 15 | 16 | 17 | def test_shape_of_labels(original: pd.DataFrame, 18 | resized: pd.DataFrame) -> None: 19 | """Check the shape of CSV Files""" 20 | 21 | shape_original = original.shape 22 | shape_resized = resized.shape 23 | 24 | # Same Row 25 | assert shape_original[0] == shape_resized[0] 26 | 27 | # Mask column was added to the original 28 | assert shape_original[1] + 1 == shape_resized[1] 29 | 30 | 31 | 32 | def test_bbox_is_smaller(original, resized) -> None: 33 | """resized bbox should be always smaller than original""" 34 | 35 | assert (resized["xmin"] > original["xmin"]).sum() == 0 36 | assert (resized["xmax"] > original["xmax"]).sum() == 0 37 | assert (resized["ymin"] > original["ymin"]).sum() == 0 38 | assert (resized["ymax"] > original["ymax"]).sum() == 0 39 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple U-Net implementation in TensorFlow 3 | 4 | Objective: detect vehicles 5 | 6 | y = f(X) 7 | 8 | X: image (640, 960, 3) 9 | y: mask (640, 960, 1) 10 | - binary image 11 | - background is masked 0 12 | - vehicle is masked 255 13 | 14 | Loss function: maximize IOU 15 | 16 | (intersection of prediction & grount truth) 17 | ------------------------------- 18 | (union of prediction & ground truth) 19 | 20 | Notes: 21 | In the paper, the pixel-wise softmax was used. 22 | But, I used the IOU because the datasets I used are 23 | not labeled for segmentations 24 | 25 | Original Paper: 26 | https://arxiv.org/abs/1505.04597 27 | """ 28 | import time 29 | import os 30 | import pandas as pd 31 | import tensorflow as tf 32 | 33 | 34 | def image_augmentation(image, mask): 35 | """Returns (maybe) augmented images 36 | 37 | (1) Random flip (left <--> right) 38 | (2) Random flip (up <--> down) 39 | (3) Random brightness 40 | (4) Random hue 41 | 42 | Args: 43 | image (3-D Tensor): Image tensor of (H, W, C) 44 | mask (3-D Tensor): Mask image tensor of (H, W, 1) 45 | 46 | Returns: 47 | image: Maybe augmented image (same shape as input `image`) 48 | mask: Maybe augmented mask (same shape as input `mask`) 49 | """ 50 | concat_image = tf.concat([image, mask], axis=-1) 51 | 52 | maybe_flipped = tf.image.random_flip_left_right(concat_image) 53 | maybe_flipped = tf.image.random_flip_up_down(concat_image) 54 | 55 | image = maybe_flipped[:, :, :-1] 56 | mask = maybe_flipped[:, :, -1:] 57 | 58 | image = tf.image.random_brightness(image, 0.7) 59 | image = tf.image.random_hue(image, 0.3) 60 | 61 | return image, mask 62 | 63 | 64 | def get_image_mask(queue, augmentation=True): 65 | """Returns `image` and `mask` 66 | 67 | Input pipeline: 68 | Queue -> CSV -> FileRead -> Decode JPEG 69 | 70 | (1) Queue contains a CSV filename 71 | (2) Text Reader opens the CSV 72 | CSV file contains two columns 73 | ["path/to/image.jpg", "path/to/mask.jpg"] 74 | (3) File Reader opens both files 75 | (4) Decode JPEG to tensors 76 | 77 | Notes: 78 | height, width = 640, 960 79 | 80 | Returns 81 | image (3-D Tensor): (640, 960, 3) 82 | mask (3-D Tensor): (640, 960, 1) 83 | """ 84 | text_reader = tf.TextLineReader(skip_header_lines=1) 85 | _, csv_content = text_reader.read(queue) 86 | 87 | image_path, mask_path = tf.decode_csv( 88 | csv_content, record_defaults=[[""], [""]]) 89 | 90 | image_file = tf.read_file(image_path) 91 | mask_file = tf.read_file(mask_path) 92 | 93 | image = tf.image.decode_jpeg(image_file, channels=3) 94 | image.set_shape([640, 960, 3]) 95 | image = tf.cast(image, tf.float32) 96 | 97 | mask = tf.image.decode_jpeg(mask_file, channels=1) 98 | mask.set_shape([640, 960, 1]) 99 | mask = tf.cast(mask, tf.float32) 100 | mask = mask / (tf.reduce_max(mask) + 1e-7) 101 | 102 | if augmentation: 103 | image, mask = image_augmentation(image, mask) 104 | 105 | return image, mask 106 | 107 | 108 | def conv_conv_pool(input_, 109 | n_filters, 110 | training, 111 | flags, 112 | name, 113 | pool=True, 114 | activation=tf.nn.relu): 115 | """{Conv -> BN -> RELU}x2 -> {Pool, optional} 116 | 117 | Args: 118 | input_ (4-D Tensor): (batch_size, H, W, C) 119 | n_filters (list): number of filters [int, int] 120 | training (1-D Tensor): Boolean Tensor 121 | name (str): name postfix 122 | pool (bool): If True, MaxPool2D 123 | activation: Activaion functions 124 | 125 | Returns: 126 | net: output of the Convolution operations 127 | pool (optional): output of the max pooling operations 128 | """ 129 | net = input_ 130 | 131 | with tf.variable_scope("layer{}".format(name)): 132 | for i, F in enumerate(n_filters): 133 | net = tf.layers.conv2d( 134 | net, 135 | F, (3, 3), 136 | activation=None, 137 | padding='same', 138 | kernel_regularizer=tf.contrib.layers.l2_regularizer(flags.reg), 139 | name="conv_{}".format(i + 1)) 140 | net = tf.layers.batch_normalization( 141 | net, training=training, name="bn_{}".format(i + 1)) 142 | net = activation(net, name="relu{}_{}".format(name, i + 1)) 143 | 144 | if pool is False: 145 | return net 146 | 147 | pool = tf.layers.max_pooling2d( 148 | net, (2, 2), strides=(2, 2), name="pool_{}".format(name)) 149 | 150 | return net, pool 151 | 152 | 153 | def upconv_concat(inputA, input_B, n_filter, flags, name): 154 | """Upsample `inputA` and concat with `input_B` 155 | 156 | Args: 157 | input_A (4-D Tensor): (N, H, W, C) 158 | input_B (4-D Tensor): (N, 2*H, 2*H, C2) 159 | name (str): name of the concat operation 160 | 161 | Returns: 162 | output (4-D Tensor): (N, 2*H, 2*W, C + C2) 163 | """ 164 | up_conv = upconv_2D(inputA, n_filter, flags, name) 165 | 166 | return tf.concat( 167 | [up_conv, input_B], axis=-1, name="concat_{}".format(name)) 168 | 169 | 170 | def upconv_2D(tensor, n_filter, flags, name): 171 | """Up Convolution `tensor` by 2 times 172 | 173 | Args: 174 | tensor (4-D Tensor): (N, H, W, C) 175 | n_filter (int): Filter Size 176 | name (str): name of upsampling operations 177 | 178 | Returns: 179 | output (4-D Tensor): (N, 2 * H, 2 * W, C) 180 | """ 181 | 182 | return tf.layers.conv2d_transpose( 183 | tensor, 184 | filters=n_filter, 185 | kernel_size=2, 186 | strides=2, 187 | kernel_regularizer=tf.contrib.layers.l2_regularizer(flags.reg), 188 | name="upsample_{}".format(name)) 189 | 190 | 191 | def make_unet(X, training, flags=None): 192 | """Build a U-Net architecture 193 | 194 | Args: 195 | X (4-D Tensor): (N, H, W, C) 196 | training (1-D Tensor): Boolean Tensor is required for batchnormalization layers 197 | 198 | Returns: 199 | output (4-D Tensor): (N, H, W, C) 200 | Same shape as the `input` tensor 201 | 202 | Notes: 203 | U-Net: Convolutional Networks for Biomedical Image Segmentation 204 | https://arxiv.org/abs/1505.04597 205 | """ 206 | net = X / 127.5 - 1 207 | conv1, pool1 = conv_conv_pool(net, [8, 8], training, flags, name=1) 208 | conv2, pool2 = conv_conv_pool(pool1, [16, 16], training, flags, name=2) 209 | conv3, pool3 = conv_conv_pool(pool2, [32, 32], training, flags, name=3) 210 | conv4, pool4 = conv_conv_pool(pool3, [64, 64], training, flags, name=4) 211 | conv5 = conv_conv_pool( 212 | pool4, [128, 128], training, flags, name=5, pool=False) 213 | 214 | up6 = upconv_concat(conv5, conv4, 64, flags, name=6) 215 | conv6 = conv_conv_pool(up6, [64, 64], training, flags, name=6, pool=False) 216 | 217 | up7 = upconv_concat(conv6, conv3, 32, flags, name=7) 218 | conv7 = conv_conv_pool(up7, [32, 32], training, flags, name=7, pool=False) 219 | 220 | up8 = upconv_concat(conv7, conv2, 16, flags, name=8) 221 | conv8 = conv_conv_pool(up8, [16, 16], training, flags, name=8, pool=False) 222 | 223 | up9 = upconv_concat(conv8, conv1, 8, flags, name=9) 224 | conv9 = conv_conv_pool(up9, [8, 8], training, flags, name=9, pool=False) 225 | 226 | return tf.layers.conv2d( 227 | conv9, 228 | 1, (1, 1), 229 | name='final', 230 | activation=tf.nn.sigmoid, 231 | padding='same') 232 | 233 | 234 | def IOU_(y_pred, y_true): 235 | """Returns a (approx) IOU score 236 | 237 | intesection = y_pred.flatten() * y_true.flatten() 238 | Then, IOU = 2 * intersection / (y_pred.sum() + y_true.sum() + 1e-7) + 1e-7 239 | 240 | Args: 241 | y_pred (4-D array): (N, H, W, 1) 242 | y_true (4-D array): (N, H, W, 1) 243 | 244 | Returns: 245 | float: IOU score 246 | """ 247 | H, W, _ = y_pred.get_shape().as_list()[1:] 248 | 249 | pred_flat = tf.reshape(y_pred, [-1, H * W]) 250 | true_flat = tf.reshape(y_true, [-1, H * W]) 251 | 252 | intersection = 2 * tf.reduce_sum(pred_flat * true_flat, axis=1) + 1e-7 253 | denominator = tf.reduce_sum( 254 | pred_flat, axis=1) + tf.reduce_sum( 255 | true_flat, axis=1) + 1e-7 256 | 257 | return tf.reduce_mean(intersection / denominator) 258 | 259 | 260 | def make_train_op(y_pred, y_true): 261 | """Returns a training operation 262 | 263 | Loss function = - IOU(y_pred, y_true) 264 | 265 | IOU is 266 | 267 | (the area of intersection) 268 | -------------------------- 269 | (the area of two boxes) 270 | 271 | Args: 272 | y_pred (4-D Tensor): (N, H, W, 1) 273 | y_true (4-D Tensor): (N, H, W, 1) 274 | 275 | Returns: 276 | train_op: minimize operation 277 | """ 278 | loss = -IOU_(y_pred, y_true) 279 | 280 | global_step = tf.train.get_or_create_global_step() 281 | 282 | optim = tf.train.AdamOptimizer() 283 | return optim.minimize(loss, global_step=global_step) 284 | 285 | 286 | def read_flags(): 287 | """Returns flags""" 288 | 289 | import argparse 290 | 291 | parser = argparse.ArgumentParser( 292 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 293 | parser.add_argument( 294 | "--epochs", default=1, type=int, help="Number of epochs") 295 | 296 | parser.add_argument("--batch-size", default=4, type=int, help="Batch size") 297 | 298 | parser.add_argument( 299 | "--logdir", default="logdir", help="Tensorboard log directory") 300 | 301 | parser.add_argument( 302 | "--reg", type=float, default=0.1, help="L2 Regularizer Term") 303 | 304 | parser.add_argument( 305 | "--ckdir", default="models", help="Checkpoint directory") 306 | 307 | flags = parser.parse_args() 308 | return flags 309 | 310 | 311 | def main(flags): 312 | train = pd.read_csv("./train.csv") 313 | n_train = train.shape[0] 314 | 315 | test = pd.read_csv("./test.csv") 316 | n_test = test.shape[0] 317 | 318 | current_time = time.strftime("%m/%d/%H/%M/%S") 319 | train_logdir = os.path.join(flags.logdir, "train", current_time) 320 | test_logdir = os.path.join(flags.logdir, "test", current_time) 321 | 322 | tf.reset_default_graph() 323 | X = tf.placeholder(tf.float32, shape=[None, 640, 960, 3], name="X") 324 | y = tf.placeholder(tf.float32, shape=[None, 640, 960, 1], name="y") 325 | mode = tf.placeholder(tf.bool, name="mode") 326 | 327 | pred = make_unet(X, mode, flags) 328 | 329 | tf.add_to_collection("inputs", X) 330 | tf.add_to_collection("inputs", mode) 331 | tf.add_to_collection("outputs", pred) 332 | 333 | tf.summary.histogram("Predicted Mask", pred) 334 | tf.summary.image("Predicted Mask", pred) 335 | 336 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 337 | 338 | with tf.control_dependencies(update_ops): 339 | train_op = make_train_op(pred, y) 340 | 341 | IOU_op = IOU_(pred, y) 342 | IOU_op = tf.Print(IOU_op, [IOU_op]) 343 | tf.summary.scalar("IOU", IOU_op) 344 | 345 | train_csv = tf.train.string_input_producer(['train.csv']) 346 | test_csv = tf.train.string_input_producer(['test.csv']) 347 | train_image, train_mask = get_image_mask(train_csv) 348 | test_image, test_mask = get_image_mask(test_csv, augmentation=False) 349 | 350 | X_batch_op, y_batch_op = tf.train.shuffle_batch( 351 | [train_image, train_mask], 352 | batch_size=flags.batch_size, 353 | capacity=flags.batch_size * 5, 354 | min_after_dequeue=flags.batch_size * 2, 355 | allow_smaller_final_batch=True) 356 | 357 | X_test_op, y_test_op = tf.train.batch( 358 | [test_image, test_mask], 359 | batch_size=flags.batch_size, 360 | capacity=flags.batch_size * 2, 361 | allow_smaller_final_batch=True) 362 | 363 | summary_op = tf.summary.merge_all() 364 | 365 | with tf.Session() as sess: 366 | train_summary_writer = tf.summary.FileWriter(train_logdir, sess.graph) 367 | test_summary_writer = tf.summary.FileWriter(test_logdir) 368 | 369 | init = tf.global_variables_initializer() 370 | sess.run(init) 371 | 372 | saver = tf.train.Saver() 373 | if os.path.exists(flags.ckdir) and tf.train.checkpoint_exists( 374 | flags.ckdir): 375 | latest_check_point = tf.train.latest_checkpoint(flags.ckdir) 376 | saver.restore(sess, latest_check_point) 377 | 378 | else: 379 | try: 380 | os.rmdir(flags.ckdir) 381 | except FileNotFoundError: 382 | pass 383 | os.mkdir(flags.ckdir) 384 | 385 | try: 386 | global_step = tf.train.get_global_step(sess.graph) 387 | 388 | coord = tf.train.Coordinator() 389 | threads = tf.train.start_queue_runners(coord=coord) 390 | 391 | for epoch in range(flags.epochs): 392 | 393 | for step in range(0, n_train, flags.batch_size): 394 | 395 | X_batch, y_batch = sess.run([X_batch_op, y_batch_op]) 396 | 397 | _, step_iou, step_summary, global_step_value = sess.run( 398 | [train_op, IOU_op, summary_op, global_step], 399 | feed_dict={X: X_batch, 400 | y: y_batch, 401 | mode: True}) 402 | 403 | train_summary_writer.add_summary(step_summary, 404 | global_step_value) 405 | 406 | total_iou = 0 407 | for step in range(0, n_test, flags.batch_size): 408 | X_test, y_test = sess.run([X_test_op, y_test_op]) 409 | step_iou, step_summary = sess.run( 410 | [IOU_op, summary_op], 411 | feed_dict={X: X_test, 412 | y: y_test, 413 | mode: False}) 414 | 415 | total_iou += step_iou * X_test.shape[0] 416 | 417 | test_summary_writer.add_summary(step_summary, 418 | (epoch + 1) * (step + 1)) 419 | 420 | saver.save(sess, "{}/model.ckpt".format(flags.ckdir)) 421 | 422 | finally: 423 | coord.request_stop() 424 | coord.join(threads) 425 | saver.save(sess, "{}/model.ckpt".format(flags.ckdir)) 426 | 427 | 428 | if __name__ == '__main__': 429 | flags = read_flags() 430 | main(flags) 431 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkweon/UNet-in-Tensorflow/ef36bfbd7398d4ffd02bb7a9329715f4419aa54f/utils/__init__.py -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=E1101,C0103,C0326,W1202 2 | """ 3 | Data Related Functions 4 | 5 | """ 6 | import argparse 7 | import logging 8 | import os 9 | import shutil 10 | import time 11 | from collections import namedtuple 12 | from multiprocessing.pool import Pool 13 | # For Typing Annotation 14 | from typing import List, Tuple 15 | 16 | import cv2 17 | import numpy as np 18 | import pandas as pd 19 | 20 | from .image import read_image, read_image_and_resize 21 | 22 | logging.basicConfig(level=logging.INFO) 23 | LOGGER = logging.getLogger(__name__) 24 | 25 | Box = namedtuple("Box", ["left_top", "right_bot"]) 26 | 27 | 28 | def read_flags(): 29 | """Returns global variables""" 30 | 31 | parser = argparse.ArgumentParser( 32 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 33 | description="resize image and adjusts coordinates") 34 | parser.add_argument( 35 | "--src_csv", 36 | default="data/labels_crowdai.csv", 37 | help="/path/to/labels.csv") 38 | 39 | parser.add_argument( 40 | "--data_dir", 41 | default="object-detection-crowdai", 42 | help="Directory where training datasets are located") 43 | 44 | parser.add_argument( 45 | "--save_dir", 46 | default="data_resize", 47 | help="path to the directory in which resize image will be saved") 48 | 49 | parser.add_argument( 50 | "--target_width", default=960, help="new target width (default: 960)") 51 | 52 | parser.add_argument( 53 | "--target_height", default=640, help="target height (default: 640)") 54 | 55 | parser.add_argument( 56 | "--target_csv", 57 | default="labels_resized.csv", 58 | help="target csv filename") 59 | 60 | return parser.parse_args() 61 | 62 | 63 | def get_boxes(df: pd.DataFrame) -> List[Box]: 64 | """Given relevant DATAFRAME return a list of BOX""" 65 | boxes = [] 66 | for _, items in df.iterrows(): 67 | 68 | left_top = items["xmin"], items["ymin"] 69 | right_bot = items["xmax"], items["ymax"] 70 | 71 | boxes.append(Box(left_top, right_bot)) 72 | return boxes 73 | 74 | 75 | def create_clean_dir(dirname: str) -> None: 76 | """Create an empty directory 77 | 78 | Args: 79 | dirname (str): An empty directory name to create 80 | """ 81 | 82 | if os.path.exists(dirname): 83 | shutil.rmtree(dirname) 84 | 85 | assert os.path.exists(dirname) is False 86 | 87 | os.mkdir(dirname) 88 | 89 | assert not os.listdir(dirname) 90 | 91 | 92 | def adjust_bbox(bboxframe: pd.DataFrame, 93 | src_size: Tuple[int, int], 94 | dst_size: Tuple[int, int]) -> pd.DataFrame: 95 | """Returns a new dataframe with adjusted coordinates 96 | 97 | W W_new 98 | +----+ ----> +-+ 99 | | | H | | H_new 100 | +----+ +-+ 101 | Args: 102 | bboxframe (pd.DataFrame): Bounding box infor dataframe 103 | src_size (Tuple[int, int]): Original image (width, height) 104 | dst_size (Tuple[int, int]): New image (width, height) 105 | 106 | Returns: 107 | pd.DataFrame: Its coordinates are adjusted to a new size 108 | """ 109 | W, H = src_size 110 | W_new, H_new = dst_size 111 | 112 | bboxframe = bboxframe.copy() 113 | 114 | bboxframe['xmin'] = (bboxframe['xmin'] * W_new / W).astype(np.int16) 115 | bboxframe['xmax'] = (bboxframe['xmax'] * W_new / W).astype(np.int16) 116 | bboxframe['ymin'] = (bboxframe['ymin'] * H_new / H).astype(np.int16) 117 | bboxframe['ymax'] = (bboxframe['ymax'] * H_new / H).astype(np.int16) 118 | 119 | return bboxframe 120 | 121 | 122 | def get_relevant_frames(image_path: str, 123 | dataframe: pd.DataFrame) -> pd.DataFrame: 124 | """Returns a dataframe that contains truck image 125 | 126 | Args: 127 | image_path (str): "path/to/image.jpg" 128 | dataframe (pd.DataFrame): The base frame to be searched 129 | 130 | Returns: 131 | pd.DataFrame: A dataframe that contains input images 132 | """ 133 | 134 | cond = dataframe["Frame"] == image_path 135 | return dataframe[cond].reset_index(drop=True) 136 | 137 | 138 | def get_mask(image: np.ndarray, bbox_frame: pd.DataFrame) -> np.ndarray: 139 | """Returns a binary mask 140 | 141 | Args: 142 | image (3-D array): Numpy array of shape (H, W, C) 143 | bbox_frame (pd.DataFrame): Dataframe related with the input image 144 | It contains bounding box coordinates 145 | 146 | Returns: 147 | 2-D array: Mask shape (H, W) 148 | 1 for bounding box area 149 | 0 for background 150 | """ 151 | 152 | H, W, _ = image.shape 153 | 154 | mask = np.zeros((H, W)) 155 | 156 | for _, row in bbox_frame.iterrows(): 157 | W_beg, W_end = row['xmin'], row['xmax'] 158 | H_beg, H_end = row['ymin'], row['ymax'] 159 | 160 | mask[H_beg:H_end, W_beg:W_end] = 1 161 | 162 | return mask 163 | 164 | 165 | def create_mask(image_WH: Tuple[int, int], 166 | image_path: str, 167 | dataframe: pd.DataFrame) -> np.ndarray: 168 | """Returns a mask array 169 | 170 | Object = 255 171 | Else = 0 172 | 173 | Args: 174 | image_WH (Tuple[int, int]): Image size (width, height) 175 | image_path (str): /path/to/image.jpg 176 | dataframe (pd.DataFrame): DataFrame contains bbox information 177 | 178 | Returns: 179 | 2-D array: Mask array 180 | 181 | Examples: 182 | >>> image_WH = (960, 640) 183 | >>> image_path = "images/image000.jpg" 184 | >>> mask = create_mask(image_WH, image_path, dataframe) 185 | """ 186 | 187 | W, H = image_WH 188 | mask = np.zeros((H, W)) 189 | 190 | bbox_frame = get_relevant_frames(image_path, dataframe) 191 | 192 | for _, row in bbox_frame.iterrows(): 193 | W_beg, W_end = row['xmin'], row['xmax'] 194 | H_beg, H_end = row['ymin'], row['ymax'] 195 | 196 | mask[H_beg:H_end, W_beg:W_end] = 255 197 | 198 | return mask 199 | 200 | 201 | def generate_mask_pipeline(image_WH: Tuple[int, int], 202 | image_path: str, 203 | dataframe: pd.DataFrame, 204 | save_dir: str="mask") -> None: 205 | """Create a mask and save as JPG 206 | 207 | Args: 208 | image_WH (Tuple[int, int]): (width: int, height: int) 209 | image_path (str): path/to/image.jpg 210 | dataframe (pd.DataFrame): labels.csv 211 | save_dir (str): Save directory 212 | """ 213 | filename = os.path.basename(image_path) 214 | full_path = os.path.join(save_dir, filename) 215 | mask = create_mask(image_WH, image_path, dataframe) 216 | 217 | cv2.imwrite(full_path, mask) 218 | 219 | 220 | def main(FLAGS): 221 | """Main Function 222 | 223 | Notes: 224 | 1. Read image and resize to Target Width, Height 225 | 2. Resize bounding box coordinates accordingly 226 | 3. Create masks with the bounding box 227 | background is 0 and vehicle is 255 228 | 229 | """ 230 | new_WH = (FLAGS.target_width, FLAGS.target_height) 231 | data = pd.read_csv(FLAGS.src_csv) 232 | # Only consider car and truck images 233 | data = data[data["Label"].isin(["Car", "Truck"])].reset_index(drop=True) 234 | 235 | # 123.jpg -> object-detection-crowdai/123.jpg 236 | data["Frame"] = data["Frame"].map( 237 | lambda x: os.path.join(FLAGS.data_dir, x)) 238 | 239 | # IF dir exists, clean it 240 | create_clean_dir(FLAGS.save_dir) 241 | LOGGER.info("Cleaned {} directory".format(FLAGS.save_dir)) 242 | 243 | LOGGER.info("Resizing begins") 244 | start = time.time() 245 | pool = Pool() 246 | pool.starmap_async(read_image_and_resize, 247 | [(image_path, new_WH, FLAGS.save_dir) 248 | for image_path in data["Frame"].unique()]) 249 | 250 | pool.close() 251 | pool.join() 252 | end = time.time() 253 | 254 | LOGGER.info("Time elapsed: {}".format(end - start)) 255 | LOGGER.info("Resizing ends") 256 | 257 | LOGGER.info("Adjusting dataframe") 258 | 259 | # Read any image file to get the WIDTH and HEIGHT 260 | image_path = data["Frame"][0] 261 | image = read_image(image_path) 262 | 263 | H, W, _ = image.shape 264 | src_size = (W, H) 265 | 266 | labels = adjust_bbox(data, src_size, new_WH) 267 | 268 | # object-.../123.jpg -> data_resize/123.jpg 269 | labels["Frame"] = labels["Frame"].map( 270 | lambda x: os.path.join(FLAGS.save_dir, os.path.basename(x))) 271 | 272 | create_clean_dir("mask") 273 | LOGGER.info("Cleaned {} directory".format("mask")) 274 | LOGGER.info("Masking begin") 275 | start = time.time() 276 | 277 | pool = Pool() 278 | tasks = [(new_WH, image_path, labels, "mask") 279 | for image_path in labels["Frame"].unique()] 280 | pool.starmap_async(generate_mask_pipeline, tasks) 281 | pool.close() 282 | pool.join() 283 | end = time.time() 284 | LOGGER.info("Masking ends. Time elapsed: {}".format(end - start)) 285 | 286 | labels["Mask"] = labels["Frame"].map( 287 | lambda x: os.path.join("mask", os.path.basename(x))) 288 | labels.to_csv(FLAGS.target_csv, index=False) 289 | 290 | LOGGER.info("Adjustment saved to {}".format(FLAGS.target_csv)) 291 | 292 | 293 | if __name__ == '__main__': 294 | flags = read_flags() 295 | main(flags) 296 | -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=C0326,W0102,C0103,R0913,E1101 2 | """ 3 | Image related functions 4 | """ 5 | import os 6 | from typing import Optional, Tuple 7 | 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | 13 | def read_image(image_path: str, gray: bool=False) -> np.ndarray: 14 | """Returns an image array 15 | 16 | Args: 17 | image_path (str): Path to image.jpg 18 | gray (bool): Grayscale flag 19 | 20 | Returns: 21 | np.ndarray: 22 | 3D numpy array of shape (H, W, 3) or 2D Grayscale Image (H, W) 23 | """ 24 | if gray: 25 | return cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) 26 | 27 | image = cv2.imread(image_path) 28 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 29 | 30 | 31 | def read_image_and_resize(image_path: str, 32 | new_WH: Tuple[int, int]=(512, 512), 33 | save_dir: str="resize") -> str: 34 | """Reads an image and resize it 35 | 36 | 1) open `image_path` that is image.jpg 37 | 2) resize to `new_WH` 38 | 3) save to save_dir/image.jpg 39 | 4) returns `image_path` 40 | 41 | Args: 42 | image_path (str): /path/to/image.jpg 43 | new_WH (tuple): Target width & height to resize 44 | save_dir (str): Directory name to save a resized image 45 | 46 | Returns: 47 | image_path (str): same as input `image_path` 48 | """ 49 | assert os.path.exists(save_dir) is True 50 | new_path = os.path.join(save_dir, os.path.basename(image_path)) 51 | image = cv2.imread(image_path) 52 | image = cv2.resize(image, new_WH, interpolation=cv2.INTER_AREA) 53 | cv2.imwrite(new_path, image) 54 | 55 | return image_path 56 | 57 | 58 | def plot_image(image: np.ndarray, title: Optional[str]=None, **kwargs) -> None: 59 | """Plot a single image 60 | 61 | Args: 62 | image (2-D or 3-D array): image as a numpy array (H, W) or (H, W, C) 63 | title (str, optional): title for a plot 64 | **kwargs: keyword arguemtns for `plt.imshow` 65 | """ 66 | shape = image.shape 67 | 68 | if len(shape) == 3: 69 | plt.imshow(image, **kwargs) 70 | elif len(shape) == 2: 71 | plt.imshow(image, **kwargs) 72 | else: 73 | raise TypeError( 74 | "2-D array or 3-D array should be given but {} was given".format( 75 | shape)) 76 | 77 | if title: 78 | plt.title(title) 79 | 80 | 81 | def plot_two_images(image_A: np.ndarray, 82 | title_A: str, 83 | image_B: np.ndarray, 84 | title_B: str, 85 | figsize: Tuple[int, int]=(15, 15), 86 | kwargs_1: dict={}, 87 | kwargs_2: dict={}) -> None: 88 | """Plot two images side by side""" 89 | plt.figure(figsize=figsize) 90 | plt.subplot(1, 2, 1) 91 | plot_image(image_A, title=title_A, **kwargs_1) 92 | 93 | plt.subplot(1, 2, 2) 94 | plot_image(image_B, title=title_B, **kwargs_2) 95 | 96 | 97 | def draw_bbox(image: np.ndarray, 98 | left_top: Tuple[int, int], 99 | right_bot: Tuple[int, int], 100 | color: Tuple[int, int, int], 101 | thickness: int, 102 | **kwargs) -> np.ndarray: 103 | """Returns an image with the bounding box 104 | Args: 105 | image (np.ndarray): Numpy array of shape (H, W, C) 106 | left_top (Tuple[int, int]): The left top coordinate, (column, row) 107 | right_bot (Tuple[int, int]): The right bottom coordinate (column, row) 108 | color (Tuple[int, int, int]): Color (R, G, B) 109 | thickness (int): Thickness of the box 110 | **kwargs (dict): kwargs for cv2.rectangle 111 | 112 | Returns: 113 | np.ndarray: Numpy array of same shape with bounding boxes 114 | """ 115 | return cv2.rectangle( 116 | image, left_top, right_bot, color=color, thickness=thickness, **kwargs) 117 | --------------------------------------------------------------------------------