├── .gitignore ├── ImageAugmenter.py ├── LICENSE ├── README.md ├── apply_convnet.py ├── dataset.py ├── images ├── 1-266-19922494159_78303f8b16_n.jpg ├── 1-39-125521249_b1318298ec_n.jpg ├── 1-56-180653960_21cf28e0b3_n.jpg ├── 1-61-213767259_11c8550a0e_n.jpg ├── 3-2287-2088100404_c0112197e3_n.jpg └── 3-2831-10902603864_4993c4aa1a_n.jpg ├── predictions └── README.md └── train_convnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | predictions/*.jpg 2 | predictions/*.png 3 | apply_locator_output/*.jpg 4 | apply_locator_output/*.png 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | env/ 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *,cover 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | 58 | # Sphinx documentation 59 | docs/_build/ 60 | 61 | # PyBuilder 62 | target/ 63 | -------------------------------------------------------------------------------- /ImageAugmenter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Wrapper functions and classes around scikit-images AffineTransformation. 3 | Simplifies augmentation of images in machine learning. 4 | 5 | Example usage: 6 | img_width = 32 # width of the images 7 | img_height = 32 # height of the images 8 | images = ... # e.g. load via scipy.misc.imload(filename) 9 | 10 | # For each image: randomly flip it horizontally (50% chance), 11 | # randomly rotate it between -20 and +20 degrees, randomly translate 12 | # it on the x-axis between -5 and +5 pixel. 13 | ia = ImageAugmenter(img_width, img_height, hlip=True, rotation_deg=20, 14 | translation_x_px=5) 15 | augmented_images = ia.augment_batch(images) 16 | """ 17 | from __future__ import division 18 | from skimage import transform as tf 19 | import numpy as np 20 | import random 21 | 22 | def is_minmax_tuple(param): 23 | """Returns whether the parameter is a tuple containing two values. 24 | 25 | Used in create_aug_matrices() and probably useless everywhere else. 26 | 27 | Args: 28 | param: The parameter to check (whether it is a tuple of length 2). 29 | 30 | Returns: 31 | Boolean 32 | """ 33 | return type(param) is tuple and len(param) == 2 34 | 35 | def create_aug_matrices(nb_matrices, img_width_px, img_height_px, 36 | scale_to_percent=1.0, scale_axis_equally=False, 37 | rotation_deg=0, shear_deg=0, 38 | translation_x_px=0, translation_y_px=0, 39 | seed=None): 40 | """Creates the augmentation matrices that may later be used to transform 41 | images. 42 | 43 | This is a wrapper around scikit-image's transform.AffineTransform class. 44 | You can apply those matrices to images using the apply_aug_matrices() 45 | function. 46 | 47 | Args: 48 | nb_matrices: How many matrices to return, e.g. 100 returns 100 different 49 | random-generated matrices (= 100 different transformations). 50 | img_width_px: Width of the images that will be transformed later 51 | on (same as the width of each of the matrices). 52 | img_height_px: Height of the images that will be transformed later 53 | on (same as the height of each of the matrices). 54 | scale_to_percent: Same as in ImageAugmenter.__init__(). 55 | Up to which percentage the images may be 56 | scaled/zoomed. The negative scaling is automatically derived 57 | from this value. A value of 1.1 allows scaling by any value 58 | between -10% and +10%. You may set min and max values yourself 59 | by using a tuple instead, like (1.1, 1.2) to scale between 60 | +10% and +20%. Default is 1.0 (no scaling). 61 | scale_axis_equally: Same as in ImageAugmenter.__init__(). 62 | Whether to always scale both axis (x and y) 63 | in the same way. If set to False, then e.g. the Augmenter 64 | might scale the x-axis by 20% and the y-axis by -5%. 65 | Default is False. 66 | rotation_deg: Same as in ImageAugmenter.__init__(). 67 | By how much the image may be rotated around its 68 | center (in degrees). The negative rotation will automatically 69 | be derived from this value. E.g. a value of 20 allows any 70 | rotation between -20 degrees and +20 degrees. You may set min 71 | and max values yourself by using a tuple instead, e.g. (5, 20) 72 | to rotate between +5 und +20 degrees. Default is 0 (no 73 | rotation). 74 | shear_deg: Same as in ImageAugmenter.__init__(). 75 | By how much the image may be sheared (in degrees). The 76 | negative value will automatically be derived from this value. 77 | E.g. a value of 20 allows any shear between -20 degrees and 78 | +20 degrees. You may set min and max values yourself by using a 79 | tuple instead, e.g. (5, 20) to shear between +5 und +20 80 | degrees. Default is 0 (no shear). 81 | translation_x_px: Same as in ImageAugmenter.__init__(). 82 | By up to how many pixels the image may be 83 | translated (moved) on the x-axis. The negative value will 84 | automatically be derived from this value. E.g. a value of +7 85 | allows any translation between -7 and +7 pixels on the x-axis. 86 | You may set min and max values yourself by using a tuple 87 | instead, e.g. (5, 20) to translate between +5 und +20 pixels. 88 | Default is 0 (no translation on the x-axis). 89 | translation_y_px: Same as in ImageAugmenter.__init__(). 90 | See translation_x_px, just for the y-axis. 91 | seed: Seed to use for python's and numpy's random functions. 92 | 93 | Returns: 94 | List of augmentation matrices. 95 | """ 96 | assert nb_matrices > 0 97 | assert img_width_px > 0 98 | assert img_height_px > 0 99 | assert is_minmax_tuple(scale_to_percent) or scale_to_percent >= 1.0 100 | assert is_minmax_tuple(rotation_deg) or rotation_deg >= 0 101 | assert is_minmax_tuple(shear_deg) or shear_deg >= 0 102 | assert is_minmax_tuple(translation_x_px) or translation_x_px >= 0 103 | assert is_minmax_tuple(translation_y_px) or translation_y_px >= 0 104 | 105 | if seed is not None: 106 | random.seed(seed) 107 | np.random.seed(seed) 108 | 109 | result = [] 110 | 111 | shift_x = int(img_width_px / 2.0) 112 | shift_y = int(img_height_px / 2.0) 113 | 114 | # prepare min and max values for 115 | # scaling/zooming (min/max values) 116 | if is_minmax_tuple(scale_to_percent): 117 | scale_x_min = scale_to_percent[0] 118 | scale_x_max = scale_to_percent[1] 119 | else: 120 | scale_x_min = scale_to_percent 121 | scale_x_max = 1.0 - (scale_to_percent - 1.0) 122 | assert scale_x_min > 0.0 123 | #if scale_x_max >= 2.0: 124 | # warnings.warn("Scaling by more than 100 percent (%.2f)." % (scale_x_max,)) 125 | scale_y_min = scale_x_min # scale_axis_equally affects the random value generation 126 | scale_y_max = scale_x_max 127 | 128 | # rotation (min/max values) 129 | if is_minmax_tuple(rotation_deg): 130 | rotation_deg_min = rotation_deg[0] 131 | rotation_deg_max = rotation_deg[1] 132 | else: 133 | rotation_deg_min = (-1) * int(rotation_deg) 134 | rotation_deg_max = int(rotation_deg) 135 | 136 | # shear (min/max values) 137 | if is_minmax_tuple(shear_deg): 138 | shear_deg_min = shear_deg[0] 139 | shear_deg_max = shear_deg[1] 140 | else: 141 | shear_deg_min = (-1) * int(shear_deg) 142 | shear_deg_max = int(shear_deg) 143 | 144 | # translation x-axis (min/max values) 145 | if is_minmax_tuple(translation_x_px): 146 | translation_x_px_min = translation_x_px[0] 147 | translation_x_px_max = translation_x_px[1] 148 | else: 149 | translation_x_px_min = (-1) * translation_x_px 150 | translation_x_px_max = translation_x_px 151 | 152 | # translation y-axis (min/max values) 153 | if is_minmax_tuple(translation_y_px): 154 | translation_y_px_min = translation_y_px[0] 155 | translation_y_px_max = translation_y_px[1] 156 | else: 157 | translation_y_px_min = (-1) * translation_y_px 158 | translation_y_px_max = translation_y_px 159 | 160 | # create nb_matrices randomized affine transformation matrices 161 | for _ in range(nb_matrices): 162 | # generate random values for scaling, rotation, shear, translation 163 | scale_x = random.uniform(scale_x_min, scale_x_max) 164 | scale_y = random.uniform(scale_y_min, scale_y_max) 165 | if not scale_axis_equally: 166 | scale_y = random.uniform(scale_y_min, scale_y_max) 167 | else: 168 | scale_y = scale_x 169 | rotation = np.deg2rad(random.randint(rotation_deg_min, rotation_deg_max)) 170 | shear = np.deg2rad(random.randint(shear_deg_min, shear_deg_max)) 171 | translation_x = random.randint(translation_x_px_min, translation_x_px_max) 172 | translation_y = random.randint(translation_y_px_min, translation_y_px_max) 173 | 174 | # create three affine transformation matrices 175 | # 1st one moves the image to the top left, 2nd one transforms it, 3rd one 176 | # moves it back to the center. 177 | # The movement is neccessary, because rotation is applied to the top left 178 | # and not to the image's center (same for scaling and shear). 179 | matrix_to_topleft = tf.SimilarityTransform(translation=[-shift_x, -shift_y]) 180 | matrix_transforms = tf.AffineTransform(scale=(scale_x, scale_y), 181 | rotation=rotation, shear=shear, 182 | translation=(translation_x, 183 | translation_y)) 184 | matrix_to_center = tf.SimilarityTransform(translation=[shift_x, shift_y]) 185 | 186 | # Combine the three matrices to one affine transformation (one matrix) 187 | matrix = matrix_to_topleft + matrix_transforms + matrix_to_center 188 | 189 | # one matrix is ready, add it to the result 190 | result.append(matrix.inverse) 191 | 192 | return result 193 | 194 | def apply_aug_matrices(images, matrices, transform_channels_equally=True, 195 | channel_is_first_axis=False, random_order=True, 196 | mode="constant", cval=0.0, interpolation_order=1, 197 | seed=None): 198 | """Augment the given images using the given augmentation matrices. 199 | 200 | This function is a wrapper around scikit-image's transform.warp(). 201 | It is expected to be called by ImageAugmenter.augment_batch(). 202 | The matrices may be generated by create_aug_matrices(). 203 | 204 | Args: 205 | images: Same as in ImageAugmenter.augment_batch(). 206 | Numpy array (dtype: uint8, i.e. values 0-255) with the images. 207 | Expected shape is either (image-index, height, width) for 208 | grayscale images or (image-index, channel, height, width) for 209 | images with channels (e.g. RGB) where the channel has the first 210 | index or (image-index, height, width, channel) for images with 211 | channels, where the channel is the last index. 212 | If your shape is (image-index, channel, width, height) then 213 | you must also set channel_is_first_axis=True in the constructor. 214 | matrices: A list of augmentation matrices as produced by 215 | create_aug_matrices(). 216 | transform_channels_equally: Same as in ImageAugmenter.__init__(). 217 | Whether to apply the exactly same 218 | transformations to each channel of an image (True). Setting 219 | it to False allows different transformations per channel, 220 | e.g. the red-channel might be rotated by +20 degrees, while 221 | the blue channel (of the same image) might be rotated 222 | by -5 degrees. If you don't have any channels (2D grayscale), 223 | you can simply ignore this setting. 224 | Default is True (transform all equally). 225 | channel_is_first_axis: Same as in ImageAugmenter.__init__(). 226 | Whether the channel (e.g. RGB) is the first 227 | axis of each image (True) or the last axis (False). 228 | False matches the scipy and PIL implementation and is the 229 | default. If your images are 2D-grayscale then you can ignore 230 | this setting (as the augmenter will ignore it too). 231 | random_order: Whether to apply the augmentation matrices in a random 232 | order (True, e.g. the 2nd matrix might be applied to the 233 | 5th image) or in the given order (False, e.g. the 2nd matrix might 234 | be applied to the 2nd image). 235 | Notice that for multi-channel images (e.g. RGB) this function 236 | will use a different matrix for each channel, unless 237 | transform_channels_equally is set to True. 238 | mode: Parameter used for the transform.warp-function of scikit-image. 239 | Can usually be ignored. 240 | cval: Parameter used for the transform.warp-function of scikit-image. 241 | Defines the fill color for "new" pixels, e.g. for empty areas 242 | after rotations. (0.0 is black, 1.0 is white.) 243 | interpolation_order: Parameter used for the transform.warp-function of 244 | scikit-image. Defines the order of all interpolations used to 245 | generate the new/augmented image. See their documentation for 246 | further details. 247 | seed: Seed to use for python's and numpy's random functions. 248 | """ 249 | # images must be numpy array 250 | assert type(images).__module__ == np.__name__, "Expected numpy array for " \ 251 | "parameter 'images'." 252 | 253 | # images must have uint8 as dtype (0-255) 254 | assert images.dtype.name == "uint8", "Expected numpy.uint8 as image dtype." 255 | 256 | # 3 axis total (2 per image) for grayscale, 257 | # 4 axis total (3 per image) for RGB (usually) 258 | assert len(images.shape) in [3, 4], """Expected 'images' parameter to have 259 | either shape (image index, y, x) for greyscale 260 | or (image index, channel, y, x) / (image index, y, x, channel) 261 | for multi-channel (usually color) images.""" 262 | 263 | if seed: 264 | np.random.seed(seed) 265 | 266 | nb_images = images.shape[0] 267 | 268 | # estimate number of channels, set to 1 if there is no axis channel, 269 | # otherwise it will usually be 3 270 | has_channels = False 271 | nb_channels = 1 272 | if len(images.shape) == 4: 273 | has_channels = True 274 | if channel_is_first_axis: 275 | nb_channels = images.shape[1] # first axis within each image 276 | else: 277 | nb_channels = images.shape[3] # last axis within each image 278 | 279 | # whether to apply the transformations directly to the whole image 280 | # array (True) or for each channel individually (False) 281 | apply_directly = not has_channels or (transform_channels_equally 282 | and not channel_is_first_axis) 283 | 284 | # We generate here the order in which the matrices may be applied. 285 | # At the end, order_indices will contain the index of the matrix to use 286 | # for each image, e.g. [15, 2] would mean, that the 15th matrix will be 287 | # applied to the 0th image, the 2nd matrix to the 1st image. 288 | # If the images gave multiple channels (e.g. RGB) and 289 | # transform_channels_equally has been set to False, we will need one 290 | # matrix per channel instead of per image. 291 | 292 | # 0 to nb_images, but restart at 0 if index is beyond number of matrices 293 | len_indices = nb_images if apply_directly else nb_images * nb_channels 294 | if random_order: 295 | # Notice: This way to choose random matrices is concise, but can create 296 | # problems if there is a low amount of images and matrices. 297 | # E.g. suppose that 2 images are ought to be transformed by either 298 | # 0px translation on the x-axis or 1px translation. So 50% of all 299 | # matrices translate by 0px and 50% by 1px. The following method 300 | # will randomly choose a combination of the two matrices for the 301 | # two images (matrix 0 for image 0 and matrix 0 for image 1, 302 | # matrix 0 for image 0 and matrix 1 for image 1, ...). 303 | # In 50% of these cases, a different matrix will be chosen for image 0 304 | # and image 1 (matrices 0, 1 or matrices 1, 0). But 50% of these 305 | # "different" matrices (different index) will be the same, as 50% 306 | # translate by 1px and 50% by 0px. As a result, 75% of all augmentations 307 | # will transform both images in the same way. 308 | # The effect decreases if more matrices or images are chosen. 309 | order_indices = np.random.random_integers(0, len(matrices) - 1, len_indices) 310 | else: 311 | # monotonously growing indexes (each by +1), but none of them may be 312 | # higher than or equal to the number of matrices 313 | order_indices = np.arange(0, len_indices) % len(matrices) 314 | 315 | result = np.zeros(images.shape, dtype=np.float32) 316 | matrix_number = 0 317 | 318 | # iterate over every image, find out which matrix to apply and then use 319 | # that matrix to augment the image 320 | for img_idx, image in enumerate(images): 321 | if apply_directly: 322 | # we can apply the matrix to the whole numpy array of the image 323 | # at the same time, so we do that to save time (instead of eg. three 324 | # steps for three channels as in the else-part) 325 | matrix = matrices[order_indices[matrix_number]] 326 | result[img_idx, ...] = tf.warp(image, matrix, mode=mode, cval=cval, 327 | order=interpolation_order) 328 | matrix_number += 1 329 | else: 330 | # we cant apply the matrix to the whole image in one step, instead 331 | # we have to apply it to each channel individually. that happens 332 | # if the channel is the first axis of each image (incompatible with 333 | # tf.warp()) or if it was explicitly requested via 334 | # transform_channels_equally=False. 335 | for channel_idx in range(nb_channels): 336 | matrix = matrices[order_indices[matrix_number]] 337 | if channel_is_first_axis: 338 | warped = tf.warp(image[channel_idx], matrix, mode=mode, 339 | cval=cval, order=interpolation_order) 340 | result[img_idx, channel_idx, ...] = warped 341 | else: 342 | warped = tf.warp(image[..., channel_idx], matrix, mode=mode, 343 | cval=cval, order=interpolation_order) 344 | result[img_idx, ..., channel_idx] = warped 345 | 346 | if not transform_channels_equally: 347 | matrix_number += 1 348 | if transform_channels_equally: 349 | matrix_number += 1 350 | 351 | return result 352 | 353 | class ImageAugmenter(object): 354 | """Helper class to randomly augment images, usually for neural networks. 355 | 356 | Example usage: 357 | img_width = 32 # width of the images 358 | img_height = 32 # height of the images 359 | images = ... # e.g. load via scipy.misc.imload(filename) 360 | 361 | # For each image: randomly flip it horizontally (50% chance), 362 | # randomly rotate it between -20 and +20 degrees, randomly translate 363 | # it on the x-axis between -5 and +5 pixel. 364 | ia = ImageAugmenter(img_width, img_height, hlip=True, rotation_deg=20, 365 | translation_x_px=5) 366 | augmented_images = ia.augment_batch(images) 367 | """ 368 | def __init__(self, img_width_px, img_height_px, channel_is_first_axis=False, 369 | hflip=False, vflip=False, 370 | scale_to_percent=1.0, scale_axis_equally=False, 371 | rotation_deg=0, shear_deg=0, 372 | translation_x_px=0, translation_y_px=0, 373 | transform_channels_equally=True): 374 | """ 375 | Args: 376 | img_width_px: The intended width of each image in pixels. 377 | img_height_px: The intended height of each image in pixels. 378 | channel_is_first_axis: Whether the channel (e.g. RGB) is the first 379 | axis of each image (True) or the last axis (False). 380 | False matches the scipy and PIL implementation and is the 381 | default. If your images are 2D-grayscale then you can ignore 382 | this setting (as the augmenter will ignore it too). 383 | hflip: Whether to randomly flip images horizontally (on the y-axis). 384 | You may choose either False (no horizontal flipping), 385 | True (flip with probability 0.5) or use a float 386 | value (probability) between 0.0 and 1.0. Default is False. 387 | vflip: Whether to randomly flip images vertically (on the x-axis). 388 | You may choose either False (no vertical flipping), 389 | True (flip with probability 0.5) or use a float 390 | value (probability) between 0.0 and 1.0. Default is False. 391 | scale_to_percent: Up to which percentage the images may be 392 | scaled/zoomed. The negative scaling is automatically derived 393 | from this value. A value of 1.1 allows scaling by any value 394 | between -10% and +10%. You may set min and max values yourself 395 | by using a tuple instead, like (1.1, 1.2) to scale between 396 | +10% and +20%. Default is 1.0 (no scaling). 397 | scale_axis_equally: Whether to always scale both axis (x and y) 398 | in the same way. If set to False, then e.g. the Augmenter 399 | might scale the x-axis by 20% and the y-axis by -5%. 400 | Default is False. 401 | rotation_deg: By how much the image may be rotated around its 402 | center (in degrees). The negative rotation will automatically 403 | be derived from this value. E.g. a value of 20 allows any 404 | rotation between -20 degrees and +20 degrees. You may set min 405 | and max values yourself by using a tuple instead, e.g. (5, 20) 406 | to rotate between +5 und +20 degrees. Default is 0 (no 407 | rotation). 408 | shear_deg: By how much the image may be sheared (in degrees). The 409 | negative value will automatically be derived from this value. 410 | E.g. a value of 20 allows any shear between -20 degrees and 411 | +20 degrees. You may set min and max values yourself by using a 412 | tuple instead, e.g. (5, 20) to shear between +5 und +20 413 | degrees. Default is 0 (no shear). 414 | translation_x_px: By up to how many pixels the image may be 415 | translated (moved) on the x-axis. The negative value will 416 | automatically be derived from this value. E.g. a value of +7 417 | allows any translation between -7 and +7 pixels on the x-axis. 418 | You may set min and max values yourself by using a tuple 419 | instead, e.g. (5, 20) to translate between +5 und +20 pixels. 420 | Default is 0 (no translation on the x-axis). 421 | translation_y_px: See translation_x_px, just for the y-axis. 422 | transform_channels_equally: Whether to apply the exactly same 423 | transformations to each channel of an image (True). Setting 424 | it to False allows different transformations per channel, 425 | e.g. the red-channel might be rotated by +20 degrees, while 426 | the blue channel (of the same image) might be rotated 427 | by -5 degrees. If you don't have any channels (2D grayscale), 428 | you can simply ignore this setting. 429 | Default is True (transform all equally). 430 | """ 431 | self.img_width_px = img_width_px 432 | self.img_height_px = img_height_px 433 | self.channel_is_first_axis = channel_is_first_axis 434 | 435 | self.hflip_prob = 0.0 436 | # note: we have to check first for floats, otherwise "hflip == True" 437 | # will evaluate to true if hflip is 1.0. So chosing 1.0 (100%) would 438 | # result in hflip_prob to be set to 0.5 (50%). 439 | if isinstance(hflip, float): 440 | assert hflip >= 0.0 and hflip <= 1.0 441 | self.hflip_prob = hflip 442 | elif hflip == True: 443 | self.hflip_prob = 0.5 444 | elif hflip == False: 445 | self.hflip_prob = 0.0 446 | else: 447 | raise Exception("Unexpected value for parameter 'hflip'.") 448 | 449 | self.vflip_prob = 0.0 450 | if isinstance(vflip, float): 451 | assert vflip >= 0.0 and vflip <= 1.0 452 | self.vflip_prob = vflip 453 | elif vflip == True: 454 | self.vflip_prob = 0.5 455 | elif vflip == False: 456 | self.vflip_prob = 0.0 457 | else: 458 | raise Exception("Unexpected value for parameter 'vflip'.") 459 | 460 | self.scale_to_percent = scale_to_percent 461 | self.scale_axis_equally = scale_axis_equally 462 | self.rotation_deg = rotation_deg 463 | self.shear_deg = shear_deg 464 | self.translation_x_px = translation_x_px 465 | self.translation_y_px = translation_y_px 466 | self.transform_channels_equally = transform_channels_equally 467 | self.cval = 0.0 468 | self.interpolation_order = 1 469 | self.pregenerated_matrices = None 470 | 471 | def pregenerate_matrices(self, nb_matrices, seed=None): 472 | """Pregenerate/cache augmentation matrices. 473 | 474 | If matrices are pregenerated, augment_batch() will reuse them on 475 | each call. The augmentations will not always be the same, 476 | as the order of the matrices will be randomized (when 477 | they are applied to the images). The requirement for that is though 478 | that you pregenerate enough of them (e.g. a couple thousand). 479 | 480 | Note that generating the augmentation matrices is usually fast 481 | and only starts to make sense if you process millions of small images 482 | or many tens of thousands of big images. 483 | 484 | Each call to this method results in pregenerating a new set of matrices, 485 | e.g. to replace a list of matrices that has been used often enough. 486 | 487 | Calling this method with nb_matrices set to 0 will remove the 488 | pregenerated matrices and augment_batch() returns to its default 489 | behaviour of generating new matrices on each call. 490 | 491 | Args: 492 | nb_matrices: The number of matrices to pregenerate. E.g. a few 493 | thousand. If set to 0, the matrices will be generated again on 494 | each call of augment_batch(). 495 | seed: A random seed to use. 496 | """ 497 | assert nb_matrices >= 0 498 | if nb_matrices == 0: 499 | self.pregenerated_matrices = None 500 | else: 501 | matrices = create_aug_matrices(nb_matrices, 502 | self.img_width_px, 503 | self.img_height_px, 504 | scale_to_percent=self.scale_to_percent, 505 | scale_axis_equally=self.scale_axis_equally, 506 | rotation_deg=self.rotation_deg, 507 | shear_deg=self.shear_deg, 508 | translation_x_px=self.translation_x_px, 509 | translation_y_px=self.translation_y_px, 510 | seed=seed) 511 | self.pregenerated_matrices = matrices 512 | 513 | def augment_batch(self, images, seed=None): 514 | """Augments a batch of images. 515 | 516 | Applies all settings (rotation, shear, translation, ...) that 517 | have been chosen in the constructor. 518 | 519 | Args: 520 | images: Numpy array (dtype: uint8, i.e. values 0-255) with the images. 521 | Expected shape is either (image-index, height, width) for 522 | grayscale images or (image-index, channel, height, width) for 523 | images with channels (e.g. RGB) where the channel has the first 524 | index or (image-index, height, width, channel) for images with 525 | channels, where the channel is the last index. 526 | If your shape is (image-index, channel, width, height) then 527 | you must also set channel_is_first_axis=True in the constructor. 528 | seed: Seed to use for python's and numpy's random functions. 529 | Default is None (dont use a seed). 530 | 531 | Returns: 532 | Augmented images as numpy array of dtype float32 (i.e. values 533 | are between 0.0 and 1.0). 534 | """ 535 | shape = images.shape 536 | nb_channels = 0 537 | if len(shape) == 3: 538 | # shape like (image_index, y-axis, x-axis) 539 | assert shape[1] == self.img_height_px 540 | assert shape[2] == self.img_width_px 541 | nb_channels = 1 542 | elif len(shape) == 4: 543 | if not self.channel_is_first_axis: 544 | # shape like (image-index, y-axis, x-axis, channel-index) 545 | assert shape[1] == self.img_height_px 546 | assert shape[2] == self.img_width_px 547 | nb_channels = shape[3] 548 | else: 549 | # shape like (image-index, channel-index, y-axis, x-axis) 550 | assert shape[2] == self.img_height_px 551 | assert shape[3] == self.img_width_px 552 | nb_channels = shape[1] 553 | else: 554 | msg = "Mismatch between images shape %s and " \ 555 | "predefined image width/height (%d/%d)." 556 | raise Exception(msg % (str(shape), self.img_width_px, self.img_height_px)) 557 | 558 | if seed: 559 | random.seed(seed) 560 | np.random.seed(seed) 561 | 562 | # -------------------------------- 563 | # horizontal and vertical flipping/mirroring 564 | # -------------------------------- 565 | # This should be done before applying the affine matrices, as otherwise 566 | # contents of image might already be rotated/translated out of the image. 567 | # It is done with numpy instead of the affine matrices, because 568 | # scikit-image doesn't offer a nice interface to add mirroring/flipping 569 | # to affine transformations. The numpy operations are O(1), so they 570 | # shouldn't have a noticeable effect on runtimes. They also won't suffer 571 | # from interpolation problems. 572 | if self.hflip_prob > 0 or self.vflip_prob > 0: 573 | # TODO this currently ignores the setting in 574 | # transform_channels_equally and will instead always flip all 575 | # channels equally 576 | 577 | # if this is simply a view, then the input array gets flipped too 578 | # for some reason 579 | images_flipped = np.copy(images) 580 | #images_flipped = images.view() 581 | 582 | if len(shape) == 4 and self.channel_is_first_axis: 583 | # roll channel to the last axis 584 | # swapaxes doesnt work here, because 585 | # (image index, channel, y, x) 586 | # would be turned into 587 | # (image index, x, y, channel) 588 | # and y needs to come before x 589 | images_flipped = np.rollaxis(images_flipped, 1, 4) 590 | 591 | y_p = self.hflip_prob 592 | x_p = self.vflip_prob 593 | for i in range(images.shape[0]): 594 | if y_p > 0 and random.random() < y_p: 595 | images_flipped[i] = np.fliplr(images_flipped[i]) 596 | if x_p > 0 and random.random() < x_p: 597 | images_flipped[i] = np.flipud(images_flipped[i]) 598 | 599 | if len(shape) == 4 and self.channel_is_first_axis: 600 | # roll channel back to the second axis (index 1) 601 | images_flipped = np.rollaxis(images_flipped, 3, 1) 602 | images = images_flipped 603 | 604 | # -------------------------------- 605 | # if no augmentation has been chosen, stop early 606 | # for improved performance (evade applying matrices) 607 | # -------------------------------- 608 | if self.pregenerated_matrices is None \ 609 | and self.scale_to_percent == 1.0 and self.rotation_deg == 0 \ 610 | and self.shear_deg == 0 \ 611 | and self.translation_x_px == 0 and self.translation_y_px == 0: 612 | return np.array(images, dtype=np.float32) / 255 613 | 614 | # -------------------------------- 615 | # generate transformation matrices 616 | # -------------------------------- 617 | if self.pregenerated_matrices is not None: 618 | matrices = self.pregenerated_matrices 619 | else: 620 | # estimate the number of matrices required 621 | if self.transform_channels_equally: 622 | nb_matrices = shape[0] 623 | else: 624 | nb_matrices = shape[0] * nb_channels 625 | 626 | # generate matrices 627 | matrices = create_aug_matrices(nb_matrices, 628 | self.img_width_px, 629 | self.img_height_px, 630 | scale_to_percent=self.scale_to_percent, 631 | scale_axis_equally=self.scale_axis_equally, 632 | rotation_deg=self.rotation_deg, 633 | shear_deg=self.shear_deg, 634 | translation_x_px=self.translation_x_px, 635 | translation_y_px=self.translation_y_px, 636 | seed=seed) 637 | 638 | # -------------------------------- 639 | # apply transformation matrices (i.e. augment images) 640 | # -------------------------------- 641 | return apply_aug_matrices(images, matrices, 642 | transform_channels_equally=self.transform_channels_equally, 643 | channel_is_first_axis=self.channel_is_first_axis, 644 | cval=self.cval, interpolation_order=self.interpolation_order, 645 | seed=seed) 646 | 647 | def plot_image(self, image, nb_repeat=40, show_plot=True): 648 | """Plot augmented variations of an image. 649 | 650 | This method takes an image and plots it by default in 40 differently 651 | augmented versions. 652 | 653 | This method is intended to visualize the strength of your chosen 654 | augmentations (so for debugging). 655 | 656 | Args: 657 | image: The image to plot. 658 | nb_repeat: How often to plot the image. Each time it is plotted, 659 | the chosen augmentation will be different. (Default: 40). 660 | show_plot: Whether to show the plot. False makes sense if you 661 | don't have a graphical user interface on the machine. 662 | (Default: True) 663 | 664 | Returns: 665 | The figure of the plot. 666 | Use figure.savefig() to save the image. 667 | """ 668 | if len(image.shape) == 2: 669 | images = np.resize(image, (nb_repeat, image.shape[0], image.shape[1])) 670 | else: 671 | images = np.resize(image, (nb_repeat, image.shape[0], image.shape[1], 672 | image.shape[2])) 673 | return self.plot_images(images, True, show_plot=show_plot) 674 | 675 | def plot_images(self, images, augment, show_plot=True, figure=None): 676 | """Plot augmented variations of images. 677 | 678 | The images will all be shown in the same window. 679 | It is recommended to not plot too many of them (i.e. stay below 100). 680 | 681 | This method is intended to visualize the strength of your chosen 682 | augmentations (so for debugging). 683 | 684 | Args: 685 | images: A numpy array of images. See augment_batch(). 686 | augment: Whether to augment the images (True) or just display 687 | them in the way they are (False). 688 | show_plot: Whether to show the plot. False makes sense if you 689 | don't have a graphical user interface on the machine. 690 | (Default: True) 691 | figure: The figure of the plot in which to draw the images. 692 | Provide the return value of this function (from a prior call) 693 | to draw in the same plot window again. Chosing 'None' will 694 | create a new figure. (Default is None.) 695 | 696 | Returns: 697 | The figure of the plot. 698 | Use figure.savefig() to save the image. 699 | """ 700 | import matplotlib.pyplot as plt 701 | import matplotlib.cm as cm 702 | 703 | if augment: 704 | images = self.augment_batch(images) 705 | 706 | # (Lists of) Grayscale images have the shape (image index, y, x) 707 | # Multi-Channel images therefore must have 4 or more axes here 708 | if len(images.shape) >= 4: 709 | # The color-channel is expected to be the last axis by matplotlib 710 | # therefore exchange the axes, if its the first one here 711 | if self.channel_is_first_axis: 712 | images = np.rollaxis(images, 1, 4) 713 | 714 | nb_cols = 10 715 | nb_rows = 1 + int(images.shape[0] / nb_cols) 716 | if figure is not None: 717 | fig = figure 718 | plt.figure(fig.number) 719 | fig.clear() 720 | else: 721 | fig = plt.figure(figsize=(10, 10)) 722 | 723 | for i, image in enumerate(images): 724 | image = images[i] 725 | 726 | plot_number = i + 1 727 | ax = fig.add_subplot(nb_rows, nb_cols, plot_number, xticklabels=[], 728 | yticklabels=[]) 729 | ax.set_axis_off() 730 | # "cmap" should restrict the color map to grayscale, but strangely 731 | # also works well with color images 732 | imgplot = plt.imshow(image, cmap=cm.Greys_r, aspect="equal") 733 | 734 | # not showing the plot might be useful e.g. on clusters 735 | if show_plot: 736 | plt.show() 737 | 738 | return fig 739 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Alexander Jung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | Simple convolutional neural network to detect cat bounding boxes in images. 4 | The system is restricted to one bounding box per image, which is localized using regression (i.e. directly predicting the bounding box coordinates). 5 | The model consists of 7 convolutional layers and 2 fully connected layers (including output layer). 6 | 7 | # Dependencies 8 | 9 | * python 2.7 (only tested with that version) 10 | * keras (tested in v1.06) 11 | * scipy 12 | * numpy 13 | * scikit-image 14 | 15 | # Usage 16 | 17 | * Download the [10k cats dataset](https://web.archive.org/web/20150520175645/http://137.189.35.203/WebUI/CatDatabase/catData.html) and extract it, e.g. into directory `/foo/bar/10k-cats`. That directory should contain the subdirectories `CAT_00`, `CAT_01`, etc. 18 | * Train the model using `train_convnet.py --dataset="/foo/bar/10k-cats"`. 19 | * Apply the model using `train_convnet.py --dataset="/foo/bar/directory-with-cat-images"`. 20 | 21 | # Images 22 | 23 | Example results: 24 | 25 | ![Located cat face](images/1-39-125521249_b1318298ec_n.jpg?raw=true "Located cat face") 26 | ![Located cat face](images/1-56-180653960_21cf28e0b3_n.jpg?raw=true "Located cat face") 27 | ![Located cat face](images/1-61-213767259_11c8550a0e_n.jpg?raw=true "Located cat face") 28 | ![Located cat face](images/1-266-19922494159_78303f8b16_n.jpg?raw=true "Located cat face") 29 | ![Located cat face](images/3-2287-2088100404_c0112197e3_n.jpg?raw=true "Located cat face") 30 | ![Located cat face](images/3-2831-10902603864_4993c4aa1a_n.jpg?raw=true "Located cat face") 31 | -------------------------------------------------------------------------------- /apply_convnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | File to apply the trained ConvNet model to a number of images. 4 | It will use the ConvNet to locate cat faces in the images and mark them. 5 | It is expected that each image contains exactly one cat (i.e. a face will be 6 | extracted out of each image, even if there is no cat). 7 | If an image contains multiple cats, only one face will be extracted. 8 | 9 | Usage: 10 | python train.py 11 | python apply_convnet.py 12 | """ 13 | from __future__ import division, print_function 14 | from dataset import Dataset 15 | import os 16 | import re 17 | import numpy as np 18 | import argparse 19 | import random 20 | from scipy import ndimage 21 | from scipy import misc 22 | from train_convnet import MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, \ 23 | BATCH_SIZE, SAVE_WEIGHTS_CHECKPOINT_FILEPATH, \ 24 | create_model, draw_predicted_rectangle 25 | from keras.optimizers import Adam 26 | 27 | np.random.seed(42) 28 | random.seed(42) 29 | 30 | CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 31 | OUT_SCALE = 64 # scale (height, width) of each saved image 32 | 33 | def main(): 34 | """ 35 | Main function. 36 | Does the following step by step: 37 | * Load images (from which to extract cat faces) from SOURCE_DIR 38 | * Initialize model (as trained via train_convnet.py) 39 | * Loads and prepares images for the model. 40 | * Uses trained model to predict locations of cat faces. 41 | * Projects face coordinates onto original images 42 | * Marks faces in original images. 43 | * Saves each marked image. 44 | """ 45 | parser = argparse.ArgumentParser(description="Apply a trained cat face locator " \ 46 | "model to images.") 47 | parser.add_argument("--images", required=True, help="Directory containing images to analyze.") 48 | parser.add_argument("--weights", required=False, default=SAVE_WEIGHTS_CHECKPOINT_FILEPATH, 49 | help="Filepath to the weights of the model.") 50 | parser.add_argument("--output", required=False, default=os.path.join(CURRENT_DIR, "predictions"), 51 | help="Filepath to the directory in which to save the output.") 52 | args = parser.parse_args() 53 | 54 | # load images 55 | filepaths = get_image_filepaths([args.images]) 56 | filenames = [os.path.basename(fp) for fp in filepaths] # will be used during saving 57 | nb_images = len(filepaths) 58 | X = np.zeros((nb_images, MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, 3), dtype=np.float32) 59 | for i, fp in enumerate(filepaths): 60 | image = ndimage.imread(fp, mode="RGB") 61 | image = misc.imresize(image, (MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH)) 62 | X[i] = image / 255.0 63 | X = np.rollaxis(X, 3, 1) 64 | 65 | # assure that dataset is not empty 66 | print("Found %d images..." % (X.shape[0],)) 67 | assert X.shape[0] > 0, "The dataset appears to be empty (shape of X: %s)." % (X.shape,) 68 | 69 | # create model 70 | model = create_model(MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, "mse", Adam()) 71 | model.load_weights(args.weights) 72 | 73 | # predict positions of faces 74 | preds = model.predict(X, batch_size=BATCH_SIZE) 75 | 76 | # Draw predicted rectangles and save 77 | print("Saving images...") 78 | for idx, (y, x, half_height, half_width) in enumerate(preds): 79 | img = draw_predicted_rectangle(X[idx], y, x, half_height, half_width) 80 | filepath = os.path.join(args.output, filenames[idx]) 81 | misc.imsave(filepath, img) 82 | 83 | def get_image_filepaths(dirs): 84 | """Loads filepaths of images from dataset. 85 | Args: 86 | dirs List of directories as strings 87 | Returns: 88 | List of strings (filepaths)""" 89 | result_img = [] 90 | for fp_dir in dirs: 91 | fps = [f for f in os.listdir(fp_dir) if os.path.isfile(os.path.join(fp_dir, f))] 92 | fps = [os.path.join(fp_dir, f) for f in fps] 93 | fps_img = [fp for fp in fps if re.match(r".*\.jpg$", fp)] 94 | result_img.extend(fps_img) 95 | return result_img 96 | 97 | if __name__ == "__main__": 98 | main() 99 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper classes to handle the process of normalizing and augmenting the 10k cats dataset. 3 | 4 | Classes: 5 | Dataset Handles loading of cat images with keypoints (e.g. eyes, ears) 6 | ImageWithKeypoints Container for one example images with its keypoints, 7 | supports e.g. resizing the image, showing it (in a window), drawing 8 | points/rectangles on it or augmenting it. 9 | Keypoints Helper class to handle the keypoints of one image, 10 | supports e.g. shifting/translating them by N pixels, warping them 11 | via an affine transformation matrix, flipping them or calculating 12 | a face rectangle from them. 13 | Point2D A class to encapsulate a (y, x) coordinate. 14 | PointsList A list of Point2D 15 | Rectangle A rectangle in an image, used for the face rectangles. 16 | 17 | Note that coordinates are usually provided as (y, x) not (x, y). 18 | """ 19 | from __future__ import print_function, division 20 | import os 21 | import re 22 | import math 23 | import random 24 | from scipy import misc 25 | import numpy as np 26 | from ImageAugmenter import create_aug_matrices 27 | from skimage import transform as tf 28 | from skimage import color 29 | from skimage import exposure 30 | 31 | WARP_KEYPOINTS_MODE = "constant" 32 | WARP_KEYPOINTS_CVAL = 0.0 33 | WARP_KEYPOINTS_INTERPOLATION_ORDER = 1 34 | 35 | class Dataset(object): 36 | """Helper class to load images with facial keypoints.""" 37 | def __init__(self, dirs): 38 | """Initialize the class. 39 | Args: 40 | dirs A list of directories (filepaths) to load from.""" 41 | self.dirs = dirs 42 | self.fps = self.get_image_filepaths() 43 | 44 | def get_images(self, start_at=None, count=None): 45 | """Load images with keypoints. 46 | Args: 47 | start_at Index of first image to load. 48 | count Maximum number of images to load. 49 | Returns: 50 | List of ImageWithKeypoints (generator)""" 51 | start_at = 0 if start_at is None else start_at 52 | end_at = len(self.fps) if count is None else start_at+count 53 | for fp in self.fps[start_at:end_at]: 54 | image = misc.imread(fp) 55 | keypoints = Keypoints(self.get_keypoints(fp, image.shape[0], image.shape[1])) 56 | yield ImageWithKeypoints(image, keypoints) 57 | 58 | def get_image_filepaths(self): 59 | """Loads filepaths of example images. 60 | Returns: 61 | List of strings (filepaths)""" 62 | result_img = [] 63 | for fp_dir in self.dirs: 64 | fps = [f for f in os.listdir(fp_dir) if os.path.isfile(os.path.join(fp_dir, f))] 65 | fps = [os.path.join(fp_dir, f) for f in fps] 66 | fps_img = [fp for fp in fps if re.match(r".*\.jpg$", fp)] 67 | fps_img = [fp for fp in fps if os.path.isfile("%s.cat" % (fp,))] 68 | result_img.extend(fps_img) 69 | 70 | return result_img 71 | 72 | def get_keypoints(self, image_filepath, image_height, image_width): 73 | """Loads the keypoints of one image. 74 | Args: 75 | image_filepath Filepath of the image for which to load keypoints. 76 | image_height Height of the image. 77 | image_width Width of the image. 78 | Returns: 79 | Numpy array of shape (19,)""" 80 | fp_keypoints = "%s.cat" % (image_filepath,) 81 | if not os.path.isfile(fp_keypoints): 82 | raise Exception("Could not find keypoint coordinates for image '%s'." \ 83 | % (image_filepath,)) 84 | else: 85 | coords_raw = open(fp_keypoints, "r").readlines()[0].strip().split(" ") 86 | coords_raw = [abs(int(coord)) for coord in coords_raw] 87 | keypoints_arr = np.zeros((9*2,), dtype=np.uint16) 88 | for i in range(1, len(coords_raw), 2): # first element is the number of coords 89 | y = clip(0, coords_raw[i+1], image_height-1) 90 | x = clip(0, coords_raw[i], image_width-1) 91 | keypoints_arr[(i-1)] = y 92 | keypoints_arr[(i-1) + 1] = x 93 | return keypoints_arr 94 | 95 | class ImageWithKeypoints(object): 96 | """Container for an example image and its keypoints.""" 97 | def __init__(self, image_arr, keypoints): 98 | """Instantiate an object. 99 | Args: 100 | image_arr Numpy array of the image, shape (height, width, channels) 101 | keypoints Keypoints object""" 102 | assert len(image_arr.shape) == 3 103 | assert image_arr.shape[2] == 3 104 | self.image_arr = image_arr 105 | self.keypoints = keypoints 106 | 107 | def copy(self): 108 | """Copy the object. 109 | Returns: ImageWithKeypoints""" 110 | return ImageWithKeypoints(np.copy(self.image_arr), self.keypoints.copy()) 111 | 112 | def get_height(self): 113 | """Get the image height. 114 | Returns: height, integer""" 115 | return self.image_arr.shape[0] 116 | 117 | def get_width(self): 118 | """Get the image width. 119 | Returns: width, integer""" 120 | return self.image_arr.shape[1] 121 | 122 | def get_center(self): 123 | """Get the center of the image. 124 | Returns: Point2D""" 125 | y, x = self.get_height()/2, self.get_width()/2 126 | return Point2D(y=int(y), x=int(x)) 127 | 128 | def resize(self, new_height, new_width): 129 | """Resize the image to given height and width. 130 | Args: 131 | new_height Height to resize to. 132 | new_width Width to resize to.""" 133 | self.keypoints.normalize(self) 134 | # unclear in scipy doc if (new_height, new_width) or (new_width, new_height) is correct 135 | #print(self.image_arr.shape) 136 | #self.image_arr = misc.imresize(np.rollaxis(self.image_arr, 2, 0), (new_height, new_width)) 137 | self.image_arr = misc.imresize(self.image_arr, (new_height, new_width)) 138 | #self.image_arr = np.rollaxis(self.image_arr, 0, 3) 139 | #print(self.image_arr.shape) 140 | self.keypoints.unnormalize(self) 141 | 142 | def grayscale(self): 143 | """Converts the image to grayscale.""" 144 | self.image_arr = color.rgb2gray(self.image_arr) 145 | 146 | def equalize(self): 147 | """Perform adaptive histogram equalization.""" 148 | self.image_arr = exposure.equalize_adapthist(self.image_arr, clip_limit=0.03) 149 | self.image_arr = self.image_arr * 256 150 | self.image_arr = np.clip(self.image_arr, 0, 255) 151 | self.image_arr = self.image_arr.astype(np.uint8) 152 | 153 | def pad(self, nb_pixels, mode="median"): 154 | """Adds in-place N pixels to the sides of the image. 155 | Args: 156 | nb_pixels Number of pixels 157 | mode Padding mode for numpy.pad. 158 | """ 159 | nb_top = nb_pixels 160 | nb_bottom = nb_pixels 161 | nb_left = nb_pixels 162 | nb_right = nb_pixels 163 | if len(self.image_arr.shape) == 2: 164 | self.image_arr = np.pad(self.image_arr, ((nb_top, nb_bottom), \ 165 | (nb_left, nb_right)), \ 166 | mode=mode) 167 | else: 168 | self.image_arr = np.pad(self.image_arr, ((nb_top, nb_bottom), \ 169 | (nb_left, nb_right), \ 170 | (0, 0)), \ 171 | mode=mode) 172 | self.keypoints.shift_y(nb_top, self) 173 | self.keypoints.shift_x(nb_left, self) 174 | 175 | def unpad(self, nb_pixels): 176 | """Removes padding around the image. Updates keypoints accordingly. 177 | Args: nb_pixels: Number of pixels of padding to remove""" 178 | self.image_arr = self.image_arr[nb_pixels:self.get_height()-nb_pixels, nb_pixels:self.get_width()-nb_pixels, ...] 179 | self.keypoints.shift_y(-nb_pixels, self) 180 | self.keypoints.shift_x(-nb_pixels, self) 181 | 182 | def remove_rotation(self): 183 | """Removes the image's rotation by aligning its eyeline parallel to the x axis.""" 184 | angle = math.radians(self.keypoints.get_angle_between_eyes(normalize=False)) 185 | 186 | # move eyes center to top left of image 187 | eyes_center = self.keypoints.get_eyes_center() 188 | img_center = self.get_center() 189 | matrix_to_topleft = tf.SimilarityTransform(translation=[-eyes_center.x, -eyes_center.y]) 190 | 191 | # rotate the image around the top left corner by -$angle degrees 192 | matrix_transforms = tf.AffineTransform(rotation=-angle) 193 | 194 | # move the face to the center of the image 195 | # this protects against parts of the face leaving the image (because of the rotation) 196 | matrix_to_center = tf.SimilarityTransform(translation=[img_center.x, img_center.y]) 197 | 198 | # combine to one affine transformation 199 | matrix = matrix_to_topleft + matrix_transforms + matrix_to_center 200 | matrix = matrix.inverse 201 | 202 | # apply transformations 203 | new_image = tf.warp(self.image_arr, matrix, mode="constant") 204 | new_image = np.array(new_image * 255, dtype=np.uint8) 205 | self.image_arr = new_image 206 | 207 | # create new image with N channels for N coordinates 208 | # mark each coordinate's pixel in the respective channel 209 | # rotate 210 | # read out new coordinates (after rotation) 211 | self.keypoints.warp(self, matrix) 212 | 213 | if self.keypoints.mouth().y < self.keypoints.left_eye().y: 214 | print("Warning: mouth is above left eye") 215 | # unclear where this problem comes from, fix it with flipping for now 216 | #self.image_arr = np.flipud(self.image_arr) 217 | #self.keypoints.flipud(self) 218 | if self.keypoints.right_eye().x < self.keypoints.left_eye().x: 219 | print("Warning: right eye is left, left eye is right") 220 | 221 | def extract_rectangle(self, rect, pad): 222 | """Extracts a rectangle within the image as a new ImageWithKeypoints. 223 | Args: 224 | rect Rectangle object 225 | pad Padding in pixels around the rectangle 226 | Returns: 227 | ImageWithKeypoints""" 228 | pad_black_top = 0 229 | pad_black_right = 0 230 | pad_black_bottom = 0 231 | pad_black_left = 0 232 | 233 | if rect.tl_y - pad < 0: 234 | pad_black_top = abs(rect.tl_y - pad) 235 | if rect.tl_x - pad < 0: 236 | pad_black_left = abs(rect.tl_x - pad) 237 | if rect.br_y + pad > (self.get_height() - 1): 238 | pad_black_bottom = (rect.br_y + pad) - (self.get_height() - 1) 239 | if rect.br_x + pad > (self.get_width() - 1): 240 | pad_black_right = (rect.br_x + pad) - (self.get_width() - 1) 241 | 242 | tl_y = clip(0, rect.tl_y - pad, self.get_height()-1) 243 | tl_x = clip(0, rect.tl_x - pad, self.get_width()-1) 244 | br_y = clip(0, rect.br_y + pad, self.get_height()-1) 245 | br_x = clip(0, rect.br_x + pad, self.get_width()-1) 246 | 247 | img_rect = self.image_arr[tl_y:br_y+1, tl_x:br_x+1, ...] 248 | keypoints = self.keypoints.copy() 249 | img = ImageWithKeypoints(img_rect, keypoints) 250 | keypoints.shift_y(-tl_y, img) 251 | keypoints.shift_x(-tl_x, img) 252 | 253 | img.image_arr = np.pad(img.image_arr, ((pad_black_top, pad_black_bottom), \ 254 | (pad_black_left, pad_black_right), \ 255 | (0, 0)), \ 256 | mode="median") 257 | keypoints.shift_y(pad_black_top, img) 258 | keypoints.shift_x(pad_black_left, img) 259 | 260 | return img 261 | 262 | def extract_face(self, pad): 263 | """Extracts the cat face within the image. 264 | Args: 265 | pad Padding in pixels around the face. 266 | Returns: 267 | ImageWithKeypoints""" 268 | face_rect = self.keypoints.get_rectangle(self) 269 | return self.extract_rectangle(face_rect, pad) 270 | 271 | def augment(self, n, hflip=False, vflip=False, scale_to_percent=1.0, scale_axis_equally=True, 272 | rotation_deg=0, shear_deg=0, translation_x_px=0, translation_y_px=0, 273 | brightness_change=0.0, noise_mean=0.0, noise_std=0.0): 274 | """Generates randomly augmented versions of the image. 275 | Also augments the keypoints accordingly. 276 | 277 | Args: 278 | n Number of augmentations to generate. 279 | hflip Allow horizontal flipping (yes/no). 280 | vflip Allow vertical flipping (yes/no) 281 | scale_to_percent How much scaling/zooming to allow. Values are around 1.0. 282 | E.g. 1.1 is -10% to +10% 283 | E.g. (0.7, 1.05) is -30% to 5%. 284 | scale_axis_equally Whether to enforce equal scaling of x and y axis. 285 | rotation_deg How much rotation to allow. E.g. 5 is -5 degrees to +5 degrees. 286 | shear_deg How much shearing to allow. 287 | translation_x_px How many pixels of translation along the x axis to allow. 288 | translation_y_px How many pixels of translation along the y axis to allow. 289 | brightness_change How much change in brightness to allow. Values are around 0.0. 290 | E.g. 0.2 is -20% to +20%. 291 | noise_mean Mean value of gaussian noise to add. 292 | noise_std Standard deviation of gaussian noise to add. 293 | Returns: 294 | List of ImageWithKeypoints 295 | """ 296 | assert n >= 0 297 | result = [] 298 | if n == 0: 299 | return result 300 | 301 | matrices = create_aug_matrices(n, 302 | img_width_px=self.get_width(), 303 | img_height_px=self.get_height(), 304 | scale_to_percent=scale_to_percent, 305 | scale_axis_equally=scale_axis_equally, 306 | rotation_deg=rotation_deg, 307 | shear_deg=shear_deg, 308 | translation_x_px=translation_x_px, 309 | translation_y_px=translation_y_px) 310 | for i in range(n): 311 | img = self.copy() 312 | matrix = matrices[i] 313 | 314 | # random horizontal / vertical flip 315 | if hflip and random.random() > 0.5: 316 | img.image_arr = np.fliplr(img.image_arr) 317 | img.keypoints.fliplr(img) 318 | if vflip and random.random() > 0.5: 319 | img.image_arr = np.flipud(img.image_arr) 320 | img.keypoints.flipud(img) 321 | 322 | # random brightness adjustment 323 | by_percent = random.uniform(1.0 - brightness_change, 1.0 + brightness_change) 324 | img.image_arr = img.image_arr * by_percent 325 | 326 | # gaussian noise 327 | # numpy requires a std above 0 328 | if noise_std > 0: 329 | img.image_arr = img.image_arr \ 330 | + (255 * np.random.normal(noise_mean, noise_std, 331 | (img.image_arr.shape))) 332 | 333 | # clip to 0-255 334 | img.image_arr = np.clip(img.image_arr, 0, 255).astype(np.uint8) 335 | 336 | arr = tf.warp(img.image_arr, matrix, mode="constant") # projects to float 0-1 337 | img.image_arr = np.array(arr * 255, dtype=np.uint8) 338 | img.keypoints.warp(img, matrix) 339 | result.append(img) 340 | 341 | return result 342 | 343 | def draw_rectangle(self, rect, color_tuple=None): 344 | """Draw a rectangle with given color onto the image. 345 | Args: 346 | rect The rectangle object 347 | color_tuple Color of the rectangle, e.g. (255, 0, 0) for red.""" 348 | self.draw_rectangles([rect], color_tuple=color_tuple) 349 | 350 | def draw_rectangles(self, rects, color_tuple=None): 351 | """Draw several rectangles onto the image.""" 352 | if color_tuple is None: 353 | color_tuple = (255, 0, 0) 354 | 355 | for rect in rects: 356 | for x in range(rect.tl_x, rect.br_x+1): 357 | self.image_arr[rect.tl_y, x, ...] = color_tuple 358 | self.image_arr[rect.br_y, x, ...] = color_tuple 359 | for y in range(rect.tl_y, rect.br_y+1): 360 | self.image_arr[y, rect.tl_x, ...] = color_tuple 361 | self.image_arr[y, rect.br_x, ...] = color_tuple 362 | 363 | def draw_face_rectangles(self): 364 | """Draw all face rectangles onto the image according to the 5 existing methods. 365 | Colors: 366 | Green = Method 0 367 | Blue = Method 1 368 | Red = Method 2 369 | Yellow = Method 3 370 | Cyan = Method 4 371 | """ 372 | self.draw_rectangle(self.keypoints.get_rectangle(self, method=0), color_tuple=(0, 255, 0)) 373 | self.draw_rectangle(self.keypoints.get_rectangle(self, method=1), color_tuple=(0, 0, 255)) 374 | self.draw_rectangle(self.keypoints.get_rectangle(self, method=2), color_tuple=(255, 0, 0)) 375 | self.draw_rectangle(self.keypoints.get_rectangle(self, method=3), color_tuple=(255, 255, 0)) 376 | self.draw_rectangle(self.keypoints.get_rectangle(self, method=4), color_tuple=(0, 255, 255)) 377 | 378 | def draw_point(self, pnt, color_tuple=None): 379 | """Draw a point onto the image.""" 380 | self.draw_point([pnt], color_tuple=color_tuple) 381 | 382 | def draw_points(self, pnts, color_tuple=None): 383 | """Draw several points onto the image.""" 384 | if color_tuple is None: 385 | color_tuple = (255, 0, 0) 386 | 387 | height = self.get_height() 388 | width = self.get_width() 389 | 390 | for pnt in pnts: 391 | self.image_arr[pnt.y, clip(0, pnt.x-1, width-1) \ 392 | :clip(0, pnt.x+2, width-1), ...] = (255, 0, 0) 393 | self.image_arr[clip(0, pnt.y-1, height-1) \ 394 | :clip(0, pnt.y+2, height-1), pnt.x, ...] = (255, 0, 0) 395 | 396 | def draw_keypoints(self, color_tuple=None): 397 | """Draw all image's keypoints as crosses.""" 398 | self.draw_points(self.keypoints.get_points(), color_tuple=color_tuple) 399 | 400 | def show(self): 401 | """Show the image in a window.""" 402 | misc.imshow(self.image_arr) 403 | 404 | def to_array(self): 405 | """Return the image content's numpy array. 406 | Returns: numpy array of shape (height, width, channels)""" 407 | return self.image_arr 408 | 409 | class Keypoints(object): 410 | """Helper class to encapsulate the facial keypoints. 411 | 412 | Existing keypoints: 413 | point number | meaning 414 | 1 = left eye 415 | 2 = right eye 416 | 3 = mouth 417 | 4 = left ear 1 (left side start) 418 | 5 = left ear 2 (tip) 419 | 6 = left ear 3 (right side start) 420 | 7 = right ear 1 (left side start) 421 | 8 = right ear 2 (tip) 422 | 9 = right ear 3 (right side start) 423 | (left/right when looking at cat (not from the perspective of the cat)) 424 | 425 | Rough outline on image (frontal perspective on cat): 426 | 427 | 5 8 428 | 6 7 429 | 4 9 430 | 431 | 1 2 432 | 433 | 3 434 | 435 | """ 436 | def __init__(self, keypoints_arr, is_normalized=False): 437 | """Instantiate a new keypoints object. 438 | Args: 439 | keypoints_arr Numpy array of the keypoints of shape (18,) 440 | is_normalized Whether the keypoints are in the range 0-1 (true) or have integer 441 | pixel values. 442 | """ 443 | assert len(keypoints_arr.shape) == 1 444 | assert len(keypoints_arr) == 9*2 445 | if is_normalized: 446 | assert keypoints_arr.dtype == np.float32 and all([0 <= v <= 1.0 for v in keypoints_arr]) 447 | else: 448 | assert keypoints_arr.dtype == np.uint16 and all([v >= 0 for v in keypoints_arr]) 449 | self.keypoints_arr = keypoints_arr 450 | self.is_normalized = is_normalized 451 | 452 | def copy(self): 453 | """Creates a copy of the keypoints object. 454 | Returns: Keypoints""" 455 | return Keypoints(np.copy(self.keypoints_arr)) 456 | 457 | def normalize(self, image): 458 | """Normalizes the keypoint value to 0-1 floats with respect to the given image's dimensions. 459 | Args: 460 | image ImageWithKeypoints""" 461 | assert not self.is_normalized 462 | height = image.get_height() 463 | width = image.get_width() 464 | self.keypoints_arr = self.keypoints_arr.astype(np.float32) 465 | for i in range(0, len(self.keypoints_arr), 2): 466 | self.keypoints_arr[i] = self.keypoints_arr[i] / height 467 | self.keypoints_arr[i+1] = self.keypoints_arr[i+1] / width 468 | self.is_normalized = True 469 | 470 | def unnormalize(self, image): 471 | """Converts back from 0-1 floats to integer pixel values with respect to the given 472 | image's dimensions. 473 | Args: 474 | image ImageWithKeypoints""" 475 | assert self.is_normalized 476 | height = image.get_height() 477 | width = image.get_width() 478 | for i in range(0, len(self.keypoints_arr), 2): 479 | self.keypoints_arr[i] = self.keypoints_arr[i] * height 480 | self.keypoints_arr[i+1] = self.keypoints_arr[i+1] * width 481 | self.keypoints_arr = self.keypoints_arr.astype(np.uint16) 482 | self.is_normalized = False 483 | 484 | def left_eye(self): 485 | """Returns the coordinates of the left eye as Point2D.""" 486 | return self.get_nth_keypoint(0) 487 | 488 | def right_eye(self): 489 | """Returns the coordinates of the right eye as Point2D.""" 490 | return self.get_nth_keypoint(1) 491 | 492 | def mouth(self): 493 | """Returns the coordinates of the mouth eye as Point2D.""" 494 | return self.get_nth_keypoint(2) 495 | 496 | def get_nth_keypoint(self, nth): 497 | """Returns the coordinates of the n-th (starting with 0) keypoint as Point2D.""" 498 | y = self.keypoints_arr[nth*2] 499 | x = self.keypoints_arr[nth*2 + 1] 500 | if self.is_normalized: 501 | y = float(y) 502 | x = float(x) 503 | else: 504 | y = int(y) 505 | x = int(x) 506 | return Point2D(y=y, x=x) 507 | 508 | def get_face_center(self): 509 | """Returns the coordinates of the face center as Point2D.""" 510 | face_center_x = (self.left_eye().x + self.right_eye().x + self.mouth().x) / 3 511 | face_center_y = (self.left_eye().y + self.right_eye().y + self.mouth().y) / 3 512 | face_center = Point2D(y=int(face_center_y), x=int(face_center_x)) 513 | return face_center 514 | 515 | def get_eyes_center(self): 516 | """Returns the coordinates of center between the eyes as Point2D.""" 517 | x = (self.left_eye().x + self.right_eye().x) / 2 518 | y = (self.left_eye().y + self.right_eye().y) / 2 519 | return Point2D(y=int(y), x=int(x)) 520 | 521 | def get_angle_between_eyes(self, normalize): 522 | """Returns with angle of the eyeline with respect to the x axis in degrees. 523 | E.g. a value of -5 indicates that the face is rotated by 5 degrees counter clock wise. 524 | Args: 525 | normalize Whether to normalize the value to the range of -1 (-180) to +1 (+180). 526 | Returns: 527 | Angle in degrees relative to x axis""" 528 | left_eye = self.left_eye().to_array() 529 | right_eye = self.right_eye().to_array() 530 | # conversion to int is here necessary, otherwise eyes_vector cant have negative values 531 | eyes_vector = right_eye.astype(np.int) - left_eye.astype(np.int) 532 | x_axis_vector = np.array([0, 1]) 533 | angle = angle_between(x_axis_vector, eyes_vector) 534 | angle_deg = math.degrees(angle) 535 | 536 | assert -180 <= angle_deg <= 180, angle_deg 537 | if normalize: 538 | return angle_deg / 180 539 | else: 540 | return angle_deg 541 | 542 | def get_points(self): 543 | """Returns all facial keypoints as Point2D-s. 544 | Returns: List of Point2D.""" 545 | result = [] 546 | for i in range(0, len(self.keypoints_arr)//2): 547 | result.append(self.get_nth_keypoint(i)) 548 | return result 549 | 550 | def get_min_x(self): 551 | """Returns the minimum x value among all facial keypoints.""" 552 | return min([point.x for point in self.get_points()]) 553 | 554 | def get_min_y(self): 555 | """Returns the minimum y value among all facial keypoints.""" 556 | return min([point.y for point in self.get_points()]) 557 | 558 | def get_max_x(self): 559 | """Returns the maximum x value among all facial keypoints.""" 560 | return max([point.x for point in self.get_points()]) 561 | 562 | def get_max_y(self): 563 | """Returns the maximum y value among all facial keypoints.""" 564 | return max([point.y for point in self.get_points()]) 565 | 566 | def shift_x(self, n_pixels, image): 567 | """Shifts all keypoints by N pixels on the x axis. 568 | Args: 569 | n_pixels Shift by that number of pixels 570 | image Image with maximum dimensions, i.e. dont shift further than image.width""" 571 | for i in range(0, len(self.keypoints_arr), 2): 572 | new_val = int(self.keypoints_arr[i+1]) + n_pixels 573 | new_val = clip(0, new_val, image.get_width()-1) 574 | self.keypoints_arr[i+1] = new_val 575 | 576 | def shift_y(self, n_pixels, image): 577 | """Shifts all keypoints by N pixels on the y axis. 578 | Args: 579 | n_pixels Shift by that number of pixels 580 | image Image with maximum dimensions, i.e. dont shift further than image.height""" 581 | for i in range(0, len(self.keypoints_arr), 2): 582 | new_val = int(self.keypoints_arr[i]) + n_pixels 583 | new_val = clip(0, new_val, image.get_height()-1) 584 | self.keypoints_arr[i] = new_val 585 | 586 | def warp(self, image, matrix): 587 | """Warp all keypoints according to an affine transformation matrix. 588 | Args: 589 | image Image with maximum dimensions 590 | matrix Affine transformation matrix from scikit-image.""" 591 | points = self.get_points() 592 | for i, pnt in enumerate(points): 593 | pnt.warp(image, matrix) 594 | self.keypoints_arr[i*2:(i*2)+2] = [pnt.y, pnt.x] 595 | 596 | def fliplr(self, image): 597 | """Flip all keypoints horizontally. 598 | Args: 599 | image Image with maximum dimensions.""" 600 | for i in range(0, len(self.keypoints_arr), 2): 601 | self.keypoints_arr[i+1] = (image.get_width()-1) - self.keypoints_arr[i+1] 602 | # switch points 603 | # 9 with 4 (right ear 3, left ear 1) 604 | self._switch_points(9-1, 4-1) 605 | # 8 with 5 (right ear 2, left ear 2) 606 | self._switch_points(8-1, 5-1) 607 | # 7 with 6 (right ear 1, left ear 3) 608 | self._switch_points(7-1, 6-1) 609 | # 2 with 1 (right eye, left eye) 610 | self._switch_points(2-1, 1-1) 611 | 612 | def flipud(self, image): 613 | """Flip all keypoints vertically. 614 | Args: 615 | image Image with maximum dimensions.""" 616 | for i in range(0, len(self.keypoints_arr), 2): 617 | self.keypoints_arr[i] = (image.get_height()-1) - self.keypoints_arr[i] 618 | 619 | def _switch_points(self, index1, index2): 620 | """Switch the coordinates of two keypoints. 621 | Args: 622 | index1 Index of the first keypoint 623 | index1 Index of the second keypoint 624 | """ 625 | y1 = self.keypoints_arr[index1*2] 626 | x1 = self.keypoints_arr[index1*2+1] 627 | y2 = self.keypoints_arr[index2*2] 628 | x2 = self.keypoints_arr[index2*2+1] 629 | self.keypoints_arr[index1*2] = y2 630 | self.keypoints_arr[index1*2+1] = x2 631 | self.keypoints_arr[index2*2] = y1 632 | self.keypoints_arr[index2*2+1] = x1 633 | 634 | def get_rectangle(self, image, method=4): 635 | """Generate face rectangles based on various methods. 636 | 637 | Face rectangles are rectangles around the facial keypoints that contain various parts 638 | of the face. 639 | Methods: 640 | - 0: Bounding box around all keypoints 641 | - 1: Rectangle 0, translated to the center of the face 642 | - 2: Rectangle 0, translated half-way to the center of the face 643 | - 3: Bounding box around the corners of Rectangle 0 and 2 644 | - 4: Rectangle 3, squared (this is the main rectangle used) 645 | 646 | Args: 647 | image Image with maximum dimensions 648 | method Index of the method 649 | Returns: 650 | Rectangle object 651 | """ 652 | 653 | image_width = image.get_width() 654 | image_height = image.get_height() 655 | 656 | face_center = self.get_face_center() 657 | 658 | if method == 0: 659 | # rectangle 0: bounding box around provided keypoints 660 | return Rectangle(self.get_min_y(), self.get_min_x(), self.get_max_y(), self.get_max_x()) 661 | elif method == 1: 662 | # rectangle 1: the same rectangle as rect 0, but translated to the center of the face 663 | rect = self.get_rectangle(image, method=0) 664 | rect_center = rect.get_center() 665 | diff_y = face_center.y - rect_center.y 666 | diff_x = face_center.x - rect_center.x 667 | 668 | min_x_fcenter = max(0, rect.tl_x + diff_x) 669 | min_y_fcenter = max(0, rect.tl_y + diff_y) 670 | max_x_fcenter = min(image_width-1, rect.br_x + diff_x) 671 | max_y_fcenter = min(image_height-1, rect.br_y + diff_y) 672 | 673 | return Rectangle(min_y_fcenter, min_x_fcenter, max_y_fcenter, max_x_fcenter) 674 | elif method == 2: 675 | # rectangle 2: the same rectangle as rect 0, but translated _half-way_ towards the 676 | # center of the face 677 | rect = self.get_rectangle(image, method=0) 678 | rect_center = rect.get_center() 679 | diff_y = face_center.y - rect_center.y 680 | diff_x = face_center.x - rect_center.x 681 | 682 | min_x_half = int(max(0, rect.tl_x + (diff_x/2))) 683 | min_y_half = int(max(0, rect.tl_y + (diff_y/2))) 684 | max_x_half = int(min(image_width-1, rect.br_x + (diff_x/2))) 685 | max_y_half = int(min(image_height-1, rect.br_y + (diff_y/2))) 686 | 687 | return Rectangle(min_y_half, min_x_half, max_y_half, max_x_half) 688 | elif method == 3: 689 | # rectangle 3: a merge between rect 0 and 2 rectangle, essentially a bounding box around 690 | # the corners of both rectangles 691 | 692 | rect0 = self.get_rectangle(image, method=0) 693 | rect2 = self.get_rectangle(image, method=2) 694 | 695 | min_x_merge = max(0, min(rect0.tl_x, rect2.tl_x)) 696 | min_y_merge = max(0, min(rect0.tl_y, rect2.tl_y)) 697 | max_x_merge = min(image_width-1, max(rect0.br_x, rect2.br_x)) 698 | max_y_merge = min(image_height-1, max(rect0.br_y, rect2.br_y)) 699 | 700 | return Rectangle(min_y_merge, min_x_merge, max_y_merge, max_x_merge) 701 | elif method == 4: 702 | # rectangle 4: like 3, but squared with Rectangle.square() 703 | 704 | rect3 = self.get_rectangle(image, method=3) 705 | rect3.square(image) 706 | return rect3 707 | else: 708 | raise Exception("Unknown rectangle generation method %d chosen." % (method,)) 709 | 710 | def get_rectangles(self, image): 711 | """Returns all facial rectangles. 712 | Args: image: Image with maximum dimensions 713 | Returns: List of Rectangle""" 714 | return [self.get_rectangle(image, method=i) for i in range(0, 5)] 715 | 716 | def to_array(self): 717 | """Returns the keypoints as array of shape (18,).""" 718 | return self.keypoints_arr 719 | 720 | def __str__(self): 721 | """Converts object to string representation.""" 722 | return str(self.keypoints_arr) 723 | 724 | class PointsList(object): 725 | """A helper class encapsulating multiple Point2D.""" 726 | 727 | def __init__(self, points): 728 | """Instantiates a new points list. 729 | Args: 730 | points List of Point2D.""" 731 | self.points = points 732 | 733 | def normalize(self, image): 734 | """Normalizes each point to 0-1 with respect to an image's dimensions.""" 735 | for point in self.points: 736 | point.normalize(image) 737 | 738 | def unnormalize(self, image): 739 | """Unnormalizes each point from 0-1 to integer pixel values with respect to an 740 | image's dimensions.""" 741 | for point in self.points: 742 | point.unnormalize(image) 743 | 744 | def any_normalized(self): 745 | """Returns whether any point in the list has normalized coordinates.""" 746 | return any([point.is_normalized for point in self.points]) 747 | 748 | def all_normalized(self): 749 | """Returns whether all points in the list have normalized coordinates.""" 750 | return all([point.is_normalized for point in self.points]) 751 | 752 | def to_array(self): 753 | """Returns the list of points as a numpy array of shape (nb_points*2).""" 754 | result = np.zeros((len(self.points)*2,), dtype=np.float32) 755 | for i, point in enumerate(self.points): 756 | result[i*2] = point.y 757 | result[i*2 + 1] = point.x 758 | return result 759 | 760 | def __str__(self): 761 | """Returns a string representation of this point list.""" 762 | return str([str(pnt) for pnt in self.points]) 763 | 764 | class Point2D(object): 765 | """A helper class encapsulating a (y, x) coordinate.""" 766 | 767 | def __init__(self, y, x, is_normalized=False): 768 | """Instantiate a new Point2D object. 769 | Args: 770 | y Y-coordinate of point 771 | x X-coordinate of point 772 | is_normalized Whether the coordinates are normalized to 0-1 instead of integer 773 | pixel values""" 774 | if is_normalized: 775 | assert isinstance(y, float), type(y) 776 | assert isinstance(x, float), type(x) 777 | else: 778 | assert isinstance(y, int), type(y) 779 | assert isinstance(x, int), type(x) 780 | self.y = y 781 | self.x = x 782 | self.is_normalized = is_normalized 783 | 784 | def normalize(self, image): 785 | """Normalize the integer pixel values to 0-1 with respect to an image's dimensions. 786 | Args: image: The image which's dimensions to use.""" 787 | assert not self.is_normalized 788 | self.y = self.y / image.get_height() # changes y to float 789 | self.x = self.x / image.get_width() # changes x to float 790 | self.is_normalized = True 791 | 792 | def unnormalize(self, image): 793 | """Unnormalize the 0-1 coordinate value to integer pixel values with respect to an 794 | image's dimensions. 795 | Args: image: The image which's dimensions to use.""" 796 | assert self.is_normalized 797 | self.y = int(self.y * image.get_height()) 798 | self.x = int(self.x * image.get_width()) 799 | self.is_normalized = False 800 | 801 | def warp(self, image, matrix): 802 | """Warp the point's coordinates according to an affine transformation matrix. 803 | Args: 804 | image The image which's dimensions to use. 805 | matrix The affine transformation matrix (from scikit-image) 806 | """ 807 | assert not self.is_normalized 808 | 809 | # This method draws the point as a white pixel on a black image, 810 | # then warps that image according to the matrix 811 | # then reads out the new position of the pixel 812 | # (if its not found / outside of the image then the coordinates will be unchanged). 813 | # This is a very wasteful process as many pixels have to be warped instead of just one. 814 | # There is probably a better method for that, but I don't know it. 815 | image_pnt = np.zeros((image.get_height(), image.get_width()), dtype=np.uint8) 816 | image_pnt[self.y, self.x] = 255 817 | image_pnt_warped = tf.warp(image_pnt, matrix, mode=WARP_KEYPOINTS_MODE, 818 | cval=WARP_KEYPOINTS_CVAL, 819 | order=WARP_KEYPOINTS_INTERPOLATION_ORDER) 820 | maxindex = np.argmax(image_pnt_warped) 821 | if maxindex == 0 and image_pnt_warped[0, 0] < 0.5: 822 | # dont change coordinates 823 | #print("Note: Coordinate (%d, %d) not changed" % (self.y, self.x)) 824 | pass 825 | else: 826 | (y, x) = np.unravel_index(maxindex, image_pnt_warped.shape) 827 | self.y = y 828 | self.x = x 829 | 830 | def to_array(self): 831 | """Returns the coordinate as a numpy array.""" 832 | if self.is_normalized: 833 | return np.array([self.y, self.x], dtype=np.float32) 834 | else: 835 | return np.array([self.y, self.x], dtype=np.uint16) 836 | 837 | def __str__(self): 838 | """Returns a string representation of the coordinate.""" 839 | if self.is_normalized: 840 | return "PN(%.4f, %.4f)" % (self.y, self.x) 841 | else: 842 | return "P(%d, %d)" % (self.y, self.x) 843 | 844 | class Rectangle(object): 845 | """Class representing a rectangle in an image.""" 846 | def __init__(self, tl_y, tl_x, br_y, br_x, is_normalized=False): 847 | """Instantiate a new rectangle. 848 | Args: 849 | tl_y y-coordinate of top left corner 850 | tl_x x-coordinate of top left corner 851 | br_y y-coordinate of bottom right corner 852 | br_x x-coordinate of bottom right corner 853 | is_normalized Whether the coordinates are normalized to 0-1 instead of integer 854 | pixel values""" 855 | assert tl_y >= 0 and tl_x >= 0 and br_y >= 0 and br_x >= 0 856 | assert tl_y < br_y and tl_x < br_x 857 | if is_normalized: 858 | assert all(isinstance(v, float) for v in [tl_y, tl_x, br_y, br_x]) 859 | else: 860 | assert all(isinstance(v, int) for v in [tl_y, tl_x, br_y, br_x]) 861 | 862 | self.tl_y = tl_y 863 | self.tl_x = tl_x 864 | self.br_y = br_y 865 | self.br_x = br_x 866 | self.is_normalized = is_normalized 867 | 868 | def get_width(self): 869 | """Returns the width of the rectangle.""" 870 | return self.br_x - self.tl_x 871 | 872 | def get_height(self): 873 | """Returns the height of the rectangle.""" 874 | return self.br_y - self.tl_y 875 | 876 | def get_center(self): 877 | """Returns the center of the rectangle as a Point2D.""" 878 | y = self.tl_y + (self.get_height() / 2) 879 | x = self.tl_x + (self.get_width() / 2) 880 | if self.is_normalized: 881 | return Point2D(y=float(y), x=float(x), is_normalized=True) 882 | else: 883 | return Point2D(y=int(y), x=int(x), is_normalized=False) 884 | 885 | def square(self, image): 886 | """Squares the rectangle. 887 | It first adds columns/rows until the image's borders are reached. 888 | Then deletes columns/rows until the rectangle is squared. 889 | Args: 890 | image Image which's dimensions to use, i.e. rectangle won't be increased in size 891 | beyond that image's height/width. 892 | """ 893 | assert not self.is_normalized 894 | 895 | img_height = image.get_height() 896 | img_width = image.get_width() 897 | height = self.get_height() 898 | width = self.get_width() 899 | 900 | # extend by adding cols / rows until borders of image are reached 901 | # removed, because only removing cols/rows was really tested. 902 | # Fixme: test with adding cols/rows 903 | # Todo: change method so that it adds and removes cols/rows at the same time 904 | """ 905 | i = 0 906 | while width < height and self.br_x < img_width and self.tl_x > 0: 907 | if i % 2 == 0: 908 | self.tl_x -= 1 909 | else: 910 | self.br_x += 1 911 | width += 1 912 | i += 1 913 | 914 | while height < width and self.br_y < img_height and self.tl_y > 0: 915 | if i % 2 == 0: 916 | self.tl_y -= 1 917 | else: 918 | self.br_y += 1 919 | height += 1 920 | i += 1 921 | """ 922 | 923 | # remove cols / rows until rectangle is squared 924 | # this part was written at a different time, which is why the removal works differently, 925 | # it does however the exactle same thing (move yx coordinates of topleft/bottemright 926 | # corners) 927 | if height > width: 928 | diff = height - width 929 | remove_top = math.floor(diff / 2) 930 | remove_bottom = math.floor(diff / 2) 931 | if diff % 2 != 0: 932 | remove_top += 1 933 | self.tl_y += int(remove_top) 934 | self.br_y -= int(remove_bottom) 935 | elif width > height: 936 | diff = width - height 937 | remove_left = math.floor(diff / 2) 938 | remove_right = math.floor(diff / 2) 939 | if diff % 2 != 0: 940 | remove_left += 1 941 | self.tl_x += int(remove_left) 942 | self.br_x -= int(remove_right) 943 | 944 | def normalize(self, image): 945 | """Normalize integer pixel values to 0-1 with respect to an image. 946 | Args: image: Image which's dimensions to use.""" 947 | assert not self.is_normalized 948 | self.tl_y /= image.get_height() 949 | self.tl_x /= image.get_width() 950 | self.br_y /= image.get_height() 951 | self.br_x /= image.get_width() 952 | self.is_normalized = True 953 | 954 | def unnormalize(self, image): 955 | """Normalize from 0-1 to integer pixel values with respect to an image. 956 | Args: image: Image which's dimensions to use.""" 957 | assert self.is_normalized 958 | self.tl_y *= image.get_height() 959 | self.tl_x *= image.get_width() 960 | self.br_y *= image.get_height() 961 | self.br_x *= image.get_width() 962 | self.is_normalized = False 963 | 964 | def __str__(self): 965 | """Returns a string representation of the rectangle.""" 966 | if self.is_normalized: 967 | return "RN(%.4f, %.4f)x(%.4f, %.4f)" % (self.tl_y, self.tl_x, self.br_y, self.br_x) 968 | else: 969 | return "R(%d, %d)x(%d, %d)" % (self.tl_y, self.tl_x, self.br_y, self.br_x) 970 | 971 | def unit_vector(vector): 972 | """Returns the unit vector of the vector.""" 973 | return vector / np.linalg.norm(vector) 974 | 975 | def angle_between(v1, v2): 976 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 977 | 978 | >>> angle_between((1, 0, 0), (0, 1, 0)) 979 | 1.5707963267948966 980 | >>> angle_between((1, 0, 0), (1, 0, 0)) 981 | 0.0 982 | >>> angle_between((1, 0, 0), (-1, 0, 0)) 983 | 3.141592653589793 984 | """ 985 | v1_u = unit_vector(v1) 986 | v2_u = unit_vector(v2) 987 | angle = np.arccos(np.dot(v1_u, v2_u)) 988 | if np.isnan(angle): 989 | if (v1_u == v2_u).all(): 990 | v = 0.0 991 | else: 992 | v = np.pi 993 | else: 994 | v = angle 995 | 996 | if v2_u[0] < 0: 997 | return -v 998 | else: 999 | return v 1000 | 1001 | def clip(minval, val, maxval): 1002 | """Clips a value between min and max (both including).""" 1003 | if val < minval: 1004 | return minval 1005 | elif val > maxval: 1006 | return maxval 1007 | else: 1008 | return val 1009 | -------------------------------------------------------------------------------- /images/1-266-19922494159_78303f8b16_n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/cat-bbs-regression/fe9bdabf4e019632449bd03d6f6dfe7f884c1dbe/images/1-266-19922494159_78303f8b16_n.jpg -------------------------------------------------------------------------------- /images/1-39-125521249_b1318298ec_n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/cat-bbs-regression/fe9bdabf4e019632449bd03d6f6dfe7f884c1dbe/images/1-39-125521249_b1318298ec_n.jpg -------------------------------------------------------------------------------- /images/1-56-180653960_21cf28e0b3_n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/cat-bbs-regression/fe9bdabf4e019632449bd03d6f6dfe7f884c1dbe/images/1-56-180653960_21cf28e0b3_n.jpg -------------------------------------------------------------------------------- /images/1-61-213767259_11c8550a0e_n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/cat-bbs-regression/fe9bdabf4e019632449bd03d6f6dfe7f884c1dbe/images/1-61-213767259_11c8550a0e_n.jpg -------------------------------------------------------------------------------- /images/3-2287-2088100404_c0112197e3_n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/cat-bbs-regression/fe9bdabf4e019632449bd03d6f6dfe7f884c1dbe/images/3-2287-2088100404_c0112197e3_n.jpg -------------------------------------------------------------------------------- /images/3-2831-10902603864_4993c4aa1a_n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/cat-bbs-regression/fe9bdabf4e019632449bd03d6f6dfe7f884c1dbe/images/3-2831-10902603864_4993c4aa1a_n.jpg -------------------------------------------------------------------------------- /predictions/README.md: -------------------------------------------------------------------------------- 1 | Directory in which `train.py` will save images with their predictions. 2 | -------------------------------------------------------------------------------- /train_convnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Trains a model to locate cat faces in images (assumes that the image contains a cat face). 4 | """ 5 | from __future__ import absolute_import, division, print_function 6 | from dataset import Dataset 7 | import numpy as np 8 | import argparse 9 | import random 10 | import os 11 | from scipy import misc 12 | from skimage import draw 13 | from keras.models import Sequential 14 | from keras.layers.core import Dense, Dropout, Activation, Flatten 15 | from keras.optimizers import Adam 16 | from keras.layers.convolutional import Convolution2D, MaxPooling2D 17 | from keras.callbacks import ModelCheckpoint 18 | 19 | np.random.seed(42) 20 | random.seed(42) 21 | 22 | CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) 23 | 24 | MODEL_IMAGE_HEIGHT = 128 25 | MODEL_IMAGE_WIDTH = 128 26 | PADDING = 20 27 | AUGMENTATIONS = 2 28 | NB_LOAD_IMAGES = 9500 29 | SPLIT = 0.1 30 | EPOCHS = 150 31 | BATCH_SIZE = 64 32 | SAVE_WEIGHTS_FILEPATH = os.path.join(CURRENT_DIR, "cat_face_locator.weights") 33 | SAVE_WEIGHTS_CHECKPOINT_FILEPATH = os.path.join(CURRENT_DIR, "cat_face_locator.best.weights") 34 | SAVE_PREDICTIONS = True 35 | SAVE_PREDICTIONS_DIR = os.path.join(CURRENT_DIR, "predictions") 36 | 37 | def main(): 38 | """Main method that reads the images, trains a model, then saves weights and predictions.""" 39 | parser = argparse.ArgumentParser(description="Train a model to locate cat faces in images.") 40 | parser.add_argument("--dataset", required=True, help="Path to your 10k cats dataset directory") 41 | args = parser.parse_args() 42 | 43 | # initialize dataset 44 | subdir_names = ["CAT_00", "CAT_01", "CAT_02", "CAT_03", "CAT_04", "CAT_05", "CAT_06"] 45 | subdirs = [os.path.join(args.dataset, subdir) for subdir in subdir_names] 46 | dataset = Dataset(subdirs) 47 | 48 | # load images and labels 49 | print("Loading images...") 50 | X, y = load_xy(dataset, NB_LOAD_IMAGES, AUGMENTATIONS) 51 | 52 | # split train and val 53 | nb_images = X.shape[0] 54 | nb_train = int(nb_images * (1 - SPLIT)) 55 | X_train = X[0:nb_train, ...] 56 | y_train = y[0:nb_train, ...] 57 | X_val = X[nb_train:, ...] 58 | y_val = y[nb_train:, ...] 59 | 60 | # create model 61 | print("Creating model...") 62 | model = create_model(MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, "mse", Adam()) 63 | 64 | # fit 65 | checkpoint_cb = ModelCheckpoint(SAVE_WEIGHTS_CHECKPOINT_FILEPATH, verbose=1, \ 66 | save_best_only=True) 67 | model.fit(X_train, y_train, batch_size=BATCH_SIZE, nb_epoch=EPOCHS, validation_split=0.0, 68 | validation_data=(X_val, y_val), 69 | callbacks=[checkpoint_cb]) 70 | 71 | # save weights 72 | print("Saving weights...") 73 | model.save_weights(SAVE_WEIGHTS_FILEPATH, overwrite=True) 74 | 75 | # save predictions on val set 76 | if SAVE_PREDICTIONS: 77 | print("Saving example predictions...") 78 | y_preds = model.predict(X_val, batch_size=BATCH_SIZE) 79 | for img_idx, (y, x, half_height, half_width) in enumerate(y_preds): 80 | img_arr = draw_predicted_rectangle(X_val[img_idx], y, x, half_height, half_width) 81 | filepath = os.path.join(SAVE_PREDICTIONS_DIR, "%d.png" % (img_idx,)) 82 | misc.imsave(filepath, np.squeeze(img_arr)) 83 | 84 | def load_xy(dataset, nb_load, nb_augmentations): 85 | """Loads X and y (examples with labels) for the dataset. 86 | Examples are images. 87 | Labels are the coordinates of the face rectangles with their half-heights and half-widths 88 | (each normalized to 0-1 with respect to the image dimensions.) 89 | 90 | Args: 91 | dataset The Dataset object. 92 | nb_load Intended number of images to load. 93 | nb_augmentations Number of augmentations to perform. 94 | Returns: 95 | X (numpy array of shape (N, 3, height, width)), 96 | y (numpy array of shape (N, 4)) 97 | """ 98 | i = 0 99 | nb_load = min(nb_load, len(dataset.fps)) 100 | nb_images = nb_load + nb_load * nb_augmentations 101 | X = np.zeros((nb_images, MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH, 3), dtype=np.float32) 102 | y = np.zeros((nb_images, 4), dtype=np.float32) 103 | #X = [] 104 | #y = [] 105 | 106 | for img_idx, image in enumerate(dataset.get_images()): 107 | if img_idx % 100 == 0: 108 | print("Loading image %d of %d..." % (img_idx+1, nb_load)) 109 | image.resize(MODEL_IMAGE_HEIGHT, MODEL_IMAGE_WIDTH) 110 | image.pad(PADDING) 111 | augs = image.augment(nb_augmentations, hflip=True, vflip=False, 112 | scale_to_percent=(0.9, 1.1), scale_axis_equally=False, 113 | rotation_deg=10, shear_deg=0, translation_x_px=5, translation_y_px=5, 114 | brightness_change=0.1, noise_mean=0.0, noise_std=0.05) 115 | for aug in [image] + augs: 116 | aug.unpad(PADDING) 117 | X[i] = aug.to_array() / 255.0 118 | face_rect = aug.keypoints.get_rectangle(aug) 119 | face_rect.normalize(aug) 120 | center = face_rect.get_center() 121 | width = face_rect.get_width() / 2 122 | height = face_rect.get_height() / 2 123 | y[i] = [center.y, center.x, height, width] 124 | i += 1 125 | 126 | if i >= nb_images: 127 | break 128 | if i >= nb_images: 129 | break 130 | 131 | #X = np.array(X, dtype=np.float32) 132 | #y = np.array(y, dtype=np.float32) 133 | X = np.rollaxis(X, 3, 1) 134 | 135 | return X, y 136 | 137 | def unnormalize_prediction(y, x, half_height, half_width, \ 138 | img_height=MODEL_IMAGE_HEIGHT, img_width=MODEL_IMAGE_WIDTH): 139 | """Transforms a predictions from normalized (0 to 1) y, x, half-width, 140 | half-height to pixel values (top left y, top left x, bottom right y, 141 | bottom right x). 142 | Args: 143 | y: Normalized y coordinate of rectangle center. 144 | x: Normalized x coordinate of rectangle center. 145 | half_height: Normalized height of rectangle. 146 | half_width: Normalized width of rectangle. 147 | img_height: Height of the image to use while unnormalizing. 148 | img_width: Width of the image to use while unnormalizing. 149 | Returns: 150 | (top left y in px, top left x in px, bottom right y in px, 151 | bottom right x in px) 152 | """ 153 | # calculate x, y of corners in pixels 154 | tl_y = int((y - half_height) * img_height) 155 | tl_x = int((x - half_width) * img_width) 156 | br_y = int((y + half_height) * img_height) 157 | br_x = int((x + half_width) * img_width) 158 | 159 | # make sure that x and y coordinates are within image boundaries 160 | tl_y = clip(0, tl_y, img_height-2) 161 | tl_x = clip(0, tl_x, img_width-2) 162 | br_y = clip(0, br_y, img_height-1) 163 | br_x = clip(0, br_x, img_width-1) 164 | 165 | # make sure that top left corner is really top left of bottom right values 166 | if tl_y > br_y: 167 | tl_y, br_y = br_y, tl_y 168 | if tl_x > br_x: 169 | tl_x, br_x = br_x, tl_x 170 | 171 | # make sure that the area covered is at least 1px, 172 | # move preferably the top left corner 173 | # but dont move it outside of the image 174 | if tl_y == br_y: 175 | if tl_y == 0: 176 | br_y += 1 177 | else: 178 | tl_y -= 1 179 | 180 | if tl_x == br_x: 181 | if tl_x == 0: 182 | br_x += 1 183 | else: 184 | tl_x -= 1 185 | 186 | return tl_y, tl_x, br_y, br_x 187 | 188 | def draw_predicted_rectangle(image_arr, y, x, half_height, half_width): 189 | """Draws a rectangle onto the image at the provided coordinates. 190 | Args: 191 | image_arr: Numpy array of the image. 192 | y: y-coordinate of the rectangle (normalized to 0-1). 193 | x: x-coordinate of the rectangle (normalized to 0-1). 194 | half_height: Half of the height of the rectangle (normalized to 0-1). 195 | half_width: Half of the width of the rectangle (normalized to 0-1). 196 | Returns: 197 | Modified image (numpy array) 198 | """ 199 | assert image_arr.shape[0] == 3, str(image_arr.shape) 200 | height = image_arr.shape[1] 201 | width = image_arr.shape[2] 202 | tl_y, tl_x, br_y, br_x = unnormalize_prediction(y, x, half_height, half_width, \ 203 | img_height=height, img_width=width) 204 | image_arr = np.copy(image_arr) * 255 205 | image_arr = np.rollaxis(image_arr, 0, 3) 206 | return draw_rectangle(image_arr, tl_y, tl_x, br_y, br_x) 207 | 208 | def draw_rectangle(img, tl_y, tl_x, br_y, br_x): 209 | """Draws a rectangle onto an image. 210 | Args: 211 | img: The image as a numpy array of shape (row, col, channel). 212 | tl_y: Top left y coordinate as pixel. 213 | tl_x: Top left x coordinate as pixel. 214 | br_y: Top left y coordinate as pixel. 215 | br_x: Top left x coordinate as pixel. 216 | Returns: 217 | image with rectangle 218 | """ 219 | assert img.shape[2] == 3, img.shape[2] 220 | img = np.copy(img) 221 | lines = [ 222 | (tl_y, tl_x, tl_y, br_x), # top left to top right 223 | (tl_y, br_x, br_y, br_x), # top right to bottom right 224 | (br_y, br_x, br_y, tl_x), # bottom right to bottom left 225 | (br_y, tl_x, tl_y, tl_x) # bottom left to top left 226 | ] 227 | for y0, x0, y1, x1 in lines: 228 | rr, cc, val = draw.line_aa(y0, x0, y1, x1) 229 | img[rr, cc, 0] = val * 255 230 | 231 | return img 232 | 233 | def clip(lower, val, upper): 234 | """Clips a value. For lower bound L, upper bound U and value V it 235 | makes sure that L <= V <= U holds. 236 | Args: 237 | lower: Lower boundary (including) 238 | val: The value to clip 239 | upper: Upper boundary (including) 240 | Returns: 241 | value within bounds 242 | """ 243 | if val < lower: 244 | return lower 245 | elif val > upper: 246 | return upper 247 | else: 248 | return val 249 | 250 | def create_model(image_height, image_width, loss, optimizer): 251 | """Creates the cat face locator model. 252 | 253 | Args: 254 | image_height: The height of the input images. 255 | image_width: The width of the input images. 256 | loss: Keras loss function (name or object), e.g. "mse". 257 | optimizer: Keras optimizer to use, e.g. Adam() or "sgd". 258 | Returns: 259 | Sequential 260 | """ 261 | 262 | model = Sequential() 263 | 264 | # 3x128x128 265 | model.add(Convolution2D(32, 3, 3, border_mode="same", \ 266 | input_shape=(3, image_height, image_width))) 267 | model.add(Activation("relu")) 268 | model.add(Dropout(0.0)) 269 | 270 | # 32x128x128 271 | model.add(Convolution2D(32, 3, 3, border_mode="same")) 272 | model.add(Activation("relu")) 273 | model.add(Dropout(0.0)) 274 | 275 | # 32x128x128 276 | model.add(Convolution2D(32, 3, 3, border_mode="same")) 277 | model.add(Activation("relu")) 278 | model.add(MaxPooling2D((2, 2))) 279 | model.add(Dropout(0.5)) 280 | 281 | # 64x64x64 282 | model.add(Convolution2D(64, 3, 3, border_mode="same")) 283 | model.add(Activation("relu")) 284 | model.add(MaxPooling2D((2, 2))) 285 | model.add(Dropout(0.5)) 286 | 287 | # 128x32x32 288 | model.add(Convolution2D(128, 3, 3, border_mode="same")) 289 | model.add(Activation("relu")) 290 | model.add(MaxPooling2D((2, 2))) 291 | model.add(Dropout(0.5)) 292 | 293 | # 128x16x16 294 | model.add(Convolution2D(256, 3, 3, border_mode="same")) 295 | model.add(Activation("relu")) 296 | model.add(MaxPooling2D((2, 2))) 297 | model.add(Dropout(0.5)) 298 | 299 | # 256x8x8 = 16384 300 | model.add(Flatten()) 301 | 302 | model.add(Dense(256)) 303 | model.add(Activation("tanh")) 304 | model.add(Dropout(0.5)) 305 | 306 | model.add(Dense(4)) 307 | model.add(Activation("sigmoid")) 308 | 309 | # compile with mean squared error 310 | print("Compiling...") 311 | model.compile(loss=loss, optimizer=optimizer) 312 | 313 | return model 314 | 315 | if __name__ == "__main__": 316 | main() 317 | --------------------------------------------------------------------------------