├── CreateDB.bat ├── Edior.bat ├── FursonaGenerator.zip ├── LICENSE ├── README.md ├── TrainGAN.bat ├── cpu.theanorc ├── create_db.py ├── editor.py ├── gpu.theanorc └── train_gan.py /CreateDB.bat: -------------------------------------------------------------------------------- 1 | python create_db.py 2 | pause 3 | -------------------------------------------------------------------------------- /Edior.bat: -------------------------------------------------------------------------------- 1 | python editor.py %1 2 | pause 3 | -------------------------------------------------------------------------------- /FursonaGenerator.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HackerPoet/FursonaGenerator/ef2ee24774d783fe83a6c7700b9a1cdd1c434368/FursonaGenerator.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 HackerPoet 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FursonaGenerator 2 | A Neural Network to Generate Fursonas 3 | 4 | **Note:** This repository is not actively maintained and exists for reference only. 5 | 6 | ### Standalone App (Windows 64-bit): 7 | https://github.com/HackerPoet/FursonaGenerator/raw/master/FursonaGenerator.zip 8 | 9 | ### Video Explanation: 10 | https://www.youtube.com/watch?v=nBcZGjxnpDY 11 | -------------------------------------------------------------------------------- /TrainGAN.bat: -------------------------------------------------------------------------------- 1 | python train_gan.py 2 | pause 3 | -------------------------------------------------------------------------------- /cpu.theanorc: -------------------------------------------------------------------------------- 1 | [global] 2 | floatX=float32 3 | device=cpu 4 | -------------------------------------------------------------------------------- /create_db.py: -------------------------------------------------------------------------------- 1 | import os, random, sys 2 | import numpy as np 3 | import cv2 4 | from scipy import ndimage 5 | 6 | IMAGE_DIR = "good_pics" 7 | DATA_DIR = "data" 8 | IMAGE_SIZE = 128 9 | 10 | color_imgs = [] 11 | num_imgs = 0 12 | 13 | if not os.path.exists(DATA_DIR): 14 | os.makedirs(DATA_DIR) 15 | 16 | print "Loading Images..." 17 | for file in os.listdir(IMAGE_DIR): 18 | path = IMAGE_DIR + "/" + file 19 | 20 | #Only attempt to load standard image formats 21 | path_split = path.split('.') 22 | if len(path_split) < 2: continue 23 | if path_split[-1] not in ['bmp', 'gif', 'png', 'jpg', 'jpeg']: 24 | continue 25 | 26 | #Make sure image is valid and not corrupt 27 | img = ndimage.imread(path) 28 | if img is None: 29 | assert(False) 30 | if len(img.shape) != 3 or img.shape[2] < 3: 31 | continue 32 | if img.shape[2] > 3: 33 | img = img[:,:,:3] 34 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 35 | 36 | #Crop center square of image 37 | h = img.shape[0] 38 | w = img.shape[1] 39 | if w > h: 40 | offs = (w - h)/2 41 | img = img[:,offs:offs+h,:] 42 | elif h > w: 43 | offs = (h - w)/2 44 | img = img[offs:offs+w,:,:] 45 | 46 | #Scale all images to a uniform size 47 | img = cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE), interpolation = cv2.INTER_AREA) 48 | 49 | #Save some samples 50 | if num_imgs < 10: 51 | cv2.imwrite("color" + str(num_imgs) + ".png", img) 52 | 53 | #Add to running list 54 | color_imgs.append(np.transpose(img, (2,0,1))) 55 | 56 | #Show progress 57 | num_imgs += 1 58 | if num_imgs % 10 == 0: 59 | sys.stdout.write('\r') 60 | sys.stdout.write(str(num_imgs)) 61 | sys.stdout.flush() 62 | print "\nLoaded " + str(num_imgs) + " images." 63 | 64 | print "Saving..." 65 | color_imgs = np.stack(color_imgs, axis=0) 66 | np.save(DATA_DIR + '/color' + str(IMAGE_SIZE) + '.npy', color_imgs) 67 | 68 | print "Done" -------------------------------------------------------------------------------- /editor.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random, sys 3 | import numpy as np 4 | import cv2 5 | 6 | #User constants 7 | device = "cpu" 8 | model_dir = 'test24/' 9 | is_gan = True 10 | background_color = (210, 210, 210) 11 | edge_color = (60, 60, 60) 12 | slider_color = (20, 20, 20) 13 | num_params = 80 14 | image_scale = 3 15 | image_padding = 10 16 | slider_w = 15 17 | slider_h = 100 18 | slider_px = 5 19 | slider_py = 10 20 | slider_cols = 20 21 | 22 | #Keras 23 | print "Loading Keras..." 24 | import os 25 | os.environ['THEANORC'] = "./" + device + ".theanorc" 26 | os.environ['KERAS_BACKEND'] = "theano" 27 | import theano 28 | print "Theano Version: " + theano.__version__ 29 | from keras.models import Sequential, load_model, model_from_json 30 | from keras.layers import Dense, Activation, Dropout, Flatten, Reshape 31 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, ZeroPadding2D 32 | from keras.layers.pooling import MaxPooling2D 33 | from keras.layers.noise import GaussianNoise 34 | from keras.layers.local import LocallyConnected2D 35 | from keras.optimizers import Adam, RMSprop, SGD 36 | from keras.regularizers import l2 37 | from keras.layers.advanced_activations import ELU 38 | from keras.preprocessing.image import ImageDataGenerator 39 | from keras.utils import plot_model 40 | from keras import backend as K 41 | K.set_image_data_format('channels_first') 42 | 43 | print "Loading model..." 44 | if is_gan: 45 | gen_model = load_model(model_dir + 'generator.h5') 46 | num_params = gen_model.input_shape[1] 47 | img_c, img_h, img_w = gen_model.output_shape[1:] 48 | 49 | if len(sys.argv) >= 2: 50 | enc_model = load_model(model_dir + 'encoder.h5') 51 | 52 | fname_in = sys.argv[1] 53 | fname_out = fname_in.split('.') 54 | fname_out[-2] += "_out" 55 | fname_out = '.'.join(fname_out) 56 | 57 | img = cv2.imread(fname_in) 58 | h = img.shape[0] 59 | w = img.shape[1] 60 | if w > h: 61 | offs = (w - h)/2 62 | img = img[:,offs:offs+h,:] 63 | elif h > w: 64 | offs = (h - w)/2 65 | img = img[offs:offs+w,:,:] 66 | img = cv2.resize(img, (img_h, img_w), interpolation = cv2.INTER_AREA) 67 | 68 | img = np.transpose(img, (2, 0, 1)) 69 | img = img.astype(np.float32) / 255.0 70 | img = np.expand_dims(img, axis=0) 71 | 72 | w = enc_model.predict(img) 73 | img = gen_model.predict(enc_model.predict(img))[0] 74 | 75 | img = (img * 255.0).astype(np.uint8) 76 | img = np.transpose(img, (1, 2, 0)) 77 | cv2.imwrite(fname_out, img) 78 | exit(0) 79 | else: 80 | model = load_model(model_dir + 'model.h5') 81 | gen_func = K.function([model.get_layer('encoder').input, K.learning_phase()], [model.layers[-1].output]) 82 | num_params = model.get_layer('encoder').input_shape[1] 83 | img_c, img_h, img_w = model.output_shape[1:] 84 | 85 | assert(img_c == 3) 86 | 87 | #Derived constants 88 | slider_w = slider_w + slider_px*2 89 | slider_h = slider_h + slider_py*2 90 | drawing_x = image_padding 91 | drawing_y = image_padding 92 | drawing_w = img_w * image_scale 93 | drawing_h = img_h * image_scale 94 | slider_rows = (num_params - 1) / slider_cols + 1 95 | sliders_x = drawing_x + drawing_w + image_padding 96 | sliders_y = image_padding 97 | sliders_w = slider_w * slider_cols 98 | sliders_h = slider_h * slider_rows 99 | window_w = drawing_w + sliders_w + image_padding*3 100 | window_h = drawing_h + image_padding*2 101 | 102 | #Global variables 103 | prev_mouse_pos = None 104 | mouse_pressed = False 105 | cur_slider_ix = 0 106 | needs_update = True 107 | cur_params = np.zeros((num_params,), dtype=np.float32) 108 | cur_face = np.zeros((img_c, img_h, img_w), dtype=np.uint8) 109 | rgb_array = np.zeros((img_h, img_w, img_c), dtype=np.uint8) 110 | 111 | print "Loading Statistics..." 112 | means = np.load(model_dir + 'means.npy') 113 | stds = np.load(model_dir + 'stds.npy') 114 | evals = np.load(model_dir + 'evals.npy') 115 | evecs = np.load(model_dir + 'evecs.npy') 116 | 117 | #Open a window 118 | pygame.init() 119 | pygame.font.init() 120 | screen = pygame.display.set_mode((window_w, window_h)) 121 | face_surface_mini = pygame.Surface((img_w, img_h)) 122 | face_surface = screen.subsurface((drawing_x, drawing_y, drawing_w, drawing_h)) 123 | pygame.display.set_caption('Fursona Editor - By CodeParade') 124 | font = pygame.font.SysFont("monospace", 15) 125 | 126 | def update_mouse_click(mouse_pos): 127 | global cur_slider_ix 128 | global mouse_pressed 129 | x = (mouse_pos[0] - sliders_x) 130 | y = (mouse_pos[1] - sliders_y) 131 | 132 | if x >= 0 and y >= 0 and x < sliders_w and y < sliders_h: 133 | slider_ix_w = x / slider_w 134 | slider_ix_h = y / slider_h 135 | 136 | cur_slider_ix = slider_ix_h * slider_cols + slider_ix_w 137 | mouse_pressed = True 138 | 139 | def update_mouse_move(mouse_pos): 140 | global needs_update 141 | y = (mouse_pos[1] - sliders_y) 142 | 143 | if y >= 0 and y < sliders_h: 144 | slider_row_ix = cur_slider_ix / slider_cols 145 | slider_val = y - slider_row_ix * slider_h 146 | 147 | slider_val = min(max(slider_val, slider_py), slider_h - slider_py) - slider_py 148 | val = (float(slider_val) / (slider_h - slider_py*2) - 0.5) * 6.0 149 | cur_params[cur_slider_ix] = val 150 | 151 | needs_update = True 152 | 153 | def draw_sliders(): 154 | for i in xrange(num_params): 155 | row = i / slider_cols 156 | col = i % slider_cols 157 | x = sliders_x + col * slider_w 158 | y = sliders_y + row * slider_h 159 | 160 | cx = x + slider_w / 2 161 | cy_1 = y + slider_py 162 | cy_2 = y + slider_h - slider_py 163 | pygame.draw.line(screen, slider_color, (cx, cy_1), (cx, cy_2)) 164 | 165 | py = y + int((cur_params[i] / 6.0 + 0.5) * (slider_h - slider_py*2)) + slider_py 166 | pygame.draw.circle(screen, slider_color, (cx, py), slider_w/2 - slider_px) 167 | 168 | cx_1 = x + slider_px 169 | cx_2 = x + slider_w - slider_px 170 | for j in xrange(7): 171 | ly = y + slider_h/2 + (j-3)*(slider_h/7) 172 | pygame.draw.line(screen, slider_color, (cx_1, ly), (cx_2, ly)) 173 | 174 | def draw_face(): 175 | pygame.surfarray.blit_array(face_surface_mini, np.transpose(cur_face, (2, 1, 0))) 176 | pygame.transform.scale(face_surface_mini, (drawing_w, drawing_h), face_surface) 177 | pygame.draw.rect(screen, (0,0,0), (drawing_x, drawing_y, drawing_w, drawing_h), 1) 178 | 179 | #Main loop 180 | running = True 181 | while running: 182 | #Process events 183 | for event in pygame.event.get(): 184 | if event.type == pygame.QUIT: 185 | running = False 186 | break 187 | elif event.type == pygame.MOUSEBUTTONDOWN: 188 | if pygame.mouse.get_pressed()[0]: 189 | prev_mouse_pos = pygame.mouse.get_pos() 190 | update_mouse_click(prev_mouse_pos) 191 | update_mouse_move(prev_mouse_pos) 192 | elif pygame.mouse.get_pressed()[2]: 193 | cur_params = np.zeros((num_params,), dtype=np.float32) 194 | needs_update = True 195 | elif event.type == pygame.MOUSEBUTTONUP: 196 | mouse_pressed = False 197 | prev_mouse_pos = None 198 | elif event.type == pygame.MOUSEMOTION and mouse_pressed: 199 | update_mouse_move(pygame.mouse.get_pos()) 200 | elif event.type == pygame.KEYDOWN: 201 | if event.key == pygame.K_r: 202 | cur_params = np.clip(np.random.normal(0.0, 1.0, (num_params,)), -3.0, 3.0) 203 | needs_update = True 204 | 205 | #Check if we need an update 206 | if needs_update: 207 | x = means + np.dot(cur_params * evals, evecs) 208 | #x = means + stds * cur_params 209 | x = np.expand_dims(x, axis=0) 210 | if is_gan: 211 | y = gen_model.predict(x)[0] 212 | else: 213 | y = gen_func([x, 0])[0][0] 214 | cur_face = (y * 255.0).astype(np.uint8) 215 | needs_update = False 216 | 217 | #Draw to the screen 218 | screen.fill(background_color) 219 | draw_face() 220 | draw_sliders() 221 | 222 | #Flip the screen buffer 223 | pygame.display.flip() 224 | pygame.time.wait(10) 225 | -------------------------------------------------------------------------------- /gpu.theanorc: -------------------------------------------------------------------------------- 1 | [global] 2 | floatX=float32 3 | device=cuda 4 | 5 | [nvcc] 6 | fastmath=True 7 | compiler_bindir=C:\Program Files (x86)\Microsoft Visual Studio 12.0\VC\bin 8 | 9 | [dnn] 10 | enabled=True 11 | include_path=C:\CUDA\v8.0\include 12 | library_path=C:\CUDA\v8.0\lib\x64 13 | -------------------------------------------------------------------------------- /train_gan.py: -------------------------------------------------------------------------------- 1 | import os, sys, random 2 | import numpy as np 3 | import cv2 4 | from matplotlib import pyplot as plt 5 | 6 | WRITE_DIR = "test26/" 7 | DATA_DIR = 'data/' 8 | CONTINUE_TRAIN = False 9 | TRAIN_EDGES = False 10 | USE_EMBEDDING = False 11 | USE_MIRROR = False 12 | USE_BG_SWAP = False 13 | USE_ROLLS = False 14 | PARAM_SIZE = 80 15 | NUM_EPOCHS = 50 16 | BATCH_SIZE = 32 17 | RATIO_G = 1 18 | LR_D = 0.0008 19 | LR_G = 0.0008 20 | BETA_1 = 0.8 21 | EPSILON = 1e-4 22 | ENC_WEIGHT = 400.0 23 | BN_M = 0.9 24 | DO_RATE = 0.5 25 | DO_RATE_G = 0.3 26 | NOISE_SIGMA = 0.15 27 | NUM_RAND_FACES = 10 28 | PREV_V = None 29 | 30 | def save_config(fname): 31 | with open(fname, 'w') as fout: 32 | fout.write('PARAM_SIZE ' + str(PARAM_SIZE ) + '\n') 33 | fout.write('NUM_EPOCHS ' + str(NUM_EPOCHS ) + '\n') 34 | fout.write('BATCH_SIZE ' + str(BATCH_SIZE ) + '\n') 35 | fout.write('RATIO_G ' + str(RATIO_G ) + '\n') 36 | fout.write('LR_D ' + str(LR_D ) + '\n') 37 | fout.write('LR_G ' + str(LR_G ) + '\n') 38 | fout.write('BETA_1 ' + str(BETA_1 ) + '\n') 39 | fout.write('EPSILON ' + str(EPSILON ) + '\n') 40 | fout.write('ENC_WEIGHT ' + str(ENC_WEIGHT ) + '\n') 41 | fout.write('BN_M ' + str(BN_M ) + '\n') 42 | fout.write('DO_RATE ' + str(DO_RATE ) + '\n') 43 | fout.write('DO_RATE_G ' + str(DO_RATE_G ) + '\n') 44 | fout.write('NOISE_SIGMA ' + str(NOISE_SIGMA) + '\n') 45 | 46 | if not os.path.exists(WRITE_DIR): 47 | os.makedirs(WRITE_DIR) 48 | save_config(WRITE_DIR + 'config.txt') 49 | 50 | def plotScores(scores, fname, on_top=True): 51 | plt.clf() 52 | ax = plt.gca() 53 | ax.yaxis.tick_right() 54 | ax.yaxis.set_ticks_position('both') 55 | ax.yaxis.grid(True) 56 | for s in scores: 57 | plt.plot(s) 58 | plt.xlabel('Epoch') 59 | loc = ('upper right' if on_top else 'lower right') 60 | plt.legend(['Dis', 'Gen', 'Enc'], loc=loc) 61 | plt.draw() 62 | plt.savefig(fname) 63 | 64 | def shift_keep(imgs, sx, sy): 65 | assert(len(imgs.shape) == 4) 66 | 67 | #Shift X 68 | result_x = np.empty_like(imgs) 69 | if sx > 0: 70 | result_x[:,:,:,:sx] = imgs[:,:,:,:1] 71 | result_x[:,:,:,sx:] = imgs[:,:,:,:-sx] 72 | elif sx < 0: 73 | result_x[:,:,:,sx:] = imgs[:,:,:,-1:] 74 | result_x[:,:,:,:sx] = imgs[:,:,:,-sx:] 75 | else: 76 | result_x = imgs 77 | 78 | #Shift Y 79 | result_y = np.empty_like(result_x) 80 | if sy > 0: 81 | result_y[:,:,:sy] = result_x[:,:,:1] 82 | result_y[:,:,sy:] = result_x[:,:,:-sy] 83 | elif sy < 0: 84 | result_y[:,:,sy:] = result_x[:,:,-1:] 85 | result_y[:,:,:sy] = result_x[:,:,-sy:] 86 | else: 87 | result_y = result_x 88 | 89 | return result_y 90 | 91 | #Load data set 92 | print "Loading Image Data..." 93 | if TRAIN_EDGES: 94 | y_train = np.load(DATA_DIR + 'gray128.npy').astype(np.float32) / 255.0 95 | else: 96 | y_train = np.load(DATA_DIR + 'color128.npy').astype(np.float32) / 255.0 97 | if USE_MIRROR: 98 | y_train = np.concatenate((y_train, np.flip(y_train, axis=3)), axis=0) 99 | if USE_BG_SWAP: 100 | y_train = np.concatenate((y_train, y_train[:,[1,0,2]]), axis=0) 101 | y_orig = y_train 102 | print "Loaded " + str(y_train.shape[0]) + " Samples." 103 | 104 | num_samples = y_train.shape[0] 105 | i_train = np.arange(num_samples) 106 | y_shape = y_train.shape 107 | if USE_EMBEDDING: 108 | x_train = np.expand_dims(np.arange(num_samples), axis=1) 109 | else: 110 | x_train = y_train 111 | x_shape = x_train.shape 112 | 113 | ################################### 114 | # Create Model 115 | ################################### 116 | print "Loading Keras..." 117 | import os, math 118 | os.environ['THEANORC'] = "./gpu.theanorc" 119 | os.environ['KERAS_BACKEND'] = "theano" 120 | import theano 121 | print "Theano Version: " + theano.__version__ 122 | import keras 123 | print "Keras Version: " + keras.__version__ 124 | from keras.initializers import RandomNormal 125 | from keras.layers import Input, Dense, Activation, Dropout, Flatten, Reshape, TimeDistributed, LeakyReLU 126 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D 127 | from keras.layers.embeddings import Embedding 128 | from keras.layers.local import LocallyConnected2D 129 | from keras.layers.noise import GaussianNoise 130 | from keras.layers.normalization import BatchNormalization 131 | from keras.layers.pooling import MaxPooling2D, AveragePooling2D 132 | from keras.models import Model, Sequential, load_model 133 | from keras.optimizers import Adam, RMSprop, SGD 134 | from keras.regularizers import l2 135 | from keras.utils import plot_model 136 | from keras import backend as K 137 | K.set_image_data_format('channels_first') 138 | 139 | ################################### 140 | # Create Model 141 | ################################### 142 | if CONTINUE_TRAIN: 143 | print "Loading Discriminator..." 144 | discriminator = load_model(WRITE_DIR + 'discriminator.h5') 145 | print "Loading Generator..." 146 | generator = load_model(WRITE_DIR + 'generator.h5') 147 | print "Loading Encoder..." 148 | encoder = load_model(WRITE_DIR + 'encoder.h5') 149 | print "Loading Vectors..." 150 | PREV_V = np.load(WRITE_DIR + 'evecs.npy') 151 | z_test = np.load(WRITE_DIR + 'rand.npy') 152 | else: 153 | print "Building Discriminator..." 154 | input_shape = y_shape[1:] 155 | print (None,) + input_shape 156 | discriminator = Sequential() 157 | discriminator.add(GaussianNoise(NOISE_SIGMA, input_shape=input_shape)) 158 | 159 | discriminator.add(Conv2D(40, (5,5), padding='same')) 160 | discriminator.add(MaxPooling2D(2)) 161 | discriminator.add(LeakyReLU(0.2)) 162 | discriminator.add(BatchNormalization(momentum=BN_M, axis=1)) 163 | if DO_RATE > 0: 164 | discriminator.add(Dropout(DO_RATE)) 165 | print discriminator.output_shape 166 | 167 | discriminator.add(Conv2D(60, (5,5), padding='same')) 168 | discriminator.add(MaxPooling2D(2)) 169 | discriminator.add(LeakyReLU(0.2)) 170 | discriminator.add(BatchNormalization(momentum=BN_M, axis=1)) 171 | if DO_RATE > 0: 172 | discriminator.add(Dropout(DO_RATE)) 173 | print discriminator.output_shape 174 | 175 | discriminator.add(Conv2D(120, (5,5), padding='same')) 176 | discriminator.add(MaxPooling2D(8)) 177 | discriminator.add(LeakyReLU(0.2)) 178 | discriminator.add(BatchNormalization(momentum=BN_M, axis=1)) 179 | if DO_RATE > 0: 180 | discriminator.add(Dropout(DO_RATE)) 181 | print discriminator.output_shape 182 | 183 | discriminator.add(Flatten(data_format = 'channels_last')) 184 | print discriminator.output_shape 185 | 186 | discriminator.add(Dense(1, activation='sigmoid')) 187 | print discriminator.output_shape 188 | 189 | print "Building Generator..." 190 | generator = Sequential() 191 | input_shape = (PARAM_SIZE,) 192 | print (None,) + input_shape 193 | 194 | generator.add(Dense(360*4*4, input_shape=input_shape)) 195 | generator.add(Reshape((360,4,4))) 196 | 197 | generator.add(LeakyReLU(0.2)) 198 | print generator.output_shape 199 | if DO_RATE_G > 0: generator.add(Dropout(DO_RATE_G)) 200 | generator.add(BatchNormalization(momentum=BN_M, axis=1)) 201 | print generator.output_shape 202 | 203 | generator.add(Conv2DTranspose(280, (5,5), strides=(2,2), padding='same')) 204 | generator.add(LeakyReLU(0.2)) 205 | if DO_RATE_G > 0: generator.add(Dropout(DO_RATE_G)) 206 | generator.add(BatchNormalization(momentum=BN_M, axis=1)) 207 | print generator.output_shape 208 | 209 | generator.add(Conv2DTranspose(200, (5,5), strides=(2,2), padding='same')) 210 | generator.add(LeakyReLU(0.2)) 211 | if DO_RATE_G > 0: generator.add(Dropout(DO_RATE_G)) 212 | generator.add(BatchNormalization(momentum=BN_M, axis=1)) 213 | print generator.output_shape 214 | 215 | generator.add(Conv2DTranspose(160, (5,5), strides=(2,2), padding='same')) 216 | generator.add(LeakyReLU(0.2)) 217 | if DO_RATE_G > 0: generator.add(Dropout(DO_RATE_G)) 218 | generator.add(BatchNormalization(momentum=BN_M, axis=1)) 219 | print generator.output_shape 220 | 221 | generator.add(Conv2DTranspose(80, (5,5), strides=(2,2), padding='same')) 222 | generator.add(LeakyReLU(0.2)) 223 | if DO_RATE_G > 0: generator.add(Dropout(DO_RATE_G)) 224 | generator.add(BatchNormalization(momentum=BN_M, axis=1)) 225 | print generator.output_shape 226 | 227 | generator.add(Conv2DTranspose(y_train.shape[1], (5,5), strides=(2,2), padding='same', activation='sigmoid')) 228 | print generator.output_shape 229 | 230 | print "Building Encoder..." 231 | encoder = Sequential() 232 | if USE_EMBEDDING: 233 | print (None, num_samples) 234 | encoder.add(Embedding(num_samples, PARAM_SIZE, input_length=1, embeddings_initializer=RandomNormal(stddev=1e-4))) 235 | encoder.add(Flatten(data_format = 'channels_last')) 236 | print encoder.output_shape 237 | else: 238 | input_shape = y_shape[1:] 239 | print (None,) + input_shape 240 | encoder = Sequential() 241 | 242 | encoder.add(Conv2D(120, (5,5), strides=(2,2), padding='same', input_shape=input_shape)) 243 | encoder.add(LeakyReLU(0.2)) 244 | encoder.add(BatchNormalization(momentum=BN_M, axis=1)) 245 | print encoder.output_shape 246 | 247 | encoder.add(Conv2D(200, (5,5), strides=(2,2), padding='same')) 248 | encoder.add(LeakyReLU(0.2)) 249 | encoder.add(BatchNormalization(momentum=BN_M, axis=1)) 250 | print encoder.output_shape 251 | 252 | encoder.add(Conv2D(260, (5,5), strides=(2,2), padding='same')) 253 | encoder.add(LeakyReLU(0.2)) 254 | encoder.add(BatchNormalization(momentum=BN_M, axis=1)) 255 | print encoder.output_shape 256 | 257 | encoder.add(Conv2D(300, (5,5), strides=(2,2), padding='same')) 258 | encoder.add(LeakyReLU(0.2)) 259 | encoder.add(BatchNormalization(momentum=BN_M, axis=1)) 260 | print encoder.output_shape 261 | 262 | encoder.add(Conv2D(360, (5,5), strides=(2,2), padding='same')) 263 | encoder.add(LeakyReLU(0.2)) 264 | encoder.add(BatchNormalization(momentum=BN_M, axis=1)) 265 | print encoder.output_shape 266 | 267 | encoder.add(Flatten(data_format = 'channels_last')) 268 | print encoder.output_shape 269 | 270 | encoder.add(Dense(PARAM_SIZE)) 271 | encoder.add(BatchNormalization(momentum=BN_M)) 272 | print encoder.output_shape 273 | 274 | print "Building GANN..." 275 | d_optimizer = Adam(lr=LR_D, beta_1=BETA_1, epsilon=EPSILON) 276 | g_optimizer = Adam(lr=LR_G, beta_1=BETA_1, epsilon=EPSILON) 277 | 278 | discriminator.trainable = True 279 | generator.trainable = False 280 | encoder.trainable = False 281 | d_in_real = Input(shape=y_shape[1:]) 282 | d_in_fake = Input(shape=x_shape[1:]) 283 | d_fake = generator(encoder(d_in_fake)) 284 | d_out_real = discriminator(d_in_real) 285 | d_out_real = Activation('linear', name='d_out_real')(d_out_real) 286 | d_out_fake = discriminator(d_fake) 287 | d_out_fake = Activation('linear', name='d_out_fake')(d_out_fake) 288 | dis_model = Model(inputs=[d_in_real, d_in_fake], outputs=[d_out_real, d_out_fake]) 289 | dis_model.compile( 290 | optimizer=d_optimizer, 291 | loss={'d_out_real':'binary_crossentropy', 'd_out_fake':'binary_crossentropy'}, 292 | loss_weights={'d_out_real':1.0, 'd_out_fake':1.0}) 293 | 294 | discriminator.trainable = False 295 | generator.trainable = True 296 | encoder.trainable = True 297 | g_in = Input(shape=x_shape[1:]) 298 | g_enc = encoder(g_in) 299 | g_out_img = generator(g_enc) 300 | g_out_img = Activation('linear', name='g_out_img')(g_out_img) 301 | g_out_dis = discriminator(g_out_img) 302 | g_out_dis = Activation('linear', name='g_out_dis')(g_out_dis) 303 | gen_dis_model = Model(inputs=[g_in], outputs=[g_out_img, g_out_dis]) 304 | gen_dis_model.compile( 305 | optimizer=g_optimizer, 306 | loss={'g_out_img':'mse', 'g_out_dis':'binary_crossentropy'}, 307 | loss_weights={'g_out_img':ENC_WEIGHT, 'g_out_dis':1.0}) 308 | 309 | plot_model(generator, to_file=WRITE_DIR + 'generator.png', show_shapes=True) 310 | plot_model(discriminator, to_file=WRITE_DIR + 'discriminator.png', show_shapes=True) 311 | plot_model(encoder, to_file=WRITE_DIR + 'encoder.png', show_shapes=True) 312 | 313 | ################################### 314 | # Encoder Decoder 315 | ################################### 316 | def save_image(fname, x): 317 | img = (x * 255.0).astype(np.uint8) 318 | img = np.transpose(img, (1, 2, 0)) 319 | cv2.imwrite(fname, img) 320 | 321 | def make_rand_faces(write_dir, x_vecs): 322 | y_faces = generator.predict(x_vecs) 323 | for i in xrange(y_faces.shape[0]): 324 | save_image(write_dir + 'rand' + str(i) + '.png', y_faces[i]) 325 | 326 | def make_rand_faces_normalized(write_dir, rand_vecs): 327 | global PREV_V 328 | x_enc = encoder.predict(x_train, batch_size=BATCH_SIZE) 329 | 330 | x_mean = np.mean(x_enc, axis=0) 331 | x_stds = np.std(x_enc, axis=0) 332 | x_cov = np.cov((x_enc - x_mean).T) 333 | u, s, v = np.linalg.svd(x_cov) 334 | e = np.sqrt(s) 335 | 336 | # This step is not necessary, but it makes random the generated test 337 | # samples consistent between epochs so you can see the evolution of 338 | # the training better. 339 | # 340 | # Like a square root, each prinicpal component has 2 solutions that 341 | # represent opposing vector directions. For each component, just 342 | # choose the direction that was closest to the last epoch. 343 | if PREV_V is not None: 344 | d = np.sum(PREV_V * v, axis=1) 345 | d = np.where(d > 0.0, 1.0, -1.0) 346 | v = v * np.expand_dims(d, axis=1) 347 | PREV_V = v 348 | 349 | print "Evals: ", e[:6] 350 | 351 | np.save(write_dir + 'means.npy', x_mean) 352 | np.save(write_dir + 'stds.npy', x_stds) 353 | np.save(write_dir + 'evals.npy', e) 354 | np.save(write_dir + 'evecs.npy', v) 355 | 356 | x_vecs = x_mean + np.dot(rand_vecs * e, v) 357 | make_rand_faces(write_dir, x_vecs) 358 | 359 | plt.clf() 360 | e[::-1].sort() 361 | plt.title('evals') 362 | plt.bar(np.arange(e.shape[0]), e, align='center') 363 | plt.draw() 364 | plt.savefig(write_dir + '_evals.png') 365 | 366 | plt.clf() 367 | plt.title('means') 368 | plt.bar(np.arange(e.shape[0]), x_mean, align='center') 369 | plt.draw() 370 | plt.savefig(write_dir + '_means.png') 371 | 372 | plt.clf() 373 | plt.title('stds') 374 | plt.bar(np.arange(e.shape[0]), x_stds, align='center') 375 | plt.draw() 376 | plt.savefig(write_dir + '_stds.png') 377 | 378 | def save_models(): 379 | discriminator.save(WRITE_DIR + 'discriminator.h5') 380 | generator.save(WRITE_DIR + 'generator.h5') 381 | encoder.save(WRITE_DIR + 'encoder.h5') 382 | print "Saved" 383 | 384 | ################################### 385 | # Train 386 | ################################### 387 | print "Training..." 388 | generator_loss = [] 389 | discriminator_loss = [] 390 | encoder_loss = [] 391 | 392 | z_test = np.random.normal(0.0, 1.0, (NUM_RAND_FACES, PARAM_SIZE)) 393 | np.save(WRITE_DIR + 'rand.npy', z_test) 394 | 395 | for iters in xrange(NUM_EPOCHS): 396 | if USE_ROLLS: 397 | y_rolls = [] 398 | for i in xrange(10): 399 | sx = random.randint(-8,8) 400 | sy = random.randint(-8,8) 401 | y_rolls.append(shift_keep(y_orig, sx, sy)) 402 | y_train = np.concatenate(y_rolls, axis=0) 403 | x_train = y_train 404 | num_samples = y_train.shape[0] 405 | i_train = np.arange(num_samples) 406 | 407 | loss_d = 0.0 408 | loss_g = 0.0 409 | loss_e = 0.0 410 | num_d = 0 411 | num_g = 0 412 | num_e = 0 413 | 414 | np.random.shuffle(i_train) 415 | for i in xrange(0, num_samples/BATCH_SIZE): 416 | if i % RATIO_G == 0: 417 | #Make samples 418 | j = i / RATIO_G 419 | i_batch1 = i_train[j*BATCH_SIZE:(j + 1)*BATCH_SIZE] 420 | x_batch1 = x_train[i_batch1] 421 | y_batch1 = y_train[i_batch1] 422 | 423 | ones = np.ones((BATCH_SIZE,), dtype=np.float32) 424 | zeros = np.zeros((BATCH_SIZE,), dtype=np.float32) 425 | 426 | losses = dis_model.train_on_batch([y_batch1, x_batch1], [ones, zeros]) 427 | names = dis_model.metrics_names 428 | loss_d += losses[names.index('d_out_real_loss')] 429 | loss_d += losses[names.index('d_out_fake_loss')] 430 | num_d += 2 431 | 432 | i_batch2 = i_train[i*BATCH_SIZE:(i + 1)*BATCH_SIZE] 433 | x_batch2 = x_train[i_batch2] 434 | y_batch2 = y_train[i_batch2] 435 | 436 | losses = gen_dis_model.train_on_batch([x_batch2], [y_batch2, ones]) 437 | names = gen_dis_model.metrics_names 438 | loss_e += losses[names.index('g_out_img_loss')] 439 | loss_g += losses[names.index('g_out_dis_loss')] 440 | num_e += 1 441 | num_g += 1 442 | 443 | progress = (i * 100)*BATCH_SIZE / num_samples 444 | sys.stdout.write( 445 | str(progress) + "%" + 446 | " D:" + str(loss_d / num_d) + 447 | " G:" + str(loss_g / num_g) + 448 | " E:" + str(loss_e / num_e) + " ") 449 | sys.stdout.write('\r') 450 | sys.stdout.flush() 451 | sys.stdout.write('\n') 452 | 453 | discriminator_loss.append(loss_d / num_d) 454 | generator_loss.append(loss_g / num_g) 455 | encoder_loss.append(loss_e * 10.0 / num_e) 456 | 457 | try: 458 | plotScores([discriminator_loss, generator_loss, encoder_loss], WRITE_DIR + 'Scores.png') 459 | save_models() 460 | 461 | make_rand_faces_normalized(WRITE_DIR, z_test) 462 | i_test = i_train[-NUM_RAND_FACES:] 463 | x_test = x_train[i_test] 464 | y_test = y_train[i_test] 465 | y_pred = generator.predict(encoder.predict(x_test)) 466 | for i in xrange(y_pred.shape[0]): 467 | save_image(WRITE_DIR + "gt" + str(i) + ".png", y_test[i]) 468 | save_image(WRITE_DIR + "pred" + str(i) + ".png", y_pred[i]) 469 | except IOError: 470 | pass 471 | 472 | print "Done" 473 | --------------------------------------------------------------------------------