├── figs
├── mae.png
└── clic.png
├── train.sh
├── test.sh
├── README.md
└── mae.py
/figs/mae.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/modulatedautoencoder/HEAD/figs/mae.png
--------------------------------------------------------------------------------
/figs/clic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FireFYF/modulatedautoencoder/HEAD/figs/clic.png
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=0 python mae.py -v --train_glob="/dataset/*.png" train --patchsize 240 --num_filters 192 192 192 --filters_offset 0 0 0 --lambda 128 512 2048 --condition_norm 2048.0 --checkpoint_dir /models/mae --last_step 1200000
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | CUDA_VISIBLE_DEVICES=0 python mae.py --num_filters 192 192 192 --filters_offset 0 0 0 --lambda 128 512 2048 --model_ID 0 --condition 128 --condition_norm 2048.0 --checkpoint_dir /models/mae/ --inputPath /dataset/ --evaluation_name mae evaluate
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variable Rate Deep Image Compression with Modulated Autoencoders
2 | ### [[paper]](https://ieeexplore.ieee.org/document/8977394)
3 |
4 | # Abstract:
5 | Variable rate is a requirement for flexible and adaptable image and video compression. However, deep image compression methods (DIC) are optimized for a single fixed rate-distortion (R-D) tradeoff. While this can be addressed by training multiple models for different tradeoffs, the memory requirements increase proportionally to the number of models. Scaling the bottleneck representation of a shared autoencoder can provide variable rate compression with a single shared autoencoder. However, the R-D performance using this simple mechanism degrades in low bitrates, and also shrinks the effective range of bitrates. To address these limitations, we formulate the problem of variable R-D optimization for DIC, and propose modulated autoencoders (MAEs), where the representations of a shared autoencoder are adapted to the specific R-D tradeoff via a modulation network. Jointly training this modulated autoencoder and the modulation network provides an effective way to navigate the R-D operational curve. Our experiments show that the proposed method can achieve almost the same R-D performance of independent models with significantly fewer parameters.
6 |
7 | # Dependences
8 | - NumPy, SciPy, NVIDIA GPU
9 | - **Data Compression Library:** (https://github.com/tensorflow/compression), thanks to Johannes Ballé, Sung Jin Hwang, and Nick Johnston
10 |
11 | # Installation
12 | - Install compression library with version 1.1 (https://github.com/tensorflow/compression/releases/tag/v1.1).
13 | (In our paper we use the version 1.1 for our MAE method without hyperprior and version 1.2 with hyperprior.)
14 |
15 | # Framework
16 |
17 |

18 |
19 | # Results
20 |
21 | 
22 |
23 | # Main references
24 | Our work heavily relys on the following projects:
25 | - \[1\] 'Lossy Image Compression with Compressive Autoencoders' by Theis et. al, https://arxiv.org/abs/1703.00395
26 | - \[2\] 'End-to-end Optimized Image Compression' by Ballé et. al, https://arxiv.org/abs/1611.01704
27 | - \[3\] 'Variational image compression with a scale hyperprior' by Ballé et. al, https://arxiv.org/abs/1802.01436
28 |
29 | It would be helpful to understand this project if you are familiar with the above projects.
30 | # Contact
31 |
32 | If you run into any problems with this code, please submit a bug report on the Github site of the project. For another inquries pleace contact with me: fyang@cvc.uab.es
33 |
--------------------------------------------------------------------------------
/mae.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import argparse
6 | import glob
7 | import os
8 | from os import listdir
9 | import pdb
10 | import scipy.io
11 | import numpy as np
12 |
13 | import tensorflow as tf
14 | import tensorflow_compression as tfc
15 |
16 | def load_image(filename):
17 | """Loads a PNG image file."""
18 | string = tf.read_file(filename)
19 | image = tf.image.decode_image(string, channels=3)
20 | image = tf.cast(image, tf.float32)
21 | image /= 255
22 | return image
23 |
24 | def quantize_image(image):
25 | image = tf.clip_by_value(image, 0, 1)
26 | image = tf.round(image * 255)
27 | image = tf.cast(image, tf.uint8)
28 | return image
29 |
30 | def save_image(filename, image):
31 | """Saves an image to a PNG file."""
32 | image = quantize_image(image)
33 | string = tf.image.encode_png(image)
34 | return tf.write_file(filename, string)
35 |
36 | #------------------------------------------------------------#
37 | #------------------- Modulation Network -------------------#
38 | #------------------------------------------------------------#
39 | def modulation_network(num_filters, filters_offsets, condition, modulation_init_an=None, modulation_init_syn=None):
40 | with tf.variable_scope("modulation_network"):
41 | num_layers_a, num_layers_s = 3, 3 # change this depend on the number of layers of autoencoder
42 | num = np.array(num_filters)
43 | off = np.array(filters_offsets)
44 | last = num + off
45 |
46 | if args.condition_norm is not None:
47 | condition = condition / args.condition_norm # condition normalization
48 | condition_tf = tf.convert_to_tensor(condition, dtype=tf.float32)
49 | condition_tf = tf.expand_dims(condition_tf, 0)
50 | condition_tf = tf.expand_dims(condition_tf, 0)
51 |
52 | modulation_analysis, modulation_synthesis = list(), list()
53 |
54 | if modulation_init_an is None:
55 | for i in range(num_layers_a):
56 | with tf.variable_scope("modulation_layer_ana%d" % i):
57 | vector = scale_layer(condition_tf, last[-1])
58 | modulation_analysis.append(vector)
59 | else:
60 | for i in range(num_layers_a):
61 | with tf.variable_scope("modulation_layer_ana%d" % i):
62 | vector = scale_layer(condition_tf, last[-1], modulation_init_an[i])
63 | modulation_analysis.append(vector)
64 |
65 | if modulation_init_syn is None:
66 | for i in range(num_layers_s):
67 | with tf.variable_scope("gating_layer_syn%d" % i):
68 | vector = scale_layer(condition_tf, last[-1])
69 | modulation_synthesis.append(vector)
70 | else:
71 | for i in range(num_layers_s):
72 | with tf.variable_scope("gating_layer_syn%d" % i):
73 | vector = scale_layer(condition_tf, last[-1], modulation_init_syn[i])
74 | modulation_synthesis.append(vector)
75 |
76 | return modulation_analysis, modulation_synthesis, last[-1]
77 |
78 | def scale_layer(condition, channel, init=None, reuse=False):
79 | x = linear(condition, 50, scope='linear_1')
80 | x = tf.nn.relu(x)
81 | if init is None:
82 | x = linear(x, channel)
83 | else:
84 | x = linear(x, channel, init)
85 | x = tf.math.exp(x)
86 | return x
87 |
88 | def linear(x, units, init=None, use_bias=True, scope='linear'):
89 | if args.regularizer == "L2":
90 | regular = tf.contrib.layers.l2_regularizer(scale=0.1)
91 | elif args.regularizer == "L1":
92 | regular = tf.contrib.layers.l1_regularizer(scale=0.1)
93 | else:
94 | regular = None
95 | with tf.variable_scope(scope):
96 | if init is None:
97 | x = tf.layers.dense(x, units=units, use_bias=use_bias, kernel_regularizer=regular)
98 | else:
99 | init_w = tf.constant_initializer(init)
100 | x = tf.layers.dense(x, units=units, kernel_initializer=init_w, use_bias=use_bias, kernel_regularizer=regular)
101 | return x
102 |
103 | #------------------------------------------------------------#
104 | #----------------- Modulated Autoencoders -----------------#
105 | #------------------------------------------------------------#
106 | def modulated_analysis_transform(tensor, conds, total_filters_num):
107 | """Builds the modulated analysis transform."""
108 |
109 | with tf.variable_scope("analysis"):
110 | with tf.variable_scope("layer_0"):
111 | layer = tfc.SignalConv2D(
112 | total_filters_num, (9, 9), corr=True, strides_down=4, padding="same_zeros",
113 | use_bias=True, activation=None)
114 | tensor = layer(tensor)
115 | vector = conds[0]
116 | modulated_tensor = tensor * vector
117 |
118 | with tf.variable_scope("gnd_an_0"):
119 | tensor_gdn_0 = tfc.GDN()(modulated_tensor)
120 |
121 | with tf.variable_scope("layer_1"):
122 | layer = tfc.SignalConv2D(
123 | total_filters_num, (5, 5), corr=True, strides_down=2, padding="same_zeros",
124 | use_bias=True, activation=None)
125 | tensor = layer(tensor_gdn_0)
126 | vector = conds[1]
127 | modulated_tensor = tensor * vector
128 |
129 | with tf.variable_scope("gnd_an_1"):
130 | tensor_gdn_1 = tfc.GDN()(modulated_tensor)
131 |
132 | with tf.variable_scope("layer_2"):
133 | layer = tfc.SignalConv2D(
134 | total_filters_num, (5, 5), corr=True, strides_down=2, padding="same_zeros",
135 | use_bias=False, activation=None)
136 | tensor = layer(tensor_gdn_1)
137 | vector = conds[2]
138 | modulated_tensor = tensor * vector
139 |
140 | with tf.variable_scope("gnd_an_2"):
141 | tensor_gdn_2 = tfc.GDN()(modulated_tensor)
142 |
143 | return tensor_gdn_2
144 |
145 | def demodulated_synthesis_transform(tensor, conds, total_filters_num):
146 | """Builds the demodulated synthesis transform."""
147 |
148 | with tf.variable_scope("synthesis"):
149 | with tf.variable_scope("layer_0"):
150 | with tf.variable_scope("gnd_sy_0"):
151 | tensor_igdn_0 = tfc.GDN(inverse=True)(tensor)
152 | vector = conds[0]
153 | demodulated_tensor = tensor_igdn_0 * vector
154 |
155 | layer = tfc.SignalConv2D(
156 | total_filters_num, (5, 5), corr=False, strides_up=2, padding="same_zeros",
157 | use_bias=True, activation=None)
158 | tensor = layer(demodulated_tensor)
159 |
160 | with tf.variable_scope("layer_1"):
161 | with tf.variable_scope("gnd_sy_1"):
162 | tensor_igdn_1 = tfc.GDN(inverse=True)(tensor)
163 | vector = conds[1]
164 | demodulated_tensor = tensor_igdn_1 * vector
165 |
166 | layer = tfc.SignalConv2D(
167 | total_filters_num, (5, 5), corr=False, strides_up=2, padding="same_zeros",
168 | use_bias=True, activation=None)
169 | tensor = layer(demodulated_tensor)
170 |
171 | with tf.variable_scope("layer_2"):
172 | with tf.variable_scope("gnd_sy_2"):
173 | tensor_igdn_2 = tfc.GDN(inverse=True)(tensor)
174 | vector = conds[2]
175 | demodulated_tensor = tensor_igdn_2 * vector
176 |
177 | layer = tfc.SignalConv2D(
178 | 3, (9, 9), corr=False, strides_up=4, padding="same_zeros",
179 | use_bias=True, activation=None)
180 | tensor = layer(demodulated_tensor)
181 |
182 | return tensor
183 |
184 | #----------------- training -----------------#
185 | def train():
186 | """Trains the model."""
187 | if args.verbose:
188 | tf.logging.set_verbosity(tf.logging.INFO)
189 |
190 | # Create input data pipeline.
191 | with tf.device('/cpu:0'):
192 | train_files = glob.glob(args.train_glob)
193 | train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
194 | train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat()
195 | train_dataset = train_dataset.map(
196 | load_image, num_parallel_calls=args.preprocess_threads)
197 | train_dataset = train_dataset.map(
198 | lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3)))
199 | train_dataset = train_dataset.batch(args.batchsize)
200 | train_dataset = train_dataset.prefetch(32)
201 |
202 | num_pixels = args.batchsize * args.patchsize ** 2
203 |
204 | # Get training patch from dataset.
205 | x = train_dataset.make_one_shot_iterator().get_next()
206 |
207 | # lists to keep loss for each lambda
208 | y, y_tilde, entropy_bottlenecks, likelihoods, x_tilde = list(), list(), list(), list(), list()
209 | train_bpp, train_mse, train_loss = list(), list(), list()
210 |
211 | # Forward pass for each RD tradeoff
212 | for i, _lmbda in enumerate(args.lmbda):
213 | with tf.variable_scope("modulation_network", reuse=(i>0)):
214 | cond_an, cond_syn, total_filters_num = modulation_network(args.num_filters, args.filters_offset, args.lmbda[i])
215 |
216 | with tf.variable_scope("analysis", reuse=(i>0)): # Reuse variables when i>0 for sharing
217 | _y = modulated_analysis_transform(x, cond_an, total_filters_num)
218 | y.append(_y)
219 |
220 | entropy_bottlenecks.append(tfc.EntropyBottleneck())
221 | _y_tilde, _likelihoods = entropy_bottlenecks[i](_y, training=True)
222 | y_tilde.append(_y_tilde)
223 | likelihoods.append(_likelihoods)
224 |
225 | with tf.variable_scope("synthesis", reuse=(i > 0)): # Reuse variable when i>0 for sharing
226 | _x_tilde = demodulated_synthesis_transform(y_tilde[i], cond_syn, total_filters_num)
227 | x_tilde.append(_x_tilde)
228 |
229 | # Total number of bits divided by number of pixels.
230 | train_bpp.append(tf.reduce_sum(tf.log(likelihoods[i])) / (-np.log(2) * num_pixels))
231 |
232 | # Mean squared error across pixels.
233 | train_mse.append(tf.reduce_mean(tf.squared_difference(x, x_tilde[i])))
234 |
235 | # The rate-distortion cost.
236 | train_loss.append(_lmbda * train_mse[i] + train_bpp[i])
237 |
238 | total_train_loss = tf.add_n(train_loss)
239 |
240 | step = tf.train.create_global_step()
241 | # learning_rate_placeholder_cnn = tf.placeholder(tf.float32, [], name='learning_rate_cnn')
242 | # learning_rate_placeholder_rate = tf.placeholder(tf.float32, [], name='learning_rate_rate')
243 | # main_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_placeholder_cnn)
244 | main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
245 | main_step = main_optimizer.minimize(total_train_loss, global_step=step)
246 |
247 | aux_optimizers = list()
248 | list_ops = [main_step]
249 | for i, entropy_bottleneck in enumerate(entropy_bottlenecks):
250 | aux_optimizers.append(tf.train.AdamOptimizer(learning_rate=1e-3))
251 | # aux_optimizers.append(tf.train.AdamOptimizer(learning_rate=learning_rate_placeholder_rate))
252 | list_ops.append(aux_optimizers[i].minimize(entropy_bottleneck.losses[0]))
253 | list_ops.append(entropy_bottleneck.updates[0])
254 | train_op = tf.group(list_ops)
255 |
256 | # Summaries
257 | for i, _lmbda in enumerate(args.lmbda):
258 | tf.summary.scalar("loss_%d" % i, train_loss[i])
259 | tf.summary.scalar("bpp_%d" % i, train_bpp[i])
260 | tf.summary.scalar("mse_%d" % i, train_mse[i]* 255 ** 2) # Rescaled
261 | # tf.summary.histogram("hist_layer_a0_%d" % i, features_an[i][0])
262 | # tf.summary.histogram("hist_layer_a1_%d" % i, features_an[i][1])
263 | tf.summary.histogram("hist_y_%d" % i, y[i])
264 | # tf.summary.image("reconstruction_%d" % i, quantize_image(x_tilde[i]))
265 |
266 | tf.summary.scalar("total_loss", total_train_loss)
267 |
268 | hooks = [
269 | tf.train.StopAtStepHook(last_step=args.last_step),
270 | tf.train.NanTensorHook(total_train_loss),
271 | ]
272 |
273 | with tf.train.MonitoredTrainingSession(
274 | hooks=hooks, checkpoint_dir=args.checkpoint_dir,
275 | save_checkpoint_secs=900, save_summaries_secs=600) as sess:
276 | while not sess.should_stop():
277 | # learning_rate_cnn = 4e-4 if step < 400000 else 2e-4
278 | # learning_rate_rate = 2e-3 if step < 400000 else 1e-3
279 | # pdb.set_trace()
280 | sess.run(step)
281 | sess.run(train_op)
282 | # sess.run(train_op, feed_dict={learning_rate_placeholder_cnn:learning_rate_cnn, learning_rate_placeholder_rate: learning_rate_rate})
283 |
284 | #----------------- evaluate -----------------#
285 | def evaluate():
286 | """Evaluate the model for test dataset"""
287 | # process all the images in input_path
288 | imagesList = listdir(args.inputPath)
289 | # Initialize metric scores
290 | bpp_actual_total, bpp_estimate_total, mse_total, psnr_total, msssim_total, msssim_db_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
291 | # get all the entropy models or one entropy model
292 | if args.model_ID is not None:
293 | entropy_bottlenecks = list()
294 | for i, _lmbda in enumerate(args.lmbda):
295 | entropy_bottlenecks.append(tfc.EntropyBottleneck())
296 | entropy_bottleneck = entropy_bottlenecks[args.model_ID]
297 | with tf.variable_scope("modulation_network", reuse=tf.AUTO_REUSE):
298 | cond_an, cond_syn, total_filters_num = modulation_network(args.num_filters, args.filters_offset, args.condition)
299 | else:
300 | print('error: model_ID is necessary for one specific entropy model')
301 |
302 | for image in imagesList:
303 | x = load_image(args.inputPath + image)
304 | x = tf.expand_dims(x, 0)
305 | x.set_shape([1, None, None, 3])
306 |
307 | with tf.variable_scope("analysis", reuse=tf.AUTO_REUSE): # Reuse variable when i>0 for sharing
308 | y = modulated_analysis_transform(x, cond_an, total_filters_num)
309 |
310 | string = entropy_bottleneck.compress(y)
311 | string = tf.squeeze(string, axis=0)
312 | y_hat, likelihoods = entropy_bottleneck(y, training=False)
313 |
314 | with tf.variable_scope("synthesis", reuse=tf.AUTO_REUSE):
315 | x_hat_first = demodulated_synthesis_transform(y_hat, cond_syn, total_filters_num)
316 |
317 | num_pixels = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1]))
318 | # Total number of bits divided by number of pixels.
319 | eval_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
320 | # Bring both images back to 0..255 range.
321 | x *= 255
322 | x_hat = tf.clip_by_value(x_hat_first, 0, 1)
323 | x_hat = tf.round(x_hat * 255)
324 | x_hat = tf.slice(x_hat, [0, 0, 0, 0], [1,tf.shape(x)[1], tf.shape(x)[2], 3])
325 |
326 | mse = tf.reduce_mean(tf.squared_difference(x, x_hat))
327 | psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255))
328 | msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255))
329 |
330 | if args.save_reconstruction:
331 | x_shape = tf.shape(x)
332 | x_hat_first = x_hat_first[0, :x_shape[1], :x_shape[2], :]
333 | if os.path.isdir(args.outputPath):
334 | print(args.outputPath + ':exists.')
335 | else:
336 | os.makedirs(args.outputPath)
337 | print(args.outputPath + ':created.')
338 | op = save_image(args.outputPath + image, x_hat_first)
339 |
340 | with tf.Session() as sess:
341 | # Load the latest model checkpoint, get the evaluation results.
342 | latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
343 | tf.train.Saver().restore(sess, save_path=latest)
344 |
345 | string, eval_bpp, mse, psnr, msssim, num_pixels = sess.run(
346 | [string, eval_bpp, mse, psnr, msssim, num_pixels])
347 |
348 | sess.run(op)
349 |
350 | # The actual bits per pixel including overhead
351 | bpp = (8 + len(string)) * 8 / num_pixels
352 |
353 | print("Mean squared error: {:0.4f}".format(mse))
354 | print("PSNR (dB): {:0.2f}".format(psnr))
355 | print("Multiscale SSIM: {:0.4f}".format(msssim))
356 | print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssim)))
357 | print("Information content in bpp: {:0.4f}".format(eval_bpp))
358 | print("Actual bits per pixel: {:0.4f}".format(bpp))
359 |
360 | with open (args.outputPath + image[:-4] + '.txt', 'w') as f:
361 | f.write('Avg_bpp_actual: '+str(bpp)+'\n')
362 | f.write('Avg_bpp_estimate: '+str(eval_bpp)+'\n')
363 | f.write('Avg_mse: '+str(mse)+'\n')
364 | f.write('Avg_psnr: '+str(psnr)+'\n')
365 | f.write('Avg_msssim: '+str(msssim)+'\n')
366 | f.write('Avg_msssim_db: '+str(-10 * np.log10(1 - msssim))+'\n')
367 | else:
368 | with tf.Session() as sess:
369 | # Load the latest model checkpoint, get the compressed string and the tensor
370 | # shapes.
371 | latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
372 | tf.train.Saver().restore(sess, save_path=latest)
373 |
374 | string, eval_bpp, mse, psnr, msssim, num_pixels = sess.run(
375 | [string, eval_bpp, mse, psnr, msssim, num_pixels])
376 |
377 | # The actual bits per pixel including overhead
378 | bpp = (8 + len(string)) * 8 / num_pixels
379 |
380 | print("Mean squared error: {:0.4f}".format(mse))
381 | print("PSNR (dB): {:0.2f}".format(psnr))
382 | print("Multiscale SSIM: {:0.4f}".format(msssim))
383 | print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssim)))
384 | print("Information content in bpp: {:0.4f}".format(eval_bpp))
385 | print("Actual bits per pixel: {:0.4f}".format(bpp))
386 |
387 | bpp_actual_total += bpp
388 | bpp_estimate_total += eval_bpp
389 | mse_total += mse
390 | psnr_total += psnr
391 | msssim_total += msssim
392 | msssim_db_total += (-10 * np.log10(1 - msssim))
393 |
394 | if args.evaluation_name is not None:
395 | Avg_bpp_actual, Avg_bpp_estimate = bpp_actual_total / len(imagesList), bpp_estimate_total / len(imagesList)
396 | Avg_mse, Avg_psnr = mse_total / len(imagesList), psnr_total / len(imagesList)
397 | Avg_msssim, Avg_msssim_db = msssim_total / len(imagesList), msssim_db_total / len(imagesList)
398 | with open (args.evaluation_name + '.txt', 'w') as f:
399 | f.write('Avg_bpp_actual: '+str(Avg_bpp_actual)+'\n')
400 | f.write('Avg_bpp_estimate: '+str(Avg_bpp_estimate)+'\n')
401 | f.write('Avg_mse: '+str(Avg_mse)+'\n')
402 | f.write('Avg_psnr: '+str(Avg_psnr)+'\n')
403 | f.write('Avg_msssim: '+str(Avg_msssim)+'\n')
404 | f.write('Avg_msssim_db: '+str(Avg_msssim_db)+'\n')
405 |
406 | if __name__ == "__main__":
407 | parser = argparse.ArgumentParser(
408 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
409 |
410 | parser.add_argument(
411 | "command", choices=["train", "evaluate"],
412 | help="'train' loads training data and trains (or continues "
413 | "to train) a new model. 'evaluate' can get the RD curves from"
414 | "the pretrained model.")
415 |
416 | parser.add_argument(
417 | "--verbose", "-v", action="store_true",
418 | help="Report bitrate and distortion when training or compressing.")
419 | parser.add_argument(
420 | "--num_filters", nargs="+", type=int, default=[128],
421 | help="Number of filters per layer (per R-D tradeoff point).")
422 | parser.add_argument(
423 | "--checkpoint_dir", default="train",
424 | help="Directory where to save/load model checkpoints.")
425 | parser.add_argument(
426 | "--train_glob", default="images/*.png",
427 | help="Glob pattern identifying training data. This pattern must expand "
428 | "to a list of RGB images in PNG format.")
429 | parser.add_argument(
430 | "--batchsize", type=int, default=8,
431 | help="Batch size for training.")
432 | parser.add_argument(
433 | "--patchsize", type=int, default=256,
434 | help="Size of image patches for training.")
435 | parser.add_argument(
436 | "--lambda", nargs="+", type=float, default=[512], dest="lmbda",
437 | help="Lambdas for rate-distortion tradeoff points.")
438 | parser.add_argument(
439 | "--last_step", type=int, default=1000000,
440 | help="Train up to this number of steps.")
441 | parser.add_argument(
442 | "--preprocess_threads", type=int, default=6,
443 | help="Number of CPU threads to use for parallel decoding of training "
444 | "images.")
445 | parser.add_argument(
446 | "--modulation_init", action="store_true",
447 | help="Initialize the modulation network by using the default vectors.")
448 | parser.add_argument(
449 | "--filters_offset", nargs="+", type=int, default=[0],
450 | help="Offset filters (per R-D tradeoff point)")
451 | parser.add_argument(
452 | "--save_reconstruction", action="store_true",
453 | help="save reconstructed image while evaluation")
454 | parser.add_argument(
455 | "--model_ID", type=int, default=[0],
456 | help="Align the model which you want to use for compression/decompression.")
457 | parser.add_argument(
458 | "--condition", type=int, default=None,
459 | help="condition for different RD trade-off.")
460 | parser.add_argument(
461 | "--condition_norm", type=float, default=None,
462 | help="Normalization of condition values.")
463 | parser.add_argument(
464 | "--evaluation_name", type=str, default='results',
465 | help="the name of evaluation results txt file.")
466 | parser.add_argument(
467 | "--inputPath", type=str, default=None,
468 | help="Directory where to evaluation dataset.")
469 | parser.add_argument(
470 | "--outputPath", type=str, default=None,
471 | help="Directory where to save reconstructed images.")
472 | parser.add_argument(
473 | "--regularizer", type=str, default=None,
474 | help="regularizer of modulation network.")
475 |
476 | args = parser.parse_args()
477 |
478 | if args.command == "train":
479 | # Check consistency between lambda, num_filters, filters_offset
480 | if len(args.lmbda) != len(args.num_filters):
481 | raise ValueError("The length of lambda and num_filters should be the same.")
482 | if len(args.num_filters) != len(args.filters_offset):
483 | raise ValueError("The length num_filters and filters_offset should be the same.")
484 | train()
485 | elif args.command == "evaluate":
486 | if args.inputPath is None:
487 | raise ValueError("Need input path for evaluation.")
488 | evaluate()
489 |
--------------------------------------------------------------------------------