├── src
├── layer
│ ├── __init__.py
│ └── layers.py
├── models
│ ├── __init__.py
│ └── BEGAN.py
├── function
│ ├── __init__.py
│ └── functions.py
├── operator
│ ├── __init__.py
│ ├── op_base.py
│ └── op_BEGAN.py
├── __init__.py
└── __init__.pyc
├── Result
├── kt.jpg
├── 64x64.bmp
├── result.gif
├── 128x128.bmp
├── decoder.bmp
├── gamma_0.3.bmp
├── gamma_0.4.bmp
├── gamma_0.5.bmp
├── gamma_0.7.bmp
├── m_global.jpg
└── m_global.tif
├── began_cmd.txt
├── Data
└── celeba
│ └── face_detect.py
├── main.py
└── README.md
/src/layer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/function/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/operator/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | ## __init__.py
--------------------------------------------------------------------------------
/Result/kt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/kt.jpg
--------------------------------------------------------------------------------
/Result/64x64.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/64x64.bmp
--------------------------------------------------------------------------------
/Result/result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/result.gif
--------------------------------------------------------------------------------
/src/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/src/__init__.pyc
--------------------------------------------------------------------------------
/Result/128x128.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/128x128.bmp
--------------------------------------------------------------------------------
/Result/decoder.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/decoder.bmp
--------------------------------------------------------------------------------
/Result/gamma_0.3.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.3.bmp
--------------------------------------------------------------------------------
/Result/gamma_0.4.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.4.bmp
--------------------------------------------------------------------------------
/Result/gamma_0.5.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.5.bmp
--------------------------------------------------------------------------------
/Result/gamma_0.7.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/gamma_0.7.bmp
--------------------------------------------------------------------------------
/Result/m_global.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/m_global.jpg
--------------------------------------------------------------------------------
/Result/m_global.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmi88/BEGAN-tensorflow/HEAD/Result/m_global.tif
--------------------------------------------------------------------------------
/src/function/functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.misc as scm
3 |
4 | def make_project_dir(project_dir):
5 | if not os.path.exists(project_dir):
6 | os.makedirs(project_dir)
7 | os.makedirs(os.path.join(project_dir, 'models'))
8 | os.makedirs(os.path.join(project_dir, 'result'))
9 | os.makedirs(os.path.join(project_dir, 'result_test'))
10 |
11 |
12 | def get_image(img_path):
13 | img = scm.imread(img_path)/255. - 0.5
14 | img = img[..., ::-1] # rgb to bgr
15 | return img
16 |
17 |
18 | def inverse_image(img):
19 | img = (img + 0.5) * 255.
20 | img[img > 255] = 255
21 | img[img < 0] = 0
22 | img = img[..., ::-1] # bgr to rgb
23 | return img
24 |
25 |
--------------------------------------------------------------------------------
/began_cmd.txt:
--------------------------------------------------------------------------------
1 | ### Train
2 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4
3 | python3 main.py -f 1 -p "began3" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0"
4 |
5 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7
6 | python3 main.py -f 1 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0"
7 |
8 | ### Test (refer the main.py and began_cmd)
9 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4
10 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0"
11 |
12 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7
13 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0"
14 |
--------------------------------------------------------------------------------
/Data/celeba/face_detect.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 |
4 | # Get user supplied values
5 | # Create the haar cascade
6 | cascPath = 'haarcascade_frontalface_default.xml'
7 | faceCascade = cv2.CascadeClassifier(cascPath)
8 |
9 | # Read the image
10 | for fn in sorted(os.listdir('raw')):
11 | print(fn)
12 | image = cv2.imread('raw/' + fn)
13 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
14 |
15 | faces = faceCascade.detectMultiScale(gray,5, 5)
16 |
17 | if len(faces) == 0:
18 | pass
19 | else:
20 | x, y, w, h = faces[0]
21 | image_crop = image[y:y+w, x:x+w, :]
22 | image_resize = cv2.resize(image_crop, (128, 128))
23 | cv2.imwrite('128_crop/' + fn[:-4] + '_crop' + fn[-4:], image_resize)
24 |
25 | # for (x, y, w, h) in faces:
26 | # print x, y, w, h
27 | # cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), 2)
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/src/layer/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def conv2d(x, filter_shape, bias=True, stride=1, padding="SAME", name="conv2d"):
6 | kw, kh, nin, nout = filter_shape
7 | pad_size = (kw - 1) / 2
8 |
9 | if padding == "VALID":
10 | x = tf.pad(x, [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], "SYMMETRIC")
11 |
12 | initializer = tf.random_normal_initializer(0., 0.02)
13 | with tf.variable_scope(name):
14 | weight = tf.get_variable("weight", shape=filter_shape, initializer=initializer)
15 | x = tf.nn.conv2d(x, weight, [1, stride, stride, 1], padding=padding)
16 |
17 | if bias:
18 | b = tf.get_variable("bias", shape=filter_shape[-1], initializer=tf.constant_initializer(0.))
19 | x = tf.nn.bias_add(x, b)
20 | return x
21 |
22 |
23 | def fc(x, output_shape, bias=True, name='fc'):
24 | shape = x.get_shape().as_list()
25 | dim = np.prod(shape[1:])
26 | x = tf.reshape(x, [-1, dim])
27 | input_shape = dim
28 |
29 | initializer = tf.random_normal_initializer(0., 0.02)
30 | with tf.variable_scope(name):
31 | weight = tf.get_variable("weight", shape=[input_shape, output_shape], initializer=initializer)
32 | x = tf.matmul(x, weight)
33 |
34 | if bias:
35 | b = tf.get_variable("bias", shape=[output_shape], initializer=tf.constant_initializer(0.))
36 | x = tf.nn.bias_add(x, b)
37 | return x
38 |
39 |
40 | def pool(x, r=2, s=1):
41 | return tf.nn.avg_pool(x, ksize=[1, r, r, 1], strides=[1, s, s, 1], padding="SAME")
42 |
43 |
44 | def l1_loss(x, y):
45 | return tf.reduce_mean(tf.abs(x - y))
46 |
47 |
48 | def resize_nn(x, size):
49 | return tf.image.resize_nearest_neighbor(x, size=(int(size), int(size)))
50 |
--------------------------------------------------------------------------------
/src/operator/op_base.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import time
3 | import numpy as np
4 | import tensorflow as tf
5 | from src.function.functions import *
6 |
7 | class op_base:
8 | def __init__(self, args, sess):
9 | self.sess = sess
10 |
11 | # Train
12 | self.flag = args.flag
13 | self.gpu_number = args.gpu_number
14 | self.project = args.project
15 |
16 | # Train Data
17 | self.data_dir = args.data_dir #./Data
18 | self.dataset = args.dataset # celeba
19 | self.data_size = args.data_size # 64 or 128
20 | self.data_opt = args.data_opt # raw or crop
21 |
22 | # Train Iteration
23 | self.niter = args.niter
24 | self.niter_snapshot = args.nsnapshot
25 | self.max_to_keep = args.max_to_keep
26 |
27 | # Train Parameter
28 | self.batch_size = args.batch_size
29 | self.learning_rate = args.learning_rate
30 | self.mm = args.momentum
31 | self.mm2 = args.momentum2
32 | self.lamda = args.lamda
33 | self.gamma = args.gamma
34 | self.filter_number = args.filter_number
35 | self.input_size = args.input_size
36 | self.embedding = args.embedding
37 |
38 | # Result Dir & File
39 | self.project_dir = 'assets/{0}_{1}_{2}_{3}/'.format(self.project, self.dataset, self.data_opt, self.data_size)
40 | self.ckpt_dir = os.path.join(self.project_dir, 'models')
41 | self.model_name = "{0}.model".format(self.project)
42 | self.ckpt_model_name = os.path.join(self.ckpt_dir, self.model_name)
43 |
44 | # etc.
45 | if not os.path.exists('assets'):
46 | os.makedirs('assets')
47 | make_project_dir(self.project_dir)
48 |
49 | def load(self, sess, saver, ckpt_dir):
50 | ckpt = tf.train.get_checkpoint_state(ckpt_dir)
51 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
52 | saver.restore(sess, os.path.join(ckpt_dir, ckpt_name))
53 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import distutils.util
3 | import os
4 | import tensorflow as tf
5 | import src.models.BEGAN as began
6 |
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 |
11 | parser.add_argument("-f", "--flag", type=distutils.util.strtobool, default='true')
12 | parser.add_argument("-g", "--gpu_number", type=str, default="0")
13 | parser.add_argument("-p", "--project", type=str, default="began")
14 |
15 | # Train Data
16 | parser.add_argument("-d", "--data_dir", type=str, default="./Data")
17 | parser.add_argument("-trd", "--dataset", type=str, default="celeba")
18 | parser.add_argument("-tro", "--data_opt", type=str, default="crop")
19 | parser.add_argument("-trs", "--data_size", type=int, default=64)
20 |
21 | # Train Iteration
22 | parser.add_argument("-n" , "--niter", type=int, default=50)
23 | parser.add_argument("-ns", "--nsnapshot", type=int, default=2440)
24 | parser.add_argument("-mx", "--max_to_keep", type=int, default=5)
25 |
26 | # Train Parameter
27 | parser.add_argument("-b" , "--batch_size", type=int, default=16)
28 | parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4)
29 | parser.add_argument("-m" , "--momentum", type=float, default=0.5)
30 | parser.add_argument("-m2", "--momentum2", type=float, default=0.999)
31 | parser.add_argument("-gm", "--gamma", type=float, default=0.5)
32 | parser.add_argument("-lm", "--lamda", type=float, default=0.001)
33 | parser.add_argument("-fn", "--filter_number", type=int, default=64)
34 | parser.add_argument("-z", "--input_size", type=int, default=64)
35 | parser.add_argument("-em", "--embedding", type=int, default=64)
36 |
37 | args = parser.parse_args()
38 |
39 | gpu_number = args.gpu_number
40 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_number
41 |
42 | with tf.device('/gpu:{0}'.format(gpu_number)):
43 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90)
44 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
45 |
46 | with tf.Session(config=config) as sess:
47 | model = began.BEGAN(args, sess)
48 |
49 | # TRAIN / TEST
50 | if args.flag:
51 | model.train(args.flag)
52 | else:
53 | model.test(args.flag)
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BEGAN: Boundary Equilibrium Generative Adversarial Networks
2 | Implementation of Google Brain's [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717) in Tensorflow. \
3 | BEGAN is the state of the art when it comes to generate realistic faces.
4 |
5 |
6 |
7 |
8 |
9 |
10 | Figure1a. 128x128 img and 64x64 img. 128x128 img is very impressive. You can see SET OF TEETH
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 | Figure1b. This is random result from my train model. From gamma 0.3 to 0.5. No cherry picking. gamma 0.3, nice but bias to women's face. gamma 0.4, Best. gamma 0.5, good texture but hole problem.
19 |
20 |
21 |
22 |
23 |
24 | Figure1c. From scratch to 200k iter
25 |
26 | ## Implementation detail
27 | This train model is 64x64. 128x128 will be update. Different with original paper is train loss update method, learning rate decay. First, paper's loss update way is Loss_G and Loss_D simultaneously. But when I tried that way, models are mode collapse. So, This code use altenative way. Second, learning rate decay is 0.95 every 2000 iter. This parameter is just train experienc. You can change or see the paper.
28 |
29 | ## Train progress
30 | If you want to see the train progress download [this dropbox folder](https://www.dropbox.com/sh/g72k2crptow3ime/AAAhkGlHCw9zQh0aE-Ggdt3Qa?dl=0) and run "tensorboard --logdir='./'". I uploaded two trained model(64x64 and 128x128)
31 |
32 |
33 |
34 |
35 |
36 | Figure2. Kt graph. When you train model, reference this result. It doesn't reach to 1.0. In my case, it's converge to 0.08
37 |
38 |
39 |
40 |
41 |
42 | Figure3. Convergence measure(M_global). Similar with paper's graph
43 |
44 |
45 |
46 |
47 |
48 |
49 | Figure4. Compare with Generator output and Decoder output.
50 |
51 |
52 |
53 | ## Usage
54 | Recommend to download trained model [this dropbox folder](https://www.dropbox.com/sh/g72k2crptow3ime/AAAhkGlHCw9zQh0aE-Ggdt3Qa?dl=0).
55 |
56 | ### Make Train Data
57 | 1. Download [celebA dataset (img_align_celeba.zip)](http://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA%2FImg) and unzip to 'Data/celeba/raw'
58 | 2. Run ' python ./Data/celeba/face_detect.py '
59 |
60 | ### Train (refer the main.py began_cmd)
61 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4
62 | python3 main.py -f 1 -p "began" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0"
63 |
64 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7
65 | python3 main.py -f 1 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0"
66 |
67 | ### Test (refer the main.py and began_cmd)
68 | ex) 64x64 img | Nz,Nh 128 | gamma 0.4
69 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 64 -z 128 -em 128 -fn 64 -b 16 -lr 1e-4 -gm 0.4 -g "0"
70 |
71 | ex) 128x128 img | Nz,Nh 64 | gamma 0.7
72 | python3 main.py -f 0 -p "began" -trd "celeba" -tro "crop" -trs 128 -z 64 -em 64 -fn 128 -b 16 -lr 1e-4 -gm 0.7 -g "0"
73 |
74 |
75 | ## Requirements
76 | - Python 3.5, scipy 0.18.1, numpy 1.11.2
77 | - TensorFlow 1.1.0
78 |
79 | ## Author
80 | Heumi / ckhfight@gmail.com
81 |
82 |
--------------------------------------------------------------------------------
/src/models/BEGAN.py:
--------------------------------------------------------------------------------
1 | from src.layer.layers import *
2 | from src.operator.op_BEGAN import Operator
3 |
4 |
5 | class BEGAN(Operator):
6 | def __init__(self, args, sess):
7 | Operator.__init__(self, args, sess)
8 |
9 | def generator(self, x, reuse=None):
10 | with tf.variable_scope('gen_') as scope:
11 | if reuse:
12 | scope.reuse_variables()
13 |
14 | w = self.data_size
15 | f = self.filter_number
16 | p = "SAME"
17 |
18 | x = fc(x, 8 * 8 * f, name='fc')
19 | x = tf.reshape(x, [-1, 8, 8, f])
20 |
21 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a')
22 | x = tf.nn.elu(x)
23 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b')
24 | x = tf.nn.elu(x)
25 |
26 | if self.data_size == 128:
27 | x = resize_nn(x, w / 8)
28 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a')
29 | x = tf.nn.elu(x)
30 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b')
31 | x = tf.nn.elu(x)
32 |
33 | x = resize_nn(x, w / 4)
34 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a')
35 | x = tf.nn.elu(x)
36 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b')
37 | x = tf.nn.elu(x)
38 |
39 | x = resize_nn(x, w / 2)
40 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a')
41 | x = tf.nn.elu(x)
42 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b')
43 | x = tf.nn.elu(x)
44 |
45 | x = resize_nn(x, w)
46 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv5_a')
47 | x = tf.nn.elu(x)
48 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv5_b')
49 | x = tf.nn.elu(x)
50 |
51 | x = conv2d(x, [3, 3, f, 3], stride=1, padding=p,name='conv6_a')
52 | return x
53 |
54 | def encoder(self, x, reuse=None):
55 | with tf.variable_scope('disc_') as scope:
56 | if reuse:
57 | scope.reuse_variables()
58 |
59 | f = self.filter_number
60 | h = self.embedding
61 | p = "SAME"
62 |
63 | x = conv2d(x, [3, 3, 3, f], stride=1, padding=p,name='conv1_enc_a')
64 | x = tf.nn.elu(x)
65 |
66 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_a')
67 | x = tf.nn.elu(x)
68 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_b')
69 | x = tf.nn.elu(x)
70 |
71 | x = conv2d(x, [1, 1, f, 2 * f], stride=1, padding=p,name='conv3_enc_0')
72 | x = pool(x, r=2, s=2)
73 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_a')
74 | x = tf.nn.elu(x)
75 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_b')
76 | x = tf.nn.elu(x)
77 |
78 | x = conv2d(x, [1, 1, 2 * f, 3 * f], stride=1, padding=p,name='conv4_enc_0')
79 | x = pool(x, r=2, s=2)
80 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_a')
81 | x = tf.nn.elu(x)
82 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_b')
83 | x = tf.nn.elu(x)
84 |
85 | x = conv2d(x, [1, 1, 3 * f, 4 * f], stride=1, padding=p,name='conv5_enc_0')
86 | x = pool(x, r=2, s=2)
87 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_a')
88 | x = tf.nn.elu(x)
89 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_b')
90 | x = tf.nn.elu(x)
91 |
92 | if self.data_size == 128:
93 | x = conv2d(x, [1, 1, 4 * f, 5 * f], stride=1, padding=p,name='conv6_enc_0')
94 | x = pool(x, r=2, s=2)
95 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_a')
96 | x = tf.nn.elu(x)
97 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_b')
98 | x = tf.nn.elu(x)
99 |
100 | x = fc(x, h, name='enc_fc')
101 | return x
102 |
103 | def decoder(self, x, reuse=None):
104 | with tf.variable_scope('disc_') as scope:
105 | if reuse:
106 | scope.reuse_variables()
107 |
108 | w = self.data_size
109 | f = self.filter_number
110 | p = "SAME"
111 |
112 | x = fc(x, 8 * 8 * f, name='fc')
113 | x = tf.reshape(x, [-1, 8, 8, f])
114 |
115 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a')
116 | x = tf.nn.elu(x)
117 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b')
118 | x = tf.nn.elu(x)
119 |
120 | if self.data_size == 128:
121 | x = resize_nn(x, w / 8)
122 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a')
123 | x = tf.nn.elu(x)
124 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b')
125 | x = tf.nn.elu(x)
126 |
127 | x = resize_nn(x, w / 4)
128 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a')
129 | x = tf.nn.elu(x)
130 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b')
131 | x = tf.nn.elu(x)
132 |
133 | x = resize_nn(x, w / 2)
134 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a')
135 | x = tf.nn.elu(x)
136 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b')
137 | x = tf.nn.elu(x)
138 |
139 | x = resize_nn(x, w)
140 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_a')
141 | x = tf.nn.elu(x)
142 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_b')
143 | x = tf.nn.elu(x)
144 |
145 | x = conv2d(x, [3, 3, f, 3], stride=1, padding=p, name='conv6_a')
146 | return x
147 |
--------------------------------------------------------------------------------
/src/operator/op_BEGAN.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import time
3 | import datetime
4 | from src.layer.layers import *
5 | from src.function.functions import *
6 | from src.operator.op_base import op_base
7 |
8 |
9 | class Operator(op_base):
10 | def __init__(self, args, sess):
11 | op_base.__init__(self, args, sess)
12 | self.build_model()
13 |
14 | def build_model(self):
15 | # Input placeholder
16 | self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.input_size], name='x')
17 | self.y = tf.placeholder(tf.float32, shape=[self.batch_size, self.data_size, self.data_size, 3], name='y')
18 | self.kt = tf.placeholder(tf.float32, name='kt')
19 | self.lr = tf.placeholder(tf.float32, name='lr')
20 |
21 | # Generator
22 | self.recon_gen = self.generator(self.x)
23 |
24 | # Discriminator (Critic)
25 | d_real = self.decoder(self.encoder(self.y))
26 | d_fake = self.decoder(self.encoder(self.recon_gen, reuse=True), reuse=True)
27 | self.recon_dec = self.decoder(self.x, reuse=True)
28 |
29 | # Loss
30 | self.d_real_loss = l1_loss(self.y, d_real)
31 | self.d_fake_loss = l1_loss(self.recon_gen, d_fake)
32 | self.d_loss = self.d_real_loss - self.kt * self.d_fake_loss
33 | self.g_loss = self.d_fake_loss
34 | self.m_global = self.d_real_loss + tf.abs(self.gamma * self.d_real_loss - self.d_fake_loss)
35 |
36 | # Variables
37 | g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "gen_")
38 | d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "disc_")
39 |
40 | # Optimizer
41 | self.opt_g = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.g_loss, var_list=g_vars)
42 | self.opt_d = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.d_loss, var_list=d_vars)
43 |
44 |
45 | # initializer
46 | self.sess.run(tf.global_variables_initializer())
47 |
48 | # tf saver
49 | self.saver = tf.train.Saver(max_to_keep=(self.max_to_keep))
50 |
51 | try:
52 | self.load(self.sess, self.saver, self.ckpt_dir)
53 | except:
54 | # save full graph
55 | self.saver.save(self.sess, self.ckpt_model_name, write_meta_graph=True)
56 |
57 | # Summary
58 | if self.flag:
59 | tf.summary.scalar('loss/loss', self.d_loss + self.g_loss)
60 | tf.summary.scalar('loss/g_loss', self.g_loss)
61 | tf.summary.scalar('loss/d_loss', self.d_loss)
62 | tf.summary.scalar('loss/d_real_loss', self.d_real_loss)
63 | tf.summary.scalar('loss/d_fake_loss', self.d_fake_loss)
64 | tf.summary.scalar('misc/kt', self.kt)
65 | tf.summary.scalar('misc/m_global', self.m_global)
66 | self.merged = tf.summary.merge_all()
67 | self.writer = tf.summary.FileWriter(self.project_dir, self.sess.graph)
68 |
69 | def train(self, train_flag):
70 | # load data
71 | data_path = '{0}/{1}/{2}_{3}'.format(self.data_dir, self.dataset, self.data_size, self.data_opt)
72 |
73 | if os.path.exists(data_path + '.npy'):
74 | data = np.load(data_path + '.npy')
75 | else:
76 | data = sorted(glob.glob(os.path.join(data_path, "*.*")))
77 | np.save(data_path + '.npy', data)
78 |
79 | print('Shuffle ....')
80 | random_order = np.random.permutation(len(data))
81 | data = [data[i] for i in random_order[:]]
82 | print('Shuffle Done')
83 |
84 | # initial parameter
85 | start_time = time.time()
86 | kt = np.float32(0.)
87 | lr = np.float32(self.learning_rate)
88 | self.count = 0
89 |
90 | for epoch in range(self.niter):
91 | batch_idxs = len(data) // self.batch_size
92 |
93 | for idx in range(0, batch_idxs):
94 | self.count += 1
95 |
96 | batch_x = np.random.uniform(-1., 1., size=[self.batch_size, self.input_size])
97 | batch_files = data[idx * self.batch_size: (idx + 1) * self.batch_size]
98 | batch_data = [get_image(batch_file) for batch_file in batch_files]
99 |
100 | # opt & feed list (different with paper)
101 | g_opt = [self.opt_g, self.g_loss, self.d_real_loss, self.d_fake_loss]
102 | d_opt = [self.opt_d, self.d_loss, self.merged]
103 | feed_dict = {self.x: batch_x, self.y: batch_data, self.kt: kt, self.lr: lr}
104 |
105 | # run tensorflow
106 | _, loss_g, d_real_loss, d_fake_loss = self.sess.run(g_opt, feed_dict=feed_dict)
107 | _, loss_d, summary = self.sess.run(d_opt, feed_dict=feed_dict)
108 |
109 | # update kt, m_global
110 | kt = np.maximum(np.minimum(1., kt + self.lamda * (self.gamma * d_real_loss - d_fake_loss)), 0.)
111 | m_global = d_real_loss + np.abs(self.gamma * d_real_loss - d_fake_loss)
112 | loss = loss_g + loss_d
113 |
114 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, "
115 | "loss: %.4f, loss_g: %.4f, loss_d: %.4f, d_real: %.4f, d_fake: %.4f, kt: %.8f, M: %.8f"
116 | % (epoch, idx, batch_idxs, time.time() - start_time,
117 | loss, loss_g, loss_d, d_real_loss, d_fake_loss, kt, m_global))
118 |
119 | # write train summary
120 | self.writer.add_summary(summary, self.count)
121 |
122 | # Test during Training
123 | if self.count % self.niter_snapshot == (self.niter_snapshot - 1):
124 | # update learning rate
125 | lr *= 0.95
126 | # save & test
127 | self.saver.save(self.sess, self.ckpt_model_name, global_step=self.count, write_meta_graph=False)
128 | self.test(train_flag)
129 |
130 | def test(self, train_flag=True):
131 | # generate output
132 | img_num = self.batch_size
133 | img_size = self.data_size
134 |
135 | output_f = int(np.sqrt(img_num))
136 | im_output_gen = np.zeros([img_size * output_f, img_size * output_f, 3])
137 | im_output_dec = np.zeros([img_size * output_f, img_size * output_f, 3])
138 |
139 | test_data = np.random.uniform(-1., 1., size=[img_num, self.input_size])
140 | output_gen = (self.sess.run(self.recon_gen, feed_dict={self.x: test_data})) # generator output
141 | output_dec = (self.sess.run(self.recon_dec, feed_dict={self.x: test_data})) # decoder output
142 |
143 | output_gen = [inverse_image(output_gen[i]) for i in range(img_num)]
144 | output_dec = [inverse_image(output_dec[i]) for i in range(img_num)]
145 |
146 | for i in range(output_f):
147 | for j in range(output_f):
148 | im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
149 | = output_gen[j + (i * output_f)]
150 | im_output_dec[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size, :] \
151 | = output_dec[j + (i * output_f)]
152 |
153 | # output save
154 | if train_flag:
155 | scm.imsave(self.project_dir + '/result/' + str(self.count) + '_output.bmp', im_output_gen)
156 | else:
157 | now = datetime.datetime.now()
158 | nowDatetime = now.strftime('%Y-%m-%d_%H:%M:%S')
159 | scm.imsave(self.project_dir + '/result_test/gen_{}_output.bmp'.format(nowDatetime), im_output_gen)
160 | scm.imsave(self.project_dir + '/result_test/dec_{}_output.bmp'.format(nowDatetime), im_output_dec)
161 |
--------------------------------------------------------------------------------