├── src ├── layer │ ├── __init__.py │ └── layers.py ├── models │ ├── __init__.py │ └── BEGAN.py ├── function │ ├── __init__.py │ └── functions.py ├── operator │ ├── __init__.py │ ├── op_base.py │ └── op_BEGAN.py ├── __init__.py └── __init__.pyc ├── Result ├── kt.jpg ├── 64x64.bmp ├── result.gif ├── 128x128.bmp ├── decoder.bmp ├── gamma_0.3.bmp ├── gamma_0.4.bmp ├── gamma_0.5.bmp ├── gamma_0.7.bmp ├── m_global.jpg └── m_global.tif ├── began_cmd.txt ├── Data └── celeba │ └── face_detect.py ├── main.py └── README.md /src/layer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/function/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/operator/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | ## __init__.py -------------------------------------------------------------------------------- /Result/kt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/kt.jpg -------------------------------------------------------------------------------- /Result/64x64.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/64x64.bmp -------------------------------------------------------------------------------- /Result/result.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/result.gif -------------------------------------------------------------------------------- /src/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/src/__init__.pyc -------------------------------------------------------------------------------- /Result/128x128.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/128x128.bmp -------------------------------------------------------------------------------- /Result/decoder.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/decoder.bmp -------------------------------------------------------------------------------- /Result/gamma_0.3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.3.bmp -------------------------------------------------------------------------------- /Result/gamma_0.4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.4.bmp -------------------------------------------------------------------------------- /Result/gamma_0.5.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.5.bmp -------------------------------------------------------------------------------- /Result/gamma_0.7.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.7.bmp -------------------------------------------------------------------------------- /Result/m_global.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/m_global.jpg -------------------------------------------------------------------------------- /Result/m_global.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/m_global.tif -------------------------------------------------------------------------------- /src/function/functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.misc as scm 3 | 4 | def make_project_dir(project_dir): 5 | if not os.path.exists(project_dir): 6 | os.makedirs(project_dir) 7 | os.makedirs(os.path.join(project_dir, 'models')) 8 | os.makedirs(os.path.join(project_dir, 'result')) 9 | os.makedirs(os.path.join(project_dir, 'result_test')) 10 | 11 | 12 | def get_image(img_path): 13 | img = scm.imread(img_path)/255. - 0.5 14 | img = img[..., ::-1] # rgb to bgr 15 | return img 16 | 17 | 18 | def inverse_image(img): 19 | img = (img + 0.5) * 255. 20 | img[img > 255] = 255 21 | img[img < 0] = 0 22 | img = img[..., ::-1] # bgr to rgb 23 | return img 24 | 25 | -------------------------------------------------------------------------------- /began_cmd.txt: -------------------------------------------------------------------------------- 1 | ### Train 2 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4 3 | python3 main.py -f 1 -p "began3" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0" 4 | 5 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7 6 | python3 main.py -f 1 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0" 7 | 8 | ### Test (refer the main.py and began_cmd) 9 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4 10 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0" 11 | 12 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7 13 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0" 14 | -------------------------------------------------------------------------------- /Data/celeba/face_detect.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | # Get user supplied values 5 | # Create the haar cascade 6 | cascPath = 'haarcascade_frontalface_default.xml' 7 | faceCascade = cv2.CascadeClassifier(cascPath) 8 | 9 | # Read the image 10 | for fn in sorted(os.listdir('raw')): 11 | print(fn) 12 | image = cv2.imread('raw/' + fn) 13 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 14 | 15 | faces = faceCascade.detectMultiScale(gray,5, 5) 16 | 17 | if len(faces) == 0: 18 | pass 19 | else: 20 | x, y, w, h = faces[0] 21 | image_crop = image[y:y+w, x:x+w, :] 22 | image_resize = cv2.resize(image_crop, (128, 128)) 23 | cv2.imwrite('128_crop/' + fn[:-4] + '_crop' + fn[-4:], image_resize) 24 | 25 | # for (x, y, w, h) in faces: 26 | # print x, y, w, h 27 | # cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), 2) 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /src/layer/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def conv2d(x, filter_shape, bias=True, stride=1, padding="SAME", name="conv2d"): 6 | kw, kh, nin, nout = filter_shape 7 | pad_size = (kw - 1) / 2 8 | 9 | if padding == "VALID": 10 | x = tf.pad(x, [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], "SYMMETRIC") 11 | 12 | initializer = tf.random_normal_initializer(0., 0.02) 13 | with tf.variable_scope(name): 14 | weight = tf.get_variable("weight", shape=filter_shape, initializer=initializer) 15 | x = tf.nn.conv2d(x, weight, [1, stride, stride, 1], padding=padding) 16 | 17 | if bias: 18 | b = tf.get_variable("bias", shape=filter_shape[-1], initializer=tf.constant_initializer(0.)) 19 | x = tf.nn.bias_add(x, b) 20 | return x 21 | 22 | 23 | def fc(x, output_shape, bias=True, name='fc'): 24 | shape = x.get_shape().as_list() 25 | dim = np.prod(shape[1:]) 26 | x = tf.reshape(x, [-1, dim]) 27 | input_shape = dim 28 | 29 | initializer = tf.random_normal_initializer(0., 0.02) 30 | with tf.variable_scope(name): 31 | weight = tf.get_variable("weight", shape=[input_shape, output_shape], initializer=initializer) 32 | x = tf.matmul(x, weight) 33 | 34 | if bias: 35 | b = tf.get_variable("bias", shape=[output_shape], initializer=tf.constant_initializer(0.)) 36 | x = tf.nn.bias_add(x, b) 37 | return x 38 | 39 | 40 | def pool(x, r=2, s=1): 41 | return tf.nn.avg_pool(x, ksize=[1, r, r, 1], strides=[1, s, s, 1], padding="SAME") 42 | 43 | 44 | def l1_loss(x, y): 45 | return tf.reduce_mean(tf.abs(x - y)) 46 | 47 | 48 | def resize_nn(x, size): 49 | return tf.image.resize_nearest_neighbor(x, size=(int(size), int(size))) 50 | -------------------------------------------------------------------------------- /src/operator/op_base.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import time 3 | import numpy as np 4 | import tensorflow as tf 5 | from src.function.functions import * 6 | 7 | class op_base: 8 | def __init__(self, args, sess): 9 | self.sess = sess 10 | 11 | # Train 12 | self.flag = args.flag 13 | self.gpu_number = args.gpu_number 14 | self.project = args.project 15 | 16 | # Train Data 17 | self.data_dir = args.data_dir #./Data 18 | self.dataset = args.dataset # celeba 19 | self.data_size = args.data_size # 64 or 128 20 | self.data_opt = args.data_opt # raw or crop 21 | 22 | # Train Iteration 23 | self.niter = args.niter 24 | self.niter_snapshot = args.nsnapshot 25 | self.max_to_keep = args.max_to_keep 26 | 27 | # Train Parameter 28 | self.batch_size = args.batch_size 29 | self.learning_rate = args.learning_rate 30 | self.mm = args.momentum 31 | self.mm2 = args.momentum2 32 | self.lamda = args.lamda 33 | self.gamma = args.gamma 34 | self.filter_number = args.filter_number 35 | self.input_size = args.input_size 36 | self.embedding = args.embedding 37 | 38 | # Result Dir & File 39 | self.project_dir = 'assets/{0}_{1}_{2}_{3}/'.format(self.project, self.dataset, self.data_opt, self.data_size) 40 | self.ckpt_dir = os.path.join(self.project_dir, 'models') 41 | self.model_name = "{0}.model".format(self.project) 42 | self.ckpt_model_name = os.path.join(self.ckpt_dir, self.model_name) 43 | 44 | # etc. 45 | if not os.path.exists('assets'): 46 | os.makedirs('assets') 47 | make_project_dir(self.project_dir) 48 | 49 | def load(self, sess, saver, ckpt_dir): 50 | ckpt = tf.train.get_checkpoint_state(ckpt_dir) 51 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 52 | saver.restore(sess, os.path.join(ckpt_dir, ckpt_name)) 53 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import distutils.util 3 | import os 4 | import tensorflow as tf 5 | import src.models.BEGAN as began 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | 11 | parser.add_argument("-f", "--flag", type=distutils.util.strtobool, default='true') 12 | parser.add_argument("-g", "--gpu_number", type=str, default="0") 13 | parser.add_argument("-p", "--project", type=str, default="began") 14 | 15 | # Train Data 16 | parser.add_argument("-d", "--data_dir", type=str, default="./Data") 17 | parser.add_argument("-trd", "--dataset", type=str, default="celeba") 18 | parser.add_argument("-tro", "--data_opt", type=str, default="crop") 19 | parser.add_argument("-trs", "--data_size", type=int, default=64) 20 | 21 | # Train Iteration 22 | parser.add_argument("-n" , "--niter", type=int, default=50) 23 | parser.add_argument("-ns", "--nsnapshot", type=int, default=2440) 24 | parser.add_argument("-mx", "--max_to_keep", type=int, default=5) 25 | 26 | # Train Parameter 27 | parser.add_argument("-b" , "--batch_size", type=int, default=16) 28 | parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4) 29 | parser.add_argument("-m" , "--momentum", type=float, default=0.5) 30 | parser.add_argument("-m2", "--momentum2", type=float, default=0.999) 31 | parser.add_argument("-gm", "--gamma", type=float, default=0.5) 32 | parser.add_argument("-lm", "--lamda", type=float, default=0.001) 33 | parser.add_argument("-fn", "--filter_number", type=int, default=64) 34 | parser.add_argument("-z", "--input_size", type=int, default=64) 35 | parser.add_argument("-em", "--embedding", type=int, default=64) 36 | 37 | args = parser.parse_args() 38 | 39 | gpu_number = args.gpu_number 40 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_number 41 | 42 | with tf.device('/gpu:{0}'.format(gpu_number)): 43 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90) 44 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) 45 | 46 | with tf.Session(config=config) as sess: 47 | model = began.BEGAN(args, sess) 48 | 49 | # TRAIN / TEST 50 | if args.flag: 51 | model.train(args.flag) 52 | else: 53 | model.test(args.flag) 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEGAN: Boundary Equilibrium Generative Adversarial Networks 2 | Implementation of Google Brain's [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717) in Tensorflow. \ 3 | BEGAN is the state of the art when it comes to generate realistic faces. 4 | 5 |

6 | 7 | 8 |

9 | 10 | Figure1a. 128x128 img and 64x64 img. 128x128 img is very impressive. You can see SET OF TEETH 11 | 12 |

13 | 14 | 15 | 16 |

17 | 18 | Figure1b. This is random result from my train model. From gamma 0.3 to 0.5. No cherry picking. gamma 0.3, nice but bias to women's face. gamma 0.4, Best. gamma 0.5, good texture but hole problem. 19 | 20 |

21 | 22 |

23 | 24 | Figure1c. From scratch to 200k iter 25 | 26 | ## Implementation detail 27 | This train model is 64x64. 128x128 will be update. Different with original paper is train loss update method, learning rate decay. First, paper's loss update way is Loss_G and Loss_D simultaneously. But when I tried that way, models are mode collapse. So, This code use altenative way. Second, learning rate decay is 0.95 every 2000 iter. This parameter is just train experienc. You can change or see the paper. 28 | 29 | ## Train progress 30 | If you want to see the train progress download [this dropbox folder](https://www.dropbox.com/sh/g72k2crptow3ime/AAAhkGlHCw9zQh0aE-Ggdt3Qa?dl=0) and run "tensorboard --logdir='./'". I uploaded two trained model(64x64 and 128x128) 31 | 32 |

33 | 34 |

35 | 36 | Figure2. Kt graph. When you train model, reference this result. It doesn't reach to 1.0. In my case, it's converge to 0.08 37 | 38 |

39 | 40 |

41 | 42 | Figure3. Convergence measure(M_global). Similar with paper's graph 43 | 44 |

45 | 46 | 47 |

48 | 49 | Figure4. Compare with Generator output and Decoder output. 50 | 51 | 52 | 53 | ## Usage 54 | Recommend to download trained model [this dropbox folder](https://www.dropbox.com/sh/g72k2crptow3ime/AAAhkGlHCw9zQh0aE-Ggdt3Qa?dl=0). 55 | 56 | ### Make Train Data 57 | 1. Download [celebA dataset (img_align_celeba.zip)](http://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FImg) and unzip to 'Data/celeba/raw' 58 | 2. Run ' python ./Data/celeba/face_detect.py ' 59 | 60 | ### Train (refer the main.py began_cmd) 61 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4 62 | python3 main.py -f 1 -p "began" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0" 63 | 64 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7 65 | python3 main.py -f 1 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0" 66 | 67 | ### Test (refer the main.py and began_cmd) 68 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4 69 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0" 70 | 71 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7 72 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0" 73 | 74 | 75 | ## Requirements 76 | - Python 3.5, scipy 0.18.1, numpy 1.11.2 77 | - TensorFlow 1.1.0 78 | 79 | ## Author 80 | Heumi / ckhfight@gmail.com 81 | 82 | -------------------------------------------------------------------------------- /src/models/BEGAN.py: -------------------------------------------------------------------------------- 1 | from src.layer.layers import * 2 | from src.operator.op_BEGAN import Operator 3 | 4 | 5 | class BEGAN(Operator): 6 | def __init__(self, args, sess): 7 | Operator.__init__(self, args, sess) 8 | 9 | def generator(self, x, reuse=None): 10 | with tf.variable_scope('gen_') as scope: 11 | if reuse: 12 | scope.reuse_variables() 13 | 14 | w = self.data_size 15 | f = self.filter_number 16 | p = "SAME" 17 | 18 | x = fc(x, 8 * 8 * f, name='fc') 19 | x = tf.reshape(x, [-1, 8, 8, f]) 20 | 21 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a') 22 | x = tf.nn.elu(x) 23 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b') 24 | x = tf.nn.elu(x) 25 | 26 | if self.data_size == 128: 27 | x = resize_nn(x, w / 8) 28 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a') 29 | x = tf.nn.elu(x) 30 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b') 31 | x = tf.nn.elu(x) 32 | 33 | x = resize_nn(x, w / 4) 34 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a') 35 | x = tf.nn.elu(x) 36 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b') 37 | x = tf.nn.elu(x) 38 | 39 | x = resize_nn(x, w / 2) 40 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a') 41 | x = tf.nn.elu(x) 42 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b') 43 | x = tf.nn.elu(x) 44 | 45 | x = resize_nn(x, w) 46 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv5_a') 47 | x = tf.nn.elu(x) 48 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv5_b') 49 | x = tf.nn.elu(x) 50 | 51 | x = conv2d(x, [3, 3, f, 3], stride=1, padding=p,name='conv6_a') 52 | return x 53 | 54 | def encoder(self, x, reuse=None): 55 | with tf.variable_scope('disc_') as scope: 56 | if reuse: 57 | scope.reuse_variables() 58 | 59 | f = self.filter_number 60 | h = self.embedding 61 | p = "SAME" 62 | 63 | x = conv2d(x, [3, 3, 3, f], stride=1, padding=p,name='conv1_enc_a') 64 | x = tf.nn.elu(x) 65 | 66 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_a') 67 | x = tf.nn.elu(x) 68 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_b') 69 | x = tf.nn.elu(x) 70 | 71 | x = conv2d(x, [1, 1, f, 2 * f], stride=1, padding=p,name='conv3_enc_0') 72 | x = pool(x, r=2, s=2) 73 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_a') 74 | x = tf.nn.elu(x) 75 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_b') 76 | x = tf.nn.elu(x) 77 | 78 | x = conv2d(x, [1, 1, 2 * f, 3 * f], stride=1, padding=p,name='conv4_enc_0') 79 | x = pool(x, r=2, s=2) 80 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_a') 81 | x = tf.nn.elu(x) 82 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_b') 83 | x = tf.nn.elu(x) 84 | 85 | x = conv2d(x, [1, 1, 3 * f, 4 * f], stride=1, padding=p,name='conv5_enc_0') 86 | x = pool(x, r=2, s=2) 87 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_a') 88 | x = tf.nn.elu(x) 89 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_b') 90 | x = tf.nn.elu(x) 91 | 92 | if self.data_size == 128: 93 | x = conv2d(x, [1, 1, 4 * f, 5 * f], stride=1, padding=p,name='conv6_enc_0') 94 | x = pool(x, r=2, s=2) 95 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_a') 96 | x = tf.nn.elu(x) 97 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_b') 98 | x = tf.nn.elu(x) 99 | 100 | x = fc(x, h, name='enc_fc') 101 | return x 102 | 103 | def decoder(self, x, reuse=None): 104 | with tf.variable_scope('disc_') as scope: 105 | if reuse: 106 | scope.reuse_variables() 107 | 108 | w = self.data_size 109 | f = self.filter_number 110 | p = "SAME" 111 | 112 | x = fc(x, 8 * 8 * f, name='fc') 113 | x = tf.reshape(x, [-1, 8, 8, f]) 114 | 115 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a') 116 | x = tf.nn.elu(x) 117 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b') 118 | x = tf.nn.elu(x) 119 | 120 | if self.data_size == 128: 121 | x = resize_nn(x, w / 8) 122 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a') 123 | x = tf.nn.elu(x) 124 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b') 125 | x = tf.nn.elu(x) 126 | 127 | x = resize_nn(x, w / 4) 128 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a') 129 | x = tf.nn.elu(x) 130 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b') 131 | x = tf.nn.elu(x) 132 | 133 | x = resize_nn(x, w / 2) 134 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a') 135 | x = tf.nn.elu(x) 136 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b') 137 | x = tf.nn.elu(x) 138 | 139 | x = resize_nn(x, w) 140 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_a') 141 | x = tf.nn.elu(x) 142 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_b') 143 | x = tf.nn.elu(x) 144 | 145 | x = conv2d(x, [3, 3, f, 3], stride=1, padding=p, name='conv6_a') 146 | return x 147 | -------------------------------------------------------------------------------- /src/operator/op_BEGAN.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import time 3 | import datetime 4 | from src.layer.layers import * 5 | from src.function.functions import * 6 | from src.operator.op_base import op_base 7 | 8 | 9 | class Operator(op_base): 10 | def __init__(self, args, sess): 11 | op_base.__init__(self, args, sess) 12 | self.build_model() 13 | 14 | def build_model(self): 15 | # Input placeholder 16 | self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.input_size], name='x') 17 | self.y = tf.placeholder(tf.float32, shape=[self.batch_size, self.data_size, self.data_size, 3], name='y') 18 | self.kt = tf.placeholder(tf.float32, name='kt') 19 | self.lr = tf.placeholder(tf.float32, name='lr') 20 | 21 | # Generator 22 | self.recon_gen = self.generator(self.x) 23 | 24 | # Discriminator (Critic) 25 | d_real = self.decoder(self.encoder(self.y)) 26 | d_fake = self.decoder(self.encoder(self.recon_gen, reuse=True), reuse=True) 27 | self.recon_dec = self.decoder(self.x, reuse=True) 28 | 29 | # Loss 30 | self.d_real_loss = l1_loss(self.y, d_real) 31 | self.d_fake_loss = l1_loss(self.recon_gen, d_fake) 32 | self.d_loss = self.d_real_loss - self.kt * self.d_fake_loss 33 | self.g_loss = self.d_fake_loss 34 | self.m_global = self.d_real_loss + tf.abs(self.gamma * self.d_real_loss - self.d_fake_loss) 35 | 36 | # Variables 37 | g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "gen_") 38 | d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "disc_") 39 | 40 | # Optimizer 41 | self.opt_g = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.g_loss, var_list=g_vars) 42 | self.opt_d = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.d_loss, var_list=d_vars) 43 | 44 | 45 | # initializer 46 | self.sess.run(tf.global_variables_initializer()) 47 | 48 | # tf saver 49 | self.saver = tf.train.Saver(max_to_keep=(self.max_to_keep)) 50 | 51 | try: 52 | self.load(self.sess, self.saver, self.ckpt_dir) 53 | except: 54 | # save full graph 55 | self.saver.save(self.sess, self.ckpt_model_name, write_meta_graph=True) 56 | 57 | # Summary 58 | if self.flag: 59 | tf.summary.scalar('loss/loss', self.d_loss + self.g_loss) 60 | tf.summary.scalar('loss/g_loss', self.g_loss) 61 | tf.summary.scalar('loss/d_loss', self.d_loss) 62 | tf.summary.scalar('loss/d_real_loss', self.d_real_loss) 63 | tf.summary.scalar('loss/d_fake_loss', self.d_fake_loss) 64 | tf.summary.scalar('misc/kt', self.kt) 65 | tf.summary.scalar('misc/m_global', self.m_global) 66 | self.merged = tf.summary.merge_all() 67 | self.writer = tf.summary.FileWriter(self.project_dir, self.sess.graph) 68 | 69 | def train(self, train_flag): 70 | # load data 71 | data_path = '{0}/{1}/{2}_{3}'.format(self.data_dir, self.dataset, self.data_size, self.data_opt) 72 | 73 | if os.path.exists(data_path + '.npy'): 74 | data = np.load(data_path + '.npy') 75 | else: 76 | data = sorted(glob.glob(os.path.join(data_path, "*.*"))) 77 | np.save(data_path + '.npy', data) 78 | 79 | print('Shuffle ....') 80 | random_order = np.random.permutation(len(data)) 81 | data = [data[i] for i in random_order[:]] 82 | print('Shuffle Done') 83 | 84 | # initial parameter 85 | start_time = time.time() 86 | kt = np.float32(0.) 87 | lr = np.float32(self.learning_rate) 88 | self.count = 0 89 | 90 | for epoch in range(self.niter): 91 | batch_idxs = len(data) // self.batch_size 92 | 93 | for idx in range(0, batch_idxs): 94 | self.count += 1 95 | 96 | batch_x = np.random.uniform(-1., 1., size=[self.batch_size, self.input_size]) 97 | batch_files = data[idx * self.batch_size: (idx + 1) * self.batch_size] 98 | batch_data = [get_image(batch_file) for batch_file in batch_files] 99 | 100 | # opt & feed list (different with paper) 101 | g_opt = [self.opt_g, self.g_loss, self.d_real_loss, self.d_fake_loss] 102 | d_opt = [self.opt_d, self.d_loss, self.merged] 103 | feed_dict = {self.x: batch_x, self.y: batch_data, self.kt: kt, self.lr: lr} 104 | 105 | # run tensorflow 106 | _, loss_g, d_real_loss, d_fake_loss = self.sess.run(g_opt, feed_dict=feed_dict) 107 | _, loss_d, summary = self.sess.run(d_opt, feed_dict=feed_dict) 108 | 109 | # update kt, m_global 110 | kt = np.maximum(np.minimum(1., kt + self.lamda * (self.gamma * d_real_loss - d_fake_loss)), 0.) 111 | m_global = d_real_loss + np.abs(self.gamma * d_real_loss - d_fake_loss) 112 | loss = loss_g + loss_d 113 | 114 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " 115 | "loss: %.4f, loss_g: %.4f, loss_d: %.4f, d_real: %.4f, d_fake: %.4f, kt: %.8f, M: %.8f" 116 | % (epoch, idx, batch_idxs, time.time() - start_time, 117 | loss, loss_g, loss_d, d_real_loss, d_fake_loss, kt, m_global)) 118 | 119 | # write train summary 120 | self.writer.add_summary(summary, self.count) 121 | 122 | # Test during Training 123 | if self.count % self.niter_snapshot == (self.niter_snapshot - 1): 124 | # update learning rate 125 | lr *= 0.95 126 | # save & test 127 | self.saver.save(self.sess, self.ckpt_model_name, global_step=self.count, write_meta_graph=False) 128 | self.test(train_flag) 129 | 130 | def test(self, train_flag=True): 131 | # generate output 132 | img_num = self.batch_size 133 | img_size = self.data_size 134 | 135 | output_f = int(np.sqrt(img_num)) 136 | im_output_gen = np.zeros([img_size * output_f, img_size * output_f, 3]) 137 | im_output_dec = np.zeros([img_size * output_f, img_size * output_f, 3]) 138 | 139 | test_data = np.random.uniform(-1., 1., size=[img_num, self.input_size]) 140 | output_gen = (self.sess.run(self.recon_gen, feed_dict={self.x: test_data})) # generator output 141 | output_dec = (self.sess.run(self.recon_dec, feed_dict={self.x: test_data})) # decoder output 142 | 143 | output_gen = [inverse_image(output_gen[i]) for i in range(img_num)] 144 | output_dec = [inverse_image(output_dec[i]) for i in range(img_num)] 145 | 146 | for i in range(output_f): 147 | for j in range(output_f): 148 | im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \ 149 | = output_gen[j + (i * output_f)] 150 | im_output_dec[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \ 151 | = output_dec[j + (i * output_f)] 152 | 153 | # output save 154 | if train_flag: 155 | scm.imsave(self.project_dir + '/result/' + str(self.count) + '_output.bmp', im_output_gen) 156 | else: 157 | now = datetime.datetime.now() 158 | nowDatetime = now.strftime('%Y-%m-%d_%H:%M:%S') 159 | scm.imsave(self.project_dir + '/result_test/gen_{}_output.bmp'.format(nowDatetime), im_output_gen) 160 | scm.imsave(self.project_dir + '/result_test/dec_{}_output.bmp'.format(nowDatetime), im_output_dec) 161 | --------------------------------------------------------------------------------