├── Data └── BOSSBase_512 │ ├── 1.pgm │ ├── 3.pgm │ ├── 4.pgm │ ├── 5.pgm │ ├── 6.pgm │ ├── 7.pgm │ ├── 8.pgm │ ├── 9.pgm │ └── 10.pgm ├── Implement ├── SRM_Kernels.npy ├── .idea │ ├── misc.xml │ ├── modules.xml │ ├── Implement.iml │ └── workspace.xml ├── testfiles │ ├── command.sh │ ├── test_data_split.py │ └── utils_test.py ├── generator.py ├── main.py ├── YeNet.py ├── preprocessing.py ├── layers.py └── utils.py ├── README.md ├── command_SUNI_0.4_15000_No_1.sh ├── command_SUNI_0.4_15000_No_2.sh ├── command_BOSS.sh └── command_BOSSTEST.sh /Data/BOSSBase_512/1.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/1.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/3.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/3.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/4.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/4.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/5.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/5.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/6.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/6.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/7.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/7.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/8.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/8.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/9.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/9.pgm -------------------------------------------------------------------------------- /Data/BOSSBase_512/10.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Data/BOSSBase_512/10.pgm -------------------------------------------------------------------------------- /Implement/SRM_Kernels.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/HEAD/Implement/SRM_Kernels.npy -------------------------------------------------------------------------------- /Implement/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /Implement/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-YeNet 2 | Implementation of "Deep Learning Hierarchical Representation for Image Steganalysis" by TensorFlow 3 | 4 | ## Usage 5 | Example commands(*.sh) can be found in the root directory 6 | 7 | ## Publication 8 | Ye, Jian, J. Ni, and Y. Yi. 9 | "Deep Learning Hierarchical Representations for Image Steganalysis." 10 | IEEE Transactions on Information Forensics & Security 12.11(2017):2545-2557. 11 | [**publication page**](http://ieeexplore.ieee.org/document/7937836/) 12 | -------------------------------------------------------------------------------- /Implement/.idea/Implement.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /command_SUNI_0.4_15000_No_1.sh: -------------------------------------------------------------------------------- 1 | #train 2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="0,1" 3 | 4 | #test 5 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/log" --gpu="0,1" 6 | 7 | #data_split 8 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" " 9 | -------------------------------------------------------------------------------- /command_SUNI_0.4_15000_No_2.sh: -------------------------------------------------------------------------------- 1 | #train 2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="2,3" 3 | 4 | #test 5 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --gpu="2,3" 6 | 7 | #data_split 8 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" " 9 | -------------------------------------------------------------------------------- /Implement/testfiles/command.sh: -------------------------------------------------------------------------------- 1 | #train 2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/train/cover /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/train/stego /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/valid/cover /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/log" --use-batch-norm --lr 4e-1 --max-epochs=300 --log-interval=24 --gpu="0,1,2,3" 3 | 4 | #test 5 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/test/cover /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/log" --gpu="0,1,2,3" 6 | 7 | #data_split 8 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" " 9 | -------------------------------------------------------------------------------- /command_BOSS.sh: -------------------------------------------------------------------------------- 1 | ####data_transfer 2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/BOSS /home/carlchang/YeNetTensorflow/DataTransfer --required-size 256 --required-operation="resize,crop,subsample" 3 | ####data_aug 4 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSS_256_resize/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5 5 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSS_256_crop/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5 6 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSS_256_subsample/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5 7 | ####data_split 8 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" " 9 | ####train 10 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="2,3" 11 | ####test 12 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --gpu="2,3" 13 | -------------------------------------------------------------------------------- /command_BOSSTEST.sh: -------------------------------------------------------------------------------- 1 | #data_transfer 2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/BOSSTEST /home/carlchang/YeNetTensorflow/DataTransfer --required-size 256 --required-operation="resize,crop,subsample" 3 | 4 | #data_aug 5 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSSTEST_256_resize/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5 6 | 7 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSSTEST_256_crop/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5 8 | 9 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSSTEST_256_subsample/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5 10 | 11 | #train 12 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="2,3" 13 | 14 | #test 15 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --gpu="2,3" 16 | 17 | #data_split 18 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" " 19 | -------------------------------------------------------------------------------- /Implement/generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from scipy import misc, io 5 | from random import shuffle 6 | 7 | def get_files(cover_dir, stego_dir, use_shuf_pair=False): 8 | """ 9 | 从cover和stego文件夹中提取图片,返回到get_batches组成batch 10 | shuf_pair决定了组成batch时,cover与stego是否成对 11 | """ 12 | file = [] 13 | for filename in os.listdir(cover_dir + '/'): 14 | file.append(filename) 15 | shuffle(file) 16 | file_shuf1 = file 17 | 18 | img = [] 19 | img_label = [] 20 | if use_shuf_pair: 21 | shuffle(file) 22 | file_shuf2 = file 23 | for file_idx in range(len(file_shuf1)): 24 | img.append(cover_dir + '/' + file_shuf1[file_idx]) 25 | img_label.append(0) 26 | img.append(stego_dir + '/' + file_shuf2[file_idx]) 27 | img_label.append(1) 28 | else: 29 | for filename in file_shuf1: 30 | img.append(cover_dir + '/' + filename) 31 | img_label.append(0) 32 | img.append(stego_dir + '/' + filename) 33 | img_label.append(1) 34 | 35 | #将img_list和img_label写入cover路径下的img_label_list.txt 36 | #with open(cover_dir + '/' + 'img_label_list.txt', 'w') as f: 37 | # for img_idx in range(len(img)): 38 | # f.write(img[img_idx]+' '+str(img_label[img_idx])+'\n') 39 | 40 | return img, img_label 41 | 42 | def get_minibatches(img, img_label, batch_size): 43 | """ 44 | 替代get_batches函数的作用,批次读取数据,每次返回batch_size大小的数据 45 | """ 46 | for start_idx in range(0, len(img) - batch_size + 1, batch_size): 47 | excerpt = slice(start_idx, start_idx + batch_size) 48 | img_minibatch = img[excerpt] 49 | img_label_minibatch = img_label[excerpt] 50 | yield img_minibatch, img_label_minibatch 51 | 52 | def get_minibatches_content_img(train_img_minibatch_list, img_height, img_width): 53 | """ 54 | 读取get_minibatches函数返回路径对应的内容,将图片实际内容转换为batch,作为返回值 55 | """ 56 | img_num = len(train_img_minibatch_list) 57 | image_minibatch_content = np.zeros([img_num, img_height, img_width, 1], dtype=np.float32) 58 | 59 | i = 0 60 | for img_file in train_img_minibatch_list: 61 | content = misc.imread(img_file) 62 | image_minibatch_content[i, :, :, 0] = content 63 | i = i + 1 64 | 65 | return image_minibatch_content 66 | 67 | """ 68 | def get_batches(img, img_label, batch_size, capacity): 69 | # 70 | #根据get_files返回的图片列表和标签列表,生成训练用batch 71 | #需要注意的是:输入图片应具有相同的高、宽 72 | # 73 | img = tf.cast(img, tf.string) 74 | img_label = tf.cast(img_label, tf.int32) 75 | 76 | # 生成输入队列(queue),tensorflow有多种方法,这里展示image与label分开时的情况 77 | input_queue = tf.train.slice_input_producer([img, img_label]) 78 | 79 | # 从队列里读出label,image(需要对相应的图片进行解码) 80 | label = input_queue[1] 81 | image_contents = tf.read_file(input_queue[0]) #pgm图像不能这么使用 82 | image = tf.image.decode_image(image_contents, channels=1) #pgm图像不能这么使用 83 | ##数据集augmentation的部分 84 | 85 | # 对数据进行大小标准化等操作,tf.image下有很多对image的处理,randomflip等 86 | #image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H) 87 | #image = tf.image.per_image_standardization(image) 88 | 89 | #[image, label]是tensor型变量 90 | image_batch, label_batch = tf.train.batch([image, label], 91 | batch_size=batch_size, 92 | num_threads=64, 93 | capacity=capacity) 94 | label_batch = tf.reshape(label_batch, [batch_size]) 95 | 96 | return image_batch, label_batch 97 | """ -------------------------------------------------------------------------------- /Implement/testfiles/test_data_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from random import shuffle 4 | 5 | # *数据集分割主函数 6 | def data_split(source_dir, dest_dir, 7 | batch_size, 8 | train_percent=0.6, 9 | valid_percent=0.2, 10 | test_percent=0.2): 11 | """ 12 | 根据传入的source_dir中cover/stego图像路径,根据各percent参数 13 | 在dest_dir路径中分割成train/valid/test数据集 14 | 抽取方式是随机的 15 | """ 16 | # *判断输入百分比是否合法 17 | if (train_percent + valid_percent + test_percent) > 1: 18 | raise ValueError('sum of train valid test percentage larger than 1') 19 | 20 | if os.path.exists(dest_dir + '/') is False: 21 | os.mkdir(dest_dir + '/') 22 | if os.path.exists(source_dir + '/') is False: 23 | raise OSError('source direction not exist') 24 | 25 | source_cover_dir = source_dir + '/cover' 26 | source_stego_dir = source_dir + '/stego' 27 | 28 | # *清理非对应文件 29 | file_clean(source_cover_dir, source_stego_dir) 30 | 31 | # *在dest_dir路径下创建train/valid/test路径 32 | dest_train_dir, dest_valid_dir, dest_test_dir = file_dir_mk_trainvalidtest_dir(dest_dir) 33 | 34 | # *对source_dir中的文件顺序进行shuffle 35 | source_cover_list = [] 36 | for filename in os.listdir(source_cover_dir + '/'): 37 | source_cover_list.append(filename) 38 | shuffle(source_cover_list) 39 | 40 | # *计算train/valid/test数据集容量 41 | half_batch_size = batch_size // 2 42 | train_ds_capacity = ( int( len(source_cover_list)*train_percent ) // half_batch_size ) * half_batch_size 43 | valid_ds_capacity = ( int( len(source_cover_list)*valid_percent ) // half_batch_size ) * half_batch_size 44 | test_ds_capacity = ( int( len(source_cover_list)*test_percent ) // half_batch_size ) * half_batch_size 45 | 46 | for fileidx in range(train_ds_capacity): 47 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx] 48 | dstfile_cover = dest_train_dir + '/cover/' + source_cover_list[fileidx] 49 | shutil.copyfile(srcfile_cover, dstfile_cover) 50 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx] 51 | dstfile_stego = dest_train_dir + '/stego/' + source_cover_list[fileidx] 52 | shutil.copyfile(srcfile_stego, dstfile_stego) 53 | for fileidx in range(train_ds_capacity, train_ds_capacity + valid_ds_capacity): 54 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx] 55 | dstfile_cover = dest_valid_dir + '/cover/' + source_cover_list[fileidx] 56 | shutil.copyfile(srcfile_cover, dstfile_cover) 57 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx] 58 | dstfile_stego = dest_valid_dir + '/stego/' + source_cover_list[fileidx] 59 | shutil.copyfile(srcfile_stego, dstfile_stego) 60 | for fileidx in range(train_ds_capacity + valid_ds_capacity, 61 | train_ds_capacity + valid_ds_capacity + test_ds_capacity): 62 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx] 63 | dstfile_cover = dest_test_dir + '/cover/' + source_cover_list[fileidx] 64 | shutil.copyfile(srcfile_cover, dstfile_cover) 65 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx] 66 | dstfile_stego = dest_test_dir + '/stego/' + source_cover_list[fileidx] 67 | shutil.copyfile(srcfile_stego, dstfile_stego) 68 | 69 | def file_dir_mk_trainvalidtest_dir(dest_dir): 70 | """ 71 | 在dest_dir路径下创建train/valid/test路径 72 | """ 73 | if os.path.exists(dest_dir + '/train/') is False: 74 | os.mkdir(dest_dir + '/train/') 75 | if os.path.exists(dest_dir + '/train/cover/') is False: 76 | os.mkdir(dest_dir + '/train/cover/') 77 | if os.path.exists(dest_dir + '/train/stego/') is False: 78 | os.mkdir(dest_dir + '/train/stego/') 79 | if os.path.exists(dest_dir + '/valid/') is False: 80 | os.mkdir(dest_dir + '/valid/') 81 | if os.path.exists(dest_dir + '/valid/cover/') is False: 82 | os.mkdir(dest_dir + '/valid/cover/') 83 | if os.path.exists(dest_dir + '/valid/stego/') is False: 84 | os.mkdir(dest_dir + '/valid/stego/') 85 | if os.path.exists(dest_dir + '/test/') is False: 86 | os.mkdir(dest_dir + '/test/') 87 | if os.path.exists(dest_dir + '/test/cover/') is False: 88 | os.mkdir(dest_dir + '/test/cover/') 89 | if os.path.exists(dest_dir + '/test/stego/') is False: 90 | os.mkdir(dest_dir + '/test/stego/') 91 | if os.path.exists(dest_dir + '/log/') is False: 92 | os.mkdir(dest_dir + '/log/') 93 | return dest_dir + '/train', dest_dir + '/valid', dest_dir + '/test' 94 | 95 | def file_clean(cover_dir, stego_dir): 96 | """ 97 | 对cover和stego里的文件进行清理,将只存在于单个文件夹的文件、后缀名不匹配的文件删除。 98 | """ 99 | cover_dir = cover_dir + '/' 100 | stego_dir = stego_dir + '/' 101 | cover_list = [] 102 | stego_list = [] 103 | for root, dirs, files in os.walk(cover_dir): 104 | for filenames in files: 105 | cover_list.append(filenames) 106 | for root, dirs, files in os.walk(stego_dir): 107 | for filenames in files: 108 | stego_list.append(filenames) 109 | diff_cover_list = set(cover_list).difference(set(stego_list)) 110 | diff_stego_list = set(stego_list).difference(set(cover_list)) 111 | print('About to delete: ', len(diff_cover_list), 'files in ', cover_dir, 'Continue?') 112 | os.system('pause') 113 | for filenames in diff_cover_list: 114 | os.remove(cover_dir + filenames) 115 | print('About to delete: ', len(diff_stego_list), 'files in ', stego_dir, 'Continue?') 116 | os.system('pause') 117 | for filenames in diff_stego_list: 118 | os.remove(stego_dir + filenames) 119 | print('file_clean process has completed.') 120 | 121 | if __name__ == '__main__': 122 | source_dir = 'E:\@ChangShihyoung\TensorFlow-YeNet\Data\SUNI_13_0.4' 123 | dest_dir = 'E:\@ChangShihyoung\TensorFlow-YeNet\Experiment\SUNI_13_0.4_No_1' 124 | data_split(source_dir, dest_dir, 125 | 4, 126 | train_percent=0.6, 127 | valid_percent=0.2, 128 | test_percent=0.2) -------------------------------------------------------------------------------- /Implement/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tensorflow as tf 4 | from glob import glob 5 | from preprocessing import * 6 | from generator import * 7 | from utils import * 8 | from YeNet import YeNet 9 | 10 | # *定义命令行输入变量 11 | parser = argparse.ArgumentParser(description='Tensorflow implementation of YeNet') 12 | 13 | # *根据不同操作进行不同的命令行参数定义 14 | operation_set = ['train', 'test', 'datatransfer', 'dataaug', 'datasplit'] 15 | print('Operation set ', operation_set) 16 | input_operation = input('The operation you want to perform: ') 17 | if 'train' in input_operation: 18 | parser.add_argument('train_cover_dir', type=str, metavar='PATH', 19 | help='directory of training cover images') 20 | parser.add_argument('train_stego_dir', type=str, metavar='PATH', 21 | help='directory of training stego images') 22 | parser.add_argument('valid_cover_dir', type=str, metavar='PATH', 23 | help='directory of validation cover images') 24 | parser.add_argument('valid_stego_dir', type=str, metavar='PATH', 25 | help='directory of validation stego images') 26 | if 'test' in input_operation: 27 | parser.add_argument('test_cover_dir', type=str, metavar='PATH', 28 | help='directory of testing cover images') 29 | parser.add_argument('test_stego_dir', type=str, metavar='PATH', 30 | help='directory of testing stego images') 31 | if 'datatransfer' in input_operation: 32 | parser.add_argument('source_dir', type=str, metavar='PATH', 33 | help='directory of source images') 34 | parser.add_argument('dest_dir', type=str, metavar='PATH', 35 | help='directory of destination images') 36 | parser.add_argument('--required-size', type=int, default=256, metavar='N', 37 | help='required size of destination images (default: 256)') 38 | parser.add_argument('--required-operation', type=str, default='resize,crop,subsample', metavar='S', 39 | help='transfer operation for source image (default: resize,crop,subsample)') 40 | if 'dataaug' in input_operation: 41 | parser.add_argument('source_dir', type=str, metavar='PATH', 42 | help='directory of source images') 43 | parser.add_argument('dest_dir', type=str, metavar='PATH', 44 | help='directory of destination images') 45 | parser.add_argument('--ratio-rot', type=float, default=0.5, metavar='F', 46 | help='percentage of dataset augmented by rotation (default: 0.5)') 47 | if 'datasplit' in input_operation: 48 | parser.add_argument('source_dir', type=str, metavar='PATH', 49 | help='directory of source cover and stego images') 50 | parser.add_argument('dest_dir', type=str, metavar='PATH', 51 | help='directory of separated dataset') 52 | parser.add_argument('--train-percent', type=float, default=0.6, metavar='F', 53 | help='percentage of dataset used for training (default: 0.6)') 54 | parser.add_argument('--valid-percent', type=float, default=0.2, metavar='F', 55 | help='percentage of dataset used for validation (default: 0.2)') 56 | parser.add_argument('--test-percent', type=float, default=0.2, metavar='F', 57 | help='percentage of dataset used for testing (default: 0.2)') 58 | 59 | if input_operation not in operation_set: 60 | raise NotImplementedError('invalid operation') 61 | 62 | # *定义余下可选命令行参数 63 | parser.add_argument('--use-shuf-pair', action='store_true', default=False, 64 | help='matching cover and stego image when batch is constructed (default: False)') 65 | parser.add_argument('--use-batch-norm', action='store_true', default=False, 66 | help='use batch normalization after each activation (default: False)') 67 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 68 | help='input batch size for training, testing and validation (default: 32)') 69 | parser.add_argument('--max-epochs', type=int, default=200, metavar='N', 70 | help='number of epochs to train (default: 200)') 71 | parser.add_argument('--lr', type=float, default=4e-1, metavar='F', 72 | help='learning rate (default: 4e-1)') 73 | parser.add_argument('--gpu', type=str, default='0', metavar='S', 74 | help='index of gpu used (default: 0)') 75 | parser.add_argument('--tfseed', type=int, default=1, metavar='S', 76 | help='random seed (default: 1)') 77 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 78 | help='number of batches before logging training status') 79 | parser.add_argument('--log-path', type=str, default='logs/', 80 | metavar='PATH', help='directory of log file') 81 | 82 | args = parser.parse_args() 83 | 84 | import os 85 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 86 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 87 | 88 | # *设置tf随机种子 89 | tf.set_random_seed(args.tfseed) 90 | 91 | # *根据不同操作输入执行相应函数 92 | if 'datatransfer' in input_operation: 93 | # *数据集预处理主函数 94 | data_transfer(args.source_dir, args.dest_dir, 95 | required_size=args.required_size, 96 | required_operation=args.required_operation) 97 | 98 | if 'dataaug' in input_operation: 99 | # *数据集增广主函数 100 | data_aug(args.source_dir, args.dest_dir, args.ratio_rot) 101 | 102 | if 'datasplit' in input_operation: 103 | # *数据集分割主函数 104 | data_split(args.source_dir, args.dest_dir, 105 | args.batch_size, 106 | train_percent=args.train_percent, 107 | valid_percent=args.valid_percent, 108 | test_percent=args.test_percent) 109 | 110 | if 'train' in input_operation: 111 | # *计算train/valid数据集大小 112 | train_ds_size = len(glob(args.train_cover_dir + '/*')) * 2 113 | if train_ds_size % args.batch_size != 0: 114 | raise ValueError('change batch size for training') 115 | valid_ds_size = len(glob(args.valid_cover_dir + '/*')) * 2 116 | if valid_ds_size % args.batch_size != 0: 117 | raise ValueError('change batch size for validation') 118 | # *训练主函数 119 | train(YeNet, args.use_batch_norm, args.use_shuf_pair, 120 | args.train_cover_dir, args.train_stego_dir, 121 | args.valid_cover_dir, args.valid_stego_dir, 122 | args.batch_size, train_ds_size, valid_ds_size, 123 | args.log_interval, args.max_epochs, args.lr, 124 | args.log_path) 125 | 126 | if 'test' in input_operation: 127 | # *计算test数据集大小 128 | test_ds_size = len(glob(args.test_cover_dir + '/*')) * 2 129 | if test_ds_size % args.batch_size != 0: 130 | raise ValueError('change batch size for testing') 131 | # *查找最佳模型主函数 132 | test_dataset_findbest(YeNet, args.use_shuf_pair, 133 | args.test_cover_dir, args.test_stego_dir, args.max_epochs, 134 | args.batch_size, test_ds_size, args.log_path) 135 | 136 | 137 | -------------------------------------------------------------------------------- /Implement/YeNet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import layers 3 | from tensorflow.contrib.framework import add_arg_scope, arg_scope, arg_scoped_arguments 4 | import layers as my_layers 5 | from utils import * 6 | 7 | SRM_Kernels = np.load('/home/carlchang/YeNetTensorflow/Implement/SRM_Kernels.npy') 8 | 9 | class YeNet(Model): 10 | def __init__(self, is_training=None, data_format='NCHW', 11 | with_bn=False, tlu_threshold=3): 12 | super(YeNet, self).__init__(is_training=is_training, 13 | data_format=data_format) 14 | self.with_bn = with_bn 15 | self.tlu_threshold = tlu_threshold 16 | 17 | def _build_model(self, inputs): 18 | self.inputs = inputs 19 | if self.data_format == 'NCHW': 20 | channel_axis = 1 21 | _inputs = tf.cast(tf.transpose(inputs, [0, 3, 1, 2]), tf.float32) 22 | else: 23 | channel_axis = 3 24 | _inputs = tf.cast(inputs, tf.float32) 25 | self.L = [] 26 | with arg_scope([layers.avg_pool2d], 27 | padding='VALID', data_format=self.data_format): 28 | with tf.variable_scope('SRM_preprocess'): 29 | W_SRM = tf.get_variable('W', initializer=SRM_Kernels, 30 | dtype=tf.float32, 31 | regularizer=None) 32 | b = tf.get_variable('b', shape=[30], dtype=tf.float32, 33 | initializer=tf.constant_initializer(0.)) 34 | self.L.append(tf.nn.bias_add( 35 | tf.nn.conv2d(_inputs, 36 | W_SRM, [1,1,1,1], 'VALID', 37 | data_format=self.data_format), b, 38 | data_format=self.data_format, name='Layer1')) 39 | self.L.append(tf.clip_by_value(self.L[-1], 40 | -self.tlu_threshold, self.tlu_threshold, 41 | name='TLU')) 42 | with tf.variable_scope('ConvNetwork'): 43 | with arg_scope([my_layers.conv2d], 44 | num_outputs=30, 45 | kernel_size=3, stride=1, padding='VALID', 46 | data_format=self.data_format, 47 | activation_fn=tf.nn.relu, 48 | weights_initializer=layers.xavier_initializer_conv2d(), 49 | weights_regularizer=layers.l2_regularizer(5e-4), 50 | biases_initializer=tf.constant_initializer(0.2), 51 | biases_regularizer=None), arg_scope([layers.batch_norm], 52 | decay=0.9, center=True, scale=True, 53 | updates_collections=None, is_training=self.is_training, 54 | fused=True, data_format=self.data_format): 55 | if self.with_bn: 56 | self.L.append(layers.batch_norm(self.L[-1], 57 | scope='Norm1')) 58 | self.L.append(my_layers.conv2d(self.L[-1], 59 | scope='Layer2')) 60 | if self.with_bn: 61 | self.L.append(layers.batch_norm(self.L[-1], 62 | scope='Norm2')) 63 | self.L.append(my_layers.conv2d(self.L[-1], 64 | scope='Layer3')) 65 | if self.with_bn: 66 | self.L.append(layers.batch_norm(self.L[-1], 67 | scope='Norm3')) 68 | self.L.append(my_layers.conv2d(self.L[-1], 69 | scope='Layer4')) 70 | if self.with_bn: 71 | self.L.append(layers.batch_norm(self.L[-1], 72 | scope='Norm4')) 73 | self.L.append(layers.avg_pool2d(self.L[-1], 74 | kernel_size=[2,2], scope='Stride1')) 75 | with arg_scope([my_layers.conv2d], kernel_size=5, 76 | num_outputs=32): 77 | self.L.append(my_layers.conv2d(self.L[-1], 78 | scope='Layer5')) 79 | if self.with_bn: 80 | self.L.append(layers.batch_norm(self.L[-1], 81 | scope='Norm5')) 82 | self.L.append(layers.avg_pool2d(self.L[-1], 83 | kernel_size=[3,3], 84 | scope='Stride2')) 85 | self.L.append(my_layers.conv2d(self.L[-1], 86 | scope='Layer6')) 87 | if self.with_bn: 88 | self.L.append(layers.batch_norm(self.L[-1], 89 | scope='Norm6')) 90 | self.L.append(layers.avg_pool2d(self.L[-1], 91 | kernel_size=[3,3], 92 | scope='Stride3')) 93 | self.L.append(my_layers.conv2d(self.L[-1], 94 | scope='Layer7')) 95 | if self.with_bn: 96 | self.L.append(layers.batch_norm(self.L[-1], 97 | scope='Norm7')) 98 | self.L.append(layers.avg_pool2d(self.L[-1], 99 | kernel_size=[3,3], 100 | scope='Stride4')) 101 | self.L.append(my_layers.conv2d(self.L[-1], 102 | num_outputs=16, 103 | scope='Layer8')) 104 | if self.with_bn: 105 | self.L.append(layers.batch_norm(self.L[-1], 106 | scope='Norm8')) 107 | self.L.append(my_layers.conv2d(self.L[-1], 108 | num_outputs=16, stride=3, 109 | scope='Layer9')) 110 | if self.with_bn: 111 | self.L.append(layers.batch_norm(self.L[-1], 112 | scope='Norm9')) 113 | self.L.append(layers.flatten(self.L[-1])) 114 | self.L.append(layers.fully_connected(self.L[-1], num_outputs=2, 115 | activation_fn=None, normalizer_fn=None, 116 | weights_initializer=tf.random_normal_initializer(mean=0., 117 | stddev=0.01), 118 | biases_initializer=tf.constant_initializer(0.), scope='ip')) 119 | self.outputs = self.L[-1] 120 | return self.outputs 121 | 122 | -------------------------------------------------------------------------------- /Implement/preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from scipy import misc 4 | from glob import glob 5 | import shutil 6 | import random 7 | from random import random as rand 8 | from random import shuffle 9 | 10 | # *数据集预处理主函数 11 | def data_transfer(source_dir, dest_dir, 12 | required_size, 13 | required_operation): 14 | """ 15 | 将source_dir中的图像依operation中定义的操作扩展至dest_dir中 16 | 包含resize,subsample和crop操作 17 | """ 18 | dest_dir = dest_dir + '/' + source_dir.split("/")[-1] 19 | 20 | # 建立数据集路径 21 | size = required_size, required_size 22 | if 'resize' in required_operation: 23 | dest_resize_dir = dest_dir + '_' + str(required_size) + '_resize' 24 | if os.path.exists(dest_resize_dir + '/') is False: 25 | os.mkdir(dest_resize_dir + '/') 26 | os.mkdir(dest_resize_dir + '/cover/') 27 | if 'crop' in required_operation: 28 | dest_crop_dir = dest_dir + '_' + str(required_size) + '_crop' 29 | if os.path.exists(dest_crop_dir + '/') is False: 30 | os.mkdir(dest_crop_dir + '/') 31 | os.mkdir(dest_crop_dir + '/cover/') 32 | if 'subsample' in required_operation: 33 | dest_subsample_dir = dest_dir + '_' + str(required_size) + '_subsample' 34 | if os.path.exists(dest_subsample_dir + '/') is False: 35 | os.mkdir(dest_subsample_dir + '/') 36 | os.mkdir(dest_subsample_dir + '/cover/') 37 | 38 | source_img_list = glob(source_dir + '/*') 39 | for filename in source_img_list: 40 | img = misc.imread(filename) 41 | if img is None: 42 | raise OSError('Error: could not load image') 43 | if 'resize' in required_operation: 44 | img_resize = misc.imresize(img, size, interp='bicubic') 45 | save_dir = dest_resize_dir + '/cover/' + filename.split("/")[-1] 46 | misc.imsave(save_dir, img_resize) 47 | if 'crop' in required_operation: 48 | ROI_idx = (img.shape[0] - required_size) // 2 49 | img_crop = img[ROI_idx:ROI_idx+required_size, ROI_idx:ROI_idx+required_size] 50 | save_dir = dest_crop_dir + '/cover/' + filename.split("/")[-1] 51 | misc.imsave(save_dir, img_crop) 52 | if 'subsample' in required_operation: 53 | SUB_idx = img.shape[0] // required_size 54 | img_subsample = img[0:img.shape[0]:SUB_idx, 0:img.shape[1]:SUB_idx] 55 | save_dir = dest_subsample_dir + '/cover/' + filename.split("/")[-1] 56 | misc.imsave(save_dir, img_subsample) 57 | print('data transfer succeed!') 58 | 59 | # *数据集增广主函数 60 | def data_aug(source_dir, dest_dir, ratio=0.5): 61 | """ 62 | 将source_dir中的图像增广至dest_dir中 63 | 包含rotate和flip操作 64 | """ 65 | dest_dir = dest_dir + '/' + source_dir.split("/")[-2] + '_aug' 66 | if os.path.exists(dest_dir + '/') is False: 67 | os.mkdir(dest_dir + '/') 68 | os.mkdir(dest_dir + '/cover/') 69 | 70 | dest_dir = dest_dir + '/cover' 71 | 72 | source_img_list = glob(source_dir + '/*') 73 | for filename in source_img_list: 74 | img = misc.imread(filename) 75 | if img is None: 76 | raise OSError('Error: could not load image') 77 | filename_split = (filename.split("/")[-1]) 78 | save_dir = dest_dir + '/' + filename_split 79 | misc.imsave(save_dir, img) 80 | 81 | rot = random.randint(1, 3) 82 | rand_op = rand() 83 | rand_flip = rand() 84 | if rand_op < ratio: 85 | img_rot = misc.imrotate(img, rot*90, interp='bicubic') 86 | save_dir = dest_dir + '/' + filename_split.split('.')[0] + '_rot.' + filename_split.split('.')[1] 87 | misc.imsave(save_dir, img_rot) 88 | else: 89 | if rand_flip < ratio: 90 | img_flip = np.flipud(img) 91 | else: 92 | img_flip = np.fliplr(img) 93 | save_dir = dest_dir + '/' + filename_split.split('.')[0] + '_flip.' + filename_split.split('.')[1] 94 | misc.imsave(save_dir, img_flip) 95 | print('data augment succeed!') 96 | 97 | 98 | # *数据集分割主函数 99 | def data_split(source_dir, dest_dir, 100 | batch_size, 101 | train_percent=0.6, 102 | valid_percent=0.2, 103 | test_percent=0.2): 104 | """ 105 | 根据传入的source_dir中cover/stego图像路径,根据各percent参数 106 | 在dest_dir路径中分割成train/valid/test数据集 107 | 抽取方式是随机的 108 | """ 109 | # *判断输入百分比是否合法 110 | if (train_percent + valid_percent + test_percent) > 1: 111 | raise ValueError('sum of train valid test percentage larger than 1') 112 | 113 | if os.path.exists(dest_dir + '/') is False: 114 | os.mkdir(dest_dir + '/') 115 | if os.path.exists(source_dir + '/') is False: 116 | raise OSError('source direction not exist') 117 | 118 | source_cover_dir = source_dir + '/cover' 119 | source_stego_dir = source_dir + '/stego' 120 | 121 | # *清理非对应文件 122 | file_clean(source_cover_dir, source_stego_dir) 123 | 124 | # *在dest_dir路径下创建train/valid/test路径 125 | dest_train_dir, dest_valid_dir, dest_test_dir = file_dir_mk_trainvalidtest_dir(dest_dir) 126 | 127 | # *对source_dir中的文件顺序进行shuffle 128 | source_cover_list = [] 129 | for filename in os.listdir(source_cover_dir + '/'): 130 | source_cover_list.append(filename) 131 | shuffle(source_cover_list) 132 | 133 | # *计算train/valid/test数据集容量 134 | half_batch_size = batch_size // 2 135 | train_ds_capacity = ( int( len(source_cover_list)*train_percent ) // half_batch_size ) * half_batch_size 136 | valid_ds_capacity = ( int( len(source_cover_list)*valid_percent ) // half_batch_size ) * half_batch_size 137 | test_ds_capacity = ( int( len(source_cover_list)*test_percent ) // half_batch_size ) * half_batch_size 138 | 139 | for fileidx in range(train_ds_capacity): 140 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx] 141 | dstfile_cover = dest_train_dir + '/cover/' + source_cover_list[fileidx] 142 | shutil.copyfile(srcfile_cover, dstfile_cover) 143 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx] 144 | dstfile_stego = dest_train_dir + '/stego/' + source_cover_list[fileidx] 145 | shutil.copyfile(srcfile_stego, dstfile_stego) 146 | for fileidx in range(train_ds_capacity, train_ds_capacity + valid_ds_capacity): 147 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx] 148 | dstfile_cover = dest_valid_dir + '/cover/' + source_cover_list[fileidx] 149 | shutil.copyfile(srcfile_cover, dstfile_cover) 150 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx] 151 | dstfile_stego = dest_valid_dir + '/stego/' + source_cover_list[fileidx] 152 | shutil.copyfile(srcfile_stego, dstfile_stego) 153 | for fileidx in range(train_ds_capacity + valid_ds_capacity, 154 | train_ds_capacity + valid_ds_capacity + test_ds_capacity): 155 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx] 156 | dstfile_cover = dest_test_dir + '/cover/' + source_cover_list[fileidx] 157 | shutil.copyfile(srcfile_cover, dstfile_cover) 158 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx] 159 | dstfile_stego = dest_test_dir + '/stego/' + source_cover_list[fileidx] 160 | shutil.copyfile(srcfile_stego, dstfile_stego) 161 | print('data split succeed!') 162 | 163 | def file_clean(cover_dir, stego_dir): 164 | """ 165 | 对cover和stego里的文件进行清理,将只存在于单个文件夹的文件、后缀名不匹配的文件删除。 166 | """ 167 | cover_dir = cover_dir + '/' 168 | stego_dir = stego_dir + '/' 169 | cover_list = [] 170 | stego_list = [] 171 | for root, dirs, files in os.walk(cover_dir): 172 | for filenames in files: 173 | cover_list.append(filenames) 174 | for root, dirs, files in os.walk(stego_dir): 175 | for filenames in files: 176 | stego_list.append(filenames) 177 | diff_cover_list = set(cover_list).difference(set(stego_list)) 178 | diff_stego_list = set(stego_list).difference(set(cover_list)) 179 | print('Start file cleaning...') 180 | print('About to delete: ', len(diff_cover_list), 'files in ', cover_dir) 181 | for filenames in diff_cover_list: 182 | os.remove(cover_dir + filenames) 183 | print('About to delete: ', len(diff_stego_list), 'files in ', stego_dir) 184 | for filenames in diff_stego_list: 185 | os.remove(stego_dir + filenames) 186 | 187 | def file_dir_mk_trainvalidtest_dir(dest_dir): 188 | """ 189 | 在dest_dir路径下创建train/valid/test路径 190 | """ 191 | if os.path.exists(dest_dir + '/train/') is False: 192 | os.mkdir(dest_dir + '/train/') 193 | if os.path.exists(dest_dir + '/train/cover/') is False: 194 | os.mkdir(dest_dir + '/train/cover/') 195 | if os.path.exists(dest_dir + '/train/stego/') is False: 196 | os.mkdir(dest_dir + '/train/stego/') 197 | if os.path.exists(dest_dir + '/valid/') is False: 198 | os.mkdir(dest_dir + '/valid/') 199 | if os.path.exists(dest_dir + '/valid/cover/') is False: 200 | os.mkdir(dest_dir + '/valid/cover/') 201 | if os.path.exists(dest_dir + '/valid/stego/') is False: 202 | os.mkdir(dest_dir + '/valid/stego/') 203 | if os.path.exists(dest_dir + '/test/') is False: 204 | os.mkdir(dest_dir + '/test/') 205 | if os.path.exists(dest_dir + '/test/cover/') is False: 206 | os.mkdir(dest_dir + '/test/cover/') 207 | if os.path.exists(dest_dir + '/test/stego/') is False: 208 | os.mkdir(dest_dir + '/test/stego/') 209 | if os.path.exists(dest_dir + '/log/') is False: 210 | os.mkdir(dest_dir + '/log/') 211 | return dest_dir + '/train', dest_dir + '/valid', dest_dir + '/test' 212 | -------------------------------------------------------------------------------- /Implement/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import layers 3 | from tensorflow.contrib.framework import add_arg_scope 4 | 5 | @add_arg_scope 6 | def conv2d(inputs, 7 | num_outputs, 8 | kernel_size, 9 | stride=1, 10 | padding='SAME', 11 | data_format=None, 12 | rate=1, 13 | activation_fn=tf.nn.relu, 14 | normalizer_fn=None, 15 | normalize_after_activation=True, 16 | normalizer_params=None, 17 | weights_initializer=layers.xavier_initializer(), 18 | weights_regularizer=None, 19 | biases_initializer=tf.zeros_initializer(), 20 | biases_regularizer=None, 21 | reuse=None, 22 | variables_collections=None, 23 | outputs_collections=None, 24 | trainable=True, 25 | scope=None): 26 | with tf.variable_scope(scope, 'Conv', reuse=reuse): 27 | if data_format == 'NHWC': 28 | num_inputs = inputs.get_shape().as_list()[3] 29 | height = inputs.get_shape().as_list()[1] 30 | width = inputs.get_shape().as_list()[2] 31 | if isinstance(stride, int): 32 | strides = [1, stride, stride, 1] 33 | elif isinstance(stride, list) or isinstance(stride, tuple): 34 | if len(stride) == 1: 35 | strides = [1] + stride * 2 + [1] 36 | else: 37 | strides = [1, stride[0], stride[1], 1] 38 | else: 39 | raise TypeError('stride is not an int, list or' 40 | + 'a tuple, is %s' % type(stride)) 41 | else: 42 | num_inputs = inputs.get_shape().as_list()[1] 43 | height = inputs.get_shape().as_list()[2] 44 | width = inputs.get_shape().as_list()[3] 45 | if isinstance(stride, int): 46 | strides = [1, 1, stride, stride] 47 | elif isinstance(stride, list) or isinstance(stride, tuple): 48 | if len(stride) == 1: 49 | strides = [1, 1] + stride * 2 50 | else: 51 | strides = [1, 1, stride[0], stride[1]] 52 | else: 53 | raise TypeError('stride is not an int, list or' 54 | + 'a tuple, is %s' % type(stride)) 55 | if isinstance(kernel_size, int): 56 | kernel_height = kernel_size 57 | kernel_width = kernel_size 58 | elif isinstance(kernel_size, list) or isinstance(kernel_size, tuple): 59 | kernel_height = kernel_size[0] 60 | kernel_width = kernel_size[1] 61 | else: 62 | raise ValueError('kernel_size is not an int, list or' 63 | + 'a tuple, is %s' % type(kernel_size)) 64 | weights = tf.get_variable('weights', [kernel_height, 65 | kernel_width, num_inputs, num_outputs], 66 | 'float32', weights_initializer, 67 | weights_regularizer, trainable, 68 | variables_collections) 69 | outputs = tf.nn.conv2d(inputs, weights, strides, padding, 70 | data_format=data_format) 71 | if biases_initializer is not None: 72 | biases = tf.get_variable('biases', [num_outputs], 'float32', 73 | biases_initializer, 74 | biases_regularizer, 75 | trainable, variables_collections) 76 | outputs = tf.nn.bias_add(outputs, biases, data_format) 77 | if normalizer_fn is not None and not normalize_after_activation: 78 | normalizer_params = normalizer_params or {} 79 | outputs = normalizer_fn(outputs, **normalizer_params) 80 | if activation_fn is not None: 81 | outputs = activation_fn(outputs) 82 | if normalizer_fn is not None and normalize_after_activation: 83 | normalizer_params = normalizer_params or {} 84 | outputs = normalizer_fn(outputs, **normalizer_params) 85 | return outputs 86 | 87 | """ 88 | @add_arg_scope 89 | def double_conv2d(ref_half, real_half, 90 | num_outputs, 91 | kernel_size, 92 | stride=1, 93 | padding='SAME', 94 | data_format=None, 95 | rate=1, 96 | activation_fn=tf.nn.relu, 97 | normalizer_fn=None, 98 | normalize_after_activation=True, 99 | normalizer_params=None, 100 | weights_initializer=layers.xavier_initializer(), 101 | weights_regularizer=None, 102 | biases_initializer=tf.zeros_initializer(), 103 | biases_regularizer=None, 104 | reuse=None, 105 | variables_collections=None, 106 | outputs_collections=None, 107 | trainable=True, 108 | scope=None): 109 | with tf.variable_scope(scope, 'Conv', reuse=reuse): 110 | if data_format == 'NHWC': 111 | num_inputs = real_half.get_shape().as_list()[3] 112 | height = real_half.get_shape().as_list()[1] 113 | width = real_half.get_shape().as_list()[2] 114 | if isinstance(stride, int): 115 | strides = [1, stride, stride, 1] 116 | elif isinstance(stride, list) or isinstance(stride, tuple): 117 | if len(stride) == 1: 118 | strides = [1] + stride * 2 + [1] 119 | else: 120 | strides = [1, stride[0], stride[1], 1] 121 | else: 122 | raise TypeError('stride is not an int, list or' 123 | + 'a tuple, is %s' % type(stride)) 124 | else: 125 | num_inputs = real_half.get_shape().as_list()[1] 126 | height = real_half.get_shape().as_list()[2] 127 | width = real_half.get_shape().as_list()[3] 128 | if isinstance(stride, int): 129 | strides = [1, 1, stride, stride] 130 | elif isinstance(stride, list) or isinstance(stride, tuple): 131 | if len(stride) == 1: 132 | strides = [1, 1] + stride * 2 133 | else: 134 | strides = [1, 1, stride[0], stride[1]] 135 | else: 136 | raise TypeError('stride is not an int, list or' \ 137 | + 'a tuple, is %s' % type(stride)) 138 | if isinstance(kernel_size, int): 139 | kernel_height = kernel_size 140 | kernel_width = kernel_size 141 | elif isinstance(kernel_size, list) \ 142 | or isinstance(kernel_size, tuple): 143 | kernel_height = kernel_size[0] 144 | kernel_width = kernel_size[1] 145 | else: 146 | raise ValueError('kernel_size is not an int, list or' 147 | + 'a tuple, is %s' % type(kernel_size)) 148 | weights = tf.get_variable('weights', [kernel_height, 149 | kernel_width, num_inputs, num_outputs], 150 | 'float32', weights_initializer, 151 | weights_regularizer, trainable, 152 | variables_collections) 153 | ref_outputs = tf.nn.conv2d(ref_half, weights, strides, padding, 154 | data_format=data_format) 155 | real_outputs = tf.nn.conv2d(real_half, weights, strides, padding, 156 | data_format=data_format) 157 | if biases_initializer is not None: 158 | biases = tf.get_variable('biases', [num_outputs], 'float32', 159 | biases_initializer, 160 | biases_regularizer, 161 | trainable, variables_collections) 162 | ref_outputs = tf.nn.bias_add(ref_outputs, biases, data_format) 163 | real_outputs = tf.nn.bias_add(real_outputs, biases, data_format) 164 | if normalizer_fn is not None and not normalize_after_activation: 165 | normalizer_params = normalizer_params or {} 166 | ref_outputs, real_outputs = normalizer_fn(ref_outputs, 167 | real_outputs, 168 | **normalizer_params) 169 | if activation_fn is not None: 170 | ref_outputs = activation_fn(ref_outputs) 171 | real_outputs = activation_fn(real_outputs) 172 | if normalizer_fn is not None and normalize_after_activation: 173 | normalizer_params = normalizer_params or {} 174 | ref_outputs, real_outputs = normalizer_fn(ref_outputs, 175 | real_outputs, 176 | **normalizer_params) 177 | return ref_outputs, real_outputs 178 | 179 | class Vbn_double(object): 180 | def __init__(self, x, epsilon=1e-5, scope=None): 181 | shape = x.get_shape().as_list() 182 | needs_reshape = len(shape) != 4 183 | if needs_reshape: 184 | orig_shape = shape 185 | if len(shape) == 2: 186 | if data_format == 'NCHW': 187 | x = tf.reshape(x, [shape[0], shape[1], 0, 0]) 188 | else: 189 | x = tf.reshape(x, [shape[0], 1, 1, shape[1]]) 190 | elif len(shape) == 1: 191 | x = tf.reshape(x, [shape[0], 1, 1, 1]) 192 | else: 193 | assert False, shape 194 | shape = x.get_shape().as_list() 195 | with tf.variable_scope(scope): 196 | self.epsilon = epsilon 197 | self.scope = scope 198 | self.mean, self.var = tf.nn.moments(x, [0,2,3], \ 199 | keep_dims=True) 200 | self.inv_std = tf.rsqrt(self.var + epsilon) 201 | self.batch_size = int(x.get_shape()[0]) 202 | out = self._normalize(x, self.mean, self.inv_std) 203 | if needs_reshape: 204 | out = tf.reshape(out, orig_shape) 205 | self.reference_output = out 206 | 207 | def __call__(self, x): 208 | shape = x.get_shape().as_list() 209 | needs_reshape = len(shape) != 4 210 | if needs_reshape: 211 | orig_shape = shape 212 | if len(shape) == 2: 213 | if self.data_format == 'NCHW': 214 | x = tf.reshape(x, [shape[0], shape[1], 0, 0]) 215 | else: 216 | x = tf.reshape(x, [shape[0], 1, 1, shape[1]]) 217 | elif len(shape) == 1: 218 | x = tf.reshape(x, [shape[0], 1, 1, 1]) 219 | else: 220 | assert False, shape 221 | with tf.variable_scope(self.scope, reuse=True): 222 | out = self._normalize(x, self.mean, self.inv_std) 223 | if needs_reshape: 224 | out = tf.reshape(out, orig_shape) 225 | return out 226 | 227 | def _normalize(self, x, mean, inv_std): 228 | shape = x.get_shape().as_list() 229 | assert len(shape) == 4 230 | gamma = tf.get_variable("gamma", [1,shape[1],1,1], 231 | initializer=tf.constant_initializer(1.)) 232 | beta = tf.get_variable("beta", [1,shape[1],1,1], 233 | initializer=tf.constant_initializer(0.)) 234 | coeff = gamma * inv_std 235 | return (x * coeff) + (beta - mean * coeff) 236 | 237 | @add_arg_scope 238 | def vbn_double(ref_half, real_half, center=True, scale=True, epsilon=1e-5, \ 239 | data_format='NCHW', instance_norm=True, scope=None, \ 240 | reuse=None): 241 | assert isinstance(epsilon, float) 242 | shape = real_half.get_shape().as_list() 243 | batch_size = int(real_half.get_shape()[0]) 244 | with tf.variable_scope(scope, 'VBN', reuse=reuse): 245 | if data_format == 'NCHW': 246 | if scale: 247 | gamma = tf.get_variable("gamma", [1,shape[1],1,1], 248 | initializer=tf.constant_initializer(1.)) 249 | if center: 250 | beta = tf.get_variable("beta", [1,shape[1],1,1], 251 | initializer=tf.constant_initializer(0.)) 252 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,2,3], \ 253 | keep_dims=True) 254 | else: 255 | if scale: 256 | gamma = tf.get_variable("gamma", [1,1,1,shape[-1]], 257 | initializer=tf.constant_initializer(1.)) 258 | if center: 259 | beta = tf.get_variable("beta", [1,1,1,shape[-1]], 260 | initializer=tf.constant_initializer(0.)) 261 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,1,2], \ 262 | keep_dims=True) 263 | def _normalize(x, mean, var): 264 | inv_std = tf.rsqrt(var + epsilon) 265 | if scale: 266 | coeff = inv_std * gamma 267 | else: 268 | coeff = inv_std 269 | if center: 270 | return (x * coeff) + (beta - mean * coeff) 271 | else: 272 | return (x - mean) * coeff 273 | if instance_norm: 274 | if data_format == 'NCHW': 275 | real_mean, real_var = tf.nn.moments(real_half, [2,3], \ 276 | keep_dims=True) 277 | else: 278 | real_mean, real_var = tf.nn.moments(real_half, [1,2], \ 279 | keep_dims=True) 280 | real_coeff = 1. / (batch_size + 1.) 281 | ref_coeff = 1. - real_coeff 282 | new_mean = real_coeff * real_mean + ref_coeff * ref_mean 283 | new_var = real_coeff * real_var + ref_coeff * ref_var 284 | ref_output = _normalize(ref_half, ref_mean, ref_var) 285 | real_output = _normalize(real_half, new_mean, new_var) 286 | else: 287 | ref_output = _normalize(ref_half, ref_mean, ref_var) 288 | real_output = _normalize(real_half, ref_mean, ref_var) 289 | return ref_output, real_output 290 | 291 | 292 | @add_arg_scope 293 | def vbn_single(x, center=True, scale=True, \ 294 | epsilon=1e-5, data_format='NCHW', \ 295 | instance_norm=True, scope=None, \ 296 | reuse=None): 297 | assert isinstance(epsilon, float) 298 | shape = x.get_shape().as_list() 299 | if shape[0] is None: 300 | half_size = x.shape[0] // 2 301 | else: 302 | half_size = shape[0] // 2 303 | needs_reshape = len(shape) != 4 304 | if needs_reshape: 305 | orig_shape = shape 306 | if len(shape) == 2: 307 | if data_format == 'NCHW': 308 | x = tf.reshape(x, [shape[0], shape[1], 0, 0]) 309 | else: 310 | x = tf.reshape(x, [shape[0], 1, 1, shape[1]]) 311 | elif len(shape) == 1: 312 | x = tf.reshape(x, [shape[0], 1, 1, 1]) 313 | else: 314 | assert False, shape 315 | shape = x.get_shape().as_list() 316 | batch_size = int(x.get_shape()[0]) 317 | with tf.variable_scope(scope, 'VBN', reuse=reuse): 318 | ref_half = tf.slice(x, [0,0,0,0], [half_size, shape[1], \ 319 | shape[2], shape[3]]) 320 | if data_format == 'NCHW': 321 | if scale: 322 | gamma = tf.get_variable("gamma", [1,shape[1],1,1], 323 | initializer=tf.constant_initializer(1.)) 324 | if center: 325 | beta = tf.get_variable("beta", [1,shape[1],1,1], 326 | initializer=tf.constant_initializer(0.)) 327 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,2,3], \ 328 | keep_dims=True) 329 | else: 330 | if scale: 331 | gamma = tf.get_variable("gamma", [1,1,1,shape[-1]], 332 | initializer=tf.constant_initializer(1.)) 333 | if center: 334 | beta = tf.get_variable("beta", [1,1,1,shape[-1]], 335 | initializer=tf.constant_initializer(0.)) 336 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,1,2], \ 337 | keep_dims=True) 338 | def _normalize(x, mean, var): 339 | inv_std = tf.rsqrt(var + epsilon) 340 | if scale: 341 | coeff = inv_std * gamma 342 | else: 343 | coeff = inv_std 344 | if center: 345 | return (x * coeff) + (beta - mean * coeff) 346 | else: 347 | return (x - mean) * coeff 348 | if instance_norm: 349 | real_half = tf.slice(x, [half_size,0,0,0], \ 350 | [half_size, shape[1], shape[2], shape[3]]) 351 | if data_format == 'NCHW': 352 | real_mean, real_var = tf.nn.moments(real_half, [2,3], \ 353 | keep_dims=True) 354 | else: 355 | real_mean, real_var = tf.nn.moments(real_half, [1,2], \ 356 | keep_dims=True) 357 | real_coeff = 1. / (batch_size + 1.) 358 | ref_coeff = 1. - real_coeff 359 | new_mean = real_coeff * real_mean + ref_coeff * ref_mean 360 | new_var = real_coeff * real_var + ref_coeff * ref_var 361 | ref_output = _normalize(ref_half, ref_mean, ref_var) 362 | real_output = _normalize(real_half, new_mean, new_var) 363 | return tf.concat([ref_output, real_output], axis=0) 364 | else: 365 | return _normalize(x, ref_mean, ref_var) 366 | """ 367 | -------------------------------------------------------------------------------- /Implement/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from scipy import misc, io 4 | import time 5 | from glob import glob 6 | from generator import * 7 | 8 | # *包含loss与acc变量及操作的average_summary类 9 | class average_summary(object): 10 | def __init__(self, variable, name, num_iterations): 11 | # sum_variable:内在累加器,用于累加每次的loss/acc 12 | self.sum_variable = tf.get_variable(name, shape=[], 13 | initializer=tf.constant_initializer(0.), 14 | dtype='float32', 15 | trainable=False, 16 | collections=[tf.GraphKeys.LOCAL_VARIABLES]) 17 | # 每个batch调用一次increment_op,累加每次的loss/acc 18 | with tf.control_dependencies([variable]): 19 | self.increment_op = tf.assign_add(self.sum_variable, variable) 20 | # 当increment_op操作调用了num_iterations次之后,可进行下列操作 21 | self.mean_variable = self.sum_variable / float(num_iterations) # 求平均的loss和acc 22 | self.summary = tf.summary.scalar(name, self.mean_variable) # 将loss和acc存入tf全局图 23 | with tf.control_dependencies([self.summary]): 24 | self.reset_variable_op = tf.assign(self.sum_variable, 0) # 当summary完成后,可进行reset 25 | # 外部调用,将loss/acc存入tf全局图 26 | def add_summary(self, sess, writer, step): 27 | s, _ = sess.run([self.summary, self.reset_variable_op]) 28 | writer.add_summary(s, step) 29 | 30 | # *用于挂载Net的结构,包含__build_model和__build_loss的操作 31 | class Model(object): 32 | def __init__(self, is_training=None, data_format='NCHW'): 33 | self.data_format = data_format 34 | if is_training is None: 35 | self.is_training = tf.get_variable('is_training', dtype=tf.bool, 36 | initializer=tf.constant_initializer(True), 37 | trainable=False) 38 | else: 39 | self.is_training = is_training 40 | 41 | def _build_model(self, inputs): 42 | raise NotImplementedError('Here is your model definition') 43 | 44 | def _build_losses(self, labels): 45 | self.labels = tf.cast(labels, tf.int64) 46 | with tf.variable_scope('loss'): 47 | oh = tf.one_hot(self.labels, 2) # 这里定义了2分类的输出 48 | # *除softmax cross entropy之外,还可更换其他函数 49 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 50 | labels=oh, logits=self.outputs)) 51 | with tf.variable_scope('accuracy'): 52 | am = tf.argmax(self.outputs, 1) 53 | equal = tf.equal(am, self.labels) 54 | self.accuracy = tf.reduce_mean(tf.cast(equal, tf.float32)) 55 | return self.loss, self.accuracy 56 | 57 | # *训练主函数 58 | def train(model_class, use_batch_norm, use_shuf_pair, 59 | train_cover_dir, train_stego_dir, 60 | valid_cover_dir, valid_stego_dir, 61 | batch_size, train_ds_size, valid_ds_size, 62 | log_interval, max_epochs, lr, 63 | log_path, load_path=None): 64 | # *清除默认图的堆栈,设置全局图为默认图 65 | tf.reset_default_graph() 66 | 67 | # *is_training用于判断训练处于train或者valid状态 68 | is_training = tf.get_variable('is_training', dtype=tf.bool, 69 | initializer=True, trainable=False) 70 | 71 | # *定义train_op操作和valid_op操作,将is_training和batch_size设置为对应的状态 72 | disable_training_op = tf.assign(is_training, False) 73 | enable_training_op = tf.assign(is_training, True) 74 | 75 | # *模型初始化 76 | # 设置占位符 77 | temp_cover_list = glob(train_cover_dir + '/*') 78 | temp_img = misc.imread(temp_cover_list[0]) 79 | temp_img_shape = temp_img.shape 80 | img_batch = tf.placeholder(tf.float32, 81 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1], 82 | name='input_image_batch') 83 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch") 84 | # 使用占位符初始化模型 85 | model = model_class(is_training=is_training, data_format='NCHW', 86 | with_bn=use_batch_norm, tlu_threshold=3) 87 | model._build_model(img_batch) 88 | loss, accuracy = model._build_losses(label_batch) 89 | 90 | # *设置需要最小化的loss函数 91 | regularization_losses = tf.get_collection( 92 | tf.GraphKeys.REGULARIZATION_LOSSES) 93 | regularized_loss = tf.add_n([loss] + regularization_losses) 94 | # 定义train中使用的基于loss/acc的类(运行次数:log_interval) 95 | train_loss_s = average_summary(loss, 'train_loss', log_interval) 96 | train_accuracy_s = average_summary(accuracy, 'train_accuracy', log_interval) 97 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size) 98 | valid_loss_s = average_summary(loss, 'valid_loss', 99 | float(valid_ds_size) / float(batch_size)) 100 | valid_accuracy_s = average_summary(accuracy, 'valid_accuracy', 101 | float(valid_ds_size) / float(batch_size)) 102 | 103 | # *全局变量global_step,从0开始进行计数 104 | global_step = tf.Variable(0, trainable=False) 105 | # *定义核心optimizer 106 | # 定义learning_rate的decay操作 107 | init_learning_rate = lr 108 | decay_steps, decay_rate = 2000, 0.95 109 | learning_rate = learning_rate_decay(init_learning_rate=init_learning_rate, 110 | decay_method="exponential", 111 | global_step=global_step, 112 | decay_steps=decay_steps, 113 | decay_rate=decay_rate) 114 | optimizer = tf.train.AdadeltaOptimizer(learning_rate) 115 | 116 | # *定义train及valid过程中需要用到的操作 117 | # 核心操作:最小化loss 118 | minimize_op = optimizer.minimize(loss=regularized_loss, global_step=global_step) 119 | # 训练操作(每个iteration都要用):最小化loss;train_loss累加;train_acc累加 120 | train_op = tf.group(minimize_op, train_loss_s.increment_op, 121 | train_accuracy_s.increment_op) 122 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加 123 | valid_op = tf.group(valid_loss_s.increment_op, 124 | valid_accuracy_s.increment_op) 125 | # 初始化操作:初始化所有的全局变量和局部变量 126 | init_op = tf.group(tf.global_variables_initializer(), 127 | tf.local_variables_initializer()) 128 | 129 | # *定义模型保存变量,最大存储max_to_keep个模型 130 | saver = tf.train.Saver(max_to_keep=max_epochs) 131 | global_valid_accuracy = 0 # 全局valid_acc最大值 132 | 133 | # *会话开始 134 | with tf.Session() as sess: 135 | # 初始化所有的全局变量和局部变量 136 | sess.run(init_op) 137 | # 重载模型 138 | if load_path is not None: 139 | loader = tf.train.Saver(reshape=True) 140 | loader.restore(sess, load_path) 141 | # 定义模型及参数保存位置 142 | writer = tf.summary.FileWriter(log_path + '/LogFile/', sess.graph) 143 | 144 | # 初始化train/valid的loss和acc变量 145 | sess.run([valid_loss_s.reset_variable_op, 146 | valid_accuracy_s.reset_variable_op, 147 | train_loss_s.reset_variable_op, 148 | train_accuracy_s.reset_variable_op]) 149 | 150 | # *训练开始:train/valid 151 | print('Start training...') 152 | global_train_batch = 0 # 全局batch计数 153 | for epoch in range(max_epochs): 154 | start_time = time.time() 155 | # 加载test路径下的img及label列表 156 | train_img_list, train_label_list = get_files(train_cover_dir, 157 | train_stego_dir, 158 | use_shuf_pair=use_shuf_pair) 159 | # 加载valid路径下的img及label列表 160 | valid_img_list, valid_label_list = get_files(valid_cover_dir, 161 | valid_stego_dir, 162 | use_shuf_pair=use_shuf_pair) 163 | 164 | # *训练开始:train 165 | sess.run(enable_training_op) # 转换为train训练状态 166 | local_train_batch = 0 # 局部batch计数 167 | for train_img_minibatch_list, train_label_minibatch_list in \ 168 | get_minibatches(train_img_list, train_label_list, batch_size): 169 | # minibatch数据读取 170 | train_img_batch = get_minibatches_content_img(train_img_minibatch_list, 171 | temp_img_shape[0], 172 | temp_img_shape[1]) 173 | 174 | # train操作及指标显示 175 | sess.run(train_op, feed_dict={img_batch: train_img_batch, 176 | label_batch: train_label_minibatch_list}) 177 | 178 | global_train_batch += 1 179 | local_train_batch += 1 180 | 181 | # 每log_interval个batch后,对train_loss/acc进行存储 182 | # 这是由于train_loss/acc的average_summary以log_interval为基准定义 183 | if global_train_batch % log_interval == 0: 184 | # 注意:loginterval决定了每20输出一次,而不是每个batch存储loss/acc一次 185 | # train_loss/acc显示 186 | local_train_loss = train_loss_s.mean_variable 187 | local_train_accuracy = train_accuracy_s.mean_variable 188 | local_train_loss_value = local_train_loss.eval(session=sess) 189 | local_train_accuracy_value = local_train_accuracy.eval(session=sess) 190 | print('-TRAIN- epoch: %d batch: %d | train_loss: %f train_acc: %f' 191 | % (epoch, local_train_batch, local_train_loss_value, local_train_accuracy_value)) 192 | # train_loss/acc存储 193 | train_loss_s.add_summary(sess, writer, global_train_batch) 194 | train_accuracy_s.add_summary(sess, writer, global_train_batch) 195 | 196 | # *训练开始:validation 197 | sess.run(disable_training_op) 198 | local_valid_loss, local_valid_accuracy = 0, 0 # 本epoch中valid_loss和valid_acc值 199 | for valid_img_minibatch_list, valid_label_minibatch_list in \ 200 | get_minibatches(valid_img_list, valid_label_list, batch_size): 201 | # minibatch数据读取 202 | valid_img_batch = get_minibatches_content_img(valid_img_minibatch_list, 203 | temp_img_shape[0], 204 | temp_img_shape[1]) 205 | 206 | # valid操作及指标显示 207 | sess.run(valid_op, feed_dict={img_batch: valid_img_batch, 208 | label_batch: valid_label_minibatch_list}) 209 | 210 | # 每个epoch中所有batch运行完后,对valid_loss/acc进行显示和存储 211 | # 这是由于valid_loss/acc的average_summary以(valid_ds_size/batch_size)为基准定义 212 | # valid_loss/acc显示 213 | local_valid_loss = valid_loss_s.mean_variable 214 | local_valid_accuracy = valid_accuracy_s.mean_variable 215 | local_valid_loss_value = local_valid_loss.eval(session=sess) 216 | local_valid_accuracy_value = local_valid_accuracy.eval(session=sess) 217 | print('-VALID- epoch: %d | valid_loss: %f valid_acc: %f' 218 | % (epoch, local_valid_loss_value, local_valid_accuracy_value)) 219 | # valid_loss/acc存储 220 | valid_loss_s.add_summary(sess, writer, global_train_batch) 221 | valid_accuracy_s.add_summary(sess, writer, global_train_batch) 222 | 223 | # *模型保存:如果valid_acc大于全局valid_acc,则保存 224 | if local_valid_accuracy_value > global_valid_accuracy or (max_epochs - epoch) < 5: 225 | global_valid_accuracy = local_valid_accuracy_value 226 | saver.save(sess, log_path + '/Model_' + str(epoch) + '.ckpt') 227 | print('---EPOCH:%d--- model has been saved' % epoch) 228 | 229 | # *本epoch中train及valid过程均完毕,记录时间 230 | end_time = time.time() 231 | print('--EPOCH:%d-- runtime: %.2fs ' % (epoch, end_time - start_time), 232 | ' learning rate: ', sess.run(learning_rate), '\n') 233 | 234 | # *测试主函数,查找最佳模型 235 | def test_dataset_findbest(model_class, use_shuf_pair, 236 | test_cover_dir, test_stego_dir, max_epochs, 237 | batch_size, ds_size, log_path): 238 | tf.reset_default_graph() 239 | 240 | # *模型初始化 241 | # 设置占位符 242 | temp_cover_list = glob(test_cover_dir + '/*') 243 | temp_img = misc.imread(temp_cover_list[0]) 244 | temp_img_shape = temp_img.shape 245 | img_batch = tf.placeholder(tf.float32, 246 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1], 247 | name='input_image_batch') 248 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch") 249 | # 使用占位符初始化模型 250 | model = model_class(is_training=False, data_format='NCHW', with_bn=True, tlu_threshold=3) 251 | model._build_model(img_batch) 252 | loss, accuracy = model._build_losses(label_batch) 253 | 254 | # *设置需要计算的loss函数,test_loss/acc与valid_loss/acc的功用类似 255 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size) 256 | test_loss_s = average_summary(loss, 'test_loss', 257 | float(ds_size) / float(batch_size)) 258 | test_accuracy_s = average_summary(accuracy, 'test_accuracy', 259 | float(ds_size) / float(batch_size)) 260 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加 261 | test_op = tf.group(test_loss_s.increment_op, 262 | test_accuracy_s.increment_op) 263 | 264 | # 初始化操作:初始化所有的全局变量和局部变量 265 | init_op = tf.group(tf.global_variables_initializer(), 266 | tf.local_variables_initializer()) 267 | 268 | # *定义模型保存变量,最大存储max_to_keep个模型 269 | saver = tf.train.Saver() 270 | 271 | # *记录每次test后得到的loss和acc 272 | test_loss_arr = [] 273 | test_accuracy_arr = [] 274 | 275 | # *对load_data_path_s列表中的所有模型进行test操作 276 | print('Start testing...') 277 | # 在log路径下搜寻所有可加载文件 278 | load_model_path_s = glob(log_path + '/*.data*') 279 | for load_model_path in load_model_path_s: 280 | start_time = time.time() 281 | # *会话开始 282 | with tf.Session() as sess: 283 | # 初始化所有的全局变量和局部变量 284 | sess.run(init_op) 285 | # 重载模型,去掉结尾的.data-000... 286 | trunc_str = '.data-' 287 | load_model_path_trunc = load_model_path[0:load_model_path.find(trunc_str)] 288 | saver.restore(sess, load_model_path_trunc) 289 | # 初始化test的loss和acc变量 290 | sess.run([test_loss_s.reset_variable_op, 291 | test_accuracy_s.reset_variable_op]) 292 | # 加载test路径下的img及label列表 293 | test_img_list, test_label_list = get_files(test_cover_dir, 294 | test_stego_dir, 295 | use_shuf_pair=use_shuf_pair) 296 | # *对当前load_data_path的模型进行test操作 297 | for test_img_minibatch_list, test_label_minibatch_list in \ 298 | get_minibatches(test_img_list, test_label_list, batch_size): 299 | # minibatch数据读取 300 | test_img_batch = get_minibatches_content_img(test_img_minibatch_list, 301 | temp_img_shape[0], 302 | temp_img_shape[1]) 303 | # 对每次minibatch中test后得到的loss和acc进行累加 304 | sess.run(test_op, feed_dict={img_batch: test_img_batch, 305 | label_batch: test_label_minibatch_list}) 306 | # *记录当前load_data_path模型test操作后得到的loss和acc 307 | test_mean_loss, test_mean_accuracy = sess.run([test_loss_s.mean_variable, 308 | test_accuracy_s.mean_variable]) 309 | test_loss_arr.append(test_mean_loss) 310 | test_accuracy_arr.append(test_mean_accuracy) 311 | end_time = time.time() 312 | print(load_model_path.split("/")[-1]) 313 | print('-TEST- test_loss: %f test_acc: %f | runtime: %.2fs \n' 314 | % (test_loss_arr[-1], test_accuracy_arr[-1], end_time - start_time)) 315 | 316 | # *寻找最佳test_acc对应的模型索引 317 | load_best_model_idx = np.argmax(test_accuracy_arr) 318 | print('-BEST TEST- best_path: ', load_model_path_s[load_best_model_idx]) 319 | print('-BEST TEST- best_loss: %f best_acc: %f \n' 320 | % (test_loss_arr[load_best_model_idx], test_accuracy_arr[load_best_model_idx])) 321 | 322 | return load_model_path_s[load_best_model_idx] 323 | 324 | 325 | # *学习率下降函数,包含各类学习率下降方法 326 | def learning_rate_decay(init_learning_rate, global_step, decay_steps, decay_rate, 327 | decay_method="exponential", staircase=False, 328 | end_learning_rate=0.0001, power=1.0, cycle=False,): 329 | """ 330 | 传入初始learning_rate,根据参数及选项运用不同decay策略更新learning_rate 331 | learning_rate : 初始的learning rate 332 | global_step : 全局的step,与 decay_step 和 decay_rate一起决定了 learning rate的变化 333 | staircase : 如果为 True global_step/decay_step 向下取整 334 | end_learning_rate,power,cycle:只在polynomial_decay方法中使用 335 | """ 336 | if decay_method == 'constant': 337 | decayed_learning_rate = init_learning_rate 338 | elif decay_method == 'exponential': 339 | decayed_learning_rate = tf.train.exponential_decay(init_learning_rate, global_step, 340 | decay_steps, decay_rate, staircase) 341 | elif decay_method == 'inverse_time': 342 | decayed_learning_rate = tf.train.inverse_time_decay(init_learning_rate, global_step, 343 | decay_steps, decay_rate, staircase) 344 | elif decay_method == 'natural_exp': 345 | decayed_learning_rate = tf.train.natural_exp_decay(init_learning_rate, global_step, 346 | decay_steps, decay_rate, staircase) 347 | elif decay_method == 'polynomial': 348 | decayed_learning_rate = tf.train.polynomial_decay(init_learning_rate, global_step, 349 | decay_steps, decay_rate, 350 | end_learning_rate, power, cycle) 351 | else: 352 | decayed_learning_rate = init_learning_rate 353 | 354 | return decayed_learning_rate 355 | 356 | -------------------------------------------------------------------------------- /Implement/testfiles/utils_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from scipy import misc, io 4 | import time 5 | from glob import glob 6 | from generator import * 7 | 8 | # *包含loss与acc变量及操作的average_summary类 9 | class average_summary(object): 10 | def __init__(self, variable, name, num_iterations): 11 | # sum_variable:内在累加器,用于累加每次的loss/acc 12 | self.sum_variable = tf.get_variable(name, shape=[], 13 | initializer=tf.constant_initializer(0.), 14 | dtype='float32', 15 | trainable=False, 16 | collections=[tf.GraphKeys.LOCAL_VARIABLES]) 17 | # 每个batch调用一次increment_op,累加每次的loss/acc 18 | with tf.control_dependencies([variable]): 19 | self.increment_op = tf.assign_add(self.sum_variable, variable) 20 | # 当increment_op操作调用了num_iterations次之后,可进行下列操作 21 | self.mean_variable = self.sum_variable / float(num_iterations) # 求平均的loss和acc 22 | self.summary = tf.summary.scalar(name, self.mean_variable) # 将loss和acc存入tf全局图 23 | with tf.control_dependencies([self.summary]): 24 | self.reset_variable_op = tf.assign(self.sum_variable, 0) # 当summary完成后,可进行reset 25 | # 外部调用,将loss/acc存入tf全局图 26 | def add_summary(self, sess, writer, step): 27 | s, _ = sess.run([self.summary, self.reset_variable_op]) 28 | writer.add_summary(s, step) 29 | 30 | # *用于挂载Net的结构,包含__build_model和__build_loss的操作 31 | class Model(object): 32 | def __init__(self, is_training=None, data_format='NCHW'): 33 | self.data_format = data_format 34 | if is_training is None: 35 | self.is_training = tf.get_variable('is_training', dtype=tf.bool, 36 | initializer=tf.constant_initializer(True), 37 | trainable=False) 38 | else: 39 | self.is_training = is_training 40 | 41 | def _build_model(self, inputs): 42 | raise NotImplementedError('Here is your model definition') 43 | 44 | def _build_losses(self, labels): 45 | self.labels = tf.cast(labels, tf.int64) 46 | with tf.variable_scope('loss'): 47 | oh = tf.one_hot(self.labels, 2) # 这里定义了2分类的输出 48 | # *除softmax cross entropy之外,还可更换其他函数 49 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( 50 | labels=oh, logits=self.outputs)) 51 | with tf.variable_scope('accuracy'): 52 | am = tf.argmax(self.outputs, 1) 53 | equal = tf.equal(am, self.labels) 54 | self.accuracy = tf.reduce_mean(tf.cast(equal, tf.float32)) 55 | return self.loss, self.accuracy 56 | 57 | # *训练主函数 58 | def train(model_class, use_batch_norm, use_shuf_pair, 59 | train_cover_dir, train_stego_dir, 60 | valid_cover_dir, valid_stego_dir, 61 | batch_size, train_ds_size, valid_ds_size, 62 | log_interval, max_epochs, lr, 63 | log_path, load_path=None): 64 | # *清除默认图的堆栈,设置全局图为默认图 65 | tf.reset_default_graph() 66 | 67 | # *is_training用于判断训练处于train或者valid状态 68 | is_training = tf.get_variable('is_training', dtype=tf.bool, 69 | initializer=True, trainable=False) 70 | 71 | # *定义train_op操作和valid_op操作,将is_training和batch_size设置为对应的状态 72 | disable_training_op = tf.assign(is_training, False) 73 | enable_training_op = tf.assign(is_training, True) 74 | 75 | # *模型初始化 76 | # 设置占位符 77 | temp_cover_list = glob(train_cover_dir + '/*') 78 | temp_img = misc.imread(temp_cover_list[0]) 79 | temp_img_shape = temp_img.shape 80 | img_batch = tf.placeholder(tf.float32, 81 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1], 82 | name='input_image_batch') 83 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch") 84 | # 使用占位符初始化模型 85 | model = model_class(is_training, 'NCHW', with_bn=use_batch_norm, tlu_threshold=3) 86 | model._build_model(img_batch) 87 | loss, accuracy = model._build_losses(label_batch) 88 | 89 | # *设置需要最小化的loss函数 90 | regularization_losses = tf.get_collection( 91 | tf.GraphKeys.REGULARIZATION_LOSSES) 92 | regularized_loss = tf.add_n([loss] + regularization_losses) 93 | # 定义train中使用的基于loss/acc的类(运行次数:log_interval) 94 | train_loss_s = average_summary(loss, 'train_loss', log_interval) 95 | train_accuracy_s = average_summary(accuracy, 'train_accuracy', log_interval) 96 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size) 97 | valid_loss_s = average_summary(loss, 'valid_loss', 98 | float(valid_ds_size) / float(batch_size)) 99 | valid_accuracy_s = average_summary(accuracy, 'valid_accuracy', 100 | float(valid_ds_size) / float(batch_size)) 101 | 102 | # *全局变量global_step,从0开始进行计数 103 | global_step = tf.Variable(0, trainable=False) 104 | # *定义核心optimizer 105 | # 定义learning_rate的decay操作 106 | init_learning_rate = lr 107 | decay_steps, decay_rate = 2000, 0.95 108 | learning_rate = learning_rate_decay(init_learning_rate=init_learning_rate, 109 | decay_method="exponential", 110 | global_step=global_step, 111 | decay_steps=decay_steps, 112 | decay_rate=decay_rate) 113 | optimizer = tf.train.AdadeltaOptimizer(learning_rate) 114 | 115 | # *定义train及valid过程中需要用到的操作 116 | # 核心操作:最小化loss 117 | minimize_op = optimizer.minimize(loss=regularized_loss, global_step=global_step) 118 | # 训练操作(每个iteration都要用):最小化loss;train_loss累加;train_acc累加 119 | train_op = tf.group(minimize_op, train_loss_s.increment_op, 120 | train_accuracy_s.increment_op) 121 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加 122 | valid_op = tf.group(valid_loss_s.increment_op, 123 | valid_accuracy_s.increment_op) 124 | # 初始化操作:初始化所有的全局变量和局部变量 125 | init_op = tf.group(tf.global_variables_initializer(), 126 | tf.local_variables_initializer()) 127 | 128 | # *定义模型保存变量,最大存储max_to_keep个模型 129 | saver = tf.train.Saver(max_to_keep=max_epochs+20) 130 | global_valid_accuracy = 0 # 全局valid_acc最大值 131 | 132 | # *会话开始 133 | with tf.Session() as sess: 134 | # 初始化所有的全局变量和局部变量 135 | sess.run(init_op) 136 | # 重载模型 137 | if load_path is not None: 138 | loader = tf.train.Saver(reshape=True) 139 | loader.restore(sess, load_path) 140 | # 定义模型及参数保存位置 141 | writer = tf.summary.FileWriter(log_path + '/LogFile/', sess.graph) 142 | 143 | # 初始化train/valid的loss和acc变量 144 | sess.run([valid_loss_s.reset_variable_op, 145 | valid_accuracy_s.reset_variable_op, 146 | train_loss_s.reset_variable_op, 147 | train_accuracy_s.reset_variable_op]) 148 | 149 | # *训练开始:train/valid 150 | print('Start training...') 151 | global_train_batch = 0 # 全局batch计数 152 | for epoch in range(max_epochs): 153 | start_time = time.time() 154 | train_img_list, train_label_list = get_files(train_cover_dir, 155 | train_stego_dir, 156 | use_shuf_pair=use_shuf_pair) 157 | valid_img_list, valid_label_list = get_files(valid_cover_dir, 158 | valid_stego_dir, 159 | use_shuf_pair=use_shuf_pair) 160 | 161 | # *训练开始:train 162 | sess.run(enable_training_op) # 转换为train训练状态 163 | local_train_batch = 0 # 局部batch计数 164 | for train_img_minibatch_list, train_label_minibatch_list in \ 165 | get_minibatches(train_img_list, train_label_list, batch_size): 166 | # minibatch数据读取 167 | train_img_batch = get_minibatches_content_img(train_img_minibatch_list, 168 | temp_img_shape[0], 169 | temp_img_shape[1]) 170 | 171 | # train操作及指标显示 172 | sess.run(train_op, feed_dict={img_batch: train_img_batch, 173 | label_batch: train_label_minibatch_list}) 174 | 175 | global_train_batch += 1 176 | local_train_batch += 1 177 | 178 | # 每log_interval个batch后,对train_loss/acc进行存储 179 | # 这是由于train_loss/acc的average_summary以log_interval为基准定义 180 | if global_train_batch % log_interval == 0: 181 | # 注意:loginterval决定了每20输出一次,而不是每个batch存储loss/acc一次 182 | # train_loss/acc显示 183 | local_train_loss = train_loss_s.mean_variable 184 | local_train_accuracy = train_accuracy_s.mean_variable 185 | local_train_loss_value = local_train_loss.eval(session=sess) 186 | local_train_accuracy_value = local_train_accuracy.eval(session=sess) 187 | print('-TRAIN- epoch: %d batch: %d | train_loss: %f train_acc: %f' 188 | % (epoch, local_train_batch, local_train_loss_value, local_train_accuracy_value)) 189 | # train_loss/acc存储 190 | train_loss_s.add_summary(sess, writer, global_train_batch) 191 | train_accuracy_s.add_summary(sess, writer, global_train_batch) 192 | 193 | # 对最后20个模型进行存储 194 | if ((train_ds_size // batch_size) * max_epochs - global_train_batch) < 20: 195 | saver.save(sess, log_path + '/Model_' + str(epoch) + '.ckpt') 196 | print('---EPOCH:%d LAST:%d--- model has been saved' 197 | % (epoch, (train_ds_size // batch_size) * max_epochs - global_train_batch + 1)) 198 | 199 | # *训练开始:validation 200 | sess.run(disable_training_op) 201 | local_valid_loss, local_valid_accuracy = 0, 0 # 本epoch中valid_loss和valid_acc值 202 | for valid_img_minibatch_list, valid_label_minibatch_list in \ 203 | get_minibatches(valid_img_list, valid_label_list, batch_size): 204 | # minibatch数据读取 205 | valid_img_batch = get_minibatches_content_img(valid_img_minibatch_list, 206 | temp_img_shape[0], 207 | temp_img_shape[1]) 208 | 209 | # valid操作及指标显示 210 | sess.run(valid_op, feed_dict={img_batch: valid_img_batch, 211 | label_batch: valid_label_minibatch_list}) 212 | 213 | # 每个epoch中所有batch运行完后,对valid_loss/acc进行显示和存储 214 | # 这是由于valid_loss/acc的average_summary以(valid_ds_size/batch_size)为基准定义 215 | # valid_loss/acc显示 216 | local_valid_loss = valid_loss_s.mean_variable 217 | local_valid_accuracy = valid_accuracy_s.mean_variable 218 | local_valid_loss_value = local_valid_loss.eval(session=sess) 219 | local_valid_accuracy_value = local_valid_accuracy.eval(session=sess) 220 | print('-VALID- epoch: %d | valid_loss: %f valid_acc: %f' 221 | % (epoch, local_valid_loss_value, local_valid_accuracy_value)) 222 | # valid_loss/acc存储 223 | valid_loss_s.add_summary(sess, writer, global_train_batch) 224 | valid_accuracy_s.add_summary(sess, writer, global_train_batch) 225 | 226 | # *模型保存:如果valid_acc大于全局valid_acc,则保存 227 | if local_valid_accuracy_value > global_valid_accuracy: 228 | global_valid_accuracy = local_valid_accuracy_value 229 | saver.save(sess, log_path + '/Model_' + str(epoch) + '.ckpt') 230 | print('---EPOCH:%d--- model has been saved' % (epoch)) 231 | 232 | # *本epoch中train及valid过程均完毕,记录时间 233 | end_time = time.time() 234 | print('--EPOCH:%d-- runtime: %.2fs ' % (epoch, end_time - start_time), 235 | ' learning rate: ', sess.run(learning_rate), '\n') 236 | 237 | # *测试主函数,查找最佳模型 238 | def test_dataset_findbest(model_class, use_batch_norm, use_shuf_pair, 239 | test_cover_dir, test_stego_dir, max_epochs, 240 | batch_size, ds_size, log_path): 241 | tf.reset_default_graph() 242 | 243 | # *模型初始化 244 | # 设置占位符 245 | temp_cover_list = glob(test_cover_dir + '/*') 246 | temp_img = misc.imread(temp_cover_list[0]) 247 | temp_img_shape = temp_img.shape 248 | img_batch = tf.placeholder(tf.float32, 249 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1], 250 | name='input_image_batch') 251 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch") 252 | # 使用占位符初始化模型 253 | model = model_class(is_training=False, data_format='NCHW', 254 | with_bn=use_batch_norm, tlu_threshold=3) 255 | model._build_model(img_batch) 256 | loss, accuracy = model._build_losses(label_batch) 257 | 258 | # *设置需要计算的loss函数,test_loss/acc与valid_loss/acc的功用类似 259 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size) 260 | test_loss_s = average_summary(loss, 'test_loss', 261 | float(ds_size) / float(batch_size)) 262 | test_accuracy_s = average_summary(accuracy, 'test_accuracy', 263 | float(ds_size) / float(batch_size)) 264 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加 265 | test_op = tf.group(test_loss_s.increment_op, 266 | test_accuracy_s.increment_op) 267 | 268 | # *全局变量global_step,从0开始进行计数 269 | global_step = tf.Variable(0, trainable=False) 270 | 271 | # 初始化操作:初始化所有的全局变量和局部变量 272 | init_op = tf.group(tf.global_variables_initializer(), 273 | tf.local_variables_initializer()) 274 | 275 | # *定义模型保存变量,最大存储max_to_keep个模型 276 | saver = tf.train.Saver(max_to_keep=max_epochs) 277 | 278 | # *记录每次test后得到的loss和acc 279 | test_loss_arr = [] 280 | test_accuracy_arr = [] 281 | 282 | # *对load_data_path_s列表中的所有模型进行test操作 283 | print('Start testing...') 284 | # 在log路径下搜寻所有可加载文件 285 | load_model_path_s = sorted(glob(log_path + '/*.data*')) 286 | for load_model_path in load_model_path_s: 287 | start_time = time.time() 288 | # *会话开始 289 | with tf.Session() as sess: 290 | # 初始化所有的全局变量和局部变量 291 | sess.run(init_op) 292 | # 重载模型 293 | saver.restore(sess, load_model_path) 294 | # 初始化test的loss和acc变量 295 | sess.run([test_loss_s.reset_variable_op, 296 | test_accuracy_s.reset_variable_op]) 297 | # 加载test路径下的img及label列表 298 | test_img_list, test_label_list = get_files(test_cover_dir, 299 | test_stego_dir, 300 | use_shuf_pair=use_shuf_pair) 301 | # *对当前load_data_path的模型进行test操作 302 | for test_img_minibatch_list, test_label_minibatch_list in \ 303 | get_minibatches(test_img_list, test_label_list, batch_size): 304 | # minibatch数据读取 305 | test_img_batch = get_minibatches_content_img(test_img_minibatch_list, 306 | temp_img_shape[0], 307 | temp_img_shape[1]) 308 | # 对每次minibatch中test后得到的loss和acc进行累加 309 | sess.run(test_op, feed_dict={img_batch: test_img_batch, 310 | label_batch: test_label_minibatch_list}) 311 | # *记录当前load_data_path模型test操作后得到的loss和acc 312 | test_mean_loss, test_mean_accuracy = sess.run([test_loss_s.mean_variable, 313 | test_accuracy_s.mean_variable]) 314 | test_loss_arr.append(test_mean_loss) 315 | test_accuracy_arr.append(test_mean_accuracy) 316 | end_time = time.time() 317 | print(load_model_path.split("/")[-1]) 318 | print('-TEST- test_loss: %f test_acc: %f | runtime: %.2fs \n' 319 | % (test_loss_arr[-1], test_accuracy_arr[-1], end_time - start_time)) 320 | 321 | # *寻找最佳test_acc对应的模型索引 322 | load_best_model_idx = np.argmax(test_accuracy_arr) 323 | print('-BEST TEST- best_path: ', load_model_path_s[load_best_model_idx]) 324 | print('-BEST TEST- best_loss: %f best_acc: %f \n' 325 | % (test_loss_arr[load_best_model_idx], test_accuracy_arr[load_best_model_idx])) 326 | 327 | return load_model_path_s[load_best_model_idx] 328 | 329 | 330 | # *学习率下降函数,包含各类学习率下降方法 331 | def learning_rate_decay(init_learning_rate, global_step, decay_steps, decay_rate, 332 | decay_method="exponential", staircase=False, 333 | end_learning_rate=0.0001, power=1.0, cycle=False,): 334 | """ 335 | 传入初始learning_rate,根据参数及选项运用不同decay策略更新learning_rate 336 | learning_rate : 初始的learning rate 337 | global_step : 全局的step,与 decay_step 和 decay_rate一起决定了 learning rate的变化 338 | staircase : 如果为 True global_step/decay_step 向下取整 339 | end_learning_rate,power,cycle:只在polynomial_decay方法中使用 340 | """ 341 | if decay_method == 'constant': 342 | decayed_learning_rate = init_learning_rate 343 | elif decay_method == 'exponential': 344 | decayed_learning_rate = tf.train.exponential_decay(init_learning_rate, global_step, decay_steps, decay_rate, staircase) 345 | elif decay_method == 'inverse_time': 346 | decayed_learning_rate = tf.train.inverse_time_decay(init_learning_rate, global_step, decay_steps, decay_rate, staircase) 347 | elif decay_method == 'natural_exp': 348 | decayed_learning_rate = tf.train.natural_exp_decay(init_learning_rate, global_step, decay_steps, decay_rate, staircase) 349 | elif decay_method == 'polynomial': 350 | decayed_learning_rate = tf.train.polynomial_decay(init_learning_rate, global_step, decay_steps, decay_rate, end_learning_rate, power, cycle) 351 | else: 352 | decayed_learning_rate = init_learning_rate 353 | 354 | return decayed_learning_rate 355 | 356 | 357 | 358 | 359 | def find_best(model_class, valid_gen, test_gen, valid_batch_size, \ 360 | test_batch_size, valid_ds_size, test_ds_size, load_paths): 361 | tf.reset_default_graph() 362 | valid_runner = GeneratorRunner(valid_gen, valid_batch_size * 30) 363 | img_batch, label_batch = valid_runner.get_batched_inputs(valid_batch_size) 364 | model = model_class(False, 'NCHW') 365 | model._build_model(img_batch) 366 | loss, accuracy = model._build_losses(label_batch) 367 | loss_summary = average_summary(loss, 'loss', \ 368 | float(valid_ds_size) \ 369 | / float(valid_batch_size)) 370 | accuracy_summary = average_summary(accuracy, 'accuracy', \ 371 | float(valid_ds_size) \ 372 | / float(valid_batch_size)) 373 | increment_op = tf.group(loss_summary.increment_op, \ 374 | accuracy_summary.increment_op) 375 | global_step = tf.get_variable('global_step', dtype=tf.int32, shape=[], \ 376 | initializer=tf.constant_initializer(0), \ 377 | trainable=False) 378 | init_op = tf.group(tf.global_variables_initializer(), \ 379 | tf.local_variables_initializer()) 380 | saver = tf.train.Saver(max_to_keep=10000) 381 | accuracy_arr = [] 382 | loss_arr = [] 383 | print("validation") 384 | for load_path in load_paths: 385 | with tf.Session() as sess: 386 | sess.run(init_op) 387 | saver.restore(sess, load_path) # load_path = './model/checkpoint/model.ckpt' 388 | valid_runner.start_threads(sess, 1) 389 | _time = time.time() 390 | for j in range(0, valid_ds_size, valid_batch_size): 391 | sess.run(increment_op) 392 | mean_loss, mean_accuracy = sess.run([loss_summary.mean_variable ,\ 393 | accuracy_summary.mean_variable]) 394 | accuracy_arr.append(mean_accuracy) 395 | loss_arr.append(mean_loss) 396 | print(load_path) 397 | print("Accuracy:", accuracy_arr[-1], "| Loss:", loss_arr[-1], \ 398 | "in", time.time() - _time, "seconds.") 399 | argmax = np.argmax(accuracy_arr) 400 | print("best savestate:", load_paths[argmax], "with", \ 401 | accuracy_arr[argmax], "accuracy and", loss_arr[argmax], \ 402 | "loss on validation") 403 | print("test:") 404 | test_dataset(model_class, test_gen, test_batch_size, test_ds_size, \ 405 | load_paths[argmax]) 406 | return argmax, accuracy_arr, loss_arr 407 | """按照train方式改动 408 | """ -------------------------------------------------------------------------------- /Implement/.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 94 | 95 | 96 | 97 | global_st 98 | learning_rate_decay 99 | # 100 | ## 101 | global_step 102 | ''' 103 | \ 104 | local_learning_rate 105 | reduce 106 | learn 107 | test_dataset 108 | time 109 | dataaug 110 | glob 111 | misc 112 | random 113 | rand 114 | source_cover_list_shuf 115 | tlu_threshold 116 | split 117 | 118 | 119 | # * 120 | # 121 | """ 122 | 123 | source_cover_list 124 | 125 | 126 | 127 | 139 | 140 | 141 | 142 | 143 | true 144 | DEFINITION_ORDER 145 | 146 | 147 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 |