├── DCGAN-mydataset ├── download.py ├── main.py ├── model.py ├── ops.py ├── readme.md └── utils.py ├── README.md └── gan-mnist ├── README.md ├── mnist_eval.py ├── mnist_inference.py └── mnist_train.py /DCGAN-mydataset/download.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modification of https://github.com/stanfordnlp/treelstm/blob/master/scripts/download.py 3 | 4 | Downloads the following: 5 | - Celeb-A dataset 6 | - LSUN dataset 7 | - MNIST dataset 8 | """ 9 | 10 | from __future__ import print_function 11 | import os 12 | import sys 13 | import gzip 14 | import json 15 | import shutil 16 | import zipfile 17 | import argparse 18 | import requests 19 | import subprocess 20 | from tqdm import tqdm 21 | from six.moves import urllib 22 | 23 | parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') 24 | parser.add_argument('datasets', metavar='N', type=str, nargs='+', choices=['celebA', 'lsun', 'mnist'], 25 | help='name of dataset to download [celebA, lsun, mnist]') 26 | 27 | def download(url, dirpath): 28 | filename = url.split('/')[-1] 29 | filepath = os.path.join(dirpath, filename) 30 | u = urllib.request.urlopen(url) 31 | f = open(filepath, 'wb') 32 | filesize = int(u.headers["Content-Length"]) 33 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 34 | 35 | downloaded = 0 36 | block_sz = 8192 37 | status_width = 70 38 | while True: 39 | buf = u.read(block_sz) 40 | if not buf: 41 | print('') 42 | break 43 | else: 44 | print('', end='\r') 45 | downloaded += len(buf) 46 | f.write(buf) 47 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 48 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 49 | print(status, end='') 50 | sys.stdout.flush() 51 | f.close() 52 | return filepath 53 | 54 | def download_file_from_google_drive(id, destination): 55 | URL = "https://docs.google.com/uc?export=download" 56 | session = requests.Session() 57 | 58 | response = session.get(URL, params={ 'id': id }, stream=True) 59 | token = get_confirm_token(response) 60 | 61 | if token: 62 | params = { 'id' : id, 'confirm' : token } 63 | response = session.get(URL, params=params, stream=True) 64 | 65 | save_response_content(response, destination) 66 | 67 | def get_confirm_token(response): 68 | for key, value in response.cookies.items(): 69 | if key.startswith('download_warning'): 70 | return value 71 | return None 72 | 73 | def save_response_content(response, destination, chunk_size=32*1024): 74 | total_size = int(response.headers.get('content-length', 0)) 75 | with open(destination, "wb") as f: 76 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 77 | unit='B', unit_scale=True, desc=destination): 78 | if chunk: # filter out keep-alive new chunks 79 | f.write(chunk) 80 | 81 | def unzip(filepath): 82 | print("Extracting: " + filepath) 83 | dirpath = os.path.dirname(filepath) 84 | with zipfile.ZipFile(filepath) as zf: 85 | zf.extractall(dirpath) 86 | os.remove(filepath) 87 | 88 | def download_celeb_a(dirpath): 89 | data_dir = 'celebA' 90 | if os.path.exists(os.path.join(dirpath, data_dir)): 91 | print('Found Celeb-A - skip') 92 | return 93 | 94 | filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM" 95 | save_path = os.path.join(dirpath, filename) 96 | 97 | if os.path.exists(save_path): 98 | print('[*] {} already exists'.format(save_path)) 99 | else: 100 | download_file_from_google_drive(drive_id, save_path) 101 | 102 | zip_dir = '' 103 | with zipfile.ZipFile(save_path) as zf: 104 | zip_dir = zf.namelist()[0] 105 | zf.extractall(dirpath) 106 | os.remove(save_path) 107 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 108 | 109 | def _list_categories(tag): 110 | url = 'http://lsun.cs.princeton.edu/htbin/list.cgi?tag=' + tag 111 | f = urllib.request.urlopen(url) 112 | return json.loads(f.read()) 113 | 114 | def _download_lsun(out_dir, category, set_name, tag): 115 | url = 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag={tag}' \ 116 | '&category={category}&set={set_name}'.format(**locals()) 117 | print(url) 118 | if set_name == 'test': 119 | out_name = 'test_lmdb.zip' 120 | else: 121 | out_name = '{category}_{set_name}_lmdb.zip'.format(**locals()) 122 | out_path = os.path.join(out_dir, out_name) 123 | cmd = ['curl', url, '-o', out_path] 124 | print('Downloading', category, set_name, 'set') 125 | subprocess.call(cmd) 126 | 127 | def download_lsun(dirpath): 128 | data_dir = os.path.join(dirpath, 'lsun') 129 | if os.path.exists(data_dir): 130 | print('Found LSUN - skip') 131 | return 132 | else: 133 | os.mkdir(data_dir) 134 | 135 | tag = 'latest' 136 | #categories = _list_categories(tag) 137 | categories = ['bedroom'] 138 | 139 | for category in categories: 140 | _download_lsun(data_dir, category, 'train', tag) 141 | _download_lsun(data_dir, category, 'val', tag) 142 | _download_lsun(data_dir, '', 'test', tag) 143 | 144 | def download_mnist(dirpath): 145 | data_dir = os.path.join(dirpath, 'mnist') 146 | if os.path.exists(data_dir): 147 | print('Found MNIST - skip') 148 | return 149 | else: 150 | os.mkdir(data_dir) 151 | url_base = 'http://yann.lecun.com/exdb/mnist/' 152 | file_names = ['train-images-idx3-ubyte.gz', 153 | 'train-labels-idx1-ubyte.gz', 154 | 't10k-images-idx3-ubyte.gz', 155 | 't10k-labels-idx1-ubyte.gz'] 156 | for file_name in file_names: 157 | url = (url_base+file_name).format(**locals()) 158 | print(url) 159 | out_path = os.path.join(data_dir,file_name) 160 | cmd = ['curl', url, '-o', out_path] 161 | print('Downloading ', file_name) 162 | subprocess.call(cmd) 163 | cmd = ['gzip', '-d', out_path] 164 | print('Decompressing ', file_name) 165 | subprocess.call(cmd) 166 | 167 | def prepare_data_dir(path = './data'): 168 | if not os.path.exists(path): 169 | os.mkdir(path) 170 | 171 | if __name__ == '__main__': 172 | args = parser.parse_args() 173 | prepare_data_dir() 174 | 175 | if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']): 176 | download_celeb_a('./data') 177 | if 'lsun' in args.datasets: 178 | download_lsun('./data') 179 | if 'mnist' in args.datasets: 180 | download_mnist('./data') 181 | -------------------------------------------------------------------------------- /DCGAN-mydataset/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.misc 3 | import numpy as np 4 | 5 | from model import DCGAN 6 | from utils import pp, visualize, to_json, show_all_variables 7 | 8 | import tensorflow as tf 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_integer("epoch", 10, "Epoch to train [25]") 12 | flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") 13 | flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") 14 | flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]") 15 | flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") 16 | 17 | flags.DEFINE_integer("input_height", 96, "The size of image to use (will be center cropped). [108]") 18 | flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") 19 | flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") 20 | flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") 21 | 22 | flags.DEFINE_string("dataset", "faces", "The name of dataset [celebA, mnist, lsun]") 23 | flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") 24 | flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") 25 | flags.DEFINE_string("data_dir", "E:/code/dataset/", "Root directory of dataset [data]") 26 | flags.DEFINE_string("sample_dir", "samplesnew", "Directory name to save the image samples [samples]") 27 | 28 | flags.DEFINE_boolean("train", True, "True for training, False for testing [False]") 29 | flags.DEFINE_boolean("crop", True, "True for training, False for testing [False]") 30 | 31 | flags.DEFINE_boolean("visualize", True, "True for visualizing, False for nothing [False]") 32 | flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]") 33 | FLAGS = flags.FLAGS 34 | 35 | def main(_): 36 | pp.pprint(flags.FLAGS.__flags) 37 | 38 | if FLAGS.input_width is None: 39 | FLAGS.input_width = FLAGS.input_height 40 | if FLAGS.output_width is None: 41 | FLAGS.output_width = FLAGS.output_height 42 | 43 | if not os.path.exists(FLAGS.checkpoint_dir): 44 | os.makedirs(FLAGS.checkpoint_dir) 45 | if not os.path.exists(FLAGS.sample_dir): 46 | os.makedirs(FLAGS.sample_dir) 47 | 48 | #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333) 49 | run_config = tf.ConfigProto() 50 | run_config.gpu_options.allow_growth=True 51 | 52 | with tf.Session(config=run_config) as sess: 53 | if FLAGS.dataset == 'mnist': 54 | dcgan = DCGAN( 55 | sess, 56 | input_width=FLAGS.input_width, 57 | input_height=FLAGS.input_height, 58 | output_width=FLAGS.output_width, 59 | output_height=FLAGS.output_height, 60 | batch_size=FLAGS.batch_size, 61 | sample_num=FLAGS.batch_size, 62 | y_dim=10, 63 | z_dim=FLAGS.generate_test_images, 64 | dataset_name=FLAGS.dataset, 65 | input_fname_pattern=FLAGS.input_fname_pattern, 66 | crop=FLAGS.crop, 67 | checkpoint_dir=FLAGS.checkpoint_dir, 68 | sample_dir=FLAGS.sample_dir, 69 | data_dir=FLAGS.data_dir) 70 | else: 71 | dcgan = DCGAN( 72 | sess, 73 | input_width=FLAGS.input_width, 74 | input_height=FLAGS.input_height, 75 | output_width=FLAGS.output_width, 76 | output_height=FLAGS.output_height, 77 | batch_size=FLAGS.batch_size, 78 | sample_num=FLAGS.batch_size, 79 | z_dim=FLAGS.generate_test_images, 80 | dataset_name=FLAGS.dataset, 81 | input_fname_pattern=FLAGS.input_fname_pattern, 82 | crop=FLAGS.crop, 83 | checkpoint_dir=FLAGS.checkpoint_dir, 84 | sample_dir=FLAGS.sample_dir, 85 | data_dir=FLAGS.data_dir) 86 | 87 | show_all_variables() 88 | 89 | if FLAGS.train: 90 | dcgan.train(FLAGS) 91 | else: 92 | if not dcgan.load(FLAGS.checkpoint_dir)[0]: 93 | raise Exception("[!] Train a model first, then run test mode") 94 | 95 | 96 | # to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0], 97 | # [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1], 98 | # [dcgan.h2_w, dcgan.h2_b, dcgan.g_bn2], 99 | # [dcgan.h3_w, dcgan.h3_b, dcgan.g_bn3], 100 | # [dcgan.h4_w, dcgan.h4_b, None]) 101 | 102 | # Below is codes for visualization 103 | OPTION = 1 104 | visualize(sess, dcgan, FLAGS, OPTION) 105 | 106 | if __name__ == '__main__': 107 | tf.app.run() 108 | -------------------------------------------------------------------------------- /DCGAN-mydataset/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import time 4 | import math 5 | from glob import glob 6 | import tensorflow as tf 7 | import numpy as np 8 | from six.moves import xrange 9 | 10 | from ops import * 11 | from utils import * 12 | 13 | def conv_out_size_same(size, stride): 14 | return int(math.ceil(float(size) / float(stride))) 15 | 16 | class DCGAN(object): 17 | def __init__(self, sess, input_height=108, input_width=108, crop=True, 18 | batch_size=64, sample_num = 64, output_height=64, output_width=64, 19 | y_dim=None, z_dim=100, gf_dim=64, df_dim=64, 20 | gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', 21 | input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'): 22 | """ 23 | 24 | Args: 25 | sess: TensorFlow session 26 | batch_size: The size of batch. Should be specified before training. 27 | y_dim: (optional) Dimension of dim for y. [None] 28 | z_dim: (optional) Dimension of dim for Z. [100] 29 | gf_dim: (optional) Dimension of gen filters in first conv layer. [64] 30 | df_dim: (optional) Dimension of discrim filters in first conv layer. [64] 31 | gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024] 32 | dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] 33 | c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] 34 | """ 35 | self.sess = sess 36 | self.crop = crop 37 | 38 | self.batch_size = batch_size 39 | self.sample_num = sample_num 40 | 41 | self.input_height = input_height 42 | self.input_width = input_width 43 | self.output_height = output_height 44 | self.output_width = output_width 45 | 46 | self.y_dim = y_dim 47 | self.z_dim = z_dim 48 | 49 | self.gf_dim = gf_dim 50 | self.df_dim = df_dim 51 | 52 | self.gfc_dim = gfc_dim 53 | self.dfc_dim = dfc_dim 54 | 55 | # batch normalization : deals with poor initialization helps gradient flow 56 | self.d_bn1 = batch_norm(name='d_bn1') 57 | self.d_bn2 = batch_norm(name='d_bn2') 58 | 59 | if not self.y_dim: 60 | self.d_bn3 = batch_norm(name='d_bn3') 61 | 62 | self.g_bn0 = batch_norm(name='g_bn0') 63 | self.g_bn1 = batch_norm(name='g_bn1') 64 | self.g_bn2 = batch_norm(name='g_bn2') 65 | 66 | if not self.y_dim: 67 | self.g_bn3 = batch_norm(name='g_bn3') 68 | 69 | self.dataset_name = dataset_name 70 | self.input_fname_pattern = input_fname_pattern 71 | self.checkpoint_dir = checkpoint_dir 72 | self.data_dir = data_dir 73 | 74 | if self.dataset_name == 'mnist': 75 | self.data_X, self.data_y = self.load_mnist() 76 | self.c_dim = self.data_X[0].shape[-1] 77 | else: 78 | data_path = os.path.join(self.data_dir, self.dataset_name, self.input_fname_pattern) 79 | self.data = glob(data_path) 80 | if len(self.data) == 0: 81 | raise Exception("[!] No data found in '" + data_path + "'") 82 | np.random.shuffle(self.data) 83 | imreadImg = imread(self.data[0]) 84 | if len(imreadImg.shape) >= 3: #check if image is a non-grayscale image by checking channel number 85 | self.c_dim = imread(self.data[0]).shape[-1] 86 | else: 87 | self.c_dim = 1 88 | 89 | if len(self.data) < self.batch_size: 90 | raise Exception("[!] Entire dataset size is less than the configured batch_size") 91 | 92 | self.grayscale = (self.c_dim == 1) 93 | 94 | self.build_model() 95 | 96 | def build_model(self): 97 | if self.y_dim: 98 | self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y') 99 | else: 100 | self.y = None 101 | 102 | if self.crop: 103 | image_dims = [self.output_height, self.output_width, self.c_dim] 104 | else: 105 | image_dims = [self.input_height, self.input_width, self.c_dim] 106 | 107 | self.inputs = tf.placeholder( 108 | tf.float32, [self.batch_size] + image_dims, name='real_images') 109 | 110 | inputs = self.inputs 111 | 112 | self.z = tf.placeholder( 113 | tf.float32, [None, self.z_dim], name='z') 114 | self.z_sum = histogram_summary("z", self.z) 115 | 116 | self.G = self.generator(self.z, self.y) 117 | self.D, self.D_logits = self.discriminator(inputs, self.y, reuse=False) 118 | self.sampler = self.sampler(self.z, self.y) 119 | self.D_, self.D_logits_ = self.discriminator(self.G, self.y, reuse=True) 120 | 121 | #生成分布图 122 | self.d_sum = histogram_summary("d", self.D) 123 | self.d__sum = histogram_summary("d_", self.D_) 124 | self.G_sum = image_summary("G", self.G) 125 | 126 | #image_summary输出一个包含图像的summary, 这个图像是通过一个4维张量构建的[batch_size,height,width,channels] 127 | 128 | # [batch_size,height, width, channels] 129 | def sigmoid_cross_entropy_with_logits(x, y): 130 | try: 131 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y) 132 | except: 133 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, targets=y) 134 | 135 | self.d_loss_real = tf.reduce_mean( 136 | sigmoid_cross_entropy_with_logits(self.D_logits, tf.ones_like(self.D))) 137 | self.d_loss_fake = tf.reduce_mean( 138 | sigmoid_cross_entropy_with_logits(self.D_logits_, tf.zeros_like(self.D_))) 139 | self.g_loss = tf.reduce_mean( 140 | sigmoid_cross_entropy_with_logits(self.D_logits_, tf.ones_like(self.D_))) 141 | 142 | self.d_loss_real_sum = scalar_summary("d_loss_real", self.d_loss_real) 143 | self.d_loss_fake_sum = scalar_summary("d_loss_fake", self.d_loss_fake) 144 | 145 | self.d_loss = self.d_loss_real + self.d_loss_fake 146 | 147 | self.g_loss_sum = scalar_summary("g_loss", self.g_loss) 148 | self.d_loss_sum = scalar_summary("d_loss", self.d_loss) 149 | 150 | t_vars = tf.trainable_variables() 151 | 152 | self.d_vars = [var for var in t_vars if 'd_' in var.name] 153 | self.g_vars = [var for var in t_vars if 'g_' in var.name] 154 | 155 | self.saver = tf.train.Saver() 156 | 157 | def train(self, config): 158 | d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 159 | .minimize(self.d_loss, var_list=self.d_vars) 160 | g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ 161 | .minimize(self.g_loss, var_list=self.g_vars) 162 | try: 163 | tf.global_variables_initializer().run() 164 | except: 165 | tf.initialize_all_variables().run() 166 | 167 | self.g_sum = merge_summary([self.z_sum, self.d__sum, 168 | self.G_sum, self.d_loss_fake_sum, self.g_loss_sum]) 169 | self.d_sum = merge_summary( 170 | [self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum]) 171 | self.writer = SummaryWriter("./logs", self.sess.graph) 172 | 173 | sample_z = np.random.uniform(-1, 1, size=(self.sample_num , self.z_dim)) 174 | 175 | if config.dataset == 'mnist': 176 | sample_inputs = self.data_X[0:self.sample_num] 177 | sample_labels = self.data_y[0:self.sample_num] 178 | else: 179 | sample_files = self.data[0:self.sample_num] 180 | sample = [ 181 | get_image(sample_file, 182 | input_height=self.input_height, 183 | input_width=self.input_width, 184 | resize_height=self.output_height, 185 | resize_width=self.output_width, 186 | crop=self.crop, 187 | grayscale=self.grayscale) for sample_file in sample_files] 188 | if (self.grayscale): 189 | sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None] 190 | else: 191 | sample_inputs = np.array(sample).astype(np.float32) 192 | 193 | counter = 1 194 | start_time = time.time() 195 | could_load, checkpoint_counter = self.load(self.checkpoint_dir) 196 | if could_load: 197 | counter = checkpoint_counter 198 | print(" [*] Load SUCCESS") 199 | else: 200 | print(" [!] Load failed...") 201 | 202 | for epoch in xrange(config.epoch): 203 | if config.dataset == 'mnist': 204 | batch_idxs = min(len(self.data_X), config.train_size) // config.batch_size 205 | else: 206 | self.data = glob(os.path.join( 207 | config.data_dir, config.dataset, self.input_fname_pattern)) 208 | np.random.shuffle(self.data) 209 | batch_idxs = min(len(self.data), config.train_size) // config.batch_size 210 | 211 | for idx in xrange(0, int(batch_idxs)): 212 | if config.dataset == 'mnist': 213 | batch_images = self.data_X[idx*config.batch_size:(idx+1)*config.batch_size] 214 | batch_labels = self.data_y[idx*config.batch_size:(idx+1)*config.batch_size] 215 | else: 216 | batch_files = self.data[idx*config.batch_size:(idx+1)*config.batch_size] 217 | batch = [ 218 | get_image(batch_file, 219 | input_height=self.input_height, 220 | input_width=self.input_width, 221 | resize_height=self.output_height, 222 | resize_width=self.output_width, 223 | crop=self.crop, 224 | grayscale=self.grayscale) for batch_file in batch_files] 225 | if self.grayscale: 226 | batch_images = np.array(batch).astype(np.float32)[:, :, :, None] 227 | else: 228 | batch_images = np.array(batch).astype(np.float32) 229 | 230 | batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ 231 | .astype(np.float32) 232 | 233 | if config.dataset == 'mnist': 234 | # Update D network 235 | _, summary_str = self.sess.run([d_optim, self.d_sum], 236 | feed_dict={ 237 | self.inputs: batch_images, 238 | self.z: batch_z, 239 | self.y:batch_labels, 240 | }) 241 | self.writer.add_summary(summary_str, counter) 242 | 243 | # Update G network 244 | _, summary_str = self.sess.run([g_optim, self.g_sum], 245 | feed_dict={ 246 | self.z: batch_z, 247 | self.y:batch_labels, 248 | }) 249 | self.writer.add_summary(summary_str, counter) 250 | 251 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 252 | _, summary_str = self.sess.run([g_optim, self.g_sum], 253 | feed_dict={ self.z: batch_z, self.y:batch_labels }) 254 | self.writer.add_summary(summary_str, counter) 255 | 256 | errD_fake = self.d_loss_fake.eval({ 257 | self.z: batch_z, 258 | self.y:batch_labels 259 | }) 260 | errD_real = self.d_loss_real.eval({ 261 | self.inputs: batch_images, 262 | self.y:batch_labels 263 | }) 264 | errG = self.g_loss.eval({ 265 | self.z: batch_z, 266 | self.y: batch_labels 267 | }) 268 | else: 269 | # Update D network 270 | _, summary_str = self.sess.run([d_optim, self.d_sum], 271 | feed_dict={ self.inputs: batch_images, self.z: batch_z }) 272 | self.writer.add_summary(summary_str, counter) 273 | 274 | # Update G network four times 275 | _, summary_str = self.sess.run([g_optim, self.g_sum], 276 | feed_dict={ self.z: batch_z }) 277 | self.writer.add_summary(summary_str, counter) 278 | 279 | 280 | 281 | # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) 282 | _, summary_str = self.sess.run([g_optim, self.g_sum], 283 | feed_dict={ self.z: batch_z }) 284 | self.writer.add_summary(summary_str, counter) 285 | 286 | errD_fake = self.d_loss_fake.eval({ self.z: batch_z }) 287 | errD_real = self.d_loss_real.eval({ self.inputs: batch_images }) 288 | errG = self.g_loss.eval({self.z: batch_z}) 289 | 290 | counter += 1 291 | print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 292 | % (epoch, config.epoch, idx, batch_idxs, 293 | time.time() - start_time, errD_fake+errD_real, errG)) 294 | 295 | if np.mod(counter, 100) == 1: 296 | if config.dataset == 'mnist': 297 | samples, d_loss, g_loss = self.sess.run( 298 | [self.sampler, self.d_loss, self.g_loss], 299 | feed_dict={ 300 | self.z: sample_z, 301 | self.inputs: sample_inputs, 302 | self.y:sample_labels, 303 | } 304 | ) 305 | save_images(samples, image_manifold_size(samples.shape[0]), 306 | './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) 307 | print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 308 | else: 309 | try: 310 | samples, d_loss, g_loss = self.sess.run( 311 | [self.sampler, self.d_loss, self.g_loss], 312 | feed_dict={ 313 | self.z: sample_z, 314 | self.inputs: sample_inputs, 315 | }, 316 | ) 317 | save_images(samples, image_manifold_size(samples.shape[0]), 318 | './{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx)) 319 | print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss)) 320 | except: 321 | print("one pic error!...") 322 | 323 | if np.mod(counter, 500) == 2: 324 | self.save(config.checkpoint_dir, counter) 325 | 326 | def discriminator(self, image, y=None, reuse=False): 327 | with tf.variable_scope("discriminator") as scope: 328 | if reuse: 329 | scope.reuse_variables() 330 | 331 | if not self.y_dim: 332 | h0 = lrelu(conv2d(image, self.df_dim, name='d_h0_conv')) 333 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim*2, name='d_h1_conv'))) 334 | h2 = lrelu(self.d_bn2(conv2d(h1, self.df_dim*4, name='d_h2_conv'))) 335 | h3 = lrelu(self.d_bn3(conv2d(h2, self.df_dim*8, name='d_h3_conv'))) 336 | h4 = linear(tf.reshape(h3, [self.batch_size, -1]), 1, 'd_h4_lin') 337 | 338 | return tf.nn.sigmoid(h4), h4 339 | else: 340 | yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 341 | x = conv_cond_concat(image, yb) 342 | 343 | h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv')) 344 | h0 = conv_cond_concat(h0, yb) 345 | 346 | h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv'))) 347 | h1 = tf.reshape(h1, [self.batch_size, -1]) 348 | h1 = concat([h1, y], 1) 349 | 350 | h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin'))) 351 | h2 = concat([h2, y], 1) 352 | 353 | h3 = linear(h2, 1, 'd_h3_lin') 354 | 355 | return tf.nn.sigmoid(h3), h3 356 | 357 | def generator(self, z, y=None): 358 | with tf.variable_scope("generator") as scope: 359 | if not self.y_dim: 360 | s_h, s_w = self.output_height, self.output_width 361 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) 362 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) 363 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) 364 | s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) 365 | 366 | # project `z` and reshape 367 | self.z_, self.h0_w, self.h0_b = linear( 368 | z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin', with_w=True) 369 | 370 | self.h0 = tf.reshape( 371 | self.z_, [-1, s_h16, s_w16, self.gf_dim * 8]) 372 | 373 | #对数据进行归一化处理,加快收敛速度g_bn0 374 | h0 = tf.nn.relu(self.g_bn0(self.h0)) 375 | 376 | self.h1, self.h1_w, self.h1_b = deconv2d( 377 | h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1', with_w=True) 378 | h1 = tf.nn.relu(self.g_bn1(self.h1)) 379 | 380 | h2, self.h2_w, self.h2_b = deconv2d( 381 | h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2', with_w=True) 382 | h2 = tf.nn.relu(self.g_bn2(h2)) 383 | 384 | h3, self.h3_w, self.h3_b = deconv2d( 385 | h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3', with_w=True) 386 | h3 = tf.nn.relu(self.g_bn3(h3)) 387 | 388 | h4, self.h4_w, self.h4_b = deconv2d( 389 | h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4', with_w=True) 390 | 391 | return tf.nn.tanh(h4) 392 | else: 393 | s_h, s_w = self.output_height, self.output_width 394 | s_h2, s_h4 = int(s_h/2), int(s_h/4) 395 | s_w2, s_w4 = int(s_w/2), int(s_w/4) 396 | 397 | # yb = tf.expand_dims(tf.expand_dims(y, 1),2) 398 | yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 399 | z = concat([z, y], 1) 400 | 401 | h0 = tf.nn.relu( 402 | self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'))) 403 | h0 = concat([h0, y], 1) 404 | 405 | h1 = tf.nn.relu(self.g_bn1( 406 | linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'))) 407 | h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2]) 408 | 409 | h1 = conv_cond_concat(h1, yb) 410 | 411 | h2 = tf.nn.relu(self.g_bn2(deconv2d(h1, 412 | [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'))) 413 | h2 = conv_cond_concat(h2, yb) 414 | 415 | return tf.nn.sigmoid( 416 | deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3')) 417 | 418 | def sampler(self, z, y=None): 419 | with tf.variable_scope("generator") as scope: 420 | scope.reuse_variables() 421 | 422 | if not self.y_dim: 423 | s_h, s_w = self.output_height, self.output_width 424 | s_h2, s_w2 = conv_out_size_same(s_h, 2), conv_out_size_same(s_w, 2) 425 | s_h4, s_w4 = conv_out_size_same(s_h2, 2), conv_out_size_same(s_w2, 2) 426 | s_h8, s_w8 = conv_out_size_same(s_h4, 2), conv_out_size_same(s_w4, 2) 427 | s_h16, s_w16 = conv_out_size_same(s_h8, 2), conv_out_size_same(s_w8, 2) 428 | 429 | # project `z` and reshape 430 | h0 = tf.reshape( 431 | linear(z, self.gf_dim*8*s_h16*s_w16, 'g_h0_lin'), 432 | [-1, s_h16, s_w16, self.gf_dim * 8]) 433 | h0 = tf.nn.relu(self.g_bn0(h0, train=False)) 434 | 435 | h1 = deconv2d(h0, [self.batch_size, s_h8, s_w8, self.gf_dim*4], name='g_h1') 436 | h1 = tf.nn.relu(self.g_bn1(h1, train=False)) 437 | 438 | h2 = deconv2d(h1, [self.batch_size, s_h4, s_w4, self.gf_dim*2], name='g_h2') 439 | h2 = tf.nn.relu(self.g_bn2(h2, train=False)) 440 | 441 | h3 = deconv2d(h2, [self.batch_size, s_h2, s_w2, self.gf_dim*1], name='g_h3') 442 | h3 = tf.nn.relu(self.g_bn3(h3, train=False)) 443 | 444 | h4 = deconv2d(h3, [self.batch_size, s_h, s_w, self.c_dim], name='g_h4') 445 | 446 | return tf.nn.tanh(h4) 447 | else: 448 | s_h, s_w = self.output_height, self.output_width 449 | s_h2, s_h4 = int(s_h/2), int(s_h/4) 450 | s_w2, s_w4 = int(s_w/2), int(s_w/4) 451 | 452 | # yb = tf.reshape(y, [-1, 1, 1, self.y_dim]) 453 | yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) 454 | z = concat([z, y], 1) 455 | 456 | h0 = tf.nn.relu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'), train=False)) 457 | h0 = concat([h0, y], 1) 458 | 459 | h1 = tf.nn.relu(self.g_bn1( 460 | linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'), train=False)) 461 | h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2]) 462 | h1 = conv_cond_concat(h1, yb) 463 | 464 | h2 = tf.nn.relu(self.g_bn2( 465 | deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'), train=False)) 466 | h2 = conv_cond_concat(h2, yb) 467 | 468 | return tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3')) 469 | 470 | def load_mnist(self): 471 | data_dir = os.path.join(self.data_dir, self.dataset_name) 472 | 473 | fd = open(os.path.join(data_dir,'train-images-idx3-ubyte')) 474 | loaded = np.fromfile(file=fd,dtype=np.uint8) 475 | trX = loaded[16:].reshape((60000,28,28,1)).astype(np.float) 476 | 477 | fd = open(os.path.join(data_dir,'train-labels-idx1-ubyte')) 478 | loaded = np.fromfile(file=fd,dtype=np.uint8) 479 | trY = loaded[8:].reshape((60000)).astype(np.float) 480 | 481 | fd = open(os.path.join(data_dir,'t10k-images-idx3-ubyte')) 482 | loaded = np.fromfile(file=fd,dtype=np.uint8) 483 | teX = loaded[16:].reshape((10000,28,28,1)).astype(np.float) 484 | 485 | fd = open(os.path.join(data_dir,'t10k-labels-idx1-ubyte')) 486 | loaded = np.fromfile(file=fd,dtype=np.uint8) 487 | teY = loaded[8:].reshape((10000)).astype(np.float) 488 | 489 | trY = np.asarray(trY) 490 | teY = np.asarray(teY) 491 | 492 | X = np.concatenate((trX, teX), axis=0) 493 | y = np.concatenate((trY, teY), axis=0).astype(np.int) 494 | 495 | seed = 547 496 | np.random.seed(seed) 497 | np.random.shuffle(X) 498 | np.random.seed(seed) 499 | np.random.shuffle(y) 500 | 501 | y_vec = np.zeros((len(y), self.y_dim), dtype=np.float) 502 | for i, label in enumerate(y): 503 | y_vec[i,y[i]] = 1.0 504 | 505 | return X/255.,y_vec 506 | 507 | @property 508 | def model_dir(self): 509 | return "{}_{}_{}_{}".format( 510 | self.dataset_name, self.batch_size, 511 | self.output_height, self.output_width) 512 | 513 | def save(self, checkpoint_dir, step): 514 | model_name = "DCGAN.model" 515 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 516 | 517 | if not os.path.exists(checkpoint_dir): 518 | os.makedirs(checkpoint_dir) 519 | 520 | self.saver.save(self.sess, 521 | os.path.join(checkpoint_dir, model_name), 522 | global_step=step) 523 | 524 | def load(self, checkpoint_dir): 525 | import re 526 | print(" [*] Reading checkpoints...") 527 | checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir) 528 | 529 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 530 | if ckpt and ckpt.model_checkpoint_path: 531 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 532 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 533 | counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0)) 534 | print(" [*] Success to read {}".format(ckpt_name)) 535 | return True, counter 536 | else: 537 | print(" [*] Failed to find a checkpoint") 538 | return False, 0 539 | -------------------------------------------------------------------------------- /DCGAN-mydataset/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.framework import ops 6 | 7 | from utils import * 8 | 9 | try: 10 | image_summary = tf.image_summary 11 | scalar_summary = tf.scalar_summary 12 | histogram_summary = tf.histogram_summary 13 | merge_summary = tf.merge_summary 14 | SummaryWriter = tf.train.SummaryWriter 15 | except: 16 | image_summary = tf.summary.image 17 | scalar_summary = tf.summary.scalar 18 | histogram_summary = tf.summary.histogram 19 | merge_summary = tf.summary.merge 20 | SummaryWriter = tf.summary.FileWriter 21 | 22 | if "concat_v2" in dir(tf): 23 | def concat(tensors, axis, *args, **kwargs): 24 | return tf.concat_v2(tensors, axis, *args, **kwargs) 25 | else: 26 | def concat(tensors, axis, *args, **kwargs): 27 | return tf.concat(tensors, axis, *args, **kwargs) 28 | 29 | class batch_norm(object): 30 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 31 | with tf.variable_scope(name): 32 | self.epsilon = epsilon 33 | self.momentum = momentum 34 | self.name = name 35 | 36 | def __call__(self, x, train=True): 37 | return tf.contrib.layers.batch_norm(x, 38 | decay=self.momentum, 39 | updates_collections=None, 40 | epsilon=self.epsilon, 41 | scale=True, 42 | is_training=train, 43 | scope=self.name) 44 | 45 | def conv_cond_concat(x, y): 46 | """Concatenate conditioning vector on feature map axis.""" 47 | x_shapes = x.get_shape() 48 | y_shapes = y.get_shape() 49 | return concat([ 50 | x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) 51 | 52 | def conv2d(input_, output_dim, 53 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 54 | name="conv2d"): 55 | with tf.variable_scope(name): 56 | w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], 57 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 58 | conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') 59 | 60 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 61 | conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 62 | 63 | return conv 64 | 65 | def deconv2d(input_, output_shape, 66 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 67 | name="deconv2d", with_w=False): 68 | with tf.variable_scope(name): 69 | # filter : [height, width, output_channels, in_channels] 70 | w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]], 71 | initializer=tf.random_normal_initializer(stddev=stddev)) 72 | 73 | try: 74 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, 75 | strides=[1, d_h, d_w, 1]) 76 | 77 | # Support for verisons of TensorFlow before 0.7.0 78 | except AttributeError: 79 | deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape, 80 | strides=[1, d_h, d_w, 1]) 81 | 82 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 83 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 84 | 85 | if with_w: 86 | return deconv, w, biases 87 | else: 88 | return deconv 89 | 90 | def lrelu(x, leak=0.2, name="lrelu"): 91 | return tf.maximum(x, leak*x) 92 | 93 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 94 | shape = input_.get_shape().as_list() 95 | 96 | with tf.variable_scope(scope or "Linear"): 97 | try: 98 | matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, 99 | tf.random_normal_initializer(stddev=stddev)) 100 | except ValueError as err: 101 | msg = "NOTE: Usually, this is due to an issue with the image dimensions. Did you correctly set '--crop' or '--input_height' or '--output_height'?" 102 | err.args = err.args + (msg,) 103 | raise 104 | bias = tf.get_variable("bias", [output_size], 105 | initializer=tf.constant_initializer(bias_start)) 106 | if with_w: 107 | return tf.matmul(input_, matrix) + bias, matrix, bias 108 | else: 109 | return tf.matmul(input_, matrix) + bias 110 | -------------------------------------------------------------------------------- /DCGAN-mydataset/readme.md: -------------------------------------------------------------------------------- 1 | |框架|语言|GAN类别| 2 | |:--:|:--:|:--:| 3 | |tensorflow | python | DCGAN | 4 | |tensorflow | python | DCGAN | 5 | |tensorflow | python | DCGAN | 6 | -------------------------------------------------------------------------------- /DCGAN-mydataset/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some codes from https://github.com/Newmu/dcgan_code 3 | """ 4 | from __future__ import division 5 | import math 6 | import json 7 | import random 8 | import pprint 9 | import scipy.misc 10 | import cv2 11 | import numpy as np 12 | from time import gmtime, strftime 13 | from six.moves import xrange 14 | 15 | import tensorflow as tf 16 | import tensorflow.contrib.slim as slim 17 | 18 | pp = pprint.PrettyPrinter() 19 | 20 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 21 | 22 | def show_all_variables(): 23 | model_vars = tf.trainable_variables() 24 | slim.model_analyzer.analyze_vars(model_vars, print_info=True) 25 | 26 | #输入的图像裁剪为64*64的图像,并将RGB值归一化到[-1,1] 27 | def get_image(image_path, input_height, input_width, 28 | resize_height=64, resize_width=64, 29 | crop=True, grayscale=False): 30 | image = imread(image_path, grayscale) 31 | return transform(image, input_height, input_width, 32 | resize_height, resize_width, crop) 33 | 34 | def save_images(images, size, image_path): 35 | return imsave(inverse_transform(images), size, image_path) 36 | 37 | def imread(path, grayscale = False): 38 | if (grayscale): 39 | return scipy.misc.imread(path, flatten = True).astype(np.float) 40 | else: 41 | # Reference: https://github.com/carpedm20/DCGAN-tensorflow/issues/162#issuecomment-315519747 42 | img_bgr = cv2.imread(path) 43 | # Reference: https://stackoverflow.com/a/15074748/ 44 | img_rgb = img_bgr[..., ::-1] 45 | return img_rgb.astype(np.float) 46 | 47 | def merge_images(images, size): 48 | return inverse_transform(images) 49 | 50 | def merge(images, size): 51 | h, w = images.shape[1], images.shape[2] 52 | if (images.shape[3] in (3,4)): 53 | c = images.shape[3] 54 | img = np.zeros((h * size[0], w * size[1], c)) 55 | for idx, image in enumerate(images): 56 | i = idx % size[1] 57 | j = idx // size[1] 58 | img[j * h:j * h + h, i * w:i * w + w, :] = image 59 | return img 60 | elif images.shape[3]==1: 61 | img = np.zeros((h * size[0], w * size[1])) 62 | for idx, image in enumerate(images): 63 | i = idx % size[1] 64 | j = idx // size[1] 65 | img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] 66 | return img 67 | else: 68 | raise ValueError('in merge(images,size) images parameter ' 69 | 'must have dimensions: HxW or HxWx3 or HxWx4') 70 | 71 | def imsave(images, size, path): 72 | image = np.squeeze(merge(images, size)) 73 | return scipy.misc.imsave(path, image) 74 | 75 | def center_crop(x, crop_h, crop_w, 76 | resize_h=64, resize_w=64): 77 | if crop_w is None: 78 | crop_w = crop_h 79 | h, w = x.shape[:2] 80 | j = int(round((h - crop_h)/2.)) 81 | i = int(round((w - crop_w)/2.)) 82 | return scipy.misc.imresize( 83 | x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 84 | 85 | def transform(image, input_height, input_width, 86 | resize_height=64, resize_width=64, crop=True): 87 | if crop: 88 | cropped_image = center_crop( 89 | image, input_height, input_width, 90 | resize_height, resize_width) 91 | else: 92 | #截取为64*64的图像 93 | cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) 94 | #[0,255]正则化到[-1,1] 95 | return np.array(cropped_image)/127.5 - 1. 96 | 97 | def inverse_transform(images): 98 | return (images+1.)/2. 99 | 100 | def to_json(output_path, *layers): 101 | with open(output_path, "w") as layer_f: 102 | lines = "" 103 | for w, b, bn in layers: 104 | layer_idx = w.name.split('/')[0].split('h')[1] 105 | 106 | B = b.eval() 107 | 108 | if "lin/" in w.name: 109 | W = w.eval() 110 | depth = W.shape[1] 111 | else: 112 | W = np.rollaxis(w.eval(), 2, 0) 113 | depth = W.shape[0] 114 | 115 | biases = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(B)]} 116 | if bn != None: 117 | gamma = bn.gamma.eval() 118 | beta = bn.beta.eval() 119 | 120 | gamma = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(gamma)]} 121 | beta = {"sy": 1, "sx": 1, "depth": depth, "w": ['%.2f' % elem for elem in list(beta)]} 122 | else: 123 | gamma = {"sy": 1, "sx": 1, "depth": 0, "w": []} 124 | beta = {"sy": 1, "sx": 1, "depth": 0, "w": []} 125 | 126 | if "lin/" in w.name: 127 | fs = [] 128 | for w in W.T: 129 | fs.append({"sy": 1, "sx": 1, "depth": W.shape[0], "w": ['%.2f' % elem for elem in list(w)]}) 130 | 131 | lines += """ 132 | var layer_%s = { 133 | "layer_type": "fc", 134 | "sy": 1, "sx": 1, 135 | "out_sx": 1, "out_sy": 1, 136 | "stride": 1, "pad": 0, 137 | "out_depth": %s, "in_depth": %s, 138 | "biases": %s, 139 | "gamma": %s, 140 | "beta": %s, 141 | "filters": %s 142 | };""" % (layer_idx.split('_')[0], W.shape[1], W.shape[0], biases, gamma, beta, fs) 143 | else: 144 | fs = [] 145 | for w_ in W: 146 | fs.append({"sy": 5, "sx": 5, "depth": W.shape[3], "w": ['%.2f' % elem for elem in list(w_.flatten())]}) 147 | 148 | lines += """ 149 | var layer_%s = { 150 | "layer_type": "deconv", 151 | "sy": 5, "sx": 5, 152 | "out_sx": %s, "out_sy": %s, 153 | "stride": 2, "pad": 1, 154 | "out_depth": %s, "in_depth": %s, 155 | "biases": %s, 156 | "gamma": %s, 157 | "beta": %s, 158 | "filters": %s 159 | };""" % (layer_idx, 2**(int(layer_idx)+2), 2**(int(layer_idx)+2), 160 | W.shape[0], W.shape[3], biases, gamma, beta, fs) 161 | layer_f.write(" ".join(lines.replace("'","").split())) 162 | 163 | def make_gif(images, fname, duration=2, true_image=False): 164 | import moviepy.editor as mpy 165 | 166 | def make_frame(t): 167 | try: 168 | x = images[int(len(images)/duration*t)] 169 | except: 170 | x = images[-1] 171 | 172 | if true_image: 173 | return x.astype(np.uint8) 174 | else: 175 | return ((x+1)/2*255).astype(np.uint8) 176 | 177 | clip = mpy.VideoClip(make_frame, duration=duration) 178 | clip.write_gif(fname, fps = len(images) / duration) 179 | 180 | def visualize(sess, dcgan, config, option): 181 | image_frame_dim = int(math.ceil(config.batch_size**.5)) 182 | if option == 0: 183 | z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim)) 184 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 185 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) 186 | elif option == 1: 187 | values = np.arange(0, 1, 1./config.batch_size) 188 | for idx in xrange(dcgan.z_dim): 189 | print(" [*] %d" % idx) 190 | z_sample = np.random.uniform(-1, 1, size=(config.batch_size , dcgan.z_dim)) 191 | for kdx, z in enumerate(z_sample): 192 | z[idx] = values[kdx] 193 | 194 | if config.dataset == "mnist": 195 | y = np.random.choice(10, config.batch_size) 196 | y_one_hot = np.zeros((config.batch_size, 10)) 197 | y_one_hot[np.arange(config.batch_size), y] = 1 198 | 199 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot}) 200 | else: 201 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 202 | 203 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx)) 204 | elif option == 2: 205 | values = np.arange(0, 1, 1./config.batch_size) 206 | for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]: 207 | print(" [*] %d" % idx) 208 | z = np.random.uniform(-0.2, 0.2, size=(dcgan.z_dim)) 209 | z_sample = np.tile(z, (config.batch_size, 1)) 210 | #z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 211 | for kdx, z in enumerate(z_sample): 212 | z[idx] = values[kdx] 213 | 214 | if config.dataset == "mnist": 215 | y = np.random.choice(10, config.batch_size) 216 | y_one_hot = np.zeros((config.batch_size, 10)) 217 | y_one_hot[np.arange(config.batch_size), y] = 1 218 | 219 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample, dcgan.y: y_one_hot}) 220 | else: 221 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 222 | 223 | try: 224 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 225 | except: 226 | save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime())) 227 | elif option == 3: 228 | values = np.arange(0, 1, 1./config.batch_size) 229 | for idx in xrange(dcgan.z_dim): 230 | print(" [*] %d" % idx) 231 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 232 | for kdx, z in enumerate(z_sample): 233 | z[idx] = values[kdx] 234 | 235 | samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}) 236 | make_gif(samples, './samples/test_gif_%s.gif' % (idx)) 237 | elif option == 4: 238 | image_set = [] 239 | values = np.arange(0, 1, 1./config.batch_size) 240 | 241 | for idx in xrange(dcgan.z_dim): 242 | print(" [*] %d" % idx) 243 | z_sample = np.zeros([config.batch_size, dcgan.z_dim]) 244 | for kdx, z in enumerate(z_sample): z[idx] = values[kdx] 245 | 246 | image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})) 247 | make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx)) 248 | 249 | new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \ 250 | for idx in range(64) + range(63, -1, -1)] 251 | make_gif(new_image_set, './samples/test_gif_merged.gif', duration=8) 252 | 253 | 254 | def image_manifold_size(num_images): 255 | manifold_h = int(np.floor(np.sqrt(num_images))) 256 | manifold_w = int(np.ceil(np.sqrt(num_images))) 257 | assert manifold_h * manifold_w == num_images 258 | return manifold_h, manifold_w 259 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflowCNN 2 | ### 1.gan-mnist:GAN was trained using the mnist dataset. 3 | ### 2.DCGAN-mydataset:By using our own data set to train the model, the quadric face is generated.Details show in my Blog [https://blog.csdn.net/qq_27855219/article/details/89371186] 4 | -------------------------------------------------------------------------------- /gan-mnist/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rui012345/tensorflowCNN/a20a508ef2a4729c27adbede538134855ec25471/gan-mnist/README.md -------------------------------------------------------------------------------- /gan-mnist/mnist_eval.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import time 3 | import tensorflow as tf 4 | from tensorflow.examples.tutorials.mnist import input_data 5 | 6 | import mnist_inference 7 | import mnist_train 8 | 9 | # every 10 sec load the newest model 10 | EVAL_INTERVAL_SECS = 10 11 | 12 | 13 | def evaluate(mnist): 14 | with tf.Graph().as_default() as g: 15 | x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') 16 | y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') 17 | validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} 18 | # 直接调用封装好的函数来计算前向传播的结果 19 | y = mnist_inference.inference(x, None) 20 | # 计算正确率 21 | correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) 22 | accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) 23 | # 通过变量重命名的方式加载模型 24 | variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY) 25 | variable_to_restore = variable_averages.variables_to_restore() 26 | saver = tf.train.Saver(variable_to_restore) 27 | # 每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 28 | while True: 29 | with tf.Session() as sess: 30 | ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) 31 | if ckpt and ckpt.model_checkpoint_path: 32 | # load the model 33 | saver.restore(sess, ckpt.model_checkpoint_path) 34 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 35 | accuracy_score = sess.run(accuracy, feed_dict=validate_feed) 36 | print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) 37 | 38 | else: 39 | print('No checkpoint file found') 40 | return 41 | time.sleep(EVAL_INTERVAL_SECS) 42 | 43 | 44 | def main(argv=None): 45 | mnist = input_data.read_data_sets("./mnist", one_hot=True) 46 | evaluate(mnist) 47 | 48 | 49 | if __name__ == '__main__': 50 | tf.app.run() 51 | -------------------------------------------------------------------------------- /gan-mnist/mnist_inference.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | INPUT_NODE = 784 4 | OUTPUT_NODE = 10 5 | LAYER1_NODE = 500 6 | 7 | # 通过tf.get_variable函数获取变量 8 | def get_weight_variable(shape,regularizer): 9 | weights = tf.get_variable("weights",shape,initializer=tf.truncated_normal_initializer(stddev=1.0)) 10 | # losses集合,add_to_collection张量加入集合 11 | if regularizer != None: 12 | tf.add_to_collection('losses',regularizer(weights)) 13 | return weights 14 | 15 | # 定义神经网络的前向传播 16 | def inference(input_tensor,regularizer): 17 | #声明第一层神经网络的变量并完成前向传播过程 18 | with tf.variable_scope('layer1'): 19 | weights = get_weight_variable([INPUT_NODE,LAYER1_NODE],regularizer) 20 | biases = tf.get_variable('biases',[LAYER1_NODE],initializer=tf.constant_initializer(0.0)) 21 | layer1 = tf.nn.relu(tf.matmul(input_tensor,weights) + biases) 22 | 23 | # 声明第二层神经网络的变量并完成前向传播过程 24 | with tf.variable_scope('layer2'): 25 | weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) 26 | biases = tf.get_variable('biases', [LAYER1_NODE], initializer=tf.constant_initializer(0.0)) 27 | layer2 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases) 28 | #返回最后前向传播结果 29 | return layer2 30 | 31 | 32 | -------------------------------------------------------------------------------- /gan-mnist/mnist_train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | import tensorflow as tf 5 | from tensorflow.examples.tutorials.mnist import input_data 6 | 7 | import mnist_inference 8 | 9 | BATCH_SIZE = 100 10 | LEARNING_RATE_BASE = 0.8 11 | LEARNING_RATE_DECAY = 0.99 12 | REGULARAZTION_RATE = 0.0001 13 | TRAINING_STEPS = 30000 14 | MOVING_AVERAGE_DECAY = 0.99 15 | MODEL_SAVE_PATH = "./model/" # 模型保存的路径 16 | MODEL_NAME = "model.ckpt" 17 | 18 | 19 | def train(mnist): 20 | x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') 21 | y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') 22 | 23 | regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE) 24 | # 直接使用mnist_inference.py中定义的前向传播过程 25 | y = mnist_inference.inference(x, regularizer) 26 | global_step = tf.Variable(0, trainable=False) 27 | # 定义损失函数、指数衰减学习率、滑动平均操作以及训练过程 28 | variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) 29 | variable_averages_op = variable_averages.apply(tf.trainable_variables()) 30 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1)) 31 | cross_entropy_mean = tf.reduce_mean(cross_entropy) 32 | loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) 33 | learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, 34 | LEARNING_RATE_DECAY) 35 | train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) 36 | 37 | with tf.control_dependencies([train_step, variable_averages_op]): 38 | train_op = tf.no_op(name='train') 39 | # 初始化TensorFlow持久化类 40 | saver = tf.train.Saver() 41 | with tf.Session() as sess: 42 | tf.global_variables_initializer().run() 43 | # 训练过程 44 | for i in range(TRAINING_STEPS): 45 | xs, ys = mnist.train.next_batch(BATCH_SIZE) 46 | _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) 47 | # 每1000轮保存一次模型 48 | if i % 1000 == 0: 49 | print("After %d training step(s), loss on training batch is %g." % (step, loss_value)) 50 | print 51 | os.path.join(MODEL_SAVE_PATH, MODEL_NAME) 52 | saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) 53 | 54 | 55 | def main(argv=None): 56 | mnist = input_data.read_data_sets("./mnist", one_hot=True) 57 | train(mnist) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.app.run() 62 | --------------------------------------------------------------------------------