├── .idea
├── CAECNNcode.iml
├── misc.xml
├── modules.xml
└── vcs.xml
├── 1_Cover.pgm
├── 1_stego wow_0.4.pgm
├── README.md
├── data
├── 1.jpg
├── S-UNIWARD0.2.png
├── WOW0.5random_CNN.png
├── coverstego.jpg
├── readme.md
└── subtraction.jpg
├── easy work
├── data
│ ├── C_3.pgm
│ ├── S_1.pgm
│ └── readme.md
├── main.py
├── readme.md
└── yijianyunxing.py
├── fliter.py
├── input_data.py
├── model.py
├── new version
├── convert2tfrecord.py
├── input_data1.py
├── model1.py
├── readme.md
└── train2.py
├── onehot.py
├── rename.py
├── tfrecord.py
├── train.py
└── train1.py
/.idea/CAECNNcode.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/1_Cover.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/1_Cover.pgm
--------------------------------------------------------------------------------
/1_stego wow_0.4.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/1_stego wow_0.4.pgm
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deeplearning for Steganalysis
2 |
3 | ***
4 | ## Steganography and Steganalysis
5 |
6 |
7 |
8 | Steganography is the science to conceal secrect messages in the images though slightly modifying the pixel values. Content-adaptive steganographic schemes tend to embed the messages in complex regions to escape from detection are the most secure method in nowadays. Examples in spatial domain include HUGO, WOW, S-UNIWARD.
9 | Corresponding to steganography, steganalysis is the art of detecting hidden data in images. Usually, this task is formulated as a binary classification problem to distinguish between cover and stego.
10 |
11 |
12 | ## LSB steganography cover and stego
13 |
14 | * 1: cover(left) and stego(right)
15 |
16 |
17 |
18 |
19 | * 2: the subtraction result of cover and stego(small payload)
20 |
21 |
22 | ## J-UNIWARD steganography cat cover and stego
23 |
24 | * 3: the subtraction result of cover and stego(payload = 0.3 )
25 |
26 |
27 | ## deeplearning for steganalysis
28 |
29 | ***
30 | Different from traditional computer vision task, the goal of image steganalysis is to find embedding operation which may be extremely low noise to the cover. So there's no maxpooling layer in my network which could destory small imformations or features caused by Steganography.
31 |
32 |
33 | ## some results
34 |
35 | * 3: The training process,the net begins to converge at 50,000 step(5 epoch)
36 | 
37 |
38 | * 4: WOW0.5random_CNN training and validation accurcy. It can be seen from the validation loss value that the model is not overfitted. Amazing fitting ability.
39 |
40 |
41 |
42 | ***
43 | # reference
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/data/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/1.jpg
--------------------------------------------------------------------------------
/data/S-UNIWARD0.2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/S-UNIWARD0.2.png
--------------------------------------------------------------------------------
/data/WOW0.5random_CNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/WOW0.5random_CNN.png
--------------------------------------------------------------------------------
/data/coverstego.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/coverstego.jpg
--------------------------------------------------------------------------------
/data/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/subtraction.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/data/subtraction.jpg
--------------------------------------------------------------------------------
/easy work/data/C_3.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/easy work/data/C_3.pgm
--------------------------------------------------------------------------------
/easy work/data/S_1.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangszzzzz/CAECNNcode/582954fd57f390f153c8a677f2f73988eb90dc3b/easy work/data/S_1.pgm
--------------------------------------------------------------------------------
/easy work/data/readme.md:
--------------------------------------------------------------------------------
1 | #some pgm image
2 |
--------------------------------------------------------------------------------
/easy work/main.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/6/7 0007 14:23
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | import os
10 | import matplotlib.pyplot as plt
11 | import skimage.io as io
12 | from scipy.misc import imread, imresize
13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
14 |
15 | def rename():
16 | count=1
17 | path = '.\\data\\'
18 |
19 | for dirpath, dirnames, filenames in os.walk(path):
20 | # for filename in filenames:
21 | # print(os.path.join(dirpath, filename))
22 |
23 | for files in filenames:
24 | # if files.endswith('.pgm'):
25 | name = files.split('_')
26 | if name[0] == 'C':
27 | pass
28 | else:
29 | Olddir = os.path.join(dirpath, files)
30 | if os.path.isdir(Olddir):
31 | continue
32 | filename = os.path.splitext(files)[0]
33 | filetype = os.path.splitext(files)[1]
34 |
35 | #直接改名字的
36 | # Newdir = os.path.join(path, 'S' + filetype)
37 |
38 | # 文件名前自动增加S
39 | Newdir = os.path.join(dirpath, ('S_'+filename) + filetype)
40 |
41 | # 文件序号一次递增
42 | # Newdir = os.path.join(path, str(count) + filetype)
43 |
44 |
45 |
46 | # 批量取分隔符(___)前面 / 后面的名称
47 | # if filename.find('---')>=0:#如果文件名中含有---
48 | #
49 | # Newdir=os.path.join(direc,filename.split('---')[0]+filetype);
50 | #
51 | # #取---前面的字符,若需要取后面的字符则使用filename.split('---')[1]
52 | #
53 | # if not os.path.isfile(Newdir):
54 |
55 | os.rename(Olddir, Newdir)
56 |
57 | count+= 1
58 |
59 | def get_file(file_dir):
60 | cover = []
61 | label_cover = []
62 | stego = []
63 | label_stego = []
64 | # 打标签
65 | for file in os.listdir(file_dir):
66 | # if file.endswith('0') or file.startswith('.'):
67 | # continue # Skip!
68 | name = file.split('_')
69 | if name[0] == 'C':
70 | cover.append(file_dir + file)
71 | label_cover.append(0)
72 | if name[0] == 'S':
73 | stego.append(file_dir + file)
74 | label_stego.append(1)
75 | print("这里有 %d cover \n这里有 %d stego"
76 | % (len(cover), len(stego)))
77 | # 打乱文件顺序shuffle
78 | image_list = np.hstack((cover, stego))
79 | label_list = np.hstack((label_cover, label_stego))
80 | temp = np.array([image_list, label_list])
81 | temp = temp.transpose()
82 | np.random.shuffle(temp)
83 |
84 | image_list = list(temp[:, 0])
85 | label_list = list(temp[:, 1])
86 | label_list = [int(i) for i in label_list]
87 |
88 | return image_list, label_list
89 |
90 | def int64_feature(value):
91 | """Wrapper for inserting int64 features into Example proto."""
92 | if not isinstance(value, list):
93 | value = [value]
94 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
95 |
96 |
97 | def bytes_feature(value):
98 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
99 |
100 | def convert_to_tfrecord(images, labels, save_dir, name):
101 | '''convert all images and labels to one tfrecord file.
102 | Args:
103 | images: list of image directories, string type
104 | labels: list of labels, int type
105 | save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/'
106 | name: the name of tfrecord file, string type, e.g.: 'train'
107 | Return:
108 | no return
109 | Note:
110 | converting needs some time, be patient...
111 | '''
112 |
113 | filename = os.path.join(save_dir, name + '.tfrecords')
114 | n_samples = len(labels)
115 |
116 | if np.shape(images)[0] != n_samples:
117 | raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], n_samples))
118 |
119 | # wait some time here, transforming need some time based on the size of your data.
120 | writer = tf.python_io.TFRecordWriter(filename)
121 | print('\nTransform start......')
122 | for i in np.arange(0, n_samples):
123 | try:
124 | # image = imread(image[i])
125 |
126 | image = io.imread(images[i]) # type(image) must be array!
127 | image = imresize(image, (256, 256))
128 | image_raw = image.tostring()
129 | label = int(labels[i])
130 | example = tf.train.Example(features=tf.train.Features(feature={
131 | 'label': int64_feature(label),
132 | 'image_raw': bytes_feature(image_raw)}))
133 | writer.write(example.SerializeToString())
134 | except IOError as e:
135 | print('Could not read:', images[i])
136 | print('error: %s' % e)
137 | print('Skip it!\n')
138 | writer.close()
139 | print('Transform done!')
140 |
141 |
142 | def read_and_decode(tfrecords_file, batch_size):
143 | '''read and decode tfrecord file, generate (image, label) batches
144 | Args:
145 | tfrecords_file: the directory of tfrecord file
146 | batch_size: number of images in each batch
147 | Returns:
148 | image: 4D tensor - [batch_size, width, height, channel]
149 | label: 1D tensor - [batch_size]
150 | '''
151 | # make an input queue from the tfrecord file
152 | filename_queue = tf.train.string_input_producer([tfrecords_file])
153 |
154 | reader = tf.TFRecordReader()
155 | _, serialized_example = reader.read(filename_queue)
156 | img_features = tf.parse_single_example(
157 | serialized_example,
158 | features={
159 |
160 | 'label': tf.FixedLenFeature([], tf.int64),
161 | 'image_raw': tf.FixedLenFeature([], tf.string),
162 | })
163 | image = tf.decode_raw(img_features['image_raw'], tf.uint8)
164 |
165 | ##########################################################
166 | # you can put data augmentation here, I didn't use it
167 | ##########################################################
168 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset.
169 |
170 |
171 | image = tf.reshape(image, [256, 256, 1])
172 | # image = tf.reshape(image, [256, 256]) #for plot
173 | label = tf.cast(img_features['label'], tf.int32)
174 | image_batch, label_batch = tf.train.shuffle_batch([image, label],
175 | batch_size=batch_size,
176 | num_threads=64,
177 | capacity=2000,
178 | min_after_dequeue=20)
179 | return image_batch, tf.reshape(label_batch, [batch_size])
180 |
181 |
182 | def plot_images(images, labels):
183 | '''plot one batch size
184 | '''
185 | for i in np.arange(0, BATCH_SIZE):
186 | plt.subplot(5, 5, i + 1)
187 | plt.axis('off')
188 | # plt.title(chr(ord('D') + labels[i] - 1), fontsize=14)
189 |
190 | if labels[i] == 1:
191 | plt.title(str('Stego'), fontsize=14)
192 | else:
193 | plt.title(str('Cover'), fontsize=14)
194 |
195 | plt.subplots_adjust(top=1.5)
196 | plt.imshow(images[i])
197 | plt.show()
198 |
199 | #####
200 | # for test tfrecords
201 | #####
202 | # BATCH_SIZE = 25
203 | # BATCH_SIZE1 = 25
204 | # image_batch, label_batch = read_and_decode(tfrecords_file, batch_size=BATCH_SIZE)
205 | # image_batch1, label_batch1 = read_and_decode(tfrecords_file1, batch_size=BATCH_SIZE1)
206 | #
207 | # with tf.Session() as sess:
208 | # i = 0
209 | # coord = tf.train.Coordinator()
210 | # threads = tf.train.start_queue_runners(coord=coord)
211 | #
212 | # try:
213 | # while not coord.should_stop() and i < 1:
214 | # # just plot one batch size
215 | # image, label = sess.run([image_batch, label_batch])
216 | # plot_images(image, label)
217 | #
218 | # image, label = sess.run([image_batch, label_batch])
219 | # plot_images(image, label)
220 | #
221 | # image, label = sess.run([image_batch1, label_batch1])
222 | # plot_images(image, label)
223 | #
224 | # image, label = sess.run([image_batch1, label_batch1])
225 | # plot_images(image, label)
226 | #
227 | # i += 1
228 | #
229 | # except tf.errors.OutOfRangeError:
230 | # print('done!')
231 | # finally:
232 | # coord.request_stop()
233 | # coord.join(threads)
234 |
235 |
236 |
237 | # model
238 |
239 | def inference(images, batch_size, n_classes):
240 | with tf.variable_scope('conv1') as scope:
241 | weights = tf.get_variable('weights',
242 | #kernel size, kernel size, channels, kernel number
243 | shape=[3, 3, 1, 32],
244 | dtype=tf.float32,
245 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32))
246 | biases = tf.get_variable('biases',
247 | shape=[32],
248 | dtype=tf.float32,
249 | initializer=tf.constant_initializer(0.1))
250 | conv = tf.nn.conv2d(images, weights, strides=[1, 1, 1, 1], padding='SAME')
251 | pre_activation = tf.nn.bias_add(conv, biases)
252 | conv1 = tf.nn.relu(pre_activation, name=scope.name)
253 |
254 | # with tf.variable_scope('pooling1_lrn') as scope:
255 | # pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling1')
256 | # norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
257 |
258 |
259 | with tf.variable_scope('conv2') as scope:
260 | weights = tf.get_variable('weights',
261 | shape=[3, 3, 32, 16],
262 | dtype=tf.float32,
263 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32))
264 | biases = tf.get_variable('biases',
265 | shape=[16],
266 | dtype=tf.float32,
267 | initializer=tf.constant_initializer(0.1))
268 | conv = tf.nn.conv2d(conv1, weights, strides=[1, 1, 1, 1], padding='SAME')
269 | pre_activation = tf.nn.bias_add(conv, biases)
270 | conv2 = tf.nn.relu(pre_activation, name='conv2')
271 |
272 | # pool2 and norm2
273 | # with tf.variable_scope('pooling2_lrn') as scope:
274 | # norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
275 | # pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 1, 1, 1], padding='SAME', name='pooling2')
276 |
277 | with tf.variable_scope('local3') as scope:
278 | reshape = tf.reshape(conv2, shape=[batch_size, -1])
279 | dim = reshape.get_shape()[1].value
280 | weights = tf.get_variable('weights',
281 | shape=[dim, 256],
282 | dtype=tf.float32,
283 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
284 | biases = tf.get_variable('biases',
285 | shape=[256],
286 | dtype=tf.float32,
287 | initializer=tf.constant_initializer(0.1))
288 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
289 |
290 | # local4
291 | with tf.variable_scope('local4') as scope:
292 | weights = tf.get_variable('weights',
293 | shape=[256, 256],
294 | dtype=tf.float32,
295 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
296 | biases = tf.get_variable('biases',
297 | shape=[256],
298 | dtype=tf.float32,
299 | initializer=tf.constant_initializer(0.1))
300 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4')
301 |
302 | # softmax
303 | with tf.variable_scope('softmax_linear') as scope:
304 | weights = tf.get_variable('softmax_linear',
305 | shape=[256, n_classes],
306 | dtype=tf.float32,
307 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
308 | biases = tf.get_variable('biases',
309 | shape=[n_classes],
310 | dtype=tf.float32,
311 | initializer=tf.constant_initializer(0.1))
312 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear')
313 |
314 | return softmax_linear
315 |
316 | #logits 是inference的返回值,labels是ground truth
317 | def losses(logits, labels):
318 | with tf.variable_scope('loss') as scope:
319 |
320 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits \
321 | (logits=logits, labels=labels, name='xentropy_per_example')
322 | loss = tf.reduce_mean(cross_entropy, name='loss')
323 | tf.summary.scalar(scope.name + '/loss', loss)
324 | return loss
325 |
326 |
327 | def trainning(loss, learning_rate):
328 | with tf.name_scope('optimizer'):
329 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
330 | global_step = tf.Variable(0, name='global_step', trainable=False)
331 | train_op = optimizer.minimize(loss, global_step=global_step)
332 | return train_op
333 |
334 |
335 | def evaluation(logits, labels):
336 | with tf.variable_scope('accuracy') as scope:
337 | correct = tf.nn.in_top_k(logits, labels, 1)
338 | correct = tf.cast(correct, tf.float16)
339 | accuracy = tf.reduce_mean(correct)
340 | tf.summary.scalar(scope.name + '/accuracy', accuracy)
341 | return accuracy
342 |
343 |
344 | #training set
345 | N_CLASSES = 2 # cover与stego
346 | IMG_W = 256 # resize
347 | IMG_H = 256
348 | BATCH_SIZE = 16
349 | CAPACITY = 300
350 | MAX_STEP = 250000 # 一般大于10K
351 | learning_rate = 0.00001 # 一般小于0.0001
352 |
353 | def run_training(tfrecords_file,tfrecords_file1):
354 |
355 | logs_train_dir = '.\\logs\\train'
356 | # logs_val_dir = 'H:\\dataWOW_0.05random\\logs\\val'
357 |
358 | tfrecords_traindir = tfrecords_file
359 | tfrecords_valdir = tfrecords_file1
360 |
361 | # 获得batch tfrecord方法
362 | train_batch, train_label_batch = read_and_decode(tfrecords_traindir, BATCH_SIZE)
363 | val_batch, val_label_batch = read_and_decode(tfrecords_valdir, BATCH_SIZE)
364 |
365 |
366 | x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 256, 256, 1])
367 | y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
368 |
369 |
370 | logits = inference(x, BATCH_SIZE, N_CLASSES)
371 | loss = losses(logits, y_)
372 | acc = evaluation(logits, y_)
373 | train_op = trainning(loss, learning_rate)
374 |
375 |
376 | sess = tf.Session()
377 |
378 | saver = tf.train.Saver()
379 | sess.run(tf.global_variables_initializer())
380 | coord = tf.train.Coordinator()
381 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
382 |
383 | try:
384 | for step in np.arange(MAX_STEP):
385 |
386 | if coord.should_stop():
387 | break
388 |
389 | tra_images, tra_labels = sess.run([train_batch, train_label_batch])
390 |
391 |
392 | _, tra_loss, tra_acc = sess.run([train_op, loss, acc],
393 | feed_dict={
394 | x: tra_images,
395 | y_: tra_labels})
396 | if step % 2 == 0:
397 | print(tfrecords_traindir)
398 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
399 |
400 |
401 |
402 | if step % 4 == 0 or (step + 1) == MAX_STEP:
403 | val_images, val_labels = sess.run([val_batch, val_label_batch])
404 | val_loss, val_acc = sess.run([loss, acc],
405 | feed_dict={
406 | x: val_images,
407 | y_: val_labels})
408 | print(tfrecords_valdir)
409 | print(' ** Step %d, val loss = %.2f, val accuracy = %.2f%% **' % (step, val_loss, val_acc * 100.0))
410 |
411 |
412 | if step % 2000 == 0 or (step + 1) == MAX_STEP:
413 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
414 | saver.save(sess, checkpoint_path, global_step=step)
415 |
416 | except tf.errors.OutOfRangeError:
417 | print('Done training -- epoch limit reached')
418 | finally:
419 | coord.request_stop()
420 | coord.join(threads)
421 |
422 |
--------------------------------------------------------------------------------
/easy work/readme.md:
--------------------------------------------------------------------------------
1 | # some code for steganalysis expriments.
2 | Modify the related parameters in the yijianyunxing file and run!
3 |
--------------------------------------------------------------------------------
/easy work/yijianyunxing.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/6/7 0007 14:51
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import main
8 |
9 | main.rename()
10 |
11 | name_test = 'testrandomtrain'
12 | name_test1 = 'testrandomval'
13 | tfrecords_file = '.\\testrandomtrain.tfrecords'
14 | tfrecords_file1 = '.\\testrandomval.tfrecords'
15 |
16 | test_dir = '.\\data\\train\\'
17 | save_dir = '.\\'
18 | test_dir1 = '.\\data\\val\\'
19 | save_dir1= '.\\'
20 |
21 | images, labels = main.get_file(test_dir)
22 | main.convert_to_tfrecord(images, labels, save_dir, name_test)
23 | images1, labels1 = main.get_file(test_dir1)
24 | main.convert_to_tfrecord(images1, labels1, save_dir1, name_test1)
25 |
26 | main.run_training(tfrecords_file,tfrecords_file1)
--------------------------------------------------------------------------------
/fliter.py:
--------------------------------------------------------------------------------
1 | import tensorlayer as tl
2 | import tensorflow as tf
3 | import matplotlib.pyplot as plt
4 | import cv2
5 | from tensorlayer.layers import *
6 | import numpy as np
7 |
8 | sess = tf.InteractiveSession()
9 | x = tf.placeholder(tf.float32, [None, 512, 512, 1])
10 | F0 = np.array([[-1, 2, -2, 2, -1],
11 | [2, -6, 8, -6, 2],
12 | [-2, 8, -12, 8, -2],
13 | [2, -6, 8, -6, 2],
14 | [-1, 2, -2, 2, -1]], dtype=np.float32)
15 | F0 = F0 / 12.
16 | high_pass_filter = tf.constant_initializer(value=F0, dtype=tf.float32)
17 | net = InputLayer(x, name='inputlayer')
18 | net = Conv2d(net, 1, (5, 5), (1, 1), act=tf.identity,
19 | padding='SAME', W_init=high_pass_filter, name='HighPass')
20 | y = net.outputs
21 | tl.layers.initialize_global_variables(sess)
22 |
23 | img = cv2.imread('1_cover.pgm',0).astype(np.float32).reshape([1,512,512,1])
24 |
25 | img_after = y.eval(feed_dict = {x:img})
26 |
27 |
28 |
29 |
30 | if __name__ == '__main__':
31 | # plt.imshow(img.reshape([256,256]))
32 | # plt.imshow(img_after.reshape([256,256]))
33 | #
34 | pgm_info = np.where(img_after > 10, 1, 0)
35 | plt.imshow(pgm_info.reshape([512,512]))
36 | plt.show()
37 |
38 |
--------------------------------------------------------------------------------
/input_data.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/3/12 0012 15:56
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | import os
10 |
11 | #定义读取函数,返回两个list,image_list是含有图片路径的string,label_list含有0,1
12 |
13 | def get_files(file_dir):
14 | cover = []
15 | label_cover = []
16 | stego = []
17 | label_stego = []
18 | #打标签
19 | for file in os.listdir(file_dir):
20 | # if file.endswith('0') or file.startswith('.'):
21 | # continue # Skip!
22 | name = file.split('_')
23 | if name[0] == 'C0':
24 | cover.append(file_dir + file)
25 | label_cover.append(0)
26 | if name[0] == 'S1' :
27 | stego.append(file_dir + file)
28 | label_stego.append(1)
29 | print("这里有 %d cover \n这里有 %d stego"
30 | % (len(cover), len(stego)))
31 | #打乱文件顺序shuffle
32 | image_list = np.hstack((cover,stego))
33 | label_list = np.hstack((label_cover, label_stego))
34 | temp = np.array([image_list, label_list])
35 | temp = temp.transpose()
36 | np.random.shuffle(temp)
37 |
38 | image_list = list(temp[:, 0])
39 | label_list = list(temp[:, 1])
40 | label_list = [int(i) for i in label_list]
41 |
42 | return image_list , label_list
43 |
44 | #定义batch函数
45 | def get_batch(image, label,
46 | image_W, image_H,
47 | batch_size, capacity):
48 |
49 | # #将python.list 转换成tf能够识别的格式
50 |
51 | label = tf.cast(label, tf.int32)
52 | image = tf.cast(image, tf.string)
53 |
54 | input_queue = tf.train.slice_input_producer([image, label])
55 |
56 |
57 | label = input_queue[1]
58 |
59 | image_contents = tf.read_file(input_queue[0])
60 |
61 | print(input_queue[0])
62 |
63 | image = tf.image.decode_png(image_contents, channels=0)
64 |
65 | image = tf.reshape(image, [ 256, 256, 1])
66 | image = tf.image.per_image_standardization(image)
67 |
68 | image_batch, label_batch = tf.train.batch([image, label],
69 | batch_size = batch_size,
70 | num_threads=64,
71 | capacity=capacity,
72 | )
73 |
74 | label_batch = tf.reshape(label_batch, [batch_size])
75 | image_batch = tf.cast(image_batch, tf.float32)
76 |
77 | return image_batch, label_batch
78 |
79 | def read_and_decode(tfrecords_file, batch_size):
80 | '''read and decode tfrecord file, generate (image, label) batches
81 | Args:
82 | tfrecords_file: the directory of tfrecord file
83 | batch_size: number of images in each batch
84 | Returns:
85 | image: 4D tensor - [batch_size, width, height, channel]
86 | label: 1D tensor - [batch_size]
87 | '''
88 | # make an input queue from the tfrecord file
89 | filename_queue = tf.train.string_input_producer([tfrecords_file])
90 |
91 | reader = tf.TFRecordReader()
92 | _, serialized_example = reader.read(filename_queue)
93 | img_features = tf.parse_single_example(
94 | serialized_example,
95 | features={
96 | 'label': tf.FixedLenFeature([], tf.int64),
97 | 'image_raw': tf.FixedLenFeature([], tf.string),
98 | })
99 | image = tf.decode_raw(img_features['image_raw'], tf.uint8)
100 |
101 | ##########################################################
102 | # you can put data augmentation here, I didn't use it
103 | ##########################################################
104 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset.
105 |
106 | image = tf.reshape(image, [512, 512, 1])
107 | label = tf.cast(img_features['label'], tf.int32)
108 | image_batch, label_batch = tf.train.batch([image, label],
109 | batch_size=batch_size,
110 | num_threads=64,
111 | capacity=2000)
112 |
113 | image_batch = tf.cast(image_batch, tf.float32)
114 |
115 | return image_batch, tf.reshape(label_batch, [batch_size])
116 |
117 | # file_dir = 'F://CAE_CNN//data//pgm_coverstego//'
118 | # file_dir = 'F://CAE_CNN//data//train//'
119 | # get_files(file_dir)
120 | # file_dir = 'G://PGMtoPNG//train_imgs//'
121 | #
122 | # import matplotlib.pyplot as plt
123 | #
124 | # BATCH_SIZE = 2
125 | # CAPACITY = 256
126 | # IMG_W = 256
127 | # IMG_H = 256
128 | #
129 | # image_list, label_list = get_files(file_dir)
130 | # image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
131 | #
132 | # with tf.Session() as sess:
133 | # i = 0
134 | # coord = tf.train.Coordinator()
135 | # threads = tf.train.start_queue_runners(coord=coord)
136 | # try:
137 | # while not coord.should_stop() and i < 2:
138 | # img, label = sess.run([image_batch, label_batch])
139 | #
140 | # for j in np.arange(BATCH_SIZE):
141 | # print("label: %d" % label[j])
142 | #
143 | # plt.imshow(img[j])
144 | # # plt.imshow('F://CAE_CNN//data//pgm_cover//Cover.1.pgm')
145 | #
146 | # plt.show()
147 | # # print(img.eval())
148 | # i += 1
149 | # except tf.errors.OutOfRangeError:
150 | # print("done!")
151 | # finally:
152 | # coord.request_stop()
153 | # coord.join(threads)
154 | #
155 |
156 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/3/13 0013 15:32
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import tensorflow as tf
8 |
9 | def inference(images, batch_size, n_classes):
10 | with tf.variable_scope('conv1') as scope:
11 | weights = tf.get_variable('weights',
12 | #1通道
13 | shape=[3, 3, 1, 16],
14 | dtype=tf.float32,
15 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32))
16 | biases = tf.get_variable('biases',
17 | shape=[16],
18 | dtype=tf.float32,
19 | initializer=tf.constant_initializer(0.1))
20 | conv = tf.nn.conv2d(images, weights, strides=[1, 1, 1, 1], padding='SAME')
21 | pre_activation = tf.nn.bias_add(conv, biases)
22 | conv1 = tf.nn.relu(pre_activation, name=scope.name)
23 |
24 | with tf.variable_scope('pooling1_lrn') as scope:
25 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling1')
26 | norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
27 |
28 |
29 | # with tf.variable_scope('conv2') as scope:
30 | # weights = tf.get_variable('weights',
31 | # shape=[3, 3, 16, 16],
32 | # dtype=tf.float32,
33 | # initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32))
34 | # biases = tf.get_variable('biases',
35 | # shape=[16],
36 | # dtype=tf.float32,
37 | # initializer=tf.constant_initializer(0.1))
38 | # conv = tf.nn.conv2d(norm1, weights, strides=[1, 1, 1, 1], padding='SAME')
39 | # pre_activation = tf.nn.bias_add(conv, biases)
40 | # conv2 = tf.nn.relu(pre_activation, name='conv2')
41 | #
42 | # # pool2 and norm2
43 | # with tf.variable_scope('pooling2_lrn') as scope:
44 | # norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
45 | # pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 1, 1, 1], padding='SAME', name='pooling2')
46 |
47 | with tf.variable_scope('local3') as scope:
48 | reshape = tf.reshape(pool1, shape=[batch_size, -1])
49 | dim = reshape.get_shape()[1].value
50 | weights = tf.get_variable('weights',
51 | shape=[dim, 128],
52 | dtype=tf.float32,
53 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
54 | biases = tf.get_variable('biases',
55 | shape=[128],
56 | dtype=tf.float32,
57 | initializer=tf.constant_initializer(0.1))
58 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
59 |
60 | # local4
61 | with tf.variable_scope('local4') as scope:
62 | weights = tf.get_variable('weights',
63 | shape=[128, 128],
64 | dtype=tf.float32,
65 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
66 | biases = tf.get_variable('biases',
67 | shape=[128],
68 | dtype=tf.float32,
69 | initializer=tf.constant_initializer(0.1))
70 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4')
71 |
72 | # softmax
73 | with tf.variable_scope('softmax_linear') as scope:
74 | weights = tf.get_variable('softmax_linear',
75 | shape=[128, n_classes],
76 | dtype=tf.float32,
77 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
78 | biases = tf.get_variable('biases',
79 | shape=[n_classes],
80 | dtype=tf.float32,
81 | initializer=tf.constant_initializer(0.1))
82 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear')
83 |
84 | return softmax_linear
85 |
86 | def losses(logits, labels):
87 | with tf.variable_scope('loss') as scope:
88 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits \
89 | (logits=logits, labels=labels, name='xentropy_per_example')
90 | loss = tf.reduce_mean(cross_entropy, name='loss')
91 | tf.summary.scalar(scope.name + '/loss', loss)
92 | return loss
93 |
94 |
95 | def trainning(loss, learning_rate):
96 | with tf.name_scope('optimizer'):
97 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
98 | global_step = tf.Variable(0, name='global_step', trainable=False)
99 | train_op = optimizer.minimize(loss, global_step=global_step)
100 | return train_op
101 |
102 |
103 | def evaluation(logits, labels):
104 | with tf.variable_scope('accuracy') as scope:
105 | correct = tf.nn.in_top_k(logits, labels, 1)
106 | correct = tf.cast(correct, tf.float16)
107 | accuracy = tf.reduce_mean(correct)
108 | tf.summary.scalar(scope.name + '/accuracy', accuracy)
109 | return accuracy
--------------------------------------------------------------------------------
/new version/convert2tfrecord.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/5/4 0012 10:45
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | import os
10 | import matplotlib.pyplot as plt
11 | import skimage.io as io
12 | from scipy.misc import imread, imresize
13 |
14 |
15 | def get_file(file_dir):
16 | cover = []
17 | label_cover = []
18 | stego = []
19 | label_stego = []
20 | # 打标签
21 | for file in os.listdir(file_dir):
22 | # if file.endswith('0') or file.startswith('.'):
23 | # continue # Skip!
24 | name = file.split('_')
25 | if name[0] == 'C':
26 | cover.append(file_dir + file)
27 | label_cover.append(0)
28 | if name[0] == 'S':
29 | stego.append(file_dir + file)
30 | label_stego.append(1)
31 | print("这里有 %d cover \n这里有 %d stego"
32 | % (len(cover), len(stego)))
33 | # 打乱文件顺序shuffle
34 | image_list = np.hstack((cover, stego))
35 | label_list = np.hstack((label_cover, label_stego))
36 | temp = np.array([image_list, label_list])
37 | temp = temp.transpose()
38 | np.random.shuffle(temp)
39 |
40 | image_list = list(temp[:, 0])
41 | label_list = list(temp[:, 1])
42 | label_list = [int(i) for i in label_list]
43 |
44 | return image_list, label_list
45 |
46 |
47 | # %%
48 |
49 | def int64_feature(value):
50 | """Wrapper for inserting int64 features into Example proto."""
51 | if not isinstance(value, list):
52 | value = [value]
53 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
54 |
55 |
56 | def bytes_feature(value):
57 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
58 |
59 |
60 | # %%
61 |
62 | def convert_to_tfrecord(images, labels, save_dir, name):
63 | '''convert all images and labels to one tfrecord file.
64 | Args:
65 | images: list of image directories, string type
66 | labels: list of labels, int type
67 | save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/'
68 | name: the name of tfrecord file, string type, e.g.: 'train'
69 | Return:
70 | no return
71 | Note:
72 | converting needs some time, be patient...
73 | '''
74 |
75 | filename = os.path.join(save_dir, name + '.tfrecords')
76 | n_samples = len(labels)
77 |
78 | if np.shape(images)[0] != n_samples:
79 | raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], n_samples))
80 |
81 | # wait some time here, transforming need some time based on the size of your data.
82 | writer = tf.python_io.TFRecordWriter(filename)
83 | print('\nTransform start......')
84 | for i in np.arange(0, n_samples):
85 | try:
86 | # image = imread(image[i])
87 |
88 | image = io.imread(images[i]) # type(image) must be array!
89 | image = imresize(image, (256, 256))
90 | image_raw = image.tostring()
91 | label = int(labels[i])
92 | example = tf.train.Example(features=tf.train.Features(feature={
93 | 'label': int64_feature(label),
94 | 'image_raw': bytes_feature(image_raw)}))
95 | writer.write(example.SerializeToString())
96 | except IOError as e:
97 | print('Could not read:', images[i])
98 | print('error: %s' % e)
99 | print('Skip it!\n')
100 | writer.close()
101 | print('Transform done!')
102 |
103 |
104 | # %%
105 |
106 | def read_and_decode(tfrecords_file, batch_size):
107 | '''read and decode tfrecord file, generate (image, label) batches
108 | Args:
109 | tfrecords_file: the directory of tfrecord file
110 | batch_size: number of images in each batch
111 | Returns:
112 | image: 4D tensor - [batch_size, width, height, channel]
113 | label: 1D tensor - [batch_size]
114 | '''
115 | # make an input queue from the tfrecord file
116 | filename_queue = tf.train.string_input_producer([tfrecords_file])
117 |
118 | reader = tf.TFRecordReader()
119 | _, serialized_example = reader.read(filename_queue)
120 | img_features = tf.parse_single_example(
121 | serialized_example,
122 | features={
123 |
124 | 'label': tf.FixedLenFeature([], tf.int64),
125 | 'image_raw': tf.FixedLenFeature([], tf.string),
126 | })
127 | image = tf.decode_raw(img_features['image_raw'], tf.uint8)
128 |
129 | ##########################################################
130 | # you can put data augmentation here, I didn't use it
131 | ##########################################################
132 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset.
133 |
134 |
135 | image = tf.reshape(image, [256, 256])
136 | label = tf.cast(img_features['label'], tf.int32)
137 | image_batch, label_batch = tf.train.shuffle_batch([image, label],
138 | batch_size=batch_size,
139 | num_threads=64,
140 | capacity=2000,
141 | min_after_dequeue=20)
142 | return image_batch, tf.reshape(label_batch, [batch_size])
143 |
144 |
145 | # %% Convert data to TFRecord
146 |
147 | #test_dir = 'F://CAE_CNN//data//catdogtest//'
148 | test_dir = 'G:\\dataS-UNIWARD0.4\\val\\'
149 | #save_dir = 'F://CAE_CNN//data//'
150 | save_dir = 'G:\\dataS-UNIWARD0.4\\'
151 | BATCH_SIZE = 25
152 |
153 | # Convert test data: you just need to run it ONCE !
154 | name_test = 'S_UNIWARD0.4val'
155 |
156 | #images, labels = get_file(test_dir)
157 | #convert_to_tfrecord(images, labels, save_dir, name_test)
158 |
159 | # %% TO test train.tfrecord file
160 |
161 | def plot_images(images, labels):
162 | '''plot one batch size
163 | '''
164 | for i in np.arange(0, BATCH_SIZE):
165 | plt.subplot(5, 5, i + 1)
166 | plt.axis('off')
167 | plt.title(chr(ord('D') + labels[i] - 1), fontsize=14)
168 | plt.subplots_adjust(top=1.5)
169 | plt.imshow(images[i])
170 | plt.show()
171 |
172 |
173 | # tfrecords_file = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//test.tfrecords'
174 | tfrecords_file = 'G:\\dataS-UNIWARD0.4\\S_UNIWARD0.4val.tfrecords'
175 |
176 | image_batch, label_batch = read_and_decode(tfrecords_file, batch_size=BATCH_SIZE)
177 |
178 | with tf.Session() as sess:
179 | i = 0
180 | coord = tf.train.Coordinator()
181 | threads = tf.train.start_queue_runners(coord=coord)
182 |
183 | try:
184 | while not coord.should_stop() and i < 1:
185 | # just plot one batch size
186 | image, label = sess.run([image_batch, label_batch])
187 | plot_images(image, label)
188 | i += 1
189 |
190 | except tf.errors.OutOfRangeError:
191 | print('done!')
192 | finally:
193 | coord.request_stop()
194 | coord.join(threads)
195 |
196 |
197 | # %%
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
--------------------------------------------------------------------------------
/new version/input_data1.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/4/27 0027 10:39
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | import os
10 | import math
11 |
12 | def read_and_decode(tfrecords_file, batch_size):
13 | '''read and decode tfrecord file, generate (image, label) batches
14 | Args:
15 | tfrecords_file: the directory of tfrecord file
16 | batch_size: number of images in each batch
17 | Returns:
18 | image: 4D tensor - [batch_size, width, height, channel]
19 | label: 1D tensor - [batch_size]
20 | '''
21 | # make an input queue from the tfrecord file
22 | filename_queue = tf.train.string_input_producer([tfrecords_file])
23 |
24 | reader = tf.TFRecordReader()
25 | _, serialized_example = reader.read(filename_queue)
26 | img_features = tf.parse_single_example(
27 | serialized_example,
28 | features={
29 | 'label': tf.FixedLenFeature([], tf.int64),
30 | 'image_raw': tf.FixedLenFeature([], tf.string),
31 | })
32 | image = tf.decode_raw(img_features['image_raw'], tf.uint8)
33 |
34 |
35 | ##########################################################
36 | # you can put data augmentation here, I didn't use it
37 | ##########################################################
38 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset.
39 |
40 |
41 | image = tf.reshape(image, [256, 256, 1])
42 | label = tf.cast(img_features['label'], tf.int32)
43 | image_batch, label_batch = tf.train.batch([image, label],
44 | batch_size=batch_size,
45 | num_threads=64,
46 | capacity=2000)
47 |
48 | image_batch = tf.cast(image_batch, tf.float32)
49 |
50 | return image_batch, tf.reshape(label_batch, [batch_size])
51 |
52 |
--------------------------------------------------------------------------------
/new version/model1.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/5/4 0004 10:19
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import tensorflow as tf
8 |
9 | def inference(images, batch_size, n_classes):
10 | with tf.variable_scope('conv1') as scope:
11 | weights = tf.get_variable('weights',
12 | #kernel size, kernel size, channels, kernel number
13 | shape=[3, 3, 1, 32],
14 | dtype=tf.float32,
15 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32))
16 | biases = tf.get_variable('biases',
17 | shape=[32],
18 | dtype=tf.float32,
19 | initializer=tf.constant_initializer(0.1))
20 | conv = tf.nn.conv2d(images, weights, strides=[1, 1, 1, 1], padding='SAME')
21 | pre_activation = tf.nn.bias_add(conv, biases)
22 | conv1 = tf.nn.relu(pre_activation, name=scope.name)
23 |
24 | # with tf.variable_scope('pooling1_lrn') as scope:
25 | # pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME', name='pooling1')
26 | # norm1 = tf.nn.lrn(pool1, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm1')
27 |
28 |
29 | with tf.variable_scope('conv2') as scope:
30 | weights = tf.get_variable('weights',
31 | shape=[3, 3, 32, 16],
32 | dtype=tf.float32,
33 | initializer=tf.truncated_normal_initializer(stddev=0.1, dtype=tf.float32))
34 | biases = tf.get_variable('biases',
35 | shape=[16],
36 | dtype=tf.float32,
37 | initializer=tf.constant_initializer(0.1))
38 | conv = tf.nn.conv2d(conv1, weights, strides=[1, 1, 1, 1], padding='SAME')
39 | pre_activation = tf.nn.bias_add(conv, biases)
40 | conv2 = tf.nn.relu(pre_activation, name='conv2')
41 |
42 | # pool2 and norm2
43 | # with tf.variable_scope('pooling2_lrn') as scope:
44 | # norm2 = tf.nn.lrn(conv2, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, name='norm2')
45 | # pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], strides=[1, 1, 1, 1], padding='SAME', name='pooling2')
46 |
47 | with tf.variable_scope('local3') as scope:
48 | reshape = tf.reshape(conv2, shape=[batch_size, -1])
49 | dim = reshape.get_shape()[1].value
50 | weights = tf.get_variable('weights',
51 | shape=[dim, 256],
52 | dtype=tf.float32,
53 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
54 | biases = tf.get_variable('biases',
55 | shape=[256],
56 | dtype=tf.float32,
57 | initializer=tf.constant_initializer(0.1))
58 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
59 |
60 | # local4
61 | with tf.variable_scope('local4') as scope:
62 | weights = tf.get_variable('weights',
63 | shape=[256, 256],
64 | dtype=tf.float32,
65 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
66 | biases = tf.get_variable('biases',
67 | shape=[256],
68 | dtype=tf.float32,
69 | initializer=tf.constant_initializer(0.1))
70 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name='local4')
71 |
72 | # softmax
73 | with tf.variable_scope('softmax_linear') as scope:
74 | weights = tf.get_variable('softmax_linear',
75 | shape=[256, n_classes],
76 | dtype=tf.float32,
77 | initializer=tf.truncated_normal_initializer(stddev=0.005, dtype=tf.float32))
78 | biases = tf.get_variable('biases',
79 | shape=[n_classes],
80 | dtype=tf.float32,
81 | initializer=tf.constant_initializer(0.1))
82 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name='softmax_linear')
83 |
84 | return softmax_linear
85 |
86 | #logits 是inference的返回值,labels是ground truth
87 | def losses(logits, labels):
88 | with tf.variable_scope('loss') as scope:
89 |
90 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits \
91 | (logits=logits, labels=labels, name='xentropy_per_example')
92 | loss = tf.reduce_mean(cross_entropy, name='loss')
93 | tf.summary.scalar(scope.name + '/loss', loss)
94 | return loss
95 |
96 |
97 | def trainning(loss, learning_rate):
98 | with tf.name_scope('optimizer'):
99 | optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
100 | global_step = tf.Variable(0, name='global_step', trainable=False)
101 | train_op = optimizer.minimize(loss, global_step=global_step)
102 | return train_op
103 |
104 |
105 | def evaluation(logits, labels):
106 | with tf.variable_scope('accuracy') as scope:
107 | correct = tf.nn.in_top_k(logits, labels, 1)
108 | correct = tf.cast(correct, tf.float16)
109 | accuracy = tf.reduce_mean(correct)
110 | tf.summary.scalar(scope.name + '/accuracy', accuracy)
111 | return accuracy
--------------------------------------------------------------------------------
/new version/readme.md:
--------------------------------------------------------------------------------
1 | # new version
2 | a new version for network training process.
3 | trainning with tfrecord and validation process.
4 |
5 | # merge all
6 | The tf.merge_all_summaries() function is convenient, but also somewhat dangerous: it merges all summaries in the default graph, which includes any summaries from previous—apparently unconnected—invocations of code that also added summary nodes to the default graph. If old summary nodes depend on an old placeholder, you will get errors like the one you have shown in your question (and like previous questions as well).
7 |
--------------------------------------------------------------------------------
/new version/train2.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/5/4 0004 10:19
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import os
8 | import numpy as np
9 | import tensorflow as tf
10 | import input_data1
11 | import model1
12 |
13 | N_CLASSES = 2 # cover与stego
14 | IMG_W = 256 # resize
15 | IMG_H = 256
16 | BATCH_SIZE = 32
17 | CAPACITY = 300
18 | MAX_STEP = 15000 # 一般大于10K
19 | learning_rate = 0.0001 # 一般小于0.0001
20 |
21 |
22 | def run_training():
23 |
24 | logs_train_dir = 'G:\\dataS-UNIWARD0.4\\logs\\train'
25 | logs_val_dir = 'G:\\dataS-UNIWARD0.4\\logs\\val'
26 |
27 | tfrecords_traindir = 'G:\\dataS-UNIWARD0.4\\S_UNIWARD0.4train.tfrecords'
28 | tfrecords_valdir = 'G:\\dataS-UNIWARD0.4\\S_UNIWARD0.4val.tfrecords'
29 |
30 | # 获得batch tfrecord方法
31 | train_batch, train_label_batch = input_data1.read_and_decode(tfrecords_traindir, BATCH_SIZE)
32 | val_batch, val_label_batch = input_data1.read_and_decode(tfrecords_valdir, BATCH_SIZE)
33 |
34 |
35 | x = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 256, 256, 1])
36 | y_ = tf.placeholder(tf.int32, shape=[BATCH_SIZE])
37 |
38 |
39 | logits = model1.inference(x, BATCH_SIZE, N_CLASSES)
40 | loss = model1.losses(logits, y_)
41 | acc = model1.evaluation(logits, y_)
42 | train_op = model1.trainning(loss, learning_rate)
43 |
44 |
45 | sess = tf.Session()
46 | saver = tf.train.Saver()
47 | sess.run(tf.global_variables_initializer())
48 | coord = tf.train.Coordinator()
49 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
50 |
51 |
52 | summary_op = tf.summary.merge_all()
53 | train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
54 | val_writer = tf.summary.FileWriter(logs_val_dir, sess.graph)
55 |
56 | try:
57 | for step in np.arange(MAX_STEP):
58 | if coord.should_stop():
59 | break
60 |
61 | tra_images, tra_labels = sess.run([train_batch, train_label_batch])
62 | _, tra_loss, tra_acc = sess.run([train_op, loss, acc],
63 | feed_dict={x: tra_images, y_: tra_labels})
64 | if step % 2 == 0:
65 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
66 | # summary_str = sess.run(summary_op)
67 | # train_writer.add_summary(summary_str, step)
68 |
69 | if step % 200 == 0 or (step + 1) == MAX_STEP:
70 | val_images, val_labels = sess.run([val_batch, val_label_batch])
71 | val_loss, val_acc = sess.run([loss, acc],
72 | feed_dict={x: val_images, y_: val_labels})
73 | print('** Step %d, val loss = %.2f, val accuracy = %.2f%% **' % (step, val_loss, val_acc * 100.0))
74 | # summary_str = sess.run(summary_op)
75 | # val_writer.add_summary(summary_str, step)
76 |
77 | if step % 2000 == 0 or (step + 1) == MAX_STEP:
78 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
79 | saver.save(sess, checkpoint_path, global_step=step)
80 |
81 | except tf.errors.OutOfRangeError:
82 | print('Done training -- epoch limit reached')
83 | finally:
84 | coord.request_stop()
85 | coord.join(threads)
86 |
87 |
88 | run_training()
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/onehot.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/4/3 0003 10:23
4 | # @Author :
5 | # @Software: PyCharm
6 |
7 | #one hot编码,m个可能值,转换成2元可能互斥特征,从而使得数据变得稀疏
8 | #
9 |
10 | from sklearn.preprocessing import OneHotEncoder
11 |
12 | enc = OneHotEncoder()
13 |
14 | #fit后面四个样本,得到两个参数(实际操作中需要fit多少个元素??)
15 | #enc.n_values_ 是每个样本中每一维度特征的可能数
16 | #enc.active_features_ 是上面可能数的累加
17 | enc.fit([[0, 0, 9], [1, 1, 3],[1,0,8],
18 | [0,0,8],[0,0,4],[0,0,6],
19 | [0,0,5],[0,0,7],
20 | [0, 2, 1],[1, 0, 2]])
21 |
22 |
23 | print ("enc.n_values_ is:",enc.n_values_)
24 | print ("enc.feature_indices_ is:",enc.feature_indices_)
25 |
26 | print (enc.transform([[0, 1, 7]]).toarray())
--------------------------------------------------------------------------------
/rename.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/3/12 0012 16:17
4 | # @Author : jsz
5 | # @Software: PyCharm
6 | import os
7 |
8 |
9 | def rename():
10 | count=1
11 | path = 'F:\CAE_CNN\data\lldata'
12 | filelist = os.listdir(path)
13 | for files in filelist:
14 | Olddir = os.path.join(path, files)
15 | if os.path.isdir(Olddir):
16 | continue
17 | filename = os.path.splitext(files)[0]
18 | filetype = os.path.splitext(files)[1]
19 |
20 | #直接改名字的
21 | # Newdir = os.path.join(path, 'S' + filetype)
22 |
23 | # 文件名前自动增加S
24 | # Newdir = os.path.join(path, ('Stego.'+filename) + filetype)
25 |
26 | # 文件序号一次递增
27 | # Newdir = os.path.join(path, str(count) + filetype)
28 |
29 |
30 |
31 | # 批量取分隔符(___)前面 / 后面的名称
32 | # if filename.find('---')>=0:#如果文件名中含有---
33 | #
34 | # Newdir=os.path.join(direc,filename.split('---')[0]+filetype);
35 | #
36 | # #取---前面的字符,若需要取后面的字符则使用filename.split('---')[1]
37 | #
38 | # if not os.path.isfile(Newdir):
39 |
40 |
41 |
42 | os.rename(Olddir, Newdir)
43 |
44 | count+= 1
45 |
46 | rename()
47 |
--------------------------------------------------------------------------------
/tfrecord.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/4/12 0012 11:02
4 | # @Author :
5 | # @Software: PyCharm
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | import os
10 | import matplotlib.pyplot as plt
11 | import skimage.io as io
12 |
13 | def get_file(file_dir):
14 | cover = []
15 | label_cover = []
16 | stego = []
17 | label_stego = []
18 | # 打标签
19 | for file in os.listdir(file_dir):
20 | # if file.endswith('0') or file.startswith('.'):
21 | # continue # Skip!
22 | name = file.split('.')
23 | if name[0] == 'Cover':
24 | cover.append(file_dir + file)
25 | label_cover.append(0)
26 | if name[0] == 'Stego':
27 | stego.append(file_dir + file)
28 | label_stego.append(1)
29 | print("这里有 %d cover \n这里有 %d stego"
30 | % (len(cover), len(stego)))
31 | # 打乱文件顺序shuffle
32 | image_list = np.hstack((cover, stego))
33 | label_list = np.hstack((label_cover, label_stego))
34 | temp = np.array([image_list, label_list])
35 | temp = temp.transpose()
36 | np.random.shuffle(temp)
37 |
38 | image_list = list(temp[:, 0])
39 | label_list = list(temp[:, 1])
40 | label_list = [int(i) for i in label_list]
41 |
42 | return image_list, label_list
43 |
44 | # %%
45 |
46 | def int64_feature(value):
47 | """Wrapper for inserting int64 features into Example proto."""
48 | if not isinstance(value, list):
49 | value = [value]
50 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
51 |
52 |
53 | def bytes_feature(value):
54 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
55 |
56 |
57 | # %%
58 |
59 | def convert_to_tfrecord(images, labels, save_dir, name):
60 | '''convert all images and labels to one tfrecord file.
61 | Args:
62 | images: list of image directories, string type
63 | labels: list of labels, int type
64 | save_dir: the directory to save tfrecord file, e.g.: '/home/folder1/'
65 | name: the name of tfrecord file, string type, e.g.: 'train'
66 | Return:
67 | no return
68 | Note:
69 | converting needs some time, be patient...
70 | '''
71 |
72 | filename = os.path.join(save_dir, name + '.tfrecords')
73 | n_samples = len(labels)
74 |
75 | if np.shape(images)[0] != n_samples:
76 | raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], n_samples))
77 |
78 | # wait some time here, transforming need some time based on the size of your data.
79 | writer = tf.python_io.TFRecordWriter(filename)
80 | print('\nTransform start......')
81 | for i in np.arange(0, n_samples):
82 | try:
83 | image = io.imread(images[i]) # type(image) must be array!
84 | image_raw = image.tostring()
85 | label = int(labels[i])
86 | example = tf.train.Example(features=tf.train.Features(feature={
87 | 'label': int64_feature(label),
88 | 'image_raw': bytes_feature(image_raw)}))
89 | writer.write(example.SerializeToString())
90 | except IOError as e:
91 | print('Could not read:', images[i])
92 | print('error: %s' % e)
93 | print('Skip it!\n')
94 | writer.close()
95 | print('Transform done!')
96 |
97 |
98 | # %%
99 |
100 | def read_and_decode(tfrecords_file, batch_size):
101 | '''read and decode tfrecord file, generate (image, label) batches
102 | Args:
103 | tfrecords_file: the directory of tfrecord file
104 | batch_size: number of images in each batch
105 | Returns:
106 | image: 4D tensor - [batch_size, width, height, channel]
107 | label: 1D tensor - [batch_size]
108 | '''
109 | # make an input queue from the tfrecord file
110 | filename_queue = tf.train.string_input_producer([tfrecords_file])
111 |
112 | reader = tf.TFRecordReader()
113 | _, serialized_example = reader.read(filename_queue)
114 | img_features = tf.parse_single_example(
115 | serialized_example,
116 | features={
117 | 'label': tf.FixedLenFeature([], tf.int64),
118 | 'image_raw': tf.FixedLenFeature([], tf.string),
119 | })
120 | image = tf.decode_raw(img_features['image_raw'], tf.uint8)
121 |
122 | ##########################################################
123 | # you can put data augmentation here, I didn't use it
124 | ##########################################################
125 | # all the images of notMNIST are 28*28, you need to change the image size if you use other dataset.
126 |
127 | image = tf.reshape(image, [512, 512])
128 | label = tf.cast(img_features['label'], tf.int32)
129 | image_batch, label_batch = tf.train.batch([image, label],
130 | batch_size=batch_size,
131 | num_threads=64,
132 | capacity=2000)
133 | return image_batch, tf.reshape(label_batch, [batch_size])
134 |
135 |
136 | # %% Convert data to TFRecord
137 |
138 | # test_dir = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//notMNIST_small//'
139 | test_dir = 'F://CAE_CNN//data//pgm_coverstego//'
140 |
141 | # save_dir = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//'
142 | save_dir = 'F://CAE_CNN//data//'
143 |
144 | BATCH_SIZE = 25
145 |
146 | # Convert test data: you just need to run it ONCE !
147 | name_test = 'test'
148 | images, labels = get_file(test_dir)
149 | convert_to_tfrecord(images, labels, save_dir, name_test)
150 |
151 |
152 | # %% TO test train.tfrecord file
153 |
154 | def plot_images(images, labels):
155 | '''plot one batch size
156 | '''
157 | for i in np.arange(0, BATCH_SIZE):
158 | plt.subplot(5, 5, i + 1)
159 | plt.axis('off')
160 | plt.title(chr(ord('A') + labels[i] - 1), fontsize=14)
161 | plt.subplots_adjust(top=1.5)
162 | plt.imshow(images[i])
163 | plt.show()
164 |
165 |
166 | # tfrecords_file = 'C://Users//Windows7//Documents//Python Scripts//notMNIST//test.tfrecords'
167 | tfrecords_file = 'F://CAE_CNN//data//test.tfrecords'
168 |
169 | image_batch, label_batch = read_and_decode(tfrecords_file, batch_size=BATCH_SIZE)
170 |
171 | with tf.Session() as sess:
172 | i = 0
173 | coord = tf.train.Coordinator()
174 | threads = tf.train.start_queue_runners(coord=coord)
175 |
176 | try:
177 | while not coord.should_stop() and i < 1:
178 | # just plot one batch size
179 | image, label = sess.run([image_batch, label_batch])
180 | plot_images(image, label)
181 | i += 1
182 |
183 | except tf.errors.OutOfRangeError:
184 | print('done!')
185 | finally:
186 | coord.request_stop()
187 | coord.join(threads)
188 |
189 |
190 | # %%
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/3/13 0013 15:34
4 | # @Author : jsz
5 | # @Software: PyCharm
6 |
7 | import os
8 | import numpy as np
9 | import tensorflow as tf
10 | import input_data
11 | import model
12 |
13 | N_CLASSES = 2 # cover与stego
14 | IMG_W = 256 # resize
15 | IMG_H = 256
16 | BATCH_SIZE = 16
17 | CAPACITY = 2000
18 | MAX_STEP = 15000 # 一般大于10K
19 | learning_rate = 0.001 # 一般小于0.0001
20 |
21 |
22 | def run_training():
23 |
24 | train_dir= 'F://CAE_CNN//data//train_imgs//'
25 | #产生一些文件,可以用tensorboard查看
26 | logs_train_dir = 'F://CAE_CNN//log//train//'
27 |
28 |
29 |
30 | #读取数据
31 | train, train_label = input_data.get_files(train_dir)
32 | #获得batch
33 | train_batch, train_label_batch = input_data.get_batch(train,
34 | train_label,
35 | IMG_W,
36 | IMG_H,
37 | BATCH_SIZE,
38 | CAPACITY)
39 | #参数传景区
40 | train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)
41 |
42 | train_loss = model.losses(train_logits, train_label_batch)
43 | #训练
44 | train_op = model.trainning(train_loss, learning_rate)
45 |
46 | train__acc = model.evaluation(train_logits, train_label_batch)
47 | #merge到一块?
48 | summary_op = tf.summary.merge_all() # 这个是log汇总记录
49 |
50 | # 产生一个会话
51 | sess = tf.Session()
52 | # 产生一个writer来写log文件
53 | train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
54 | # 产生一个saver来存储训练好的模型
55 | saver = tf.train.Saver()
56 | # 所有节点初始化
57 | sess.run(tf.global_variables_initializer())
58 |
59 | # 队列监控
60 | coord = tf.train.Coordinator()
61 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
62 |
63 | for step in np.arange(MAX_STEP):
64 | _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])
65 | # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer
66 | if step % 2 == 0:
67 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
68 | summary_str = sess.run(summary_op)
69 | train_writer.add_summary(summary_str, step)
70 | # 每隔2000步,保存一次训练好的模型
71 | if step % 2000 == 0 or (step + 1) == MAX_STEP:
72 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
73 | saver.save(sess, checkpoint_path, global_step=step)
74 | # try:
75 | # # 执行MAX_STEP步的训练,一步一个batch
76 | # for step in np.arange(MAX_STEP):
77 | # # if coord.should_stop():
78 | # # break
79 | # # 启动以下操作节点,有个疑问,为什么train_logits在这里没有开启?
80 | # _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])
81 | # # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer
82 | # if step % 2 == 0:
83 | # print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
84 | # summary_str = sess.run(summary_op)
85 | # train_writer.add_summary(summary_str, step)
86 | # # 每隔2000步,保存一次训练好的模型
87 | # if step % 2000 == 0 or (step + 1) == MAX_STEP:
88 | # checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
89 | # saver.save(sess, checkpoint_path, global_step=step)
90 | #
91 | # except tf.errors.OutOfRangeError:
92 | # print('Done training -- epoch limit reached')
93 | # finally:
94 | # coord.request_stop()
95 | # sess.close()
96 |
97 | run_training()
98 |
--------------------------------------------------------------------------------
/train1.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2018/4/12 0012 13:50
4 | # @Author :
5 | # @Software: PyCharm
6 | # ! /usr/bin/python3
7 | # -*- coding: utf-8 -*-
8 | # @Time : 2018/3/13 0013 15:34
9 | # @Author : jsz
10 | # @Software: PyCharm
11 |
12 | import os
13 | import numpy as np
14 | import tensorflow as tf
15 | import input_data
16 | import model
17 |
18 | N_CLASSES = 2 # cover与stego
19 | IMG_W = 256 # resize
20 | IMG_H = 256
21 | BATCH_SIZE = 16
22 | CAPACITY = 300
23 | MAX_STEP = 15000 # 一般大于10K
24 | learning_rate = 0.0001 # 一般小于0.0001
25 |
26 |
27 | def run_training():
28 |
29 | logs_train_dir = 'F://CAE_CNN//log//train1//'
30 | tfrecords_dir = 'F://CAE_CNN//data//test.tfrecords'
31 |
32 | # 获得batch tfrecord方法
33 | train_batch, train_label_batch = input_data.read_and_decode(tfrecords_dir, BATCH_SIZE)
34 |
35 | train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES)
36 |
37 | train_loss = model.losses(train_logits, train_label_batch)
38 | # 训练
39 | train_op = model.trainning(train_loss, learning_rate)
40 |
41 | train__acc = model.evaluation(train_logits, train_label_batch)
42 | # merge到一块?
43 | summary_op = tf.summary.merge_all() # 这个是log汇总记录
44 |
45 | # 产生一个会话
46 | sess = tf.Session()
47 | # 产生一个writer来写log文件
48 | train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph)
49 | # 产生一个saver来存储训练好的模型
50 | saver = tf.train.Saver()
51 | # 所有节点初始化
52 | sess.run(tf.global_variables_initializer())
53 |
54 | # 队列监控
55 | coord = tf.train.Coordinator()
56 | threads = tf.train.start_queue_runners(sess=sess, coord=coord)
57 |
58 | # for step in np.arange(MAX_STEP):
59 | # _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])
60 | # # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer
61 | # if step % 2 == 0:
62 | # print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
63 | # summary_str = sess.run(summary_op)
64 | # train_writer.add_summary(summary_str, step)
65 | # # 每隔2000步,保存一次训练好的模型
66 | # if step % 2000 == 0 or (step + 1) == MAX_STEP:
67 | # checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
68 | # saver.save(sess, checkpoint_path, global_step=step)
69 | try:
70 | # 执行MAX_STEP步的训练,一步一个batch
71 | for step in np.arange(MAX_STEP):
72 | if coord.should_stop():
73 | break
74 | # 启动以下操作节点,有个疑问,为什么train_logits在这里没有开启?
75 | _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc])
76 | # 每隔50步打印一次当前的loss以及acc,同时记录log,写入writer
77 | if step % 2 == 0:
78 | print('Step %d, train loss = %.2f, train accuracy = %.2f%%' % (step, tra_loss, tra_acc * 100.0))
79 | summary_str = sess.run(summary_op)
80 | train_writer.add_summary(summary_str, step)
81 | # 每隔2000步,保存一次训练好的模型
82 | if step % 2000 == 0 or (step + 1) == MAX_STEP:
83 | checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
84 | saver.save(sess, checkpoint_path, global_step=step)
85 |
86 | except tf.errors.OutOfRangeError:
87 | print('Done training -- epoch limit reached')
88 | finally:
89 | coord.request_stop()
90 |
91 | coord.join(threads)
92 | sess.close()
93 |
94 |
95 | run_training()
96 |
--------------------------------------------------------------------------------