├── README.md ├── figures ├── demo.png └── framework.png ├── main.py ├── model.py ├── ops.py ├── prepare.py ├── requirements.txt ├── scripts ├── recon_flower.sh ├── test_flower.sh └── train_flower.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | CAIS is a [Tensorflow](http://tensorflow.org/)-based framework for training and testing of our paper: ***Component Aware Image Steganography via Adversarial Global-and-Part Checking***. 3 | 4 | ## News 5 | **2022.06.09** - Our paper is accepted by IEEE Transactions on Neural Networks and Learning Systems (TNNLS) 6 | 7 | ## Installation 8 | 1. We use [Anaconda3](https://www.anaconda.com/products/individual) as the basic environment. If you have installed the Anaconda3 in path `Conda_Path`, please create a new virtual environment by `conda create -n tf114`, then `source activate tf114`. Install `tensorflow-gpu` using the command `conda install tensorflow-gpu==1.14.0`. 9 | 2. Install the dependencies by `pip install -r requirements.txt` (if necessary). The `requirements.txt` file is provided in this package. 10 | 3. Please download the pre-trained VGG19 model [imagenet-vgg-verydeep-19.mat](https://pan.baidu.com/s/1wjOFCzV5pxVhgti-T7-24Q) (passwd:a8yv), then place it at the current path. 11 | 12 | ## Data preparation 13 | ### 102Flowers (Within-domain) 14 | Please download the original image files from [this](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz). Decompress this file and prepare the training and testing image files as follows: 15 | ``` 16 | mkdir datasets 17 | cd datasets 18 | mkdir flower 19 | # The directory structure of flower should be this: 20 | ├──flower 21 | ├── train_cover 22 | ├── cover1.jpg 23 | └── ... 24 | ├── train_message 25 | ├── message1.jpg 26 | └── ... 27 | ├── test_cover 28 | ├── test_a.jpg (The test cover image that you want) 29 | └── ... 30 | ├── test_message 31 | ├── test_b.jpg (The test message image that you want) 32 | └── ... 33 | ``` 34 | We also provide a simple `prepare.py` file to randomly split the images. Please edit the `img_path` to specify the image path before running this file. 35 | 36 | ### 102Flowers and Caricature dataset (Cross-domain) 37 | Please download the [caricature image dataset](https://www.kaggle.com/ranjeetapegu/caricature-image). We follow the training/testing split of this dataset. And prepare the traing/testing images as follows: 38 | ``` 39 | mkdir flowercari 40 | # The directory structure of flower should be this: 41 | ├──flowercari 42 | ├── train_cover 43 | ├── cover1.jpg (the same 7000 flower images) 44 | └── ... 45 | ├── train_message (the train caricature images) 46 | ├── message1.jpg 47 | └── ... 48 | ├── test_cover 49 | ├── test_a.jpg (The rest flower images) 50 | └── ... 51 | ├── test_message 52 | ├── test_b.jpg (The test caricature images) 53 | └── ... 54 | ``` 55 | 56 | ## Train 57 | `sh scripts/train_flower.sh`. 58 | You can also edit the default parameters referring the `main.py`. 59 | 60 | ## Test 61 | `sh scripts/test_flower.sh`. Generate the steganographic images by using random two images: one cover image and one message image. 62 | 63 | ## Reconstruction 64 | `sh scripts/recon_flower.sh`. Reconstruct the message images by using the steganographic images. Please specify the `stegano_dir` and `recon_dir` while running this procedure. 65 | 66 | ## Losses 67 | - [Perceptual loss](https://arxiv.org/abs/1603.08155). 68 | - `LSGAN`: [Least Square GAN](https://arxiv.org/abs/1703.07737). 69 | 70 | ## Image steganography description 71 | ![Demo illustrtion](figures/demo.png) 72 | 73 | ## The detailed CAIS model 74 | ![Overview framework](figures/framework.png) 75 | 76 | ## Pre-trained model 77 | * Flower: [Google Drive](https://drive.google.com/drive/folders/1F4o41UNba8I02hsx0JAmz_QeRUoHsHBY?usp=sharing); [BaiduYun](https://pan.baidu.com/s/1mF1UpfYLs_O9s1r4-0LCtA) (1u8c) 78 | * Flowercari: [Google Drive](https://drive.google.com/drive/folders/1F4o41UNba8I02hsx0JAmz_QeRUoHsHBY?usp=sharing); [BaiduYun](https://pan.baidu.com/s/1NOyNTDgdpR44fjt_CciWTw) (ykc3) 79 | * ImageNet: [Google Drive](https://drive.google.com/drive/folders/1F4o41UNba8I02hsx0JAmz_QeRUoHsHBY?usp=sharing); [BaiduYun](https://pan.baidu.com/s/1D0yst70-_1OiQZfw7i_taA) (9fto) 80 | 81 | All the models are trained at the resolution `256*256`. Please put them at the path `./check` and unzip them. 82 | 83 | 84 | ## TODO list 85 | * Add the analysis tools. 86 | 87 | ## Citation 88 | 89 | If you find our work useful in your research, please consider citing: 90 | 91 | ```bibtex 92 | @article{zheng2022composition, 93 | title={Composition-Aware Image Steganography Through Adversarial Self-Generated Supervision}, 94 | author={Zheng, Ziqiang and Hu, Yuanmeng and Bin, Yi and Xu, Xing and Yang, Yang and Shen, Heng Tao}, 95 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 96 | year={2022}, 97 | publisher={IEEE} 98 | } 99 | ``` 100 | 101 | ## Acknowledgments 102 | Code borrows from [CycleGAN](https://github.com/junyanz/CycleGAN) and [DCGAN](https://github.com/carpedm20/DCGAN-tensorflow). The network architecture design is modified from DCGAN. The generative network is adopted from neural-style with Instance Normalization. 103 | -------------------------------------------------------------------------------- /figures/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengziqiang/CAIS/a76e8c727b8890331e6d4e363a7d833aaeef5b41/figures/demo.png -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengziqiang/CAIS/a76e8c727b8890331e6d4e363a7d833aaeef5b41/figures/framework.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tensorflow as tf 4 | from model import cyclegan 5 | parser = argparse.ArgumentParser(description='') 6 | parser.add_argument('--dataset_dir', dest='dataset_dir', default='flower', help='path of the dataset') 7 | parser.add_argument('--epoch', dest='epoch', type=int, default=40, help='# of epoch') 8 | parser.add_argument('--epoch_step', dest='epoch_step', type=int, default=20, help='# of epoch to decay lr') 9 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch') 10 | parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train') 11 | parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size') 12 | parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size') 13 | parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer') 14 | parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer') 15 | parser.add_argument('--gpu', dest='gpu', type=int, default=0, help='# index of gpu device') 16 | parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels') 17 | parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels') 18 | parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam') 19 | parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam') 20 | parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA') 21 | parser.add_argument('--phase', dest='phase', default='train', help='train, test') 22 | parser.add_argument('--save_freq', dest='save_freq', type=int, default=1000, 23 | help='save a model every save_freq iterations') 24 | parser.add_argument('--print_freq', dest='print_freq', type=int, default=100, 25 | help='print the debug information every print_freq iterations') 26 | parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, 27 | help='if continue training, load the latest model: 1: true, 0: false') 28 | parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here') 29 | parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here') 30 | parser.add_argument('--stegano_dir', dest='stegano_dir', default='./stegano', help='steganographic images are saved here') 31 | parser.add_argument('--recon_dir', dest='recon_dir', default='./recon', help='reconstructed images are saved here') 32 | parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here') 33 | parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=10.0, help='weight on L1 term in objective') 34 | parser.add_argument('--use_lsgan', dest='use_lsgan', type=bool, default=True, help='gan loss defined in lsgan') 35 | parser.add_argument('--max_size', dest='max_size', type=int, default=50, 36 | help='max size of image pool, 0 means do not use image pool') 37 | 38 | args = parser.parse_args() 39 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 40 | def main(_): 41 | if not os.path.exists(args.checkpoint_dir): 42 | os.makedirs(args.checkpoint_dir) 43 | if not os.path.exists(args.sample_dir): 44 | os.makedirs(args.sample_dir) 45 | if not os.path.exists(args.test_dir): 46 | os.makedirs(args.test_dir) 47 | 48 | tfconfig = tf.ConfigProto(allow_soft_placement=True) 49 | tfconfig.gpu_options.allow_growth = True 50 | with tf.Session(config=tfconfig) as sess: 51 | model = cyclegan(sess, args) 52 | if args.phase== 'train': 53 | model.train(args) 54 | if args.phase== 'test': 55 | model.test(args) 56 | if args.phase== 'recon': 57 | model.test_reverse(args) 58 | if __name__ == '__main__': 59 | tf.app.run() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from ops import * 3 | from utils import * 4 | from glob import glob 5 | import time 6 | import scipy.io 7 | def build_net_vgg(ntype, nin, nwb=None, name=None): 8 | if ntype == 'conv': 9 | return tf.nn.relu(tf.nn.conv2d(nin, nwb[0], strides=[1, 1, 1, 1], padding='SAME', name=name) + nwb[1]) 10 | elif ntype == 'pool': 11 | return tf.nn.avg_pool(nin, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 12 | 13 | 14 | 15 | def get_weight_bias(vgg_layers, i): 16 | weights = vgg_layers[i][0][0][2][0][0] 17 | weights = tf.constant(weights) 18 | bias = vgg_layers[i][0][0][2][0][1] 19 | bias = tf.constant(np.reshape(bias, (bias.size))) 20 | return weights, bias 21 | 22 | def build_vgg19(input, reuse=False): 23 | if reuse: 24 | tf.get_variable_scope().reuse_variables() 25 | net = {} 26 | vgg_rawnet = scipy.io.loadmat('imagenet-vgg-verydeep-19.mat') 27 | vgg_layers = vgg_rawnet['layers'][0] 28 | net['input'] = input - np.array([-0.029960784313725397, -0.084086274509804, -0.1847921568627452]).reshape((1, 1, 1, 3)) 29 | net['conv1_1'] = build_net_vgg('conv', net['input'], get_weight_bias(vgg_layers, 0), name='vgg_conv1_1') 30 | net['conv1_2'] = build_net_vgg('conv', net['conv1_1'], get_weight_bias(vgg_layers, 2), name='vgg_conv1_2') 31 | net['pool1'] = build_net_vgg('pool', net['conv1_2']) 32 | net['conv2_1'] = build_net_vgg('conv', net['pool1'], get_weight_bias(vgg_layers, 5), name='vgg_conv2_1') 33 | net['conv2_2'] = build_net_vgg('conv', net['conv2_1'], get_weight_bias(vgg_layers, 7), name='vgg_conv2_2') 34 | net['pool2'] = build_net_vgg('pool', net['conv2_2']) 35 | net['conv3_1'] = build_net_vgg('conv', net['pool2'], get_weight_bias(vgg_layers, 10), name='vgg_conv3_1') 36 | net['conv3_2'] = build_net_vgg('conv', net['conv3_1'], get_weight_bias(vgg_layers, 12), name='vgg_conv3_2') 37 | net['conv3_3'] = build_net_vgg('conv', net['conv3_2'], get_weight_bias(vgg_layers, 14), name='vgg_conv3_3') 38 | net['conv3_4'] = build_net_vgg('conv', net['conv3_3'], get_weight_bias(vgg_layers, 16), name='vgg_conv3_4') 39 | net['pool3'] = build_net_vgg('pool', net['conv3_4']) 40 | net['conv4_1'] = build_net_vgg('conv', net['pool3'], get_weight_bias(vgg_layers, 19), name='vgg_conv4_1') 41 | net['conv4_2'] = build_net_vgg('conv', net['conv4_1'], get_weight_bias(vgg_layers, 21), name='vgg_conv4_2') 42 | net['conv4_3'] = build_net_vgg('conv', net['conv4_2'], get_weight_bias(vgg_layers, 23), name='vgg_conv4_3') 43 | net['conv4_4'] = build_net_vgg('conv', net['conv4_3'], get_weight_bias(vgg_layers, 25), name='vgg_conv4_4') 44 | net['pool4'] = build_net_vgg('pool', net['conv4_4']) 45 | net['conv5_1'] = build_net_vgg('conv', net['pool4'], get_weight_bias(vgg_layers, 28), name='vgg_conv5_1') 46 | net['conv5_2'] = build_net_vgg('conv', net['conv5_1'], get_weight_bias(vgg_layers, 30), name='vgg_conv5_2') 47 | net['conv5_3'] = build_net_vgg('conv', net['conv5_2'], get_weight_bias(vgg_layers, 32), name='vgg_conv5_3') 48 | net['conv5_4'] = build_net_vgg('conv', net['conv5_3'], get_weight_bias(vgg_layers, 34), name='vgg_conv5_4') 49 | net['pool5'] = build_net_vgg('pool', net['conv5_4']) 50 | return net 51 | 52 | def compute_error(real, fake): 53 | return tf.reduce_mean(tf.abs(real - fake)) 54 | 55 | def discriminator(image, options, reuse=False, name="discriminator"): 56 | with tf.variable_scope(name): 57 | # image is 256 x 256 x input_c_dim 58 | if reuse: 59 | tf.get_variable_scope().reuse_variables() 60 | else: 61 | assert tf.get_variable_scope().reuse is False 62 | h0 = lrelu(conv2d(image, options.df_dim, name='d_h0_conv')) 63 | h1 = lrelu(instance_norm(conv2d(h0, options.df_dim*2, name='d_h1_conv'), 'd_bn1')) 64 | h2 = lrelu(instance_norm(conv2d(h1, options.df_dim*4, name='d_h2_conv'), 'd_bn2')) 65 | h3 = lrelu(instance_norm(conv2d(h2, options.df_dim*8, name='d_h3_conv'), 'd_bn3')) 66 | h4_logit = conv2d(h3, 1, s=1, name='d_h3_pred') 67 | h4 = lrelu(instance_norm(conv2d(h3, options.df_dim * 8, name='d_h4_conv'), 'd_bn4')) 68 | h5 = lrelu(instance_norm(conv2d(h4, options.df_dim * 4, name='d_h5_conv'), 'd_bn5')) 69 | h6 = lrelu(instance_norm(conv2d(h5, options.df_dim * 2, name='d_h6_conv'), 'd_bn6')) 70 | h6_logit = conv2d(h6, 1, s=1, name='d_h6_pred') 71 | return h4_logit,tf.reshape(tf.reduce_mean(h6_logit,axis=[1,2]),[1,1,1,-1]) 72 | 73 | def generator_resnet(image, options, reuse=False, name="hiding"): 74 | with tf.variable_scope(name): 75 | # image is 256 x 256 x input_c_dim 76 | if reuse: 77 | tf.get_variable_scope().reuse_variables() 78 | else: 79 | assert tf.get_variable_scope().reuse is False 80 | 81 | def residule_block(x, dim, ks=3, s=1, name='res'): 82 | p = int((ks - 1) / 2) 83 | y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 84 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name + '_c1'), name + '_bn1') 85 | y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 86 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name + '_c2'), name + '_bn2') 87 | return y + x 88 | c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 89 | c1 = tf.nn.relu(instance_norm(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn')) 90 | c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim * 2, 3, 2, name='g_e2_c'), 'g_e2_bn')) 91 | c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim * 4, 3, 2, name='g_e3_c'), 'g_e3_bn')) 92 | # define G network with 9 resnet blocks 93 | r1 = residule_block(c3, options.gf_dim * 4, name='g_r1') 94 | r2 = residule_block(r1, options.gf_dim * 4, name='g_r2') 95 | r3 = residule_block(r2, options.gf_dim * 4, name='g_r3') 96 | r4 = residule_block(r3, options.gf_dim * 4, name='g_r4') 97 | r5 = residule_block(r4, options.gf_dim * 4, name='g_r5') 98 | r6 = residule_block(r5, options.gf_dim * 4, name='g_r6') 99 | r7 = residule_block(r6, options.gf_dim * 4, name='g_r7') 100 | r8 = residule_block(r7, options.gf_dim * 4, name='g_r8') 101 | r9 = residule_block(r8, options.gf_dim * 4, name='g_r9') 102 | d1 = deconv2d(r9, options.gf_dim * 2, 3, 2, name='g_d1_dc') 103 | d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn')) 104 | d2 = deconv2d(d1, options.gf_dim, 3, 2, name='g_d2_dc') 105 | d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn')) 106 | d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 107 | pred = tf.nn.tanh(conv2d(d2, options.output_c_dim, 7, 1, padding='VALID', name='g_pred_c')) 108 | 109 | return pred 110 | 111 | def generator_resnet_recon(image, options, reuse=False, name="revealing"): 112 | with tf.variable_scope(name): 113 | if reuse: 114 | tf.get_variable_scope().reuse_variables() 115 | else: 116 | assert tf.get_variable_scope().reuse is False 117 | 118 | def residule_block(x, dim, ks=3, s=1, name='res'): 119 | p = int((ks - 1) / 2) 120 | y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 121 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name + '_c1'), name + '_bn1') 122 | y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT") 123 | y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name + '_c2'), name + '_bn2') 124 | return y + x 125 | 126 | c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 127 | c1 = tf.nn.relu(instance_norm(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn')) 128 | c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim * 2, 3, 2, name='g_e2_c'), 'g_e2_bn')) 129 | c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim * 4, 3, 2, name='g_e3_c'), 'g_e3_bn')) 130 | # define G network with 9 resnet blocks 131 | r1 = residule_block(c3, options.gf_dim * 4, name='g_r1') 132 | r2 = residule_block(r1, options.gf_dim * 4, name='g_r2') 133 | r3 = residule_block(r2, options.gf_dim * 4, name='g_r3') 134 | r4 = residule_block(r3, options.gf_dim * 4, name='g_r4') 135 | r5 = residule_block(r4, options.gf_dim * 4, name='g_r5') 136 | r6 = residule_block(r5, options.gf_dim * 4, name='g_r6') 137 | r7 = residule_block(r6, options.gf_dim * 4, name='g_r7') 138 | r8 = residule_block(r7, options.gf_dim * 4, name='g_r8') 139 | r9 = residule_block(r8, options.gf_dim * 4, name='g_r9') 140 | 141 | d1 = deconv2d(r9, options.gf_dim * 2, 3, 2, name='g_d1_dc') 142 | d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn')) 143 | d2 = deconv2d(d1, options.gf_dim, 3, 2, name='g_d2_dc') 144 | d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn')) 145 | d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") 146 | pred = tf.nn.tanh(conv2d(d2, options.output_c_dim*2, 7, 1, padding='VALID', name='g_pred_c')) 147 | 148 | return pred[:,:,:,:options.output_c_dim],pred[:,:,:,options.output_c_dim:options.output_c_dim*2] 149 | 150 | 151 | def abs_criterion(in_, target): 152 | return tf.reduce_mean(tf.abs(in_ - target)) 153 | 154 | 155 | def mae_criterion(in_, target): 156 | return tf.reduce_mean((in_ - target) ** 2) 157 | 158 | def sce_criterion(logits, labels): 159 | return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels)) 160 | 161 | class cyclegan(object): 162 | def __init__(self, sess, args): 163 | self.sess = sess 164 | self.batch_size = args.batch_size 165 | self.image_size = args.fine_size 166 | self.patch_size = self.image_size//2 167 | self.input_c_dim = args.input_nc 168 | self.output_c_dim = args.output_nc 169 | self.L1_lambda = args.L1_lambda 170 | self.dataset_dir = args.dataset_dir 171 | 172 | self.discriminator = discriminator 173 | self.generator = generator_resnet 174 | self.generator_recon=generator_resnet_recon 175 | 176 | if args.use_lsgan: 177 | self.criterionGAN = mae_criterion 178 | else: 179 | self.criterionGAN = sce_criterion 180 | 181 | OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \ 182 | gf_dim df_dim output_c_dim is_training') 183 | self.options = OPTIONS._make((args.batch_size, args.fine_size, 184 | args.ngf, args.ndf, args.output_nc, 185 | args.phase == 'train')) 186 | 187 | self._build_model() 188 | self.saver = tf.train.Saver() 189 | self.pool = ImagePool(args.max_size) 190 | 191 | def _build_model(self): 192 | self.real_data = tf.placeholder(tf.float32,[self.batch_size, self.image_size, self.image_size,self.input_c_dim*3],name='collected_images') 193 | 194 | self.alpha = tf.placeholder(tf.float32,[self.batch_size, 1],name='alpha') 195 | 196 | self.real_cover = self.real_data[:, :, :, :self.input_c_dim] 197 | self.real_message = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim*2] 198 | self.weighted = self.real_data[:, :, :, self.input_c_dim*2:self.input_c_dim*3] 199 | 200 | self.fake_stegano = self.generator(tf.concat([self.real_cover,self.real_message],axis=3), self.options, False, name="generator_hiding") 201 | self.recon_cover,self.recon_message = self.generator_recon(self.fake_stegano, self.options, False, name="generator_revealing") 202 | self.fake_part = tf.random_crop(self.fake_stegano, [1, self.patch_size, self.patch_size, 3]) 203 | self.real_part = tf.random_crop(self.real_cover, [1, self.patch_size, self.patch_size, 3]) 204 | 205 | 206 | self.Dstegano_fake,self.weighted_est= self.discriminator(self.fake_stegano, self.options, reuse=False, name="discriminator") 207 | self.Dstegano_fake_part, self.weighted_est_part = self.discriminator(self.fake_part, self.options, reuse=False,name="part") 208 | 209 | self.g_adv=(self.criterionGAN(self.Dstegano_fake, tf.ones_like(self.Dstegano_fake))+abs_criterion(self.weighted_est,tf.zeros_like(self.weighted_est))+ 210 | self.criterionGAN(self.Dstegano_fake_part, tf.ones_like(self.Dstegano_fake_part))+abs_criterion(self.weighted_est_part,tf.zeros_like(self.weighted_est_part))) 211 | 212 | self.g_est=(abs_criterion(self.weighted_est,tf.zeros_like(self.weighted_est))+abs_criterion(self.weighted_est_part,tf.zeros_like(self.weighted_est_part))) 213 | 214 | 215 | with tf.variable_scope("VGG_loss"): 216 | vgg_real = build_vgg19(self.real_message) 217 | vgg_fake = build_vgg19(self.recon_message, reuse=True) 218 | p1 = compute_error(vgg_real['conv1_2'], vgg_fake['conv1_2']) 219 | p2 = compute_error(vgg_real['conv2_2'], vgg_fake['conv2_2']) 220 | p3 = compute_error(vgg_real['conv3_2'], vgg_fake['conv3_2']) 221 | p4 = compute_error(vgg_real['conv4_2'], vgg_fake['conv4_2']) 222 | p5 = compute_error(vgg_real['conv5_2'], vgg_fake['conv5_2']) 223 | self.G_loss = (p1 + p2 + p3 + p4 + p5) * 0.1 224 | 225 | self.g_loss = self.g_adv\ 226 | + self.G_loss \ 227 | + self.L1_lambda * abs_criterion(self.real_cover, self.fake_stegano)\ 228 | + self.L1_lambda * abs_criterion(self.real_cover, self.recon_cover) \ 229 | + self.L1_lambda * abs_criterion(self.real_message, self.recon_message) 230 | 231 | self.fake_stegano_sample = tf.placeholder(tf.float32,[self.batch_size, self.image_size, self.image_size,self.output_c_dim], name='fake_stegano_sample') 232 | self.fake_stegano_sample_part = tf.random_crop(self.fake_stegano_sample, [1, self.patch_size, self.patch_size, 3]) 233 | self.weighted_part = tf.random_crop(self.weighted, [1, self.patch_size, self.patch_size, 3]) 234 | 235 | self.Dcover_real,self.Dcover_est = self.discriminator(self.real_cover, self.options, reuse=True, name="discriminator") 236 | self.Dstegano_fake_sample,_ = self.discriminator(self.fake_stegano_sample, self.options, reuse=True, name="discriminator") 237 | _, self.est_real = self.discriminator(self.weighted, self.options, reuse=True,name="discriminator") 238 | 239 | self.d_loss_real = self.criterionGAN(self.Dcover_real, tf.ones_like(self.Dcover_real)) 240 | self.d_loss_fake = self.criterionGAN(self.Dstegano_fake_sample, tf.zeros_like(self.Dstegano_fake_sample)) 241 | self.d_adv_loss = (self.d_loss_real + self.d_loss_fake) / 2 242 | self.Dcover_est_loss = abs_criterion(self.Dcover_est,tf.zeros_like(self.Dcover_est)) 243 | self.weight_loss = abs_criterion(self.est_real,tf.reshape(self.alpha,[1,1,1,-1])) 244 | self.d_loss = (self.d_adv_loss + self.Dcover_est_loss+self.weight_loss) 245 | self.est_loss=self.Dcover_est_loss+self.weight_loss 246 | 247 | self.Dcover_real_part, self.Dcover_est_part = self.discriminator(self.real_part, self.options, reuse=True,name="part") 248 | self.Dstegano_fake_sample_part, _ = self.discriminator(self.fake_stegano_sample_part, self.options, reuse=True,name="part") 249 | _, self.est_real_part = self.discriminator(self.weighted_part, self.options, reuse=True,name="part") 250 | self.d_loss_real_part = self.criterionGAN(self.Dcover_real_part, tf.ones_like(self.Dcover_real_part)) 251 | self.d_loss_fake_part = self.criterionGAN(self.Dstegano_fake_sample_part, tf.zeros_like(self.Dstegano_fake_sample_part)) 252 | self.d_adv_loss_part = (self.d_loss_real_part + self.d_loss_fake_part) / 2 253 | self.Dcover_est_loss_part = abs_criterion(self.Dcover_est_part, tf.zeros_like(self.Dcover_est_part)) 254 | self.weight_loss_part = abs_criterion(self.est_real_part, tf.reshape(self.alpha, [1, 1, 1, -1])) 255 | self.d_loss_part = (self.d_adv_loss_part + self.Dcover_est_loss_part + self.weight_loss_part) 256 | 257 | ### G summary 258 | self.g_adv_sum = tf.summary.scalar("g_adv", self.g_adv) 259 | self.g_est_sum = tf.summary.scalar("g_est", self.g_est) 260 | self.G_loss_sum = tf.summary.scalar("G_loss", self.G_loss) 261 | self.g_sum = tf.summary.merge([self.g_adv_sum,self.g_est_sum,self.G_loss_sum]) 262 | ### D summary 263 | self.d_adv_loss_sum = tf.summary.scalar("d_adv_loss", self.d_adv_loss) 264 | self.d_adv_loss_part_sum = tf.summary.scalar("d_adv_loss_part", self.d_adv_loss_part) 265 | 266 | self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss) 267 | self.d_loss_part_sum = tf.summary.scalar("d_loss_part", self.d_loss_part) 268 | 269 | self.d_loss_real_sum = tf.summary.scalar("d_loss_real", self.d_loss_real) 270 | self.d_loss_fake_sum = tf.summary.scalar("d_loss_fake", self.d_loss_fake) 271 | self.d_loss_real_part_sum = tf.summary.scalar("d_loss_real_part", self.d_loss_real_part) 272 | self.d_loss_fake_part_sum = tf.summary.scalar("d_loss_fake_part", self.d_loss_fake_part) 273 | 274 | self.d_sum = tf.summary.merge( 275 | [self.d_adv_loss_sum, self.d_loss_real_sum, self.d_loss_fake_sum, 276 | self.d_adv_loss_part_sum, self.d_loss_real_part_sum, self.d_loss_fake_part_sum, 277 | self.d_loss_sum,self.d_loss_part_sum] 278 | ) 279 | 280 | self.test_cover = tf.placeholder(tf.float32,[self.batch_size, self.image_size, self.image_size,self.input_c_dim], name='test_cover') 281 | self.test_message = tf.placeholder(tf.float32,[self.batch_size, self.image_size, self.image_size,self.output_c_dim], name='test_message') 282 | 283 | self.test_stega = self.generator(tf.concat([self.test_cover,self.test_message],axis=3), self.options, True, name="generator_hiding") 284 | self.test_cover_recon,self.test_message_recon = self.generator_recon(self.test_stega, self.options, True, name="generator_revealing") 285 | 286 | t_vars = tf.trainable_variables() 287 | self.g_vars = [var for var in t_vars if 'generator' in var.name] 288 | self.d_vars=[var for var in t_vars if 'discriminator' in var.name] 289 | self.part_vars = [var for var in t_vars if 'part' in var.name] 290 | 291 | for var in t_vars: print(var.name) 292 | 293 | def train(self, args): 294 | """Train cyclegan""" 295 | self.lr = tf.placeholder(tf.float32, None, name='learning_rate') 296 | self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1).minimize(self.d_loss, var_list=self.d_vars) 297 | self.part_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1).minimize(self.d_loss_part, var_list=self.part_vars) 298 | self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1).minimize(self.g_loss, var_list=self.g_vars) 299 | 300 | init_op = tf.global_variables_initializer() 301 | self.sess.run(init_op) 302 | self.writer = tf.summary.FileWriter(os.path.join(args.checkpoint_dir,"logs"), self.sess.graph) 303 | 304 | counter = 1 305 | start_time = time.time() 306 | 307 | if args.continue_train: 308 | if self.load(args.checkpoint_dir): 309 | print(" [*] Load SUCCESS") 310 | else: 311 | print(" [!] Load failed...") 312 | 313 | for epoch in range(args.epoch): 314 | data_cover = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/train_cover')) 315 | data_message = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/train_message')) 316 | np.random.shuffle(data_cover) 317 | np.random.shuffle(data_message) 318 | batch_idxs = min(min(len(data_cover), len(data_message)), args.train_size) // self.batch_size 319 | lr = args.lr if epoch < args.epoch_step else args.lr * (args.epoch - epoch) / (args.epoch - args.epoch_step) 320 | 321 | for idx in range(0, batch_idxs): 322 | batch_files = list(zip(data_cover[idx * self.batch_size:(idx + 1) * self.batch_size], 323 | data_message[idx * self.batch_size:(idx + 1) * self.batch_size])) 324 | batch_images,alpha = load_train_data(batch_files[0], args.load_size, args.fine_size) 325 | batch_images = [batch_images] 326 | alpha = [alpha] 327 | alpha = np.reshape(alpha,[self.batch_size,1]) 328 | batch_images = np.array(batch_images).astype(np.float32) 329 | # Update G network and record fake outputs 330 | fake_stegano,_,g_loss,g_est,g_adv,G_loss,summary_str = self.sess.run( 331 | [self.fake_stegano, self.g_optim,self.g_loss,self.g_est,self.g_adv,self.G_loss, self.g_sum], 332 | feed_dict={self.real_data: batch_images, self.lr: lr,self.alpha:alpha}) 333 | self.writer.add_summary(summary_str, counter) 334 | # Update D network 335 | _,_,d_loss,est_loss,d_summary_str = self.sess.run( 336 | [self.d_optim,self.part_optim, self.d_loss,self.est_loss,self.d_sum], 337 | feed_dict={self.real_data: batch_images, 338 | self.fake_stegano_sample: fake_stegano, 339 | self.lr: lr,self.alpha:alpha}) 340 | self.writer.add_summary(d_summary_str, counter) 341 | counter += 1 342 | print(("Epoch:[%2d][%4d/%4d] time: %4.4f g_loss: %4.4f g_est: %4.4f g_adv: %4.4f G_loss: %4.4f d_loss: %4.4f est_loss: %4.4f" % ( 343 | epoch, idx, batch_idxs, time.time() - start_time,g_loss,g_est,g_adv,G_loss,d_loss,est_loss))) 344 | if np.mod(counter, args.print_freq) == 1: 345 | self.sample_model(args.sample_dir, epoch, idx) 346 | 347 | if np.mod(counter, args.save_freq) == 2: 348 | self.save(args.checkpoint_dir, counter) 349 | 350 | def save(self, checkpoint_dir, step): 351 | model_name = "stegano.model" 352 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size) 353 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 354 | 355 | if not os.path.exists(checkpoint_dir): 356 | os.makedirs(checkpoint_dir) 357 | 358 | self.saver.save(self.sess, 359 | os.path.join(checkpoint_dir, model_name), 360 | global_step=step) 361 | 362 | def load(self, checkpoint_dir): 363 | print(" [*] Reading checkpoint...") 364 | 365 | model_dir = "%s_%s" % (self.dataset_dir, self.image_size) 366 | checkpoint_dir = os.path.join(checkpoint_dir, model_dir) 367 | 368 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 369 | if ckpt and ckpt.model_checkpoint_path: 370 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 371 | self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name)) 372 | return True 373 | else: 374 | return False 375 | 376 | def sample_model(self, sample_dir, epoch, idx): 377 | data_cover = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/test_cover')) 378 | data_message = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/test_message')) 379 | np.random.shuffle(data_cover) 380 | np.random.shuffle(data_message) 381 | batch_files = list(zip(data_cover[:self.batch_size], data_message[:self.batch_size])) 382 | sample_images,_ =load_train_data(batch_files[0],is_testing=True) 383 | sample_images=[sample_images] 384 | sample_images = np.array(sample_images).astype(np.float32) 385 | 386 | fake_stegano,rec_cover,rec_message = self.sess.run( 387 | [self.fake_stegano,self.recon_cover,self.recon_message], 388 | feed_dict={self.real_data: sample_images} 389 | ) 390 | real_cover = sample_images[:, :, :, :self.input_c_dim] 391 | real_message = sample_images[:, :, :, self.input_c_dim:self.input_c_dim*2] 392 | weighted = sample_images[:, :, :, self.input_c_dim*2:self.input_c_dim*3] 393 | 394 | merge = np.concatenate([real_cover,real_message,weighted, fake_stegano,rec_cover,rec_message], axis=2) 395 | check_folder('./{}/{:02d}'.format(sample_dir, epoch)) 396 | save_images(merge, [self.batch_size, 1], 397 | './{}/{:02d}/{:04d}.jpg'.format(sample_dir, epoch, idx)) 398 | 399 | def test(self, args): 400 | """Test cyclegan""" 401 | init_op = tf.global_variables_initializer() 402 | self.sess.run(init_op) 403 | 404 | if self.load(args.checkpoint_dir): 405 | print(" [*] Load SUCCESS") 406 | else: 407 | print(" [!] Load failed...") 408 | data_cover = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/test_cover')) 409 | data_message = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/test_message')) 410 | for sample_file in data_message: 411 | print('Processing image: ' + sample_file) 412 | random_idx=np.random.randint(0,len(data_cover)) 413 | cover_path=data_cover[random_idx] 414 | sample_image = [load_test_data(cover_path,sample_file, args.fine_size)] 415 | sample_image = np.array(sample_image).astype(np.float32) 416 | jpg_name=os.path.basename(sample_file) 417 | image_path = os.path.join(args.test_dir,'{0}_{1}'.format("merge",jpg_name[:-4]+".png" )) 418 | test_A,test_B,fake_img,recon_cover,recon_message = self.sess.run([self.test_cover,self.test_message,self.test_stega,self.test_cover_recon,self.test_message_recon], 419 | feed_dict={self.test_cover: sample_image[:,:,:,:3],self.test_message:sample_image[:,:,:,3:6]}) 420 | merge=np.concatenate([test_A,test_B,fake_img,recon_cover,recon_message],axis=2) 421 | save_images(merge, [1, 1], image_path) 422 | split_path = os.path.join(args.test_dir, jpg_name[:-4] + ".png") 423 | save_images(fake_img, [1, 1], split_path) 424 | 425 | def test_reverse(self, args): 426 | """Test cyclegan""" 427 | init_op = tf.global_variables_initializer() 428 | self.sess.run(init_op) 429 | 430 | if self.load(args.checkpoint_dir): 431 | print(" [*] Load SUCCESS") 432 | else: 433 | print(" [!] Load failed...") 434 | data_stegano = glob('{}/*.*'.format(args.stegano_dir)) 435 | for sample_file in data_stegano: 436 | print('Processing image: '+sample_file) 437 | sample_image = [load_reverse_data( sample_file, args.fine_size)] 438 | sample_image = np.array(sample_image).astype(np.float32) 439 | jpg_name = os.path.basename(sample_file) 440 | image_path = os.path.join(args.recon_dir, '{}'.format(jpg_name[:-4] + ".png")) 441 | recon_cover,recon_message = self.sess.run( 442 | [self.test_cover_recon,self.test_message_recon], 443 | feed_dict={self.test_stega: sample_image}) 444 | merge = np.concatenate([recon_cover,recon_message], axis=2) 445 | save_images(merge, [1, 1], image_path) -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | def instance_norm(input, name="instance_norm"): 4 | with tf.variable_scope(name): 5 | depth = input.get_shape()[3] 6 | scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32)) 7 | offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0)) 8 | mean, variance = tf.nn.moments(input, axes=[1, 2], keep_dims=True) 9 | epsilon = 1e-5 10 | inv = tf.rsqrt(variance + epsilon) 11 | normalized = (input - mean) * inv 12 | return scale * normalized + offset 13 | 14 | 15 | def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME', name="conv2d"): 16 | with tf.variable_scope(name): 17 | return slim.conv2d(input_, output_dim, ks, s, padding=padding, activation_fn=None, 18 | weights_initializer=tf.truncated_normal_initializer(stddev=stddev), 19 | biases_initializer=None) 20 | 21 | 22 | def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"): 23 | with tf.variable_scope(name): 24 | return slim.conv2d_transpose(input_, output_dim, ks, s, padding='SAME', activation_fn=None, 25 | weights_initializer=tf.truncated_normal_initializer(stddev=stddev), 26 | biases_initializer=None) 27 | 28 | 29 | def lrelu(x, leak=0.2, name="lrelu"): 30 | return tf.maximum(x, leak * x) 31 | 32 | 33 | def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): 34 | with tf.variable_scope(scope or "Linear"): 35 | matrix = tf.get_variable("Matrix", [input_.get_shape()[-1], output_size], tf.float32, 36 | tf.random_normal_initializer(stddev=stddev)) 37 | bias = tf.get_variable("bias", [output_size], 38 | initializer=tf.constant_initializer(bias_start)) 39 | if with_w: 40 | return tf.matmul(input_, matrix) + bias, matrix, bias 41 | else: 42 | return tf.matmul(input_, matrix) + bias -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import shutil 4 | img_path="" 5 | cnt=0 6 | for files in glob.glob(img_path+"/*.*"): 7 | p,n=os.path.split(files) 8 | if cnt<7000: 9 | shutil.copyfile(files,"datasets/flower/train_cover/"+n) 10 | shutil.copyfile(files, "datasets/flower/train_message/" + n) 11 | else: 12 | shutil.copyfile(files, "datasets/flower/test_cover/" + n) 13 | shutil.copyfile(files, "datasets/flower/test_message/" + n) 14 | cnt+=1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy==1.1.0 3 | pillow 4 | cv2 5 | gast=0.2.2 -------------------------------------------------------------------------------- /scripts/recon_flower.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python main.py --phase recon --stegano_dir ./check/flower/stegano --recon_dir ./check/flower/recon --gpu 0 --checkpoint_dir ./check/flower -------------------------------------------------------------------------------- /scripts/test_flower.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python main.py --phase test --epoch 120 --dataset_dir flower --checkpoint_dir ./check/flower --test_dir ./check/flower/test --gpu 0 -------------------------------------------------------------------------------- /scripts/train_flower.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python main.py --phase train --epoch 40 --dataset_dir flower --checkpoint_dir ./check/flower --sample_dir ./check/flower/sample --gpu 0 --L1_lambda 10 3 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pprint 3 | import scipy.misc 4 | import numpy as np 5 | import copy 6 | import os 7 | import cv2 8 | pp = pprint.PrettyPrinter() 9 | get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1]) 10 | # ----------------------------- 11 | # new added functions for cyclegan 12 | class ImagePool(object): 13 | def __init__(self, maxsize=50): 14 | self.maxsize = maxsize 15 | self.num_img = 0 16 | self.images = [] 17 | 18 | def __call__(self, image): 19 | if self.maxsize <= 0: 20 | return image 21 | if self.num_img < self.maxsize: 22 | self.images.append(image) 23 | self.num_img += 1 24 | return image 25 | if np.random.rand() > 0.5: 26 | idx = int(np.random.rand()*self.maxsize) 27 | tmp1 = copy.copy(self.images[idx])[0] 28 | self.images[idx][0] = image[0] 29 | idx = int(np.random.rand()*self.maxsize) 30 | tmp2 = copy.copy(self.images[idx])[1] 31 | self.images[idx][1] = image[1] 32 | return [tmp1, tmp2] 33 | else: 34 | return image 35 | 36 | def load_test_data(image_path,image_path_B, fine_size=256): 37 | img = imread(image_path) 38 | img = scipy.misc.imresize(img, [fine_size, fine_size]) 39 | img = img/127.5 - 1 40 | 41 | img_B = imread(image_path_B) 42 | img_B = scipy.misc.imresize(img_B, [fine_size, fine_size]) 43 | img_B = img_B / 127.5 - 1 44 | return np.concatenate((img, img_B), axis=2) 45 | 46 | 47 | def load_reverse_data(image_path, fine_size=256): 48 | img = imread(image_path) 49 | img = scipy.misc.imresize(img, [fine_size, fine_size]) 50 | img = img / 127.5 - 1 51 | return img 52 | 53 | def check_folder(path): 54 | if not os.path.exists(path): 55 | os.mkdir(path) 56 | 57 | def load_train_data(image_path, load_size=286, fine_size=256, is_testing=False): 58 | img_A = imread(image_path[0]) 59 | img_B = imread(image_path[1]) 60 | if not is_testing: 61 | img_A = scipy.misc.imresize(img_A, [load_size, load_size]) 62 | img_B = scipy.misc.imresize(img_B, [load_size, load_size]) 63 | h1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 64 | w1 = int(np.ceil(np.random.uniform(1e-2, load_size-fine_size))) 65 | img_A = img_A[h1:h1+fine_size, w1:w1+fine_size] 66 | img_B = img_B[h1:h1+fine_size, w1:w1+fine_size] 67 | 68 | if np.random.random() > 0.5: 69 | img_A = np.fliplr(img_A) 70 | img_B = np.fliplr(img_B) 71 | else: 72 | img_A = scipy.misc.imresize(img_A, [fine_size, fine_size]) 73 | img_B = scipy.misc.imresize(img_B, [fine_size, fine_size]) 74 | 75 | img_A = img_A/127.5 - 1. 76 | img_B = img_B/127.5 - 1. 77 | alpha = np.random.uniform(0.0,1.0) 78 | beta = 1 - alpha 79 | gamma = 0 80 | img_C = cv2.addWeighted(img_A, alpha, img_B, beta, gamma) 81 | img_AB = np.concatenate((img_A, img_B, img_C), axis=2) 82 | return img_AB,alpha 83 | 84 | 85 | def get_image(image_path, image_size, is_crop=True, resize_w=64, is_grayscale = False): 86 | return transform(imread(image_path, is_grayscale), image_size, is_crop, resize_w) 87 | 88 | def save_images(images, size, image_path): 89 | return imsave(inverse_transform(images), size, image_path) 90 | 91 | def imread(path, is_grayscale = False): 92 | if (is_grayscale): 93 | return scipy.misc.imread(path, flatten = True).astype(np.float) 94 | else: 95 | return scipy.misc.imread(path, mode='RGB').astype(np.float) 96 | 97 | def merge_images(images, size): 98 | return inverse_transform(images) 99 | 100 | def merge(images, size): 101 | h, w = images.shape[1], images.shape[2] 102 | img = np.zeros((h * size[0], w * size[1], 3)) 103 | for idx, image in enumerate(images): 104 | i = idx % size[1] 105 | j = idx // size[1] 106 | img[j*h:j*h+h, i*w:i*w+w, :] = image 107 | 108 | return img 109 | 110 | def imsave(images, size, path): 111 | return scipy.misc.imsave(path, merge(images, size)) 112 | 113 | def center_crop(x, crop_h, crop_w, 114 | resize_h=64, resize_w=64): 115 | if crop_w is None: 116 | crop_w = crop_h 117 | h, w = x.shape[:2] 118 | j = int(round((h - crop_h)/2.)) 119 | i = int(round((w - crop_w)/2.)) 120 | return scipy.misc.imresize( 121 | x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) 122 | 123 | def transform(image, npx=64, is_crop=True, resize_w=64): 124 | # npx : # of pixels width/height of image 125 | if is_crop: 126 | cropped_image = center_crop(image, npx, resize_w=resize_w) 127 | else: 128 | cropped_image = image 129 | return np.array(cropped_image)/127.5 - 1. 130 | 131 | def inverse_transform(images): 132 | return (images+1.)/2. --------------------------------------------------------------------------------