├── MNIST_data ├── t10k-images-idx3-ubyte.gz ├── t10k-labels-idx1-ubyte.gz ├── train-images-idx3-ubyte.gz └── train-labels-idx1-ubyte.gz ├── README.md ├── examples ├── example_linear.gif ├── example_sinusoid.gif ├── inverse.mp4 ├── inverse.webm ├── inverse_23_gen.mp4 ├── normal.mp4 ├── normal.webm ├── normal_23_gen.mp4 └── normal_23_gen.webm ├── images2gif.py ├── mnist_data.py ├── model.py ├── ops.py ├── sampler.py ├── save ├── checkpoint ├── config.pkl ├── model.ckpt-2 ├── model.ckpt-2.meta ├── model.ckpt-23 ├── model.ckpt-23.meta ├── model.ckpt-4 └── model.ckpt-4.meta └── train.py /MNIST_data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/MNIST_data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/MNIST_data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/MNIST_data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /MNIST_data/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/MNIST_data/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # resnet-cppn-gan-tensorflow 2 | 3 | Improvements made for training [Compositional Pattern Producing Network](https://en.wikipedia.org/wiki/Compositional_pattern-producing_network) as a Generative Model, using Residual Generative Adversarial Networks and Variational Autoencoder techniques to produce high resolution images. 4 | 5 | ![Morphing2](https://cdn.rawgit.com/hardmaru/resnet-cppn-gan-tensorflow/master/examples/example_sinusoid.gif) 6 | 7 | Run `python train.py` from the command line to train from scratch and experiment with different settings. 8 | 9 | `sampler.py` can be used inside IPython to interactively see results from the models being trained. 10 | 11 | See my blog post at [blog.otoro.net](http://blog.otoro.net/2016/06/02/generating-large-images-from-latent-vectors-part-two/) for more details. 12 | 13 | ![Morphing1](https://cdn.rawgit.com/hardmaru/resnet-cppn-gan-tensorflow/master/examples/example_linear.gif) 14 | 15 | I tested the implementation on TensorFlow 0.80. 16 | 17 | Used images2gif.py written by Almar Klein, Ant1, Marius van Voorden. 18 | 19 | # License 20 | 21 | BSD - images2gif.py 22 | 23 | MIT - everything else 24 | -------------------------------------------------------------------------------- /examples/example_linear.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/example_linear.gif -------------------------------------------------------------------------------- /examples/example_sinusoid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/example_sinusoid.gif -------------------------------------------------------------------------------- /examples/inverse.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/inverse.mp4 -------------------------------------------------------------------------------- /examples/inverse.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/inverse.webm -------------------------------------------------------------------------------- /examples/inverse_23_gen.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/inverse_23_gen.mp4 -------------------------------------------------------------------------------- /examples/normal.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/normal.mp4 -------------------------------------------------------------------------------- /examples/normal.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/normal.webm -------------------------------------------------------------------------------- /examples/normal_23_gen.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/normal_23_gen.mp4 -------------------------------------------------------------------------------- /examples/normal_23_gen.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/examples/normal_23_gen.webm -------------------------------------------------------------------------------- /images2gif.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (C) 2012, Almar Klein, Ant1, Marius van Voorden 3 | # 4 | # This code is subject to the (new) BSD license: 5 | # 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted provided that the following conditions are met: 8 | # * Redistributions of source code must retain the above copyright 9 | # notice, this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # * Neither the name of the nor the 14 | # names of its contributors may be used to endorse or promote products 15 | # derived from this software without specific prior written permission. 16 | # 17 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | # ARE DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY 21 | # DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 24 | # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | 28 | """ Module images2gif 29 | 30 | Provides functionality for reading and writing animated GIF images. 31 | Use writeGif to write a series of numpy arrays or PIL images as an 32 | animated GIF. Use readGif to read an animated gif as a series of numpy 33 | arrays. 34 | 35 | Note that since July 2004, all patents on the LZW compression patent have 36 | expired. Therefore the GIF format may now be used freely. 37 | 38 | Acknowledgements 39 | ---------------- 40 | 41 | Many thanks to Ant1 for: 42 | * noting the use of "palette=PIL.Image.ADAPTIVE", which significantly 43 | improves the results. 44 | * the modifications to save each image with its own palette, or optionally 45 | the global palette (if its the same). 46 | 47 | Many thanks to Marius van Voorden for porting the NeuQuant quantization 48 | algorithm of Anthony Dekker to Python (See the NeuQuant class for its 49 | license). 50 | 51 | Many thanks to Alex Robinson for implementing the concept of subrectangles, 52 | which (depening on image content) can give a very significant reduction in 53 | file size. 54 | 55 | This code is based on gifmaker (in the scripts folder of the source 56 | distribution of PIL) 57 | 58 | 59 | Usefull links 60 | ------------- 61 | * http://tronche.com/computer-graphics/gif/ 62 | * http://en.wikipedia.org/wiki/Graphics_Interchange_Format 63 | * http://www.w3.org/Graphics/GIF/spec-gif89a.txt 64 | 65 | """ 66 | # todo: This module should be part of imageio (or at least based on) 67 | 68 | import os, time 69 | 70 | def encode(x): 71 | if False: 72 | return x.encode('utf-8') 73 | return x 74 | 75 | try: 76 | import PIL 77 | from PIL import Image 78 | from PIL.GifImagePlugin import getheader, getdata 79 | except ImportError: 80 | PIL = None 81 | 82 | try: 83 | import numpy as np 84 | except ImportError: 85 | np = None 86 | 87 | def get_cKDTree(): 88 | try: 89 | from scipy.spatial import cKDTree 90 | except ImportError: 91 | cKDTree = None 92 | return cKDTree 93 | 94 | 95 | # getheader gives a 87a header and a color palette (two elements in a list). 96 | # getdata()[0] gives the Image Descriptor up to (including) "LZW min code size". 97 | # getdatas()[1:] is the image data itself in chuncks of 256 bytes (well 98 | # technically the first byte says how many bytes follow, after which that 99 | # amount (max 255) follows). 100 | 101 | def checkImages(images): 102 | """ checkImages(images) 103 | Check numpy images and correct intensity range etc. 104 | The same for all movie formats. 105 | """ 106 | # Init results 107 | images2 = [] 108 | 109 | for im in images: 110 | if PIL and isinstance(im, PIL.Image.Image): 111 | # We assume PIL images are allright 112 | images2.append(im) 113 | 114 | elif np and isinstance(im, np.ndarray): 115 | # Check and convert dtype 116 | if im.dtype == np.uint8: 117 | images2.append(im) # Ok 118 | elif im.dtype in [np.float32, np.float64]: 119 | im = im.copy() 120 | im[im<0] = 0 121 | im[im>1] = 1 122 | im *= 255 123 | images2.append( im.astype(np.uint8) ) 124 | else: 125 | im = im.astype(np.uint8) 126 | images2.append(im) 127 | # Check size 128 | if im.ndim == 2: 129 | pass # ok 130 | elif im.ndim == 3: 131 | if im.shape[2] not in [3,4]: 132 | raise ValueError('This array can not represent an image.') 133 | else: 134 | raise ValueError('This array can not represent an image.') 135 | else: 136 | raise ValueError('Invalid image type: ' + str(type(im))) 137 | 138 | # Done 139 | return images2 140 | 141 | 142 | def intToBin(i): 143 | """ Integer to two bytes """ 144 | # devide in two parts (bytes) 145 | i1 = i % 256 146 | i2 = int( i/256) 147 | # make string (little endian) 148 | return chr(i1) + chr(i2) 149 | 150 | 151 | class GifWriter: 152 | """ GifWriter() 153 | 154 | Class that contains methods for helping write the animated GIF file. 155 | 156 | """ 157 | 158 | def getheaderAnim(self, im): 159 | """ getheaderAnim(im) 160 | 161 | Get animation header. To replace PILs getheader()[0] 162 | 163 | """ 164 | bb = "GIF89a" 165 | bb += intToBin(im.size[0]) 166 | bb += intToBin(im.size[1]) 167 | bb += "\x87\x00\x00" 168 | return bb 169 | 170 | 171 | def getImageDescriptor(self, im, xy=None): 172 | """ getImageDescriptor(im, xy=None) 173 | 174 | Used for the local color table properties per image. 175 | Otherwise global color table applies to all frames irrespective of 176 | whether additional colors comes in play that require a redefined 177 | palette. Still a maximum of 256 color per frame, obviously. 178 | 179 | Written by Ant1 on 2010-08-22 180 | Modified by Alex Robinson in Janurari 2011 to implement subrectangles. 181 | 182 | """ 183 | 184 | # Defaule use full image and place at upper left 185 | if xy is None: 186 | xy = (0,0) 187 | 188 | # Image separator, 189 | bb = '\x2C' 190 | 191 | # Image position and size 192 | bb += intToBin( xy[0] ) # Left position 193 | bb += intToBin( xy[1] ) # Top position 194 | bb += intToBin( im.size[0] ) # image width 195 | bb += intToBin( im.size[1] ) # image height 196 | 197 | # packed field: local color table flag1, interlace0, sorted table0, 198 | # reserved00, lct size111=7=2^(7+1)=256. 199 | bb += '\x87' 200 | 201 | # LZW minimum size code now comes later, begining of [image data] blocks 202 | return bb 203 | 204 | 205 | def getAppExt(self, loops=float('inf')): 206 | """ getAppExt(loops=float('inf')) 207 | 208 | Application extention. This part specifies the amount of loops. 209 | If loops is 0 or inf, it goes on infinitely. 210 | 211 | """ 212 | 213 | if loops==0 or loops==float('inf'): 214 | loops = 2**16-1 215 | #bb = "" # application extension should not be used 216 | # (the extension interprets zero loops 217 | # to mean an infinite number of loops) 218 | # Mmm, does not seem to work 219 | if True: 220 | bb = "\x21\xFF\x0B" # application extension 221 | bb += "NETSCAPE2.0" 222 | bb += "\x03\x01" 223 | bb += intToBin(loops) 224 | bb += '\x00' # end 225 | return bb 226 | 227 | 228 | def getGraphicsControlExt(self, duration=0.1, dispose=2): 229 | """ getGraphicsControlExt(duration=0.1, dispose=2) 230 | 231 | Graphics Control Extension. A sort of header at the start of 232 | each image. Specifies duration and transparancy. 233 | 234 | Dispose 235 | ------- 236 | * 0 - No disposal specified. 237 | * 1 - Do not dispose. The graphic is to be left in place. 238 | * 2 - Restore to background color. The area used by the graphic 239 | must be restored to the background color. 240 | * 3 - Restore to previous. The decoder is required to restore the 241 | area overwritten by the graphic with what was there prior to 242 | rendering the graphic. 243 | * 4-7 -To be defined. 244 | 245 | """ 246 | 247 | bb = '\x21\xF9\x04' 248 | bb += chr((dispose & 3) << 2) # low bit 1 == transparency, 249 | # 2nd bit 1 == user input , next 3 bits, the low two of which are used, 250 | # are dispose. 251 | bb += intToBin( int(duration*100) ) # in 100th of seconds 252 | bb += '\x00' # no transparant color 253 | bb += '\x00' # end 254 | return bb 255 | 256 | 257 | def handleSubRectangles(self, images, subRectangles): 258 | """ handleSubRectangles(images) 259 | 260 | Handle the sub-rectangle stuff. If the rectangles are given by the 261 | user, the values are checked. Otherwise the subrectangles are 262 | calculated automatically. 263 | 264 | """ 265 | 266 | if isinstance(subRectangles, (tuple,list)): 267 | # xy given directly 268 | 269 | # Check xy 270 | xy = subRectangles 271 | if xy is None: 272 | xy = (0,0) 273 | if hasattr(xy, '__len__'): 274 | if len(xy) == len(images): 275 | xy = [xxyy for xxyy in xy] 276 | else: 277 | raise ValueError("len(xy) doesn't match amount of images.") 278 | else: 279 | xy = [xy for im in images] 280 | xy[0] = (0,0) 281 | 282 | else: 283 | # Calculate xy using some basic image processing 284 | 285 | # Check Numpy 286 | if np is None: 287 | raise RuntimeError("Need Numpy to use auto-subRectangles.") 288 | 289 | # First make numpy arrays if required 290 | for i in range(len(images)): 291 | im = images[i] 292 | if isinstance(im, Image.Image): 293 | tmp = im.convert() # Make without palette 294 | a = np.asarray(tmp) 295 | if len(a.shape)==0: 296 | raise MemoryError("Too little memory to convert PIL image to array") 297 | images[i] = a 298 | 299 | # Determine the sub rectangles 300 | images, xy = self.getSubRectangles(images) 301 | 302 | # Done 303 | return images, xy 304 | 305 | 306 | def getSubRectangles(self, ims): 307 | """ getSubRectangles(ims) 308 | 309 | Calculate the minimal rectangles that need updating each frame. 310 | Returns a two-element tuple containing the cropped images and a 311 | list of x-y positions. 312 | 313 | Calculating the subrectangles takes extra time, obviously. However, 314 | if the image sizes were reduced, the actual writing of the GIF 315 | goes faster. In some cases applying this method produces a GIF faster. 316 | 317 | """ 318 | 319 | # Check image count 320 | if len(ims) < 2: 321 | return ims, [(0,0) for i in ims] 322 | 323 | # We need numpy 324 | if np is None: 325 | raise RuntimeError("Need Numpy to calculate sub-rectangles. ") 326 | 327 | # Prepare 328 | ims2 = [ims[0]] 329 | xy = [(0,0)] 330 | t0 = time.time() 331 | 332 | # Iterate over images 333 | prev = ims[0] 334 | for im in ims[1:]: 335 | 336 | # Get difference, sum over colors 337 | diff = np.abs(im-prev) 338 | if diff.ndim==3: 339 | diff = diff.sum(2) 340 | # Get begin and end for both dimensions 341 | X = np.argwhere(diff.sum(0)) 342 | Y = np.argwhere(diff.sum(1)) 343 | # Get rect coordinates 344 | if X.size and Y.size: 345 | x0, x1 = X[0], X[-1]+1 346 | y0, y1 = Y[0], Y[-1]+1 347 | else: # No change ... make it minimal 348 | x0, x1 = 0, 2 349 | y0, y1 = 0, 2 350 | 351 | # Cut out and store 352 | im2 = im[y0:y1,x0:x1] 353 | prev = im 354 | ims2.append(im2) 355 | xy.append((x0,y0)) 356 | 357 | # Done 358 | #print('%1.2f seconds to determine subrectangles of %i images' % 359 | # (time.time()-t0, len(ims2)) ) 360 | return ims2, xy 361 | 362 | 363 | def convertImagesToPIL(self, images, dither, nq=0): 364 | """ convertImagesToPIL(images, nq=0) 365 | 366 | Convert images to Paletted PIL images, which can then be 367 | written to a single animaged GIF. 368 | 369 | """ 370 | 371 | # Convert to PIL images 372 | images2 = [] 373 | for im in images: 374 | if isinstance(im, Image.Image): 375 | images2.append(im) 376 | elif np and isinstance(im, np.ndarray): 377 | if im.ndim==3 and im.shape[2]==3: 378 | im = Image.fromarray(im,'RGB') 379 | elif im.ndim==3 and im.shape[2]==4: 380 | im = Image.fromarray(im[:,:,:3],'RGB') 381 | elif im.ndim==2: 382 | im = Image.fromarray(im,'L') 383 | images2.append(im) 384 | 385 | # Convert to paletted PIL images 386 | images, images2 = images2, [] 387 | if nq >= 1: 388 | # NeuQuant algorithm 389 | for im in images: 390 | im = im.convert("RGBA") # NQ assumes RGBA 391 | nqInstance = NeuQuant(im, int(nq)) # Learn colors from image 392 | if dither: 393 | im = im.convert("RGB").quantize(palette=nqInstance.paletteImage()) 394 | else: 395 | im = nqInstance.quantize(im) # Use to quantize the image itself 396 | images2.append(im) 397 | else: 398 | # Adaptive PIL algorithm 399 | AD = Image.ADAPTIVE 400 | for im in images: 401 | im = im.convert('P', palette=AD, dither=dither) 402 | images2.append(im) 403 | 404 | # Done 405 | return images2 406 | 407 | 408 | def writeGifToFile(self, fp, images, durations, loops, xys, disposes): 409 | """ writeGifToFile(fp, images, durations, loops, xys, disposes) 410 | 411 | Given a set of images writes the bytes to the specified stream. 412 | 413 | """ 414 | 415 | # Obtain palette for all images and count each occurance 416 | palettes, occur = [], [] 417 | for im in images: 418 | #palette = getheader(im)[1] 419 | palette = getheader(im)[0][-1] 420 | if not palette: 421 | #palette = PIL.ImagePalette.ImageColor 422 | palette = im.palette.tobytes() 423 | palettes.append(palette) 424 | for palette in palettes: 425 | occur.append( palettes.count( palette ) ) 426 | 427 | # Select most-used palette as the global one (or first in case no max) 428 | globalPalette = palettes[ occur.index(max(occur)) ] 429 | 430 | # Init 431 | frames = 0 432 | firstFrame = True 433 | 434 | 435 | for im, palette in zip(images, palettes): 436 | 437 | if firstFrame: 438 | # Write header 439 | 440 | # Gather info 441 | header = self.getheaderAnim(im) 442 | appext = self.getAppExt(loops) 443 | 444 | # Write 445 | fp.write(encode(header)) 446 | fp.write(globalPalette) 447 | fp.write(encode(appext)) 448 | 449 | # Next frame is not the first 450 | firstFrame = False 451 | 452 | if True: 453 | # Write palette and image data 454 | 455 | # Gather info 456 | data = getdata(im) 457 | imdes, data = data[0], data[1:] 458 | graphext = self.getGraphicsControlExt(durations[frames], 459 | disposes[frames]) 460 | # Make image descriptor suitable for using 256 local color palette 461 | lid = self.getImageDescriptor(im, xys[frames]) 462 | 463 | # Write local header 464 | if (palette != globalPalette) or (disposes[frames] != 2): 465 | # Use local color palette 466 | fp.write(encode(graphext)) 467 | fp.write(encode(lid)) # write suitable image descriptor 468 | fp.write(palette) # write local color table 469 | fp.write(encode('\x08')) # LZW minimum size code 470 | else: 471 | # Use global color palette 472 | fp.write(encode(graphext)) 473 | fp.write(imdes) # write suitable image descriptor 474 | 475 | # Write image data 476 | for d in data: 477 | fp.write(d) 478 | 479 | # Prepare for next round 480 | frames = frames + 1 481 | 482 | fp.write(encode(";")) # end gif 483 | return frames 484 | 485 | 486 | 487 | 488 | ## Exposed functions 489 | 490 | def writeGif(filename, images, duration=0.1, repeat=True, dither=False, 491 | nq=0, subRectangles=True, dispose=None): 492 | """ writeGif(filename, images, duration=0.1, repeat=True, dither=False, 493 | nq=0, subRectangles=True, dispose=None) 494 | 495 | Write an animated gif from the specified images. 496 | 497 | Parameters 498 | ---------- 499 | filename : string 500 | The name of the file to write the image to. 501 | images : list 502 | Should be a list consisting of PIL images or numpy arrays. 503 | The latter should be between 0 and 255 for integer types, and 504 | between 0 and 1 for float types. 505 | duration : scalar or list of scalars 506 | The duration for all frames, or (if a list) for each frame. 507 | repeat : bool or integer 508 | The amount of loops. If True, loops infinitetely. 509 | dither : bool 510 | Whether to apply dithering 511 | nq : integer 512 | If nonzero, applies the NeuQuant quantization algorithm to create 513 | the color palette. This algorithm is superior, but slower than 514 | the standard PIL algorithm. The value of nq is the quality 515 | parameter. 1 represents the best quality. 10 is in general a 516 | good tradeoff between quality and speed. When using this option, 517 | better results are usually obtained when subRectangles is False. 518 | subRectangles : False, True, or a list of 2-element tuples 519 | Whether to use sub-rectangles. If True, the minimal rectangle that 520 | is required to update each frame is automatically detected. This 521 | can give significant reductions in file size, particularly if only 522 | a part of the image changes. One can also give a list of x-y 523 | coordinates if you want to do the cropping yourself. The default 524 | is True. 525 | dispose : int 526 | How to dispose each frame. 1 means that each frame is to be left 527 | in place. 2 means the background color should be restored after 528 | each frame. 3 means the decoder should restore the previous frame. 529 | If subRectangles==False, the default is 2, otherwise it is 1. 530 | 531 | """ 532 | 533 | # Check PIL 534 | if PIL is None: 535 | raise RuntimeError("Need PIL to write animated gif files.") 536 | 537 | # Check images 538 | images = checkImages(images) 539 | 540 | # Instantiate writer object 541 | gifWriter = GifWriter() 542 | 543 | # Check loops 544 | if repeat is False: 545 | loops = 1 546 | elif repeat is True: 547 | loops = 0 # zero means infinite 548 | else: 549 | loops = int(repeat) 550 | 551 | # Check duration 552 | if hasattr(duration, '__len__'): 553 | if len(duration) == len(images): 554 | duration = [d for d in duration] 555 | else: 556 | raise ValueError("len(duration) doesn't match amount of images.") 557 | else: 558 | duration = [duration for im in images] 559 | 560 | # Check subrectangles 561 | if subRectangles: 562 | images, xy = gifWriter.handleSubRectangles(images, subRectangles) 563 | defaultDispose = 1 # Leave image in place 564 | else: 565 | # Normal mode 566 | xy = [(0,0) for im in images] 567 | defaultDispose = 2 # Restore to background color. 568 | 569 | # Check dispose 570 | if dispose is None: 571 | dispose = defaultDispose 572 | if hasattr(dispose, '__len__'): 573 | if len(dispose) != len(images): 574 | raise ValueError("len(xy) doesn't match amount of images.") 575 | else: 576 | dispose = [dispose for im in images] 577 | 578 | 579 | # Make images in a format that we can write easy 580 | images = gifWriter.convertImagesToPIL(images, dither, nq) 581 | 582 | # Write 583 | fp = open(filename, 'wb') 584 | try: 585 | gifWriter.writeGifToFile(fp, images, duration, loops, xy, dispose) 586 | finally: 587 | fp.close() 588 | 589 | 590 | 591 | def readGif(filename, asNumpy=True): 592 | """ readGif(filename, asNumpy=True) 593 | 594 | Read images from an animated GIF file. Returns a list of numpy 595 | arrays, or, if asNumpy is false, a list if PIL images. 596 | 597 | """ 598 | 599 | # Check PIL 600 | if PIL is None: 601 | raise RuntimeError("Need PIL to read animated gif files.") 602 | 603 | # Check Numpy 604 | if np is None: 605 | raise RuntimeError("Need Numpy to read animated gif files.") 606 | 607 | # Check whether it exists 608 | if not os.path.isfile(filename): 609 | raise IOError('File not found: '+str(filename)) 610 | 611 | # Load file using PIL 612 | pilIm = PIL.Image.open(filename) 613 | pilIm.seek(0) 614 | 615 | # Read all images inside 616 | images = [] 617 | try: 618 | while True: 619 | # Get image as numpy array 620 | tmp = pilIm.convert() # Make without palette 621 | a = np.asarray(tmp) 622 | if len(a.shape)==0: 623 | raise MemoryError("Too little memory to convert PIL image to array") 624 | # Store, and next 625 | images.append(a) 626 | pilIm.seek(pilIm.tell()+1) 627 | except EOFError: 628 | pass 629 | 630 | # Convert to normal PIL images if needed 631 | if not asNumpy: 632 | images2 = images 633 | images = [] 634 | for im in images2: 635 | images.append( PIL.Image.fromarray(im) ) 636 | 637 | # Done 638 | return images 639 | 640 | 641 | class NeuQuant: 642 | """ NeuQuant(image, samplefac=10, colors=256) 643 | 644 | samplefac should be an integer number of 1 or higher, 1 645 | being the highest quality, but the slowest performance. 646 | With avalue of 10, one tenth of all pixels are used during 647 | training. This value seems a nice tradeof between speed 648 | and quality. 649 | 650 | colors is the amount of colors to reduce the image to. This 651 | should best be a power of two. 652 | 653 | See also: 654 | http://members.ozemail.com.au/~dekker/NEUQUANT.HTML 655 | 656 | License of the NeuQuant Neural-Net Quantization Algorithm 657 | --------------------------------------------------------- 658 | 659 | Copyright (c) 1994 Anthony Dekker 660 | Ported to python by Marius van Voorden in 2010 661 | 662 | NEUQUANT Neural-Net quantization algorithm by Anthony Dekker, 1994. 663 | See "Kohonen neural networks for optimal colour quantization" 664 | in "network: Computation in Neural Systems" Vol. 5 (1994) pp 351-367. 665 | for a discussion of the algorithm. 666 | See also http://members.ozemail.com.au/~dekker/NEUQUANT.HTML 667 | 668 | Any party obtaining a copy of these files from the author, directly or 669 | indirectly, is granted, free of charge, a full and unrestricted irrevocable, 670 | world-wide, paid up, royalty-free, nonexclusive right and license to deal 671 | in this software and documentation files (the "Software"), including without 672 | limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 673 | and/or sell copies of the Software, and to permit persons who receive 674 | copies from any such party to do so, with the only requirement being 675 | that this copyright notice remain intact. 676 | 677 | """ 678 | 679 | NCYCLES = None # Number of learning cycles 680 | NETSIZE = None # Number of colours used 681 | SPECIALS = None # Number of reserved colours used 682 | BGCOLOR = None # Reserved background colour 683 | CUTNETSIZE = None 684 | MAXNETPOS = None 685 | 686 | INITRAD = None # For 256 colours, radius starts at 32 687 | RADIUSBIASSHIFT = None 688 | RADIUSBIAS = None 689 | INITBIASRADIUS = None 690 | RADIUSDEC = None # Factor of 1/30 each cycle 691 | 692 | ALPHABIASSHIFT = None 693 | INITALPHA = None # biased by 10 bits 694 | 695 | GAMMA = None 696 | BETA = None 697 | BETAGAMMA = None 698 | 699 | network = None # The network itself 700 | colormap = None # The network itself 701 | 702 | netindex = None # For network lookup - really 256 703 | 704 | bias = None # Bias and freq arrays for learning 705 | freq = None 706 | 707 | pimage = None 708 | 709 | # Four primes near 500 - assume no image has a length so large 710 | # that it is divisible by all four primes 711 | PRIME1 = 499 712 | PRIME2 = 491 713 | PRIME3 = 487 714 | PRIME4 = 503 715 | MAXPRIME = PRIME4 716 | 717 | pixels = None 718 | samplefac = None 719 | 720 | a_s = None 721 | 722 | 723 | def setconstants(self, samplefac, colors): 724 | self.NCYCLES = 100 # Number of learning cycles 725 | self.NETSIZE = colors # Number of colours used 726 | self.SPECIALS = 3 # Number of reserved colours used 727 | self.BGCOLOR = self.SPECIALS-1 # Reserved background colour 728 | self.CUTNETSIZE = self.NETSIZE - self.SPECIALS 729 | self.MAXNETPOS = self.NETSIZE - 1 730 | 731 | self.INITRAD = self.NETSIZE/8 # For 256 colours, radius starts at 32 732 | self.RADIUSBIASSHIFT = 6 733 | self.RADIUSBIAS = 1 << self.RADIUSBIASSHIFT 734 | self.INITBIASRADIUS = self.INITRAD * self.RADIUSBIAS 735 | self.RADIUSDEC = 30 # Factor of 1/30 each cycle 736 | 737 | self.ALPHABIASSHIFT = 10 # Alpha starts at 1 738 | self.INITALPHA = 1 << self.ALPHABIASSHIFT # biased by 10 bits 739 | 740 | self.GAMMA = 1024.0 741 | self.BETA = 1.0/1024.0 742 | self.BETAGAMMA = self.BETA * self.GAMMA 743 | 744 | self.network = np.empty((self.NETSIZE, 3), dtype='float64') # The network itself 745 | self.colormap = np.empty((self.NETSIZE, 4), dtype='int32') # The network itself 746 | 747 | self.netindex = np.empty(256, dtype='int32') # For network lookup - really 256 748 | 749 | self.bias = np.empty(self.NETSIZE, dtype='float64') # Bias and freq arrays for learning 750 | self.freq = np.empty(self.NETSIZE, dtype='float64') 751 | 752 | self.pixels = None 753 | self.samplefac = samplefac 754 | 755 | self.a_s = {} 756 | 757 | def __init__(self, image, samplefac=10, colors=256): 758 | 759 | # Check Numpy 760 | if np is None: 761 | raise RuntimeError("Need Numpy for the NeuQuant algorithm.") 762 | 763 | # Check image 764 | if image.size[0] * image.size[1] < NeuQuant.MAXPRIME: 765 | raise IOError("Image is too small") 766 | if image.mode != "RGBA": 767 | raise IOError("Image mode should be RGBA.") 768 | 769 | # Initialize 770 | self.setconstants(samplefac, colors) 771 | self.pixels = np.fromstring(image.tostring(), np.uint32) 772 | self.setUpArrays() 773 | 774 | self.learn() 775 | self.fix() 776 | self.inxbuild() 777 | 778 | def writeColourMap(self, rgb, outstream): 779 | for i in range(self.NETSIZE): 780 | bb = self.colormap[i,0]; 781 | gg = self.colormap[i,1]; 782 | rr = self.colormap[i,2]; 783 | outstream.write(rr if rgb else bb) 784 | outstream.write(gg) 785 | outstream.write(bb if rgb else rr) 786 | return self.NETSIZE 787 | 788 | def setUpArrays(self): 789 | self.network[0,0] = 0.0 # Black 790 | self.network[0,1] = 0.0 791 | self.network[0,2] = 0.0 792 | 793 | self.network[1,0] = 255.0 # White 794 | self.network[1,1] = 255.0 795 | self.network[1,2] = 255.0 796 | 797 | # RESERVED self.BGCOLOR # Background 798 | 799 | for i in range(self.SPECIALS): 800 | self.freq[i] = 1.0 / self.NETSIZE 801 | self.bias[i] = 0.0 802 | 803 | for i in range(self.SPECIALS, self.NETSIZE): 804 | p = self.network[i] 805 | p[:] = (255.0 * (i-self.SPECIALS)) / self.CUTNETSIZE 806 | 807 | self.freq[i] = 1.0 / self.NETSIZE 808 | self.bias[i] = 0.0 809 | 810 | # Omitted: setPixels 811 | 812 | def altersingle(self, alpha, i, b, g, r): 813 | """Move neuron i towards biased (b,g,r) by factor alpha""" 814 | n = self.network[i] # Alter hit neuron 815 | n[0] -= (alpha*(n[0] - b)) 816 | n[1] -= (alpha*(n[1] - g)) 817 | n[2] -= (alpha*(n[2] - r)) 818 | 819 | def geta(self, alpha, rad): 820 | try: 821 | return self.a_s[(alpha, rad)] 822 | except KeyError: 823 | length = rad*2-1 824 | mid = int(length//2) 825 | q = np.array(list(range(mid-1,-1,-1))+list(range(-1,mid))) 826 | a = alpha*(rad*rad - q*q)/(rad*rad) 827 | a[mid] = 0 828 | self.a_s[(alpha, rad)] = a 829 | return a 830 | 831 | def alterneigh(self, alpha, rad, i, b, g, r): 832 | if i-rad >= self.SPECIALS-1: 833 | lo = i-rad 834 | start = 0 835 | else: 836 | lo = self.SPECIALS-1 837 | start = (self.SPECIALS-1 - (i-rad)) 838 | 839 | if i+rad <= self.NETSIZE: 840 | hi = i+rad 841 | end = rad*2-1 842 | else: 843 | hi = self.NETSIZE 844 | end = (self.NETSIZE - (i+rad)) 845 | 846 | a = self.geta(alpha, rad)[start:end] 847 | 848 | p = self.network[lo+1:hi] 849 | p -= np.transpose(np.transpose(p - np.array([b, g, r])) * a) 850 | 851 | #def contest(self, b, g, r): 852 | # """ Search for biased BGR values 853 | # Finds closest neuron (min dist) and updates self.freq 854 | # finds best neuron (min dist-self.bias) and returns position 855 | # for frequently chosen neurons, self.freq[i] is high and self.bias[i] is negative 856 | # self.bias[i] = self.GAMMA*((1/self.NETSIZE)-self.freq[i])""" 857 | # 858 | # i, j = self.SPECIALS, self.NETSIZE 859 | # dists = abs(self.network[i:j] - np.array([b,g,r])).sum(1) 860 | # bestpos = i + np.argmin(dists) 861 | # biasdists = dists - self.bias[i:j] 862 | # bestbiaspos = i + np.argmin(biasdists) 863 | # self.freq[i:j] -= self.BETA * self.freq[i:j] 864 | # self.bias[i:j] += self.BETAGAMMA * self.freq[i:j] 865 | # self.freq[bestpos] += self.BETA 866 | # self.bias[bestpos] -= self.BETAGAMMA 867 | # return bestbiaspos 868 | def contest(self, b, g, r): 869 | """ Search for biased BGR values 870 | Finds closest neuron (min dist) and updates self.freq 871 | finds best neuron (min dist-self.bias) and returns position 872 | for frequently chosen neurons, self.freq[i] is high and self.bias[i] is negative 873 | self.bias[i] = self.GAMMA*((1/self.NETSIZE)-self.freq[i])""" 874 | i, j = self.SPECIALS, self.NETSIZE 875 | dists = abs(self.network[i:j] - np.array([b,g,r])).sum(1) 876 | bestpos = i + np.argmin(dists) 877 | biasdists = dists - self.bias[i:j] 878 | bestbiaspos = i + np.argmin(biasdists) 879 | self.freq[i:j] *= (1-self.BETA) 880 | self.bias[i:j] += self.BETAGAMMA * self.freq[i:j] 881 | self.freq[bestpos] += self.BETA 882 | self.bias[bestpos] -= self.BETAGAMMA 883 | return bestbiaspos 884 | 885 | 886 | 887 | 888 | def specialFind(self, b, g, r): 889 | for i in range(self.SPECIALS): 890 | n = self.network[i] 891 | if n[0] == b and n[1] == g and n[2] == r: 892 | return i 893 | return -1 894 | 895 | def learn(self): 896 | biasRadius = self.INITBIASRADIUS 897 | alphadec = 30 + ((self.samplefac-1)/3) 898 | lengthcount = self.pixels.size 899 | samplepixels = lengthcount / self.samplefac 900 | delta = samplepixels / self.NCYCLES 901 | alpha = self.INITALPHA 902 | 903 | i = 0; 904 | rad = biasRadius * 2**self.RADIUSBIASSHIFT 905 | if rad <= 1: 906 | rad = 0 907 | 908 | print("Beginning 1D learning: samplepixels = %1.2f rad = %i" % 909 | (samplepixels, rad) ) 910 | step = 0 911 | pos = 0 912 | if lengthcount%NeuQuant.PRIME1 != 0: 913 | step = NeuQuant.PRIME1 914 | elif lengthcount%NeuQuant.PRIME2 != 0: 915 | step = NeuQuant.PRIME2 916 | elif lengthcount%NeuQuant.PRIME3 != 0: 917 | step = NeuQuant.PRIME3 918 | else: 919 | step = NeuQuant.PRIME4 920 | 921 | i = 0 922 | printed_string = '' 923 | while i < samplepixels: 924 | if i%100 == 99: 925 | tmp = '\b'*len(printed_string) 926 | printed_string = str((i+1)*100/samplepixels)+"%\n" 927 | print(tmp + printed_string) 928 | p = self.pixels[pos] 929 | r = (p >> 16) & 0xff 930 | g = (p >> 8) & 0xff 931 | b = (p ) & 0xff 932 | 933 | if i == 0: # Remember background colour 934 | self.network[self.BGCOLOR] = [b, g, r] 935 | 936 | j = self.specialFind(b, g, r) 937 | if j < 0: 938 | j = self.contest(b, g, r) 939 | 940 | if j >= self.SPECIALS: # Don't learn for specials 941 | a = (1.0 * alpha) / self.INITALPHA 942 | self.altersingle(a, j, b, g, r) 943 | if rad > 0: 944 | self.alterneigh(a, rad, j, b, g, r) 945 | 946 | pos = (pos+step)%lengthcount 947 | 948 | i += 1 949 | if i%delta == 0: 950 | alpha -= alpha / alphadec 951 | biasRadius -= biasRadius / self.RADIUSDEC 952 | rad = biasRadius * 2**self.RADIUSBIASSHIFT 953 | if rad <= 1: 954 | rad = 0 955 | 956 | finalAlpha = (1.0*alpha)/self.INITALPHA 957 | print("Finished 1D learning: final alpha = %1.2f!" % finalAlpha) 958 | 959 | def fix(self): 960 | for i in range(self.NETSIZE): 961 | for j in range(3): 962 | x = int(0.5 + self.network[i,j]) 963 | x = max(0, x) 964 | x = min(255, x) 965 | self.colormap[i,j] = x 966 | self.colormap[i,3] = i 967 | 968 | def inxbuild(self): 969 | previouscol = 0 970 | startpos = 0 971 | for i in range(self.NETSIZE): 972 | p = self.colormap[i] 973 | q = None 974 | smallpos = i 975 | smallval = p[1] # Index on g 976 | # Find smallest in i..self.NETSIZE-1 977 | for j in range(i+1, self.NETSIZE): 978 | q = self.colormap[j] 979 | if q[1] < smallval: # Index on g 980 | smallpos = j 981 | smallval = q[1] # Index on g 982 | 983 | q = self.colormap[smallpos] 984 | # Swap p (i) and q (smallpos) entries 985 | if i != smallpos: 986 | p[:],q[:] = q, p.copy() 987 | 988 | # smallval entry is now in position i 989 | if smallval != previouscol: 990 | self.netindex[previouscol] = (startpos+i) >> 1 991 | for j in range(previouscol+1, smallval): 992 | self.netindex[j] = i 993 | previouscol = smallval 994 | startpos = i 995 | self.netindex[previouscol] = (startpos+self.MAXNETPOS) >> 1 996 | for j in range(previouscol+1, 256): # Really 256 997 | self.netindex[j] = self.MAXNETPOS 998 | 999 | 1000 | def paletteImage(self): 1001 | """ PIL weird interface for making a paletted image: create an image which 1002 | already has the palette, and use that in Image.quantize. This function 1003 | returns this palette image. """ 1004 | if self.pimage is None: 1005 | palette = [] 1006 | for i in range(self.NETSIZE): 1007 | palette.extend(self.colormap[i][:3]) 1008 | 1009 | palette.extend([0]*(256-self.NETSIZE)*3) 1010 | 1011 | # a palette image to use for quant 1012 | self.pimage = Image.new("P", (1, 1), 0) 1013 | self.pimage.putpalette(palette) 1014 | return self.pimage 1015 | 1016 | 1017 | def quantize(self, image): 1018 | """ Use a kdtree to quickly find the closest palette colors for the pixels """ 1019 | if get_cKDTree(): 1020 | return self.quantize_with_scipy(image) 1021 | else: 1022 | print('Scipy not available, falling back to slower version.') 1023 | return self.quantize_without_scipy(image) 1024 | 1025 | 1026 | def quantize_with_scipy(self, image): 1027 | w,h = image.size 1028 | px = np.asarray(image).copy() 1029 | px2 = px[:,:,:3].reshape((w*h,3)) 1030 | 1031 | cKDTree = get_cKDTree() 1032 | kdtree = cKDTree(self.colormap[:,:3],leafsize=10) 1033 | result = kdtree.query(px2) 1034 | colorindex = result[1] 1035 | print("Distance: %1.2f" % (result[0].sum()/(w*h)) ) 1036 | px2[:] = self.colormap[colorindex,:3] 1037 | 1038 | return Image.fromarray(px).convert("RGB").quantize(palette=self.paletteImage()) 1039 | 1040 | 1041 | def quantize_without_scipy(self, image): 1042 | """" This function can be used if no scipy is availabe. 1043 | It's 7 times slower though. 1044 | """ 1045 | w,h = image.size 1046 | px = np.asarray(image).copy() 1047 | memo = {} 1048 | for j in range(w): 1049 | for i in range(h): 1050 | key = (px[i,j,0],px[i,j,1],px[i,j,2]) 1051 | try: 1052 | val = memo[key] 1053 | except KeyError: 1054 | val = self.convert(*key) 1055 | memo[key] = val 1056 | px[i,j,0],px[i,j,1],px[i,j,2] = val 1057 | return Image.fromarray(px).convert("RGB").quantize(palette=self.paletteImage()) 1058 | 1059 | def convert(self, *color): 1060 | i = self.inxsearch(*color) 1061 | return self.colormap[i,:3] 1062 | 1063 | def inxsearch(self, r, g, b): 1064 | """Search for BGR values 0..255 and return colour index""" 1065 | dists = (self.colormap[:,:3] - np.array([r,g,b])) 1066 | a= np.argmin((dists*dists).sum(1)) 1067 | return a 1068 | 1069 | 1070 | 1071 | if __name__ == '__main__': 1072 | im = np.zeros((200,200), dtype=np.uint8) 1073 | im[10:30,:] = 100 1074 | im[:,80:120] = 255 1075 | im[-50:-40,:] = 50 1076 | 1077 | images = [im*1.0, im*0.8, im*0.6, im*0.4, im*0] 1078 | writeGif('lala3.gif',images, duration=0.5, dither=0) 1079 | -------------------------------------------------------------------------------- /mnist_data.py: -------------------------------------------------------------------------------- 1 | """Functions for downloading and reading MNIST data.""" 2 | ''' 3 | modified for the purpose of cppgan demo 4 | ''' 5 | import gzip 6 | import os 7 | import urllib 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 11 | def maybe_download(filename, work_directory): 12 | """Download the data from Yann's website, unless it's already here.""" 13 | if not os.path.exists(work_directory): 14 | os.mkdir(work_directory) 15 | filepath = os.path.join(work_directory, filename) 16 | if not os.path.exists(filepath): 17 | filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) 18 | statinfo = os.stat(filepath) 19 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 20 | return filepath 21 | def _read32(bytestream): 22 | dt = np.dtype(np.uint32).newbyteorder('>') 23 | return np.frombuffer(bytestream.read(4), dtype=dt) 24 | def extract_images(filename): 25 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 26 | print('Extracting', filename) 27 | with gzip.open(filename) as bytestream: 28 | magic = _read32(bytestream) 29 | if magic != 2051: 30 | raise ValueError( 31 | 'Invalid magic number %d in MNIST image file: %s' % 32 | (magic, filename)) 33 | num_images = _read32(bytestream) 34 | rows = _read32(bytestream) 35 | cols = _read32(bytestream) 36 | buf = bytestream.read(rows * cols * num_images) 37 | data = np.frombuffer(buf, dtype=np.uint8) 38 | data = data.reshape(num_images, rows, cols, 1) 39 | return data 40 | def dense_to_one_hot(labels_dense, num_classes=10): 41 | """Convert class labels from scalars to one-hot vectors.""" 42 | num_labels = labels_dense.shape[0] 43 | index_offset = np.arange(num_labels) * num_classes 44 | labels_one_hot = np.zeros((num_labels, num_classes)) 45 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 46 | return labels_one_hot 47 | def extract_labels(filename, one_hot=False): 48 | """Extract the labels into a 1D uint8 numpy array [index].""" 49 | print('Extracting', filename) 50 | with gzip.open(filename) as bytestream: 51 | magic = _read32(bytestream) 52 | if magic != 2049: 53 | raise ValueError( 54 | 'Invalid magic number %d in MNIST label file: %s' % 55 | (magic, filename)) 56 | num_items = _read32(bytestream) 57 | buf = bytestream.read(num_items) 58 | labels = np.frombuffer(buf, dtype=np.uint8) 59 | if one_hot: 60 | return dense_to_one_hot(labels) 61 | return labels 62 | 63 | # class to store mnist data 64 | class DataSet(object): 65 | def __init__(self, images, labels): 66 | # Convert from [0, 255] -> [0.0, 1.0]. 67 | images = images.astype(np.float32) 68 | images = np.multiply(images, 1.0 / 255.0) 69 | self._num_examples = len(images) 70 | perm = np.arange(self._num_examples) 71 | np.random.shuffle(perm) 72 | self._images = images[perm] 73 | self._labels = labels[perm] 74 | self._epochs_completed = 0 75 | self._index_in_epoch = 0 76 | 77 | @property 78 | def images(self): 79 | return self._images 80 | @property 81 | def labels(self): 82 | return self._labels 83 | @property 84 | def num_examples(self): 85 | return self._num_examples 86 | @property 87 | def epochs_completed(self): 88 | return self._epochs_completed 89 | def next_batch(self, batch_size, with_label = False): 90 | """Return the next `batch_size` examples from this data set.""" 91 | start = self._index_in_epoch 92 | self._index_in_epoch += batch_size 93 | if self._index_in_epoch > self._num_examples: 94 | # Finished epoch 95 | self._epochs_completed += 1 96 | # Shuffle the data 97 | perm = np.arange(self._num_examples) 98 | np.random.shuffle(perm) 99 | self._images = self._images[perm] 100 | self._labels = self._labels[perm] 101 | # Start next epoch 102 | start = 0 103 | self._index_in_epoch = batch_size 104 | assert batch_size <= self._num_examples 105 | end = self._index_in_epoch 106 | if with_label == True: 107 | return self.distort_batch(self._images[start:end]), self._labels[start:end] 108 | return self.distort_batch(self._images[start:end]) 109 | def distort_batch(self, batch): 110 | batch_size = len(batch) 111 | row_distort = np.random.randint(0, 3, batch_size) 112 | col_distort = np.random.randint(0, 3, batch_size) 113 | result = np.zeros(shape=(batch_size, 26, 26, 1), dtype=np.float32) 114 | for i in range(batch_size): 115 | result[i, :, :, :] = batch[i, row_distort[i]:row_distort[i]+26, col_distort[i]:col_distort[i]+26, :] 116 | return result 117 | def show_image(self, image): 118 | plt.subplot(1, 1, 1) 119 | plt.imshow(np.reshape(image, (26, 26)), cmap='Greys', interpolation='nearest') 120 | plt.axis('off') 121 | plt.show() 122 | def shuffle_data(self): 123 | perm = np.arange(self._num_examples) 124 | np.random.shuffle(perm) 125 | self._images = self._images[perm] 126 | self._labels = self._labels[perm] 127 | 128 | def read_data_sets(train_dir = 'MNIST_data', one_hot=False): 129 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 130 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 131 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 132 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 133 | VALIDATION_SIZE = 5000 134 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 135 | train_images = extract_images(local_file) 136 | local_file = maybe_download(TRAIN_LABELS, train_dir) 137 | train_labels = extract_labels(local_file, one_hot=one_hot) 138 | local_file = maybe_download(TEST_IMAGES, train_dir) 139 | test_images = extract_images(local_file) 140 | local_file = maybe_download(TEST_LABELS, train_dir) 141 | test_labels = extract_labels(local_file, one_hot=one_hot) 142 | 143 | all_images = np.vstack((train_images, test_images)) 144 | all_labels = np.concatenate((train_labels, test_labels)) 145 | 146 | data_sets = DataSet(all_images, all_labels) # train on all train+test sets 70k 147 | #data_sets = DataSet(train_images, train_labels) # train only only train set 60k 148 | return data_sets 149 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from ops import * 6 | 7 | ''' 8 | cppgan-vae 9 | 10 | compositional pattern-producing generative adversarial network combined with variational autoencoder 11 | 12 | UPDATE: modified to get rid of likelihood function, and use pure GAN to draw to correct class. 13 | 14 | I learned a lot from studying the below pages: 15 | 16 | https://github.com/carpedm20/DCGAN-tensorflow 17 | https://jmetzen.github.io/2015-11-27/vae.html 18 | 19 | it wouldn't have been possible without referencing those two guy's code! 20 | 21 | Description of CPPNs: 22 | 23 | https://en.wikipedia.org/wiki/Compositional_pattern-producing_network 24 | 25 | ''' 26 | 27 | class CPPNVAE(): 28 | def __init__(self, batch_size=1, z_dim=32, 29 | x_dim = 26, y_dim = 26, c_dim = 1, scale = 8.0, 30 | learning_rate_g= 0.01, learning_rate_d= 0.001, learning_rate_vae = 0.0001, beta1 = 0.9, net_size_g = 6, net_depth_g = 24, subnet_depth_g = 4, 31 | net_size_q = 512, keep_prob = 1.0, df_dim = 32, model_name = "cppnvae", grad_clip = 5.0): 32 | """ 33 | 34 | Args: 35 | z_dim dimensionality of the latent vector 36 | x_dim, y_dim default resolution of generated images for training 37 | c_dim 1 for monotone, 3 for colour 38 | learning_rate_g learning rate for the generator 39 | _d learning rate for the discriminiator 40 | _vae learning rate for the variational autoencoder 41 | net_size_g number of activations per layer for cppn generator function 42 | net_depth_g depth of generator 43 | net_size_q number of activations per layer for decoder (real image -> z). 2 layers. 44 | df_dim discriminiator is a convnet. higher -> more activtions -> smarter. 45 | keep_prob dropout probability 46 | 47 | when training, use I used dropout on training the decoder, batch norm on discriminator, nothing on cppn 48 | choose training parameters so that over the long run, decoder and encoder log errors hover around 0.7 each (so they are at the same skill level) 49 | while the error for vae should slowly move lower over time with D and G balanced. 50 | 51 | """ 52 | 53 | self.batch_size = batch_size 54 | self.learning_rate_g = learning_rate_g 55 | self.learning_rate_d = learning_rate_d 56 | self.learning_rate_vae = learning_rate_vae 57 | self.beta1 = beta1 58 | self.net_size_g = net_size_g 59 | self.net_size_q = net_size_q 60 | self.x_dim = x_dim 61 | self.y_dim = y_dim 62 | self.scale = scale 63 | self.c_dim = c_dim 64 | self.z_dim = z_dim 65 | self.net_depth_g = net_depth_g 66 | self.subnet_depth_g = subnet_depth_g 67 | self.model_name = model_name 68 | self.keep_prob = keep_prob 69 | self.df_dim = df_dim 70 | self.num_class = 11 # 0->9 are MNIST classes, 10 are fake digits. 71 | self.grad_clip = grad_clip 72 | 73 | # tf Graph batch of image (batch_size, height, width, depth) 74 | self.batch = tf.placeholder(tf.float32, [batch_size, x_dim, y_dim, c_dim]) 75 | self.batch_flatten = tf.reshape(self.batch, [batch_size, -1]) 76 | self.batch_label = tf.placeholder(tf.float32, [batch_size, self.num_class]) # mnist labels for the batch (one-hot) 77 | self.fake_label = np.array(self.batch_size*[10], dtype=np.int32) # label of fake batches 78 | self.fake_label_one_hot = self.to_one_hot(self.fake_label) 79 | 80 | n_points = x_dim * y_dim 81 | self.n_points = n_points 82 | 83 | self.x_vec, self.y_vec, self.r_vec = self.coordinates(x_dim, y_dim, scale) 84 | 85 | # latent vector 86 | # self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim]) 87 | # inputs to cppn, like coordinates and radius from centre 88 | self.x = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 89 | self.y = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 90 | self.r = tf.placeholder(tf.float32, [self.batch_size, None, 1]) 91 | 92 | # batch normalization : deals with poor initialization helps gradient flow 93 | self.d_bn1 = batch_norm(batch_size, name=self.model_name+'_d_bn1') 94 | self.d_bn2 = batch_norm(batch_size, name=self.model_name+'_d_bn2') 95 | 96 | # Use recognition network to determine mean and 97 | # (log) variance of Gaussian distribution in latent 98 | # space 99 | self.z_mean, self.z_log_sigma_sq = self.encoder() 100 | 101 | # Draw one sample z from Gaussian distribution 102 | eps = tf.random_normal((self.batch_size, self.z_dim), 0, 1, dtype=tf.float32) 103 | # z = mu + sigma*epsilon 104 | self.z = tf.add(self.z_mean, tf.mul(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps)) 105 | 106 | # Use generator to determine mean of 107 | # Bernoulli distribution of reconstructed input 108 | self.G = self.generator() 109 | #self.batch_reconstruct_flatten = tf.reshape(self.G, [batch_size, -1]) # not needed 110 | 111 | self.predict_real_samples = self.discriminator(self.batch) # discriminiator on correct examples 112 | self.predict_fake_samples = self.discriminator(self.G, reuse=True) # feed generated images into D 113 | 114 | self.create_vae_loss_terms() 115 | self.create_gan_loss_terms() 116 | 117 | self.balanced_loss = 1.0 * self.g_loss + 10.0 * self.vae_loss # can try to weight these. 118 | 119 | self.t_vars = tf.trainable_variables() 120 | 121 | self.q_vars = [var for var in self.t_vars if (self.model_name+'_q_') in var.name] 122 | self.g_vars = [var for var in self.t_vars if (self.model_name+'_g_') in var.name] 123 | self.d_vars = [var for var in self.t_vars if (self.model_name+'_d_') in var.name] 124 | self.both_vars = self.q_vars+self.g_vars 125 | #self.vae_vars = self.q_vars # in this version, g_vars don't concern vae_loss 126 | 127 | # clip gradients 128 | d_opt_real_grads, _ = tf.clip_by_global_norm(tf.gradients(self.d_loss_real, self.d_vars), self.grad_clip) 129 | d_opt_grads, _ = tf.clip_by_global_norm(tf.gradients(self.d_loss, self.d_vars), self.grad_clip) 130 | g_opt_grads, _ = tf.clip_by_global_norm(tf.gradients(self.balanced_loss, self.both_vars), self.grad_clip) 131 | vae_opt_grads, _ = tf.clip_by_global_norm(tf.gradients(self.vae_loss, self.q_vars), self.grad_clip) 132 | 133 | d_real_optimizer = tf.train.AdamOptimizer(self.learning_rate_d, beta1=self.beta1) 134 | d_optimizer = tf.train.AdamOptimizer(self.learning_rate_d, beta1=self.beta1) 135 | g_optimizer = tf.train.AdamOptimizer(self.learning_rate_g, beta1=self.beta1) 136 | vae_optimizer = tf.train.AdamOptimizer(self.learning_rate_vae, beta1=self.beta1) 137 | 138 | self.d_opt_real = d_real_optimizer.apply_gradients(zip(d_opt_real_grads, self.d_vars)) 139 | self.d_opt = d_optimizer.apply_gradients(zip(d_opt_grads, self.d_vars)) 140 | self.g_opt = g_optimizer.apply_gradients(zip(g_opt_grads, self.both_vars)) 141 | self.vae_opt = vae_optimizer.apply_gradients(zip(vae_opt_grads, self.q_vars)) 142 | 143 | ''' 144 | self.d_opt_real = tf.train.AdamOptimizer(self.learning_rate_d, beta1=self.beta1) \ 145 | .minimize(self.d_loss_real, var_list=self.d_vars) 146 | self.d_opt = tf.train.AdamOptimizer(self.learning_rate_d, beta1=self.beta1) \ 147 | .minimize(self.d_loss, var_list=self.d_vars) 148 | self.g_opt = tf.train.AdamOptimizer(self.learning_rate_g, beta1=self.beta1) \ 149 | .minimize(self.g_loss, var_list=self.both_vars) 150 | self.vae_opt = tf.train.AdamOptimizer(self.learning_rate_vae, beta1=self.beta1) \ 151 | .minimize(self.vae_loss, var_list=self.q_vars) 152 | ''' 153 | 154 | ''' 155 | tvars = tf.trainable_variables() 156 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), args.grad_clip) 157 | optimizer = tf.train.AdamOptimizer(self.lr, epsilon=0.001) 158 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 159 | ''' 160 | 161 | self.init() 162 | self.saver = tf.train.Saver(tf.all_variables()) 163 | 164 | def init(self): 165 | 166 | # Initializing the tensor flow variables 167 | init = tf.initialize_all_variables() 168 | # Launch the session 169 | self.sess = tf.InteractiveSession() 170 | self.sess.run(init) 171 | 172 | def reinit(self): 173 | init = tf.initialize_variables(tf.trainable_variables()) 174 | self.sess.run(init) 175 | 176 | def to_one_hot(self, label): 177 | # convert labels, a numpy list of labels (of size batch_size) to the one hot equivalent 178 | return np.eye(self.num_class)[label] 179 | 180 | def create_vae_loss_terms(self): 181 | # The loss is composed of two terms: 182 | # 1.) The reconstruction loss (the negative log probability 183 | # of the input under the reconstructed Bernoulli distribution 184 | # induced by the decoder in the data space). 185 | # This can be interpreted as the number of "nats" required 186 | # for reconstructing the input when the activation in latent 187 | # is given. 188 | # Adding 1e-10 to avoid evaluatio of log(0.0) 189 | 190 | # stop using likelihood function for similarity 191 | # reconstr_loss = \ 192 | # -tf.reduce_sum(self.batch_flatten * tf.log(1e-10 + self.batch_reconstruct_flatten) 193 | # + (1-self.batch_flatten) * tf.log(1e-10 + 1 - self.batch_reconstruct_flatten), 1) 194 | 195 | # 2.) The latent loss, which is defined as the Kullback Leibler divergence 196 | ## between the distribution in latent space induced by the encoder on 197 | # the data and some prior. This acts as a kind of regularizer. 198 | # This can be interpreted as the number of "nats" required 199 | # for transmitting the the latent space distribution given 200 | # the prior. 201 | 202 | latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq 203 | - tf.square(self.z_mean) 204 | - tf.exp(self.z_log_sigma_sq), 1) 205 | 206 | #self.vae_loss = tf.reduce_mean(reconstr_loss + latent_loss) / self.n_points # average over batch and pixel 207 | 208 | # vae loss is now purely kl divergence loss term. let GAN take care of mnist class accuracy. 209 | self.vae_loss = tf.reduce_mean(latent_loss) / self.n_points # average over batch and pixel 210 | 211 | def create_gan_loss_terms(self): 212 | # Define loss function and optimiser 213 | ''' replace below with class-based disriminiator 214 | self.d_loss_real = binary_cross_entropy_with_logits(tf.ones_like(self.D_right), self.D_right) 215 | self.d_loss_fake = binary_cross_entropy_with_logits(tf.zeros_like(self.D_wrong), self.D_wrong) 216 | self.d_loss = 1.0*(self.d_loss_real + self.d_loss_fake)/ 2.0 217 | self.g_loss = 1.0*binary_cross_entropy_with_logits(tf.ones_like(self.D_wrong), self.D_wrong) 218 | ''' 219 | 220 | # cross entropy loss of predicting real mnist to real classes 221 | self.d_loss_real = tf.reduce_mean(-tf.reduce_sum(self.batch_label * tf.log(self.predict_real_samples), reduction_indices=[1])) 222 | # accuracy of using discriminiator as a normal mnist classifier 223 | self.d_loss_real_accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.predict_real_samples,1), tf.argmax(self.batch_label,1)), tf.float32)) 224 | # cross entropy loss of predicting that fake generated mnist are in fact fake 225 | self.d_loss_fake = tf.reduce_mean(-tf.reduce_sum(self.fake_label_one_hot * tf.log(self.predict_fake_samples), reduction_indices=[1])) 226 | # accuracy of discriminator predicting a fake mnist digit 227 | self.d_loss_fake_accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.predict_fake_samples,1), tf.argmax(self.fake_label_one_hot,1)), tf.float32)) 228 | # take the average of two d_loss to be the defacto d_loss 229 | self.d_loss = (10.0*self.d_loss_real + self.d_loss_fake)/ 11.0 # balanc out the classes 230 | # cross entropy of generator fooling discriminiator that its shit is real. 231 | self.g_loss = tf.reduce_mean(-tf.reduce_sum(self.batch_label * tf.log(self.predict_fake_samples), reduction_indices=[1])) 232 | # accuracy of generated samples being fooled to be classified as their supposed ground truth labels 233 | self.g_loss_accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(self.predict_fake_samples,1), tf.argmax(self.batch_label,1)), tf.float32)) 234 | 235 | def coordinates(self, x_dim = 32, y_dim = 32, scale = 1.0): 236 | n_pixel = x_dim * y_dim 237 | x_range = scale*(np.arange(x_dim)-(x_dim-1)/2.0)/(x_dim-1)/0.5 238 | y_range = scale*(np.arange(y_dim)-(y_dim-1)/2.0)/(y_dim-1)/0.5 239 | x_mat = np.matmul(np.ones((y_dim, 1)), x_range.reshape((1, x_dim))) 240 | y_mat = np.matmul(y_range.reshape((y_dim, 1)), np.ones((1, x_dim))) 241 | r_mat = np.sqrt(x_mat*x_mat + y_mat*y_mat) 242 | x_mat = np.tile(x_mat.flatten(), self.batch_size).reshape(self.batch_size, n_pixel, 1) 243 | y_mat = np.tile(y_mat.flatten(), self.batch_size).reshape(self.batch_size, n_pixel, 1) 244 | r_mat = np.tile(r_mat.flatten(), self.batch_size).reshape(self.batch_size, n_pixel, 1) 245 | return x_mat, y_mat, r_mat 246 | 247 | def show_image(self, image): 248 | ''' 249 | image is in [height width depth] 250 | ''' 251 | plt.subplot(1, 1, 1) 252 | y_dim = image.shape[0] 253 | x_dim = image.shape[1] 254 | if self.c_dim > 1: 255 | plt.imshow(image, interpolation='nearest') 256 | else: 257 | plt.imshow(image.reshape(y_dim, x_dim), cmap='Greys', interpolation='nearest') 258 | plt.axis('off') 259 | plt.show() 260 | 261 | def encoder(self): 262 | # Generate probabilistic encoder (recognition network), which 263 | # maps inputs onto a normal distribution in latent space. 264 | # The transformation is parametrized and can be learned. 265 | H1 = tf.nn.dropout(tf.nn.softplus(linear(self.batch_flatten, self.net_size_q, self.model_name+'_q_lin1')), self.keep_prob) 266 | H2 = tf.nn.dropout(tf.nn.softplus(linear(H1, self.net_size_q, self.model_name+'_q_lin2')), self.keep_prob) 267 | z_mean = linear(H2, self.z_dim, self.model_name+'_q_lin3_mean') 268 | z_log_sigma_sq = linear(H2, self.z_dim, self.model_name+'_q_lin3_log_sigma_sq') 269 | return (z_mean, z_log_sigma_sq) 270 | 271 | def discriminator(self, image, reuse=False): 272 | 273 | if reuse: 274 | tf.get_variable_scope().reuse_variables() 275 | 276 | h0 = lrelu(conv2d(image, self.df_dim, name=self.model_name+'_d_h0_conv')) 277 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name=self.model_name+'_d_h1_conv'))) 278 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name=self.model_name+'_d_h2_conv'))) 279 | h3 = linear(tf.reshape(h2, [self.batch_size, -1]), self.num_class, self.model_name+'_d_h2_lin') 280 | 281 | return tf.nn.softmax(h3) 282 | 283 | def generator(self, gen_x_dim = 26, gen_y_dim = 26, reuse = False): 284 | 285 | if reuse: 286 | tf.get_variable_scope().reuse_variables() 287 | 288 | n_network = self.net_size_g 289 | gen_n_points = gen_x_dim * gen_y_dim 290 | 291 | z_scaled = tf.reshape(self.z, [self.batch_size, 1, self.z_dim]) * \ 292 | tf.ones([gen_n_points, 1], dtype=tf.float32) * self.scale 293 | z_unroll = tf.reshape(z_scaled, [self.batch_size*gen_n_points, self.z_dim]) 294 | x_unroll = tf.reshape(self.x, [self.batch_size*gen_n_points, 1]) 295 | y_unroll = tf.reshape(self.y, [self.batch_size*gen_n_points, 1]) 296 | r_unroll = tf.reshape(self.r, [self.batch_size*gen_n_points, 1]) 297 | 298 | U = fully_connected(z_unroll, n_network, self.model_name+'_g_0_z') + \ 299 | fully_connected(x_unroll, n_network, self.model_name+'_g_0_x', with_bias = False) + \ 300 | fully_connected(y_unroll, n_network, self.model_name+'_g_0_y', with_bias = False) + \ 301 | fully_connected(r_unroll, n_network, self.model_name+'_g_0_r', with_bias = False) 302 | 303 | #H = tf.nn.relu(U) 304 | H = tf.nn.tanh(U) 305 | 306 | for i in range(0, self.net_depth_g): 307 | H0 = H 308 | for j in range(0, self.subnet_depth_g): 309 | H0 = tf.nn.relu(fully_connected(H0, n_network, self.model_name+'_g_relu_skip_'+str(i)+'_'+str(j), stddev = 1.0)) 310 | H0 = tf.nn.tanh(fully_connected(H0, n_network, self.model_name+'_g_tanh_skip_'+str(i), stddev = 0.001)) 311 | H = H + H0 312 | 313 | output = tf.sigmoid(fully_connected(H, self.c_dim, self.model_name+'_g_'+str(self.net_depth_g))) 314 | 315 | result = tf.reshape(output, [self.batch_size, gen_y_dim, gen_x_dim, self.c_dim]) 316 | 317 | return result 318 | 319 | 320 | def partial_train(self, batch, label): 321 | """Train model based on mini-batch of input data. 322 | 323 | Return cost of mini-batch. 324 | 325 | I should really seperate the below tricks into parameters, like number of times/pass 326 | and also the regulator threshold levels. 327 | """ 328 | 329 | counter = 0 330 | 331 | label_one_hot = self.to_one_hot(label) 332 | 333 | ''' 334 | for i in range(1): 335 | counter += 1 336 | _, vae_loss = self.sess.run((self.vae_opt, self.vae_loss), 337 | feed_dict={self.batch: batch, self.x: self.x_vec, self.y: self.y_vec, self.r: self.r_vec, self.batch_label: label_one_hot}) 338 | ''' 339 | 340 | for i in range(16): 341 | counter += 1 342 | _, g_loss, vae_loss, g_accuracy = self.sess.run((self.g_opt, self.g_loss, self.vae_loss, self.g_loss_accuracy), 343 | feed_dict={self.batch: batch, self.x: self.x_vec, self.y: self.y_vec, self.r: self.r_vec, self.batch_label: label_one_hot}) 344 | if g_accuracy > 0.98: 345 | break 346 | 347 | # train classifier on only real mnist digits 348 | # _ = self.sess.run((self.d_opt_real), feed_dict={self.batch: batch, self.x: self.x_vec, self.y: self.y_vec, self.r: self.r_vec, self.batch_label: label_one_hot}) 349 | 350 | # calculate accuracy before deciding whether to train discriminator 351 | d_loss, d_loss_real, d_loss_fake, d_real_accuracy, d_fake_accuracy = self.sess.run((self.d_loss, self.d_loss_real, self.d_loss_fake, self.d_loss_real_accuracy, self.d_loss_fake_accuracy), 352 | feed_dict={self.batch: batch, self.x: self.x_vec, self.y: self.y_vec, self.r: self.r_vec, self.batch_label: label_one_hot}) 353 | 354 | if d_fake_accuracy < 0.7 and g_accuracy > 0.6: # only train discriminiator if generator is good and d is behind. 355 | for i in range(8): 356 | counter += 1 357 | _, d_loss, d_loss_real, d_loss_fake, d_real_accuracy, d_fake_accuracy = self.sess.run((self.d_opt, self.d_loss, self.d_loss_real, self.d_loss_fake, self.d_loss_real_accuracy, self.d_loss_fake_accuracy), 358 | feed_dict={self.batch: batch, self.x: self.x_vec, self.y: self.y_vec, self.r: self.r_vec, self.batch_label: label_one_hot}) 359 | if d_fake_accuracy > 0.75: 360 | break 361 | elif d_real_accuracy < 0.6: 362 | for i in range(8): 363 | counter += 1 364 | _, d_real_accuracy = self.sess.run((self.d_opt_real, self.d_loss_real_accuracy), feed_dict={self.batch: batch, self.x: self.x_vec, self.y: self.y_vec, self.r: self.r_vec, self.batch_label: label_one_hot}) 365 | if d_real_accuracy > 0.7: 366 | break 367 | 368 | return d_loss, g_loss, vae_loss, counter, d_real_accuracy, d_fake_accuracy, g_accuracy, d_loss_real, d_loss_fake 369 | 370 | def encode(self, X): 371 | """Transform data by mapping it into the latent space.""" 372 | # Note: This maps to mean of distribution, we could alternatively 373 | # sample from Gaussian distribution 374 | return self.sess.run(self.z_mean, feed_dict={self.batch: X}) 375 | 376 | def generate(self, z=None, x_dim = 26, y_dim = 26, scale = 5.0): 377 | """ Generate data by sampling from latent space. 378 | 379 | If z is not None, data for this point in latent space is 380 | generated. Otherwise, z is drawn from prior in latent 381 | space. 382 | """ 383 | if z is None: 384 | z = np.random.normal(size=self.z_dim).astype(np.float32) 385 | # Note: This maps to mean of distribution, we could alternatively 386 | # sample from Gaussian distribution 387 | 388 | z = np.reshape(z, (self.batch_size, self.z_dim)) 389 | 390 | G = self.generator(gen_x_dim = x_dim, gen_y_dim = y_dim, reuse = True) 391 | gen_x_vec, gen_y_vec, gen_r_vec = self.coordinates(x_dim, y_dim, scale = scale) 392 | image = self.sess.run(G, feed_dict={self.z: z, self.x: gen_x_vec, self.y: gen_y_vec, self.r: gen_r_vec}) 393 | return image 394 | 395 | def save_model(self, checkpoint_path, epoch): 396 | """ saves the model to a file """ 397 | self.saver.save(self.sess, checkpoint_path, global_step = epoch) 398 | 399 | def load_model(self, checkpoint_path): 400 | 401 | ckpt = tf.train.get_checkpoint_state(checkpoint_path) 402 | print "loading model: ",ckpt.model_checkpoint_path 403 | 404 | #self.saver.restore(self.sess, checkpoint_path+'/'+ckpt.model_checkpoint_path) 405 | # use the below line for tensorflow 0.7 406 | self.saver.restore(self.sess, ckpt.model_checkpoint_path) 407 | 408 | def close(self): 409 | self.sess.close() 410 | 411 | 412 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.framework import ops 6 | 7 | #from utils import * 8 | 9 | class batch_norm(object): 10 | """Code modification of http://stackoverflow.com/a/33950177""" 11 | def __init__(self, batch_size, epsilon=1e-5, momentum = 0.1, name="batch_norm"): 12 | with tf.variable_scope(name) as scope: 13 | self.epsilon = epsilon 14 | self.momentum = momentum 15 | self.batch_size = batch_size 16 | 17 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 18 | self.name=name 19 | 20 | def __call__(self, x, train=True): 21 | shape = x.get_shape().as_list() 22 | 23 | with tf.variable_scope(self.name) as scope: 24 | self.gamma = tf.get_variable("gamma", [shape[-1]], 25 | initializer=tf.random_normal_initializer(1., 0.02)) 26 | self.beta = tf.get_variable("beta", [shape[-1]], 27 | initializer=tf.constant_initializer(0.)) 28 | 29 | self.mean, self.variance = tf.nn.moments(x, [0, 1, 2]) 30 | 31 | return tf.nn.batch_norm_with_global_normalization( 32 | x, self.mean, self.variance, self.beta, self.gamma, self.epsilon, 33 | scale_after_normalization=True) 34 | 35 | def binary_cross_entropy_with_logits(logits, targets, name=None): 36 | """Computes binary cross entropy given `logits`. 37 | 38 | For brevity, let `x = logits`, `z = targets`. The logistic loss is 39 | 40 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 41 | 42 | Args: 43 | logits: A `Tensor` of type `float32` or `float64`. 44 | targets: A `Tensor` of the same type and shape as `logits`. 45 | """ 46 | eps = 1e-12 47 | with ops.op_scope([logits, targets], name, "bce_loss") as name: 48 | logits = ops.convert_to_tensor(logits, name="logits") 49 | targets = ops.convert_to_tensor(targets, name="targets") 50 | return tf.reduce_mean(-(logits * tf.log(targets + eps) + 51 | (1. - logits) * tf.log(1. - targets + eps))) 52 | 53 | def conv_cond_concat(x, y): 54 | """Concatenate conditioning vector on feature map axis.""" 55 | x_shapes = x.get_shape() 56 | y_shapes = y.get_shape() 57 | return tf.concat(3, [x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])]) 58 | 59 | def conv2d(input_, output_dim, 60 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 61 | name="conv2d"): 62 | with tf.variable_scope(name): 63 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 64 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 65 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 66 | 67 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 68 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 69 | 70 | return conv 71 | 72 | def deconv2d(input_, output_shape, 73 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 74 | name="deconv2d", with_w=False): 75 | with tf.variable_scope(name): 76 | # filter : [height, width, output_channels, in_channels] 77 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]], 78 | initializer=tf.random_normal_initializer(stddev=stddev)) 79 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 80 | strides=[1, d_h, d_w, 1]) 81 | 82 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 83 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 84 | 85 | if with_w: 86 | return deconv, w, biases 87 | else: 88 | return deconv 89 | 90 | def lrelu(x, leak=0.2, name="lrelu"): 91 | with tf.variable_scope(name): 92 | f1 = 0.5 * (1 + leak) 93 | f2 = 0.5 * (1 - leak) 94 | return f1 * x + f2 * abs(x) 95 | 96 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 97 | shape = input_.get_shape().as_list() 98 | 99 | with tf.variable_scope(scope or "Linear"): 100 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 101 | tf.random_normal_initializer(stddev=stddev)) 102 | bias = tf.get_variable("bias", [output_size], 103 | initializer=tf.constant_initializer(bias_start)) 104 | if with_w: 105 | return tf.matmul(input_, matrix) + bias, matrix, bias 106 | else: 107 | return tf.matmul(input_, matrix) + bias 108 | 109 | def fully_connected(input_, output_size, scope=None, stddev=0.05, with_bias = True): 110 | shape = input_.get_shape().as_list() 111 | 112 | with tf.variable_scope(scope or "FC"): 113 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 114 | tf.random_normal_initializer(stddev=stddev)) 115 | 116 | result = tf.matmul(input_, matrix) 117 | 118 | if with_bias: 119 | bias = tf.get_variable("bias", [1, output_size], 120 | initializer=tf.random_normal_initializer(stddev=stddev)) 121 | result += bias*tf.ones([shape[0], 1], dtype=tf.float32) 122 | 123 | return result 124 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of Compositional Pattern Producing Networks in Tensorflow 3 | 4 | https://en.wikipedia.org/wiki/Compositional_pattern-producing_network 5 | 6 | @hardmaru, 2016 7 | 8 | Sampler Class 9 | 10 | This file is meant to be run inside an IPython session, as it is meant 11 | to be used interacively for experimentation. 12 | 13 | It shouldn't be that hard to take bits of this code into a normal 14 | command line environment though if you want to use outside of IPython. 15 | 16 | usage: 17 | 18 | %run -i sampler.py 19 | 20 | sampler = Sampler() 21 | 22 | ''' 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | import math 27 | import random 28 | import PIL 29 | from PIL import Image 30 | import pylab 31 | from model import CPPNVAE 32 | import matplotlib.pyplot as plt 33 | import images2gif 34 | from images2gif import writeGif 35 | 36 | mgc = get_ipython().magic 37 | mgc(u'matplotlib inline') 38 | mgc(u'run -i mnist_data.py') 39 | pylab.rcParams['figure.figsize'] = (10.0, 10.0) 40 | 41 | class Sampler(): 42 | def __init__(self): 43 | self.mnist = None 44 | self.model = CPPNVAE() 45 | self.z = self.generate_z() 46 | def load_model(self): 47 | self.model.load_model('save') 48 | def get_random_mnist(self, with_label = False): 49 | if self.mnist == None: 50 | self.mnist = read_data_sets() 51 | if with_label == True: 52 | data, label = self.mnist.next_batch(1, with_label) 53 | return data[0], label[0] 54 | return self.mnist.next_batch(1)[0] 55 | def get_random_specific_mnist(self, label = 2): 56 | m, l = self.get_random_mnist(with_label = True) 57 | for i in range(100): 58 | if l == label: 59 | break 60 | m, l = self.get_random_mnist(with_label = True) 61 | return m 62 | def generate_random_label(self, label): 63 | m = self.get_random_specific_mnist(label) 64 | self.show_image(m) 65 | self.show_image_from_z(self.encode(m)) 66 | def generate_z(self): 67 | z = np.random.normal(size=self.model.z_dim).astype(np.float32) 68 | return z 69 | def encode(self, mnist_data): 70 | new_shape = [1]+list(mnist_data.shape) 71 | return self.model.encode(np.reshape(mnist_data, new_shape)) 72 | def generate(self, z=None, x_dim=512, y_dim=512, scale = 8.0): 73 | if z is None: 74 | z = self.generate_z() 75 | else: 76 | z = np.reshape(z, (1, self.model.z_dim)) 77 | self.z = z 78 | return self.model.generate(z, x_dim, y_dim, scale)[0] 79 | def show_image(self, image_data): 80 | ''' 81 | image_data is a tensor, in [height width depth] 82 | image_data is NOT the PIL.Image class 83 | ''' 84 | plt.subplot(1, 1, 1) 85 | y_dim = image_data.shape[0] 86 | x_dim = image_data.shape[1] 87 | c_dim = self.model.c_dim 88 | if c_dim > 1: 89 | plt.imshow(image_data, interpolation='nearest') 90 | else: 91 | plt.imshow(image_data.reshape(y_dim, x_dim), cmap='Greys', interpolation='nearest') 92 | plt.axis('off') 93 | plt.show() 94 | def show_image_from_z(self, z): 95 | self.show_image(self.generate(z)) 96 | def save_png(self, image_data, filename, specific_size = None): 97 | img_data = np.array(1-image_data) 98 | y_dim = image_data.shape[0] 99 | x_dim = image_data.shape[1] 100 | c_dim = self.model.c_dim 101 | if c_dim > 1: 102 | img_data = np.array(img_data.reshape((y_dim, x_dim, c_dim))*255.0, dtype=np.uint8) 103 | else: 104 | img_data = np.array(img_data.reshape((y_dim, x_dim))*255.0, dtype=np.uint8) 105 | im = Image.fromarray(img_data) 106 | if specific_size != None: 107 | im = im.resize(specific_size) 108 | im.save(filename) 109 | def to_image(self, image_data): 110 | # convert to PIL.Image format from np array (0, 1) 111 | img_data = np.array(1-image_data) 112 | y_dim = image_data.shape[0] 113 | x_dim = image_data.shape[1] 114 | c_dim = self.model.c_dim 115 | if c_dim > 1: 116 | img_data = np.array(img_data.reshape((y_dim, x_dim, c_dim))*255.0, dtype=np.uint8) 117 | else: 118 | img_data = np.array(img_data.reshape((y_dim, x_dim))*255.0, dtype=np.uint8) 119 | im = Image.fromarray(img_data) 120 | return im 121 | def morph(self, z1, z2, n_total_frame = 10, x_dim = 512, y_dim = 512, scale = 8.0, sinusoid = False): 122 | ''' 123 | returns a list of img_data to represent morph between z1 and z2 124 | default to linear morph, but can try sinusoid for more time near the anchor pts 125 | n_total_frame must be >= 2, since by definition there's one frame for z1 and z2 126 | ''' 127 | delta_z = 1.0 / (n_total_frame-1) 128 | diff_z = (z2-z1) 129 | img_data_array = [] 130 | for i in range(n_total_frame): 131 | percentage = delta_z*float(i) 132 | factor = percentage 133 | if sinusoid == True: 134 | factor = np.sin(percentage*np.pi/2) 135 | z = z1 + diff_z*factor 136 | print "processing image ", i 137 | img_data_array.append(self.generate(z, x_dim, y_dim, scale)) 138 | return img_data_array 139 | def save_anim_gif(self, img_data_array, filename, duration = 0.1): 140 | ''' 141 | this saves an animated gif given a list of img_data (numpy arrays) 142 | ''' 143 | images = [] 144 | for i in range(len(img_data_array)): 145 | images.append(self.to_image(img_data_array[i])) 146 | writeGif(filename, images, duration = duration) 147 | -------------------------------------------------------------------------------- /save/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-4" 2 | all_model_checkpoint_paths: "model.ckpt-2" 3 | all_model_checkpoint_paths: "model.ckpt-4" 4 | all_model_checkpoint_paths: "model.ckpt-23" 5 | -------------------------------------------------------------------------------- /save/config.pkl: -------------------------------------------------------------------------------- 1 | ccopy_reg 2 | _reconstructor 3 | p1 4 | (cargparse 5 | Namespace 6 | p2 7 | c__builtin__ 8 | object 9 | p3 10 | NtRp4 11 | (dp5 12 | S'beta1' 13 | p6 14 | F0.75 15 | sS'keep_prob' 16 | p7 17 | F0.65000000000000002 18 | sS'batch_size' 19 | p8 20 | I2000 21 | sS'checkpoint_step' 22 | p9 23 | I1 24 | sS'learning_rate_d' 25 | p10 26 | F0.001 27 | sS'learning_rate_g' 28 | p11 29 | F0.001 30 | sS'display_step' 31 | p12 32 | I1 33 | sS'training_epochs' 34 | p13 35 | I3000 36 | sS'learning_rate_vae' 37 | p14 38 | F0.0001 39 | sb. -------------------------------------------------------------------------------- /save/model.ckpt-2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/save/model.ckpt-2 -------------------------------------------------------------------------------- /save/model.ckpt-2.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/save/model.ckpt-2.meta -------------------------------------------------------------------------------- /save/model.ckpt-23: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/save/model.ckpt-23 -------------------------------------------------------------------------------- /save/model.ckpt-23.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/save/model.ckpt-23.meta -------------------------------------------------------------------------------- /save/model.ckpt-4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/save/model.ckpt-4 -------------------------------------------------------------------------------- /save/model.ckpt-4.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hardmaru/resnet-cppn-gan-tensorflow/9206e06512c118e932fbc789c91a5cf4f9e5d2b9/save/model.ckpt-4.meta -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | import argparse 5 | import time 6 | import os 7 | import cPickle 8 | 9 | from mnist_data import * 10 | from model import CPPNVAE 11 | 12 | ''' 13 | cppn vae: 14 | 15 | compositional pattern-producing generative adversarial network 16 | 17 | LOADS of help was taken from: 18 | 19 | https://github.com/carpedm20/DCGAN-tensorflow 20 | https://jmetzen.github.io/2015-11-27/vae.html 21 | 22 | ''' 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--training_epochs', type=int, default=3000, 27 | help='training epochs') 28 | parser.add_argument('--display_step', type=int, default=1, 29 | help='display step') 30 | parser.add_argument('--checkpoint_step', type=int, default=1, 31 | help='checkpoint step') 32 | parser.add_argument('--batch_size', type=int, default=1000, 33 | help='batch size') 34 | parser.add_argument('--learning_rate_g', type=float, default=0.001, 35 | help='learning rate for G and VAE') 36 | parser.add_argument('--learning_rate_vae', type=float, default=0.0001, 37 | help='learning rate for VAE') 38 | parser.add_argument('--learning_rate_d', type=float, default=0.001, 39 | help='learning rate for D') 40 | parser.add_argument('--keep_prob', type=float, default=0.65, 41 | help='dropout keep probability') 42 | parser.add_argument('--beta1', type=float, default=0.75, 43 | help='adam momentum param for descriminator') 44 | args = parser.parse_args() 45 | return train(args) 46 | 47 | def train(args): 48 | 49 | learning_rate_g = args.learning_rate_g 50 | learning_rate_d = args.learning_rate_d 51 | learning_rate_vae = args.learning_rate_vae 52 | batch_size = args.batch_size 53 | training_epochs = args.training_epochs 54 | display_step = args.display_step 55 | checkpoint_step = args.checkpoint_step # save training results every check point step 56 | beta1 = args.beta1 57 | keep_prob = args.keep_prob 58 | dirname = 'save' 59 | if not os.path.exists(dirname): 60 | os.makedirs(dirname) 61 | 62 | with open(os.path.join(dirname, 'config.pkl'), 'w') as f: 63 | cPickle.dump(args, f) 64 | 65 | mnist = read_data_sets() 66 | n_samples = mnist.num_examples 67 | 68 | cppnvae = CPPNVAE(batch_size=batch_size, learning_rate_g = learning_rate_g, learning_rate_d = learning_rate_d, learning_rate_vae = learning_rate_vae, beta1 = beta1, keep_prob = keep_prob) 69 | 70 | # load previously trained model if appilcable 71 | ckpt = tf.train.get_checkpoint_state(dirname) 72 | if ckpt: 73 | cppnvae.load_model(dirname) 74 | 75 | counter = 0 76 | 77 | # Training cycle 78 | for epoch in range(training_epochs): 79 | avg_d_loss = 0. 80 | avg_d_loss_real = 0. 81 | avg_d_loss_fake = 0. 82 | avg_q_loss = 0. 83 | avg_vae_loss = 0. 84 | avg_d_real_accuracy = 0. 85 | avg_d_fake_accuracy = 0. 86 | avg_g_accuracy = 0. 87 | mnist.shuffle_data() 88 | total_batch = int(n_samples / batch_size) 89 | # Loop over all batches 90 | for i in range(total_batch): 91 | batch_images, batch_labels = mnist.next_batch(batch_size, with_label = True) # obtain training labels 92 | 93 | d_loss, g_loss, vae_loss, n_operations, d_real_accuracy, d_fake_accuracy, g_accuracy, d_loss_real, d_loss_fake = cppnvae.partial_train(batch_images, batch_labels) 94 | 95 | assert( vae_loss < 1000000 ) # make sure it is not NaN or Inf 96 | assert( d_loss < 1000000 ) # make sure it is not NaN or Inf 97 | assert( g_loss < 1000000 ) # make sure it is not NaN or Inf 98 | assert( d_loss_real < 1000000 ) # make sure it is not NaN or Inf 99 | assert( d_loss_fake < 1000000 ) # make sure it is not NaN or Inf 100 | assert( d_real_accuracy < 1000000 ) # make sure it is not NaN or Inf 101 | assert( d_fake_accuracy < 1000000 ) # make sure it is not NaN or Inf 102 | assert( g_accuracy < 1000000 ) # make sure it is not NaN or Inf 103 | 104 | # Display logs per epoch step 105 | if (counter+1) % display_step == 0: 106 | print "Sample:", '%d' % ((i+1)*batch_size), " Epoch:", '%d' % (epoch), \ 107 | "d_loss=", "{:.4f}".format(d_loss), \ 108 | "d_real=", "{:.4f}".format(d_loss_real), \ 109 | "d_fake=", "{:.4f}".format(d_loss_fake), \ 110 | "g_loss=", "{:.4f}".format(g_loss), \ 111 | "vae_loss=", "{:.4f}".format(vae_loss), \ 112 | "d_real_accuracy=", "{:.2f}".format(d_real_accuracy), \ 113 | "d_fake_accuracy=", "{:.2f}".format(d_fake_accuracy), \ 114 | "g_accuracy=", "{:.2f}".format(g_accuracy), \ 115 | "n_op=", '%d' % (n_operations) 116 | counter += 1 117 | # Compute average loss 118 | avg_d_loss += d_loss / n_samples * batch_size 119 | avg_d_loss_real += d_loss_real / n_samples * batch_size 120 | avg_d_loss_fake += d_loss_fake / n_samples * batch_size 121 | avg_q_loss += g_loss / n_samples * batch_size 122 | avg_vae_loss += vae_loss / n_samples * batch_size 123 | avg_d_real_accuracy += d_real_accuracy / n_samples * batch_size 124 | avg_d_fake_accuracy += d_fake_accuracy / n_samples * batch_size 125 | avg_g_accuracy += g_accuracy / n_samples * batch_size 126 | 127 | # Display logs per epoch step 128 | if epoch >= 0: 129 | print "Epoch:", '%04d' % (epoch), \ 130 | "avg_d_loss=", "{:.6f}".format(avg_d_loss), \ 131 | "avg_d_real=", "{:.6f}".format(avg_d_loss_real), \ 132 | "avg_d_fake=", "{:.6f}".format(avg_d_loss_fake), \ 133 | "avg_q_loss=", "{:.6f}".format(avg_q_loss), \ 134 | "d_real_accuracy=", "{:.2f}".format(avg_d_real_accuracy), \ 135 | "d_fake_accuracy=", "{:.2f}".format(avg_d_fake_accuracy), \ 136 | "g_accuracy=", "{:.2f}".format(avg_g_accuracy), \ 137 | "avg_vae_loss=", "{:.6f}".format(avg_vae_loss) 138 | 139 | # save model 140 | if epoch >= 0 and epoch % checkpoint_step == 0: 141 | checkpoint_path = os.path.join('save', 'model.ckpt') 142 | cppnvae.save_model(checkpoint_path, epoch) 143 | print "model saved to {}".format(checkpoint_path) 144 | 145 | # save model one last time, under zero label to denote finish. 146 | cppnvae.save_model(checkpoint_path, 0) 147 | 148 | if __name__ == '__main__': 149 | main() 150 | --------------------------------------------------------------------------------