├── .gitignore ├── README.md ├── config.py ├── download_imagenet.py ├── img ├── SRGAN Result.pptx ├── SRGAN Result2.pptx ├── SRGAN Result3.pptx ├── SRGAN_Result.png ├── SRGAN_Result2.png ├── SRGAN_Result3.png └── model.jpeg ├── main.py ├── model.py ├── tensorlayer ├── __init__.py ├── activation.py ├── cost.py ├── db.py ├── files.py ├── iterate.py ├── layers.py ├── main-Copy3.py ├── nlp.py ├── ops.py ├── prepro.py ├── rein.py ├── utils.py └── visualize.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ._* 2 | *.pyc 3 | .DS_Store 4 | *.npz 5 | sample/ 6 | samples/ 7 | checkpoint/ 8 | __pycache__/ 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## SRGAN_Wasserstein 2 | Applying Waseerstein GAN to SRGAN, a GAN based super resolution algorithm. 3 | 4 | ***This repo was forked from @zsdonghao 's [tensorlayer/srgan](https://github.com/tensorlayer/srgan) repo, based on this original repo, I changed some code to apply wasserstein loss, making the training procedure more stable, thanks @zsdonghao again, for his great reimplementation.*** 5 | 6 | ### SRGAN Architecture 7 | ![](http://ormr426d5.bkt.clouddn.com/18-5-18/43943225.jpg) 8 | 9 | TensorFlow Implementation of ["Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"](https://arxiv.org/abs/1609.04802) 10 | 11 | ### Wasserstein GAN 12 | 13 | When the SRGAN was first proposed in 2016, we haven't had [Wasserstein GAN](https://arxiv.org/abs/1701.07875)(2017) yet, WGAN using wasserstein distance to measure the disturibution difference between different data set. As for the original GAN training, we don't know when to stop training the discriminator or the generator, to get a nice result. But when using the wasserstein loss, as the loss decreasing, the result will be better. So we are going to use the WGAN and we are not going to explain the math detail of WGAN here, but to give the following steps to apply WGAN. 14 | 15 | * Remove the sigmoid activation from the last layer of the discriminator. (```model.py```, line 218-219) 16 | * Don't take logarithm to the loss of discriminator and generator. (```main.py```, line 105-108) 17 | * Clipping the weights to some contant range [-c, c]. (```main.py```, line 136) 18 | * Don't use the optimizer like adam or momoentum which based on momentum, instead, RMSprop or SGD would be better. (```main.py```, line 132-133) 19 | 20 | These above steps was given by an excellent article[[4]](https://zhuanlan.zhihu.com/p/25071913), the arthor explained the WGAN in a very straightforward way, it was written in Chinese. 21 | 22 | ### Loss curve and Result 23 | ![](http://ormr426d5.bkt.clouddn.com/18-5-18/8141442.jpg) 24 | 25 | ![](http://ormr426d5.bkt.clouddn.com/18-5-18/22508558.jpg) 26 | 27 | ![](http://ormr426d5.bkt.clouddn.com/18-5-18/83166966.jpg) 28 | 29 | ![](http://ormr426d5.bkt.clouddn.com/18-5-18/96883821.jpg) 30 | 31 | 32 | ### Prepare Data and Pre-trained VGG 33 | 34 | - 1. You need to download the pretrained VGG19 model in [here](https://mega.nz/#!xZ8glS6J!MAnE91ND_WyfZ_8mvkuSa2YcA7q-1ehfSm-Q1fxOvvs) as [tutorial_vgg19.py](https://github.com/zsdonghao/tensorlayer/blob/master/example/tutorial_vgg19.py) show. 35 | - 2. You need to have the high resolution images for training. 36 | - In this experiment, I used images from [DIV2K - bicubic downscaling x4 competition](http://www.vision.ee.ethz.ch/ntire17/), so the hyper-paremeters in `config.py` (like number of epochs) are seleted basic on that dataset, if you change a larger dataset you can reduce the number of epochs. 37 | - If you dont want to use DIV2K dataset, you can also use [Yahoo MirFlickr25k](http://press.liacs.nl/mirflickr/mirdownload.html), just simply download it using `train_hr_imgs = tl.files.load_flickr25k_dataset(tag=None)` in `main.py`. 38 | - If you want to use your own images, you can set the path to your image folder via `config.TRAIN.hr_img_path` in `config.py`. 39 | 40 | ### Run 41 | 42 | We run this script under [TensorFlow](https://www.tensorflow.org) 1.4 and the [TensorLayer](https://github.com/tensorlayer/tensorlayer) 1.8.0+. 43 | 44 | * Installation 45 | 46 | ``` 47 | pip install tensorlayer==1.8.0 48 | conda install tensorflow-gpu==1.3.0 49 | pip install tensorflow-gpu==1.4.0 50 | pip install easydict 51 | ``` 52 | 53 | - You can download [DIV2K - bicubic downscaling x4 competition](http://www.vision.ee.ethz.ch/ntire17/) dataset, and set your image folder in `config.py`. 54 | - Other links for DIV2K, in case you can't find it : [test\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/validation_release/DIV2K_test_LR_bicubic_X4.zip), [train_HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip), [train\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip), [valid_HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/validation_release/DIV2K_valid_HR.zip), [valid\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip). 55 | 56 | ```python 57 | config.TRAIN.img_path = "your_image_folder/" 58 | ``` 59 | - Tenserboard logdir. 60 | 61 | I added the tensorboard callbacks to monitor the training procedure, please change the logdir to your folder. 62 | 63 | ```python 64 | config.VALID.logdir = 'your_tensorboard_folder' 65 | ``` 66 | 67 | - Start training. 68 | 69 | ```bash 70 | python main.py 71 | ``` 72 | 73 | - Start evaluation. ([pretrained model](https://github.com/tensorlayer/srgan/releases/tag/1.2.0) for DIV2K) 74 | **An important note:** 75 | This pretrained weights is provided by the original author @zsdonghao , his final layer's conv kernel of ```SRGAN_g``` (model.py line 53) is using 1×1 kernel, but I changed this kernel to 9×9, so if you use this pretrained weights, you may get the weights unequal error. 76 | Two advice: 77 | 1)Train the whole network from scratch, you'll get the 9×9 version weights, for further training or evaluating images. 78 | 2)You can just change the ```SRGAN_g``` 's final conv kernel (```model.py``` line 53) to (1, 1) instead of (9, 9), and change the ```model.py``` line 35 conv kernel from (9, 9) to (3, 3), so that you can use the pretrained weights. 79 | 80 | ```bash 81 | python main.py --mode=evaluate 82 | ``` 83 | 84 | ### What's new? 85 | 86 | Compare with the original version, I did the following changes: 87 | 88 | 1. Adding WGAN, as described in Wasserstein GAN chapter. 89 | 2. Adding tensorboard, to monitor the training procedure. 90 | 3. Modified the last conv layer of 'SRGAN_g' in ```model.py``` (line 100), changing the kernel size from (1, 1) to (9, 9), as the paper proposed. 91 | 92 | ### Reference 93 | * [1] [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) 94 | * [2] [Is the deconvolution layer the same as a convolutional layer ?](https://arxiv.org/abs/1609.07009) 95 | * [3] [Wasserstein GAN](https://arxiv.org/abs/1701.07875) 96 | * [4] [令人拍案叫绝的Wasserstein GAN](https://zhuanlan.zhihu.com/p/25071913) 97 | * [5] [SRGAN With WGAN,让超分辨率算法训练更稳定-知乎专栏](https://zhuanlan.zhihu.com/p/37009085) [Chinese verson readme] 98 | 99 | ### Author 100 | - [zsdonghao](https://github.com/zsdonghao) 101 | - [justinho](https://github.com/JustinhoCHN) 102 | 103 | ### License 104 | 105 | - For academic and non-commercial use only. 106 | - For commercial use, please contact tensorlayer@gmail.com. -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import json 3 | 4 | config = edict() 5 | config.TRAIN = edict() 6 | 7 | ## Adam 8 | config.TRAIN.batch_size = 16 9 | config.TRAIN.lr_init = 1e-4 10 | config.TRAIN.beta1 = 0.9 11 | 12 | ## initialize G 13 | config.TRAIN.n_epoch_init = 0 14 | # config.TRAIN.lr_decay_init = 0.1 15 | # config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2) 16 | 17 | ## adversarial learning (SRGAN) 18 | config.TRAIN.n_epoch = 1500 19 | config.TRAIN.lr_decay = 0.1 20 | config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2) 21 | 22 | ## train set location 23 | config.TRAIN.hr_img_path = '/home/ubuntu/huzhihao/SRGAN_Wasserstein/dataset/DIV2K_train_HR/' 24 | config.TRAIN.lr_img_path = '/home/ubuntu/huzhihao/SRGAN_Wasserstein/dataset/DIV2K_train_LR_bicubic/X4/' 25 | #config.TRAIN.hr_img_path = '/home/ubuntu/dataset/image_tag/srgan_all_jpg/trn_hr/' 26 | #config.TRAIN.lr_img_path = '/home/ubuntu/dataset/image_tag/srgan_all_jpg/trn_lr/' 27 | 28 | 29 | config.VALID = edict() 30 | ## test set location 31 | config.VALID.hr_img_path = '/home/ubuntu/huzhihao/SRGAN_Wasserstein/dataset/DIV2K_valid_HR/' 32 | config.VALID.lr_img_path = '/home/ubuntu/huzhihao/SRGAN_Wasserstein/dataset/DIV2K_valid_LR_bicubic/X4/' 33 | #config.VALID.hr_img_path = '/home/ubuntu/dataset/image_tag/srgan_all_jpg/val_hr/' 34 | #config.VALID.lr_img_path = '/home/ubuntu/dataset/image_tag/srgan_all_jpg/val_lr/' 35 | 36 | config.VALID.logdir = '/home/ubuntu/SRGAN_Wasserstein/log/' 37 | def log_config(filename, cfg): 38 | with open(filename, 'w') as f: 39 | f.write("================================================\n") 40 | f.write(json.dumps(cfg, indent=4)) 41 | f.write("\n================================================\n") 42 | -------------------------------------------------------------------------------- /download_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import socket 3 | import os 4 | import urllib 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from joblib import Parallel, delayed 9 | 10 | 11 | def download_image(download_str, save_dir): 12 | img_name, img_url = download_str.strip().split('\t') 13 | save_img = os.path.join(save_dir, "{}.jpg".format(img_name)) 14 | downloaded = False 15 | try: 16 | if not os.path.isfile(save_img): 17 | print("Downloading {} to {}.jpg".format(img_url, img_name)) 18 | urllib.urlretrieve(img_url, save_img) 19 | 20 | # Check size of the images 21 | downloaded = True 22 | with Image.open(save_img) as img: 23 | width, height = img.size 24 | 25 | img_size_bytes = os.path.getsize(save_img) 26 | img_size_KB = img_size_bytes / 1024 27 | 28 | if width < 500 or height < 500 or img_size_KB < 200: 29 | os.remove(save_img) 30 | print("Remove downloaded images (w:{}, h:{}, s:{}KB)".format(width, height, img_size_KB)) 31 | else: 32 | print("Already downloaded {}".format(save_img)) 33 | except Exception: 34 | if not downloaded: 35 | print("Cannot download.") 36 | else: 37 | print("Remove failed, downloaded images.") 38 | 39 | if os.path.isfile(save_img): 40 | os.remove(save_img) 41 | 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--img_url_file", type=str, required=True, 46 | help="File that contains list of image IDs and urls.") 47 | parser.add_argument("--output_dir", type=str, required=True, 48 | help="Directory where to save outputs.") 49 | parser.add_argument("--n_download_urls", type=int, default=20000, 50 | help="Directory where to save outputs.") 51 | args = parser.parse_args() 52 | 53 | # np.random.seed(123456) 54 | 55 | socket.setdefaulttimeout(10) 56 | 57 | with open(args.img_url_file) as f: 58 | lines = f.readlines() 59 | lines = np.random.choice(lines, size=args.n_download_urls, replace=False) 60 | 61 | Parallel(n_jobs=12)(delayed(download_image)(line, args.output_dir) for line in lines) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /img/SRGAN Result.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinhoCHN/SRGAN_Wasserstein/08cb76028880f95cbeea1353c5bfc5b2b356ae83/img/SRGAN Result.pptx -------------------------------------------------------------------------------- /img/SRGAN Result2.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinhoCHN/SRGAN_Wasserstein/08cb76028880f95cbeea1353c5bfc5b2b356ae83/img/SRGAN Result2.pptx -------------------------------------------------------------------------------- /img/SRGAN Result3.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinhoCHN/SRGAN_Wasserstein/08cb76028880f95cbeea1353c5bfc5b2b356ae83/img/SRGAN Result3.pptx -------------------------------------------------------------------------------- /img/SRGAN_Result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinhoCHN/SRGAN_Wasserstein/08cb76028880f95cbeea1353c5bfc5b2b356ae83/img/SRGAN_Result.png -------------------------------------------------------------------------------- /img/SRGAN_Result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinhoCHN/SRGAN_Wasserstein/08cb76028880f95cbeea1353c5bfc5b2b356ae83/img/SRGAN_Result2.png -------------------------------------------------------------------------------- /img/SRGAN_Result3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinhoCHN/SRGAN_Wasserstein/08cb76028880f95cbeea1353c5bfc5b2b356ae83/img/SRGAN_Result3.png -------------------------------------------------------------------------------- /img/model.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JustinhoCHN/SRGAN_Wasserstein/08cb76028880f95cbeea1353c5bfc5b2b356ae83/img/model.jpeg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | import os, time, pickle, random, time 5 | from datetime import datetime 6 | import numpy as np 7 | from time import localtime, strftime 8 | import logging, scipy 9 | 10 | import tensorflow as tf 11 | import tensorlayer as tl 12 | from model import * 13 | from utils import * 14 | from config import config, log_config 15 | 16 | ###====================== HYPER-PARAMETERS ===========================### 17 | ## Adam 18 | batch_size = config.TRAIN.batch_size 19 | lr_init = config.TRAIN.lr_init 20 | beta1 = config.TRAIN.beta1 21 | ## initialize G 22 | n_epoch_init = config.TRAIN.n_epoch_init 23 | ## adversarial learning (SRGAN) 24 | n_epoch = config.TRAIN.n_epoch 25 | lr_decay = config.TRAIN.lr_decay 26 | decay_every = config.TRAIN.decay_every 27 | logdir = config.VALID.logdir 28 | 29 | ni = int(np.sqrt(batch_size)) 30 | 31 | 32 | def read_all_imgs(img_list, path='', n_threads=32): 33 | """ Returns all images in array by given path and name of each image file. """ 34 | imgs = [] 35 | for idx in range(0, len(img_list), n_threads): 36 | b_imgs_list = img_list[idx : idx + n_threads] 37 | b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=path) 38 | # print(b_imgs.shape) 39 | imgs.extend(b_imgs) 40 | print('read %d from %s' % (len(imgs), path)) 41 | return imgs 42 | 43 | def train(): 44 | ## create folders to save result images and trained model 45 | save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) 46 | save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) 47 | tl.files.exists_or_mkdir(save_dir_ginit) 48 | tl.files.exists_or_mkdir(save_dir_gan) 49 | checkpoint_dir = "checkpoint" # checkpoint_resize_conv 50 | tl.files.exists_or_mkdir(checkpoint_dir) 51 | 52 | ###====================== PRE-LOAD DATA ===========================### 53 | train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False)) 54 | train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) 55 | valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) 56 | valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) 57 | 58 | ## If your machine have enough memory, please pre-load the whole train set. 59 | train_hr_imgs = read_all_imgs(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) 60 | # for im in train_hr_imgs: 61 | # print(im.shape) 62 | # valid_lr_imgs = read_all_imgs(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) 63 | # for im in valid_lr_imgs: 64 | # print(im.shape) 65 | # valid_hr_imgs = read_all_imgs(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) 66 | # for im in valid_hr_imgs: 67 | # print(im.shape) 68 | # exit() 69 | 70 | ###========================== DEFINE MODEL ============================### 71 | ## train inference 72 | t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') 73 | t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') 74 | 75 | net_g = SRGAN_g(t_image, is_train=True, reuse=False) 76 | net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) 77 | _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) 78 | 79 | net_g.print_params(False) 80 | net_d.print_params(False) 81 | 82 | ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA 83 | t_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer 84 | t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg 85 | 86 | net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224+1)/2, reuse=False) 87 | _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224+1)/2, reuse=True) 88 | 89 | ## test inference 90 | net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) 91 | 92 | # ###========================== DEFINE TRAIN OPS ==========================### 93 | # d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') 94 | # d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') 95 | 96 | # d_loss1 = tl.cost.cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') 97 | # d_loss2 = tl.cost.cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') 98 | # 99 | # d_loss = d_loss1 + d_loss2 100 | 101 | # Wasserstein GAN Loss 102 | with tf.name_scope('w_loss/WARS_1'): 103 | d_loss = tf.reduce_mean(logits_fake) - tf.reduce_mean(logits_real) 104 | tf.summary.scalar('w_loss', d_loss) 105 | 106 | merged = tf.summary.merge_all() 107 | # loss_writer = tf.summary.FileWriter('/home/ubuntu/huzhihao/WARS/log/', sess.graph) 108 | # g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g') 109 | g_gan_loss = - 1e-3 * tf.reduce_mean(logits_fake) 110 | mse_loss = tl.cost.mean_squared_error(net_g.outputs , t_target_image, is_mean=True) 111 | vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) 112 | 113 | g_loss = mse_loss + vgg_loss + g_gan_loss 114 | 115 | g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) 116 | d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) 117 | 118 | with tf.variable_scope('learning_rate'): 119 | lr_v = tf.Variable(lr_init, trainable=False) 120 | ## Pretrain 121 | # g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) 122 | g_optim_init = tf.train.RMSPropOptimizer(lr_v).minimize(mse_loss, var_list=g_vars) 123 | 124 | ## SRGAN 125 | # g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) 126 | # d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) 127 | 128 | g_optim = tf.train.RMSPropOptimizer(lr_v).minimize(g_loss, var_list=g_vars) 129 | d_optim = tf.train.RMSPropOptimizer(lr_v).minimize(d_loss, var_list=d_vars) 130 | 131 | # clip op 132 | clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars] 133 | 134 | 135 | ###========================== RESTORE MODEL =============================### 136 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 137 | loss_writer = tf.summary.FileWriter(logdir, sess.graph) 138 | tl.layers.initialize_global_variables(sess) 139 | if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: 140 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) 141 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) 142 | 143 | ###============================= LOAD VGG ===============================### 144 | vgg19_npy_path = "vgg19.npy" 145 | if not os.path.isfile(vgg19_npy_path): 146 | print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg") 147 | exit() 148 | npz = np.load(vgg19_npy_path, encoding='latin1').item() 149 | 150 | params = [] 151 | for val in sorted( npz.items() ): 152 | W = np.asarray(val[1][0]) 153 | b = np.asarray(val[1][1]) 154 | print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) 155 | params.extend([W, b]) 156 | tl.files.assign_params(sess, params, net_vgg) 157 | # net_vgg.print_params(False) 158 | # net_vgg.print_layers() 159 | 160 | ###============================= TRAINING ===============================### 161 | ## use first `batch_size` of train set to have a quick test during training 162 | sample_imgs = train_hr_imgs[0:batch_size] 163 | # sample_imgs = read_all_imgs(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set 164 | sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) 165 | print('sample HR sub-image:',sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) 166 | sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) 167 | print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) 168 | tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit+'/_train_sample_96.png') 169 | tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit+'/_train_sample_384.png') 170 | tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan+'/_train_sample_96.png') 171 | tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan+'/_train_sample_384.png') 172 | 173 | ###========================= initialize G ====================### 174 | ## fixed learning rate 175 | sess.run(tf.assign(lr_v, lr_init)) 176 | print(" ** fixed learning rate: %f (for init G)" % lr_init) 177 | for epoch in range(0, n_epoch_init+1): 178 | epoch_time = time.time() 179 | total_mse_loss, n_iter = 0, 0 180 | 181 | ## If your machine cannot load all images into memory, you should use 182 | ## this one to load batch of images while training. 183 | # random.shuffle(train_hr_img_list) 184 | # for idx in range(0, len(train_hr_img_list), batch_size): 185 | # step_time = time.time() 186 | # b_imgs_list = train_hr_img_list[idx : idx + batch_size] 187 | # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) 188 | # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) 189 | # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 190 | 191 | ## If your machine have enough memory, please pre-load the whole train set. 192 | for idx in range(0, len(train_hr_imgs), batch_size): 193 | step_time = time.time() 194 | b_imgs_384 = tl.prepro.threading_data( 195 | train_hr_imgs[idx : idx + batch_size], 196 | fn=crop_sub_imgs_fn, is_random=True) 197 | b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 198 | ## update G 199 | errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384}) 200 | print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) 201 | total_mse_loss += errM 202 | n_iter += 1 203 | log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss/n_iter) 204 | print(log) 205 | 206 | ## quick evaluation on train set 207 | if (epoch != 0) and (epoch % 10 == 0): 208 | out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max()) 209 | print("[*] save images") 210 | tl.vis.save_images(out, [ni, ni], save_dir_ginit+'/train_%d.png' % epoch) 211 | 212 | ## save model 213 | if (epoch != 0) and (epoch % 10 == 0): 214 | tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) 215 | 216 | ###========================= train GAN (SRGAN) =========================### 217 | 218 | # clipping method 219 | # clip_discriminator_var_op = [var.assign(tf.clip_by_value(var, self.clip_values[0], self.clip_values[1])) for 220 | # var in self.discriminator_variables] 221 | 222 | for epoch in range(0, n_epoch+1): 223 | ## update learning rate 224 | if epoch !=0 and (epoch % decay_every == 0): 225 | new_lr_decay = lr_decay ** (epoch // decay_every) 226 | sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) 227 | log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) 228 | print(log) 229 | elif epoch == 0: 230 | sess.run(tf.assign(lr_v, lr_init)) 231 | log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay) 232 | print(log) 233 | 234 | epoch_time = time.time() 235 | total_d_loss, total_g_loss, n_iter = 0, 0, 0 236 | 237 | ## If your machine cannot load all images into memory, you should use 238 | ## this one to load batch of images while training. 239 | # random.shuffle(train_hr_img_list) 240 | # for idx in range(0, len(train_hr_img_list), batch_size): 241 | # step_time = time.time() 242 | # b_imgs_list = train_hr_img_list[idx : idx + batch_size] 243 | # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) 244 | # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) 245 | # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 246 | 247 | ## If your machine have enough memory, please pre-load the whole train set. 248 | for idx in range(0, len(train_hr_imgs), batch_size): 249 | step_time = time.time() 250 | b_imgs_384 = tl.prepro.threading_data( 251 | train_hr_imgs[idx : idx + batch_size], 252 | fn=crop_sub_imgs_fn, is_random=True) 253 | b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 254 | ## update D 255 | 256 | errD, summary, _, _ = sess.run([d_loss, merged, d_optim, clip_D], {t_image: b_imgs_96, t_target_image: b_imgs_384}) 257 | loss_writer.add_summary(summary, idx) 258 | # d_vars = sess.run(clip_discriminator_var_op) 259 | ## update G 260 | errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim],{t_image: b_imgs_96, t_target_image: b_imgs_384}) 261 | 262 | print("Epoch [%2d/%2d] %4d time: %4.4fs, W_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" 263 | % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) 264 | total_d_loss += errD 265 | total_g_loss += errG 266 | n_iter += 1 267 | 268 | log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss/n_iter, total_g_loss/n_iter) 269 | print(log) 270 | 271 | ## quick evaluation on train set 272 | if (epoch != 0) and (epoch % 10 == 0): 273 | out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max()) 274 | print("[*] save images") 275 | tl.vis.save_images(out, [ni, ni], save_dir_gan+'/train_%d.png' % epoch) 276 | 277 | ## save model 278 | if (epoch != 0) and (epoch % 10 == 0): 279 | tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) 280 | tl.files.save_npz(net_d.all_params, name=checkpoint_dir+'/d_{}.npz'.format(tl.global_flag['mode']), sess=sess) 281 | 282 | def evaluate(): 283 | ## create folders to save result images 284 | save_dir = "samples/{}".format(tl.global_flag['mode']) 285 | tl.files.exists_or_mkdir(save_dir) 286 | checkpoint_dir = "checkpoint" 287 | 288 | ###====================== PRE-LOAD DATA ===========================### 289 | # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.jpg', printable=False)) 290 | # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.jpg', printable=False)) 291 | valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) 292 | valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) 293 | 294 | ## If your machine have enough memory, please pre-load the whole train set. 295 | # train_hr_imgs = read_all_imgs(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) 296 | # for im in train_hr_imgs: 297 | # print(im.shape) 298 | valid_lr_imgs = read_all_imgs(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) 299 | # for im in valid_lr_imgs: 300 | # print(im.shape) 301 | valid_hr_imgs = read_all_imgs(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) 302 | # for im in valid_hr_imgs: 303 | # print(im.shape) 304 | # exit() 305 | 306 | ###========================== DEFINE MODEL ============================### 307 | imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 308 | valid_lr_img = valid_lr_imgs[imid] 309 | valid_hr_img = valid_hr_imgs[imid] 310 | #img_name = '0010_80.jpg' 311 | #valid_lr_img = get_imgs_fn(img_name, '/home/ubuntu/dataset/sr_test/testing/') # if you want to test your own image 312 | valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] 313 | # print(valid_lr_img.min(), valid_lr_img.max()) 314 | 315 | size = valid_lr_img.shape 316 | t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') 317 | # t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') 318 | 319 | net_g = SRGAN_g(t_image, is_train=False, reuse=False) 320 | 321 | ###========================== RESTORE G =============================### 322 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 323 | tl.layers.initialize_global_variables(sess) 324 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_srgan.npz', network=net_g) 325 | 326 | ###======================= EVALUATION =============================### 327 | start_time = time.time() 328 | out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) 329 | print("took: %4.4fs" % (time.time() - start_time)) 330 | 331 | print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) 332 | print("[*] save images") 333 | #tl.vis.save_image(out[0], save_dir+ '/gen_' + img_name[:-4] + '.png') 334 | tl.vis.save_image(out[0], save_dir + '/valid_gen.png') 335 | #tl.vis.save_image(valid_lr_img, save_dir+'/valid_lr.png') 336 | #tl.vis.save_image(valid_hr_img, save_dir+'/valid_hr.png') 337 | 338 | out_bicu = scipy.misc.imresize(valid_lr_img, [size[0]*4, size[1]*4], interp='bicubic', mode=None) 339 | #tl.vis.save_image(out_bicu, save_dir + '/bicubic_' + img_name[:-4] + '.png') 340 | tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png') 341 | 342 | if __name__ == '__main__': 343 | import argparse 344 | parser = argparse.ArgumentParser() 345 | 346 | parser.add_argument('--mode', type=str, default='srgan', help='srgan, evaluate') 347 | 348 | args = parser.parse_args() 349 | 350 | tl.global_flag['mode'] = args.mode 351 | 352 | if tl.global_flag['mode'] == 'srgan': 353 | train() 354 | elif tl.global_flag['mode'] == 'evaluate': 355 | evaluate() 356 | else: 357 | raise Exception("Unknow --mode") 358 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | import tensorflow as tf 5 | import tensorlayer as tl 6 | from tensorlayer.layers import * 7 | # from tensorflow.python.ops import variable_scope as vs 8 | # from tensorflow.python.ops import math_ops, init_ops, array_ops, nn 9 | # from tensorflow.python.util import nest 10 | # from tensorflow.contrib.rnn.python.ops import core_rnn_cell 11 | 12 | # https://github.com/david-gpu/srez/blob/master/srez_model.py 13 | 14 | def SRGAN_g(t_image, is_train=False, reuse=False): 15 | """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 16 | feature maps (n) and stride (s) feature maps (n) and stride (s) 17 | """ 18 | w_init = tf.random_normal_initializer(stddev=0.02) 19 | b_init = None # tf.constant_initializer(value=0.0) 20 | g_init = tf.random_normal_initializer(1., 0.02) 21 | with tf.variable_scope("SRGAN_g", reuse=reuse) as vs: 22 | tl.layers.set_name_reuse(reuse) 23 | n = InputLayer(t_image, name='in') 24 | n = Conv2d(n, 64, (9, 9), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c') 25 | # temp = n 26 | 27 | # Artifact Reduction Block, add by huzhihao, reference: http://arxiv.org/abs/1608.02778 28 | #n = Conv2d(n, 32, (1,1), (1, 1), act = tf.nn.relu, padding='SAME', W_init=w_init, name='n32s1/c0/1') 29 | # n = Conv2d(n, 32, (3,3), (1, 1), act = tf.nn.relu, padding='SAME', W_init=w_init, name='n32s1/c0/2') 30 | #n = Conv2d(n, 64, (3,3), (1, 1), act = tf.nn.relu, padding='SAME', W_init=w_init, name='n32s1/c0/3') 31 | # n = Conv2d(n, 64, (7,7), (1, 1), act = tf.nn.relu, padding='SAME', W_init=w_init, name='n32s1/c0/2') 32 | temp = n 33 | # B residual blocks 34 | for i in range(16): 35 | nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i) 36 | nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i) 37 | nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i) 38 | nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i) 39 | nn = ElementwiseLayer([n, nn], tf.add, 'b_residual_add/%s' % i) 40 | n = nn 41 | 42 | n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m') 43 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m') 44 | n = ElementwiseLayer([n, temp], tf.add, 'add3') 45 | # B residual blacks end 46 | 47 | n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1') 48 | n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1') 49 | 50 | n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2') 51 | n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2') 52 | 53 | n = Conv2d(n, 3, (9, 9), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out') 54 | return n 55 | 56 | 57 | def SRGAN_g2(t_image, is_train=False, reuse=False): 58 | """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 59 | feature maps (n) and stride (s) feature maps (n) and stride (s) 60 | 61 | 96x96 --> 384x384 62 | 63 | Use Resize Conv 64 | """ 65 | w_init = tf.random_normal_initializer(stddev=0.02) 66 | b_init = None # tf.constant_initializer(value=0.0) 67 | g_init = tf.random_normal_initializer(1., 0.02) 68 | 69 | size = t_image.get_shape().as_list() 70 | 71 | with tf.variable_scope("SRGAN_g", reuse=reuse) as vs: 72 | tl.layers.set_name_reuse(reuse) 73 | n = InputLayer(t_image, name='in') 74 | n = Conv2d(n, 64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', W_init=w_init, name='n64s1/c') 75 | temp = n 76 | 77 | # B residual blocks 78 | for i in range(16): 79 | nn = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c1/%s' % i) 80 | nn = BatchNormLayer(nn, act=tf.nn.relu, is_train=is_train, gamma_init=g_init, name='n64s1/b1/%s' % i) 81 | nn = Conv2d(nn, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c2/%s' % i) 82 | nn = BatchNormLayer(nn, is_train=is_train, gamma_init=g_init, name='n64s1/b2/%s' % i) 83 | nn = ElementwiseLayer([n, nn], tf.add, 'b_residual_add/%s' % i) 84 | n = nn 85 | 86 | n = Conv2d(n, 64, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, b_init=b_init, name='n64s1/c/m') 87 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s1/b/m') 88 | n = ElementwiseLayer([n, temp], tf.add, 'add3') 89 | # B residual blacks end 90 | 91 | # n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/1') 92 | # n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/1') 93 | # 94 | # n = Conv2d(n, 256, (3, 3), (1, 1), act=None, padding='SAME', W_init=w_init, name='n256s1/2') 95 | # n = SubpixelConv2d(n, scale=2, n_out_channel=None, act=tf.nn.relu, name='pixelshufflerx2/2') 96 | 97 | ## 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA 98 | n = UpSampling2dLayer(n, size=[size[1]*2, size[2]*2], is_scale=False, method=1, align_corners=False, name='up1/upsample2d') 99 | n = Conv2d(n, 64, (3, 3), (1, 1), 100 | padding='SAME', W_init=w_init, b_init=b_init, name='up1/conv2d') # <-- may need to increase n_filter 101 | n = BatchNormLayer(n, act=tf.nn.relu, 102 | is_train=is_train, gamma_init=g_init, name='up1/batch_norm') 103 | 104 | n = UpSampling2dLayer(n, size=[size[1]*4, size[2]*4], is_scale=False, method=1, align_corners=False, name='up2/upsample2d') 105 | n = Conv2d(n, 32, (3, 3), (1, 1), 106 | padding='SAME', W_init=w_init, b_init=b_init, name='up2/conv2d') # <-- may need to increase n_filter 107 | n = BatchNormLayer(n, act=tf.nn.relu, 108 | is_train=is_train, gamma_init=g_init, name='up2/batch_norm') 109 | 110 | n = Conv2d(n, 3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init, name='out') 111 | return n 112 | 113 | 114 | def SRGAN_d2(t_image, is_train=False, reuse=False): 115 | """ Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 116 | feature maps (n) and stride (s) feature maps (n) and stride (s) 117 | """ 118 | w_init = tf.random_normal_initializer(stddev=0.02) 119 | b_init = None 120 | g_init = tf.random_normal_initializer(1., 0.02) 121 | lrelu = lambda x : tl.act.lrelu(x, 0.2) 122 | with tf.variable_scope("SRGAN_d", reuse=reuse) as vs: 123 | tl.layers.set_name_reuse(reuse) 124 | n = InputLayer(t_image, name='in') 125 | n = Conv2d(n, 64, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, name='n64s1/c') 126 | 127 | n = Conv2d(n, 64, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n64s2/c') 128 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n64s2/b') 129 | 130 | n = Conv2d(n, 128, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n128s1/c') 131 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n128s1/b') 132 | 133 | n = Conv2d(n, 128, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n128s2/c') 134 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n128s2/b') 135 | 136 | n = Conv2d(n, 256, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n256s1/c') 137 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n256s1/b') 138 | 139 | n = Conv2d(n, 256, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n256s2/c') 140 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n256s2/b') 141 | 142 | n = Conv2d(n, 512, (3, 3), (1, 1), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n512s1/c') 143 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n512s1/b') 144 | 145 | n = Conv2d(n, 512, (3, 3), (2, 2), act=lrelu, padding='SAME', W_init=w_init, b_init=b_init, name='n512s2/c') 146 | n = BatchNormLayer(n, is_train=is_train, gamma_init=g_init, name='n512s2/b') 147 | 148 | n = FlattenLayer(n, name='f') 149 | n = DenseLayer(n, n_units=1024, act=lrelu, name='d1024') 150 | n = DenseLayer(n, n_units=1, name='out') 151 | 152 | logits = n.outputs 153 | n.outputs = tf.nn.sigmoid(n.outputs) 154 | 155 | return n, logits 156 | 157 | def SRGAN_d(input_images, is_train=True, reuse=False): 158 | w_init = tf.random_normal_initializer(stddev=0.02) 159 | b_init = None # tf.constant_initializer(value=0.0) 160 | gamma_init=tf.random_normal_initializer(1., 0.02) 161 | df_dim = 64 162 | lrelu = lambda x: tl.act.lrelu(x, 0.2) 163 | with tf.variable_scope("SRGAN_d", reuse=reuse): 164 | tl.layers.set_name_reuse(reuse) 165 | net_in = InputLayer(input_images, name='input/images') 166 | net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lrelu, 167 | padding='SAME', W_init=w_init, name='h0/c') 168 | 169 | net_h1 = Conv2d(net_h0, df_dim*2, (4, 4), (2, 2), act=None, 170 | padding='SAME', W_init=w_init, b_init=b_init, name='h1/c') 171 | net_h1 = BatchNormLayer(net_h1, act=lrelu, is_train=is_train, 172 | gamma_init=gamma_init, name='h1/bn') 173 | net_h2 = Conv2d(net_h1, df_dim*4, (4, 4), (2, 2), act=None, 174 | padding='SAME', W_init=w_init, b_init=b_init, name='h2/c') 175 | net_h2 = BatchNormLayer(net_h2, act=lrelu, is_train=is_train, 176 | gamma_init=gamma_init, name='h2/bn') 177 | net_h3 = Conv2d(net_h2, df_dim*8, (4, 4), (2, 2), act=None, 178 | padding='SAME', W_init=w_init, b_init=b_init, name='h3/c') 179 | net_h3 = BatchNormLayer(net_h3, act=lrelu, is_train=is_train, 180 | gamma_init=gamma_init, name='h3/bn') 181 | net_h4 = Conv2d(net_h3, df_dim*16, (4, 4), (2, 2), act=None, 182 | padding='SAME', W_init=w_init, b_init=b_init, name='h4/c') 183 | net_h4 = BatchNormLayer(net_h4, act=lrelu, is_train=is_train, 184 | gamma_init=gamma_init, name='h4/bn') 185 | net_h5 = Conv2d(net_h4, df_dim*32, (4, 4), (2, 2), act=None, 186 | padding='SAME', W_init=w_init, b_init=b_init, name='h5/c') 187 | net_h5 = BatchNormLayer(net_h5, act=lrelu, is_train=is_train, 188 | gamma_init=gamma_init, name='h5/bn') 189 | net_h6 = Conv2d(net_h5, df_dim*16, (1, 1), (1, 1), act=None, 190 | padding='SAME', W_init=w_init, b_init=b_init, name='h6/c') 191 | net_h6 = BatchNormLayer(net_h6, act=lrelu, is_train=is_train, 192 | gamma_init=gamma_init, name='h6/bn') 193 | net_h7 = Conv2d(net_h6, df_dim*8, (1, 1), (1, 1), act=None, 194 | padding='SAME', W_init=w_init, b_init=b_init, name='h7/c') 195 | net_h7 = BatchNormLayer(net_h7, is_train=is_train, 196 | gamma_init=gamma_init, name='h7/bn') 197 | 198 | net = Conv2d(net_h7, df_dim*2, (1, 1), (1, 1), act=None, 199 | padding='SAME', W_init=w_init, b_init=b_init, name='res/c') 200 | net = BatchNormLayer(net, act=lrelu, is_train=is_train, 201 | gamma_init=gamma_init, name='res/bn') 202 | net = Conv2d(net, df_dim*2, (3, 3), (1, 1), act=None, 203 | padding='SAME', W_init=w_init, b_init=b_init, name='res/c2') 204 | net = BatchNormLayer(net, act=lrelu, is_train=is_train, 205 | gamma_init=gamma_init, name='res/bn2') 206 | net = Conv2d(net, df_dim*8, (3, 3), (1, 1), act=None, 207 | padding='SAME', W_init=w_init, b_init=b_init, name='res/c3') 208 | net = BatchNormLayer(net, is_train=is_train, 209 | gamma_init=gamma_init, name='res/bn3') 210 | net_h8 = ElementwiseLayer(layer=[net_h7, net], 211 | combine_fn=tf.add, name='res/add') 212 | net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2) 213 | 214 | net_ho = FlattenLayer(net_h8, name='ho/flatten') 215 | net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, 216 | W_init = w_init, name='ho/dense') 217 | logits = net_ho.outputs 218 | # Wasserstein GAN doesn't need the sigmoid output 219 | # net_ho.outputs = tf.nn.sigmoid(net_ho.outputs) 220 | 221 | return net_ho, logits 222 | 223 | def Vgg19_simple_api(rgb, reuse): 224 | """ 225 | Build the VGG 19 Model 226 | 227 | Parameters 228 | ----------- 229 | rgb : rgb image placeholder [batch, height, width, 3] values scaled [0, 1] 230 | """ 231 | VGG_MEAN = [103.939, 116.779, 123.68] 232 | with tf.variable_scope("VGG19", reuse=reuse) as vs: 233 | start_time = time.time() 234 | print("build model started") 235 | rgb_scaled = rgb * 255.0 236 | # Convert RGB to BGR 237 | if tf.__version__ <= '0.11': 238 | red, green, blue = tf.split(3, 3, rgb_scaled) 239 | else: # TF 1.0 240 | # print(rgb_scaled) 241 | red, green, blue = tf.split(rgb_scaled, 3, 3) 242 | assert red.get_shape().as_list()[1:] == [224, 224, 1] 243 | assert green.get_shape().as_list()[1:] == [224, 224, 1] 244 | assert blue.get_shape().as_list()[1:] == [224, 224, 1] 245 | if tf.__version__ <= '0.11': 246 | bgr = tf.concat(3, [ 247 | blue - VGG_MEAN[0], 248 | green - VGG_MEAN[1], 249 | red - VGG_MEAN[2], 250 | ]) 251 | else: 252 | bgr = tf.concat([ 253 | blue - VGG_MEAN[0], 254 | green - VGG_MEAN[1], 255 | red - VGG_MEAN[2], 256 | ], axis=3) 257 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3] 258 | 259 | """ input layer """ 260 | net_in = InputLayer(bgr, name='input') 261 | """ conv1 """ 262 | network = Conv2d(net_in, n_filter=64, filter_size=(3, 3), 263 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv1_1') 264 | network = Conv2d(network, n_filter=64, filter_size=(3, 3), 265 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv1_2') 266 | network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), 267 | padding='SAME', name='pool1') 268 | """ conv2 """ 269 | network = Conv2d(network, n_filter=128, filter_size=(3, 3), 270 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv2_1') 271 | network = Conv2d(network, n_filter=128, filter_size=(3, 3), 272 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv2_2') 273 | network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), 274 | padding='SAME', name='pool2') 275 | """ conv3 """ 276 | network = Conv2d(network, n_filter=256, filter_size=(3, 3), 277 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_1') 278 | network = Conv2d(network, n_filter=256, filter_size=(3, 3), 279 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_2') 280 | network = Conv2d(network, n_filter=256, filter_size=(3, 3), 281 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_3') 282 | network = Conv2d(network, n_filter=256, filter_size=(3, 3), 283 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv3_4') 284 | network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), 285 | padding='SAME', name='pool3') 286 | """ conv4 """ 287 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 288 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_1') 289 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 290 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_2') 291 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 292 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_3') 293 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 294 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv4_4') 295 | network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), 296 | padding='SAME', name='pool4') # (batch_size, 14, 14, 512) 297 | conv = network 298 | """ conv5 """ 299 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 300 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_1') 301 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 302 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_2') 303 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 304 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_3') 305 | network = Conv2d(network, n_filter=512, filter_size=(3, 3), 306 | strides=(1, 1), act=tf.nn.relu,padding='SAME', name='conv5_4') 307 | network = MaxPool2d(network, filter_size=(2, 2), strides=(2, 2), 308 | padding='SAME', name='pool5') # (batch_size, 7, 7, 512) 309 | """ fc 6~8 """ 310 | network = FlattenLayer(network, name='flatten') 311 | network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc6') 312 | network = DenseLayer(network, n_units=4096, act=tf.nn.relu, name='fc7') 313 | network = DenseLayer(network, n_units=1000, act=tf.identity, name='fc8') 314 | print("build model finished: %fs" % (time.time() - start_time)) 315 | return network, conv 316 | 317 | # def vgg16_cnn_emb(t_image, reuse=False): 318 | # """ t_image = 244x244 [0~255] """ 319 | # with tf.variable_scope("vgg16_cnn", reuse=reuse) as vs: 320 | # tl.layers.set_name_reuse(reuse) 321 | # 322 | # mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean') 323 | # net_in = InputLayer(t_image - mean, name='vgg_input_im') 324 | # """ conv1 """ 325 | # network = tl.layers.Conv2dLayer(net_in, 326 | # act = tf.nn.relu, 327 | # shape = [3, 3, 3, 64], # 64 features for each 3x3 patch 328 | # strides = [1, 1, 1, 1], 329 | # padding='SAME', 330 | # name ='vgg_conv1_1') 331 | # network = tl.layers.Conv2dLayer(network, 332 | # act = tf.nn.relu, 333 | # shape = [3, 3, 64, 64], # 64 features for each 3x3 patch 334 | # strides = [1, 1, 1, 1], 335 | # padding='SAME', 336 | # name ='vgg_conv1_2') 337 | # network = tl.layers.PoolLayer(network, 338 | # ksize=[1, 2, 2, 1], 339 | # strides=[1, 2, 2, 1], 340 | # padding='SAME', 341 | # pool = tf.nn.max_pool, 342 | # name ='vgg_pool1') 343 | # """ conv2 """ 344 | # network = tl.layers.Conv2dLayer(network, 345 | # act = tf.nn.relu, 346 | # shape = [3, 3, 64, 128], # 128 features for each 3x3 patch 347 | # strides = [1, 1, 1, 1], 348 | # padding='SAME', 349 | # name ='vgg_conv2_1') 350 | # network = tl.layers.Conv2dLayer(network, 351 | # act = tf.nn.relu, 352 | # shape = [3, 3, 128, 128], # 128 features for each 3x3 patch 353 | # strides = [1, 1, 1, 1], 354 | # padding='SAME', 355 | # name ='vgg_conv2_2') 356 | # network = tl.layers.PoolLayer(network, 357 | # ksize=[1, 2, 2, 1], 358 | # strides=[1, 2, 2, 1], 359 | # padding='SAME', 360 | # pool = tf.nn.max_pool, 361 | # name ='vgg_pool2') 362 | # """ conv3 """ 363 | # network = tl.layers.Conv2dLayer(network, 364 | # act = tf.nn.relu, 365 | # shape = [3, 3, 128, 256], # 256 features for each 3x3 patch 366 | # strides = [1, 1, 1, 1], 367 | # padding='SAME', 368 | # name ='vgg_conv3_1') 369 | # network = tl.layers.Conv2dLayer(network, 370 | # act = tf.nn.relu, 371 | # shape = [3, 3, 256, 256], # 256 features for each 3x3 patch 372 | # strides = [1, 1, 1, 1], 373 | # padding='SAME', 374 | # name ='vgg_conv3_2') 375 | # network = tl.layers.Conv2dLayer(network, 376 | # act = tf.nn.relu, 377 | # shape = [3, 3, 256, 256], # 256 features for each 3x3 patch 378 | # strides = [1, 1, 1, 1], 379 | # padding='SAME', 380 | # name ='vgg_conv3_3') 381 | # network = tl.layers.PoolLayer(network, 382 | # ksize=[1, 2, 2, 1], 383 | # strides=[1, 2, 2, 1], 384 | # padding='SAME', 385 | # pool = tf.nn.max_pool, 386 | # name ='vgg_pool3') 387 | # """ conv4 """ 388 | # network = tl.layers.Conv2dLayer(network, 389 | # act = tf.nn.relu, 390 | # shape = [3, 3, 256, 512], # 512 features for each 3x3 patch 391 | # strides = [1, 1, 1, 1], 392 | # padding='SAME', 393 | # name ='vgg_conv4_1') 394 | # network = tl.layers.Conv2dLayer(network, 395 | # act = tf.nn.relu, 396 | # shape = [3, 3, 512, 512], # 512 features for each 3x3 patch 397 | # strides = [1, 1, 1, 1], 398 | # padding='SAME', 399 | # name ='vgg_conv4_2') 400 | # network = tl.layers.Conv2dLayer(network, 401 | # act = tf.nn.relu, 402 | # shape = [3, 3, 512, 512], # 512 features for each 3x3 patch 403 | # strides = [1, 1, 1, 1], 404 | # padding='SAME', 405 | # name ='vgg_conv4_3') 406 | # 407 | # network = tl.layers.PoolLayer(network, 408 | # ksize=[1, 2, 2, 1], 409 | # strides=[1, 2, 2, 1], 410 | # padding='SAME', 411 | # pool = tf.nn.max_pool, 412 | # name ='vgg_pool4') 413 | # conv4 = network 414 | # 415 | # """ conv5 """ 416 | # network = tl.layers.Conv2dLayer(network, 417 | # act = tf.nn.relu, 418 | # shape = [3, 3, 512, 512], # 512 features for each 3x3 patch 419 | # strides = [1, 1, 1, 1], 420 | # padding='SAME', 421 | # name ='vgg_conv5_1') 422 | # network = tl.layers.Conv2dLayer(network, 423 | # act = tf.nn.relu, 424 | # shape = [3, 3, 512, 512], # 512 features for each 3x3 patch 425 | # strides = [1, 1, 1, 1], 426 | # padding='SAME', 427 | # name ='vgg_conv5_2') 428 | # network = tl.layers.Conv2dLayer(network, 429 | # act = tf.nn.relu, 430 | # shape = [3, 3, 512, 512], # 512 features for each 3x3 patch 431 | # strides = [1, 1, 1, 1], 432 | # padding='SAME', 433 | # name ='vgg_conv5_3') 434 | # network = tl.layers.PoolLayer(network, 435 | # ksize=[1, 2, 2, 1], 436 | # strides=[1, 2, 2, 1], 437 | # padding='SAME', 438 | # pool = tf.nn.max_pool, 439 | # name ='vgg_pool5') 440 | # 441 | # network = FlattenLayer(network, name='vgg_flatten') 442 | # 443 | # # # network = DropoutLayer(network, keep=0.6, is_fix=True, is_train=is_train, name='vgg_out/drop1') 444 | # # new_network = tl.layers.DenseLayer(network, n_units=4096, 445 | # # act = tf.nn.relu, 446 | # # name = 'vgg_out/dense') 447 | # # 448 | # # # new_network = DropoutLayer(new_network, keep=0.8, is_fix=True, is_train=is_train, name='vgg_out/drop2') 449 | # # new_network = DenseLayer(new_network, z_dim, #num_lstm_units, 450 | # # b_init=None, name='vgg_out/out') 451 | # return conv4, network 452 | -------------------------------------------------------------------------------- /tensorlayer/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep learning and Reinforcement learning library for Researchers and Engineers 3 | """ 4 | from __future__ import absolute_import 5 | 6 | 7 | try: 8 | install_instr = "Please make sure you install a recent enough version of TensorFlow." 9 | import tensorflow 10 | except ImportError: 11 | raise ImportError("__init__.py : Could not import TensorFlow." + install_instr) 12 | 13 | from . import activation 14 | from . import cost 15 | from . import files 16 | from . import iterate 17 | from . import layers 18 | from . import ops 19 | from . import utils 20 | from . import visualize 21 | from . import prepro 22 | from . import nlp 23 | from . import rein 24 | 25 | # alias 26 | act = activation 27 | vis = visualize 28 | 29 | __version__ = "1.5.0" 30 | 31 | global_flag = {} 32 | global_dict = {} 33 | -------------------------------------------------------------------------------- /tensorlayer/activation.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | import tensorflow as tf 7 | 8 | def identity(x, name=None): 9 | """The identity activation function, Shortcut is ``linear``. 10 | 11 | Parameters 12 | ---------- 13 | x : a tensor input 14 | input(s) 15 | 16 | 17 | Returns 18 | -------- 19 | A `Tensor` with the same type as `x`. 20 | """ 21 | return x 22 | 23 | # Shortcut 24 | linear = identity 25 | 26 | def ramp(x=None, v_min=0, v_max=1, name=None): 27 | """The ramp activation function. 28 | 29 | Parameters 30 | ---------- 31 | x : a tensor input 32 | input(s) 33 | v_min : float 34 | if input(s) smaller than v_min, change inputs to v_min 35 | v_max : float 36 | if input(s) greater than v_max, change inputs to v_max 37 | name : a string or None 38 | An optional name to attach to this activation function. 39 | 40 | 41 | Returns 42 | -------- 43 | A `Tensor` with the same type as `x`. 44 | """ 45 | return tf.clip_by_value(x, clip_value_min=v_min, clip_value_max=v_max, name=name) 46 | 47 | def leaky_relu(x=None, alpha=0.1, name="LeakyReLU"): 48 | """The LeakyReLU, Shortcut is ``lrelu``. 49 | 50 | Modified version of ReLU, introducing a nonzero gradient for negative 51 | input. 52 | 53 | Parameters 54 | ---------- 55 | x : A `Tensor` with type `float`, `double`, `int32`, `int64`, `uint8`, 56 | `int16`, or `int8`. 57 | alpha : `float`. slope. 58 | name : a string or None 59 | An optional name to attach to this activation function. 60 | 61 | Examples 62 | --------- 63 | >>> network = tl.layers.DenseLayer(network, n_units=100, name = 'dense_lrelu', 64 | ... act= lambda x : tl.act.lrelu(x, 0.2)) 65 | 66 | References 67 | ------------ 68 | - `Rectifier Nonlinearities Improve Neural Network Acoustic Models, Maas et al. (2013) `_ 69 | """ 70 | with tf.name_scope(name) as scope: 71 | # x = tf.nn.relu(x) 72 | # m_x = tf.nn.relu(-x) 73 | # x -= alpha * m_x 74 | x = tf.maximum(x, alpha * x) 75 | return x 76 | 77 | #Shortcut 78 | lrelu = leaky_relu 79 | 80 | def pixel_wise_softmax(output, name='pixel_wise_softmax'): 81 | """Return the softmax outputs of images, every pixels have multiple label, the sum of a pixel is 1. 82 | Usually be used for image segmentation. 83 | 84 | Parameters 85 | ------------ 86 | output : tensor 87 | - For 2d image, 4D tensor [batch_size, height, weight, channel], channel >= 2. 88 | - For 3d image, 5D tensor [batch_size, depth, height, weight, channel], channel >= 2. 89 | 90 | Examples 91 | --------- 92 | >>> outputs = pixel_wise_softmax(network.outputs) 93 | >>> dice_loss = 1 - dice_coe(outputs, y_, epsilon=1e-5) 94 | 95 | References 96 | ----------- 97 | - `tf.reverse `_ 98 | """ 99 | with tf.name_scope(name) as scope: 100 | return tf.nn.softmax(output) 101 | ## old implementation 102 | # exp_map = tf.exp(output) 103 | # if output.get_shape().ndims == 4: # 2d image 104 | # evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, True])) 105 | # elif output.get_shape().ndims == 5: # 3d image 106 | # evidence = tf.add(exp_map, tf.reverse(exp_map, [False, False, False, False, True])) 107 | # else: 108 | # raise Exception("output parameters should be 2d or 3d image, not %s" % str(output._shape)) 109 | # return tf.div(exp_map, evidence) 110 | -------------------------------------------------------------------------------- /tensorlayer/cost.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | import logging 5 | 6 | import tensorflow as tf 7 | import numbers 8 | from tensorflow.python.framework import ops 9 | from tensorflow.python.ops import standard_ops 10 | 11 | ## Cost Functions 12 | 13 | def cross_entropy(output, target, name=None): 14 | """It is a softmax cross-entropy operation, returns the TensorFlow expression of cross-entropy of two distributions, implement 15 | softmax internally. See ``tf.nn.sparse_softmax_cross_entropy_with_logits``. 16 | 17 | Parameters 18 | ---------- 19 | output : Tensorflow variable 20 | A distribution with shape: [batch_size, n_feature]. 21 | target : Tensorflow variable 22 | A batch of index with shape: [batch_size, ]. 23 | name : string 24 | Name of this loss. 25 | 26 | Examples 27 | -------- 28 | >>> ce = tl.cost.cross_entropy(y_logits, y_target_logits, 'my_loss') 29 | 30 | References 31 | ----------- 32 | - About cross-entropy: `wiki `_.\n 33 | - The code is borrowed from: `here `_. 34 | """ 35 | try: # old 36 | return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, targets=target)) 37 | except: # TF 1.0 38 | assert name is not None, "Please give a unique name to tl.cost.cross_entropy for TF1.0+" 39 | return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output, name=name)) 40 | 41 | def sigmoid_cross_entropy(output, target, name=None): 42 | """It is a sigmoid cross-entropy operation, see ``tf.nn.sigmoid_cross_entropy_with_logits``. 43 | """ 44 | try: # TF 1.0 45 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output, name=name)) 46 | except: 47 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=output, targets=target)) 48 | 49 | 50 | def binary_cross_entropy(output, target, epsilon=1e-8, name='bce_loss'): 51 | """Computes binary cross entropy given `output`. 52 | 53 | For brevity, let `x = output`, `z = target`. The binary cross entropy loss is 54 | 55 | loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i])) 56 | 57 | Parameters 58 | ---------- 59 | output : tensor of type `float32` or `float64`. 60 | target : tensor of the same type and shape as `output`. 61 | epsilon : float 62 | A small value to avoid output is zero. 63 | name : string 64 | An optional name to attach to this layer. 65 | 66 | References 67 | ----------- 68 | - `DRAW `_ 69 | """ 70 | # from tensorflow.python.framework import ops 71 | # with ops.op_scope([output, target], name, "bce_loss") as name: 72 | # output = ops.convert_to_tensor(output, name="preds") 73 | # target = ops.convert_to_tensor(targets, name="target") 74 | with tf.name_scope(name): 75 | return tf.reduce_mean(tf.reduce_sum(-(target * tf.log(output + epsilon) + 76 | (1. - target) * tf.log(1. - output + epsilon)), axis=1)) 77 | 78 | 79 | def mean_squared_error(output, target, is_mean=False): 80 | """Return the TensorFlow expression of mean-squre-error of two distributions. 81 | 82 | Parameters 83 | ---------- 84 | output : 2D or 4D tensor. 85 | target : 2D or 4D tensor. 86 | is_mean : boolean, if True, use ``tf.reduce_mean`` to compute the loss of one data, otherwise, use ``tf.reduce_sum`` (default). 87 | 88 | References 89 | ------------ 90 | - `Wiki Mean Squared Error `_ 91 | """ 92 | with tf.name_scope("mean_squared_error_loss"): 93 | if output.get_shape().ndims == 2: # [batch_size, n_feature] 94 | if is_mean: 95 | mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), 1)) 96 | else: 97 | mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), 1)) 98 | elif output.get_shape().ndims == 4: # [batch_size, w, h, c] 99 | if is_mean: 100 | mse = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(output, target), [1, 2, 3])) 101 | else: 102 | mse = tf.reduce_mean(tf.reduce_sum(tf.squared_difference(output, target), [1, 2, 3])) 103 | return mse 104 | 105 | def normalized_mean_square_error(output, target): 106 | """Return the TensorFlow expression of normalized mean-squre-error of two distributions. 107 | 108 | Parameters 109 | ---------- 110 | output : 2D or 4D tensor. 111 | target : 2D or 4D tensor. 112 | """ 113 | with tf.name_scope("mean_squared_error_loss"): 114 | if output.get_shape().ndims == 2: # [batch_size, n_feature] 115 | nmse_a = tf.sqrt(tf.reduce_sum(tf.squared_difference(output, target), axis=1)) 116 | nmse_b = tf.sqrt(tf.reduce_sum(tf.square(target), axis=1)) 117 | elif output.get_shape().ndims == 4: # [batch_size, w, h, c] 118 | nmse_a = tf.sqrt(tf.reduce_sum(tf.squared_difference(output, target), axis=[1,2,3])) 119 | nmse_b = tf.sqrt(tf.reduce_sum(tf.square(target), axis=[1,2,3])) 120 | nmse = tf.reduce_mean(nmse_a / nmse_b) 121 | return nmse 122 | 123 | 124 | def dice_coe(output, target, epsilon=1e-10): 125 | """Sørensen–Dice coefficient for comparing the similarity of two distributions, 126 | usually be used for binary image segmentation i.e. labels are binary. 127 | The coefficient = [0, 1], 1 if totally match. 128 | 129 | Parameters 130 | ----------- 131 | output : tensor 132 | A distribution with shape: [batch_size, ....], (any dimensions). 133 | target : tensor 134 | A distribution with shape: [batch_size, ....], (any dimensions). 135 | epsilon : float 136 | An optional name to attach to this layer. 137 | 138 | Examples 139 | --------- 140 | >>> outputs = tl.act.pixel_wise_softmax(network.outputs) 141 | >>> dice_loss = 1 - tl.cost.dice_coe(outputs, y_, epsilon=1e-5) 142 | 143 | References 144 | ----------- 145 | - `wiki-dice `_ 146 | """ 147 | # inse = tf.reduce_sum( tf.mul(output, target) ) 148 | # l = tf.reduce_sum( tf.mul(output, output) ) 149 | # r = tf.reduce_sum( tf.mul(target, target) ) 150 | inse = tf.reduce_sum( output * target ) 151 | l = tf.reduce_sum( output * output ) 152 | r = tf.reduce_sum( target * target ) 153 | dice = 2 * (inse) / (l + r) 154 | if epsilon == 0: 155 | return dice 156 | else: 157 | return tf.clip_by_value(dice, 0, 1.0-epsilon) 158 | 159 | 160 | def dice_hard_coe(output, target, epsilon=1e-10): 161 | """Non-differentiable Sørensen–Dice coefficient for comparing the similarity of two distributions, 162 | usually be used for binary image segmentation i.e. labels are binary. 163 | The coefficient = [0, 1], 1 if totally match. 164 | 165 | Parameters 166 | ----------- 167 | output : tensor 168 | A distribution with shape: [batch_size, ....], (any dimensions). 169 | target : tensor 170 | A distribution with shape: [batch_size, ....], (any dimensions). 171 | epsilon : float 172 | An optional name to attach to this layer. 173 | 174 | Examples 175 | --------- 176 | >>> outputs = pixel_wise_softmax(network.outputs) 177 | >>> dice_loss = 1 - dice_coe(outputs, y_, epsilon=1e-5) 178 | 179 | References 180 | ----------- 181 | - `wiki-dice `_ 182 | """ 183 | output = tf.cast(output > 0.5, dtype=tf.float32) 184 | target = tf.cast(target > 0.5, dtype=tf.float32) 185 | inse = tf.reduce_sum( output * target ) 186 | l = tf.reduce_sum( output * output ) 187 | r = tf.reduce_sum( target * target ) 188 | dice = 2 * (inse) / (l + r) 189 | if epsilon == 0: 190 | return dice 191 | else: 192 | return tf.clip_by_value(dice, 0, 1.0-epsilon) 193 | 194 | def iou_coe(output, target, threshold=0.5, epsilon=1e-10): 195 | """Non-differentiable Intersection over Union, usually be used for evaluating binary image segmentation. 196 | The coefficient = [0, 1], 1 means totally match. 197 | 198 | Parameters 199 | ----------- 200 | output : tensor 201 | A distribution with shape: [batch_size, ....], (any dimensions). 202 | target : tensor 203 | A distribution with shape: [batch_size, ....], (any dimensions). 204 | threshold : float 205 | The threshold value to be true. 206 | epsilon : float 207 | A small value to avoid zero denominator when both output and target output nothing. 208 | 209 | Examples 210 | --------- 211 | >>> outputs = tl.act.pixel_wise_softmax(network.outputs) 212 | >>> iou = tl.cost.iou_coe(outputs[:,:,:,0], y_[:,:,:,0]) 213 | 214 | Notes 215 | ------ 216 | - IOU cannot be used as training loss, people usually use dice coefficient for training, and IOU for evaluating. 217 | """ 218 | pre = tf.cast(output > threshold, dtype=tf.float32) 219 | truth = tf.cast(target > threshold, dtype=tf.float32) 220 | intersection = tf.reduce_sum(pre * truth) 221 | union = tf.reduce_sum(tf.cast((pre + truth) > threshold, dtype=tf.float32)) 222 | return tf.reduce_sum(intersection) / (tf.reduce_sum(union) + epsilon) 223 | 224 | 225 | def cross_entropy_seq(logits, target_seqs, batch_size=None):#, batch_size=1, num_steps=None): 226 | """Returns the expression of cross-entropy of two sequences, implement 227 | softmax internally. Normally be used for Fixed Length RNN outputs. 228 | 229 | Parameters 230 | ---------- 231 | logits : Tensorflow variable 232 | 2D tensor, ``network.outputs``, [batch_size*n_steps (n_examples), number of output units] 233 | target_seqs : Tensorflow variable 234 | target : 2D tensor [batch_size, n_steps], if the number of step is dynamic, please use ``cross_entropy_seq_with_mask`` instead. 235 | batch_size : None or int. 236 | If not None, the return cost will be divided by batch_size. 237 | 238 | Examples 239 | -------- 240 | >>> see PTB tutorial for more details 241 | >>> input_data = tf.placeholder(tf.int32, [batch_size, num_steps]) 242 | >>> targets = tf.placeholder(tf.int32, [batch_size, num_steps]) 243 | >>> cost = tl.cost.cross_entropy_seq(network.outputs, targets) 244 | """ 245 | try: # TF 1.0 246 | sequence_loss_by_example_fn = tf.contrib.legacy_seq2seq.sequence_loss_by_example 247 | except: 248 | sequence_loss_by_example_fn = tf.nn.seq2seq.sequence_loss_by_example 249 | 250 | loss = sequence_loss_by_example_fn( 251 | [logits], 252 | [tf.reshape(target_seqs, [-1])], 253 | [tf.ones_like(tf.reshape(target_seqs, [-1]), dtype=tf.float32)]) 254 | # [tf.ones([batch_size * num_steps])]) 255 | cost = tf.reduce_sum(loss) #/ batch_size 256 | if batch_size is not None: 257 | cost = cost / batch_size 258 | return cost 259 | 260 | 261 | def cross_entropy_seq_with_mask(logits, target_seqs, input_mask, return_details=False, name=None): 262 | """Returns the expression of cross-entropy of two sequences, implement 263 | softmax internally. Normally be used for Dynamic RNN outputs. 264 | 265 | Parameters 266 | ----------- 267 | logits : network identity outputs 268 | 2D tensor, ``network.outputs``, [batch_size, number of output units]. 269 | target_seqs : int of tensor, like word ID. 270 | [batch_size, ?] 271 | input_mask : the mask to compute loss 272 | The same size with target_seqs, normally 0 and 1. 273 | return_details : boolean 274 | - If False (default), only returns the loss. 275 | - If True, returns the loss, losses, weights and targets (reshape to one vetcor). 276 | 277 | Examples 278 | -------- 279 | - see Image Captioning Example. 280 | """ 281 | targets = tf.reshape(target_seqs, [-1]) # to one vector 282 | weights = tf.to_float(tf.reshape(input_mask, [-1])) # to one vector like targets 283 | losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=targets, name=name) * weights 284 | #losses = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=targets, name=name)) # for TF1.0 and others 285 | 286 | try: ## TF1.0 287 | loss = tf.divide(tf.reduce_sum(losses), # loss from mask. reduce_sum before element-wise mul with mask !! 288 | tf.reduce_sum(weights), 289 | name="seq_loss_with_mask") 290 | except: ## TF0.12 291 | loss = tf.div(tf.reduce_sum(losses), # loss from mask. reduce_sum before element-wise mul with mask !! 292 | tf.reduce_sum(weights), 293 | name="seq_loss_with_mask") 294 | if return_details: 295 | return loss, losses, weights, targets 296 | else: 297 | return loss 298 | 299 | 300 | def cosine_similarity(v1, v2): 301 | """Cosine similarity [-1, 1], `wiki `_. 302 | 303 | Parameters 304 | ----------- 305 | v1, v2 : tensor of [batch_size, n_feature], with the same number of features. 306 | 307 | Returns 308 | ----------- 309 | a tensor of [batch_size, ] 310 | """ 311 | try: ## TF1.0 312 | cost = tf.reduce_sum(tf.multiply(v1, v2), 1) / (tf.sqrt(tf.reduce_sum(tf.multiply(v1, v1), 1)) * tf.sqrt(tf.reduce_sum(tf.multiply(v2, v2), 1))) 313 | except: ## TF0.12 314 | cost = tf.reduce_sum(tf.mul(v1, v2), reduction_indices=1) / (tf.sqrt(tf.reduce_sum(tf.mul(v1, v1), reduction_indices=1)) * tf.sqrt(tf.reduce_sum(tf.mul(v2, v2), reduction_indices=1))) 315 | return cost 316 | 317 | 318 | ## Regularization Functions 319 | def li_regularizer(scale, scope=None): 320 | """li regularization removes the neurons of previous layer, `i` represents `inputs`.\n 321 | Returns a function that can be used to apply group li regularization to weights.\n 322 | The implementation follows `TensorFlow contrib `_. 323 | 324 | Parameters 325 | ---------- 326 | scale : float 327 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 328 | scope: An optional scope name for TF12+. 329 | 330 | Returns 331 | -------- 332 | A function with signature `li(weights, name=None)` that apply Li regularization. 333 | 334 | Raises 335 | ------ 336 | ValueError : if scale is outside of the range [0.0, 1.0] or if scale is not a float. 337 | """ 338 | import numbers 339 | from tensorflow.python.framework import ops 340 | from tensorflow.python.ops import standard_ops 341 | # from tensorflow.python.platform import tf_logging as logging 342 | 343 | if isinstance(scale, numbers.Integral): 344 | raise ValueError('scale cannot be an integer: %s' % scale) 345 | if isinstance(scale, numbers.Real): 346 | if scale < 0.: 347 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 348 | scale) 349 | if scale >= 1.: 350 | raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 351 | scale) 352 | if scale == 0.: 353 | logging.info('Scale of 0 disables regularizer.') 354 | return lambda _, name=None: None 355 | 356 | def li(weights, name=None): 357 | """Applies li regularization to weights.""" 358 | with tf.name_scope('li_regularizer') as scope: 359 | my_scale = ops.convert_to_tensor(scale, 360 | dtype=weights.dtype.base_dtype, 361 | name='scale') 362 | if tf.__version__ <= '0.12': 363 | standard_ops_fn = standard_ops.mul 364 | else: 365 | standard_ops_fn = standard_ops.multiply 366 | return standard_ops_fn( 367 | my_scale, 368 | standard_ops.reduce_sum(standard_ops.sqrt(standard_ops.reduce_sum(tf.square(weights), 1))), 369 | name=scope) 370 | return li 371 | 372 | 373 | 374 | def lo_regularizer(scale, scope=None): 375 | """lo regularization removes the neurons of current layer, `o` represents `outputs`\n 376 | Returns a function that can be used to apply group lo regularization to weights.\n 377 | The implementation follows `TensorFlow contrib `_. 378 | 379 | Parameters 380 | ---------- 381 | scale : float 382 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 383 | scope: An optional scope name for TF12+. 384 | 385 | Returns 386 | ------- 387 | A function with signature `lo(weights, name=None)` that apply Lo regularization. 388 | 389 | Raises 390 | ------ 391 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 392 | """ 393 | import numbers 394 | from tensorflow.python.framework import ops 395 | from tensorflow.python.ops import standard_ops 396 | # from tensorflow.python.platform import tf_logging as logging 397 | 398 | if isinstance(scale, numbers.Integral): 399 | raise ValueError('scale cannot be an integer: %s' % scale) 400 | if isinstance(scale, numbers.Real): 401 | if scale < 0.: 402 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 403 | scale) 404 | if scale >= 1.: 405 | raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 406 | scale) 407 | if scale == 0.: 408 | logging.info('Scale of 0 disables regularizer.') 409 | return lambda _, name=None: None 410 | 411 | def lo(weights, name='lo_regularizer'): 412 | """Applies group column regularization to weights.""" 413 | with tf.name_scope(name) as scope: 414 | my_scale = ops.convert_to_tensor(scale, 415 | dtype=weights.dtype.base_dtype, 416 | name='scale') 417 | if tf.__version__ <= '0.12': 418 | standard_ops_fn = standard_ops.mul 419 | else: 420 | standard_ops_fn = standard_ops.multiply 421 | return standard_ops_fn( 422 | my_scale, 423 | standard_ops.reduce_sum(standard_ops.sqrt(standard_ops.reduce_sum(tf.square(weights), 0))), 424 | name=scope) 425 | return lo 426 | 427 | def maxnorm_regularizer(scale=1.0, scope=None): 428 | """Max-norm regularization returns a function that can be used 429 | to apply max-norm regularization to weights. 430 | About max-norm: `wiki `_.\n 431 | The implementation follows `TensorFlow contrib `_. 432 | 433 | Parameters 434 | ---------- 435 | scale : float 436 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 437 | scope: An optional scope name. 438 | 439 | Returns 440 | --------- 441 | A function with signature `mn(weights, name=None)` that apply Lo regularization. 442 | 443 | Raises 444 | -------- 445 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 446 | """ 447 | import numbers 448 | from tensorflow.python.framework import ops 449 | from tensorflow.python.ops import standard_ops 450 | 451 | if isinstance(scale, numbers.Integral): 452 | raise ValueError('scale cannot be an integer: %s' % scale) 453 | if isinstance(scale, numbers.Real): 454 | if scale < 0.: 455 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 456 | scale) 457 | # if scale >= 1.: 458 | # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 459 | # scale) 460 | if scale == 0.: 461 | logging.info('Scale of 0 disables regularizer.') 462 | return lambda _, name=None: None 463 | 464 | def mn(weights, name='max_regularizer'): 465 | """Applies max-norm regularization to weights.""" 466 | with tf.name_scope(name) as scope: 467 | my_scale = ops.convert_to_tensor(scale, 468 | dtype=weights.dtype.base_dtype, 469 | name='scale') 470 | if tf.__version__ <= '0.12': 471 | standard_ops_fn = standard_ops.mul 472 | else: 473 | standard_ops_fn = standard_ops.multiply 474 | return standard_ops_fn(my_scale, standard_ops.reduce_max(standard_ops.abs(weights)), name=scope) 475 | return mn 476 | 477 | def maxnorm_o_regularizer(scale, scope): 478 | """Max-norm output regularization removes the neurons of current layer.\n 479 | Returns a function that can be used to apply max-norm regularization to each column of weight matrix.\n 480 | The implementation follows `TensorFlow contrib `_. 481 | 482 | Parameters 483 | ---------- 484 | scale : float 485 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 486 | scope: An optional scope name. 487 | 488 | Returns 489 | --------- 490 | A function with signature `mn_o(weights, name=None)` that apply Lo regularization. 491 | 492 | Raises 493 | --------- 494 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 495 | """ 496 | import numbers 497 | from tensorflow.python.framework import ops 498 | from tensorflow.python.ops import standard_ops 499 | 500 | if isinstance(scale, numbers.Integral): 501 | raise ValueError('scale cannot be an integer: %s' % scale) 502 | if isinstance(scale, numbers.Real): 503 | if scale < 0.: 504 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 505 | scale) 506 | # if scale >= 1.: 507 | # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 508 | # scale) 509 | if scale == 0.: 510 | logging.info('Scale of 0 disables regularizer.') 511 | return lambda _, name=None: None 512 | 513 | def mn_o(weights, name='maxnorm_o_regularizer'): 514 | """Applies max-norm regularization to weights.""" 515 | with tf.name_scope(name) as scope: 516 | my_scale = ops.convert_to_tensor(scale, 517 | dtype=weights.dtype.base_dtype, 518 | name='scale') 519 | if tf.__version__ <= '0.12': 520 | standard_ops_fn = standard_ops.mul 521 | else: 522 | standard_ops_fn = standard_ops.multiply 523 | return standard_ops_fn(my_scale, standard_ops.reduce_sum(standard_ops.reduce_max(standard_ops.abs(weights), 0)), name=scope) 524 | return mn_o 525 | 526 | def maxnorm_i_regularizer(scale, scope=None): 527 | """Max-norm input regularization removes the neurons of previous layer.\n 528 | Returns a function that can be used to apply max-norm regularization to each row of weight matrix.\n 529 | The implementation follows `TensorFlow contrib `_. 530 | 531 | Parameters 532 | ---------- 533 | scale : float 534 | A scalar multiplier `Tensor`. 0.0 disables the regularizer. 535 | scope: An optional scope name. 536 | 537 | Returns 538 | --------- 539 | A function with signature `mn_i(weights, name=None)` that apply Lo regularization. 540 | 541 | Raises 542 | --------- 543 | ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float. 544 | """ 545 | import numbers 546 | from tensorflow.python.framework import ops 547 | from tensorflow.python.ops import standard_ops 548 | 549 | if isinstance(scale, numbers.Integral): 550 | raise ValueError('scale cannot be an integer: %s' % scale) 551 | if isinstance(scale, numbers.Real): 552 | if scale < 0.: 553 | raise ValueError('Setting a scale less than 0 on a regularizer: %g' % 554 | scale) 555 | # if scale >= 1.: 556 | # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % 557 | # scale) 558 | if scale == 0.: 559 | logging.info('Scale of 0 disables regularizer.') 560 | return lambda _, name=None: None 561 | 562 | def mn_i(weights, name='maxnorm_i_regularizer'): 563 | """Applies max-norm regularization to weights.""" 564 | with tf.name_scope(name) as scope: 565 | my_scale = ops.convert_to_tensor(scale, 566 | dtype=weights.dtype.base_dtype, 567 | name='scale') 568 | if tf.__version__ <= '0.12': 569 | standard_ops_fn = standard_ops.mul 570 | else: 571 | standard_ops_fn = standard_ops.multiply 572 | return standard_ops_fn(my_scale, standard_ops.reduce_sum(standard_ops.reduce_max(standard_ops.abs(weights), 1)), name=scope) 573 | return mn_i 574 | 575 | 576 | 577 | 578 | 579 | # 580 | -------------------------------------------------------------------------------- /tensorlayer/db.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | """ 4 | Experimental Database Management System. 5 | 6 | Latest Version 7 | """ 8 | 9 | 10 | import tensorflow as tf 11 | import tensorlayer as tl 12 | import numpy as np 13 | import time 14 | import math 15 | 16 | 17 | import uuid 18 | 19 | import pymongo 20 | import gridfs 21 | import pickle 22 | from pymongo import MongoClient 23 | from datetime import datetime 24 | 25 | import inspect 26 | 27 | def AutoFill(func): 28 | def func_wrapper(self,*args,**kwargs): 29 | d=inspect.getcallargs(func,self,*args,**kwargs) 30 | d['args'].update({"studyID":self.studyID}) 31 | return func(**d) 32 | return func_wrapper 33 | 34 | 35 | 36 | 37 | 38 | 39 | class TensorDB(object): 40 | """TensorDB is a MongoDB based manager that help you to manage data, network topology, parameters and logging. 41 | 42 | Parameters 43 | ------------- 44 | ip : string, localhost or IP address. 45 | port : int, port number. 46 | db_name : string, database name. 47 | user_name : string, set to None if it donnot need authentication. 48 | password : string. 49 | 50 | Properties 51 | ------------ 52 | db : ``pymongo.MongoClient[db_name]``, xxxxxx 53 | datafs : ``gridfs.GridFS(self.db, collection="datafs")``, xxxxxxxxxx 54 | modelfs : ``gridfs.GridFS(self.db, collection="modelfs")``, 55 | paramsfs : ``gridfs.GridFS(self.db, collection="paramsfs")``, 56 | db.Params : Collection for 57 | db.TrainLog : Collection for 58 | db.ValidLog : Collection for 59 | db.TestLog : Collection for 60 | studyID : string, unique ID, if None random generate one. 61 | 62 | Dependencies 63 | ------------- 64 | 1 : MongoDB, as TensorDB is based on MongoDB, you need to install it in your 65 | local machine or remote machine. 66 | 2 : pip install pymongo, for MongoDB python API. 67 | 68 | Optional Tools 69 | ---------------- 70 | 1 : You may like to install MongoChef or Mongo Management Studo APP for 71 | visualizing or testing your MongoDB. 72 | """ 73 | def __init__( 74 | self, 75 | ip = 'localhost', 76 | port = 27017, 77 | db_name = 'db_name', 78 | user_name = None, 79 | password = 'password', 80 | studyID=None 81 | ): 82 | ## connect mongodb 83 | client = MongoClient(ip, port) 84 | self.db = client[db_name] 85 | if user_name != None: 86 | self.db.authenticate(user_name, password) 87 | 88 | 89 | if studyID is None: 90 | self.studyID=str(uuid.uuid1()) 91 | else: 92 | self.studyID=studyID 93 | 94 | ## define file system (Buckets) 95 | self.datafs = gridfs.GridFS(self.db, collection="datafs") 96 | self.modelfs = gridfs.GridFS(self.db, collection="modelfs") 97 | self.paramsfs = gridfs.GridFS(self.db, collection="paramsfs") 98 | self.archfs=gridfs.GridFS(self.db,collection="ModelArchitecture") 99 | ## 100 | print("[TensorDB] Connect SUCCESS {}:{} {} {} {}".format(ip, port, db_name, user_name, studyID)) 101 | 102 | self.ip = ip 103 | self.port = port 104 | self.db_name = db_name 105 | self.user_name = user_name 106 | 107 | def __autofill(self,args): 108 | return args.update({'studyID':self.studyID}) 109 | 110 | def __serialization(self,ps): 111 | return pickle.dumps(ps, protocol=2) 112 | 113 | def __deserialization(self,ps): 114 | return pickle.loads(ps) 115 | 116 | def save_params(self, params=[], args={}):#, file_name='parameters'): 117 | """ Save parameters into MongoDB Buckets, and save the file ID into Params Collections. 118 | 119 | Parameters 120 | ---------- 121 | params : a list of parameters 122 | args : dictionary, item meta data. 123 | 124 | Returns 125 | --------- 126 | f_id : the Buckets ID of the parameters. 127 | """ 128 | self.__autofill(args) 129 | s = time.time() 130 | f_id = self.paramsfs.put(self.__serialization(params))#, file_name=file_name) 131 | args.update({'f_id': f_id, 'time': datetime.utcnow()}) 132 | self.db.Params.insert_one(args) 133 | # print("[TensorDB] Save params: {} SUCCESS, took: {}s".format(file_name, round(time.time()-s, 2))) 134 | print("[TensorDB] Save params: SUCCESS, took: {}s".format(round(time.time()-s, 2))) 135 | return f_id 136 | 137 | @AutoFill 138 | def find_one_params(self, args={},sort=None): 139 | """ Find one parameter from MongoDB Buckets. 140 | 141 | Parameters 142 | ---------- 143 | args : dictionary, find items. 144 | 145 | Returns 146 | -------- 147 | params : the parameters, return False if nothing found. 148 | f_id : the Buckets ID of the parameters, return False if nothing found. 149 | """ 150 | 151 | s = time.time() 152 | # print(args) 153 | d = self.db.Params.find_one(filter=args,sort=sort) 154 | 155 | if d is not None: 156 | f_id = d['f_id'] 157 | else: 158 | print("[TensorDB] FAIL! Cannot find: {}".format(args)) 159 | return False, False 160 | try: 161 | params = self.__deserialization(self.paramsfs.get(f_id).read()) 162 | print("[TensorDB] Find one params SUCCESS, {} took: {}s".format(args, round(time.time()-s, 2))) 163 | return params, f_id 164 | except: 165 | return False, False 166 | 167 | @AutoFill 168 | def find_all_params(self, args={}): 169 | """ Find all parameter from MongoDB Buckets 170 | 171 | Parameters 172 | ---------- 173 | args : dictionary, find items 174 | 175 | Returns 176 | -------- 177 | params : the parameters, return False if nothing found. 178 | 179 | """ 180 | 181 | s = time.time() 182 | pc = self.db.Params.find(args) 183 | 184 | if pc is not None: 185 | f_id_list = pc.distinct('f_id') 186 | params = [] 187 | for f_id in f_id_list: # you may have multiple Buckets files 188 | tmp = self.paramsfs.get(f_id).read() 189 | params.append(self.__deserialization(tmp)) 190 | else: 191 | print("[TensorDB] FAIL! Cannot find any: {}".format(args)) 192 | return False 193 | 194 | print("[TensorDB] Find all params SUCCESS, took: {}s".format(round(time.time()-s, 2))) 195 | return params 196 | 197 | @AutoFill 198 | def del_params(self, args={}): 199 | """ Delete params in MongoDB uckets. 200 | 201 | Parameters 202 | ----------- 203 | args : dictionary, find items to delete, leave it empty to delete all parameters. 204 | """ 205 | 206 | pc = self.db.Params.find(args) 207 | f_id_list = pc.distinct('f_id') 208 | # remove from Buckets 209 | for f in f_id_list: 210 | self.paramsfs.delete(f) 211 | # remove from Collections 212 | self.db.Params.remove(args) 213 | 214 | print("[TensorDB] Delete params SUCCESS: {}".format(args)) 215 | 216 | def _print_dict(self, args): 217 | # return " / ".join(str(key) + ": "+ str(value) for key, value in args.items()) 218 | 219 | string = '' 220 | for key, value in args.items(): 221 | if key is not '_id': 222 | string += str(key) + ": "+ str(value) + " / " 223 | return string 224 | 225 | ## =========================== LOG =================================== ## 226 | @AutoFill 227 | def train_log(self, args={}): 228 | """Save the training log. 229 | 230 | Parameters 231 | ----------- 232 | args : dictionary, items to save. 233 | 234 | Examples 235 | --------- 236 | >>> db.train_log(time=time.time(), {'loss': loss, 'acc': acc}) 237 | """ 238 | 239 | _result = self.db.TrainLog.insert_one(args) 240 | _log = self._print_dict(args) 241 | #print("[TensorDB] TrainLog: " +_log) 242 | return _result 243 | 244 | @AutoFill 245 | def del_train_log(self, args={}): 246 | """ Delete train log. 247 | 248 | Parameters 249 | ----------- 250 | args : dictionary, find items to delete, leave it empty to delete all log. 251 | """ 252 | 253 | self.db.TrainLog.delete_many(args) 254 | print("[TensorDB] Delete TrainLog SUCCESS") 255 | 256 | @AutoFill 257 | def valid_log(self, args={}): 258 | """Save the validating log. 259 | 260 | Parameters 261 | ----------- 262 | args : dictionary, items to save. 263 | 264 | Examples 265 | --------- 266 | >>> db.valid_log(time=time.time(), {'loss': loss, 'acc': acc}) 267 | """ 268 | 269 | _result = self.db.ValidLog.insert_one(args) 270 | # _log = "".join(str(key) + ": " + str(value) for key, value in args.items()) 271 | _log = self._print_dict(args) 272 | print("[TensorDB] ValidLog: " +_log) 273 | return _result 274 | 275 | @AutoFill 276 | def del_valid_log(self, args={}): 277 | """ Delete validation log. 278 | 279 | Parameters 280 | ----------- 281 | args : dictionary, find items to delete, leave it empty to delete all log. 282 | """ 283 | self.db.ValidLog.delete_many(args) 284 | print("[TensorDB] Delete ValidLog SUCCESS") 285 | 286 | @AutoFill 287 | def test_log(self, args={}): 288 | """Save the testing log. 289 | 290 | Parameters 291 | ----------- 292 | args : dictionary, items to save. 293 | 294 | Examples 295 | --------- 296 | >>> db.test_log(time=time.time(), {'loss': loss, 'acc': acc}) 297 | """ 298 | 299 | _result = self.db.TestLog.insert_one(args) 300 | # _log = "".join(str(key) + str(value) for key, value in args.items()) 301 | _log = self._print_dict(args) 302 | print("[TensorDB] TestLog: " +_log) 303 | return _result 304 | 305 | @AutoFill 306 | def del_test_log(self, args={}): 307 | """ Delete test log. 308 | 309 | Parameters 310 | ----------- 311 | args : dictionary, find items to delete, leave it empty to delete all log. 312 | """ 313 | 314 | self.db.TestLog.delete_many(args) 315 | print("[TensorDB] Delete TestLog SUCCESS") 316 | 317 | ## =========================== Network Architecture ================== ## 318 | @AutoFill 319 | def save_model_architecture(self,s,args={}): 320 | self.__autofill(args) 321 | fid=self.archfs.put(s,filename="modelarchitecture") 322 | args.update({"fid":fid}) 323 | self.db.march.insert_one(args) 324 | 325 | @AutoFill 326 | def load_model_architecture(self,args={}): 327 | 328 | d = self.db.march.find_one(args) 329 | if d is not None: 330 | fid = d['fid'] 331 | print(d) 332 | print(fid) 333 | # "print find" 334 | else: 335 | print("[TensorDB] FAIL! Cannot find: {}".format(args)) 336 | print ("no idtem") 337 | return False, False 338 | try: 339 | archs = self.archfs.get(fid).read() 340 | '''print("[TensorDB] Find one params SUCCESS, {} took: {}s".format(args, round(time.time()-s, 2)))''' 341 | return archs, fid 342 | except Exception as e: 343 | print("exception") 344 | print(e) 345 | return False, False 346 | 347 | @AutoFill 348 | def save_job(self, script=None, args={}): 349 | """Save the job. 350 | 351 | Parameters 352 | ----------- 353 | script : a script file name or None. 354 | args : dictionary, items to save. 355 | 356 | Examples 357 | --------- 358 | >>> # Save your job 359 | >>> db.save_job('your_script.py', {'job_id': 1, 'learning_rate': 0.01, 'n_units': 100}) 360 | >>> # Run your job 361 | >>> temp = db.find_one_job(args={'job_id': 1}) 362 | >>> print(temp['learning_rate']) 363 | ... 0.01 364 | >>> import _your_script 365 | ... running your script 366 | """ 367 | self.__autofill(args) 368 | if script is not None: 369 | _script = open(script, 'rb').read() 370 | args.update({'script': _script, 'script_name': script}) 371 | # _result = self.db.Job.insert_one(args) 372 | _result = self.db.Job.replace_one(args, args, upsert=True) 373 | _log = self._print_dict(args) 374 | print("[TensorDB] Save Job: script={}, args={}".format(script, args)) 375 | return _result 376 | 377 | @AutoFill 378 | def find_one_job(self, args={}): 379 | """ Find one job from MongoDB Job Collections. 380 | 381 | Parameters 382 | ---------- 383 | args : dictionary, find items. 384 | 385 | Returns 386 | -------- 387 | dictionary : contains all meta data and script. 388 | """ 389 | 390 | 391 | temp = self.db.Job.find_one(args) 392 | 393 | if temp is not None: 394 | if 'script_name' in temp.keys(): 395 | f = open('_' + temp['script_name'], 'wb') 396 | f.write(temp['script']) 397 | f.close() 398 | print("[TensorDB] Find Job: {}".format(args)) 399 | else: 400 | print("[TensorDB] FAIL! Cannot find any: {}".format(args)) 401 | return False 402 | 403 | return temp 404 | 405 | def push_job(self,margs, wargs,dargs,epoch): 406 | 407 | ms,mid=self.load_model_architecture(margs) 408 | weight,wid=self.find_one_params(wargs) 409 | args={"weight":wid,"model":mid,"dargs":dargs,"epoch":epoch,"time":datetime.utcnow(),"Running":False} 410 | self.__autofill(args) 411 | self.db.JOBS.insert_one(args) 412 | 413 | def peek_job(self): 414 | args={'Running':False} 415 | self.__autofill(args) 416 | m=self.db.JOBS.find_one(args) 417 | print(m) 418 | if m is None: 419 | return False 420 | 421 | s=self.paramsfs.get(m['weight']).read() 422 | w=self.__deserialization(s) 423 | 424 | ach=self.archfs.get(m['model']).read() 425 | 426 | return m['_id'], ach,w,m["dargs"],m['epoch'] 427 | 428 | def run_job(self,jid): 429 | self.db.JOBS.find_one_and_update({'_id':jid},{'$set': {'Running': True,"Since":datetime.utcnow()}}) 430 | 431 | def del_job(self,jid): 432 | self.db.JOBS.find_one_and_update({'_id':jid},{'$set': {'Running': True,"Finished":datetime.utcnow()}}) 433 | 434 | def __str__(self): 435 | _s = "[TensorDB] Info:\n" 436 | _t = _s + " " + str(self.db) 437 | return _t 438 | 439 | # def save_bulk_data(self, data=None, filename='filename'): 440 | # """ Put bulk data into TensorDB.datafs, return file ID. 441 | # When you have a very large data, you may like to save it into GridFS Buckets 442 | # instead of Collections, then when you want to load it, XXXX 443 | # 444 | # Parameters 445 | # ----------- 446 | # data : serialized data. 447 | # filename : string, GridFS Buckets. 448 | # 449 | # References 450 | # ----------- 451 | # - MongoDB find, xxxxx 452 | # """ 453 | # s = time.time() 454 | # f_id = self.datafs.put(data, filename=filename) 455 | # print("[TensorDB] save_bulk_data: {} took: {}s".format(filename, round(time.time()-s, 2))) 456 | # return f_id 457 | # 458 | # def save_collection(self, data=None, collect_name='collect_name'): 459 | # """ Insert data into MongoDB Collections, return xx. 460 | # 461 | # Parameters 462 | # ----------- 463 | # data : serialized data. 464 | # collect_name : string, MongoDB collection name. 465 | # 466 | # References 467 | # ----------- 468 | # - MongoDB find, xxxxx 469 | # """ 470 | # s = time.time() 471 | # rl = self.db[collect_name].insert_many(data) 472 | # print("[TensorDB] save_collection: {} took: {}s".format(collect_name, round(time.time()-s, 2))) 473 | # return rl 474 | # 475 | # def find(self, args={}, collect_name='collect_name'): 476 | # """ Find data from MongoDB Collections. 477 | # 478 | # Parameters 479 | # ----------- 480 | # args : dictionary, arguments for finding. 481 | # collect_name : string, MongoDB collection name. 482 | # 483 | # References 484 | # ----------- 485 | # - MongoDB find, xxxxx 486 | # """ 487 | # s = time.time() 488 | # 489 | # pc = self.db[collect_name].find(args) # pymongo.cursor.Cursor object 490 | # flist = pc.distinct('f_id') 491 | # fldict = {} 492 | # for f in flist: # you may have multiple Buckets files 493 | # # fldict[f] = pickle.loads(self.datafs.get(f).read()) 494 | # # s2 = time.time() 495 | # tmp = self.datafs.get(f).read() 496 | # # print(time.time()-s2) 497 | # fldict[f] = pickle.loads(tmp) 498 | # # print(time.time()-s2) 499 | # # exit() 500 | # # print(round(time.time()-s, 2)) 501 | # data = [fldict[x['f_id']][x['id']] for x in pc] 502 | # data = np.asarray(data) 503 | # print("[TensorDB] find: {} get: {} took: {}s".format(collect_name, pc.count(), round(time.time()-s, 2))) 504 | # return data 505 | 506 | 507 | 508 | class DBLogger: 509 | """ """ 510 | def __init__(self,db,model): 511 | self.db=db 512 | self.model=model 513 | 514 | def on_train_begin(self,logs={}): 515 | print("start") 516 | 517 | def on_train_end(self,logs={}): 518 | print("end") 519 | 520 | def on_epoch_begin(self,epoch,logs={}): 521 | self.epoch=epoch 522 | self.et=time.time() 523 | return 524 | 525 | def on_epoch_end(self, epoch, logs={}): 526 | self.et=time.time()-self.et 527 | print("ending") 528 | print(epoch) 529 | logs['epoch']=epoch 530 | logs['time']=datetime.utcnow() 531 | logs['stepTime']=self.et 532 | logs['acc']=np.asscalar(logs['acc']) 533 | print(logs) 534 | 535 | w=self.model.Params 536 | fid=self.db.save_params(w,logs) 537 | logs.update({'params':fid}) 538 | self.db.valid_log(logs) 539 | def on_batch_begin(self, batch,logs={}): 540 | self.t=time.time() 541 | self.losses = [] 542 | self.batch=batch 543 | 544 | def on_batch_end(self, batch, logs={}): 545 | self.t2=time.time()-self.t 546 | logs['acc']=np.asscalar(logs['acc']) 547 | #logs['loss']=np.asscalar(logs['loss']) 548 | logs['step_time']=self.t2 549 | logs['time']=datetime.utcnow() 550 | logs['epoch']=self.epoch 551 | logs['batch']=self.batch 552 | self.db.train_log(logs) 553 | -------------------------------------------------------------------------------- /tensorlayer/iterate.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | import numpy as np 7 | from six.moves import xrange 8 | 9 | def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False): 10 | """Generate a generator that input a group of example in numpy.array and 11 | their labels, return the examples and labels by the given batchsize. 12 | 13 | Parameters 14 | ---------- 15 | inputs : numpy.array 16 | (X) The input features, every row is a example. 17 | targets : numpy.array 18 | (y) The labels of inputs, every row is a example. 19 | batch_size : int 20 | The batch size. 21 | shuffle : boolean 22 | Indicating whether to use a shuffling queue, shuffle the dataset before return. 23 | 24 | Hints 25 | ------- 26 | - If you have two inputs, e.g. X1 (1000, 100) and X2 (1000, 80), you can ``np.hstack((X1, X2)) 27 | into (1000, 180) and feed into ``inputs``, then you can split a batch of X1 and X2. 28 | 29 | Examples 30 | -------- 31 | >>> X = np.asarray([['a','a'], ['b','b'], ['c','c'], ['d','d'], ['e','e'], ['f','f']]) 32 | >>> y = np.asarray([0,1,2,3,4,5]) 33 | >>> for batch in tl.iterate.minibatches(inputs=X, targets=y, batch_size=2, shuffle=False): 34 | >>> print(batch) 35 | ... (array([['a', 'a'], 36 | ... ['b', 'b']], 37 | ... dtype='>> X = np.asarray([['a','a'], ['b','b'], ['c','c'], ['d','d'], ['e','e'], ['f','f']]) 64 | >>> y = np.asarray([0, 1, 2, 3, 4, 5]) 65 | >>> for batch in tl.iterate.seq_minibatches(inputs=X, targets=y, batch_size=2, seq_length=2, stride=1): 66 | >>> print(batch) 67 | ... (array([['a', 'a'], 68 | ... ['b', 'b'], 69 | ... ['b', 'b'], 70 | ... ['c', 'c']], 71 | ... dtype='>> return_last = True 82 | >>> num_steps = 2 83 | >>> X = np.asarray([['a','a'], ['b','b'], ['c','c'], ['d','d'], ['e','e'], ['f','f']]) 84 | >>> Y = np.asarray([0,1,2,3,4,5]) 85 | >>> for batch in tl.iterate.seq_minibatches(inputs=X, targets=Y, batch_size=2, seq_length=num_steps, stride=1): 86 | >>> x, y = batch 87 | >>> if return_last: 88 | >>> tmp_y = y.reshape((-1, num_steps) + y.shape[1:]) 89 | >>> y = tmp_y[:, -1] 90 | >>> print(x, y) 91 | ... [['a' 'a'] 92 | ... ['b' 'b'] 93 | ... ['b' 'b'] 94 | ... ['c' 'c']] [1 2] 95 | ... [['c' 'c'] 96 | ... ['d' 'd'] 97 | ... ['d' 'd'] 98 | ... ['e' 'e']] [3 4] 99 | """ 100 | assert len(inputs) == len(targets) 101 | n_loads = (batch_size * stride) + (seq_length - stride) 102 | for start_idx in range(0, len(inputs) - n_loads + 1, (batch_size * stride)): 103 | seq_inputs = np.zeros((batch_size, seq_length) + inputs.shape[1:], 104 | dtype=inputs.dtype) 105 | seq_targets = np.zeros((batch_size, seq_length) + targets.shape[1:], 106 | dtype=targets.dtype) 107 | for b_idx in xrange(batch_size): 108 | start_seq_idx = start_idx + (b_idx * stride) 109 | end_seq_idx = start_seq_idx + seq_length 110 | seq_inputs[b_idx] = inputs[start_seq_idx:end_seq_idx] 111 | seq_targets[b_idx] = targets[start_seq_idx:end_seq_idx] 112 | flatten_inputs = seq_inputs.reshape((-1,) + inputs.shape[1:]) 113 | flatten_targets = seq_targets.reshape((-1,) + targets.shape[1:]) 114 | yield flatten_inputs, flatten_targets 115 | 116 | def seq_minibatches2(inputs, targets, batch_size, num_steps): 117 | """Generate a generator that iterates on two list of words. Yields (Returns) the source contexts and 118 | the target context by the given batch_size and num_steps (sequence_length), 119 | see ``PTB tutorial``. In TensorFlow's tutorial, this generates the batch_size pointers into the raw 120 | PTB data, and allows minibatch iteration along these pointers. 121 | 122 | - Hint, if the input data are images, you can modify the code as follow. 123 | 124 | .. code-block:: python 125 | 126 | from 127 | data = np.zeros([batch_size, batch_len) 128 | to 129 | data = np.zeros([batch_size, batch_len, inputs.shape[1], inputs.shape[2], inputs.shape[3]]) 130 | 131 | Parameters 132 | ---------- 133 | inputs : a list 134 | the context in list format; note that context usually be 135 | represented by splitting by space, and then convert to unique 136 | word IDs. 137 | targets : a list 138 | the context in list format; note that context usually be 139 | represented by splitting by space, and then convert to unique 140 | word IDs. 141 | batch_size : int 142 | the batch size. 143 | num_steps : int 144 | the number of unrolls. i.e. sequence_length 145 | 146 | Yields 147 | ------ 148 | Pairs of the batched data, each a matrix of shape [batch_size, num_steps]. 149 | 150 | Raises 151 | ------ 152 | ValueError : if batch_size or num_steps are too high. 153 | 154 | Examples 155 | -------- 156 | >>> X = [i for i in range(20)] 157 | >>> Y = [i for i in range(20,40)] 158 | >>> for batch in tl.iterate.seq_minibatches2(X, Y, batch_size=2, num_steps=3): 159 | ... x, y = batch 160 | ... print(x, y) 161 | ... 162 | ... [[ 0. 1. 2.] 163 | ... [ 10. 11. 12.]] 164 | ... [[ 20. 21. 22.] 165 | ... [ 30. 31. 32.]] 166 | ... 167 | ... [[ 3. 4. 5.] 168 | ... [ 13. 14. 15.]] 169 | ... [[ 23. 24. 25.] 170 | ... [ 33. 34. 35.]] 171 | ... 172 | ... [[ 6. 7. 8.] 173 | ... [ 16. 17. 18.]] 174 | ... [[ 26. 27. 28.] 175 | ... [ 36. 37. 38.]] 176 | 177 | Code References 178 | --------------- 179 | - ``tensorflow/models/rnn/ptb/reader.py`` 180 | """ 181 | assert len(inputs) == len(targets) 182 | data_len = len(inputs) 183 | batch_len = data_len // batch_size 184 | # data = np.zeros([batch_size, batch_len]) 185 | data = np.zeros((batch_size, batch_len) + inputs.shape[1:], 186 | dtype=inputs.dtype) 187 | data2 = np.zeros([batch_size, batch_len]) 188 | 189 | for i in range(batch_size): 190 | data[i] = inputs[batch_len * i:batch_len * (i + 1)] 191 | data2[i] = targets[batch_len * i:batch_len * (i + 1)] 192 | 193 | epoch_size = (batch_len - 1) // num_steps 194 | 195 | if epoch_size == 0: 196 | raise ValueError("epoch_size == 0, decrease batch_size or num_steps") 197 | 198 | for i in range(epoch_size): 199 | x = data[:, i*num_steps:(i+1)*num_steps] 200 | x2 = data2[:, i*num_steps:(i+1)*num_steps] 201 | yield (x, x2) 202 | 203 | 204 | def ptb_iterator(raw_data, batch_size, num_steps): 205 | """ 206 | Generate a generator that iterates on a list of words, see PTB tutorial. Yields (Returns) the source contexts and 207 | the target context by the given batch_size and num_steps (sequence_length).\n 208 | see ``PTB tutorial``. 209 | 210 | e.g. x = [0, 1, 2] y = [1, 2, 3] , when batch_size = 1, num_steps = 3, 211 | raw_data = [i for i in range(100)] 212 | 213 | In TensorFlow's tutorial, this generates batch_size pointers into the raw 214 | PTB data, and allows minibatch iteration along these pointers. 215 | 216 | Parameters 217 | ---------- 218 | raw_data : a list 219 | the context in list format; note that context usually be 220 | represented by splitting by space, and then convert to unique 221 | word IDs. 222 | batch_size : int 223 | the batch size. 224 | num_steps : int 225 | the number of unrolls. i.e. sequence_length 226 | 227 | Yields 228 | ------ 229 | Pairs of the batched data, each a matrix of shape [batch_size, num_steps]. 230 | The second element of the tuple is the same data time-shifted to the 231 | right by one. 232 | 233 | Raises 234 | ------ 235 | ValueError : if batch_size or num_steps are too high. 236 | 237 | Examples 238 | -------- 239 | >>> train_data = [i for i in range(20)] 240 | >>> for batch in tl.iterate.ptb_iterator(train_data, batch_size=2, num_steps=3): 241 | >>> x, y = batch 242 | >>> print(x, y) 243 | ... [[ 0 1 2] <---x 1st subset/ iteration 244 | ... [10 11 12]] 245 | ... [[ 1 2 3] <---y 246 | ... [11 12 13]] 247 | ... 248 | ... [[ 3 4 5] <--- 1st batch input 2nd subset/ iteration 249 | ... [13 14 15]] <--- 2nd batch input 250 | ... [[ 4 5 6] <--- 1st batch target 251 | ... [14 15 16]] <--- 2nd batch target 252 | ... 253 | ... [[ 6 7 8] 3rd subset/ iteration 254 | ... [16 17 18]] 255 | ... [[ 7 8 9] 256 | ... [17 18 19]] 257 | 258 | Code References 259 | ---------------- 260 | - ``tensorflow/models/rnn/ptb/reader.py`` 261 | """ 262 | raw_data = np.array(raw_data, dtype=np.int32) 263 | 264 | data_len = len(raw_data) 265 | batch_len = data_len // batch_size 266 | data = np.zeros([batch_size, batch_len], dtype=np.int32) 267 | for i in range(batch_size): 268 | data[i] = raw_data[batch_len * i:batch_len * (i + 1)] 269 | 270 | epoch_size = (batch_len - 1) // num_steps 271 | 272 | if epoch_size == 0: 273 | raise ValueError("epoch_size == 0, decrease batch_size or num_steps") 274 | 275 | for i in range(epoch_size): 276 | x = data[:, i*num_steps:(i+1)*num_steps] 277 | y = data[:, i*num_steps+1:(i+1)*num_steps+1] 278 | yield (x, y) 279 | 280 | 281 | 282 | # def minibatches_for_sequence2D(inputs, targets, batch_size, sequence_length, stride=1): 283 | # """ 284 | # Input a group of example in 2D numpy.array and their labels. 285 | # Return the examples and labels by the given batchsize, sequence_length. 286 | # Use for RNN. 287 | # 288 | # Parameters 289 | # ---------- 290 | # inputs : numpy.array 291 | # (X) The input features, every row is a example. 292 | # targets : numpy.array 293 | # (y) The labels of inputs, every row is a example. 294 | # batchsize : int 295 | # The batch size must be a multiple of sequence_length: int(batch_size % sequence_length) == 0 296 | # sequence_length : int 297 | # The sequence length 298 | # stride : int 299 | # The stride step 300 | # 301 | # Examples 302 | # -------- 303 | # >>> sequence_length = 2 304 | # >>> batch_size = 4 305 | # >>> stride = 1 306 | # >>> X_train = np.asarray([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15],[16,17,18],[19,20,21],[22,23,24]]) 307 | # >>> y_train = np.asarray(['0','1','2','3','4','5','6','7']) 308 | # >>> print('X_train = %s' % X_train) 309 | # >>> print('y_train = %s' % y_train) 310 | # >>> for batch in minibatches_for_sequence2D(X_train, y_train, batch_size=batch_size, sequence_length=sequence_length, stride=stride): 311 | # >>> inputs, targets = batch 312 | # >>> print(inputs) 313 | # >>> print(targets) 314 | # ... [[ 1. 2. 3.] 315 | # ... [ 4. 5. 6.] 316 | # ... [ 4. 5. 6.] 317 | # ... [ 7. 8. 9.]] 318 | # ... [1 2] 319 | # ... [[ 4. 5. 6.] 320 | # ... [ 7. 8. 9.] 321 | # ... [ 7. 8. 9.] 322 | # ... [ 10. 11. 12.]] 323 | # ... [2 3] 324 | # ... ... 325 | # ... [[ 16. 17. 18.] 326 | # ... [ 19. 20. 21.] 327 | # ... [ 19. 20. 21.] 328 | # ... [ 22. 23. 24.]] 329 | # ... [6 7] 330 | # """ 331 | # print('len(targets)=%d batch_size=%d sequence_length=%d stride=%d' % (len(targets), batch_size, sequence_length, stride)) 332 | # assert len(inputs) == len(targets), '1 feature vector have 1 target vector/value' #* sequence_length 333 | # # assert int(batch_size % sequence_length) == 0, 'batch_size % sequence_length must == 0\ 334 | # # batch_size is number of examples rather than number of targets' 335 | # 336 | # # print(inputs.shape, len(inputs), len(inputs[0])) 337 | # 338 | # n_targets = int(batch_size/sequence_length) 339 | # # n_targets = int(np.ceil(batch_size/sequence_length)) 340 | # X = np.empty(shape=(0,len(inputs[0])), dtype=np.float32) 341 | # y = np.zeros(shape=(1, n_targets), dtype=np.int32) 342 | # 343 | # for idx in range(sequence_length, len(inputs), stride): # go through all example during 1 epoch 344 | # for n in range(n_targets): # for num of target 345 | # X = np.concatenate((X, inputs[idx-sequence_length+n:idx+n])) 346 | # y[0][n] = targets[idx-1+n] 347 | # # y = np.vstack((y, targets[idx-1+n])) 348 | # yield X, y[0] 349 | # X = np.empty(shape=(0,len(inputs[0]))) 350 | # # y = np.empty(shape=(1,0)) 351 | # 352 | # 353 | # def minibatches_for_sequence4D(inputs, targets, batch_size, sequence_length, stride=1): # 354 | # """ 355 | # Input a group of example in 4D numpy.array and their labels. 356 | # Return the examples and labels by the given batchsize, sequence_length. 357 | # Use for RNN. 358 | # 359 | # Parameters 360 | # ---------- 361 | # inputs : numpy.array 362 | # (X) The input features, every row is a example. 363 | # targets : numpy.array 364 | # (y) The labels of inputs, every row is a example. 365 | # batchsize : int 366 | # The batch size must be a multiple of sequence_length: int(batch_size % sequence_length) == 0 367 | # sequence_length : int 368 | # The sequence length 369 | # stride : int 370 | # The stride step 371 | # 372 | # Examples 373 | # -------- 374 | # >>> sequence_length = 2 375 | # >>> batch_size = 2 376 | # >>> stride = 1 377 | # >>> X_train = np.asarray([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15],[16,17,18],[19,20,21],[22,23,24]]) 378 | # >>> y_train = np.asarray(['0','1','2','3','4','5','6','7']) 379 | # >>> X_train = np.expand_dims(X_train, axis=1) 380 | # >>> X_train = np.expand_dims(X_train, axis=3) 381 | # >>> for batch in minibatches_for_sequence4D(X_train, y_train, batch_size=batch_size, sequence_length=sequence_length, stride=stride): 382 | # >>> inputs, targets = batch 383 | # >>> print(inputs) 384 | # >>> print(targets) 385 | # ... [[[[ 1.] 386 | # ... [ 2.] 387 | # ... [ 3.]]] 388 | # ... [[[ 4.] 389 | # ... [ 5.] 390 | # ... [ 6.]]]] 391 | # ... [1] 392 | # ... [[[[ 4.] 393 | # ... [ 5.] 394 | # ... [ 6.]]] 395 | # ... [[[ 7.] 396 | # ... [ 8.] 397 | # ... [ 9.]]]] 398 | # ... [2] 399 | # ... ... 400 | # ... [[[[ 19.] 401 | # ... [ 20.] 402 | # ... [ 21.]]] 403 | # ... [[[ 22.] 404 | # ... [ 23.] 405 | # ... [ 24.]]]] 406 | # ... [7] 407 | # """ 408 | # print('len(targets)=%d batch_size=%d sequence_length=%d stride=%d' % (len(targets), batch_size, sequence_length, stride)) 409 | # assert len(inputs) == len(targets), '1 feature vector have 1 target vector/value' #* sequence_length 410 | # # assert int(batch_size % sequence_length) == 0, 'in LSTM, batch_size % sequence_length must == 0\ 411 | # # batch_size is number of X_train rather than number of targets' 412 | # assert stride >= 1, 'stride must be >=1, at least move 1 step for each iternation' 413 | # 414 | # n_example, n_channels, width, height = inputs.shape 415 | # print('n_example=%d n_channels=%d width=%d height=%d' % (n_example, n_channels, width, height)) 416 | # 417 | # n_targets = int(np.ceil(batch_size/sequence_length)) # 实际为 batchsize/sequence_length + 1 418 | # print(n_targets) 419 | # X = np.zeros(shape=(batch_size, n_channels, width, height), dtype=np.float32) 420 | # # X = np.zeros(shape=(n_targets, sequence_length, n_channels, width, height), dtype=np.float32) 421 | # y = np.zeros(shape=(1,n_targets), dtype=np.int32) 422 | # # y = np.empty(shape=(0,1), dtype=np.float32) 423 | # # time.sleep(2) 424 | # for idx in range(sequence_length, n_example-n_targets+2, stride): # go through all example during 1 epoch 425 | # for n in range(n_targets): # for num of target 426 | # # print(idx+n, inputs[idx-sequence_length+n : idx+n].shape) 427 | # X[n*sequence_length : (n+1)*sequence_length] = inputs[idx+n-sequence_length : idx+n] 428 | # # X[n] = inputs[idx-sequence_length+n:idx+n] 429 | # y[0][n] = targets[idx+n-1] 430 | # # y = np.vstack((y, targets[idx-1+n])) 431 | # # y = targets[idx: idx+n_targets] 432 | # yield X, y[0] 433 | -------------------------------------------------------------------------------- /tensorlayer/main-Copy3.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | import os, time, pickle, random, time 5 | from datetime import datetime 6 | import numpy as np 7 | from time import localtime, strftime 8 | import logging, scipy 9 | import keras.backend as K 10 | 11 | import tensorflow as tf 12 | import tensorlayer as tl 13 | from model import * 14 | from utils import * 15 | from config import config, log_config 16 | 17 | ###====================== HYPER-PARAMETERS ===========================### 18 | ## Adam 19 | batch_size = config.TRAIN.batch_size 20 | lr_init = config.TRAIN.lr_init 21 | beta1 = config.TRAIN.beta1 22 | ## initialize G 23 | n_epoch_init = config.TRAIN.n_epoch_init 24 | ## adversarial learning (SRGAN) 25 | n_epoch = config.TRAIN.n_epoch 26 | lr_decay = config.TRAIN.lr_decay 27 | decay_every = config.TRAIN.decay_every 28 | 29 | ni = int(np.sqrt(batch_size)) 30 | 31 | def limit_mem(): 32 | cfg = K.tf.ConfigProto() 33 | cfg.gpu_options.allow_growth = True 34 | K.set_session(K.tf.Session(config=cfg)) 35 | 36 | def read_all_imgs(img_list, path='', n_threads=32): 37 | """ Returns all images in array by given path and name of each image file. """ 38 | imgs = [] 39 | for idx in range(0, len(img_list), n_threads): 40 | b_imgs_list = img_list[idx : idx + n_threads] 41 | b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=path) 42 | # print(b_imgs.shape) 43 | imgs.extend(b_imgs) 44 | print('read %d from %s' % (len(imgs), path)) 45 | return imgs 46 | 47 | def train(): 48 | ## create folders to save result images and trained model 49 | save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode']) 50 | save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode']) 51 | tl.files.exists_or_mkdir(save_dir_ginit) 52 | tl.files.exists_or_mkdir(save_dir_gan) 53 | checkpoint_dir = "checkpoint" # checkpoint_resize_conv 54 | tl.files.exists_or_mkdir(checkpoint_dir) 55 | 56 | ###====================== PRE-LOAD DATA ===========================### 57 | train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.jpg', printable=False)) 58 | train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.jpg', printable=False)) 59 | valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.jpg', printable=False)) 60 | valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.jpg', printable=False)) 61 | 62 | ## If your machine have enough memory, please pre-load the whole train set. 63 | train_hr_imgs = read_all_imgs(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) 64 | # for im in train_hr_imgs: 65 | # print(im.shape) 66 | # valid_lr_imgs = read_all_imgs(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) 67 | # for im in valid_lr_imgs: 68 | # print(im.shape) 69 | # valid_hr_imgs = read_all_imgs(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) 70 | # for im in valid_hr_imgs: 71 | # print(im.shape) 72 | # exit() 73 | 74 | ###========================== DEFINE MODEL ============================### 75 | ## train inference 76 | t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator') 77 | t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image') 78 | 79 | net_g = SRGAN_g(t_image, is_train=True, reuse=False) 80 | net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False) 81 | _, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True) 82 | 83 | net_g.print_params(False) 84 | net_d.print_params(False) 85 | 86 | ## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA 87 | t_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer 88 | t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg 89 | 90 | net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224+1)/2, reuse=False) 91 | _, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224+1)/2, reuse=True) 92 | 93 | ## test inference 94 | net_g_test = SRGAN_g(t_image, is_train=False, reuse=True) 95 | 96 | # ###========================== DEFINE TRAIN OPS ==========================### 97 | # d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') 98 | # d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') 99 | 100 | # d_loss1 = tl.cost.cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') 101 | # d_loss2 = tl.cost.cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') 102 | # 103 | # d_loss = d_loss1 + d_loss2 104 | 105 | # Wasserstein GAN Loss 106 | 107 | d_loss = tf.reduce_mean(logits_real) - tf.reduce_mean(logits_fake) 108 | 109 | # g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g') 110 | g_gan_loss = 1e-3 * tf.reduce_mean(logits_fake) 111 | mse_loss = tl.cost.mean_squared_error(net_g.outputs , t_target_image, is_mean=True) 112 | vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True) 113 | 114 | g_loss = mse_loss + vgg_loss + g_gan_loss 115 | 116 | g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True) 117 | d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True) 118 | 119 | with tf.variable_scope('learning_rate'): 120 | lr_v = tf.Variable(lr_init, trainable=False) 121 | ## Pretrain 122 | # g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars) 123 | g_optim_init = tf.train.RMSpropOptimizer(lr_v).minimize(mse_loss, var_list=g_vars) 124 | 125 | ## SRGAN 126 | # g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) 127 | # d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) 128 | 129 | g_optim = tf.train.RMSpropOptimizer(lr_v).minimize(g_loss, var_list=g_vars) 130 | d_optim = tf.train.RMSpropOptimizer(lr_v).minimize(d_loss, var_list=d_vars) 131 | 132 | # clip op 133 | clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in d_vars] 134 | 135 | 136 | ###========================== RESTORE MODEL =============================### 137 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 138 | tl.layers.initialize_global_variables(sess) 139 | if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False: 140 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g) 141 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/d_{}.npz'.format(tl.global_flag['mode']), network=net_d) 142 | 143 | ###============================= LOAD VGG ===============================### 144 | vgg19_npy_path = "vgg19.npy" 145 | if not os.path.isfile(vgg19_npy_path): 146 | print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg") 147 | exit() 148 | npz = np.load(vgg19_npy_path, encoding='latin1').item() 149 | 150 | params = [] 151 | for val in sorted( npz.items() ): 152 | W = np.asarray(val[1][0]) 153 | b = np.asarray(val[1][1]) 154 | print(" Loading %s: %s, %s" % (val[0], W.shape, b.shape)) 155 | params.extend([W, b]) 156 | tl.files.assign_params(sess, params, net_vgg) 157 | # net_vgg.print_params(False) 158 | # net_vgg.print_layers() 159 | 160 | ###============================= TRAINING ===============================### 161 | ## use first `batch_size` of train set to have a quick test during training 162 | sample_imgs = train_hr_imgs[0:batch_size] 163 | # sample_imgs = read_all_imgs(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set 164 | sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False) 165 | print('sample HR sub-image:',sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max()) 166 | sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn) 167 | print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max()) 168 | tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit+'/_train_sample_96.jpg') 169 | tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit+'/_train_sample_384.jpg') 170 | tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan+'/_train_sample_96.jpg') 171 | tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan+'/_train_sample_384.jpg') 172 | 173 | ###========================= initialize G ====================### 174 | ## fixed learning rate 175 | sess.run(tf.assign(lr_v, lr_init)) 176 | print(" ** fixed learning rate: %f (for init G)" % lr_init) 177 | for epoch in range(0, n_epoch_init+1): 178 | epoch_time = time.time() 179 | total_mse_loss, n_iter = 0, 0 180 | 181 | ## If your machine cannot load all images into memory, you should use 182 | ## this one to load batch of images while training. 183 | # random.shuffle(train_hr_img_list) 184 | # for idx in range(0, len(train_hr_img_list), batch_size): 185 | # step_time = time.time() 186 | # b_imgs_list = train_hr_img_list[idx : idx + batch_size] 187 | # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) 188 | # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) 189 | # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 190 | 191 | ## If your machine have enough memory, please pre-load the whole train set. 192 | for idx in range(0, len(train_hr_imgs), batch_size): 193 | step_time = time.time() 194 | b_imgs_384 = tl.prepro.threading_data( 195 | train_hr_imgs[idx : idx + batch_size], 196 | fn=crop_sub_imgs_fn, is_random=True) 197 | b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 198 | ## update G 199 | errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384}) 200 | print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM)) 201 | total_mse_loss += errM 202 | n_iter += 1 203 | log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss/n_iter) 204 | print(log) 205 | 206 | ## quick evaluation on train set 207 | if (epoch != 0) and (epoch % 10 == 0): 208 | out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max()) 209 | print("[*] save images") 210 | tl.vis.save_images(out, [ni, ni], save_dir_ginit+'/train_%d.jpg' % epoch) 211 | 212 | ## save model 213 | if (epoch != 0) and (epoch % 10 == 0): 214 | tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess) 215 | 216 | ###========================= train GAN (SRGAN) =========================### 217 | 218 | # clipping method 219 | # clip_discriminator_var_op = [var.assign(tf.clip_by_value(var, self.clip_values[0], self.clip_values[1])) for 220 | # var in self.discriminator_variables] 221 | 222 | for epoch in range(0, n_epoch+1): 223 | ## update learning rate 224 | if epoch !=0 and (epoch % decay_every == 0): 225 | new_lr_decay = lr_decay ** (epoch // decay_every) 226 | sess.run(tf.assign(lr_v, lr_init * new_lr_decay)) 227 | log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay) 228 | print(log) 229 | elif epoch == 0: 230 | sess.run(tf.assign(lr_v, lr_init)) 231 | log = " ** init lr: %f decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay) 232 | print(log) 233 | 234 | epoch_time = time.time() 235 | total_d_loss, total_g_loss, n_iter = 0, 0, 0 236 | 237 | ## If your machine cannot load all images into memory, you should use 238 | ## this one to load batch of images while training. 239 | # random.shuffle(train_hr_img_list) 240 | # for idx in range(0, len(train_hr_img_list), batch_size): 241 | # step_time = time.time() 242 | # b_imgs_list = train_hr_img_list[idx : idx + batch_size] 243 | # b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path) 244 | # b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True) 245 | # b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 246 | 247 | ## If your machine have enough memory, please pre-load the whole train set. 248 | for idx in range(0, len(train_hr_imgs), batch_size): 249 | step_time = time.time() 250 | b_imgs_384 = tl.prepro.threading_data( 251 | train_hr_imgs[idx : idx + batch_size], 252 | fn=crop_sub_imgs_fn, is_random=True) 253 | b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn) 254 | ## update D 255 | errD, _, _ = sess.run([d_loss, d_optim, clip_D], {t_image: b_imgs_96, t_target_image: b_imgs_384}) 256 | # d_vars = sess.run(clip_discriminator_var_op) 257 | ## update G 258 | errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], 259 | {t_image: b_imgs_96, t_target_image: b_imgs_384}) 260 | 261 | print("Epoch [%2d/%2d] %4d time: %4.4fs, W_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" 262 | % (epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA)) 263 | total_d_loss += errD 264 | total_g_loss += errG 265 | n_iter += 1 266 | 267 | log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss/n_iter, total_g_loss/n_iter) 268 | print(log) 269 | 270 | ## quick evaluation on train set 271 | if (epoch != 0) and (epoch % 10 == 0): 272 | out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})#; print('gen sub-image:', out.shape, out.min(), out.max()) 273 | print("[*] save images") 274 | tl.vis.save_images(out, [ni, ni], save_dir_gan+'/train_%d.jpg' % epoch) 275 | 276 | ## save model 277 | if (epoch != 0) and (epoch % 10 == 0): 278 | tl.files.save_npz(net_g.all_params, name=checkpoint_dir+'/g_{}.npz'.format(tl.global_flag['mode']), sess=sess) 279 | tl.files.save_npz(net_d.all_params, name=checkpoint_dir+'/d_{}.npz'.format(tl.global_flag['mode']), sess=sess) 280 | 281 | def evaluate(): 282 | ## create folders to save result images 283 | save_dir = "samples/{}".format(tl.global_flag['mode']) 284 | tl.files.exists_or_mkdir(save_dir) 285 | checkpoint_dir = "checkpoint" 286 | 287 | ###====================== PRE-LOAD DATA ===========================### 288 | # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.jpg', printable=False)) 289 | # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.jpg', printable=False)) 290 | valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.jpg', printable=False)) 291 | valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.jpg', printable=False)) 292 | 293 | ## If your machine have enough memory, please pre-load the whole train set. 294 | # train_hr_imgs = read_all_imgs(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) 295 | # for im in train_hr_imgs: 296 | # print(im.shape) 297 | valid_lr_imgs = read_all_imgs(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) 298 | # for im in valid_lr_imgs: 299 | # print(im.shape) 300 | valid_hr_imgs = read_all_imgs(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32) 301 | # for im in valid_hr_imgs: 302 | # print(im.shape) 303 | # exit() 304 | 305 | ###========================== DEFINE MODEL ============================### 306 | imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 307 | #valid_lr_img = valid_lr_imgs[imid] 308 | #valid_hr_img = valid_hr_imgs[imid] 309 | img_name = 'SKU473723_1.jpg' 310 | valid_lr_img = get_imgs_fn(img_name, '/home/ubuntu/dataset/sr_test/testing/') # if you want to test your own image 311 | valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] 312 | # print(valid_lr_img.min(), valid_lr_img.max()) 313 | 314 | size = valid_lr_img.shape 315 | t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') 316 | # t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image') 317 | 318 | net_g = SRGAN_g(t_image, is_train=False, reuse=False) 319 | 320 | ###========================== RESTORE G =============================### 321 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) 322 | tl.layers.initialize_global_variables(sess) 323 | tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir+'/g_srgan.npz', network=net_g) 324 | 325 | ###======================= EVALUATION =============================### 326 | start_time = time.time() 327 | out = sess.run(net_g.outputs, {t_image: [valid_lr_img]}) 328 | print("took: %4.4fs" % (time.time() - start_time)) 329 | 330 | print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) 331 | print("[*] save images") 332 | tl.vis.save_image(out[0], save_dir+ '/gen_' + img_name[:-4] + '.jpg') 333 | #tl.vis.save_image(valid_lr_img, save_dir+'/valid_lr.png') 334 | #tl.vis.save_image(valid_hr_img, save_dir+'/valid_hr.png') 335 | 336 | out_bicu = scipy.misc.imresize(valid_lr_img, [size[0]*4, size[1]*4], interp='bicubic', mode=None) 337 | tl.vis.save_image(out_bicu, save_dir + '/bicubic_' + img_name[:-4] + '.jpg') 338 | 339 | if __name__ == '__main__': 340 | import argparse 341 | parser = argparse.ArgumentParser() 342 | 343 | parser.add_argument('--mode', type=str, default='srgan', help='srgan, evaluate') 344 | 345 | args = parser.parse_args() 346 | 347 | tl.global_flag['mode'] = args.mode 348 | 349 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 350 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 351 | 352 | 353 | if tl.global_flag['mode'] == 'srgan': 354 | limit_mem() 355 | train() 356 | elif tl.global_flag['mode'] == 'evaluate': 357 | limit_mem() 358 | evaluate() 359 | else: 360 | raise Exception("Unknow --mode") 361 | -------------------------------------------------------------------------------- /tensorlayer/nlp.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | 7 | import tensorflow as tf 8 | import os 9 | from sys import platform as _platform 10 | import collections 11 | import random 12 | import numpy as np 13 | import warnings 14 | from six.moves import xrange 15 | from tensorflow.python.platform import gfile 16 | import re 17 | 18 | ## Iteration functions 19 | def generate_skip_gram_batch(data, batch_size, num_skips, skip_window, data_index=0): 20 | """Generate a training batch for the Skip-Gram model. 21 | 22 | Parameters 23 | ---------- 24 | data : a list 25 | To present context. 26 | batch_size : an int 27 | Batch size to return. 28 | num_skips : an int 29 | How many times to reuse an input to generate a label. 30 | skip_window : an int 31 | How many words to consider left and right. 32 | data_index : an int 33 | Index of the context location. 34 | without using yield, this code use data_index to instead. 35 | 36 | Returns 37 | -------- 38 | batch : a list 39 | Inputs 40 | labels : a list 41 | Labels 42 | data_index : an int 43 | Index of the context location. 44 | 45 | Examples 46 | -------- 47 | >>> Setting num_skips=2, skip_window=1, use the right and left words. 48 | >>> In the same way, num_skips=4, skip_window=2 means use the nearby 4 words. 49 | 50 | >>> data = [1,2,3,4,5,6,7,8,9,10,11] 51 | >>> batch, labels, data_index = tl.nlp.generate_skip_gram_batch(data=data, batch_size=8, num_skips=2, skip_window=1, data_index=0) 52 | >>> print(batch) 53 | ... [2 2 3 3 4 4 5 5] 54 | >>> print(labels) 55 | ... [[3] 56 | ... [1] 57 | ... [4] 58 | ... [2] 59 | ... [5] 60 | ... [3] 61 | ... [4] 62 | ... [6]] 63 | 64 | References 65 | ----------- 66 | - `TensorFlow word2vec tutorial `_ 67 | """ 68 | # global data_index # you can put data_index outside the function, then 69 | # modify the global data_index in the function without return it. 70 | # note: without using yield, this code use data_index to instead. 71 | assert batch_size % num_skips == 0 72 | assert num_skips <= 2 * skip_window 73 | batch = np.ndarray(shape=(batch_size), dtype=np.int32) 74 | labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) 75 | span = 2 * skip_window + 1 # [ skip_window target skip_window ] 76 | buffer = collections.deque(maxlen=span) 77 | for _ in range(span): 78 | buffer.append(data[data_index]) 79 | data_index = (data_index + 1) % len(data) 80 | for i in range(batch_size // num_skips): 81 | target = skip_window # target label at the center of the buffer 82 | targets_to_avoid = [ skip_window ] 83 | for j in range(num_skips): 84 | while target in targets_to_avoid: 85 | target = random.randint(0, span - 1) 86 | targets_to_avoid.append(target) 87 | batch[i * num_skips + j] = buffer[skip_window] 88 | labels[i * num_skips + j, 0] = buffer[target] 89 | buffer.append(data[data_index]) 90 | data_index = (data_index + 1) % len(data) 91 | return batch, labels, data_index 92 | 93 | 94 | ## Sampling functions 95 | def sample(a=[], temperature=1.0): 96 | """Sample an index from a probability array. 97 | 98 | Parameters 99 | ---------- 100 | a : a list 101 | List of probabilities. 102 | temperature : float or None 103 | The higher the more uniform.\n 104 | When a = [0.1, 0.2, 0.7],\n 105 | temperature = 0.7, the distribution will be sharpen [ 0.05048273 0.13588945 0.81362782]\n 106 | temperature = 1.0, the distribution will be the same [0.1 0.2 0.7]\n 107 | temperature = 1.5, the distribution will be filtered [ 0.16008435 0.25411807 0.58579758]\n 108 | If None, it will be ``np.argmax(a)`` 109 | 110 | Notes 111 | ------ 112 | No matter what is the temperature and input list, the sum of all probabilities will be one. 113 | Even if input list = [1, 100, 200], the sum of all probabilities will still be one. 114 | 115 | For large vocabulary_size, choice a higher temperature to avoid error. 116 | """ 117 | b = np.copy(a) 118 | try: 119 | if temperature == 1: 120 | return np.argmax(np.random.multinomial(1, a, 1)) 121 | if temperature is None: 122 | return np.argmax(a) 123 | else: 124 | a = np.log(a) / temperature 125 | a = np.exp(a) / np.sum(np.exp(a)) 126 | return np.argmax(np.random.multinomial(1, a, 1)) 127 | except: 128 | # np.set_printoptions(threshold=np.nan) 129 | # print(a) 130 | # print(np.sum(a)) 131 | # print(np.max(a)) 132 | # print(np.min(a)) 133 | # exit() 134 | message = "For large vocabulary_size, choice a higher temperature\ 135 | to avoid log error. Hint : use ``sample_top``. " 136 | warnings.warn(message, Warning) 137 | # print(a) 138 | # print(b) 139 | return np.argmax(np.random.multinomial(1, b, 1)) 140 | 141 | def sample_top(a=[], top_k=10): 142 | """Sample from ``top_k`` probabilities. 143 | 144 | Parameters 145 | ---------- 146 | a : a list 147 | List of probabilities. 148 | top_k : int 149 | Number of candidates to be considered. 150 | """ 151 | idx = np.argpartition(a, -top_k)[-top_k:] 152 | probs = a[idx] 153 | # print("new", probs) 154 | probs = probs / np.sum(probs) 155 | choice = np.random.choice(idx, p=probs) 156 | return choice 157 | ## old implementation 158 | # a = np.array(a) 159 | # idx = np.argsort(a)[::-1] 160 | # idx = idx[:top_k] 161 | # # a = a[idx] 162 | # probs = a[idx] 163 | # print("prev", probs) 164 | # # probs = probs / np.sum(probs) 165 | # # choice = np.random.choice(idx, p=probs) 166 | # # return choice 167 | 168 | 169 | ## Vector representations of words (Advanced) UNDOCUMENT 170 | class SimpleVocabulary(object): 171 | """Simple vocabulary wrapper, see create_vocab(). 172 | 173 | Parameters 174 | ------------ 175 | vocab : A dictionary of word to word_id. 176 | unk_id : Id of the special 'unknown' word. 177 | """ 178 | 179 | def __init__(self, vocab, unk_id): 180 | """Initializes the vocabulary.""" 181 | 182 | 183 | self._vocab = vocab 184 | self._unk_id = unk_id 185 | 186 | def word_to_id(self, word): 187 | """Returns the integer id of a word string.""" 188 | if word in self._vocab: 189 | return self._vocab[word] 190 | else: 191 | return self._unk_id 192 | 193 | class Vocabulary(object): 194 | """Create Vocabulary class from a given vocabulary and its id-word, word-id convert, 195 | see create_vocab() and ``tutorial_tfrecord3.py``. 196 | 197 | Parameters 198 | ----------- 199 | vocab_file : File containing the vocabulary, where the words are the first 200 | whitespace-separated token on each line (other tokens are ignored) and 201 | the word ids are the corresponding line numbers. 202 | start_word : Special word denoting sentence start. 203 | end_word : Special word denoting sentence end. 204 | unk_word : Special word denoting unknown words. 205 | 206 | Properties 207 | ------------ 208 | vocab : a dictionary from word to id. 209 | reverse_vocab : a list from id to word. 210 | start_id : int of start id 211 | end_id : int of end id 212 | unk_id : int of unk id 213 | pad_id : int of padding id 214 | 215 | Vocab_files 216 | ------------- 217 | >>> Look as follow, includes `start_word` , `end_word` but no `unk_word` . 218 | >>> a 969108 219 | >>> 586368 220 | >>> 586368 221 | >>> . 440479 222 | >>> on 213612 223 | >>> of 202290 224 | >>> the 196219 225 | >>> in 182598 226 | >>> with 152984 227 | >>> and 139109 228 | >>> is 97322 229 | """ 230 | 231 | def __init__(self, 232 | vocab_file, 233 | start_word="", 234 | end_word="", 235 | unk_word="", 236 | pad_word=""): 237 | if not tf.gfile.Exists(vocab_file): 238 | tf.logging.fatal("Vocab file %s not found.", vocab_file) 239 | tf.logging.info("Initializing vocabulary from file: %s", vocab_file) 240 | 241 | with tf.gfile.GFile(vocab_file, mode="r") as f: 242 | reverse_vocab = list(f.readlines()) 243 | reverse_vocab = [line.split()[0] for line in reverse_vocab] 244 | assert start_word in reverse_vocab 245 | assert end_word in reverse_vocab 246 | if unk_word not in reverse_vocab: 247 | reverse_vocab.append(unk_word) 248 | vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) 249 | 250 | print(" [TL] Vocabulary from %s : %s %s %s" % (vocab_file, start_word, end_word, unk_word)) 251 | print(" vocabulary with %d words (includes start_word, end_word, unk_word)" % len(vocab)) 252 | # tf.logging.info(" vocabulary with %d words" % len(vocab)) 253 | 254 | self.vocab = vocab # vocab[word] = id 255 | self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word 256 | 257 | # Save special word ids. 258 | self.start_id = vocab[start_word] 259 | self.end_id = vocab[end_word] 260 | self.unk_id = vocab[unk_word] 261 | self.pad_id = vocab[pad_word] 262 | print(" start_id: %d" % self.start_id) 263 | print(" end_id: %d" % self.end_id) 264 | print(" unk_id: %d" % self.unk_id) 265 | print(" pad_id: %d" % self.pad_id) 266 | 267 | def word_to_id(self, word): 268 | """Returns the integer word id of a word string.""" 269 | if word in self.vocab: 270 | return self.vocab[word] 271 | else: 272 | return self.unk_id 273 | 274 | def id_to_word(self, word_id): 275 | """Returns the word string of an integer word id.""" 276 | if word_id >= len(self.reverse_vocab): 277 | return self.reverse_vocab[self.unk_id] 278 | else: 279 | return self.reverse_vocab[word_id] 280 | 281 | def process_sentence(sentence, start_word="", end_word=""): 282 | """Converts a sentence string into a list of string words, add start_word and end_word, 283 | see ``create_vocab()`` and ``tutorial_tfrecord3.py``. 284 | 285 | Parameter 286 | --------- 287 | sentence : a sentence in string. 288 | start_word : a string or None, if None, non start word will be appended. 289 | end_word : a string or None, if None, non end word will be appended. 290 | 291 | Returns 292 | --------- 293 | A list of strings; the processed caption. 294 | 295 | Examples 296 | ----------- 297 | >>> c = "how are you?" 298 | >>> c = tl.nlp.process_sentence(c) 299 | >>> print(c) 300 | ... ['', 'how', 'are', 'you', '?', ''] 301 | 302 | Notes 303 | ------- 304 | - You have to install the following package. 305 | - `Installing NLTK `_ 306 | - `Installing NLTK data `_ 307 | """ 308 | try: 309 | import nltk 310 | except: 311 | raise Exception("Hint : NLTK is required.") 312 | if start_word is not None: 313 | process_sentence = [start_word] 314 | else: 315 | process_sentence = [] 316 | process_sentence.extend(nltk.tokenize.word_tokenize(sentence.lower())) 317 | if end_word is not None: 318 | process_sentence.append(end_word) 319 | return process_sentence 320 | 321 | def create_vocab(sentences, word_counts_output_file, min_word_count=1): 322 | """Creates the vocabulary of word to word_id, see create_vocab() and ``tutorial_tfrecord3.py``. 323 | 324 | The vocabulary is saved to disk in a text file of word counts. The id of each 325 | word in the file is its corresponding 0-based line number. 326 | 327 | Parameters 328 | ------------ 329 | sentences : a list of lists of strings. 330 | word_counts_output_file : A string 331 | The file name. 332 | min_word_count : a int 333 | Minimum number of occurrences for a word. 334 | 335 | Returns 336 | -------- 337 | - tl.nlp.SimpleVocabulary object. 338 | 339 | Mores 340 | ----- 341 | - ``tl.nlp.build_vocab()`` 342 | 343 | Examples 344 | -------- 345 | >>> captions = ["one two , three", "four five five"] 346 | >>> processed_capts = [] 347 | >>> for c in captions: 348 | >>> c = tl.nlp.process_sentence(c, start_word="", end_word="") 349 | >>> processed_capts.append(c) 350 | >>> print(processed_capts) 351 | ...[['', 'one', 'two', ',', 'three', ''], ['', 'four', 'five', 'five', '']] 352 | 353 | >>> tl.nlp.create_vocab(processed_capts, word_counts_output_file='vocab.txt', min_word_count=1) 354 | ... [TL] Creating vocabulary. 355 | ... Total words: 8 356 | ... Words in vocabulary: 8 357 | ... Wrote vocabulary file: vocab.txt 358 | >>> vocab = tl.nlp.Vocabulary('vocab.txt', start_word="", end_word="", unk_word="") 359 | ... INFO:tensorflow:Initializing vocabulary from file: vocab.txt 360 | ... [TL] Vocabulary from vocab.txt : 361 | ... vocabulary with 10 words (includes start_word, end_word, unk_word) 362 | ... start_id: 2 363 | ... end_id: 3 364 | ... unk_id: 9 365 | ... pad_id: 0 366 | """ 367 | from collections import Counter 368 | print(" [TL] Creating vocabulary.") 369 | counter = Counter() 370 | for c in sentences: 371 | counter.update(c) 372 | # print('c',c) 373 | print(" Total words: %d" % len(counter)) 374 | 375 | # Filter uncommon words and sort by descending count. 376 | word_counts = [x for x in counter.items() if x[1] >= min_word_count] 377 | word_counts.sort(key=lambda x: x[1], reverse=True) 378 | word_counts = [("", 0)] + word_counts # 1st id should be reserved for padding 379 | # print(word_counts) 380 | print(" Words in vocabulary: %d" % len(word_counts)) 381 | 382 | # Write out the word counts file. 383 | with tf.gfile.FastGFile(word_counts_output_file, "w") as f: 384 | f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts])) 385 | print(" Wrote vocabulary file: %s" % word_counts_output_file) 386 | 387 | # Create the vocabulary dictionary. 388 | reverse_vocab = [x[0] for x in word_counts] 389 | unk_id = len(reverse_vocab) 390 | vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)]) 391 | vocab = SimpleVocabulary(vocab_dict, unk_id) 392 | 393 | return vocab 394 | 395 | 396 | ## Vector representations of words 397 | def simple_read_words(filename="nietzsche.txt"): 398 | """Read context from file without any preprocessing. 399 | 400 | Parameters 401 | ---------- 402 | filename : a string 403 | A file path (like .txt file) 404 | 405 | Returns 406 | -------- 407 | The context in a string 408 | """ 409 | with open("nietzsche.txt", "r") as f: 410 | words = f.read() 411 | return words 412 | 413 | def read_words(filename="nietzsche.txt", replace = ['\n', '']): 414 | """File to list format context. Note that, this script can not handle punctuations. 415 | For customized read_words method, see ``tutorial_generate_text.py``. 416 | 417 | Parameters 418 | ---------- 419 | filename : a string 420 | A file path (like .txt file), 421 | replace : a list 422 | [original string, target string], to disable replace use ['', ''] 423 | 424 | Returns 425 | -------- 426 | The context in a list, split by space by default, and use ``''`` to represent ``'\n'``, 427 | e.g. ``[... 'how', 'useful', 'it', "'s" ... ]``. 428 | 429 | Code References 430 | --------------- 431 | - `tensorflow.models.rnn.ptb.reader `_ 432 | """ 433 | with tf.gfile.GFile(filename, "r") as f: 434 | try: # python 3.4 or older 435 | context_list = f.read().replace(*replace).split() 436 | except: # python 3.5 437 | f.seek(0) 438 | replace = [x.encode('utf-8') for x in replace] 439 | context_list = f.read().replace(*replace).split() 440 | return context_list 441 | 442 | def read_analogies_file(eval_file='questions-words.txt', word2id={}): 443 | """Reads through an analogy question file, return its id format. 444 | 445 | Parameters 446 | ---------- 447 | eval_data : a string 448 | The file name. 449 | word2id : a dictionary 450 | Mapping words to unique IDs. 451 | 452 | Returns 453 | -------- 454 | analogy_questions : a [n, 4] numpy array containing the analogy question's 455 | word ids. 456 | questions_skipped: questions skipped due to unknown words. 457 | 458 | Examples 459 | --------- 460 | >>> eval_file should be in this format : 461 | >>> : capital-common-countries 462 | >>> Athens Greece Baghdad Iraq 463 | >>> Athens Greece Bangkok Thailand 464 | >>> Athens Greece Beijing China 465 | >>> Athens Greece Berlin Germany 466 | >>> Athens Greece Bern Switzerland 467 | >>> Athens Greece Cairo Egypt 468 | >>> Athens Greece Canberra Australia 469 | >>> Athens Greece Hanoi Vietnam 470 | >>> Athens Greece Havana Cuba 471 | ... 472 | 473 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 474 | >>> data, count, dictionary, reverse_dictionary = \ 475 | tl.nlp.build_words_dataset(words, vocabulary_size, True) 476 | >>> analogy_questions = tl.nlp.read_analogies_file( \ 477 | eval_file='questions-words.txt', word2id=dictionary) 478 | >>> print(analogy_questions) 479 | ... [[ 3068 1248 7161 1581] 480 | ... [ 3068 1248 28683 5642] 481 | ... [ 3068 1248 3878 486] 482 | ... ..., 483 | ... [ 1216 4309 19982 25506] 484 | ... [ 1216 4309 3194 8650] 485 | ... [ 1216 4309 140 312]] 486 | """ 487 | questions = [] 488 | questions_skipped = 0 489 | with open(eval_file, "rb") as analogy_f: 490 | for line in analogy_f: 491 | if line.startswith(b":"): # Skip comments. 492 | continue 493 | words = line.strip().lower().split(b" ") # lowercase 494 | ids = [word2id.get(w.strip()) for w in words] 495 | if None in ids or len(ids) != 4: 496 | questions_skipped += 1 497 | else: 498 | questions.append(np.array(ids)) 499 | print("Eval analogy file: ", eval_file) 500 | print("Questions: ", len(questions)) 501 | print("Skipped: ", questions_skipped) 502 | analogy_questions = np.array(questions, dtype=np.int32) 503 | return analogy_questions 504 | 505 | def build_vocab(data): 506 | """Build vocabulary. 507 | Given the context in list format. 508 | Return the vocabulary, which is a dictionary for word to id. 509 | e.g. {'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 .... } 510 | 511 | Parameters 512 | ---------- 513 | data : a list of string 514 | the context in list format 515 | 516 | Returns 517 | -------- 518 | word_to_id : a dictionary 519 | mapping words to unique IDs. e.g. {'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 .... } 520 | 521 | Code References 522 | --------------- 523 | - `tensorflow.models.rnn.ptb.reader `_ 524 | 525 | Examples 526 | -------- 527 | >>> data_path = os.getcwd() + '/simple-examples/data' 528 | >>> train_path = os.path.join(data_path, "ptb.train.txt") 529 | >>> word_to_id = build_vocab(read_txt_words(train_path)) 530 | """ 531 | # data = _read_words(filename) 532 | counter = collections.Counter(data) 533 | # print('counter', counter) # dictionary for the occurrence number of each word, e.g. 'banknote': 1, 'photography': 1, 'kia': 1 534 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 535 | # print('count_pairs',count_pairs) # convert dictionary to list of tuple, e.g. ('ssangyong', 1), ('swapo', 1), ('wachter', 1) 536 | words, _ = list(zip(*count_pairs)) 537 | word_to_id = dict(zip(words, range(len(words)))) 538 | # print(words) # list of words 539 | # print(word_to_id) # dictionary for word to id, e.g. 'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 540 | return word_to_id 541 | 542 | def build_reverse_dictionary(word_to_id): 543 | """Given a dictionary for converting word to integer id. 544 | Returns a reverse dictionary for converting a id to word. 545 | 546 | Parameters 547 | ---------- 548 | word_to_id : dictionary 549 | mapping words to unique ids 550 | 551 | Returns 552 | -------- 553 | reverse_dictionary : a dictionary 554 | mapping ids to words 555 | """ 556 | reverse_dictionary = dict(zip(word_to_id.values(), word_to_id.keys())) 557 | return reverse_dictionary 558 | 559 | def build_words_dataset(words=[], vocabulary_size=50000, printable=True, unk_key = 'UNK'): 560 | """Build the words dictionary and replace rare words with 'UNK' token. 561 | The most common word has the smallest integer id. 562 | 563 | Parameters 564 | ---------- 565 | words : a list of string or byte 566 | The context in list format. You may need to do preprocessing on the words, 567 | such as lower case, remove marks etc. 568 | vocabulary_size : an int 569 | The maximum vocabulary size, limiting the vocabulary size. 570 | Then the script replaces rare words with 'UNK' token. 571 | printable : boolean 572 | Whether to print the read vocabulary size of the given words. 573 | unk_key : a string 574 | Unknown words = unk_key 575 | 576 | Returns 577 | -------- 578 | data : a list of integer 579 | The context in a list of ids 580 | count : a list of tuple and list 581 | count[0] is a list : the number of rare words\n 582 | count[1:] are tuples : the number of occurrence of each word\n 583 | e.g. [['UNK', 418391], (b'the', 1061396), (b'of', 593677), (b'and', 416629), (b'one', 411764)] 584 | dictionary : a dictionary 585 | word_to_id, mapping words to unique IDs. 586 | reverse_dictionary : a dictionary 587 | id_to_word, mapping id to unique word. 588 | 589 | Examples 590 | -------- 591 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 592 | >>> vocabulary_size = 50000 593 | >>> data, count, dictionary, reverse_dictionary = tl.nlp.build_words_dataset(words, vocabulary_size) 594 | 595 | Code References 596 | ----------------- 597 | - `tensorflow/examples/tutorials/word2vec/word2vec_basic.py `_ 598 | """ 599 | import collections 600 | count = [[unk_key, -1]] 601 | count.extend(collections.Counter(words).most_common(vocabulary_size - 1)) 602 | dictionary = dict() 603 | for word, _ in count: 604 | dictionary[word] = len(dictionary) 605 | data = list() 606 | unk_count = 0 607 | for word in words: 608 | if word in dictionary: 609 | index = dictionary[word] 610 | else: 611 | index = 0 # dictionary['UNK'] 612 | unk_count += 1 613 | data.append(index) 614 | count[0][1] = unk_count 615 | reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys())) 616 | if printable: 617 | print('Real vocabulary size %d' % len(collections.Counter(words).keys())) 618 | print('Limited vocabulary size {}'.format(vocabulary_size)) 619 | assert len(collections.Counter(words).keys()) >= vocabulary_size , \ 620 | "the limited vocabulary_size must be less than or equal to the read vocabulary_size" 621 | return data, count, dictionary, reverse_dictionary 622 | 623 | def words_to_word_ids(data=[], word_to_id={}, unk_key = 'UNK'): 624 | """Given a context (words) in list format and the vocabulary, 625 | Returns a list of IDs to represent the context. 626 | 627 | Parameters 628 | ---------- 629 | data : a list of string or byte 630 | the context in list format 631 | word_to_id : a dictionary 632 | mapping words to unique IDs. 633 | unk_key : a string 634 | Unknown words = unk_key 635 | 636 | Returns 637 | -------- 638 | A list of IDs to represent the context. 639 | 640 | Examples 641 | -------- 642 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 643 | >>> vocabulary_size = 50000 644 | >>> data, count, dictionary, reverse_dictionary = \ 645 | ... tl.nlp.build_words_dataset(words, vocabulary_size, True) 646 | >>> context = [b'hello', b'how', b'are', b'you'] 647 | >>> ids = tl.nlp.words_to_word_ids(words, dictionary) 648 | >>> context = tl.nlp.word_ids_to_words(ids, reverse_dictionary) 649 | >>> print(ids) 650 | ... [6434, 311, 26, 207] 651 | >>> print(context) 652 | ... [b'hello', b'how', b'are', b'you'] 653 | 654 | Code References 655 | --------------- 656 | - `tensorflow.models.rnn.ptb.reader `_ 657 | """ 658 | # if isinstance(data[0], six.string_types): 659 | # print(type(data[0])) 660 | # # exit() 661 | # print(data[0]) 662 | # print(word_to_id) 663 | # return [word_to_id[str(word)] for word in data] 664 | # else: 665 | 666 | word_ids = [] 667 | for word in data: 668 | if word_to_id.get(word) is not None: 669 | word_ids.append(word_to_id[word]) 670 | else: 671 | word_ids.append(word_to_id[unk_key]) 672 | return word_ids 673 | # return [word_to_id[word] for word in data] # this one 674 | 675 | # if isinstance(data[0], str): 676 | # # print('is a string object') 677 | # return [word_to_id[word] for word in data] 678 | # else:#if isinstance(s, bytes): 679 | # # print('is a unicode object') 680 | # # print(data[0]) 681 | # return [word_to_id[str(word)] f 682 | 683 | def word_ids_to_words(data, id_to_word): 684 | """Given a context (ids) in list format and the vocabulary, 685 | Returns a list of words to represent the context. 686 | 687 | Parameters 688 | ---------- 689 | data : a list of integer 690 | the context in list format 691 | id_to_word : a dictionary 692 | mapping id to unique word. 693 | 694 | Returns 695 | -------- 696 | A list of string or byte to represent the context. 697 | 698 | Examples 699 | --------- 700 | >>> see words_to_word_ids 701 | """ 702 | return [id_to_word[i] for i in data] 703 | 704 | def save_vocab(count=[], name='vocab.txt'): 705 | """Save the vocabulary to a file so the model can be reloaded. 706 | 707 | Parameters 708 | ---------- 709 | count : a list of tuple and list 710 | count[0] is a list : the number of rare words\n 711 | count[1:] are tuples : the number of occurrence of each word\n 712 | e.g. [['UNK', 418391], (b'the', 1061396), (b'of', 593677), (b'and', 416629), (b'one', 411764)] 713 | 714 | Examples 715 | --------- 716 | >>> words = tl.files.load_matt_mahoney_text8_dataset() 717 | >>> vocabulary_size = 50000 718 | >>> data, count, dictionary, reverse_dictionary = \ 719 | ... tl.nlp.build_words_dataset(words, vocabulary_size, True) 720 | >>> tl.nlp.save_vocab(count, name='vocab_text8.txt') 721 | >>> vocab_text8.txt 722 | ... UNK 418391 723 | ... the 1061396 724 | ... of 593677 725 | ... and 416629 726 | ... one 411764 727 | ... in 372201 728 | ... a 325873 729 | ... to 316376 730 | """ 731 | pwd = os.getcwd() 732 | vocabulary_size = len(count) 733 | with open(os.path.join(pwd, name), "w") as f: 734 | for i in xrange(vocabulary_size): 735 | f.write("%s %d\n" % (tf.compat.as_text(count[i][0]), count[i][1])) 736 | print("%d vocab saved to %s in %s" % (vocabulary_size, name, pwd)) 737 | 738 | ## Functions for translation 739 | def basic_tokenizer(sentence, _WORD_SPLIT=re.compile(b"([.,!?\"':;)(])")): 740 | """Very basic tokenizer: split the sentence into a list of tokens. 741 | 742 | Parameters 743 | ----------- 744 | sentence : tensorflow.python.platform.gfile.GFile Object 745 | _WORD_SPLIT : regular expression for word spliting. 746 | 747 | 748 | Examples 749 | -------- 750 | >>> see create_vocabulary 751 | >>> from tensorflow.python.platform import gfile 752 | >>> train_path = "wmt/giga-fren.release2" 753 | >>> with gfile.GFile(train_path + ".en", mode="rb") as f: 754 | >>> for line in f: 755 | >>> tokens = tl.nlp.basic_tokenizer(line) 756 | >>> print(tokens) 757 | >>> exit() 758 | ... [b'Changing', b'Lives', b'|', b'Changing', b'Society', b'|', b'How', 759 | ... b'It', b'Works', b'|', b'Technology', b'Drives', b'Change', b'Home', 760 | ... b'|', b'Concepts', b'|', b'Teachers', b'|', b'Search', b'|', b'Overview', 761 | ... b'|', b'Credits', b'|', b'HHCC', b'Web', b'|', b'Reference', b'|', 762 | ... b'Feedback', b'Virtual', b'Museum', b'of', b'Canada', b'Home', b'Page'] 763 | 764 | References 765 | ---------- 766 | - Code from ``/tensorflow/models/rnn/translation/data_utils.py`` 767 | """ 768 | words = [] 769 | sentence = tf.compat.as_bytes(sentence) 770 | for space_separated_fragment in sentence.strip().split(): 771 | words.extend(re.split(_WORD_SPLIT, space_separated_fragment)) 772 | return [w for w in words if w] 773 | 774 | def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size, 775 | tokenizer=None, normalize_digits=True, 776 | _DIGIT_RE=re.compile(br"\d"), 777 | _START_VOCAB=[b"_PAD", b"_GO", b"_EOS", b"_UNK"]): 778 | """Create vocabulary file (if it does not exist yet) from data file. 779 | 780 | Data file is assumed to contain one sentence per line. Each sentence is 781 | tokenized and digits are normalized (if normalize_digits is set). 782 | Vocabulary contains the most-frequent tokens up to max_vocabulary_size. 783 | We write it to vocabulary_path in a one-token-per-line format, so that later 784 | token in the first line gets id=0, second line gets id=1, and so on. 785 | 786 | Parameters 787 | ----------- 788 | vocabulary_path : path where the vocabulary will be created. 789 | data_path : data file that will be used to create vocabulary. 790 | max_vocabulary_size : limit on the size of the created vocabulary. 791 | tokenizer : a function to use to tokenize each data sentence. 792 | if None, basic_tokenizer will be used. 793 | normalize_digits : Boolean 794 | if true, all digits are replaced by 0s. 795 | 796 | References 797 | ---------- 798 | - Code from ``/tensorflow/models/rnn/translation/data_utils.py`` 799 | """ 800 | if not gfile.Exists(vocabulary_path): 801 | print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path)) 802 | vocab = {} 803 | with gfile.GFile(data_path, mode="rb") as f: 804 | counter = 0 805 | for line in f: 806 | counter += 1 807 | if counter % 100000 == 0: 808 | print(" processing line %d" % counter) 809 | tokens = tokenizer(line) if tokenizer else basic_tokenizer(line) 810 | for w in tokens: 811 | word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w 812 | if word in vocab: 813 | vocab[word] += 1 814 | else: 815 | vocab[word] = 1 816 | vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True) 817 | if len(vocab_list) > max_vocabulary_size: 818 | vocab_list = vocab_list[:max_vocabulary_size] 819 | with gfile.GFile(vocabulary_path, mode="wb") as vocab_file: 820 | for w in vocab_list: 821 | vocab_file.write(w + b"\n") 822 | else: 823 | print("Vocabulary %s from data %s exists" % (vocabulary_path, data_path)) 824 | 825 | def initialize_vocabulary(vocabulary_path): 826 | """Initialize vocabulary from file, return the word_to_id (dictionary) 827 | and id_to_word (list). 828 | 829 | We assume the vocabulary is stored one-item-per-line, so a file:\n 830 | dog\n 831 | cat\n 832 | will result in a vocabulary {"dog": 0, "cat": 1}, and this function will 833 | also return the reversed-vocabulary ["dog", "cat"]. 834 | 835 | Parameters 836 | ----------- 837 | vocabulary_path : path to the file containing the vocabulary. 838 | 839 | Returns 840 | -------- 841 | vocab : a dictionary 842 | Word to id. A dictionary mapping string to integers. 843 | rev_vocab : a list 844 | Id to word. The reversed vocabulary (a list, which reverses the vocabulary mapping). 845 | 846 | Examples 847 | --------- 848 | >>> Assume 'test' contains 849 | ... dog 850 | ... cat 851 | ... bird 852 | >>> vocab, rev_vocab = tl.nlp.initialize_vocabulary("test") 853 | >>> print(vocab) 854 | >>> {b'cat': 1, b'dog': 0, b'bird': 2} 855 | >>> print(rev_vocab) 856 | >>> [b'dog', b'cat', b'bird'] 857 | 858 | Raises 859 | ------- 860 | ValueError : if the provided vocabulary_path does not exist. 861 | """ 862 | if gfile.Exists(vocabulary_path): 863 | rev_vocab = [] 864 | with gfile.GFile(vocabulary_path, mode="rb") as f: 865 | rev_vocab.extend(f.readlines()) 866 | rev_vocab = [tf.compat.as_bytes(line.strip()) for line in rev_vocab] 867 | vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)]) 868 | return vocab, rev_vocab 869 | else: 870 | raise ValueError("Vocabulary file %s not found.", vocabulary_path) 871 | 872 | def sentence_to_token_ids(sentence, vocabulary, 873 | tokenizer=None, normalize_digits=True, 874 | UNK_ID=3, _DIGIT_RE=re.compile(br"\d")): 875 | """Convert a string to list of integers representing token-ids. 876 | 877 | For example, a sentence "I have a dog" may become tokenized into 878 | ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2, 879 | "a": 4, "dog": 7"} this function will return [1, 2, 4, 7]. 880 | 881 | Parameters 882 | ----------- 883 | sentence : tensorflow.python.platform.gfile.GFile Object 884 | The sentence in bytes format to convert to token-ids.\n 885 | see basic_tokenizer(), data_to_token_ids() 886 | vocabulary : a dictionary mapping tokens to integers. 887 | tokenizer : a function to use to tokenize each sentence; 888 | If None, basic_tokenizer will be used. 889 | normalize_digits : Boolean 890 | If true, all digits are replaced by 0s. 891 | 892 | Returns 893 | -------- 894 | A list of integers, the token-ids for the sentence. 895 | """ 896 | 897 | if tokenizer: 898 | words = tokenizer(sentence) 899 | else: 900 | words = basic_tokenizer(sentence) 901 | if not normalize_digits: 902 | return [vocabulary.get(w, UNK_ID) for w in words] 903 | # Normalize digits by 0 before looking words up in the vocabulary. 904 | return [vocabulary.get(re.sub(_DIGIT_RE, b"0", w), UNK_ID) for w in words] 905 | 906 | def data_to_token_ids(data_path, target_path, vocabulary_path, 907 | tokenizer=None, normalize_digits=True, 908 | UNK_ID=3, _DIGIT_RE=re.compile(br"\d")): 909 | """Tokenize data file and turn into token-ids using given vocabulary file. 910 | 911 | This function loads data line-by-line from data_path, calls the above 912 | sentence_to_token_ids, and saves the result to target_path. See comment 913 | for sentence_to_token_ids on the details of token-ids format. 914 | 915 | Parameters 916 | ----------- 917 | data_path : path to the data file in one-sentence-per-line format. 918 | target_path : path where the file with token-ids will be created. 919 | vocabulary_path : path to the vocabulary file. 920 | tokenizer : a function to use to tokenize each sentence; 921 | if None, basic_tokenizer will be used. 922 | normalize_digits : Boolean; if true, all digits are replaced by 0s. 923 | 924 | References 925 | ---------- 926 | - Code from ``/tensorflow/models/rnn/translation/data_utils.py`` 927 | """ 928 | if not gfile.Exists(target_path): 929 | print("Tokenizing data in %s" % data_path) 930 | vocab, _ = initialize_vocabulary(vocabulary_path) 931 | with gfile.GFile(data_path, mode="rb") as data_file: 932 | with gfile.GFile(target_path, mode="w") as tokens_file: 933 | counter = 0 934 | for line in data_file: 935 | counter += 1 936 | if counter % 100000 == 0: 937 | print(" tokenizing line %d" % counter) 938 | token_ids = sentence_to_token_ids(line, vocab, tokenizer, 939 | normalize_digits, UNK_ID=UNK_ID, 940 | _DIGIT_RE=_DIGIT_RE) 941 | tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n") 942 | else: 943 | print("Target path %s exists" % target_path) 944 | -------------------------------------------------------------------------------- /tensorlayer/ops.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | 7 | import tensorflow as tf 8 | import os 9 | import sys 10 | from sys import platform as _platform 11 | 12 | 13 | def exit_tf(sess=None): 14 | """Close tensorboard and nvidia-process if available 15 | 16 | Parameters 17 | ---------- 18 | sess : a session instance of TensorFlow 19 | TensorFlow session 20 | """ 21 | text = "[tl] Close tensorboard and nvidia-process if available" 22 | sess.close() 23 | # import time 24 | # time.sleep(2) 25 | if _platform == "linux" or _platform == "linux2": 26 | print('linux: %s' % text) 27 | os.system('nvidia-smi') 28 | os.system('fuser 6006/tcp -k') # kill tensorboard 6006 29 | os.system("nvidia-smi | grep python |awk '{print $3}'|xargs kill") # kill all nvidia-smi python process 30 | elif _platform == "darwin": 31 | print('OS X: %s' % text) 32 | os.system("lsof -i tcp:6006 | grep -v PID | awk '{print $2}' | xargs kill") # kill tensorboard 6006 33 | elif _platform == "win32": 34 | print('Windows: %s' % text) 35 | else: 36 | print(_platform) 37 | exit() 38 | 39 | def clear_all(printable=True): 40 | """Clears all the placeholder variables of keep prob, 41 | including keeping probabilities of all dropout, denoising, dropconnect etc. 42 | 43 | Parameters 44 | ---------- 45 | printable : boolean 46 | If True, print all deleted variables. 47 | """ 48 | print('clear all .....................................') 49 | gl = globals().copy() 50 | for var in gl: 51 | if var[0] == '_': continue 52 | if 'func' in str(globals()[var]): continue 53 | if 'module' in str(globals()[var]): continue 54 | if 'class' in str(globals()[var]): continue 55 | 56 | if printable: 57 | print(" clear_all ------- %s" % str(globals()[var])) 58 | 59 | del globals()[var] 60 | 61 | # def clear_all2(vars, printable=True): 62 | # """ 63 | # The :function:`clear_all()` Clears all the placeholder variables of keep prob, 64 | # including keeping probabilities of all dropout, denoising, dropconnect 65 | # Parameters 66 | # ---------- 67 | # printable : if True, print all deleted variables. 68 | # """ 69 | # print('clear all .....................................') 70 | # for var in vars: 71 | # if var[0] == '_': continue 72 | # if 'func' in str(var): continue 73 | # if 'module' in str(var): continue 74 | # if 'class' in str(var): continue 75 | # 76 | # if printable: 77 | # print(" clear_all ------- %s" % str(var)) 78 | # 79 | # del var 80 | 81 | def set_gpu_fraction(sess=None, gpu_fraction=0.3): 82 | """Set the GPU memory fraction for the application. 83 | 84 | Parameters 85 | ---------- 86 | sess : a session instance of TensorFlow 87 | TensorFlow session 88 | gpu_fraction : a float 89 | Fraction of GPU memory, (0 ~ 1] 90 | 91 | References 92 | ---------- 93 | - `TensorFlow using GPU `_ 94 | """ 95 | print(" tensorlayer: GPU MEM Fraction %f" % gpu_fraction) 96 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction) 97 | sess = tf.Session(config = tf.ConfigProto(gpu_options = gpu_options)) 98 | return sess 99 | 100 | 101 | 102 | 103 | 104 | def disable_print(): 105 | """Disable console output, ``suppress_stdout`` is recommended. 106 | 107 | Examples 108 | --------- 109 | >>> print("You can see me") 110 | >>> tl.ops.disable_print() 111 | >>> print(" You can't see me") 112 | >>> tl.ops.enable_print() 113 | >>> print("You can see me") 114 | """ 115 | # sys.stdout = os.devnull # this one kill the process 116 | sys.stdout = None 117 | sys.stderr = os.devnull 118 | 119 | def enable_print(): 120 | """Enable console output, ``suppress_stdout`` is recommended. 121 | 122 | Examples 123 | -------- 124 | - see tl.ops.disable_print() 125 | """ 126 | sys.stdout = sys.__stdout__ 127 | sys.stderr = sys.__stderr__ 128 | 129 | 130 | # class temporary_disable_print: 131 | # """Temporarily disable console output. 132 | # 133 | # Examples 134 | # --------- 135 | # >>> print("You can see me") 136 | # >>> with tl.ops.temporary_disable_print() as t: 137 | # >>> print("You can't see me") 138 | # >>> print("You can see me") 139 | # """ 140 | # def __init__(self): 141 | # pass 142 | # def __enter__(self): 143 | # sys.stdout = None 144 | # sys.stderr = os.devnull 145 | # def __exit__(self, type, value, traceback): 146 | # sys.stdout = sys.__stdout__ 147 | # sys.stderr = sys.__stderr__ 148 | # return isinstance(value, TypeError) 149 | 150 | 151 | from contextlib import contextmanager 152 | @contextmanager 153 | def suppress_stdout(): 154 | """Temporarily disable console output. 155 | 156 | Examples 157 | --------- 158 | >>> print("You can see me") 159 | >>> with tl.ops.suppress_stdout(): 160 | >>> print("You can't see me") 161 | >>> print("You can see me") 162 | 163 | References 164 | ----------- 165 | - `stackoverflow `_ 166 | """ 167 | with open(os.devnull, "w") as devnull: 168 | old_stdout = sys.stdout 169 | sys.stdout = devnull 170 | try: 171 | yield 172 | finally: 173 | sys.stdout = old_stdout 174 | 175 | 176 | 177 | def get_site_packages_directory(): 178 | """Print and return the site-packages directory. 179 | 180 | Examples 181 | --------- 182 | >>> loc = tl.ops.get_site_packages_directory() 183 | """ 184 | import site 185 | try: 186 | loc = site.getsitepackages() 187 | print(" tl.ops : site-packages in ", loc) 188 | return loc 189 | except: 190 | print(" tl.ops : Cannot find package dir from virtual environment") 191 | return False 192 | 193 | 194 | 195 | def empty_trash(): 196 | """Empty trash folder. 197 | 198 | """ 199 | text = "[tl] Empty the trash" 200 | if _platform == "linux" or _platform == "linux2": 201 | print('linux: %s' % text) 202 | os.system("rm -rf ~/.local/share/Trash/*") 203 | elif _platform == "darwin": 204 | print('OS X: %s' % text) 205 | os.system("sudo rm -rf ~/.Trash/*") 206 | elif _platform == "win32": 207 | print('Windows: %s' % text) 208 | try: 209 | os.system("rd /s c:\$Recycle.Bin") # Windows 7 or Server 2008 210 | except: 211 | pass 212 | try: 213 | os.system("rd /s c:\recycler") # Windows XP, Vista, or Server 2003 214 | except: 215 | pass 216 | else: 217 | print(_platform) 218 | 219 | # 220 | -------------------------------------------------------------------------------- /tensorlayer/rein.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | from six.moves import xrange 9 | 10 | def discount_episode_rewards(rewards=[], gamma=0.99, mode=0): 11 | """ Take 1D float array of rewards and compute discounted rewards for an 12 | episode. When encount a non-zero value, consider as the end a of an episode. 13 | 14 | Parameters 15 | ---------- 16 | rewards : numpy list 17 | a list of rewards 18 | gamma : float 19 | discounted factor 20 | mode : int 21 | if mode == 0, reset the discount process when encount a non-zero reward (Ping-pong game). 22 | if mode == 1, would not reset the discount process. 23 | 24 | Examples 25 | ---------- 26 | >>> rewards = np.asarray([0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1]) 27 | >>> gamma = 0.9 28 | >>> discount_rewards = tl.rein.discount_episode_rewards(rewards, gamma) 29 | >>> print(discount_rewards) 30 | ... [ 0.72899997 0.81 0.89999998 1. 0.72899997 0.81 31 | ... 0.89999998 1. 0.72899997 0.81 0.89999998 1. ] 32 | >>> discount_rewards = tl.rein.discount_episode_rewards(rewards, gamma, mode=1) 33 | >>> print(discount_rewards) 34 | ... [ 1.52110755 1.69011939 1.87791049 2.08656716 1.20729685 1.34144104 35 | ... 1.49048996 1.65610003 0.72899997 0.81 0.89999998 1. ] 36 | """ 37 | discounted_r = np.zeros_like(rewards, dtype=np.float32) 38 | running_add = 0 39 | for t in reversed(xrange(0, rewards.size)): 40 | if mode == 0: 41 | if rewards[t] != 0: running_add = 0 42 | 43 | running_add = running_add * gamma + rewards[t] 44 | discounted_r[t] = running_add 45 | return discounted_r 46 | 47 | 48 | def cross_entropy_reward_loss(logits, actions, rewards, name=None): 49 | """ Calculate the loss for Policy Gradient Network. 50 | 51 | Parameters 52 | ---------- 53 | logits : tensor 54 | The network outputs without softmax. This function implements softmax 55 | inside. 56 | actions : tensor/ placeholder 57 | The agent actions. 58 | rewards : tensor/ placeholder 59 | The rewards. 60 | 61 | Examples 62 | ---------- 63 | >>> states_batch_pl = tf.placeholder(tf.float32, shape=[None, D]) # observation for training 64 | >>> network = tl.layers.InputLayer(states_batch_pl, name='input_layer') 65 | >>> network = tl.layers.DenseLayer(network, n_units=H, act = tf.nn.relu, name='relu1') 66 | >>> network = tl.layers.DenseLayer(network, n_units=3, act = tl.activation.identity, name='output_layer') 67 | >>> probs = network.outputs 68 | >>> sampling_prob = tf.nn.softmax(probs) 69 | >>> actions_batch_pl = tf.placeholder(tf.int32, shape=[None]) 70 | >>> discount_rewards_batch_pl = tf.placeholder(tf.float32, shape=[None]) 71 | >>> loss = cross_entropy_reward_loss(probs, actions_batch_pl, discount_rewards_batch_pl) 72 | >>> train_op = tf.train.RMSPropOptimizer(learning_rate, decay_rate).minimize(loss) 73 | """ 74 | 75 | try: # TF 1.0 76 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=actions, logits=logits, name=name) 77 | except: 78 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, targets=actions) 79 | # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, actions) 80 | 81 | try: ## TF1.0 82 | loss = tf.reduce_sum(tf.multiply(cross_entropy, rewards)) 83 | except: ## TF0.12 84 | loss = tf.reduce_sum(tf.mul(cross_entropy, rewards)) # element-wise mul 85 | return loss 86 | -------------------------------------------------------------------------------- /tensorlayer/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | import tensorflow as tf 4 | import tensorlayer as tl 5 | from . import iterate 6 | import numpy as np 7 | import time 8 | import math 9 | import random 10 | 11 | 12 | def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100, 13 | n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True, 14 | tensorboard=False, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True, tensorboard_graph_vis=True): 15 | """Traing a given non time-series network by the given cost function, training data, batch_size, n_epoch etc. 16 | 17 | Parameters 18 | ---------- 19 | sess : TensorFlow session 20 | sess = tf.InteractiveSession() 21 | network : a TensorLayer layer 22 | the network will be trained 23 | train_op : a TensorFlow optimizer 24 | like tf.train.AdamOptimizer 25 | X_train : numpy array 26 | the input of training data 27 | y_train : numpy array 28 | the target of training data 29 | x : placeholder 30 | for inputs 31 | y_ : placeholder 32 | for targets 33 | acc : the TensorFlow expression of accuracy (or other metric) or None 34 | if None, would not display the metric 35 | batch_size : int 36 | batch size for training and evaluating 37 | n_epoch : int 38 | the number of training epochs 39 | print_freq : int 40 | display the training information every ``print_freq`` epochs 41 | X_val : numpy array or None 42 | the input of validation data 43 | y_val : numpy array or None 44 | the target of validation data 45 | eval_train : boolean 46 | if X_val and y_val are not None, it refects whether to evaluate the training data 47 | tensorboard : boolean 48 | if True summary data will be stored to the log/ direcory for visualization with tensorboard. 49 | See also detailed tensorboard_X settings for specific configurations of features. (default False) 50 | Also runs tl.layers.initialize_global_variables(sess) internally in fit() to setup the summary nodes, see Note: 51 | tensorboard_epoch_freq : int 52 | how many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5) 53 | tensorboard_weight_histograms : boolean 54 | if True updates tensorboard data in the logs/ directory for visulaization 55 | of the weight histograms every tensorboard_epoch_freq epoch (default True) 56 | tensorboard_graph_vis : boolean 57 | if True stores the graph in the tensorboard summaries saved to log/ (default True) 58 | 59 | Examples 60 | -------- 61 | >>> see tutorial_mnist_simple.py 62 | >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_, 63 | ... acc=acc, batch_size=500, n_epoch=200, print_freq=5, 64 | ... X_val=X_val, y_val=y_val, eval_train=False) 65 | >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_, 66 | ... acc=acc, batch_size=500, n_epoch=200, print_freq=5, 67 | ... X_val=X_val, y_val=y_val, eval_train=False, 68 | ... tensorboard=True, tensorboard_weight_histograms=True, tensorboard_graph_vis=True) 69 | 70 | Note 71 | -------- 72 | If tensorboard=True, the global_variables_initializer will be run inside the fit function 73 | in order to initalize the automatically generated summary nodes used for tensorboard visualization, 74 | thus tf.global_variables_initializer().run() before the fit() call will be undefined. 75 | """ 76 | assert X_train.shape[0] >= batch_size, "Number of training examples should be bigger than the batch size" 77 | 78 | if(tensorboard): 79 | print("Setting up tensorboard ...") 80 | #Set up tensorboard summaries and saver 81 | tl.files.exists_or_mkdir('logs/') 82 | 83 | #Only write summaries for more recent TensorFlow versions 84 | if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'): 85 | if tensorboard_graph_vis: 86 | train_writer = tf.summary.FileWriter('logs/train',sess.graph) 87 | val_writer = tf.summary.FileWriter('logs/validation',sess.graph) 88 | else: 89 | train_writer = tf.summary.FileWriter('logs/train') 90 | val_writer = tf.summary.FileWriter('logs/validation') 91 | 92 | #Set up summary nodes 93 | if(tensorboard_weight_histograms): 94 | for param in network.all_params: 95 | if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'): 96 | print('Param name ', param.name) 97 | tf.summary.histogram(param.name, param) 98 | 99 | if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'): 100 | tf.summary.scalar('cost', cost) 101 | 102 | merged = tf.summary.merge_all() 103 | 104 | #Initalize all variables and summaries 105 | tl.layers.initialize_global_variables(sess) 106 | print("Finished! use $tensorboard --logdir=logs/ to start server") 107 | 108 | print("Start training the network ...") 109 | start_time_begin = time.time() 110 | tensorboard_train_index, tensorboard_val_index = 0, 0 111 | for epoch in range(n_epoch): 112 | start_time = time.time() 113 | loss_ep = 0; n_step = 0 114 | for X_train_a, y_train_a in iterate.minibatches(X_train, y_train, 115 | batch_size, shuffle=True): 116 | feed_dict = {x: X_train_a, y_: y_train_a} 117 | feed_dict.update( network.all_drop ) # enable noise layers 118 | loss, _ = sess.run([cost, train_op], feed_dict=feed_dict) 119 | loss_ep += loss 120 | n_step += 1 121 | loss_ep = loss_ep/ n_step 122 | 123 | if tensorboard and hasattr(tf, 'summary'): 124 | if epoch+1 == 1 or (epoch+1) % tensorboard_epoch_freq == 0: 125 | for X_train_a, y_train_a in iterate.minibatches( 126 | X_train, y_train, batch_size, shuffle=True): 127 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 128 | feed_dict = {x: X_train_a, y_: y_train_a} 129 | feed_dict.update(dp_dict) 130 | result = sess.run(merged, feed_dict=feed_dict) 131 | train_writer.add_summary(result, tensorboard_train_index) 132 | tensorboard_train_index += 1 133 | if (X_val is not None) and (y_val is not None): 134 | for X_val_a, y_val_a in iterate.minibatches( 135 | X_val, y_val, batch_size, shuffle=True): 136 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 137 | feed_dict = {x: X_val_a, y_: y_val_a} 138 | feed_dict.update(dp_dict) 139 | result = sess.run(merged, feed_dict=feed_dict) 140 | val_writer.add_summary(result, tensorboard_val_index) 141 | tensorboard_val_index += 1 142 | 143 | if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: 144 | if (X_val is not None) and (y_val is not None): 145 | print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time)) 146 | if eval_train is True: 147 | train_loss, train_acc, n_batch = 0, 0, 0 148 | for X_train_a, y_train_a in iterate.minibatches( 149 | X_train, y_train, batch_size, shuffle=True): 150 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 151 | feed_dict = {x: X_train_a, y_: y_train_a} 152 | feed_dict.update(dp_dict) 153 | if acc is not None: 154 | err, ac = sess.run([cost, acc], feed_dict=feed_dict) 155 | train_acc += ac 156 | else: 157 | err = sess.run(cost, feed_dict=feed_dict) 158 | train_loss += err; n_batch += 1 159 | print(" train loss: %f" % (train_loss/ n_batch)) 160 | if acc is not None: 161 | print(" train acc: %f" % (train_acc/ n_batch)) 162 | val_loss, val_acc, n_batch = 0, 0, 0 163 | for X_val_a, y_val_a in iterate.minibatches( 164 | X_val, y_val, batch_size, shuffle=True): 165 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 166 | feed_dict = {x: X_val_a, y_: y_val_a} 167 | feed_dict.update(dp_dict) 168 | if acc is not None: 169 | err, ac = sess.run([cost, acc], feed_dict=feed_dict) 170 | val_acc += ac 171 | else: 172 | err = sess.run(cost, feed_dict=feed_dict) 173 | val_loss += err; n_batch += 1 174 | print(" val loss: %f" % (val_loss/ n_batch)) 175 | if acc is not None: 176 | print(" val acc: %f" % (val_acc/ n_batch)) 177 | else: 178 | print("Epoch %d of %d took %fs, loss %f" % (epoch + 1, n_epoch, time.time() - start_time, loss_ep)) 179 | print("Total training time: %fs" % (time.time() - start_time_begin)) 180 | 181 | 182 | def test(sess, network, acc, X_test, y_test, x, y_, batch_size, cost=None): 183 | """ 184 | Test a given non time-series network by the given test data and metric. 185 | 186 | Parameters 187 | ---------- 188 | sess : TensorFlow session 189 | sess = tf.InteractiveSession() 190 | network : a TensorLayer layer 191 | the network will be trained 192 | acc : the TensorFlow expression of accuracy (or other metric) or None 193 | if None, would not display the metric 194 | X_test : numpy array 195 | the input of test data 196 | y_test : numpy array 197 | the target of test data 198 | x : placeholder 199 | for inputs 200 | y_ : placeholder 201 | for targets 202 | batch_size : int or None 203 | batch size for testing, when dataset is large, we should use minibatche for testing. 204 | when dataset is small, we can set it to None. 205 | cost : the TensorFlow expression of cost or None 206 | if None, would not display the cost 207 | 208 | Examples 209 | -------- 210 | >>> see tutorial_mnist_simple.py 211 | >>> tl.utils.test(sess, network, acc, X_test, y_test, x, y_, batch_size=None, cost=cost) 212 | """ 213 | print('Start testing the network ...') 214 | if batch_size is None: 215 | dp_dict = dict_to_one( network.all_drop ) 216 | feed_dict = {x: X_test, y_: y_test} 217 | feed_dict.update(dp_dict) 218 | if cost is not None: 219 | print(" test loss: %f" % sess.run(cost, feed_dict=feed_dict)) 220 | print(" test acc: %f" % sess.run(acc, feed_dict=feed_dict)) 221 | # print(" test acc: %f" % np.mean(y_test == sess.run(y_op, 222 | # feed_dict=feed_dict))) 223 | else: 224 | test_loss, test_acc, n_batch = 0, 0, 0 225 | for X_test_a, y_test_a in iterate.minibatches( 226 | X_test, y_test, batch_size, shuffle=True): 227 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 228 | feed_dict = {x: X_test_a, y_: y_test_a} 229 | feed_dict.update(dp_dict) 230 | if cost is not None: 231 | err, ac = sess.run([cost, acc], feed_dict=feed_dict) 232 | test_loss += err 233 | else: 234 | ac = sess.run(acc, feed_dict=feed_dict) 235 | test_acc += ac; n_batch += 1 236 | if cost is not None: 237 | print(" test loss: %f" % (test_loss/ n_batch)) 238 | print(" test acc: %f" % (test_acc/ n_batch)) 239 | 240 | 241 | def predict(sess, network, X, x, y_op, batch_size=None): 242 | """ 243 | Return the predict results of given non time-series network. 244 | 245 | Parameters 246 | ---------- 247 | sess : TensorFlow session 248 | sess = tf.InteractiveSession() 249 | network : a TensorLayer layer 250 | the network will be trained 251 | X : numpy array 252 | the input 253 | x : placeholder 254 | for inputs 255 | y_op : placeholder 256 | the argmax expression of softmax outputs 257 | batch_size : int or None 258 | batch size for prediction, when dataset is large, we should use minibatche for prediction. 259 | when dataset is small, we can set it to None. 260 | 261 | Examples 262 | -------- 263 | >>> see tutorial_mnist_simple.py 264 | >>> y = network.outputs 265 | >>> y_op = tf.argmax(tf.nn.softmax(y), 1) 266 | >>> print(tl.utils.predict(sess, network, X_test, x, y_op)) 267 | """ 268 | if batch_size is None: 269 | dp_dict = dict_to_one( network.all_drop ) # disable noise layers 270 | feed_dict = {x: X,} 271 | feed_dict.update(dp_dict) 272 | return sess.run(y_op, feed_dict=feed_dict) 273 | else: 274 | result = None 275 | for X_a, _ in iterate.minibatches( 276 | X, X, batch_size, shuffle=False): 277 | dp_dict = dict_to_one( network.all_drop ) 278 | feed_dict = {x: X_a, } 279 | feed_dict.update(dp_dict) 280 | result_a = sess.run(y_op, feed_dict=feed_dict) 281 | if result is None: 282 | result = result_a 283 | else: 284 | result = np.hstack((result, result_a)) 285 | return result 286 | 287 | 288 | ## Evaluation 289 | def evaluation(y_test=None, y_predict=None, n_classes=None): 290 | """ 291 | Input the predicted results, targets results and 292 | the number of class, return the confusion matrix, F1-score of each class, 293 | accuracy and macro F1-score. 294 | 295 | Parameters 296 | ---------- 297 | y_test : numpy.array or list 298 | target results 299 | y_predict : numpy.array or list 300 | predicted results 301 | n_classes : int 302 | number of classes 303 | 304 | Examples 305 | -------- 306 | >>> c_mat, f1, acc, f1_macro = evaluation(y_test, y_predict, n_classes) 307 | """ 308 | from sklearn.metrics import confusion_matrix, f1_score, accuracy_score 309 | c_mat = confusion_matrix(y_test, y_predict, labels = [x for x in range(n_classes)]) 310 | f1 = f1_score(y_test, y_predict, average = None, labels = [x for x in range(n_classes)]) 311 | f1_macro = f1_score(y_test, y_predict, average='macro') 312 | acc = accuracy_score(y_test, y_predict) 313 | print('confusion matrix: \n',c_mat) 314 | print('f1-score:',f1) 315 | print('f1-score(macro):',f1_macro) # same output with > f1_score(y_true, y_pred, average='macro') 316 | print('accuracy-score:', acc) 317 | return c_mat, f1, acc, f1_macro 318 | 319 | def dict_to_one(dp_dict={}): 320 | """ 321 | Input a dictionary, return a dictionary that all items are set to one, 322 | use for disable dropout, dropconnect layer and so on. 323 | 324 | Parameters 325 | ---------- 326 | dp_dict : dictionary 327 | keeping probabilities 328 | 329 | Examples 330 | -------- 331 | >>> dp_dict = dict_to_one( network.all_drop ) 332 | >>> dp_dict = dict_to_one( network.all_drop ) 333 | >>> feed_dict.update(dp_dict) 334 | """ 335 | return {x: 1 for x in dp_dict} 336 | 337 | def flatten_list(list_of_list=[[],[]]): 338 | """ 339 | Input a list of list, return a list that all items are in a list. 340 | 341 | Parameters 342 | ---------- 343 | list_of_list : a list of list 344 | 345 | Examples 346 | -------- 347 | >>> tl.utils.flatten_list([[1, 2, 3],[4, 5],[6]]) 348 | ... [1, 2, 3, 4, 5, 6] 349 | """ 350 | return sum(list_of_list, []) 351 | 352 | 353 | def class_balancing_oversample(X_train=None, y_train=None, printable=True): 354 | """Input the features and labels, return the features and labels after oversampling. 355 | 356 | Parameters 357 | ---------- 358 | X_train : numpy.array 359 | Features, each row is an example 360 | y_train : numpy.array 361 | Labels 362 | 363 | Examples 364 | -------- 365 | - One X 366 | >>> X_train, y_train = class_balancing_oversample(X_train, y_train, printable=True) 367 | 368 | - Two X 369 | >>> X, y = tl.utils.class_balancing_oversample(X_train=np.hstack((X1, X2)), y_train=y, printable=False) 370 | >>> X1 = X[:, 0:5] 371 | >>> X2 = X[:, 5:] 372 | """ 373 | # ======== Classes balancing 374 | if printable: 375 | print("Classes balancing for training examples...") 376 | from collections import Counter 377 | c = Counter(y_train) 378 | if printable: 379 | print('the occurrence number of each stage: %s' % c.most_common()) 380 | print('the least stage is Label %s have %s instances' % c.most_common()[-1]) 381 | print('the most stage is Label %s have %s instances' % c.most_common(1)[0]) 382 | most_num = c.most_common(1)[0][1] 383 | if printable: 384 | print('most num is %d, all classes tend to be this num' % most_num) 385 | 386 | locations = {} 387 | number = {} 388 | 389 | for lab, num in c.most_common(): # find the index from y_train 390 | number[lab] = num 391 | locations[lab] = np.where(np.array(y_train)==lab)[0] 392 | if printable: 393 | print('convert list(np.array) to dict format') 394 | X = {} # convert list to dict 395 | for lab, num in number.items(): 396 | X[lab] = X_train[locations[lab]] 397 | 398 | # oversampling 399 | if printable: 400 | print('start oversampling') 401 | for key in X: 402 | temp = X[key] 403 | while True: 404 | if len(X[key]) >= most_num: 405 | break 406 | X[key] = np.vstack((X[key], temp)) 407 | if printable: 408 | print('first features of label 0 >', len(X[0][0])) 409 | print('the occurrence num of each stage after oversampling') 410 | for key in X: 411 | print(key, len(X[key])) 412 | if printable: 413 | print('make each stage have same num of instances') 414 | for key in X: 415 | X[key] = X[key][0:most_num,:] 416 | print(key, len(X[key])) 417 | 418 | # convert dict to list 419 | if printable: 420 | print('convert from dict to list format') 421 | y_train = [] 422 | X_train = np.empty(shape=(0,len(X[0][0]))) 423 | for key in X: 424 | X_train = np.vstack( (X_train, X[key] ) ) 425 | y_train.extend([key for i in range(len(X[key]))]) 426 | # print(len(X_train), len(y_train)) 427 | c = Counter(y_train) 428 | if printable: 429 | print('the occurrence number of each stage after oversampling: %s' % c.most_common()) 430 | # ================ End of Classes balancing 431 | return X_train, y_train 432 | 433 | ## Random 434 | def get_random_int(min=0, max=10, number=5, seed=None): 435 | """Return a list of random integer by the given range and quantity. 436 | 437 | Examples 438 | --------- 439 | >>> r = get_random_int(min=0, max=10, number=5) 440 | ... [10, 2, 3, 3, 7] 441 | """ 442 | rnd = random.Random() 443 | if seed: 444 | rnd = random.Random(seed) 445 | # return [random.randint(min,max) for p in range(0, number)] 446 | return [rnd.randint(min,max) for p in range(0, number)] 447 | 448 | # 449 | # def class_balancing_sequence_4D(X_train, y_train, sequence_length, model='downsampling' ,printable=True): 450 | # ''' 输入、输出都是sequence format 451 | # oversampling or downsampling 452 | # ''' 453 | # n_features = X_train.shape[2] 454 | # # ======== Classes balancing for sequence 455 | # if printable: 456 | # print("Classes balancing for 4D sequence training examples...") 457 | # from collections import Counter 458 | # c = Counter(y_train) # Counter({2: 454, 4: 267, 3: 124, 1: 57, 0: 48}) 459 | # if printable: 460 | # print('the occurrence number of each stage: %s' % c.most_common()) 461 | # print('the least Label %s have %s instances' % c.most_common()[-1]) 462 | # print('the most Label %s have %s instances' % c.most_common(1)[0]) 463 | # # print(c.most_common()) # [(2, 454), (4, 267), (3, 124), (1, 57), (0, 48)] 464 | # most_num = c.most_common(1)[0][1] 465 | # less_num = c.most_common()[-1][1] 466 | # 467 | # locations = {} 468 | # number = {} 469 | # for lab, num in c.most_common(): 470 | # number[lab] = num 471 | # locations[lab] = np.where(np.array(y_train)==lab)[0] 472 | # # print(locations) 473 | # # print(number) 474 | # if printable: 475 | # print(' convert list to dict') 476 | # X = {} # convert list to dict 477 | # ### a sequence 478 | # for lab, _ in number.items(): 479 | # X[lab] = np.empty(shape=(0,1,n_features,1)) # 4D 480 | # for lab, _ in number.items(): 481 | # #X[lab] = X_train[locations[lab] 482 | # for l in locations[lab]: 483 | # X[lab] = np.vstack((X[lab], X_train[l*sequence_length : (l+1)*(sequence_length)])) 484 | # # X[lab] = X_train[locations[lab]*sequence_length : locations[lab]*(sequence_length+1)] # a sequence 485 | # # print(X) 486 | # 487 | # if model=='oversampling': 488 | # if printable: 489 | # print(' oversampling -- most num is %d, all classes tend to be this num\nshuffle applied' % most_num) 490 | # for key in X: 491 | # temp = X[key] 492 | # while True: 493 | # if len(X[key]) >= most_num * sequence_length: # sequence 494 | # break 495 | # X[key] = np.vstack((X[key], temp)) 496 | # # print(key, len(X[key])) 497 | # if printable: 498 | # print(' make each stage have same num of instances') 499 | # for key in X: 500 | # X[key] = X[key][0:most_num*sequence_length,:] # sequence 501 | # if printable: 502 | # print(key, len(X[key])) 503 | # elif model=='downsampling': 504 | # import random 505 | # if printable: 506 | # print(' downsampling -- less num is %d, all classes tend to be this num by randomly choice without replacement\nshuffle applied' % less_num) 507 | # for key in X: 508 | # # print(key, len(X[key]))#, len(X[key])/sequence_length) 509 | # s_idx = [ i for i in range(int(len(X[key])/sequence_length))] 510 | # s_idx = np.asarray(s_idx)*sequence_length # start index of sequnce in X[key] 511 | # # print('s_idx',s_idx) 512 | # r_idx = np.random.choice(s_idx, less_num, replace=False) # random choice less_num of s_idx 513 | # # print('r_idx',r_idx) 514 | # temp = X[key] 515 | # X[key] = np.empty(shape=(0,1,n_features,1)) # 4D 516 | # for idx in r_idx: 517 | # X[key] = np.vstack((X[key], temp[idx:idx+sequence_length])) 518 | # # print(key, X[key]) 519 | # # np.random.choice(l, len(l), replace=False) 520 | # else: 521 | # raise Exception(' model should be oversampling or downsampling') 522 | # 523 | # # convert dict to list 524 | # if printable: 525 | # print(' convert dict to list') 526 | # y_train = [] 527 | # # X_train = np.empty(shape=(0,len(X[0][0]))) 528 | # # X_train = np.empty(shape=(0,len(X[1][0]))) # 2D 529 | # X_train = np.empty(shape=(0,1,n_features,1)) # 4D 530 | # l_key = list(X.keys()) # shuffle 531 | # random.shuffle(l_key) # shuffle 532 | # # for key in X: # no shuffle 533 | # for key in l_key: # shuffle 534 | # X_train = np.vstack( (X_train, X[key] ) ) 535 | # # print(len(X[key])) 536 | # y_train.extend([key for i in range(int(len(X[key])/sequence_length))]) 537 | # # print(X_train,y_train, type(X_train), type(y_train)) 538 | # # ================ End of Classes balancing for sequence 539 | # # print(X_train.shape, len(y_train)) 540 | # return X_train, np.asarray(y_train) 541 | -------------------------------------------------------------------------------- /tensorlayer/visualize.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf8 -*- 3 | 4 | 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | # import matplotlib.pyplot as plt 9 | import numpy as np 10 | import os 11 | 12 | 13 | ## Save images 14 | import scipy.misc 15 | 16 | def save_image(image, image_path): 17 | """Save one image. 18 | 19 | Parameters 20 | ----------- 21 | images : numpy array [w, h, c] 22 | image_path : string. 23 | """ 24 | scipy.misc.imsave(image_path, image) 25 | 26 | def save_images(images, size, image_path): 27 | """Save mutiple images into one single image. 28 | 29 | Parameters 30 | ----------- 31 | images : numpy array [batch, w, h, c] 32 | size : list of two int, row and column number. 33 | number of images should be equal or less than size[0] * size[1] 34 | image_path : string. 35 | 36 | Examples 37 | --------- 38 | >>> images = np.random.rand(64, 100, 100, 3) 39 | >>> tl.visualize.save_images(images, [8, 8], 'temp.png') 40 | """ 41 | def merge(images, size): 42 | h, w = images.shape[1], images.shape[2] 43 | img = np.zeros((h * size[0], w * size[1], 3)) 44 | for idx, image in enumerate(images): 45 | i = idx % size[1] 46 | j = idx // size[1] 47 | img[j*h:j*h+h, i*w:i*w+w, :] = image 48 | return img 49 | 50 | def imsave(images, size, path): 51 | return scipy.misc.imsave(path, merge(images, size)) 52 | 53 | assert len(images) <= size[0] * size[1], "number of images should be equal or less than size[0] * size[1] {}".format(len(images)) 54 | return imsave(images, size, image_path) 55 | 56 | def W(W=None, second=10, saveable=True, shape=[28,28], name='mnist', fig_idx=2396512): 57 | """Visualize every columns of the weight matrix to a group of Greyscale img. 58 | 59 | Parameters 60 | ---------- 61 | W : numpy.array 62 | The weight matrix 63 | second : int 64 | The display second(s) for the image(s), if saveable is False. 65 | saveable : boolean 66 | Save or plot the figure. 67 | shape : a list with 2 int 68 | The shape of feature image, MNIST is [28, 80]. 69 | name : a string 70 | A name to save the image, if saveable is True. 71 | fig_idx : int 72 | matplotlib figure index. 73 | 74 | Examples 75 | -------- 76 | >>> tl.visualize.W(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012) 77 | """ 78 | if saveable is False: 79 | plt.ion() 80 | fig = plt.figure(fig_idx) # show all feature images 81 | size = W.shape[0] 82 | n_units = W.shape[1] 83 | 84 | num_r = int(np.sqrt(n_units)) # 每行显示的个数 若25个hidden unit -> 每行显示5个 85 | num_c = int(np.ceil(n_units/num_r)) 86 | count = int(1) 87 | for row in range(1, num_r+1): 88 | for col in range(1, num_c+1): 89 | if count > n_units: 90 | break 91 | a = fig.add_subplot(num_r, num_c, count) 92 | # ------------------------------------------------------------ 93 | # plt.imshow(np.reshape(W[:,count-1],(28,28)), cmap='gray') 94 | # ------------------------------------------------------------ 95 | feature = W[:,count-1] / np.sqrt( (W[:,count-1]**2).sum()) 96 | # feature[feature<0.0001] = 0 # value threshold 97 | # if count == 1 or count == 2: 98 | # print(np.mean(feature)) 99 | # if np.std(feature) < 0.03: # condition threshold 100 | # feature = np.zeros_like(feature) 101 | # if np.mean(feature) < -0.015: # condition threshold 102 | # feature = np.zeros_like(feature) 103 | plt.imshow(np.reshape(feature ,(shape[0],shape[1])), 104 | cmap='gray', interpolation="nearest")#, vmin=np.min(feature), vmax=np.max(feature)) 105 | # plt.title(name) 106 | # ------------------------------------------------------------ 107 | # plt.imshow(np.reshape(W[:,count-1] ,(np.sqrt(size),np.sqrt(size))), cmap='gray', interpolation="nearest") 108 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 109 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 110 | count = count + 1 111 | if saveable: 112 | plt.savefig(name+'.pdf',format='pdf') 113 | else: 114 | plt.draw() 115 | plt.pause(second) 116 | 117 | def frame(I=None, second=5, saveable=True, name='frame', cmap=None, fig_idx=12836): 118 | """Display a frame(image). Make sure OpenAI Gym render() is disable before using it. 119 | 120 | Parameters 121 | ---------- 122 | I : numpy.array 123 | The image 124 | second : int 125 | The display second(s) for the image(s), if saveable is False. 126 | saveable : boolean 127 | Save or plot the figure. 128 | name : a string 129 | A name to save the image, if saveable is True. 130 | cmap : None or string 131 | 'gray' for greyscale, None for default, etc. 132 | fig_idx : int 133 | matplotlib figure index. 134 | 135 | Examples 136 | -------- 137 | >>> env = gym.make("Pong-v0") 138 | >>> observation = env.reset() 139 | >>> tl.visualize.frame(observation) 140 | """ 141 | if saveable is False: 142 | plt.ion() 143 | fig = plt.figure(fig_idx) # show all feature images 144 | 145 | if len(I.shape) and I.shape[-1]==1: # (10,10,1) --> (10,10) 146 | I = I[:,:,0] 147 | 148 | plt.imshow(I, cmap) 149 | plt.title(name) 150 | # plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 151 | # plt.gca().yaxis.set_major_locator(plt.NullLocator()) 152 | 153 | if saveable: 154 | plt.savefig(name+'.pdf',format='pdf') 155 | else: 156 | plt.draw() 157 | plt.pause(second) 158 | 159 | def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362): 160 | """Display a group of RGB or Greyscale CNN masks. 161 | 162 | Parameters 163 | ---------- 164 | CNN : numpy.array 165 | The image. e.g: 64 5x5 RGB images can be (5, 5, 3, 64). 166 | second : int 167 | The display second(s) for the image(s), if saveable is False. 168 | saveable : boolean 169 | Save or plot the figure. 170 | name : a string 171 | A name to save the image, if saveable is True. 172 | fig_idx : int 173 | matplotlib figure index. 174 | 175 | Examples 176 | -------- 177 | >>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012) 178 | """ 179 | # print(CNN.shape) # (5, 5, 3, 64) 180 | # exit() 181 | n_mask = CNN.shape[3] 182 | n_row = CNN.shape[0] 183 | n_col = CNN.shape[1] 184 | n_color = CNN.shape[2] 185 | row = int(np.sqrt(n_mask)) 186 | col = int(np.ceil(n_mask/row)) 187 | plt.ion() # active mode 188 | fig = plt.figure(fig_idx) 189 | count = 1 190 | for ir in range(1, row+1): 191 | for ic in range(1, col+1): 192 | if count > n_mask: 193 | break 194 | a = fig.add_subplot(col, row, count) 195 | # print(CNN[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5 196 | # exit() 197 | # plt.imshow( 198 | # np.reshape(CNN[count-1,:,:,:], (n_row, n_col)), 199 | # cmap='gray', interpolation="nearest") # theano 200 | if n_color == 1: 201 | plt.imshow( 202 | np.reshape(CNN[:,:,:,count-1], (n_row, n_col)), 203 | cmap='gray', interpolation="nearest") 204 | elif n_color == 3: 205 | plt.imshow( 206 | np.reshape(CNN[:,:,:,count-1], (n_row, n_col, n_color)), 207 | cmap='gray', interpolation="nearest") 208 | else: 209 | raise Exception("Unknown n_color") 210 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 211 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 212 | count = count + 1 213 | if saveable: 214 | plt.savefig(name+'.pdf',format='pdf') 215 | else: 216 | plt.draw() 217 | plt.pause(second) 218 | 219 | 220 | def images2d(images=None, second=10, saveable=True, name='images', dtype=None, 221 | fig_idx=3119362): 222 | """Display a group of RGB or Greyscale images. 223 | 224 | Parameters 225 | ---------- 226 | images : numpy.array 227 | The images. 228 | second : int 229 | The display second(s) for the image(s), if saveable is False. 230 | saveable : boolean 231 | Save or plot the figure. 232 | name : a string 233 | A name to save the image, if saveable is True. 234 | dtype : None or numpy data type 235 | The data type for displaying the images. 236 | fig_idx : int 237 | matplotlib figure index. 238 | 239 | Examples 240 | -------- 241 | >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) 242 | >>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212) 243 | """ 244 | # print(images.shape) # (50000, 32, 32, 3) 245 | # exit() 246 | if dtype: 247 | images = np.asarray(images, dtype=dtype) 248 | n_mask = images.shape[0] 249 | n_row = images.shape[1] 250 | n_col = images.shape[2] 251 | n_color = images.shape[3] 252 | row = int(np.sqrt(n_mask)) 253 | col = int(np.ceil(n_mask/row)) 254 | plt.ion() # active mode 255 | fig = plt.figure(fig_idx) 256 | count = 1 257 | for ir in range(1, row+1): 258 | for ic in range(1, col+1): 259 | if count > n_mask: 260 | break 261 | a = fig.add_subplot(col, row, count) 262 | # print(images[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5 263 | # plt.imshow( 264 | # np.reshape(images[count-1,:,:,:], (n_row, n_col)), 265 | # cmap='gray', interpolation="nearest") # theano 266 | if n_color == 1: 267 | plt.imshow( 268 | np.reshape(images[count-1,:,:], (n_row, n_col)), 269 | cmap='gray', interpolation="nearest") 270 | # plt.title(name) 271 | elif n_color == 3: 272 | plt.imshow(images[count-1,:,:], 273 | cmap='gray', interpolation="nearest") 274 | # plt.title(name) 275 | else: 276 | raise Exception("Unknown n_color") 277 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick 278 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 279 | count = count + 1 280 | if saveable: 281 | plt.savefig(name+'.pdf',format='pdf') 282 | else: 283 | plt.draw() 284 | plt.pause(second) 285 | 286 | def tsne_embedding(embeddings, reverse_dictionary, plot_only=500, 287 | second=5, saveable=False, name='tsne', fig_idx=9862): 288 | """Visualize the embeddings by using t-SNE. 289 | 290 | Parameters 291 | ---------- 292 | embeddings : a matrix 293 | The images. 294 | reverse_dictionary : a dictionary 295 | id_to_word, mapping id to unique word. 296 | plot_only : int 297 | The number of examples to plot, choice the most common words. 298 | second : int 299 | The display second(s) for the image(s), if saveable is False. 300 | saveable : boolean 301 | Save or plot the figure. 302 | name : a string 303 | A name to save the image, if saveable is True. 304 | fig_idx : int 305 | matplotlib figure index. 306 | 307 | Examples 308 | -------- 309 | >>> see 'tutorial_word2vec_basic.py' 310 | >>> final_embeddings = normalized_embeddings.eval() 311 | >>> tl.visualize.tsne_embedding(final_embeddings, labels, reverse_dictionary, 312 | ... plot_only=500, second=5, saveable=False, name='tsne') 313 | """ 314 | def plot_with_labels(low_dim_embs, labels, figsize=(18, 18), second=5, 315 | saveable=True, name='tsne', fig_idx=9862): 316 | assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings" 317 | if saveable is False: 318 | plt.ion() 319 | plt.figure(fig_idx) 320 | plt.figure(figsize=figsize) #in inches 321 | for i, label in enumerate(labels): 322 | x, y = low_dim_embs[i,:] 323 | plt.scatter(x, y) 324 | plt.annotate(label, 325 | xy=(x, y), 326 | xytext=(5, 2), 327 | textcoords='offset points', 328 | ha='right', 329 | va='bottom') 330 | if saveable: 331 | plt.savefig(name+'.pdf',format='pdf') 332 | else: 333 | plt.draw() 334 | plt.pause(second) 335 | 336 | try: 337 | from sklearn.manifold import TSNE 338 | import matplotlib.pyplot as plt 339 | from six.moves import xrange 340 | 341 | tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) 342 | # plot_only = 500 343 | low_dim_embs = tsne.fit_transform(embeddings[:plot_only,:]) 344 | labels = [reverse_dictionary[i] for i in xrange(plot_only)] 345 | plot_with_labels(low_dim_embs, labels, second=second, saveable=saveable, \ 346 | name=name, fig_idx=fig_idx) 347 | except ImportError: 348 | print("Please install sklearn and matplotlib to visualize embeddings.") 349 | 350 | 351 | # 352 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorlayer as tl 3 | from tensorlayer.prepro import * 4 | # from config import config, log_config 5 | # 6 | # img_path = config.TRAIN.img_path 7 | 8 | import scipy 9 | import numpy as np 10 | 11 | def get_imgs_fn(file_name, path): 12 | """ Input an image path and name, return an image array """ 13 | # return scipy.misc.imread(path + file_name).astype(np.float) 14 | return scipy.misc.imread(path + file_name, mode='RGB') 15 | 16 | def crop_sub_imgs_fn(x, is_random=True): 17 | x = crop(x, wrg=384, hrg=384, is_random=is_random) 18 | x = x / (255. / 2.) 19 | x = x - 1. 20 | return x 21 | 22 | def downsample_fn(x): 23 | # We obtained the LR images by downsampling the HR images using bicubic kernel with downsampling factor r = 4. 24 | x = imresize(x, size=[96, 96], interp='bicubic', mode=None) 25 | x = x / (255. / 2.) 26 | x = x - 1. 27 | return x 28 | --------------------------------------------------------------------------------