├── .idea ├── CAECNNcode.iml ├── misc.xml ├── modules.xml └── vcs.xml ├── 1_Cover.pgm ├── 1_stego wow_0.4.pgm ├── README.md ├── data ├── 1.jpg ├── S-UNIWARD0.2.png ├── WOW0.5random_CNN.png ├── coverstego.jpg ├── readme.md └── subtraction.jpg ├── easy work ├── data │ ├── C_3.pgm │ ├── S_1.pgm │ └── readme.md ├── main.py ├── readme.md └── yijianyunxing.py ├── fliter.py ├── input_data.py ├── model.py ├── new version ├── convert2tfrecord.py ├── input_data1.py ├── model1.py ├── readme.md └── train2.py ├── onehot.py ├── rename.py ├── tfrecord.py ├── train.py └── train1.py /.idea/CAECNNcode.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 36 | 37 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /1_Cover.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/1_Cover.pgm -------------------------------------------------------------------------------- /1_stego wow_0.4.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/1_stego wow_0.4.pgm -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deeplearning for Steganalysis 2 | 3 | *** 4 | ## Steganography and Steganalysis 5 | 6 | 7 | 8 | Steganography is the science to conceal secrect messages in the images though slightly modifying the pixel values. Content-adaptive steganographic schemes tend to embed the messages in complex regions to escape from detection are the most secure method in nowadays. Examples in spatial domain include HUGO, WOW, S-UNIWARD. 9 | Corresponding to steganography, steganalysis is the art of detecting hidden data in images. Usually, this task is formulated as a binary classification problem to distinguish between cover and stego. 10 | 11 | 12 | ## LSB steganography cover and stego 13 | 14 | * 1: cover(left) and stego(right) 15 |
16 | 17 | 18 | 19 | * 2: the subtraction result of cover and stego(small payload) 20 |
21 | 22 | ## J-UNIWARD steganography cat cover and stego 23 | 24 | * 3: the subtraction result of cover and stego(payload = 0.3 ) 25 |
26 | 27 | ## deeplearning for steganalysis 28 | 29 | *** 30 | Different from traditional computer vision task, the goal of image steganalysis is to find embedding operation which may be extremely low noise to the cover. So there's no maxpooling layer in my network which could destory small imformations or features caused by Steganography. 31 | 32 | 33 | ## some results 34 | 35 | * 3: The training process,the net begins to converge at 50,000 step(5 epoch) 36 | ![Training process](https://github.com/jiangszzzzz/CAECNNcode/blob/master/data/S-UNIWARD0.2.png?raw=true) 37 | 38 | * 4: WOW0.5random_CNN training and validation accurcy. It can be seen from the validation loss value that the model is not overfitted. Amazing fitting ability. 39 | 40 |
41 | 42 | *** 43 | # reference 44 | 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /data/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/1.jpg -------------------------------------------------------------------------------- /data/S-UNIWARD0.2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/S-UNIWARD0.2.png -------------------------------------------------------------------------------- /data/WOW0.5random_CNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/WOW0.5random_CNN.png -------------------------------------------------------------------------------- /data/coverstego.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/coverstego.jpg -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/subtraction.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/subtraction.jpg -------------------------------------------------------------------------------- /easy work/data/C_3.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/easy work/data/C_3.pgm -------------------------------------------------------------------------------- /easy work/data/S_1.pgm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/easy work/data/S_1.pgm -------------------------------------------------------------------------------- /easy work/data/readme.md: -------------------------------------------------------------------------------- 1 | #some pgm image 2 | -------------------------------------------------------------------------------- /easy work/main.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/6/7 0007 14:23 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import os 10 | import matplotlib.pyplot as plt 11 | import skimage.io as io 12 | from scipy.misc import imread, imresize 13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 14 | 15 | def rename(): 16 | count=1 17 | path = '.\\data\\' 18 | 19 | for dirpath, dirnames, filenames in os.walk(path): 20 | # for filename in filenames: 21 | # print(os.path.join(dirpath, filename)) 22 | 23 | for files in filenames: 24 | # if files.endswith('.pgm'): 25 | name = files.split('_') 26 | if name[0] == 'C': 27 | pass 28 | else: 29 | Olddir = os.path.join(dirpath, files) 30 | if os.path.isdir(Olddir): 31 | continue 32 | filename = os.path.splitext(files)[0] 33 | filetype = os.path.splitext(files)[1] 34 | 35 | #直接改名字的 36 | # Newdir = os.path.join(path, 'S' + filetype) 37 | 38 | # 文件名前自动增加S 39 | Newdir = os.path.join(dirpath, ('S_'+filename) + filetype) 40 | 41 | # 文件序号一次递增 42 | # Newdir = os.path.join(path, str(count) + filetype) 43 | 44 | 45 | 46 | # 批量取分隔符(___)前面 / 后面的名称 47 | # if filename.find('---')>=0:#如果文件名中含有--- 48 | # 49 | # Newdir=os.path.join(direc,filename.split('---')[0]+filetype); 50 | # 51 | # #取---前面的字符,若需要取后面的字符则使用filename.split('---')[1] 52 | # 53 | # if not os.path.isfile(Newdir): 54 | 55 | os.rename(Olddir, Newdir) 56 | 57 | count+= 1 58 | 59 | def get_file(file_dir): 60 | cover = [] 61 | label_cover = [] 62 | stego = [] 63 | label_stego = [] 64 | # 打标签 65 | for file in os.listdir(file_dir): 66 | # if file.endswith('0') or file.startswith('.'): 67 | # continue # Skip! 68 | name = file.split('_') 69 | if name[0] == 'C': 70 | cover.append(file_dir + file) 71 | label_cover.append(0) 72 | if name[0] == 'S': 73 | stego.append(file_dir + file) 74 | label_stego.append(1) 75 | print("这里有 %d cover \n这里有 %d stego" 76 | % (len(cover), len(stego))) 77 | # 打乱文件顺序shuffle 78 | image_list = np.hstack((cover, stego)) 79 | label_list = np.hstack((label_cover, label_stego)) 80 | temp = np.array([image_list, label_list]) 81 | temp = temp.transpose() 82 | np.random.shuffle(temp) 83 | 84 | image_list = list(temp[:, 0]) 85 | label_list = list(temp[:, 1]) 86 | label_list = [int(i) for i in label_list] 87 | 88 | return image_list, label_list 89 | 90 | def int64_feature(value): 91 | """Wrapper for inserting int64 features into Example proto.""" 92 | if not isinstance(value, list): 93 | value = [value] 94 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 95 | 96 | 97 | def bytes_feature(value): 98 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 99 | 100 | def convert_to_tfrecord(images, labels, save_dir, name): 101 | '''convert all images and labels to one tfrecord file. 102 | Args: 103 | images: list of image directories, string type 104 | labels: list of labels, int type 105 | save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' 106 | name: the name of tfrecord file, string type, e.g.: 'train' 107 | Return: 108 | no return 109 | Note: 110 | converting needs some time, be patient... 111 | ''' 112 | 113 | filename = os.path.join(save_dir, name + '.tfrecords') 114 | n_samples = len(labels) 115 | 116 | if np.shape(images)[0] != n_samples: 117 | raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], n_samples)) 118 | 119 | # wait some time here, transforming need some time based on the size of your data. 120 | writer = tf.python_io.TFRecordWriter(filename) 121 | print('\nTransform start......') 122 | for i in np.arange(0, n_samples): 123 | try: 124 | # image = imread(image[i]) 125 | 126 | image = io.imread(images[i]) # type(image) must be array! 127 | image = imresize(image, (256, 256)) 128 | image_raw = image.tostring() 129 | label = int(labels[i]) 130 | example = tf.train.Example(features=tf.train.Features(feature={ 131 | 'label': int64_feature(label), 132 | 'image_raw': bytes_feature(image_raw)})) 133 | writer.write(example.SerializeToString()) 134 | except IOError as e: 135 | print('Could not read:', images[i]) 136 | print('error: %s' % e) 137 | print('Skip it!\n') 138 | writer.close() 139 | print('Transform done!') 140 | 141 | 142 | def read_and_decode(tfrecords_file, batch_size): 143 | '''read and decode tfrecord file, generate (image, label) batches 144 | Args: 145 | tfrecords_file: the directory of tfrecord file 146 | batch_size: number of images in each batch 147 | Returns: 148 | image: 4D tensor - [batch_size, width, height, channel] 149 | label: 1D tensor - [batch_size] 150 | ''' 151 | # make an input queue from the tfrecord file 152 | filename_queue = tf.train.string_input_producer([tfrecords_file]) 153 | 154 | reader = tf.TFRecordReader() 155 | _, serialized_example = reader.read(filename_queue) 156 | img_features = tf.parse_single_example( 157 | serialized_example, 158 | features={ 159 | 160 | 'label': tf.FixedLenFeature([], tf.int64), 161 | 'image_raw': tf.FixedLenFeature([], tf.string), 162 | }) 163 | image = tf.decode_raw(img_features['image_raw'], tf.uint8) 164 | 165 | ########################################################## 166 | # you can put data augmentation here, I didn't use it 167 | ########################################################## 168 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset. 169 | 170 | 171 | image = tf.reshape(image, [256, 256, 1]) 172 | # image = tf.reshape(image, [256, 256]) #for plot 173 | label = tf.cast(img_features['label'], tf.int32) 174 | image_batch, label_batch = tf.train.shuffle_batch([image, label], 175 | batch_size=batch_size, 176 | num_threads=64, 177 | capacity=2000, 178 | min_after_dequeue=20) 179 | return image_batch, tf.reshape(label_batch, [batch_size]) 180 | 181 | 182 | def plot_images(images, labels): 183 | '''plot one batch size 184 | ''' 185 | for i in np.arange(0, BATCH_SIZE): 186 | plt.subplot(5, 5, i + 1) 187 | plt.axis('off') 188 | # plt.title(chr(ord('D') + labels[i] - 1), fontsize=14) 189 | 190 | if labels[i] == 1: 191 | plt.title(str('Stego'), fontsize=14) 192 | else: 193 | plt.title(str('Cover'), fontsize=14) 194 | 195 | plt.subplots_adjust(top=1.5) 196 | plt.imshow(images[i]) 197 | plt.show() 198 | 199 | ##### 200 | # for test tfrecords 201 | ##### 202 | # BATCH_SIZE = 25 203 | # BATCH_SIZE1 = 25 204 | # image_batch, label_batch = read_and_decode(tfrecords_file, batch_size=BATCH_SIZE) 205 | # image_batch1, label_batch1 = read_and_decode(tfrecords_file1, batch_size=BATCH_SIZE1) 206 | # 207 | # with tf.Session() as sess: 208 | # i = 0 209 | # coord = tf.train.Coordinator() 210 | # threads = tf.train.start_queue_runners(coord=coord) 211 | # 212 | # try: 213 | # while not coord.should_stop() and i < 1: 214 | # # just plot one batch size 215 | # image, label = sess.run([image_batch, label_batch]) 216 | # plot_images(image, label) 217 | # 218 | # image, label = sess.run([image_batch, label_batch]) 219 | # plot_images(image, label) 220 | # 221 | # image, label = sess.run([image_batch1, label_batch1]) 222 | # plot_images(image, label) 223 | # 224 | # image, label = sess.run([image_batch1, label_batch1]) 225 | # plot_images(image, label) 226 | # 227 | # i += 1 228 | # 229 | # except tf.errors.OutOfRangeError: 230 | # print('done!') 231 | # finally: 232 | # coord.request_stop() 233 | # coord.join(threads) 234 | 235 | 236 | 237 | # model 238 | 239 | def inference(images, batch_size, n_classes): 240 | with tf.variable_scope('conv1') as scope: 241 | weights = tf.get_variable('weights', 242 | #kernel size, kernel size, channels, kernel number 243 | shape=[3, 3, 1, 32], 244 | dtype=tf.float32, 245 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)) 246 | biases = tf.get_variable('biases', 247 | shape=[32], 248 | dtype=tf.float32, 249 | initializer=tf.constant_initializer(0.1)) 250 | conv = tf.nn.conv2d(images, weights, strides=[1, 1, 1, 1], padding='SAME') 251 | pre_activation = tf.nn.bias_add(conv, biases) 252 | conv1 = tf.nn.relu(pre_activation, name=scope.name) 253 | 254 | # with tf.variable_scope('pooling1_lrn') as scope: 255 | # pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling1') 256 | # norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1') 257 | 258 | 259 | with tf.variable_scope('conv2') as scope: 260 | weights = tf.get_variable('weights', 261 | shape=[3, 3, 32, 16], 262 | dtype=tf.float32, 263 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)) 264 | biases = tf.get_variable('biases', 265 | shape=[16], 266 | dtype=tf.float32, 267 | initializer=tf.constant_initializer(0.1)) 268 | conv = tf.nn.conv2d(conv1, weights, strides=[1, 1, 1, 1], padding='SAME') 269 | pre_activation = tf.nn.bias_add(conv, biases) 270 | conv2 = tf.nn.relu(pre_activation, name='conv2') 271 | 272 | # pool2 and norm2 273 | # with tf.variable_scope('pooling2_lrn') as scope: 274 | # norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2') 275 | # pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 1, 1, 1], padding='SAME', name='pooling2') 276 | 277 | with tf.variable_scope('local3') as scope: 278 | reshape = tf.reshape(conv2, shape=[batch_size, -1]) 279 | dim = reshape.get_shape()[1].value 280 | weights = tf.get_variable('weights', 281 | shape=[dim, 256], 282 | dtype=tf.float32, 283 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 284 | biases = tf.get_variable('biases', 285 | shape=[256], 286 | dtype=tf.float32, 287 | initializer=tf.constant_initializer(0.1)) 288 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) 289 | 290 | # local4 291 | with tf.variable_scope('local4') as scope: 292 | weights = tf.get_variable('weights', 293 | shape=[256, 256], 294 | dtype=tf.float32, 295 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 296 | biases = tf.get_variable('biases', 297 | shape=[256], 298 | dtype=tf.float32, 299 | initializer=tf.constant_initializer(0.1)) 300 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4') 301 | 302 | # softmax 303 | with tf.variable_scope('softmax_linear') as scope: 304 | weights = tf.get_variable('softmax_linear', 305 | shape=[256, n_classes], 306 | dtype=tf.float32, 307 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 308 | biases = tf.get_variable('biases', 309 | shape=[n_classes], 310 | dtype=tf.float32, 311 | initializer=tf.constant_initializer(0.1)) 312 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear') 313 | 314 | return softmax_linear 315 | 316 | #logits 是inference的返回值,labels是ground truth 317 | def losses(logits, labels): 318 | with tf.variable_scope('loss') as scope: 319 | 320 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits \ 321 | (logits=logits, labels=labels, name='xentropy_per_example') 322 | loss = tf.reduce_mean(cross_entropy, name='loss') 323 | tf.summary.scalar(scope.name + '/loss', loss) 324 | return loss 325 | 326 | 327 | def trainning(loss, learning_rate): 328 | with tf.name_scope('optimizer'): 329 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 330 | global_step = tf.Variable(0, name='global_step', trainable=False) 331 | train_op = optimizer.minimize(loss, global_step=global_step) 332 | return train_op 333 | 334 | 335 | def evaluation(logits, labels): 336 | with tf.variable_scope('accuracy') as scope: 337 | correct = tf.nn.in_top_k(logits, labels, 1) 338 | correct = tf.cast(correct, tf.float16) 339 | accuracy = tf.reduce_mean(correct) 340 | tf.summary.scalar(scope.name + '/accuracy', accuracy) 341 | return accuracy 342 | 343 | 344 | #training set 345 | N_CLASSES = 2 # cover与stego 346 | IMG_W = 256 # resize 347 | IMG_H = 256 348 | BATCH_SIZE = 16 349 | CAPACITY = 300 350 | MAX_STEP = 250000 # 一般大于10K 351 | learning_rate = 0.00001 # 一般小于0.0001 352 | 353 | def run_training(tfrecords_file,tfrecords_file1): 354 | 355 | logs_train_dir = '.\\logs\\train' 356 | # logs_val_dir = 'H:\\dataWOW_0.05random\\logs\\val' 357 | 358 | tfrecords_traindir = tfrecords_file 359 | tfrecords_valdir = tfrecords_file1 360 | 361 | # 获得batch tfrecord方法 362 | train_batch, train_label_batch = read_and_decode(tfrecords_traindir, BATCH_SIZE) 363 | val_batch, val_label_batch = read_and_decode(tfrecords_valdir, BATCH_SIZE) 364 | 365 | 366 | x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 256, 256, 1]) 367 | y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE]) 368 | 369 | 370 | logits = inference(x, BATCH_SIZE, N_CLASSES) 371 | loss = losses(logits, y_) 372 | acc = evaluation(logits, y_) 373 | train_op = trainning(loss, learning_rate) 374 | 375 | 376 | sess = tf.Session() 377 | 378 | saver = tf.train.Saver() 379 | sess.run(tf.global_variables_initializer()) 380 | coord = tf.train.Coordinator() 381 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 382 | 383 | try: 384 | for step in np.arange(MAX_STEP): 385 | 386 | if coord.should_stop(): 387 | break 388 | 389 | tra_images, tra_labels = sess.run([train_batch, train_label_batch]) 390 | 391 | 392 | _, tra_loss, tra_acc = sess.run([train_op, loss, acc], 393 | feed_dict={ 394 | x: tra_images, 395 | y_: tra_labels}) 396 | if step % 2 == 0: 397 | print(tfrecords_traindir) 398 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) 399 | 400 | 401 | 402 | if step % 4 == 0 or (step + 1) == MAX_STEP: 403 | val_images, val_labels = sess.run([val_batch, val_label_batch]) 404 | val_loss, val_acc = sess.run([loss, acc], 405 | feed_dict={ 406 | x: val_images, 407 | y_: val_labels}) 408 | print(tfrecords_valdir) 409 | print(' ** Step %d, val loss = %.2f, val accuracy = %.2f%% **' % (step, val_loss, val_acc * 100.0)) 410 | 411 | 412 | if step % 2000 == 0 or (step + 1) == MAX_STEP: 413 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') 414 | saver.save(sess, checkpoint_path, global_step=step) 415 | 416 | except tf.errors.OutOfRangeError: 417 | print('Done training -- epoch limit reached') 418 | finally: 419 | coord.request_stop() 420 | coord.join(threads) 421 | 422 | -------------------------------------------------------------------------------- /easy work/readme.md: -------------------------------------------------------------------------------- 1 | # some code for steganalysis expriments. 2 | Modify the related parameters in the yijianyunxing file and run! 3 | -------------------------------------------------------------------------------- /easy work/yijianyunxing.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/6/7 0007 14:51 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import main 8 | 9 | main.rename() 10 | 11 | name_test = 'testrandomtrain' 12 | name_test1 = 'testrandomval' 13 | tfrecords_file = '.\\testrandomtrain.tfrecords' 14 | tfrecords_file1 = '.\\testrandomval.tfrecords' 15 | 16 | test_dir = '.\\data\\train\\' 17 | save_dir = '.\\' 18 | test_dir1 = '.\\data\\val\\' 19 | save_dir1= '.\\' 20 | 21 | images, labels = main.get_file(test_dir) 22 | main.convert_to_tfrecord(images, labels, save_dir, name_test) 23 | images1, labels1 = main.get_file(test_dir1) 24 | main.convert_to_tfrecord(images1, labels1, save_dir1, name_test1) 25 | 26 | main.run_training(tfrecords_file,tfrecords_file1) -------------------------------------------------------------------------------- /fliter.py: -------------------------------------------------------------------------------- 1 | import tensorlayer as tl 2 | import tensorflow as tf 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | from tensorlayer.layers import * 6 | import numpy as np 7 | 8 | sess = tf.InteractiveSession() 9 | x = tf.placeholder(tf.float32, [None, 512, 512, 1]) 10 | F0 = np.array([[-1, 2, -2, 2, -1], 11 | [2, -6, 8, -6, 2], 12 | [-2, 8, -12, 8, -2], 13 | [2, -6, 8, -6, 2], 14 | [-1, 2, -2, 2, -1]], dtype=np.float32) 15 | F0 = F0 / 12. 16 | high_pass_filter = tf.constant_initializer(value=F0, dtype=tf.float32) 17 | net = InputLayer(x, name='inputlayer') 18 | net = Conv2d(net, 1, (5, 5), (1, 1), act=tf.identity, 19 | padding='SAME', W_init=high_pass_filter, name='HighPass') 20 | y = net.outputs 21 | tl.layers.initialize_global_variables(sess) 22 | 23 | img = cv2.imread('1_cover.pgm',0).astype(np.float32).reshape([1,512,512,1]) 24 | 25 | img_after = y.eval(feed_dict = {x:img}) 26 | 27 | 28 | 29 | 30 | if __name__ == '__main__': 31 | # plt.imshow(img.reshape([256,256])) 32 | # plt.imshow(img_after.reshape([256,256])) 33 | # 34 | pgm_info = np.where(img_after > 10, 1, 0) 35 | plt.imshow(pgm_info.reshape([512,512])) 36 | plt.show() 37 | 38 | -------------------------------------------------------------------------------- /input_data.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/3/12 0012 15:56 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import os 10 | 11 | #定义读取函数,返回两个list,image_list是含有图片路径的string,label_list含有0,1 12 | 13 | def get_files(file_dir): 14 | cover = [] 15 | label_cover = [] 16 | stego = [] 17 | label_stego = [] 18 | #打标签 19 | for file in os.listdir(file_dir): 20 | # if file.endswith('0') or file.startswith('.'): 21 | # continue # Skip! 22 | name = file.split('_') 23 | if name[0] == 'C0': 24 | cover.append(file_dir + file) 25 | label_cover.append(0) 26 | if name[0] == 'S1' : 27 | stego.append(file_dir + file) 28 | label_stego.append(1) 29 | print("这里有 %d cover \n这里有 %d stego" 30 | % (len(cover), len(stego))) 31 | #打乱文件顺序shuffle 32 | image_list = np.hstack((cover,stego)) 33 | label_list = np.hstack((label_cover, label_stego)) 34 | temp = np.array([image_list, label_list]) 35 | temp = temp.transpose() 36 | np.random.shuffle(temp) 37 | 38 | image_list = list(temp[:, 0]) 39 | label_list = list(temp[:, 1]) 40 | label_list = [int(i) for i in label_list] 41 | 42 | return image_list , label_list 43 | 44 | #定义batch函数 45 | def get_batch(image, label, 46 | image_W, image_H, 47 | batch_size, capacity): 48 | 49 | # #将python.list 转换成tf能够识别的格式 50 | 51 | label = tf.cast(label, tf.int32) 52 | image = tf.cast(image, tf.string) 53 | 54 | input_queue = tf.train.slice_input_producer([image, label]) 55 | 56 | 57 | label = input_queue[1] 58 | 59 | image_contents = tf.read_file(input_queue[0]) 60 | 61 | print(input_queue[0]) 62 | 63 | image = tf.image.decode_png(image_contents, channels=0) 64 | 65 | image = tf.reshape(image, [ 256, 256, 1]) 66 | image = tf.image.per_image_standardization(image) 67 | 68 | image_batch, label_batch = tf.train.batch([image, label], 69 | batch_size = batch_size, 70 | num_threads=64, 71 | capacity=capacity, 72 | ) 73 | 74 | label_batch = tf.reshape(label_batch, [batch_size]) 75 | image_batch = tf.cast(image_batch, tf.float32) 76 | 77 | return image_batch, label_batch 78 | 79 | def read_and_decode(tfrecords_file, batch_size): 80 | '''read and decode tfrecord file, generate (image, label) batches 81 | Args: 82 | tfrecords_file: the directory of tfrecord file 83 | batch_size: number of images in each batch 84 | Returns: 85 | image: 4D tensor - [batch_size, width, height, channel] 86 | label: 1D tensor - [batch_size] 87 | ''' 88 | # make an input queue from the tfrecord file 89 | filename_queue = tf.train.string_input_producer([tfrecords_file]) 90 | 91 | reader = tf.TFRecordReader() 92 | _, serialized_example = reader.read(filename_queue) 93 | img_features = tf.parse_single_example( 94 | serialized_example, 95 | features={ 96 | 'label': tf.FixedLenFeature([], tf.int64), 97 | 'image_raw': tf.FixedLenFeature([], tf.string), 98 | }) 99 | image = tf.decode_raw(img_features['image_raw'], tf.uint8) 100 | 101 | ########################################################## 102 | # you can put data augmentation here, I didn't use it 103 | ########################################################## 104 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset. 105 | 106 | image = tf.reshape(image, [512, 512, 1]) 107 | label = tf.cast(img_features['label'], tf.int32) 108 | image_batch, label_batch = tf.train.batch([image, label], 109 | batch_size=batch_size, 110 | num_threads=64, 111 | capacity=2000) 112 | 113 | image_batch = tf.cast(image_batch, tf.float32) 114 | 115 | return image_batch, tf.reshape(label_batch, [batch_size]) 116 | 117 | # file_dir = 'F://CAE_CNN//data//pgm_coverstego//' 118 | # file_dir = 'F://CAE_CNN//data//train//' 119 | # get_files(file_dir) 120 | # file_dir = 'G://PGMtoPNG//train_imgs//' 121 | # 122 | # import matplotlib.pyplot as plt 123 | # 124 | # BATCH_SIZE = 2 125 | # CAPACITY = 256 126 | # IMG_W = 256 127 | # IMG_H = 256 128 | # 129 | # image_list, label_list = get_files(file_dir) 130 | # image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY) 131 | # 132 | # with tf.Session() as sess: 133 | # i = 0 134 | # coord = tf.train.Coordinator() 135 | # threads = tf.train.start_queue_runners(coord=coord) 136 | # try: 137 | # while not coord.should_stop() and i < 2: 138 | # img, label = sess.run([image_batch, label_batch]) 139 | # 140 | # for j in np.arange(BATCH_SIZE): 141 | # print("label: %d" % label[j]) 142 | # 143 | # plt.imshow(img[j]) 144 | # # plt.imshow('F://CAE_CNN//data//pgm_cover//Cover.1.pgm') 145 | # 146 | # plt.show() 147 | # # print(img.eval()) 148 | # i += 1 149 | # except tf.errors.OutOfRangeError: 150 | # print("done!") 151 | # finally: 152 | # coord.request_stop() 153 | # coord.join(threads) 154 | # 155 | 156 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/3/13 0013 15:32 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import tensorflow as tf 8 | 9 | def inference(images, batch_size, n_classes): 10 | with tf.variable_scope('conv1') as scope: 11 | weights = tf.get_variable('weights', 12 | #1通道 13 | shape=[3, 3, 1, 16], 14 | dtype=tf.float32, 15 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)) 16 | biases = tf.get_variable('biases', 17 | shape=[16], 18 | dtype=tf.float32, 19 | initializer=tf.constant_initializer(0.1)) 20 | conv = tf.nn.conv2d(images, weights, strides=[1, 1, 1, 1], padding='SAME') 21 | pre_activation = tf.nn.bias_add(conv, biases) 22 | conv1 = tf.nn.relu(pre_activation, name=scope.name) 23 | 24 | with tf.variable_scope('pooling1_lrn') as scope: 25 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling1') 26 | norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1') 27 | 28 | 29 | # with tf.variable_scope('conv2') as scope: 30 | # weights = tf.get_variable('weights', 31 | # shape=[3, 3, 16, 16], 32 | # dtype=tf.float32, 33 | # initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)) 34 | # biases = tf.get_variable('biases', 35 | # shape=[16], 36 | # dtype=tf.float32, 37 | # initializer=tf.constant_initializer(0.1)) 38 | # conv = tf.nn.conv2d(norm1, weights, strides=[1, 1, 1, 1], padding='SAME') 39 | # pre_activation = tf.nn.bias_add(conv, biases) 40 | # conv2 = tf.nn.relu(pre_activation, name='conv2') 41 | # 42 | # # pool2 and norm2 43 | # with tf.variable_scope('pooling2_lrn') as scope: 44 | # norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2') 45 | # pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 1, 1, 1], padding='SAME', name='pooling2') 46 | 47 | with tf.variable_scope('local3') as scope: 48 | reshape = tf.reshape(pool1, shape=[batch_size, -1]) 49 | dim = reshape.get_shape()[1].value 50 | weights = tf.get_variable('weights', 51 | shape=[dim, 128], 52 | dtype=tf.float32, 53 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 54 | biases = tf.get_variable('biases', 55 | shape=[128], 56 | dtype=tf.float32, 57 | initializer=tf.constant_initializer(0.1)) 58 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) 59 | 60 | # local4 61 | with tf.variable_scope('local4') as scope: 62 | weights = tf.get_variable('weights', 63 | shape=[128, 128], 64 | dtype=tf.float32, 65 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 66 | biases = tf.get_variable('biases', 67 | shape=[128], 68 | dtype=tf.float32, 69 | initializer=tf.constant_initializer(0.1)) 70 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4') 71 | 72 | # softmax 73 | with tf.variable_scope('softmax_linear') as scope: 74 | weights = tf.get_variable('softmax_linear', 75 | shape=[128, n_classes], 76 | dtype=tf.float32, 77 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 78 | biases = tf.get_variable('biases', 79 | shape=[n_classes], 80 | dtype=tf.float32, 81 | initializer=tf.constant_initializer(0.1)) 82 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear') 83 | 84 | return softmax_linear 85 | 86 | def losses(logits, labels): 87 | with tf.variable_scope('loss') as scope: 88 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits \ 89 | (logits=logits, labels=labels, name='xentropy_per_example') 90 | loss = tf.reduce_mean(cross_entropy, name='loss') 91 | tf.summary.scalar(scope.name + '/loss', loss) 92 | return loss 93 | 94 | 95 | def trainning(loss, learning_rate): 96 | with tf.name_scope('optimizer'): 97 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 98 | global_step = tf.Variable(0, name='global_step', trainable=False) 99 | train_op = optimizer.minimize(loss, global_step=global_step) 100 | return train_op 101 | 102 | 103 | def evaluation(logits, labels): 104 | with tf.variable_scope('accuracy') as scope: 105 | correct = tf.nn.in_top_k(logits, labels, 1) 106 | correct = tf.cast(correct, tf.float16) 107 | accuracy = tf.reduce_mean(correct) 108 | tf.summary.scalar(scope.name + '/accuracy', accuracy) 109 | return accuracy -------------------------------------------------------------------------------- /new version/convert2tfrecord.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/5/4 0012 10:45 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import os 10 | import matplotlib.pyplot as plt 11 | import skimage.io as io 12 | from scipy.misc import imread, imresize 13 | 14 | 15 | def get_file(file_dir): 16 | cover = [] 17 | label_cover = [] 18 | stego = [] 19 | label_stego = [] 20 | # 打标签 21 | for file in os.listdir(file_dir): 22 | # if file.endswith('0') or file.startswith('.'): 23 | # continue # Skip! 24 | name = file.split('_') 25 | if name[0] == 'C': 26 | cover.append(file_dir + file) 27 | label_cover.append(0) 28 | if name[0] == 'S': 29 | stego.append(file_dir + file) 30 | label_stego.append(1) 31 | print("这里有 %d cover \n这里有 %d stego" 32 | % (len(cover), len(stego))) 33 | # 打乱文件顺序shuffle 34 | image_list = np.hstack((cover, stego)) 35 | label_list = np.hstack((label_cover, label_stego)) 36 | temp = np.array([image_list, label_list]) 37 | temp = temp.transpose() 38 | np.random.shuffle(temp) 39 | 40 | image_list = list(temp[:, 0]) 41 | label_list = list(temp[:, 1]) 42 | label_list = [int(i) for i in label_list] 43 | 44 | return image_list, label_list 45 | 46 | 47 | # %% 48 | 49 | def int64_feature(value): 50 | """Wrapper for inserting int64 features into Example proto.""" 51 | if not isinstance(value, list): 52 | value = [value] 53 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 54 | 55 | 56 | def bytes_feature(value): 57 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 58 | 59 | 60 | # %% 61 | 62 | def convert_to_tfrecord(images, labels, save_dir, name): 63 | '''convert all images and labels to one tfrecord file. 64 | Args: 65 | images: list of image directories, string type 66 | labels: list of labels, int type 67 | save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' 68 | name: the name of tfrecord file, string type, e.g.: 'train' 69 | Return: 70 | no return 71 | Note: 72 | converting needs some time, be patient... 73 | ''' 74 | 75 | filename = os.path.join(save_dir, name + '.tfrecords') 76 | n_samples = len(labels) 77 | 78 | if np.shape(images)[0] != n_samples: 79 | raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], n_samples)) 80 | 81 | # wait some time here, transforming need some time based on the size of your data. 82 | writer = tf.python_io.TFRecordWriter(filename) 83 | print('\nTransform start......') 84 | for i in np.arange(0, n_samples): 85 | try: 86 | # image = imread(image[i]) 87 | 88 | image = io.imread(images[i]) # type(image) must be array! 89 | image = imresize(image, (256, 256)) 90 | image_raw = image.tostring() 91 | label = int(labels[i]) 92 | example = tf.train.Example(features=tf.train.Features(feature={ 93 | 'label': int64_feature(label), 94 | 'image_raw': bytes_feature(image_raw)})) 95 | writer.write(example.SerializeToString()) 96 | except IOError as e: 97 | print('Could not read:', images[i]) 98 | print('error: %s' % e) 99 | print('Skip it!\n') 100 | writer.close() 101 | print('Transform done!') 102 | 103 | 104 | # %% 105 | 106 | def read_and_decode(tfrecords_file, batch_size): 107 | '''read and decode tfrecord file, generate (image, label) batches 108 | Args: 109 | tfrecords_file: the directory of tfrecord file 110 | batch_size: number of images in each batch 111 | Returns: 112 | image: 4D tensor - [batch_size, width, height, channel] 113 | label: 1D tensor - [batch_size] 114 | ''' 115 | # make an input queue from the tfrecord file 116 | filename_queue = tf.train.string_input_producer([tfrecords_file]) 117 | 118 | reader = tf.TFRecordReader() 119 | _, serialized_example = reader.read(filename_queue) 120 | img_features = tf.parse_single_example( 121 | serialized_example, 122 | features={ 123 | 124 | 'label': tf.FixedLenFeature([], tf.int64), 125 | 'image_raw': tf.FixedLenFeature([], tf.string), 126 | }) 127 | image = tf.decode_raw(img_features['image_raw'], tf.uint8) 128 | 129 | ########################################################## 130 | # you can put data augmentation here, I didn't use it 131 | ########################################################## 132 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset. 133 | 134 | 135 | image = tf.reshape(image, [256, 256]) 136 | label = tf.cast(img_features['label'], tf.int32) 137 | image_batch, label_batch = tf.train.shuffle_batch([image, label], 138 | batch_size=batch_size, 139 | num_threads=64, 140 | capacity=2000, 141 | min_after_dequeue=20) 142 | return image_batch, tf.reshape(label_batch, [batch_size]) 143 | 144 | 145 | # %% Convert data to TFRecord 146 | 147 | #test_dir = 'F://CAE_CNN//data//catdogtest//' 148 | test_dir = 'G:\\dataS-UNIWARD0.4\\val\\' 149 | #save_dir = 'F://CAE_CNN//data//' 150 | save_dir = 'G:\\dataS-UNIWARD0.4\\' 151 | BATCH_SIZE = 25 152 | 153 | # Convert test data: you just need to run it ONCE ! 154 | name_test = 'S_UNIWARD0.4val' 155 | 156 | #images, labels = get_file(test_dir) 157 | #convert_to_tfrecord(images, labels, save_dir, name_test) 158 | 159 | # %% TO test train.tfrecord file 160 | 161 | def plot_images(images, labels): 162 | '''plot one batch size 163 | ''' 164 | for i in np.arange(0, BATCH_SIZE): 165 | plt.subplot(5, 5, i + 1) 166 | plt.axis('off') 167 | plt.title(chr(ord('D') + labels[i] - 1), fontsize=14) 168 | plt.subplots_adjust(top=1.5) 169 | plt.imshow(images[i]) 170 | plt.show() 171 | 172 | 173 | # tfrecords_file = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//test.tfrecords' 174 | tfrecords_file = 'G:\\dataS-UNIWARD0.4\\S_UNIWARD0.4val.tfrecords' 175 | 176 | image_batch, label_batch = read_and_decode(tfrecords_file, batch_size=BATCH_SIZE) 177 | 178 | with tf.Session() as sess: 179 | i = 0 180 | coord = tf.train.Coordinator() 181 | threads = tf.train.start_queue_runners(coord=coord) 182 | 183 | try: 184 | while not coord.should_stop() and i < 1: 185 | # just plot one batch size 186 | image, label = sess.run([image_batch, label_batch]) 187 | plot_images(image, label) 188 | i += 1 189 | 190 | except tf.errors.OutOfRangeError: 191 | print('done!') 192 | finally: 193 | coord.request_stop() 194 | coord.join(threads) 195 | 196 | 197 | # %% 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | -------------------------------------------------------------------------------- /new version/input_data1.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/4/27 0027 10:39 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import os 10 | import math 11 | 12 | def read_and_decode(tfrecords_file, batch_size): 13 | '''read and decode tfrecord file, generate (image, label) batches 14 | Args: 15 | tfrecords_file: the directory of tfrecord file 16 | batch_size: number of images in each batch 17 | Returns: 18 | image: 4D tensor - [batch_size, width, height, channel] 19 | label: 1D tensor - [batch_size] 20 | ''' 21 | # make an input queue from the tfrecord file 22 | filename_queue = tf.train.string_input_producer([tfrecords_file]) 23 | 24 | reader = tf.TFRecordReader() 25 | _, serialized_example = reader.read(filename_queue) 26 | img_features = tf.parse_single_example( 27 | serialized_example, 28 | features={ 29 | 'label': tf.FixedLenFeature([], tf.int64), 30 | 'image_raw': tf.FixedLenFeature([], tf.string), 31 | }) 32 | image = tf.decode_raw(img_features['image_raw'], tf.uint8) 33 | 34 | 35 | ########################################################## 36 | # you can put data augmentation here, I didn't use it 37 | ########################################################## 38 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset. 39 | 40 | 41 | image = tf.reshape(image, [256, 256, 1]) 42 | label = tf.cast(img_features['label'], tf.int32) 43 | image_batch, label_batch = tf.train.batch([image, label], 44 | batch_size=batch_size, 45 | num_threads=64, 46 | capacity=2000) 47 | 48 | image_batch = tf.cast(image_batch, tf.float32) 49 | 50 | return image_batch, tf.reshape(label_batch, [batch_size]) 51 | 52 | -------------------------------------------------------------------------------- /new version/model1.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/5/4 0004 10:19 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import tensorflow as tf 8 | 9 | def inference(images, batch_size, n_classes): 10 | with tf.variable_scope('conv1') as scope: 11 | weights = tf.get_variable('weights', 12 | #kernel size, kernel size, channels, kernel number 13 | shape=[3, 3, 1, 32], 14 | dtype=tf.float32, 15 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)) 16 | biases = tf.get_variable('biases', 17 | shape=[32], 18 | dtype=tf.float32, 19 | initializer=tf.constant_initializer(0.1)) 20 | conv = tf.nn.conv2d(images, weights, strides=[1, 1, 1, 1], padding='SAME') 21 | pre_activation = tf.nn.bias_add(conv, biases) 22 | conv1 = tf.nn.relu(pre_activation, name=scope.name) 23 | 24 | # with tf.variable_scope('pooling1_lrn') as scope: 25 | # pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling1') 26 | # norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1') 27 | 28 | 29 | with tf.variable_scope('conv2') as scope: 30 | weights = tf.get_variable('weights', 31 | shape=[3, 3, 32, 16], 32 | dtype=tf.float32, 33 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32)) 34 | biases = tf.get_variable('biases', 35 | shape=[16], 36 | dtype=tf.float32, 37 | initializer=tf.constant_initializer(0.1)) 38 | conv = tf.nn.conv2d(conv1, weights, strides=[1, 1, 1, 1], padding='SAME') 39 | pre_activation = tf.nn.bias_add(conv, biases) 40 | conv2 = tf.nn.relu(pre_activation, name='conv2') 41 | 42 | # pool2 and norm2 43 | # with tf.variable_scope('pooling2_lrn') as scope: 44 | # norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2') 45 | # pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 1, 1, 1], padding='SAME', name='pooling2') 46 | 47 | with tf.variable_scope('local3') as scope: 48 | reshape = tf.reshape(conv2, shape=[batch_size, -1]) 49 | dim = reshape.get_shape()[1].value 50 | weights = tf.get_variable('weights', 51 | shape=[dim, 256], 52 | dtype=tf.float32, 53 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 54 | biases = tf.get_variable('biases', 55 | shape=[256], 56 | dtype=tf.float32, 57 | initializer=tf.constant_initializer(0.1)) 58 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) 59 | 60 | # local4 61 | with tf.variable_scope('local4') as scope: 62 | weights = tf.get_variable('weights', 63 | shape=[256, 256], 64 | dtype=tf.float32, 65 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 66 | biases = tf.get_variable('biases', 67 | shape=[256], 68 | dtype=tf.float32, 69 | initializer=tf.constant_initializer(0.1)) 70 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4') 71 | 72 | # softmax 73 | with tf.variable_scope('softmax_linear') as scope: 74 | weights = tf.get_variable('softmax_linear', 75 | shape=[256, n_classes], 76 | dtype=tf.float32, 77 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32)) 78 | biases = tf.get_variable('biases', 79 | shape=[n_classes], 80 | dtype=tf.float32, 81 | initializer=tf.constant_initializer(0.1)) 82 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear') 83 | 84 | return softmax_linear 85 | 86 | #logits 是inference的返回值,labels是ground truth 87 | def losses(logits, labels): 88 | with tf.variable_scope('loss') as scope: 89 | 90 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits \ 91 | (logits=logits, labels=labels, name='xentropy_per_example') 92 | loss = tf.reduce_mean(cross_entropy, name='loss') 93 | tf.summary.scalar(scope.name + '/loss', loss) 94 | return loss 95 | 96 | 97 | def trainning(loss, learning_rate): 98 | with tf.name_scope('optimizer'): 99 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) 100 | global_step = tf.Variable(0, name='global_step', trainable=False) 101 | train_op = optimizer.minimize(loss, global_step=global_step) 102 | return train_op 103 | 104 | 105 | def evaluation(logits, labels): 106 | with tf.variable_scope('accuracy') as scope: 107 | correct = tf.nn.in_top_k(logits, labels, 1) 108 | correct = tf.cast(correct, tf.float16) 109 | accuracy = tf.reduce_mean(correct) 110 | tf.summary.scalar(scope.name + '/accuracy', accuracy) 111 | return accuracy -------------------------------------------------------------------------------- /new version/readme.md: -------------------------------------------------------------------------------- 1 | # new version 2 | a new version for network training process. 3 | trainning with tfrecord and validation process. 4 | 5 | # merge all 6 | The tf.merge_all_summaries() function is convenient, but also somewhat dangerous: it merges all summaries in the default graph, which includes any summaries from previous—apparently unconnected—invocations of code that also added summary nodes to the default graph. If old summary nodes depend on an old placeholder, you will get errors like the one you have shown in your question (and like previous questions as well). 7 | -------------------------------------------------------------------------------- /new version/train2.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/5/4 0004 10:19 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import os 8 | import numpy as np 9 | import tensorflow as tf 10 | import input_data1 11 | import model1 12 | 13 | N_CLASSES = 2 # cover与stego 14 | IMG_W = 256 # resize 15 | IMG_H = 256 16 | BATCH_SIZE = 32 17 | CAPACITY = 300 18 | MAX_STEP = 15000 # 一般大于10K 19 | learning_rate = 0.0001 # 一般小于0.0001 20 | 21 | 22 | def run_training(): 23 | 24 | logs_train_dir = 'G:\\dataS-UNIWARD0.4\\logs\\train' 25 | logs_val_dir = 'G:\\dataS-UNIWARD0.4\\logs\\val' 26 | 27 | tfrecords_traindir = 'G:\\dataS-UNIWARD0.4\\S_UNIWARD0.4train.tfrecords' 28 | tfrecords_valdir = 'G:\\dataS-UNIWARD0.4\\S_UNIWARD0.4val.tfrecords' 29 | 30 | # 获得batch tfrecord方法 31 | train_batch, train_label_batch = input_data1.read_and_decode(tfrecords_traindir, BATCH_SIZE) 32 | val_batch, val_label_batch = input_data1.read_and_decode(tfrecords_valdir, BATCH_SIZE) 33 | 34 | 35 | x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 256, 256, 1]) 36 | y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE]) 37 | 38 | 39 | logits = model1.inference(x, BATCH_SIZE, N_CLASSES) 40 | loss = model1.losses(logits, y_) 41 | acc = model1.evaluation(logits, y_) 42 | train_op = model1.trainning(loss, learning_rate) 43 | 44 | 45 | sess = tf.Session() 46 | saver = tf.train.Saver() 47 | sess.run(tf.global_variables_initializer()) 48 | coord = tf.train.Coordinator() 49 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 50 | 51 | 52 | summary_op = tf.summary.merge_all() 53 | train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) 54 | val_writer = tf.summary.FileWriter(logs_val_dir, sess.graph) 55 | 56 | try: 57 | for step in np.arange(MAX_STEP): 58 | if coord.should_stop(): 59 | break 60 | 61 | tra_images, tra_labels = sess.run([train_batch, train_label_batch]) 62 | _, tra_loss, tra_acc = sess.run([train_op, loss, acc], 63 | feed_dict={x: tra_images, y_: tra_labels}) 64 | if step % 2 == 0: 65 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) 66 | # summary_str = sess.run(summary_op) 67 | # train_writer.add_summary(summary_str, step) 68 | 69 | if step % 200 == 0 or (step + 1) == MAX_STEP: 70 | val_images, val_labels = sess.run([val_batch, val_label_batch]) 71 | val_loss, val_acc = sess.run([loss, acc], 72 | feed_dict={x: val_images, y_: val_labels}) 73 | print('** Step %d, val loss = %.2f, val accuracy = %.2f%% **' % (step, val_loss, val_acc * 100.0)) 74 | # summary_str = sess.run(summary_op) 75 | # val_writer.add_summary(summary_str, step) 76 | 77 | if step % 2000 == 0 or (step + 1) == MAX_STEP: 78 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') 79 | saver.save(sess, checkpoint_path, global_step=step) 80 | 81 | except tf.errors.OutOfRangeError: 82 | print('Done training -- epoch limit reached') 83 | finally: 84 | coord.request_stop() 85 | coord.join(threads) 86 | 87 | 88 | run_training() 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /onehot.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/4/3 0003 10:23 4 | # @Author : 5 | # @Software: PyCharm 6 | 7 | #one hot编码,m个可能值,转换成2元可能互斥特征,从而使得数据变得稀疏 8 | # 9 | 10 | from sklearn.preprocessing import OneHotEncoder 11 | 12 | enc = OneHotEncoder() 13 | 14 | #fit后面四个样本,得到两个参数(实际操作中需要fit多少个元素??) 15 | #enc.n_values_ 是每个样本中每一维度特征的可能数 16 | #enc.active_features_ 是上面可能数的累加 17 | enc.fit([[0, 0, 9], [1, 1, 3],[1,0,8], 18 | [0,0,8],[0,0,4],[0,0,6], 19 | [0,0,5],[0,0,7], 20 | [0, 2, 1],[1, 0, 2]]) 21 | 22 | 23 | print ("enc.n_values_ is:",enc.n_values_) 24 | print ("enc.feature_indices_ is:",enc.feature_indices_) 25 | 26 | print (enc.transform([[0, 1, 7]]).toarray()) -------------------------------------------------------------------------------- /rename.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/3/12 0012 16:17 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | import os 7 | 8 | 9 | def rename(): 10 | count=1 11 | path = 'F:\CAE_CNN\data\lldata' 12 | filelist = os.listdir(path) 13 | for files in filelist: 14 | Olddir = os.path.join(path, files) 15 | if os.path.isdir(Olddir): 16 | continue 17 | filename = os.path.splitext(files)[0] 18 | filetype = os.path.splitext(files)[1] 19 | 20 | #直接改名字的 21 | # Newdir = os.path.join(path, 'S' + filetype) 22 | 23 | # 文件名前自动增加S 24 | # Newdir = os.path.join(path, ('Stego.'+filename) + filetype) 25 | 26 | # 文件序号一次递增 27 | # Newdir = os.path.join(path, str(count) + filetype) 28 | 29 | 30 | 31 | # 批量取分隔符(___)前面 / 后面的名称 32 | # if filename.find('---')>=0:#如果文件名中含有--- 33 | # 34 | # Newdir=os.path.join(direc,filename.split('---')[0]+filetype); 35 | # 36 | # #取---前面的字符,若需要取后面的字符则使用filename.split('---')[1] 37 | # 38 | # if not os.path.isfile(Newdir): 39 | 40 | 41 | 42 | os.rename(Olddir, Newdir) 43 | 44 | count+= 1 45 | 46 | rename() 47 | -------------------------------------------------------------------------------- /tfrecord.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/4/12 0012 11:02 4 | # @Author : 5 | # @Software: PyCharm 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import os 10 | import matplotlib.pyplot as plt 11 | import skimage.io as io 12 | 13 | def get_file(file_dir): 14 | cover = [] 15 | label_cover = [] 16 | stego = [] 17 | label_stego = [] 18 | # 打标签 19 | for file in os.listdir(file_dir): 20 | # if file.endswith('0') or file.startswith('.'): 21 | # continue # Skip! 22 | name = file.split('.') 23 | if name[0] == 'Cover': 24 | cover.append(file_dir + file) 25 | label_cover.append(0) 26 | if name[0] == 'Stego': 27 | stego.append(file_dir + file) 28 | label_stego.append(1) 29 | print("这里有 %d cover \n这里有 %d stego" 30 | % (len(cover), len(stego))) 31 | # 打乱文件顺序shuffle 32 | image_list = np.hstack((cover, stego)) 33 | label_list = np.hstack((label_cover, label_stego)) 34 | temp = np.array([image_list, label_list]) 35 | temp = temp.transpose() 36 | np.random.shuffle(temp) 37 | 38 | image_list = list(temp[:, 0]) 39 | label_list = list(temp[:, 1]) 40 | label_list = [int(i) for i in label_list] 41 | 42 | return image_list, label_list 43 | 44 | # %% 45 | 46 | def int64_feature(value): 47 | """Wrapper for inserting int64 features into Example proto.""" 48 | if not isinstance(value, list): 49 | value = [value] 50 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 51 | 52 | 53 | def bytes_feature(value): 54 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 55 | 56 | 57 | # %% 58 | 59 | def convert_to_tfrecord(images, labels, save_dir, name): 60 | '''convert all images and labels to one tfrecord file. 61 | Args: 62 | images: list of image directories, string type 63 | labels: list of labels, int type 64 | save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/' 65 | name: the name of tfrecord file, string type, e.g.: 'train' 66 | Return: 67 | no return 68 | Note: 69 | converting needs some time, be patient... 70 | ''' 71 | 72 | filename = os.path.join(save_dir, name + '.tfrecords') 73 | n_samples = len(labels) 74 | 75 | if np.shape(images)[0] != n_samples: 76 | raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], n_samples)) 77 | 78 | # wait some time here, transforming need some time based on the size of your data. 79 | writer = tf.python_io.TFRecordWriter(filename) 80 | print('\nTransform start......') 81 | for i in np.arange(0, n_samples): 82 | try: 83 | image = io.imread(images[i]) # type(image) must be array! 84 | image_raw = image.tostring() 85 | label = int(labels[i]) 86 | example = tf.train.Example(features=tf.train.Features(feature={ 87 | 'label': int64_feature(label), 88 | 'image_raw': bytes_feature(image_raw)})) 89 | writer.write(example.SerializeToString()) 90 | except IOError as e: 91 | print('Could not read:', images[i]) 92 | print('error: %s' % e) 93 | print('Skip it!\n') 94 | writer.close() 95 | print('Transform done!') 96 | 97 | 98 | # %% 99 | 100 | def read_and_decode(tfrecords_file, batch_size): 101 | '''read and decode tfrecord file, generate (image, label) batches 102 | Args: 103 | tfrecords_file: the directory of tfrecord file 104 | batch_size: number of images in each batch 105 | Returns: 106 | image: 4D tensor - [batch_size, width, height, channel] 107 | label: 1D tensor - [batch_size] 108 | ''' 109 | # make an input queue from the tfrecord file 110 | filename_queue = tf.train.string_input_producer([tfrecords_file]) 111 | 112 | reader = tf.TFRecordReader() 113 | _, serialized_example = reader.read(filename_queue) 114 | img_features = tf.parse_single_example( 115 | serialized_example, 116 | features={ 117 | 'label': tf.FixedLenFeature([], tf.int64), 118 | 'image_raw': tf.FixedLenFeature([], tf.string), 119 | }) 120 | image = tf.decode_raw(img_features['image_raw'], tf.uint8) 121 | 122 | ########################################################## 123 | # you can put data augmentation here, I didn't use it 124 | ########################################################## 125 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset. 126 | 127 | image = tf.reshape(image, [512, 512]) 128 | label = tf.cast(img_features['label'], tf.int32) 129 | image_batch, label_batch = tf.train.batch([image, label], 130 | batch_size=batch_size, 131 | num_threads=64, 132 | capacity=2000) 133 | return image_batch, tf.reshape(label_batch, [batch_size]) 134 | 135 | 136 | # %% Convert data to TFRecord 137 | 138 | # test_dir = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//notMNIST_small//' 139 | test_dir = 'F://CAE_CNN//data//pgm_coverstego//' 140 | 141 | # save_dir = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//' 142 | save_dir = 'F://CAE_CNN//data//' 143 | 144 | BATCH_SIZE = 25 145 | 146 | # Convert test data: you just need to run it ONCE ! 147 | name_test = 'test' 148 | images, labels = get_file(test_dir) 149 | convert_to_tfrecord(images, labels, save_dir, name_test) 150 | 151 | 152 | # %% TO test train.tfrecord file 153 | 154 | def plot_images(images, labels): 155 | '''plot one batch size 156 | ''' 157 | for i in np.arange(0, BATCH_SIZE): 158 | plt.subplot(5, 5, i + 1) 159 | plt.axis('off') 160 | plt.title(chr(ord('A') + labels[i] - 1), fontsize=14) 161 | plt.subplots_adjust(top=1.5) 162 | plt.imshow(images[i]) 163 | plt.show() 164 | 165 | 166 | # tfrecords_file = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//test.tfrecords' 167 | tfrecords_file = 'F://CAE_CNN//data//test.tfrecords' 168 | 169 | image_batch, label_batch = read_and_decode(tfrecords_file, batch_size=BATCH_SIZE) 170 | 171 | with tf.Session() as sess: 172 | i = 0 173 | coord = tf.train.Coordinator() 174 | threads = tf.train.start_queue_runners(coord=coord) 175 | 176 | try: 177 | while not coord.should_stop() and i < 1: 178 | # just plot one batch size 179 | image, label = sess.run([image_batch, label_batch]) 180 | plot_images(image, label) 181 | i += 1 182 | 183 | except tf.errors.OutOfRangeError: 184 | print('done!') 185 | finally: 186 | coord.request_stop() 187 | coord.join(threads) 188 | 189 | 190 | # %% 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/3/13 0013 15:34 4 | # @Author : jsz 5 | # @Software: PyCharm 6 | 7 | import os 8 | import numpy as np 9 | import tensorflow as tf 10 | import input_data 11 | import model 12 | 13 | N_CLASSES = 2 # cover与stego 14 | IMG_W = 256 # resize 15 | IMG_H = 256 16 | BATCH_SIZE = 16 17 | CAPACITY = 2000 18 | MAX_STEP = 15000 # 一般大于10K 19 | learning_rate = 0.001 # 一般小于0.0001 20 | 21 | 22 | def run_training(): 23 | 24 | train_dir= 'F://CAE_CNN//data//train_imgs//' 25 | #产生一些文件,可以用tensorboard查看 26 | logs_train_dir = 'F://CAE_CNN//log//train//' 27 | 28 | 29 | 30 | #读取数据 31 | train, train_label = input_data.get_files(train_dir) 32 | #获得batch 33 | train_batch, train_label_batch = input_data.get_batch(train, 34 | train_label, 35 | IMG_W, 36 | IMG_H, 37 | BATCH_SIZE, 38 | CAPACITY) 39 | #参数传景区 40 | train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES) 41 | 42 | train_loss = model.losses(train_logits, train_label_batch) 43 | #训练 44 | train_op = model.trainning(train_loss, learning_rate) 45 | 46 | train__acc = model.evaluation(train_logits, train_label_batch) 47 | #merge到一块? 48 | summary_op = tf.summary.merge_all() # 这个是log汇总记录 49 | 50 | # 产生一个会话 51 | sess = tf.Session() 52 | # 产生一个writer来写log文件 53 | train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) 54 | # 产生一个saver来存储训练好的模型 55 | saver = tf.train.Saver() 56 | # 所有节点初始化 57 | sess.run(tf.global_variables_initializer()) 58 | 59 | # 队列监控 60 | coord = tf.train.Coordinator() 61 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 62 | 63 | for step in np.arange(MAX_STEP): 64 | _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc]) 65 | # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer 66 | if step % 2 == 0: 67 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) 68 | summary_str = sess.run(summary_op) 69 | train_writer.add_summary(summary_str, step) 70 | # 每隔2000步,保存一次训练好的模型 71 | if step % 2000 == 0 or (step + 1) == MAX_STEP: 72 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') 73 | saver.save(sess, checkpoint_path, global_step=step) 74 | # try: 75 | # # 执行MAX_STEP步的训练,一步一个batch 76 | # for step in np.arange(MAX_STEP): 77 | # # if coord.should_stop(): 78 | # # break 79 | # # 启动以下操作节点,有个疑问,为什么train_logits在这里没有开启? 80 | # _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc]) 81 | # # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer 82 | # if step % 2 == 0: 83 | # print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) 84 | # summary_str = sess.run(summary_op) 85 | # train_writer.add_summary(summary_str, step) 86 | # # 每隔2000步,保存一次训练好的模型 87 | # if step % 2000 == 0 or (step + 1) == MAX_STEP: 88 | # checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') 89 | # saver.save(sess, checkpoint_path, global_step=step) 90 | # 91 | # except tf.errors.OutOfRangeError: 92 | # print('Done training -- epoch limit reached') 93 | # finally: 94 | # coord.request_stop() 95 | # sess.close() 96 | 97 | run_training() 98 | -------------------------------------------------------------------------------- /train1.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2018/4/12 0012 13:50 4 | # @Author : 5 | # @Software: PyCharm 6 | # ! /usr/bin/python3 7 | # -*- coding: utf-8 -*- 8 | # @Time : 2018/3/13 0013 15:34 9 | # @Author : jsz 10 | # @Software: PyCharm 11 | 12 | import os 13 | import numpy as np 14 | import tensorflow as tf 15 | import input_data 16 | import model 17 | 18 | N_CLASSES = 2 # cover与stego 19 | IMG_W = 256 # resize 20 | IMG_H = 256 21 | BATCH_SIZE = 16 22 | CAPACITY = 300 23 | MAX_STEP = 15000 # 一般大于10K 24 | learning_rate = 0.0001 # 一般小于0.0001 25 | 26 | 27 | def run_training(): 28 | 29 | logs_train_dir = 'F://CAE_CNN//log//train1//' 30 | tfrecords_dir = 'F://CAE_CNN//data//test.tfrecords' 31 | 32 | # 获得batch tfrecord方法 33 | train_batch, train_label_batch = input_data.read_and_decode(tfrecords_dir, BATCH_SIZE) 34 | 35 | train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES) 36 | 37 | train_loss = model.losses(train_logits, train_label_batch) 38 | # 训练 39 | train_op = model.trainning(train_loss, learning_rate) 40 | 41 | train__acc = model.evaluation(train_logits, train_label_batch) 42 | # merge到一块? 43 | summary_op = tf.summary.merge_all() # 这个是log汇总记录 44 | 45 | # 产生一个会话 46 | sess = tf.Session() 47 | # 产生一个writer来写log文件 48 | train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) 49 | # 产生一个saver来存储训练好的模型 50 | saver = tf.train.Saver() 51 | # 所有节点初始化 52 | sess.run(tf.global_variables_initializer()) 53 | 54 | # 队列监控 55 | coord = tf.train.Coordinator() 56 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 57 | 58 | # for step in np.arange(MAX_STEP): 59 | # _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc]) 60 | # # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer 61 | # if step % 2 == 0: 62 | # print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) 63 | # summary_str = sess.run(summary_op) 64 | # train_writer.add_summary(summary_str, step) 65 | # # 每隔2000步,保存一次训练好的模型 66 | # if step % 2000 == 0 or (step + 1) == MAX_STEP: 67 | # checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') 68 | # saver.save(sess, checkpoint_path, global_step=step) 69 | try: 70 | # 执行MAX_STEP步的训练,一步一个batch 71 | for step in np.arange(MAX_STEP): 72 | if coord.should_stop(): 73 | break 74 | # 启动以下操作节点,有个疑问,为什么train_logits在这里没有开启? 75 | _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc]) 76 | # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer 77 | if step % 2 == 0: 78 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0)) 79 | summary_str = sess.run(summary_op) 80 | train_writer.add_summary(summary_str, step) 81 | # 每隔2000步,保存一次训练好的模型 82 | if step % 2000 == 0 or (step + 1) == MAX_STEP: 83 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') 84 | saver.save(sess, checkpoint_path, global_step=step) 85 | 86 | except tf.errors.OutOfRangeError: 87 | print('Done training -- epoch limit reached') 88 | finally: 89 | coord.request_stop() 90 | 91 | coord.join(threads) 92 | sess.close() 93 | 94 | 95 | run_training() 96 | --------------------------------------------------------------------------------