├── .gitignore ├── README.md ├── example ├── phd08.gif └── sample499.jpg ├── gan.py └── phd08 ├── phd08_data_1.npy └── phd08_labels_1.npy /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/ 3 | MNIST_data/ 4 | output/ 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## GAN 2 | > Forked from [GAN-tensorflow](https://github.com/ckmarkoh/GAN-tensorflow) 3 | 4 | Implementation of [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661) by Tensorflow, 5 | with [phd08](https://www.dropbox.com/s/69cwkkqt4m1xl55/phd08.alz?dl=0) 6 | 7 | ## Requirements 8 | 9 | - [Tensorflow](http://www.tensorflow.org/) 10 | 11 | ## Examples 12 | 13 | ### Change Log 14 | 15 | ![phd08](example/phd08.gif) 16 | 17 | ### Final Result 18 | 19 | - max_epoch : 500 20 | 21 | - batch_size : 256 22 | 23 | ![phd08 final](example/sample499.jpg) 24 | 25 | ## Run 26 | 27 | ### train 28 | 29 | ``` 30 | python gan.py 31 | ``` 32 | 33 | ### test 34 | 35 | > after train, change `to_train` from `True` to `False` 36 | 37 | ``` 38 | python gan.py 39 | ``` -------------------------------------------------------------------------------- /example/phd08.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ckmarkoh/GAN-tensorflow/4ccb228912892d4e867f12e3e99af25133fbcc58/example/phd08.gif -------------------------------------------------------------------------------- /example/sample499.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ckmarkoh/GAN-tensorflow/4ccb228912892d4e867f12e3e99af25133fbcc58/example/sample499.jpg -------------------------------------------------------------------------------- /gan.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | import numpy as np 4 | from skimage.io import imsave 5 | import os 6 | import shutil 7 | 8 | img_height = 28 9 | img_width = 28 10 | img_size = img_height * img_width 11 | 12 | to_train = True 13 | to_restore = False 14 | output_path = "output" 15 | 16 | max_epoch = 500 17 | 18 | h1_size = 150 19 | h2_size = 300 20 | z_size = 100 21 | batch_size = 256 22 | 23 | # 일반적인 GAN 의 형태 입니다. 24 | # 라벨을 구분하지 않습니다. 25 | 26 | # 제너레이터 (G) 27 | def build_generator(z_prior): 28 | # Fully Connected Layer 1 (100 (latent-vector) -> 150 (h1_size)) 29 | w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32) 30 | b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32) 31 | h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1) 32 | 33 | # Fully Connected Layer 2 (150 (h1_size) -> 300 (h2_size)) 34 | w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32) 35 | b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32) 36 | h2 = tf.nn.relu(tf.matmul(h1, w2) + b2) 37 | 38 | # Fully Connected Layer 3 (300 (h1_size) -> input_height * input_width (img_size)) 39 | w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], stddev=0.1), name="g_w3", dtype=tf.float32) 40 | b3 = tf.Variable(tf.zeros([img_size]), name="g_b3", dtype=tf.float32) 41 | h3 = tf.matmul(h2, w3) + b3 42 | # 마지막은 활성화함수 tanh 로 43 | x_generate = tf.nn.tanh(h3) 44 | 45 | # generator 변수 저장 46 | g_params = [w1, b1, w2, b2, w3, b3] 47 | 48 | return x_generate, g_params 49 | 50 | # 디스크리미네이터 (D) 51 | def build_discriminator(x_data, x_generated, keep_prob): 52 | # 실제 이미지와 생성된 이미지를 합침 53 | x_in = tf.concat([x_data, x_generated], 0) 54 | 55 | # Fully Connected Layer 1 (input_height * input_width (img_size) -> 200 (h2_size)) , dropout 56 | w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32) 57 | b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32) 58 | h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob) 59 | 60 | # Fully Connected Layer 2 (200 (h1_size) -> 150 (h1_size)) , dropout 61 | w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32) 62 | b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32) 63 | h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob) 64 | 65 | # Fully Connected Layer 3 (150 (h1_size) -> 1) 66 | w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32) 67 | b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32) 68 | h3 = tf.matmul(h2, w3) + b3 69 | 70 | # batch_size 만큼 잘라 각각 y_data, y_generated 로 71 | # ex) 72 | # 이미지 60000개, 배치 사이즈 256, 이미지 사이즈 28 * 28 이라면 73 | # h3 shape : (257, 1) 74 | # y_data shape : (256, 1) 75 | # y_generated shape : (1, 1) 76 | y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None)) 77 | y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None)) 78 | 79 | # discriminator 변수 저장 80 | d_params = [w1, b1, w2, b2, w3, b3] 81 | 82 | return y_data, y_generated, d_params 83 | 84 | # 결과 저장 (이미지) 85 | def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5): 86 | batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5 87 | img_h, img_w = batch_res.shape[1], batch_res.shape[2] 88 | grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1) 89 | grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1) 90 | img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8) 91 | for i, res in enumerate(batch_res): 92 | if i >= grid_size[0] * grid_size[1]: 93 | break 94 | img = (res) * 255. 95 | img = img.astype(np.uint8) 96 | row = (i // grid_size[0]) * (img_h + grid_pad) 97 | col = (i % grid_size[1]) * (img_w + grid_pad) 98 | img_grid[row:row + img_h, col:col + img_w] = img 99 | imsave(fname, img_grid) 100 | 101 | 102 | def train(): 103 | # mnist 로 학습할 경우 104 | # mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 105 | 106 | # phd08 로 학습할 경우 107 | phd08 = np.load('phd08/phd08_data_1.npy') 108 | size = phd08.shape[0] 109 | phd08 = phd08.reshape((size,784)) 110 | 111 | x_data = tf.placeholder(tf.float32, [batch_size, img_size], name="x_data") # (batch_size, img_size) 112 | z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior") # (batch_size, z_size) 113 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") # dropout 퍼센트 114 | global_step = tf.Variable(0, name="global_step", trainable=False) 115 | 116 | # x_generated : generator 가 생성한 이미지, g_params : generater 의 TF 변수들 117 | x_generated, g_params = build_generator(z_prior) 118 | # 실제이미지, generater 가 생성한 이미지, dropout keep_prob 를 넣고 discriminator(경찰) 이 감별 119 | y_data, y_generated, d_params = build_discriminator(x_data, x_generated, keep_prob) 120 | 121 | # loss 함수 ( D 와 G 를 따로 ) * 122 | d_loss = - (tf.log(y_data) + tf.log(1 - y_generated)) 123 | g_loss = - tf.log(y_generated) 124 | 125 | # optimizer : AdamOptimizer 사용 * 126 | optimizer = tf.train.AdamOptimizer(0.0001) 127 | 128 | # discriminator 와 generator 의 변수로 각각의 loss 함수를 최소화시키도록 학습 129 | d_trainer = optimizer.minimize(d_loss, var_list=d_params) 130 | g_trainer = optimizer.minimize(g_loss, var_list=g_params) 131 | 132 | init = tf.global_variables_initializer() 133 | 134 | saver = tf.train.Saver() 135 | 136 | sess = tf.Session() 137 | 138 | sess.run(init) 139 | 140 | if to_restore: 141 | chkpt_fname = tf.train.latest_checkpoint(output_path) 142 | saver.restore(sess, chkpt_fname) 143 | else: 144 | if os.path.exists(output_path): 145 | shutil.rmtree(output_path) 146 | os.mkdir(output_path) 147 | 148 | z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) 149 | 150 | for i in range(sess.run(global_step), max_epoch): 151 | for j in range(21870 // batch_size): 152 | print("epoch:%s, iter:%s" % (i, j)) 153 | 154 | # x_value, _ = mnist.train.next_batch(batch_size) 155 | # x_value = 2 * x_value.astype(np.float32) - 1 156 | # print(x_value[0]) 157 | 158 | batch_end = j * batch_size + batch_size 159 | if batch_end >= size: 160 | batch_end = size - 1 161 | x_value = phd08[ j * batch_size : batch_end ] 162 | x_value = x_value / 255. 163 | x_value = 2 * x_value - 1 164 | 165 | z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) 166 | sess.run(d_trainer, 167 | feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)}) 168 | if j % 1 == 0: 169 | sess.run(g_trainer, 170 | feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)}) 171 | x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val}) 172 | show_result(x_gen_val, os.path.join(output_path, "sample%s.jpg" % i)) 173 | print(x_gen_val) 174 | z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) 175 | x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val}) 176 | show_result(x_gen_val, os.path.join(output_path, "random_sample%s.jpg" % i)) 177 | sess.run(tf.assign(global_step, i + 1)) 178 | saver.save(sess, os.path.join(output_path, "model"), global_step=global_step) 179 | 180 | # 학습 완료 후 테스트 (이미지로 저장) 181 | def test(): 182 | z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior") 183 | x_generated, _ = build_generator(z_prior) 184 | chkpt_fname = tf.train.latest_checkpoint(output_path) 185 | 186 | init = tf.global_variables_initializer() 187 | sess = tf.Session() 188 | saver = tf.train.Saver() 189 | sess.run(init) 190 | saver.restore(sess, chkpt_fname) 191 | z_test_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32) 192 | x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_test_value}) 193 | show_result(x_gen_val, os.path.join(output_path, "test_result.jpg")) 194 | 195 | 196 | if __name__ == '__main__': 197 | if to_train: 198 | train() 199 | else: 200 | test() 201 | -------------------------------------------------------------------------------- /phd08/phd08_data_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ckmarkoh/GAN-tensorflow/4ccb228912892d4e867f12e3e99af25133fbcc58/phd08/phd08_data_1.npy -------------------------------------------------------------------------------- /phd08/phd08_labels_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ckmarkoh/GAN-tensorflow/4ccb228912892d4e867f12e3e99af25133fbcc58/phd08/phd08_labels_1.npy --------------------------------------------------------------------------------