├── Data
└── BOSSBase_512
│ ├── 1.pgm
│ ├── 10.pgm
│ ├── 2.pgm
│ ├── 3.pgm
│ ├── 4.pgm
│ ├── 5.pgm
│ ├── 6.pgm
│ ├── 7.pgm
│ ├── 8.pgm
│ └── 9.pgm
├── Implement
├── .idea
│ ├── Implement.iml
│ ├── misc.xml
│ ├── modules.xml
│ └── workspace.xml
├── SRM_Kernels.npy
├── YeNet.py
├── generator.py
├── layers.py
├── main.py
├── preprocessing.py
├── testfiles
│ ├── command.sh
│ ├── test_data_split.py
│ └── utils_test.py
└── utils.py
├── README.md
├── command_BOSS.sh
├── command_BOSSTEST.sh
├── command_SUNI_0.4_15000_No_1.sh
└── command_SUNI_0.4_15000_No_2.sh
/Data/BOSSBase_512/1.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/1.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/10.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/10.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/3.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/3.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/4.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/4.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/5.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/5.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/6.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/6.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/7.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/7.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/8.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/8.pgm
--------------------------------------------------------------------------------
/Data/BOSSBase_512/9.pgm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Data/BOSSBase_512/9.pgm
--------------------------------------------------------------------------------
/Implement/.idea/Implement.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/Implement/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/Implement/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/Implement/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
94 |
95 |
96 |
97 | global_st
98 | learning_rate_decay
99 | #
100 | ##
101 | global_step
102 | '''
103 | \
104 | local_learning_rate
105 | reduce
106 | learn
107 | test_dataset
108 | time
109 | dataaug
110 | glob
111 | misc
112 | random
113 | rand
114 | source_cover_list_shuf
115 | tlu_threshold
116 | split
117 |
118 |
119 | # *
120 | #
121 | """
122 |
123 | source_cover_list
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 | true
144 | DEFINITION_ORDER
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 | 1520340817798
214 |
215 |
216 | 1520340817798
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
--------------------------------------------------------------------------------
/Implement/SRM_Kernels.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/changshihyoung/TensorFlow-YeNet/ffa3778c350b1636f174d7fd0ab4bfec5ec79d83/Implement/SRM_Kernels.npy
--------------------------------------------------------------------------------
/Implement/YeNet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import layers
3 | from tensorflow.contrib.framework import add_arg_scope, arg_scope, arg_scoped_arguments
4 | import layers as my_layers
5 | from utils import *
6 |
7 | SRM_Kernels = np.load('/home/carlchang/YeNetTensorflow/Implement/SRM_Kernels.npy')
8 |
9 | class YeNet(Model):
10 | def __init__(self, is_training=None, data_format='NCHW',
11 | with_bn=False, tlu_threshold=3):
12 | super(YeNet, self).__init__(is_training=is_training,
13 | data_format=data_format)
14 | self.with_bn = with_bn
15 | self.tlu_threshold = tlu_threshold
16 |
17 | def _build_model(self, inputs):
18 | self.inputs = inputs
19 | if self.data_format == 'NCHW':
20 | channel_axis = 1
21 | _inputs = tf.cast(tf.transpose(inputs, [0, 3, 1, 2]), tf.float32)
22 | else:
23 | channel_axis = 3
24 | _inputs = tf.cast(inputs, tf.float32)
25 | self.L = []
26 | with arg_scope([layers.avg_pool2d],
27 | padding='VALID', data_format=self.data_format):
28 | with tf.variable_scope('SRM_preprocess'):
29 | W_SRM = tf.get_variable('W', initializer=SRM_Kernels,
30 | dtype=tf.float32,
31 | regularizer=None)
32 | b = tf.get_variable('b', shape=[30], dtype=tf.float32,
33 | initializer=tf.constant_initializer(0.))
34 | self.L.append(tf.nn.bias_add(
35 | tf.nn.conv2d(_inputs,
36 | W_SRM, [1,1,1,1], 'VALID',
37 | data_format=self.data_format), b,
38 | data_format=self.data_format, name='Layer1'))
39 | self.L.append(tf.clip_by_value(self.L[-1],
40 | -self.tlu_threshold, self.tlu_threshold,
41 | name='TLU'))
42 | with tf.variable_scope('ConvNetwork'):
43 | with arg_scope([my_layers.conv2d],
44 | num_outputs=30,
45 | kernel_size=3, stride=1, padding='VALID',
46 | data_format=self.data_format,
47 | activation_fn=tf.nn.relu,
48 | weights_initializer=layers.xavier_initializer_conv2d(),
49 | weights_regularizer=layers.l2_regularizer(5e-4),
50 | biases_initializer=tf.constant_initializer(0.2),
51 | biases_regularizer=None), arg_scope([layers.batch_norm],
52 | decay=0.9, center=True, scale=True,
53 | updates_collections=None, is_training=self.is_training,
54 | fused=True, data_format=self.data_format):
55 | if self.with_bn:
56 | self.L.append(layers.batch_norm(self.L[-1],
57 | scope='Norm1'))
58 | self.L.append(my_layers.conv2d(self.L[-1],
59 | scope='Layer2'))
60 | if self.with_bn:
61 | self.L.append(layers.batch_norm(self.L[-1],
62 | scope='Norm2'))
63 | self.L.append(my_layers.conv2d(self.L[-1],
64 | scope='Layer3'))
65 | if self.with_bn:
66 | self.L.append(layers.batch_norm(self.L[-1],
67 | scope='Norm3'))
68 | self.L.append(my_layers.conv2d(self.L[-1],
69 | scope='Layer4'))
70 | if self.with_bn:
71 | self.L.append(layers.batch_norm(self.L[-1],
72 | scope='Norm4'))
73 | self.L.append(layers.avg_pool2d(self.L[-1],
74 | kernel_size=[2,2], scope='Stride1'))
75 | with arg_scope([my_layers.conv2d], kernel_size=5,
76 | num_outputs=32):
77 | self.L.append(my_layers.conv2d(self.L[-1],
78 | scope='Layer5'))
79 | if self.with_bn:
80 | self.L.append(layers.batch_norm(self.L[-1],
81 | scope='Norm5'))
82 | self.L.append(layers.avg_pool2d(self.L[-1],
83 | kernel_size=[3,3],
84 | scope='Stride2'))
85 | self.L.append(my_layers.conv2d(self.L[-1],
86 | scope='Layer6'))
87 | if self.with_bn:
88 | self.L.append(layers.batch_norm(self.L[-1],
89 | scope='Norm6'))
90 | self.L.append(layers.avg_pool2d(self.L[-1],
91 | kernel_size=[3,3],
92 | scope='Stride3'))
93 | self.L.append(my_layers.conv2d(self.L[-1],
94 | scope='Layer7'))
95 | if self.with_bn:
96 | self.L.append(layers.batch_norm(self.L[-1],
97 | scope='Norm7'))
98 | self.L.append(layers.avg_pool2d(self.L[-1],
99 | kernel_size=[3,3],
100 | scope='Stride4'))
101 | self.L.append(my_layers.conv2d(self.L[-1],
102 | num_outputs=16,
103 | scope='Layer8'))
104 | if self.with_bn:
105 | self.L.append(layers.batch_norm(self.L[-1],
106 | scope='Norm8'))
107 | self.L.append(my_layers.conv2d(self.L[-1],
108 | num_outputs=16, stride=3,
109 | scope='Layer9'))
110 | if self.with_bn:
111 | self.L.append(layers.batch_norm(self.L[-1],
112 | scope='Norm9'))
113 | self.L.append(layers.flatten(self.L[-1]))
114 | self.L.append(layers.fully_connected(self.L[-1], num_outputs=2,
115 | activation_fn=None, normalizer_fn=None,
116 | weights_initializer=tf.random_normal_initializer(mean=0.,
117 | stddev=0.01),
118 | biases_initializer=tf.constant_initializer(0.), scope='ip'))
119 | self.outputs = self.L[-1]
120 | return self.outputs
121 |
122 |
--------------------------------------------------------------------------------
/Implement/generator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | from scipy import misc, io
5 | from random import shuffle
6 |
7 | def get_files(cover_dir, stego_dir, use_shuf_pair=False):
8 | """
9 | 从cover和stego文件夹中提取图片,返回到get_batches组成batch
10 | shuf_pair决定了组成batch时,cover与stego是否成对
11 | """
12 | file = []
13 | for filename in os.listdir(cover_dir + '/'):
14 | file.append(filename)
15 | shuffle(file)
16 | file_shuf1 = file
17 |
18 | img = []
19 | img_label = []
20 | if use_shuf_pair:
21 | shuffle(file)
22 | file_shuf2 = file
23 | for file_idx in range(len(file_shuf1)):
24 | img.append(cover_dir + '/' + file_shuf1[file_idx])
25 | img_label.append(0)
26 | img.append(stego_dir + '/' + file_shuf2[file_idx])
27 | img_label.append(1)
28 | else:
29 | for filename in file_shuf1:
30 | img.append(cover_dir + '/' + filename)
31 | img_label.append(0)
32 | img.append(stego_dir + '/' + filename)
33 | img_label.append(1)
34 |
35 | #将img_list和img_label写入cover路径下的img_label_list.txt
36 | #with open(cover_dir + '/' + 'img_label_list.txt', 'w') as f:
37 | # for img_idx in range(len(img)):
38 | # f.write(img[img_idx]+' '+str(img_label[img_idx])+'\n')
39 |
40 | return img, img_label
41 |
42 | def get_minibatches(img, img_label, batch_size):
43 | """
44 | 替代get_batches函数的作用,批次读取数据,每次返回batch_size大小的数据
45 | """
46 | for start_idx in range(0, len(img) - batch_size + 1, batch_size):
47 | excerpt = slice(start_idx, start_idx + batch_size)
48 | img_minibatch = img[excerpt]
49 | img_label_minibatch = img_label[excerpt]
50 | yield img_minibatch, img_label_minibatch
51 |
52 | def get_minibatches_content_img(train_img_minibatch_list, img_height, img_width):
53 | """
54 | 读取get_minibatches函数返回路径对应的内容,将图片实际内容转换为batch,作为返回值
55 | """
56 | img_num = len(train_img_minibatch_list)
57 | image_minibatch_content = np.zeros([img_num, img_height, img_width, 1], dtype=np.float32)
58 |
59 | i = 0
60 | for img_file in train_img_minibatch_list:
61 | content = misc.imread(img_file)
62 | image_minibatch_content[i, :, :, 0] = content
63 | i = i + 1
64 |
65 | return image_minibatch_content
66 |
67 | """
68 | def get_batches(img, img_label, batch_size, capacity):
69 | #
70 | #根据get_files返回的图片列表和标签列表,生成训练用batch
71 | #需要注意的是:输入图片应具有相同的高、宽
72 | #
73 | img = tf.cast(img, tf.string)
74 | img_label = tf.cast(img_label, tf.int32)
75 |
76 | # 生成输入队列(queue),tensorflow有多种方法,这里展示image与label分开时的情况
77 | input_queue = tf.train.slice_input_producer([img, img_label])
78 |
79 | # 从队列里读出label,image(需要对相应的图片进行解码)
80 | label = input_queue[1]
81 | image_contents = tf.read_file(input_queue[0]) #pgm图像不能这么使用
82 | image = tf.image.decode_image(image_contents, channels=1) #pgm图像不能这么使用
83 | ##数据集augmentation的部分
84 |
85 | # 对数据进行大小标准化等操作,tf.image下有很多对image的处理,randomflip等
86 | #image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
87 | #image = tf.image.per_image_standardization(image)
88 |
89 | #[image, label]是tensor型变量
90 | image_batch, label_batch = tf.train.batch([image, label],
91 | batch_size=batch_size,
92 | num_threads=64,
93 | capacity=capacity)
94 | label_batch = tf.reshape(label_batch, [batch_size])
95 |
96 | return image_batch, label_batch
97 | """
--------------------------------------------------------------------------------
/Implement/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib import layers
3 | from tensorflow.contrib.framework import add_arg_scope
4 |
5 | @add_arg_scope
6 | def conv2d(inputs,
7 | num_outputs,
8 | kernel_size,
9 | stride=1,
10 | padding='SAME',
11 | data_format=None,
12 | rate=1,
13 | activation_fn=tf.nn.relu,
14 | normalizer_fn=None,
15 | normalize_after_activation=True,
16 | normalizer_params=None,
17 | weights_initializer=layers.xavier_initializer(),
18 | weights_regularizer=None,
19 | biases_initializer=tf.zeros_initializer(),
20 | biases_regularizer=None,
21 | reuse=None,
22 | variables_collections=None,
23 | outputs_collections=None,
24 | trainable=True,
25 | scope=None):
26 | with tf.variable_scope(scope, 'Conv', reuse=reuse):
27 | if data_format == 'NHWC':
28 | num_inputs = inputs.get_shape().as_list()[3]
29 | height = inputs.get_shape().as_list()[1]
30 | width = inputs.get_shape().as_list()[2]
31 | if isinstance(stride, int):
32 | strides = [1, stride, stride, 1]
33 | elif isinstance(stride, list) or isinstance(stride, tuple):
34 | if len(stride) == 1:
35 | strides = [1] + stride * 2 + [1]
36 | else:
37 | strides = [1, stride[0], stride[1], 1]
38 | else:
39 | raise TypeError('stride is not an int, list or'
40 | + 'a tuple, is %s' % type(stride))
41 | else:
42 | num_inputs = inputs.get_shape().as_list()[1]
43 | height = inputs.get_shape().as_list()[2]
44 | width = inputs.get_shape().as_list()[3]
45 | if isinstance(stride, int):
46 | strides = [1, 1, stride, stride]
47 | elif isinstance(stride, list) or isinstance(stride, tuple):
48 | if len(stride) == 1:
49 | strides = [1, 1] + stride * 2
50 | else:
51 | strides = [1, 1, stride[0], stride[1]]
52 | else:
53 | raise TypeError('stride is not an int, list or'
54 | + 'a tuple, is %s' % type(stride))
55 | if isinstance(kernel_size, int):
56 | kernel_height = kernel_size
57 | kernel_width = kernel_size
58 | elif isinstance(kernel_size, list) or isinstance(kernel_size, tuple):
59 | kernel_height = kernel_size[0]
60 | kernel_width = kernel_size[1]
61 | else:
62 | raise ValueError('kernel_size is not an int, list or'
63 | + 'a tuple, is %s' % type(kernel_size))
64 | weights = tf.get_variable('weights', [kernel_height,
65 | kernel_width, num_inputs, num_outputs],
66 | 'float32', weights_initializer,
67 | weights_regularizer, trainable,
68 | variables_collections)
69 | outputs = tf.nn.conv2d(inputs, weights, strides, padding,
70 | data_format=data_format)
71 | if biases_initializer is not None:
72 | biases = tf.get_variable('biases', [num_outputs], 'float32',
73 | biases_initializer,
74 | biases_regularizer,
75 | trainable, variables_collections)
76 | outputs = tf.nn.bias_add(outputs, biases, data_format)
77 | if normalizer_fn is not None and not normalize_after_activation:
78 | normalizer_params = normalizer_params or {}
79 | outputs = normalizer_fn(outputs, **normalizer_params)
80 | if activation_fn is not None:
81 | outputs = activation_fn(outputs)
82 | if normalizer_fn is not None and normalize_after_activation:
83 | normalizer_params = normalizer_params or {}
84 | outputs = normalizer_fn(outputs, **normalizer_params)
85 | return outputs
86 |
87 | """
88 | @add_arg_scope
89 | def double_conv2d(ref_half, real_half,
90 | num_outputs,
91 | kernel_size,
92 | stride=1,
93 | padding='SAME',
94 | data_format=None,
95 | rate=1,
96 | activation_fn=tf.nn.relu,
97 | normalizer_fn=None,
98 | normalize_after_activation=True,
99 | normalizer_params=None,
100 | weights_initializer=layers.xavier_initializer(),
101 | weights_regularizer=None,
102 | biases_initializer=tf.zeros_initializer(),
103 | biases_regularizer=None,
104 | reuse=None,
105 | variables_collections=None,
106 | outputs_collections=None,
107 | trainable=True,
108 | scope=None):
109 | with tf.variable_scope(scope, 'Conv', reuse=reuse):
110 | if data_format == 'NHWC':
111 | num_inputs = real_half.get_shape().as_list()[3]
112 | height = real_half.get_shape().as_list()[1]
113 | width = real_half.get_shape().as_list()[2]
114 | if isinstance(stride, int):
115 | strides = [1, stride, stride, 1]
116 | elif isinstance(stride, list) or isinstance(stride, tuple):
117 | if len(stride) == 1:
118 | strides = [1] + stride * 2 + [1]
119 | else:
120 | strides = [1, stride[0], stride[1], 1]
121 | else:
122 | raise TypeError('stride is not an int, list or'
123 | + 'a tuple, is %s' % type(stride))
124 | else:
125 | num_inputs = real_half.get_shape().as_list()[1]
126 | height = real_half.get_shape().as_list()[2]
127 | width = real_half.get_shape().as_list()[3]
128 | if isinstance(stride, int):
129 | strides = [1, 1, stride, stride]
130 | elif isinstance(stride, list) or isinstance(stride, tuple):
131 | if len(stride) == 1:
132 | strides = [1, 1] + stride * 2
133 | else:
134 | strides = [1, 1, stride[0], stride[1]]
135 | else:
136 | raise TypeError('stride is not an int, list or' \
137 | + 'a tuple, is %s' % type(stride))
138 | if isinstance(kernel_size, int):
139 | kernel_height = kernel_size
140 | kernel_width = kernel_size
141 | elif isinstance(kernel_size, list) \
142 | or isinstance(kernel_size, tuple):
143 | kernel_height = kernel_size[0]
144 | kernel_width = kernel_size[1]
145 | else:
146 | raise ValueError('kernel_size is not an int, list or'
147 | + 'a tuple, is %s' % type(kernel_size))
148 | weights = tf.get_variable('weights', [kernel_height,
149 | kernel_width, num_inputs, num_outputs],
150 | 'float32', weights_initializer,
151 | weights_regularizer, trainable,
152 | variables_collections)
153 | ref_outputs = tf.nn.conv2d(ref_half, weights, strides, padding,
154 | data_format=data_format)
155 | real_outputs = tf.nn.conv2d(real_half, weights, strides, padding,
156 | data_format=data_format)
157 | if biases_initializer is not None:
158 | biases = tf.get_variable('biases', [num_outputs], 'float32',
159 | biases_initializer,
160 | biases_regularizer,
161 | trainable, variables_collections)
162 | ref_outputs = tf.nn.bias_add(ref_outputs, biases, data_format)
163 | real_outputs = tf.nn.bias_add(real_outputs, biases, data_format)
164 | if normalizer_fn is not None and not normalize_after_activation:
165 | normalizer_params = normalizer_params or {}
166 | ref_outputs, real_outputs = normalizer_fn(ref_outputs,
167 | real_outputs,
168 | **normalizer_params)
169 | if activation_fn is not None:
170 | ref_outputs = activation_fn(ref_outputs)
171 | real_outputs = activation_fn(real_outputs)
172 | if normalizer_fn is not None and normalize_after_activation:
173 | normalizer_params = normalizer_params or {}
174 | ref_outputs, real_outputs = normalizer_fn(ref_outputs,
175 | real_outputs,
176 | **normalizer_params)
177 | return ref_outputs, real_outputs
178 |
179 | class Vbn_double(object):
180 | def __init__(self, x, epsilon=1e-5, scope=None):
181 | shape = x.get_shape().as_list()
182 | needs_reshape = len(shape) != 4
183 | if needs_reshape:
184 | orig_shape = shape
185 | if len(shape) == 2:
186 | if data_format == 'NCHW':
187 | x = tf.reshape(x, [shape[0], shape[1], 0, 0])
188 | else:
189 | x = tf.reshape(x, [shape[0], 1, 1, shape[1]])
190 | elif len(shape) == 1:
191 | x = tf.reshape(x, [shape[0], 1, 1, 1])
192 | else:
193 | assert False, shape
194 | shape = x.get_shape().as_list()
195 | with tf.variable_scope(scope):
196 | self.epsilon = epsilon
197 | self.scope = scope
198 | self.mean, self.var = tf.nn.moments(x, [0,2,3], \
199 | keep_dims=True)
200 | self.inv_std = tf.rsqrt(self.var + epsilon)
201 | self.batch_size = int(x.get_shape()[0])
202 | out = self._normalize(x, self.mean, self.inv_std)
203 | if needs_reshape:
204 | out = tf.reshape(out, orig_shape)
205 | self.reference_output = out
206 |
207 | def __call__(self, x):
208 | shape = x.get_shape().as_list()
209 | needs_reshape = len(shape) != 4
210 | if needs_reshape:
211 | orig_shape = shape
212 | if len(shape) == 2:
213 | if self.data_format == 'NCHW':
214 | x = tf.reshape(x, [shape[0], shape[1], 0, 0])
215 | else:
216 | x = tf.reshape(x, [shape[0], 1, 1, shape[1]])
217 | elif len(shape) == 1:
218 | x = tf.reshape(x, [shape[0], 1, 1, 1])
219 | else:
220 | assert False, shape
221 | with tf.variable_scope(self.scope, reuse=True):
222 | out = self._normalize(x, self.mean, self.inv_std)
223 | if needs_reshape:
224 | out = tf.reshape(out, orig_shape)
225 | return out
226 |
227 | def _normalize(self, x, mean, inv_std):
228 | shape = x.get_shape().as_list()
229 | assert len(shape) == 4
230 | gamma = tf.get_variable("gamma", [1,shape[1],1,1],
231 | initializer=tf.constant_initializer(1.))
232 | beta = tf.get_variable("beta", [1,shape[1],1,1],
233 | initializer=tf.constant_initializer(0.))
234 | coeff = gamma * inv_std
235 | return (x * coeff) + (beta - mean * coeff)
236 |
237 | @add_arg_scope
238 | def vbn_double(ref_half, real_half, center=True, scale=True, epsilon=1e-5, \
239 | data_format='NCHW', instance_norm=True, scope=None, \
240 | reuse=None):
241 | assert isinstance(epsilon, float)
242 | shape = real_half.get_shape().as_list()
243 | batch_size = int(real_half.get_shape()[0])
244 | with tf.variable_scope(scope, 'VBN', reuse=reuse):
245 | if data_format == 'NCHW':
246 | if scale:
247 | gamma = tf.get_variable("gamma", [1,shape[1],1,1],
248 | initializer=tf.constant_initializer(1.))
249 | if center:
250 | beta = tf.get_variable("beta", [1,shape[1],1,1],
251 | initializer=tf.constant_initializer(0.))
252 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,2,3], \
253 | keep_dims=True)
254 | else:
255 | if scale:
256 | gamma = tf.get_variable("gamma", [1,1,1,shape[-1]],
257 | initializer=tf.constant_initializer(1.))
258 | if center:
259 | beta = tf.get_variable("beta", [1,1,1,shape[-1]],
260 | initializer=tf.constant_initializer(0.))
261 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,1,2], \
262 | keep_dims=True)
263 | def _normalize(x, mean, var):
264 | inv_std = tf.rsqrt(var + epsilon)
265 | if scale:
266 | coeff = inv_std * gamma
267 | else:
268 | coeff = inv_std
269 | if center:
270 | return (x * coeff) + (beta - mean * coeff)
271 | else:
272 | return (x - mean) * coeff
273 | if instance_norm:
274 | if data_format == 'NCHW':
275 | real_mean, real_var = tf.nn.moments(real_half, [2,3], \
276 | keep_dims=True)
277 | else:
278 | real_mean, real_var = tf.nn.moments(real_half, [1,2], \
279 | keep_dims=True)
280 | real_coeff = 1. / (batch_size + 1.)
281 | ref_coeff = 1. - real_coeff
282 | new_mean = real_coeff * real_mean + ref_coeff * ref_mean
283 | new_var = real_coeff * real_var + ref_coeff * ref_var
284 | ref_output = _normalize(ref_half, ref_mean, ref_var)
285 | real_output = _normalize(real_half, new_mean, new_var)
286 | else:
287 | ref_output = _normalize(ref_half, ref_mean, ref_var)
288 | real_output = _normalize(real_half, ref_mean, ref_var)
289 | return ref_output, real_output
290 |
291 |
292 | @add_arg_scope
293 | def vbn_single(x, center=True, scale=True, \
294 | epsilon=1e-5, data_format='NCHW', \
295 | instance_norm=True, scope=None, \
296 | reuse=None):
297 | assert isinstance(epsilon, float)
298 | shape = x.get_shape().as_list()
299 | if shape[0] is None:
300 | half_size = x.shape[0] // 2
301 | else:
302 | half_size = shape[0] // 2
303 | needs_reshape = len(shape) != 4
304 | if needs_reshape:
305 | orig_shape = shape
306 | if len(shape) == 2:
307 | if data_format == 'NCHW':
308 | x = tf.reshape(x, [shape[0], shape[1], 0, 0])
309 | else:
310 | x = tf.reshape(x, [shape[0], 1, 1, shape[1]])
311 | elif len(shape) == 1:
312 | x = tf.reshape(x, [shape[0], 1, 1, 1])
313 | else:
314 | assert False, shape
315 | shape = x.get_shape().as_list()
316 | batch_size = int(x.get_shape()[0])
317 | with tf.variable_scope(scope, 'VBN', reuse=reuse):
318 | ref_half = tf.slice(x, [0,0,0,0], [half_size, shape[1], \
319 | shape[2], shape[3]])
320 | if data_format == 'NCHW':
321 | if scale:
322 | gamma = tf.get_variable("gamma", [1,shape[1],1,1],
323 | initializer=tf.constant_initializer(1.))
324 | if center:
325 | beta = tf.get_variable("beta", [1,shape[1],1,1],
326 | initializer=tf.constant_initializer(0.))
327 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,2,3], \
328 | keep_dims=True)
329 | else:
330 | if scale:
331 | gamma = tf.get_variable("gamma", [1,1,1,shape[-1]],
332 | initializer=tf.constant_initializer(1.))
333 | if center:
334 | beta = tf.get_variable("beta", [1,1,1,shape[-1]],
335 | initializer=tf.constant_initializer(0.))
336 | ref_mean, ref_var = tf.nn.moments(ref_half, [0,1,2], \
337 | keep_dims=True)
338 | def _normalize(x, mean, var):
339 | inv_std = tf.rsqrt(var + epsilon)
340 | if scale:
341 | coeff = inv_std * gamma
342 | else:
343 | coeff = inv_std
344 | if center:
345 | return (x * coeff) + (beta - mean * coeff)
346 | else:
347 | return (x - mean) * coeff
348 | if instance_norm:
349 | real_half = tf.slice(x, [half_size,0,0,0], \
350 | [half_size, shape[1], shape[2], shape[3]])
351 | if data_format == 'NCHW':
352 | real_mean, real_var = tf.nn.moments(real_half, [2,3], \
353 | keep_dims=True)
354 | else:
355 | real_mean, real_var = tf.nn.moments(real_half, [1,2], \
356 | keep_dims=True)
357 | real_coeff = 1. / (batch_size + 1.)
358 | ref_coeff = 1. - real_coeff
359 | new_mean = real_coeff * real_mean + ref_coeff * ref_mean
360 | new_var = real_coeff * real_var + ref_coeff * ref_var
361 | ref_output = _normalize(ref_half, ref_mean, ref_var)
362 | real_output = _normalize(real_half, new_mean, new_var)
363 | return tf.concat([ref_output, real_output], axis=0)
364 | else:
365 | return _normalize(x, ref_mean, ref_var)
366 | """
367 |
--------------------------------------------------------------------------------
/Implement/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 | import tensorflow as tf
4 | from glob import glob
5 | from preprocessing import *
6 | from generator import *
7 | from utils import *
8 | from YeNet import YeNet
9 |
10 | # *定义命令行输入变量
11 | parser = argparse.ArgumentParser(description='Tensorflow implementation of YeNet')
12 |
13 | # *根据不同操作进行不同的命令行参数定义
14 | operation_set = ['train', 'test', 'datatransfer', 'dataaug', 'datasplit']
15 | print('Operation set ', operation_set)
16 | input_operation = input('The operation you want to perform: ')
17 | if 'train' in input_operation:
18 | parser.add_argument('train_cover_dir', type=str, metavar='PATH',
19 | help='directory of training cover images')
20 | parser.add_argument('train_stego_dir', type=str, metavar='PATH',
21 | help='directory of training stego images')
22 | parser.add_argument('valid_cover_dir', type=str, metavar='PATH',
23 | help='directory of validation cover images')
24 | parser.add_argument('valid_stego_dir', type=str, metavar='PATH',
25 | help='directory of validation stego images')
26 | if 'test' in input_operation:
27 | parser.add_argument('test_cover_dir', type=str, metavar='PATH',
28 | help='directory of testing cover images')
29 | parser.add_argument('test_stego_dir', type=str, metavar='PATH',
30 | help='directory of testing stego images')
31 | if 'datatransfer' in input_operation:
32 | parser.add_argument('source_dir', type=str, metavar='PATH',
33 | help='directory of source images')
34 | parser.add_argument('dest_dir', type=str, metavar='PATH',
35 | help='directory of destination images')
36 | parser.add_argument('--required-size', type=int, default=256, metavar='N',
37 | help='required size of destination images (default: 256)')
38 | parser.add_argument('--required-operation', type=str, default='resize,crop,subsample', metavar='S',
39 | help='transfer operation for source image (default: resize,crop,subsample)')
40 | if 'dataaug' in input_operation:
41 | parser.add_argument('source_dir', type=str, metavar='PATH',
42 | help='directory of source images')
43 | parser.add_argument('dest_dir', type=str, metavar='PATH',
44 | help='directory of destination images')
45 | parser.add_argument('--ratio-rot', type=float, default=0.5, metavar='F',
46 | help='percentage of dataset augmented by rotation (default: 0.5)')
47 | if 'datasplit' in input_operation:
48 | parser.add_argument('source_dir', type=str, metavar='PATH',
49 | help='directory of source cover and stego images')
50 | parser.add_argument('dest_dir', type=str, metavar='PATH',
51 | help='directory of separated dataset')
52 | parser.add_argument('--train-percent', type=float, default=0.6, metavar='F',
53 | help='percentage of dataset used for training (default: 0.6)')
54 | parser.add_argument('--valid-percent', type=float, default=0.2, metavar='F',
55 | help='percentage of dataset used for validation (default: 0.2)')
56 | parser.add_argument('--test-percent', type=float, default=0.2, metavar='F',
57 | help='percentage of dataset used for testing (default: 0.2)')
58 |
59 | if input_operation not in operation_set:
60 | raise NotImplementedError('invalid operation')
61 |
62 | # *定义余下可选命令行参数
63 | parser.add_argument('--use-shuf-pair', action='store_true', default=False,
64 | help='matching cover and stego image when batch is constructed (default: False)')
65 | parser.add_argument('--use-batch-norm', action='store_true', default=False,
66 | help='use batch normalization after each activation (default: False)')
67 | parser.add_argument('--batch-size', type=int, default=32, metavar='N',
68 | help='input batch size for training, testing and validation (default: 32)')
69 | parser.add_argument('--max-epochs', type=int, default=200, metavar='N',
70 | help='number of epochs to train (default: 200)')
71 | parser.add_argument('--lr', type=float, default=4e-1, metavar='F',
72 | help='learning rate (default: 4e-1)')
73 | parser.add_argument('--gpu', type=str, default='0', metavar='S',
74 | help='index of gpu used (default: 0)')
75 | parser.add_argument('--tfseed', type=int, default=1, metavar='S',
76 | help='random seed (default: 1)')
77 | parser.add_argument('--log-interval', type=int, default=20, metavar='N',
78 | help='number of batches before logging training status')
79 | parser.add_argument('--log-path', type=str, default='logs/',
80 | metavar='PATH', help='directory of log file')
81 |
82 | args = parser.parse_args()
83 |
84 | import os
85 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
86 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
87 |
88 | # *设置tf随机种子
89 | tf.set_random_seed(args.tfseed)
90 |
91 | # *根据不同操作输入执行相应函数
92 | if 'datatransfer' in input_operation:
93 | # *数据集预处理主函数
94 | data_transfer(args.source_dir, args.dest_dir,
95 | required_size=args.required_size,
96 | required_operation=args.required_operation)
97 |
98 | if 'dataaug' in input_operation:
99 | # *数据集增广主函数
100 | data_aug(args.source_dir, args.dest_dir, args.ratio_rot)
101 |
102 | if 'datasplit' in input_operation:
103 | # *数据集分割主函数
104 | data_split(args.source_dir, args.dest_dir,
105 | args.batch_size,
106 | train_percent=args.train_percent,
107 | valid_percent=args.valid_percent,
108 | test_percent=args.test_percent)
109 |
110 | if 'train' in input_operation:
111 | # *计算train/valid数据集大小
112 | train_ds_size = len(glob(args.train_cover_dir + '/*')) * 2
113 | if train_ds_size % args.batch_size != 0:
114 | raise ValueError('change batch size for training')
115 | valid_ds_size = len(glob(args.valid_cover_dir + '/*')) * 2
116 | if valid_ds_size % args.batch_size != 0:
117 | raise ValueError('change batch size for validation')
118 | # *训练主函数
119 | train(YeNet, args.use_batch_norm, args.use_shuf_pair,
120 | args.train_cover_dir, args.train_stego_dir,
121 | args.valid_cover_dir, args.valid_stego_dir,
122 | args.batch_size, train_ds_size, valid_ds_size,
123 | args.log_interval, args.max_epochs, args.lr,
124 | args.log_path)
125 |
126 | if 'test' in input_operation:
127 | # *计算test数据集大小
128 | test_ds_size = len(glob(args.test_cover_dir + '/*')) * 2
129 | if test_ds_size % args.batch_size != 0:
130 | raise ValueError('change batch size for testing')
131 | # *查找最佳模型主函数
132 | test_dataset_findbest(YeNet, args.use_shuf_pair,
133 | args.test_cover_dir, args.test_stego_dir, args.max_epochs,
134 | args.batch_size, test_ds_size, args.log_path)
135 |
136 |
137 |
--------------------------------------------------------------------------------
/Implement/preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from scipy import misc
4 | from glob import glob
5 | import shutil
6 | import random
7 | from random import random as rand
8 | from random import shuffle
9 |
10 | # *数据集预处理主函数
11 | def data_transfer(source_dir, dest_dir,
12 | required_size,
13 | required_operation):
14 | """
15 | 将source_dir中的图像依operation中定义的操作扩展至dest_dir中
16 | 包含resize,subsample和crop操作
17 | """
18 | dest_dir = dest_dir + '/' + source_dir.split("/")[-1]
19 |
20 | # 建立数据集路径
21 | size = required_size, required_size
22 | if 'resize' in required_operation:
23 | dest_resize_dir = dest_dir + '_' + str(required_size) + '_resize'
24 | if os.path.exists(dest_resize_dir + '/') is False:
25 | os.mkdir(dest_resize_dir + '/')
26 | os.mkdir(dest_resize_dir + '/cover/')
27 | if 'crop' in required_operation:
28 | dest_crop_dir = dest_dir + '_' + str(required_size) + '_crop'
29 | if os.path.exists(dest_crop_dir + '/') is False:
30 | os.mkdir(dest_crop_dir + '/')
31 | os.mkdir(dest_crop_dir + '/cover/')
32 | if 'subsample' in required_operation:
33 | dest_subsample_dir = dest_dir + '_' + str(required_size) + '_subsample'
34 | if os.path.exists(dest_subsample_dir + '/') is False:
35 | os.mkdir(dest_subsample_dir + '/')
36 | os.mkdir(dest_subsample_dir + '/cover/')
37 |
38 | source_img_list = glob(source_dir + '/*')
39 | for filename in source_img_list:
40 | img = misc.imread(filename)
41 | if img is None:
42 | raise OSError('Error: could not load image')
43 | if 'resize' in required_operation:
44 | img_resize = misc.imresize(img, size, interp='bicubic')
45 | save_dir = dest_resize_dir + '/cover/' + filename.split("/")[-1]
46 | misc.imsave(save_dir, img_resize)
47 | if 'crop' in required_operation:
48 | ROI_idx = (img.shape[0] - required_size) // 2
49 | img_crop = img[ROI_idx:ROI_idx+required_size, ROI_idx:ROI_idx+required_size]
50 | save_dir = dest_crop_dir + '/cover/' + filename.split("/")[-1]
51 | misc.imsave(save_dir, img_crop)
52 | if 'subsample' in required_operation:
53 | SUB_idx = img.shape[0] // required_size
54 | img_subsample = img[0:img.shape[0]:SUB_idx, 0:img.shape[1]:SUB_idx]
55 | save_dir = dest_subsample_dir + '/cover/' + filename.split("/")[-1]
56 | misc.imsave(save_dir, img_subsample)
57 | print('data transfer succeed!')
58 |
59 | # *数据集增广主函数
60 | def data_aug(source_dir, dest_dir, ratio=0.5):
61 | """
62 | 将source_dir中的图像增广至dest_dir中
63 | 包含rotate和flip操作
64 | """
65 | dest_dir = dest_dir + '/' + source_dir.split("/")[-2] + '_aug'
66 | if os.path.exists(dest_dir + '/') is False:
67 | os.mkdir(dest_dir + '/')
68 | os.mkdir(dest_dir + '/cover/')
69 |
70 | dest_dir = dest_dir + '/cover'
71 |
72 | source_img_list = glob(source_dir + '/*')
73 | for filename in source_img_list:
74 | img = misc.imread(filename)
75 | if img is None:
76 | raise OSError('Error: could not load image')
77 | filename_split = (filename.split("/")[-1])
78 | save_dir = dest_dir + '/' + filename_split
79 | misc.imsave(save_dir, img)
80 |
81 | rot = random.randint(1, 3)
82 | rand_op = rand()
83 | rand_flip = rand()
84 | if rand_op < ratio:
85 | img_rot = misc.imrotate(img, rot*90, interp='bicubic')
86 | save_dir = dest_dir + '/' + filename_split.split('.')[0] + '_rot.' + filename_split.split('.')[1]
87 | misc.imsave(save_dir, img_rot)
88 | else:
89 | if rand_flip < ratio:
90 | img_flip = np.flipud(img)
91 | else:
92 | img_flip = np.fliplr(img)
93 | save_dir = dest_dir + '/' + filename_split.split('.')[0] + '_flip.' + filename_split.split('.')[1]
94 | misc.imsave(save_dir, img_flip)
95 | print('data augment succeed!')
96 |
97 |
98 | # *数据集分割主函数
99 | def data_split(source_dir, dest_dir,
100 | batch_size,
101 | train_percent=0.6,
102 | valid_percent=0.2,
103 | test_percent=0.2):
104 | """
105 | 根据传入的source_dir中cover/stego图像路径,根据各percent参数
106 | 在dest_dir路径中分割成train/valid/test数据集
107 | 抽取方式是随机的
108 | """
109 | # *判断输入百分比是否合法
110 | if (train_percent + valid_percent + test_percent) > 1:
111 | raise ValueError('sum of train valid test percentage larger than 1')
112 |
113 | if os.path.exists(dest_dir + '/') is False:
114 | os.mkdir(dest_dir + '/')
115 | if os.path.exists(source_dir + '/') is False:
116 | raise OSError('source direction not exist')
117 |
118 | source_cover_dir = source_dir + '/cover'
119 | source_stego_dir = source_dir + '/stego'
120 |
121 | # *清理非对应文件
122 | file_clean(source_cover_dir, source_stego_dir)
123 |
124 | # *在dest_dir路径下创建train/valid/test路径
125 | dest_train_dir, dest_valid_dir, dest_test_dir = file_dir_mk_trainvalidtest_dir(dest_dir)
126 |
127 | # *对source_dir中的文件顺序进行shuffle
128 | source_cover_list = []
129 | for filename in os.listdir(source_cover_dir + '/'):
130 | source_cover_list.append(filename)
131 | shuffle(source_cover_list)
132 |
133 | # *计算train/valid/test数据集容量
134 | half_batch_size = batch_size // 2
135 | train_ds_capacity = ( int( len(source_cover_list)*train_percent ) // half_batch_size ) * half_batch_size
136 | valid_ds_capacity = ( int( len(source_cover_list)*valid_percent ) // half_batch_size ) * half_batch_size
137 | test_ds_capacity = ( int( len(source_cover_list)*test_percent ) // half_batch_size ) * half_batch_size
138 |
139 | for fileidx in range(train_ds_capacity):
140 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx]
141 | dstfile_cover = dest_train_dir + '/cover/' + source_cover_list[fileidx]
142 | shutil.copyfile(srcfile_cover, dstfile_cover)
143 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx]
144 | dstfile_stego = dest_train_dir + '/stego/' + source_cover_list[fileidx]
145 | shutil.copyfile(srcfile_stego, dstfile_stego)
146 | for fileidx in range(train_ds_capacity, train_ds_capacity + valid_ds_capacity):
147 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx]
148 | dstfile_cover = dest_valid_dir + '/cover/' + source_cover_list[fileidx]
149 | shutil.copyfile(srcfile_cover, dstfile_cover)
150 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx]
151 | dstfile_stego = dest_valid_dir + '/stego/' + source_cover_list[fileidx]
152 | shutil.copyfile(srcfile_stego, dstfile_stego)
153 | for fileidx in range(train_ds_capacity + valid_ds_capacity,
154 | train_ds_capacity + valid_ds_capacity + test_ds_capacity):
155 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx]
156 | dstfile_cover = dest_test_dir + '/cover/' + source_cover_list[fileidx]
157 | shutil.copyfile(srcfile_cover, dstfile_cover)
158 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx]
159 | dstfile_stego = dest_test_dir + '/stego/' + source_cover_list[fileidx]
160 | shutil.copyfile(srcfile_stego, dstfile_stego)
161 | print('data split succeed!')
162 |
163 | def file_clean(cover_dir, stego_dir):
164 | """
165 | 对cover和stego里的文件进行清理,将只存在于单个文件夹的文件、后缀名不匹配的文件删除。
166 | """
167 | cover_dir = cover_dir + '/'
168 | stego_dir = stego_dir + '/'
169 | cover_list = []
170 | stego_list = []
171 | for root, dirs, files in os.walk(cover_dir):
172 | for filenames in files:
173 | cover_list.append(filenames)
174 | for root, dirs, files in os.walk(stego_dir):
175 | for filenames in files:
176 | stego_list.append(filenames)
177 | diff_cover_list = set(cover_list).difference(set(stego_list))
178 | diff_stego_list = set(stego_list).difference(set(cover_list))
179 | print('Start file cleaning...')
180 | print('About to delete: ', len(diff_cover_list), 'files in ', cover_dir)
181 | for filenames in diff_cover_list:
182 | os.remove(cover_dir + filenames)
183 | print('About to delete: ', len(diff_stego_list), 'files in ', stego_dir)
184 | for filenames in diff_stego_list:
185 | os.remove(stego_dir + filenames)
186 |
187 | def file_dir_mk_trainvalidtest_dir(dest_dir):
188 | """
189 | 在dest_dir路径下创建train/valid/test路径
190 | """
191 | if os.path.exists(dest_dir + '/train/') is False:
192 | os.mkdir(dest_dir + '/train/')
193 | if os.path.exists(dest_dir + '/train/cover/') is False:
194 | os.mkdir(dest_dir + '/train/cover/')
195 | if os.path.exists(dest_dir + '/train/stego/') is False:
196 | os.mkdir(dest_dir + '/train/stego/')
197 | if os.path.exists(dest_dir + '/valid/') is False:
198 | os.mkdir(dest_dir + '/valid/')
199 | if os.path.exists(dest_dir + '/valid/cover/') is False:
200 | os.mkdir(dest_dir + '/valid/cover/')
201 | if os.path.exists(dest_dir + '/valid/stego/') is False:
202 | os.mkdir(dest_dir + '/valid/stego/')
203 | if os.path.exists(dest_dir + '/test/') is False:
204 | os.mkdir(dest_dir + '/test/')
205 | if os.path.exists(dest_dir + '/test/cover/') is False:
206 | os.mkdir(dest_dir + '/test/cover/')
207 | if os.path.exists(dest_dir + '/test/stego/') is False:
208 | os.mkdir(dest_dir + '/test/stego/')
209 | if os.path.exists(dest_dir + '/log/') is False:
210 | os.mkdir(dest_dir + '/log/')
211 | return dest_dir + '/train', dest_dir + '/valid', dest_dir + '/test'
212 |
--------------------------------------------------------------------------------
/Implement/testfiles/command.sh:
--------------------------------------------------------------------------------
1 | #train
2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/train/cover /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/train/stego /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/valid/cover /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/log" --use-batch-norm --lr 4e-1 --max-epochs=300 --log-interval=24 --gpu="0,1,2,3"
3 |
4 | #test
5 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/test/cover /home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/TryCaenorstStructure/log" --gpu="0,1,2,3"
6 |
7 | #data_split
8 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" "
9 |
--------------------------------------------------------------------------------
/Implement/testfiles/test_data_split.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from random import shuffle
4 |
5 | # *数据集分割主函数
6 | def data_split(source_dir, dest_dir,
7 | batch_size,
8 | train_percent=0.6,
9 | valid_percent=0.2,
10 | test_percent=0.2):
11 | """
12 | 根据传入的source_dir中cover/stego图像路径,根据各percent参数
13 | 在dest_dir路径中分割成train/valid/test数据集
14 | 抽取方式是随机的
15 | """
16 | # *判断输入百分比是否合法
17 | if (train_percent + valid_percent + test_percent) > 1:
18 | raise ValueError('sum of train valid test percentage larger than 1')
19 |
20 | if os.path.exists(dest_dir + '/') is False:
21 | os.mkdir(dest_dir + '/')
22 | if os.path.exists(source_dir + '/') is False:
23 | raise OSError('source direction not exist')
24 |
25 | source_cover_dir = source_dir + '/cover'
26 | source_stego_dir = source_dir + '/stego'
27 |
28 | # *清理非对应文件
29 | file_clean(source_cover_dir, source_stego_dir)
30 |
31 | # *在dest_dir路径下创建train/valid/test路径
32 | dest_train_dir, dest_valid_dir, dest_test_dir = file_dir_mk_trainvalidtest_dir(dest_dir)
33 |
34 | # *对source_dir中的文件顺序进行shuffle
35 | source_cover_list = []
36 | for filename in os.listdir(source_cover_dir + '/'):
37 | source_cover_list.append(filename)
38 | shuffle(source_cover_list)
39 |
40 | # *计算train/valid/test数据集容量
41 | half_batch_size = batch_size // 2
42 | train_ds_capacity = ( int( len(source_cover_list)*train_percent ) // half_batch_size ) * half_batch_size
43 | valid_ds_capacity = ( int( len(source_cover_list)*valid_percent ) // half_batch_size ) * half_batch_size
44 | test_ds_capacity = ( int( len(source_cover_list)*test_percent ) // half_batch_size ) * half_batch_size
45 |
46 | for fileidx in range(train_ds_capacity):
47 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx]
48 | dstfile_cover = dest_train_dir + '/cover/' + source_cover_list[fileidx]
49 | shutil.copyfile(srcfile_cover, dstfile_cover)
50 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx]
51 | dstfile_stego = dest_train_dir + '/stego/' + source_cover_list[fileidx]
52 | shutil.copyfile(srcfile_stego, dstfile_stego)
53 | for fileidx in range(train_ds_capacity, train_ds_capacity + valid_ds_capacity):
54 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx]
55 | dstfile_cover = dest_valid_dir + '/cover/' + source_cover_list[fileidx]
56 | shutil.copyfile(srcfile_cover, dstfile_cover)
57 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx]
58 | dstfile_stego = dest_valid_dir + '/stego/' + source_cover_list[fileidx]
59 | shutil.copyfile(srcfile_stego, dstfile_stego)
60 | for fileidx in range(train_ds_capacity + valid_ds_capacity,
61 | train_ds_capacity + valid_ds_capacity + test_ds_capacity):
62 | srcfile_cover = source_cover_dir + '/' + source_cover_list[fileidx]
63 | dstfile_cover = dest_test_dir + '/cover/' + source_cover_list[fileidx]
64 | shutil.copyfile(srcfile_cover, dstfile_cover)
65 | srcfile_stego = source_stego_dir + '/' + source_cover_list[fileidx]
66 | dstfile_stego = dest_test_dir + '/stego/' + source_cover_list[fileidx]
67 | shutil.copyfile(srcfile_stego, dstfile_stego)
68 |
69 | def file_dir_mk_trainvalidtest_dir(dest_dir):
70 | """
71 | 在dest_dir路径下创建train/valid/test路径
72 | """
73 | if os.path.exists(dest_dir + '/train/') is False:
74 | os.mkdir(dest_dir + '/train/')
75 | if os.path.exists(dest_dir + '/train/cover/') is False:
76 | os.mkdir(dest_dir + '/train/cover/')
77 | if os.path.exists(dest_dir + '/train/stego/') is False:
78 | os.mkdir(dest_dir + '/train/stego/')
79 | if os.path.exists(dest_dir + '/valid/') is False:
80 | os.mkdir(dest_dir + '/valid/')
81 | if os.path.exists(dest_dir + '/valid/cover/') is False:
82 | os.mkdir(dest_dir + '/valid/cover/')
83 | if os.path.exists(dest_dir + '/valid/stego/') is False:
84 | os.mkdir(dest_dir + '/valid/stego/')
85 | if os.path.exists(dest_dir + '/test/') is False:
86 | os.mkdir(dest_dir + '/test/')
87 | if os.path.exists(dest_dir + '/test/cover/') is False:
88 | os.mkdir(dest_dir + '/test/cover/')
89 | if os.path.exists(dest_dir + '/test/stego/') is False:
90 | os.mkdir(dest_dir + '/test/stego/')
91 | if os.path.exists(dest_dir + '/log/') is False:
92 | os.mkdir(dest_dir + '/log/')
93 | return dest_dir + '/train', dest_dir + '/valid', dest_dir + '/test'
94 |
95 | def file_clean(cover_dir, stego_dir):
96 | """
97 | 对cover和stego里的文件进行清理,将只存在于单个文件夹的文件、后缀名不匹配的文件删除。
98 | """
99 | cover_dir = cover_dir + '/'
100 | stego_dir = stego_dir + '/'
101 | cover_list = []
102 | stego_list = []
103 | for root, dirs, files in os.walk(cover_dir):
104 | for filenames in files:
105 | cover_list.append(filenames)
106 | for root, dirs, files in os.walk(stego_dir):
107 | for filenames in files:
108 | stego_list.append(filenames)
109 | diff_cover_list = set(cover_list).difference(set(stego_list))
110 | diff_stego_list = set(stego_list).difference(set(cover_list))
111 | print('About to delete: ', len(diff_cover_list), 'files in ', cover_dir, 'Continue?')
112 | os.system('pause')
113 | for filenames in diff_cover_list:
114 | os.remove(cover_dir + filenames)
115 | print('About to delete: ', len(diff_stego_list), 'files in ', stego_dir, 'Continue?')
116 | os.system('pause')
117 | for filenames in diff_stego_list:
118 | os.remove(stego_dir + filenames)
119 | print('file_clean process has completed.')
120 |
121 | if __name__ == '__main__':
122 | source_dir = 'E:\@ChangShihyoung\TensorFlow-YeNet\Data\SUNI_13_0.4'
123 | dest_dir = 'E:\@ChangShihyoung\TensorFlow-YeNet\Experiment\SUNI_13_0.4_No_1'
124 | data_split(source_dir, dest_dir,
125 | 4,
126 | train_percent=0.6,
127 | valid_percent=0.2,
128 | test_percent=0.2)
--------------------------------------------------------------------------------
/Implement/testfiles/utils_test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from scipy import misc, io
4 | import time
5 | from glob import glob
6 | from generator import *
7 |
8 | # *包含loss与acc变量及操作的average_summary类
9 | class average_summary(object):
10 | def __init__(self, variable, name, num_iterations):
11 | # sum_variable:内在累加器,用于累加每次的loss/acc
12 | self.sum_variable = tf.get_variable(name, shape=[],
13 | initializer=tf.constant_initializer(0.),
14 | dtype='float32',
15 | trainable=False,
16 | collections=[tf.GraphKeys.LOCAL_VARIABLES])
17 | # 每个batch调用一次increment_op,累加每次的loss/acc
18 | with tf.control_dependencies([variable]):
19 | self.increment_op = tf.assign_add(self.sum_variable, variable)
20 | # 当increment_op操作调用了num_iterations次之后,可进行下列操作
21 | self.mean_variable = self.sum_variable / float(num_iterations) # 求平均的loss和acc
22 | self.summary = tf.summary.scalar(name, self.mean_variable) # 将loss和acc存入tf全局图
23 | with tf.control_dependencies([self.summary]):
24 | self.reset_variable_op = tf.assign(self.sum_variable, 0) # 当summary完成后,可进行reset
25 | # 外部调用,将loss/acc存入tf全局图
26 | def add_summary(self, sess, writer, step):
27 | s, _ = sess.run([self.summary, self.reset_variable_op])
28 | writer.add_summary(s, step)
29 |
30 | # *用于挂载Net的结构,包含__build_model和__build_loss的操作
31 | class Model(object):
32 | def __init__(self, is_training=None, data_format='NCHW'):
33 | self.data_format = data_format
34 | if is_training is None:
35 | self.is_training = tf.get_variable('is_training', dtype=tf.bool,
36 | initializer=tf.constant_initializer(True),
37 | trainable=False)
38 | else:
39 | self.is_training = is_training
40 |
41 | def _build_model(self, inputs):
42 | raise NotImplementedError('Here is your model definition')
43 |
44 | def _build_losses(self, labels):
45 | self.labels = tf.cast(labels, tf.int64)
46 | with tf.variable_scope('loss'):
47 | oh = tf.one_hot(self.labels, 2) # 这里定义了2分类的输出
48 | # *除softmax cross entropy之外,还可更换其他函数
49 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
50 | labels=oh, logits=self.outputs))
51 | with tf.variable_scope('accuracy'):
52 | am = tf.argmax(self.outputs, 1)
53 | equal = tf.equal(am, self.labels)
54 | self.accuracy = tf.reduce_mean(tf.cast(equal, tf.float32))
55 | return self.loss, self.accuracy
56 |
57 | # *训练主函数
58 | def train(model_class, use_batch_norm, use_shuf_pair,
59 | train_cover_dir, train_stego_dir,
60 | valid_cover_dir, valid_stego_dir,
61 | batch_size, train_ds_size, valid_ds_size,
62 | log_interval, max_epochs, lr,
63 | log_path, load_path=None):
64 | # *清除默认图的堆栈,设置全局图为默认图
65 | tf.reset_default_graph()
66 |
67 | # *is_training用于判断训练处于train或者valid状态
68 | is_training = tf.get_variable('is_training', dtype=tf.bool,
69 | initializer=True, trainable=False)
70 |
71 | # *定义train_op操作和valid_op操作,将is_training和batch_size设置为对应的状态
72 | disable_training_op = tf.assign(is_training, False)
73 | enable_training_op = tf.assign(is_training, True)
74 |
75 | # *模型初始化
76 | # 设置占位符
77 | temp_cover_list = glob(train_cover_dir + '/*')
78 | temp_img = misc.imread(temp_cover_list[0])
79 | temp_img_shape = temp_img.shape
80 | img_batch = tf.placeholder(tf.float32,
81 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1],
82 | name='input_image_batch')
83 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch")
84 | # 使用占位符初始化模型
85 | model = model_class(is_training, 'NCHW', with_bn=use_batch_norm, tlu_threshold=3)
86 | model._build_model(img_batch)
87 | loss, accuracy = model._build_losses(label_batch)
88 |
89 | # *设置需要最小化的loss函数
90 | regularization_losses = tf.get_collection(
91 | tf.GraphKeys.REGULARIZATION_LOSSES)
92 | regularized_loss = tf.add_n([loss] + regularization_losses)
93 | # 定义train中使用的基于loss/acc的类(运行次数:log_interval)
94 | train_loss_s = average_summary(loss, 'train_loss', log_interval)
95 | train_accuracy_s = average_summary(accuracy, 'train_accuracy', log_interval)
96 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size)
97 | valid_loss_s = average_summary(loss, 'valid_loss',
98 | float(valid_ds_size) / float(batch_size))
99 | valid_accuracy_s = average_summary(accuracy, 'valid_accuracy',
100 | float(valid_ds_size) / float(batch_size))
101 |
102 | # *全局变量global_step,从0开始进行计数
103 | global_step = tf.Variable(0, trainable=False)
104 | # *定义核心optimizer
105 | # 定义learning_rate的decay操作
106 | init_learning_rate = lr
107 | decay_steps, decay_rate = 2000, 0.95
108 | learning_rate = learning_rate_decay(init_learning_rate=init_learning_rate,
109 | decay_method="exponential",
110 | global_step=global_step,
111 | decay_steps=decay_steps,
112 | decay_rate=decay_rate)
113 | optimizer = tf.train.AdadeltaOptimizer(learning_rate)
114 |
115 | # *定义train及valid过程中需要用到的操作
116 | # 核心操作:最小化loss
117 | minimize_op = optimizer.minimize(loss=regularized_loss, global_step=global_step)
118 | # 训练操作(每个iteration都要用):最小化loss;train_loss累加;train_acc累加
119 | train_op = tf.group(minimize_op, train_loss_s.increment_op,
120 | train_accuracy_s.increment_op)
121 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加
122 | valid_op = tf.group(valid_loss_s.increment_op,
123 | valid_accuracy_s.increment_op)
124 | # 初始化操作:初始化所有的全局变量和局部变量
125 | init_op = tf.group(tf.global_variables_initializer(),
126 | tf.local_variables_initializer())
127 |
128 | # *定义模型保存变量,最大存储max_to_keep个模型
129 | saver = tf.train.Saver(max_to_keep=max_epochs+20)
130 | global_valid_accuracy = 0 # 全局valid_acc最大值
131 |
132 | # *会话开始
133 | with tf.Session() as sess:
134 | # 初始化所有的全局变量和局部变量
135 | sess.run(init_op)
136 | # 重载模型
137 | if load_path is not None:
138 | loader = tf.train.Saver(reshape=True)
139 | loader.restore(sess, load_path)
140 | # 定义模型及参数保存位置
141 | writer = tf.summary.FileWriter(log_path + '/LogFile/', sess.graph)
142 |
143 | # 初始化train/valid的loss和acc变量
144 | sess.run([valid_loss_s.reset_variable_op,
145 | valid_accuracy_s.reset_variable_op,
146 | train_loss_s.reset_variable_op,
147 | train_accuracy_s.reset_variable_op])
148 |
149 | # *训练开始:train/valid
150 | print('Start training...')
151 | global_train_batch = 0 # 全局batch计数
152 | for epoch in range(max_epochs):
153 | start_time = time.time()
154 | train_img_list, train_label_list = get_files(train_cover_dir,
155 | train_stego_dir,
156 | use_shuf_pair=use_shuf_pair)
157 | valid_img_list, valid_label_list = get_files(valid_cover_dir,
158 | valid_stego_dir,
159 | use_shuf_pair=use_shuf_pair)
160 |
161 | # *训练开始:train
162 | sess.run(enable_training_op) # 转换为train训练状态
163 | local_train_batch = 0 # 局部batch计数
164 | for train_img_minibatch_list, train_label_minibatch_list in \
165 | get_minibatches(train_img_list, train_label_list, batch_size):
166 | # minibatch数据读取
167 | train_img_batch = get_minibatches_content_img(train_img_minibatch_list,
168 | temp_img_shape[0],
169 | temp_img_shape[1])
170 |
171 | # train操作及指标显示
172 | sess.run(train_op, feed_dict={img_batch: train_img_batch,
173 | label_batch: train_label_minibatch_list})
174 |
175 | global_train_batch += 1
176 | local_train_batch += 1
177 |
178 | # 每log_interval个batch后,对train_loss/acc进行存储
179 | # 这是由于train_loss/acc的average_summary以log_interval为基准定义
180 | if global_train_batch % log_interval == 0:
181 | # 注意:loginterval决定了每20输出一次,而不是每个batch存储loss/acc一次
182 | # train_loss/acc显示
183 | local_train_loss = train_loss_s.mean_variable
184 | local_train_accuracy = train_accuracy_s.mean_variable
185 | local_train_loss_value = local_train_loss.eval(session=sess)
186 | local_train_accuracy_value = local_train_accuracy.eval(session=sess)
187 | print('-TRAIN- epoch: %d batch: %d | train_loss: %f train_acc: %f'
188 | % (epoch, local_train_batch, local_train_loss_value, local_train_accuracy_value))
189 | # train_loss/acc存储
190 | train_loss_s.add_summary(sess, writer, global_train_batch)
191 | train_accuracy_s.add_summary(sess, writer, global_train_batch)
192 |
193 | # 对最后20个模型进行存储
194 | if ((train_ds_size // batch_size) * max_epochs - global_train_batch) < 20:
195 | saver.save(sess, log_path + '/Model_' + str(epoch) + '.ckpt')
196 | print('---EPOCH:%d LAST:%d--- model has been saved'
197 | % (epoch, (train_ds_size // batch_size) * max_epochs - global_train_batch + 1))
198 |
199 | # *训练开始:validation
200 | sess.run(disable_training_op)
201 | local_valid_loss, local_valid_accuracy = 0, 0 # 本epoch中valid_loss和valid_acc值
202 | for valid_img_minibatch_list, valid_label_minibatch_list in \
203 | get_minibatches(valid_img_list, valid_label_list, batch_size):
204 | # minibatch数据读取
205 | valid_img_batch = get_minibatches_content_img(valid_img_minibatch_list,
206 | temp_img_shape[0],
207 | temp_img_shape[1])
208 |
209 | # valid操作及指标显示
210 | sess.run(valid_op, feed_dict={img_batch: valid_img_batch,
211 | label_batch: valid_label_minibatch_list})
212 |
213 | # 每个epoch中所有batch运行完后,对valid_loss/acc进行显示和存储
214 | # 这是由于valid_loss/acc的average_summary以(valid_ds_size/batch_size)为基准定义
215 | # valid_loss/acc显示
216 | local_valid_loss = valid_loss_s.mean_variable
217 | local_valid_accuracy = valid_accuracy_s.mean_variable
218 | local_valid_loss_value = local_valid_loss.eval(session=sess)
219 | local_valid_accuracy_value = local_valid_accuracy.eval(session=sess)
220 | print('-VALID- epoch: %d | valid_loss: %f valid_acc: %f'
221 | % (epoch, local_valid_loss_value, local_valid_accuracy_value))
222 | # valid_loss/acc存储
223 | valid_loss_s.add_summary(sess, writer, global_train_batch)
224 | valid_accuracy_s.add_summary(sess, writer, global_train_batch)
225 |
226 | # *模型保存:如果valid_acc大于全局valid_acc,则保存
227 | if local_valid_accuracy_value > global_valid_accuracy:
228 | global_valid_accuracy = local_valid_accuracy_value
229 | saver.save(sess, log_path + '/Model_' + str(epoch) + '.ckpt')
230 | print('---EPOCH:%d--- model has been saved' % (epoch))
231 |
232 | # *本epoch中train及valid过程均完毕,记录时间
233 | end_time = time.time()
234 | print('--EPOCH:%d-- runtime: %.2fs ' % (epoch, end_time - start_time),
235 | ' learning rate: ', sess.run(learning_rate), '\n')
236 |
237 | # *测试主函数,查找最佳模型
238 | def test_dataset_findbest(model_class, use_batch_norm, use_shuf_pair,
239 | test_cover_dir, test_stego_dir, max_epochs,
240 | batch_size, ds_size, log_path):
241 | tf.reset_default_graph()
242 |
243 | # *模型初始化
244 | # 设置占位符
245 | temp_cover_list = glob(test_cover_dir + '/*')
246 | temp_img = misc.imread(temp_cover_list[0])
247 | temp_img_shape = temp_img.shape
248 | img_batch = tf.placeholder(tf.float32,
249 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1],
250 | name='input_image_batch')
251 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch")
252 | # 使用占位符初始化模型
253 | model = model_class(is_training=False, data_format='NCHW',
254 | with_bn=use_batch_norm, tlu_threshold=3)
255 | model._build_model(img_batch)
256 | loss, accuracy = model._build_losses(label_batch)
257 |
258 | # *设置需要计算的loss函数,test_loss/acc与valid_loss/acc的功用类似
259 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size)
260 | test_loss_s = average_summary(loss, 'test_loss',
261 | float(ds_size) / float(batch_size))
262 | test_accuracy_s = average_summary(accuracy, 'test_accuracy',
263 | float(ds_size) / float(batch_size))
264 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加
265 | test_op = tf.group(test_loss_s.increment_op,
266 | test_accuracy_s.increment_op)
267 |
268 | # *全局变量global_step,从0开始进行计数
269 | global_step = tf.Variable(0, trainable=False)
270 |
271 | # 初始化操作:初始化所有的全局变量和局部变量
272 | init_op = tf.group(tf.global_variables_initializer(),
273 | tf.local_variables_initializer())
274 |
275 | # *定义模型保存变量,最大存储max_to_keep个模型
276 | saver = tf.train.Saver(max_to_keep=max_epochs)
277 |
278 | # *记录每次test后得到的loss和acc
279 | test_loss_arr = []
280 | test_accuracy_arr = []
281 |
282 | # *对load_data_path_s列表中的所有模型进行test操作
283 | print('Start testing...')
284 | # 在log路径下搜寻所有可加载文件
285 | load_model_path_s = sorted(glob(log_path + '/*.data*'))
286 | for load_model_path in load_model_path_s:
287 | start_time = time.time()
288 | # *会话开始
289 | with tf.Session() as sess:
290 | # 初始化所有的全局变量和局部变量
291 | sess.run(init_op)
292 | # 重载模型
293 | saver.restore(sess, load_model_path)
294 | # 初始化test的loss和acc变量
295 | sess.run([test_loss_s.reset_variable_op,
296 | test_accuracy_s.reset_variable_op])
297 | # 加载test路径下的img及label列表
298 | test_img_list, test_label_list = get_files(test_cover_dir,
299 | test_stego_dir,
300 | use_shuf_pair=use_shuf_pair)
301 | # *对当前load_data_path的模型进行test操作
302 | for test_img_minibatch_list, test_label_minibatch_list in \
303 | get_minibatches(test_img_list, test_label_list, batch_size):
304 | # minibatch数据读取
305 | test_img_batch = get_minibatches_content_img(test_img_minibatch_list,
306 | temp_img_shape[0],
307 | temp_img_shape[1])
308 | # 对每次minibatch中test后得到的loss和acc进行累加
309 | sess.run(test_op, feed_dict={img_batch: test_img_batch,
310 | label_batch: test_label_minibatch_list})
311 | # *记录当前load_data_path模型test操作后得到的loss和acc
312 | test_mean_loss, test_mean_accuracy = sess.run([test_loss_s.mean_variable,
313 | test_accuracy_s.mean_variable])
314 | test_loss_arr.append(test_mean_loss)
315 | test_accuracy_arr.append(test_mean_accuracy)
316 | end_time = time.time()
317 | print(load_model_path.split("/")[-1])
318 | print('-TEST- test_loss: %f test_acc: %f | runtime: %.2fs \n'
319 | % (test_loss_arr[-1], test_accuracy_arr[-1], end_time - start_time))
320 |
321 | # *寻找最佳test_acc对应的模型索引
322 | load_best_model_idx = np.argmax(test_accuracy_arr)
323 | print('-BEST TEST- best_path: ', load_model_path_s[load_best_model_idx])
324 | print('-BEST TEST- best_loss: %f best_acc: %f \n'
325 | % (test_loss_arr[load_best_model_idx], test_accuracy_arr[load_best_model_idx]))
326 |
327 | return load_model_path_s[load_best_model_idx]
328 |
329 |
330 | # *学习率下降函数,包含各类学习率下降方法
331 | def learning_rate_decay(init_learning_rate, global_step, decay_steps, decay_rate,
332 | decay_method="exponential", staircase=False,
333 | end_learning_rate=0.0001, power=1.0, cycle=False,):
334 | """
335 | 传入初始learning_rate,根据参数及选项运用不同decay策略更新learning_rate
336 | learning_rate : 初始的learning rate
337 | global_step : 全局的step,与 decay_step 和 decay_rate一起决定了 learning rate的变化
338 | staircase : 如果为 True global_step/decay_step 向下取整
339 | end_learning_rate,power,cycle:只在polynomial_decay方法中使用
340 | """
341 | if decay_method == 'constant':
342 | decayed_learning_rate = init_learning_rate
343 | elif decay_method == 'exponential':
344 | decayed_learning_rate = tf.train.exponential_decay(init_learning_rate, global_step, decay_steps, decay_rate, staircase)
345 | elif decay_method == 'inverse_time':
346 | decayed_learning_rate = tf.train.inverse_time_decay(init_learning_rate, global_step, decay_steps, decay_rate, staircase)
347 | elif decay_method == 'natural_exp':
348 | decayed_learning_rate = tf.train.natural_exp_decay(init_learning_rate, global_step, decay_steps, decay_rate, staircase)
349 | elif decay_method == 'polynomial':
350 | decayed_learning_rate = tf.train.polynomial_decay(init_learning_rate, global_step, decay_steps, decay_rate, end_learning_rate, power, cycle)
351 | else:
352 | decayed_learning_rate = init_learning_rate
353 |
354 | return decayed_learning_rate
355 |
356 |
357 |
358 |
359 | def find_best(model_class, valid_gen, test_gen, valid_batch_size, \
360 | test_batch_size, valid_ds_size, test_ds_size, load_paths):
361 | tf.reset_default_graph()
362 | valid_runner = GeneratorRunner(valid_gen, valid_batch_size * 30)
363 | img_batch, label_batch = valid_runner.get_batched_inputs(valid_batch_size)
364 | model = model_class(False, 'NCHW')
365 | model._build_model(img_batch)
366 | loss, accuracy = model._build_losses(label_batch)
367 | loss_summary = average_summary(loss, 'loss', \
368 | float(valid_ds_size) \
369 | / float(valid_batch_size))
370 | accuracy_summary = average_summary(accuracy, 'accuracy', \
371 | float(valid_ds_size) \
372 | / float(valid_batch_size))
373 | increment_op = tf.group(loss_summary.increment_op, \
374 | accuracy_summary.increment_op)
375 | global_step = tf.get_variable('global_step', dtype=tf.int32, shape=[], \
376 | initializer=tf.constant_initializer(0), \
377 | trainable=False)
378 | init_op = tf.group(tf.global_variables_initializer(), \
379 | tf.local_variables_initializer())
380 | saver = tf.train.Saver(max_to_keep=10000)
381 | accuracy_arr = []
382 | loss_arr = []
383 | print("validation")
384 | for load_path in load_paths:
385 | with tf.Session() as sess:
386 | sess.run(init_op)
387 | saver.restore(sess, load_path) # load_path = './model/checkpoint/model.ckpt'
388 | valid_runner.start_threads(sess, 1)
389 | _time = time.time()
390 | for j in range(0, valid_ds_size, valid_batch_size):
391 | sess.run(increment_op)
392 | mean_loss, mean_accuracy = sess.run([loss_summary.mean_variable ,\
393 | accuracy_summary.mean_variable])
394 | accuracy_arr.append(mean_accuracy)
395 | loss_arr.append(mean_loss)
396 | print(load_path)
397 | print("Accuracy:", accuracy_arr[-1], "| Loss:", loss_arr[-1], \
398 | "in", time.time() - _time, "seconds.")
399 | argmax = np.argmax(accuracy_arr)
400 | print("best savestate:", load_paths[argmax], "with", \
401 | accuracy_arr[argmax], "accuracy and", loss_arr[argmax], \
402 | "loss on validation")
403 | print("test:")
404 | test_dataset(model_class, test_gen, test_batch_size, test_ds_size, \
405 | load_paths[argmax])
406 | return argmax, accuracy_arr, loss_arr
407 | """按照train方式改动
408 | """
--------------------------------------------------------------------------------
/Implement/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from scipy import misc, io
4 | import time
5 | from glob import glob
6 | from generator import *
7 |
8 | # *包含loss与acc变量及操作的average_summary类
9 | class average_summary(object):
10 | def __init__(self, variable, name, num_iterations):
11 | # sum_variable:内在累加器,用于累加每次的loss/acc
12 | self.sum_variable = tf.get_variable(name, shape=[],
13 | initializer=tf.constant_initializer(0.),
14 | dtype='float32',
15 | trainable=False,
16 | collections=[tf.GraphKeys.LOCAL_VARIABLES])
17 | # 每个batch调用一次increment_op,累加每次的loss/acc
18 | with tf.control_dependencies([variable]):
19 | self.increment_op = tf.assign_add(self.sum_variable, variable)
20 | # 当increment_op操作调用了num_iterations次之后,可进行下列操作
21 | self.mean_variable = self.sum_variable / float(num_iterations) # 求平均的loss和acc
22 | self.summary = tf.summary.scalar(name, self.mean_variable) # 将loss和acc存入tf全局图
23 | with tf.control_dependencies([self.summary]):
24 | self.reset_variable_op = tf.assign(self.sum_variable, 0) # 当summary完成后,可进行reset
25 | # 外部调用,将loss/acc存入tf全局图
26 | def add_summary(self, sess, writer, step):
27 | s, _ = sess.run([self.summary, self.reset_variable_op])
28 | writer.add_summary(s, step)
29 |
30 | # *用于挂载Net的结构,包含__build_model和__build_loss的操作
31 | class Model(object):
32 | def __init__(self, is_training=None, data_format='NCHW'):
33 | self.data_format = data_format
34 | if is_training is None:
35 | self.is_training = tf.get_variable('is_training', dtype=tf.bool,
36 | initializer=tf.constant_initializer(True),
37 | trainable=False)
38 | else:
39 | self.is_training = is_training
40 |
41 | def _build_model(self, inputs):
42 | raise NotImplementedError('Here is your model definition')
43 |
44 | def _build_losses(self, labels):
45 | self.labels = tf.cast(labels, tf.int64)
46 | with tf.variable_scope('loss'):
47 | oh = tf.one_hot(self.labels, 2) # 这里定义了2分类的输出
48 | # *除softmax cross entropy之外,还可更换其他函数
49 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
50 | labels=oh, logits=self.outputs))
51 | with tf.variable_scope('accuracy'):
52 | am = tf.argmax(self.outputs, 1)
53 | equal = tf.equal(am, self.labels)
54 | self.accuracy = tf.reduce_mean(tf.cast(equal, tf.float32))
55 | return self.loss, self.accuracy
56 |
57 | # *训练主函数
58 | def train(model_class, use_batch_norm, use_shuf_pair,
59 | train_cover_dir, train_stego_dir,
60 | valid_cover_dir, valid_stego_dir,
61 | batch_size, train_ds_size, valid_ds_size,
62 | log_interval, max_epochs, lr,
63 | log_path, load_path=None):
64 | # *清除默认图的堆栈,设置全局图为默认图
65 | tf.reset_default_graph()
66 |
67 | # *is_training用于判断训练处于train或者valid状态
68 | is_training = tf.get_variable('is_training', dtype=tf.bool,
69 | initializer=True, trainable=False)
70 |
71 | # *定义train_op操作和valid_op操作,将is_training和batch_size设置为对应的状态
72 | disable_training_op = tf.assign(is_training, False)
73 | enable_training_op = tf.assign(is_training, True)
74 |
75 | # *模型初始化
76 | # 设置占位符
77 | temp_cover_list = glob(train_cover_dir + '/*')
78 | temp_img = misc.imread(temp_cover_list[0])
79 | temp_img_shape = temp_img.shape
80 | img_batch = tf.placeholder(tf.float32,
81 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1],
82 | name='input_image_batch')
83 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch")
84 | # 使用占位符初始化模型
85 | model = model_class(is_training=is_training, data_format='NCHW',
86 | with_bn=use_batch_norm, tlu_threshold=3)
87 | model._build_model(img_batch)
88 | loss, accuracy = model._build_losses(label_batch)
89 |
90 | # *设置需要最小化的loss函数
91 | regularization_losses = tf.get_collection(
92 | tf.GraphKeys.REGULARIZATION_LOSSES)
93 | regularized_loss = tf.add_n([loss] + regularization_losses)
94 | # 定义train中使用的基于loss/acc的类(运行次数:log_interval)
95 | train_loss_s = average_summary(loss, 'train_loss', log_interval)
96 | train_accuracy_s = average_summary(accuracy, 'train_accuracy', log_interval)
97 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size)
98 | valid_loss_s = average_summary(loss, 'valid_loss',
99 | float(valid_ds_size) / float(batch_size))
100 | valid_accuracy_s = average_summary(accuracy, 'valid_accuracy',
101 | float(valid_ds_size) / float(batch_size))
102 |
103 | # *全局变量global_step,从0开始进行计数
104 | global_step = tf.Variable(0, trainable=False)
105 | # *定义核心optimizer
106 | # 定义learning_rate的decay操作
107 | init_learning_rate = lr
108 | decay_steps, decay_rate = 2000, 0.95
109 | learning_rate = learning_rate_decay(init_learning_rate=init_learning_rate,
110 | decay_method="exponential",
111 | global_step=global_step,
112 | decay_steps=decay_steps,
113 | decay_rate=decay_rate)
114 | optimizer = tf.train.AdadeltaOptimizer(learning_rate)
115 |
116 | # *定义train及valid过程中需要用到的操作
117 | # 核心操作:最小化loss
118 | minimize_op = optimizer.minimize(loss=regularized_loss, global_step=global_step)
119 | # 训练操作(每个iteration都要用):最小化loss;train_loss累加;train_acc累加
120 | train_op = tf.group(minimize_op, train_loss_s.increment_op,
121 | train_accuracy_s.increment_op)
122 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加
123 | valid_op = tf.group(valid_loss_s.increment_op,
124 | valid_accuracy_s.increment_op)
125 | # 初始化操作:初始化所有的全局变量和局部变量
126 | init_op = tf.group(tf.global_variables_initializer(),
127 | tf.local_variables_initializer())
128 |
129 | # *定义模型保存变量,最大存储max_to_keep个模型
130 | saver = tf.train.Saver(max_to_keep=max_epochs)
131 | global_valid_accuracy = 0 # 全局valid_acc最大值
132 |
133 | # *会话开始
134 | with tf.Session() as sess:
135 | # 初始化所有的全局变量和局部变量
136 | sess.run(init_op)
137 | # 重载模型
138 | if load_path is not None:
139 | loader = tf.train.Saver(reshape=True)
140 | loader.restore(sess, load_path)
141 | # 定义模型及参数保存位置
142 | writer = tf.summary.FileWriter(log_path + '/LogFile/', sess.graph)
143 |
144 | # 初始化train/valid的loss和acc变量
145 | sess.run([valid_loss_s.reset_variable_op,
146 | valid_accuracy_s.reset_variable_op,
147 | train_loss_s.reset_variable_op,
148 | train_accuracy_s.reset_variable_op])
149 |
150 | # *训练开始:train/valid
151 | print('Start training...')
152 | global_train_batch = 0 # 全局batch计数
153 | for epoch in range(max_epochs):
154 | start_time = time.time()
155 | # 加载test路径下的img及label列表
156 | train_img_list, train_label_list = get_files(train_cover_dir,
157 | train_stego_dir,
158 | use_shuf_pair=use_shuf_pair)
159 | # 加载valid路径下的img及label列表
160 | valid_img_list, valid_label_list = get_files(valid_cover_dir,
161 | valid_stego_dir,
162 | use_shuf_pair=use_shuf_pair)
163 |
164 | # *训练开始:train
165 | sess.run(enable_training_op) # 转换为train训练状态
166 | local_train_batch = 0 # 局部batch计数
167 | for train_img_minibatch_list, train_label_minibatch_list in \
168 | get_minibatches(train_img_list, train_label_list, batch_size):
169 | # minibatch数据读取
170 | train_img_batch = get_minibatches_content_img(train_img_minibatch_list,
171 | temp_img_shape[0],
172 | temp_img_shape[1])
173 |
174 | # train操作及指标显示
175 | sess.run(train_op, feed_dict={img_batch: train_img_batch,
176 | label_batch: train_label_minibatch_list})
177 |
178 | global_train_batch += 1
179 | local_train_batch += 1
180 |
181 | # 每log_interval个batch后,对train_loss/acc进行存储
182 | # 这是由于train_loss/acc的average_summary以log_interval为基准定义
183 | if global_train_batch % log_interval == 0:
184 | # 注意:loginterval决定了每20输出一次,而不是每个batch存储loss/acc一次
185 | # train_loss/acc显示
186 | local_train_loss = train_loss_s.mean_variable
187 | local_train_accuracy = train_accuracy_s.mean_variable
188 | local_train_loss_value = local_train_loss.eval(session=sess)
189 | local_train_accuracy_value = local_train_accuracy.eval(session=sess)
190 | print('-TRAIN- epoch: %d batch: %d | train_loss: %f train_acc: %f'
191 | % (epoch, local_train_batch, local_train_loss_value, local_train_accuracy_value))
192 | # train_loss/acc存储
193 | train_loss_s.add_summary(sess, writer, global_train_batch)
194 | train_accuracy_s.add_summary(sess, writer, global_train_batch)
195 |
196 | # *训练开始:validation
197 | sess.run(disable_training_op)
198 | local_valid_loss, local_valid_accuracy = 0, 0 # 本epoch中valid_loss和valid_acc值
199 | for valid_img_minibatch_list, valid_label_minibatch_list in \
200 | get_minibatches(valid_img_list, valid_label_list, batch_size):
201 | # minibatch数据读取
202 | valid_img_batch = get_minibatches_content_img(valid_img_minibatch_list,
203 | temp_img_shape[0],
204 | temp_img_shape[1])
205 |
206 | # valid操作及指标显示
207 | sess.run(valid_op, feed_dict={img_batch: valid_img_batch,
208 | label_batch: valid_label_minibatch_list})
209 |
210 | # 每个epoch中所有batch运行完后,对valid_loss/acc进行显示和存储
211 | # 这是由于valid_loss/acc的average_summary以(valid_ds_size/batch_size)为基准定义
212 | # valid_loss/acc显示
213 | local_valid_loss = valid_loss_s.mean_variable
214 | local_valid_accuracy = valid_accuracy_s.mean_variable
215 | local_valid_loss_value = local_valid_loss.eval(session=sess)
216 | local_valid_accuracy_value = local_valid_accuracy.eval(session=sess)
217 | print('-VALID- epoch: %d | valid_loss: %f valid_acc: %f'
218 | % (epoch, local_valid_loss_value, local_valid_accuracy_value))
219 | # valid_loss/acc存储
220 | valid_loss_s.add_summary(sess, writer, global_train_batch)
221 | valid_accuracy_s.add_summary(sess, writer, global_train_batch)
222 |
223 | # *模型保存:如果valid_acc大于全局valid_acc,则保存
224 | if local_valid_accuracy_value > global_valid_accuracy or (max_epochs - epoch) < 5:
225 | global_valid_accuracy = local_valid_accuracy_value
226 | saver.save(sess, log_path + '/Model_' + str(epoch) + '.ckpt')
227 | print('---EPOCH:%d--- model has been saved' % epoch)
228 |
229 | # *本epoch中train及valid过程均完毕,记录时间
230 | end_time = time.time()
231 | print('--EPOCH:%d-- runtime: %.2fs ' % (epoch, end_time - start_time),
232 | ' learning rate: ', sess.run(learning_rate), '\n')
233 |
234 | # *测试主函数,查找最佳模型
235 | def test_dataset_findbest(model_class, use_shuf_pair,
236 | test_cover_dir, test_stego_dir, max_epochs,
237 | batch_size, ds_size, log_path):
238 | tf.reset_default_graph()
239 |
240 | # *模型初始化
241 | # 设置占位符
242 | temp_cover_list = glob(test_cover_dir + '/*')
243 | temp_img = misc.imread(temp_cover_list[0])
244 | temp_img_shape = temp_img.shape
245 | img_batch = tf.placeholder(tf.float32,
246 | [batch_size, temp_img_shape[0], temp_img_shape[1], 1],
247 | name='input_image_batch')
248 | label_batch = tf.placeholder(tf.int32, [batch_size, ], name="input_label_batch")
249 | # 使用占位符初始化模型
250 | model = model_class(is_training=False, data_format='NCHW', with_bn=True, tlu_threshold=3)
251 | model._build_model(img_batch)
252 | loss, accuracy = model._build_losses(label_batch)
253 |
254 | # *设置需要计算的loss函数,test_loss/acc与valid_loss/acc的功用类似
255 | # 定义valid中使用的基于loss/acc的类(运行次数:valid_ds_size / valid_batch_size)
256 | test_loss_s = average_summary(loss, 'test_loss',
257 | float(ds_size) / float(batch_size))
258 | test_accuracy_s = average_summary(accuracy, 'test_accuracy',
259 | float(ds_size) / float(batch_size))
260 | # 验证操作(一个epoch结束后,每个valid中的iteration都要用):valid_loss累加;valid_acc累加
261 | test_op = tf.group(test_loss_s.increment_op,
262 | test_accuracy_s.increment_op)
263 |
264 | # 初始化操作:初始化所有的全局变量和局部变量
265 | init_op = tf.group(tf.global_variables_initializer(),
266 | tf.local_variables_initializer())
267 |
268 | # *定义模型保存变量,最大存储max_to_keep个模型
269 | saver = tf.train.Saver()
270 |
271 | # *记录每次test后得到的loss和acc
272 | test_loss_arr = []
273 | test_accuracy_arr = []
274 |
275 | # *对load_data_path_s列表中的所有模型进行test操作
276 | print('Start testing...')
277 | # 在log路径下搜寻所有可加载文件
278 | load_model_path_s = glob(log_path + '/*.data*')
279 | for load_model_path in load_model_path_s:
280 | start_time = time.time()
281 | # *会话开始
282 | with tf.Session() as sess:
283 | # 初始化所有的全局变量和局部变量
284 | sess.run(init_op)
285 | # 重载模型,去掉结尾的.data-000...
286 | trunc_str = '.data-'
287 | load_model_path_trunc = load_model_path[0:load_model_path.find(trunc_str)]
288 | saver.restore(sess, load_model_path_trunc)
289 | # 初始化test的loss和acc变量
290 | sess.run([test_loss_s.reset_variable_op,
291 | test_accuracy_s.reset_variable_op])
292 | # 加载test路径下的img及label列表
293 | test_img_list, test_label_list = get_files(test_cover_dir,
294 | test_stego_dir,
295 | use_shuf_pair=use_shuf_pair)
296 | # *对当前load_data_path的模型进行test操作
297 | for test_img_minibatch_list, test_label_minibatch_list in \
298 | get_minibatches(test_img_list, test_label_list, batch_size):
299 | # minibatch数据读取
300 | test_img_batch = get_minibatches_content_img(test_img_minibatch_list,
301 | temp_img_shape[0],
302 | temp_img_shape[1])
303 | # 对每次minibatch中test后得到的loss和acc进行累加
304 | sess.run(test_op, feed_dict={img_batch: test_img_batch,
305 | label_batch: test_label_minibatch_list})
306 | # *记录当前load_data_path模型test操作后得到的loss和acc
307 | test_mean_loss, test_mean_accuracy = sess.run([test_loss_s.mean_variable,
308 | test_accuracy_s.mean_variable])
309 | test_loss_arr.append(test_mean_loss)
310 | test_accuracy_arr.append(test_mean_accuracy)
311 | end_time = time.time()
312 | print(load_model_path.split("/")[-1])
313 | print('-TEST- test_loss: %f test_acc: %f | runtime: %.2fs \n'
314 | % (test_loss_arr[-1], test_accuracy_arr[-1], end_time - start_time))
315 |
316 | # *寻找最佳test_acc对应的模型索引
317 | load_best_model_idx = np.argmax(test_accuracy_arr)
318 | print('-BEST TEST- best_path: ', load_model_path_s[load_best_model_idx])
319 | print('-BEST TEST- best_loss: %f best_acc: %f \n'
320 | % (test_loss_arr[load_best_model_idx], test_accuracy_arr[load_best_model_idx]))
321 |
322 | return load_model_path_s[load_best_model_idx]
323 |
324 |
325 | # *学习率下降函数,包含各类学习率下降方法
326 | def learning_rate_decay(init_learning_rate, global_step, decay_steps, decay_rate,
327 | decay_method="exponential", staircase=False,
328 | end_learning_rate=0.0001, power=1.0, cycle=False,):
329 | """
330 | 传入初始learning_rate,根据参数及选项运用不同decay策略更新learning_rate
331 | learning_rate : 初始的learning rate
332 | global_step : 全局的step,与 decay_step 和 decay_rate一起决定了 learning rate的变化
333 | staircase : 如果为 True global_step/decay_step 向下取整
334 | end_learning_rate,power,cycle:只在polynomial_decay方法中使用
335 | """
336 | if decay_method == 'constant':
337 | decayed_learning_rate = init_learning_rate
338 | elif decay_method == 'exponential':
339 | decayed_learning_rate = tf.train.exponential_decay(init_learning_rate, global_step,
340 | decay_steps, decay_rate, staircase)
341 | elif decay_method == 'inverse_time':
342 | decayed_learning_rate = tf.train.inverse_time_decay(init_learning_rate, global_step,
343 | decay_steps, decay_rate, staircase)
344 | elif decay_method == 'natural_exp':
345 | decayed_learning_rate = tf.train.natural_exp_decay(init_learning_rate, global_step,
346 | decay_steps, decay_rate, staircase)
347 | elif decay_method == 'polynomial':
348 | decayed_learning_rate = tf.train.polynomial_decay(init_learning_rate, global_step,
349 | decay_steps, decay_rate,
350 | end_learning_rate, power, cycle)
351 | else:
352 | decayed_learning_rate = init_learning_rate
353 |
354 | return decayed_learning_rate
355 |
356 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TensorFlow-YeNet
2 | Implementation of "Deep Learning Hierarchical Representation for Image Steganalysis" by TensorFlow
3 |
4 | ## Usage
5 | Example commands(*.sh) can be found in the root directory
6 |
7 | ## Publication
8 | Ye, Jian, J. Ni, and Y. Yi.
9 | "Deep Learning Hierarchical Representations for Image Steganalysis."
10 | IEEE Transactions on Information Forensics & Security 12.11(2017):2545-2557.
11 | [**publication page**](http://ieeexplore.ieee.org/document/7937836/)
12 |
--------------------------------------------------------------------------------
/command_BOSS.sh:
--------------------------------------------------------------------------------
1 | ####data_transfer
2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/BOSS /home/carlchang/YeNetTensorflow/DataTransfer --required-size 256 --required-operation="resize,crop,subsample"
3 | ####data_aug
4 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSS_256_resize/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5
5 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSS_256_crop/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5
6 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSS_256_subsample/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5
7 | ####data_split
8 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" "
9 | ####train
10 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="2,3"
11 | ####test
12 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --gpu="2,3"
13 |
--------------------------------------------------------------------------------
/command_BOSSTEST.sh:
--------------------------------------------------------------------------------
1 | #data_transfer
2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/BOSSTEST /home/carlchang/YeNetTensorflow/DataTransfer --required-size 256 --required-operation="resize,crop,subsample"
3 |
4 | #data_aug
5 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSSTEST_256_resize/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5
6 |
7 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSSTEST_256_crop/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5
8 |
9 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/DataTransfer/BOSSTEST_256_subsample/cover /home/carlchang/YeNetTensorflow/DataAug --ratio-rot=0.5
10 |
11 | #train
12 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="2,3"
13 |
14 | #test
15 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --gpu="2,3"
16 |
17 | #data_split
18 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" "
19 |
--------------------------------------------------------------------------------
/command_SUNI_0.4_15000_No_1.sh:
--------------------------------------------------------------------------------
1 | #train
2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="0,1"
3 |
4 | #test
5 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_1/log" --gpu="0,1"
6 |
7 | #data_split
8 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" "
9 |
--------------------------------------------------------------------------------
/command_SUNI_0.4_15000_No_2.sh:
--------------------------------------------------------------------------------
1 | #train
2 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/train/stego /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/valid/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --use-batch-norm --lr 4e-1 --max-epochs=200 --log-interval=20 --gpu="2,3"
3 |
4 | #test
5 | python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/cover /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/test/stego --log-path="/home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2/log" --gpu="2,3"
6 |
7 | #data_split
8 | #python /home/carlchang/YeNetTensorflow/Implement/main.py /home/carlchang/YeNetTensorflow/Data/SUNI_0.4_15000 /home/carlchang/YeNetTensorflow/Experiment/SUNI_0.4_15000_No_2 --train-percent=0.6 --valid-percent=0.2 --test-percent=0.2 --gpu=" "
9 |
--------------------------------------------------------------------------------