├── figs ├── mae.png └── clic.png ├── train.sh ├── test.sh ├── README.md └── mae.py /figs/mae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/modulatedautoencoder/HEAD/figs/mae.png -------------------------------------------------------------------------------- /figs/clic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FireFYF/modulatedautoencoder/HEAD/figs/clic.png -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python mae.py -v --train_glob="/dataset/*.png" train --patchsize 240 --num_filters 192 192 192 --filters_offset 0 0 0 --lambda 128 512 2048 --condition_norm 2048.0 --checkpoint_dir /models/mae --last_step 1200000 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | CUDA_VISIBLE_DEVICES=0 python mae.py --num_filters 192 192 192 --filters_offset 0 0 0 --lambda 128 512 2048 --model_ID 0 --condition 128 --condition_norm 2048.0 --checkpoint_dir /models/mae/ --inputPath /dataset/ --evaluation_name mae evaluate 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variable Rate Deep Image Compression with Modulated Autoencoders 2 | ### [[paper]](https://ieeexplore.ieee.org/document/8977394) 3 | 4 | # Abstract: 5 | Variable rate is a requirement for flexible and adaptable image and video compression. However, deep image compression methods (DIC) are optimized for a single fixed rate-distortion (R-D) tradeoff. While this can be addressed by training multiple models for different tradeoffs, the memory requirements increase proportionally to the number of models. Scaling the bottleneck representation of a shared autoencoder can provide variable rate compression with a single shared autoencoder. However, the R-D performance using this simple mechanism degrades in low bitrates, and also shrinks the effective range of bitrates. To address these limitations, we formulate the problem of variable R-D optimization for DIC, and propose modulated autoencoders (MAEs), where the representations of a shared autoencoder are adapted to the specific R-D tradeoff via a modulation network. Jointly training this modulated autoencoder and the modulation network provides an effective way to navigate the R-D operational curve. Our experiments show that the proposed method can achieve almost the same R-D performance of independent models with significantly fewer parameters. 6 | 7 | # Dependences 8 | - NumPy, SciPy, NVIDIA GPU 9 | - **Data Compression Library:** (https://github.com/tensorflow/compression), thanks to Johannes Ballé, Sung Jin Hwang, and Nick Johnston 10 | 11 | # Installation 12 | - Install compression library with version 1.1 (https://github.com/tensorflow/compression/releases/tag/v1.1). 13 | (In our paper we use the version 1.1 for our MAE method without hyperprior and version 1.2 with hyperprior.) 14 | 15 | # Framework 16 |
17 |

18 | 19 | # Results 20 |
21 |

22 | 23 | # Main references 24 | Our work heavily relys on the following projects: 25 | - \[1\] 'Lossy Image Compression with Compressive Autoencoders' by Theis et. al, https://arxiv.org/abs/1703.00395 26 | - \[2\] 'End-to-end Optimized Image Compression' by Ballé et. al, https://arxiv.org/abs/1611.01704 27 | - \[3\] 'Variational image compression with a scale hyperprior' by Ballé et. al, https://arxiv.org/abs/1802.01436 28 | 29 | It would be helpful to understand this project if you are familiar with the above projects. 30 | # Contact 31 | 32 | If you run into any problems with this code, please submit a bug report on the Github site of the project. For another inquries pleace contact with me: fyang@cvc.uab.es 33 | -------------------------------------------------------------------------------- /mae.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import glob 7 | import os 8 | from os import listdir 9 | import pdb 10 | import scipy.io 11 | import numpy as np 12 | 13 | import tensorflow as tf 14 | import tensorflow_compression as tfc 15 | 16 | def load_image(filename): 17 | """Loads a PNG image file.""" 18 | string = tf.read_file(filename) 19 | image = tf.image.decode_image(string, channels=3) 20 | image = tf.cast(image, tf.float32) 21 | image /= 255 22 | return image 23 | 24 | def quantize_image(image): 25 | image = tf.clip_by_value(image, 0, 1) 26 | image = tf.round(image * 255) 27 | image = tf.cast(image, tf.uint8) 28 | return image 29 | 30 | def save_image(filename, image): 31 | """Saves an image to a PNG file.""" 32 | image = quantize_image(image) 33 | string = tf.image.encode_png(image) 34 | return tf.write_file(filename, string) 35 | 36 | #------------------------------------------------------------# 37 | #------------------- Modulation Network -------------------# 38 | #------------------------------------------------------------# 39 | def modulation_network(num_filters, filters_offsets, condition, modulation_init_an=None, modulation_init_syn=None): 40 | with tf.variable_scope("modulation_network"): 41 | num_layers_a, num_layers_s = 3, 3 # change this depend on the number of layers of autoencoder 42 | num = np.array(num_filters) 43 | off = np.array(filters_offsets) 44 | last = num + off 45 | 46 | if args.condition_norm is not None: 47 | condition = condition / args.condition_norm # condition normalization 48 | condition_tf = tf.convert_to_tensor(condition, dtype=tf.float32) 49 | condition_tf = tf.expand_dims(condition_tf, 0) 50 | condition_tf = tf.expand_dims(condition_tf, 0) 51 | 52 | modulation_analysis, modulation_synthesis = list(), list() 53 | 54 | if modulation_init_an is None: 55 | for i in range(num_layers_a): 56 | with tf.variable_scope("modulation_layer_ana%d" % i): 57 | vector = scale_layer(condition_tf, last[-1]) 58 | modulation_analysis.append(vector) 59 | else: 60 | for i in range(num_layers_a): 61 | with tf.variable_scope("modulation_layer_ana%d" % i): 62 | vector = scale_layer(condition_tf, last[-1], modulation_init_an[i]) 63 | modulation_analysis.append(vector) 64 | 65 | if modulation_init_syn is None: 66 | for i in range(num_layers_s): 67 | with tf.variable_scope("gating_layer_syn%d" % i): 68 | vector = scale_layer(condition_tf, last[-1]) 69 | modulation_synthesis.append(vector) 70 | else: 71 | for i in range(num_layers_s): 72 | with tf.variable_scope("gating_layer_syn%d" % i): 73 | vector = scale_layer(condition_tf, last[-1], modulation_init_syn[i]) 74 | modulation_synthesis.append(vector) 75 | 76 | return modulation_analysis, modulation_synthesis, last[-1] 77 | 78 | def scale_layer(condition, channel, init=None, reuse=False): 79 | x = linear(condition, 50, scope='linear_1') 80 | x = tf.nn.relu(x) 81 | if init is None: 82 | x = linear(x, channel) 83 | else: 84 | x = linear(x, channel, init) 85 | x = tf.math.exp(x) 86 | return x 87 | 88 | def linear(x, units, init=None, use_bias=True, scope='linear'): 89 | if args.regularizer == "L2": 90 | regular = tf.contrib.layers.l2_regularizer(scale=0.1) 91 | elif args.regularizer == "L1": 92 | regular = tf.contrib.layers.l1_regularizer(scale=0.1) 93 | else: 94 | regular = None 95 | with tf.variable_scope(scope): 96 | if init is None: 97 | x = tf.layers.dense(x, units=units, use_bias=use_bias, kernel_regularizer=regular) 98 | else: 99 | init_w = tf.constant_initializer(init) 100 | x = tf.layers.dense(x, units=units, kernel_initializer=init_w, use_bias=use_bias, kernel_regularizer=regular) 101 | return x 102 | 103 | #------------------------------------------------------------# 104 | #----------------- Modulated Autoencoders -----------------# 105 | #------------------------------------------------------------# 106 | def modulated_analysis_transform(tensor, conds, total_filters_num): 107 | """Builds the modulated analysis transform.""" 108 | 109 | with tf.variable_scope("analysis"): 110 | with tf.variable_scope("layer_0"): 111 | layer = tfc.SignalConv2D( 112 | total_filters_num, (9, 9), corr=True, strides_down=4, padding="same_zeros", 113 | use_bias=True, activation=None) 114 | tensor = layer(tensor) 115 | vector = conds[0] 116 | modulated_tensor = tensor * vector 117 | 118 | with tf.variable_scope("gnd_an_0"): 119 | tensor_gdn_0 = tfc.GDN()(modulated_tensor) 120 | 121 | with tf.variable_scope("layer_1"): 122 | layer = tfc.SignalConv2D( 123 | total_filters_num, (5, 5), corr=True, strides_down=2, padding="same_zeros", 124 | use_bias=True, activation=None) 125 | tensor = layer(tensor_gdn_0) 126 | vector = conds[1] 127 | modulated_tensor = tensor * vector 128 | 129 | with tf.variable_scope("gnd_an_1"): 130 | tensor_gdn_1 = tfc.GDN()(modulated_tensor) 131 | 132 | with tf.variable_scope("layer_2"): 133 | layer = tfc.SignalConv2D( 134 | total_filters_num, (5, 5), corr=True, strides_down=2, padding="same_zeros", 135 | use_bias=False, activation=None) 136 | tensor = layer(tensor_gdn_1) 137 | vector = conds[2] 138 | modulated_tensor = tensor * vector 139 | 140 | with tf.variable_scope("gnd_an_2"): 141 | tensor_gdn_2 = tfc.GDN()(modulated_tensor) 142 | 143 | return tensor_gdn_2 144 | 145 | def demodulated_synthesis_transform(tensor, conds, total_filters_num): 146 | """Builds the demodulated synthesis transform.""" 147 | 148 | with tf.variable_scope("synthesis"): 149 | with tf.variable_scope("layer_0"): 150 | with tf.variable_scope("gnd_sy_0"): 151 | tensor_igdn_0 = tfc.GDN(inverse=True)(tensor) 152 | vector = conds[0] 153 | demodulated_tensor = tensor_igdn_0 * vector 154 | 155 | layer = tfc.SignalConv2D( 156 | total_filters_num, (5, 5), corr=False, strides_up=2, padding="same_zeros", 157 | use_bias=True, activation=None) 158 | tensor = layer(demodulated_tensor) 159 | 160 | with tf.variable_scope("layer_1"): 161 | with tf.variable_scope("gnd_sy_1"): 162 | tensor_igdn_1 = tfc.GDN(inverse=True)(tensor) 163 | vector = conds[1] 164 | demodulated_tensor = tensor_igdn_1 * vector 165 | 166 | layer = tfc.SignalConv2D( 167 | total_filters_num, (5, 5), corr=False, strides_up=2, padding="same_zeros", 168 | use_bias=True, activation=None) 169 | tensor = layer(demodulated_tensor) 170 | 171 | with tf.variable_scope("layer_2"): 172 | with tf.variable_scope("gnd_sy_2"): 173 | tensor_igdn_2 = tfc.GDN(inverse=True)(tensor) 174 | vector = conds[2] 175 | demodulated_tensor = tensor_igdn_2 * vector 176 | 177 | layer = tfc.SignalConv2D( 178 | 3, (9, 9), corr=False, strides_up=4, padding="same_zeros", 179 | use_bias=True, activation=None) 180 | tensor = layer(demodulated_tensor) 181 | 182 | return tensor 183 | 184 | #----------------- training -----------------# 185 | def train(): 186 | """Trains the model.""" 187 | if args.verbose: 188 | tf.logging.set_verbosity(tf.logging.INFO) 189 | 190 | # Create input data pipeline. 191 | with tf.device('/cpu:0'): 192 | train_files = glob.glob(args.train_glob) 193 | train_dataset = tf.data.Dataset.from_tensor_slices(train_files) 194 | train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat() 195 | train_dataset = train_dataset.map( 196 | load_image, num_parallel_calls=args.preprocess_threads) 197 | train_dataset = train_dataset.map( 198 | lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3))) 199 | train_dataset = train_dataset.batch(args.batchsize) 200 | train_dataset = train_dataset.prefetch(32) 201 | 202 | num_pixels = args.batchsize * args.patchsize ** 2 203 | 204 | # Get training patch from dataset. 205 | x = train_dataset.make_one_shot_iterator().get_next() 206 | 207 | # lists to keep loss for each lambda 208 | y, y_tilde, entropy_bottlenecks, likelihoods, x_tilde = list(), list(), list(), list(), list() 209 | train_bpp, train_mse, train_loss = list(), list(), list() 210 | 211 | # Forward pass for each RD tradeoff 212 | for i, _lmbda in enumerate(args.lmbda): 213 | with tf.variable_scope("modulation_network", reuse=(i>0)): 214 | cond_an, cond_syn, total_filters_num = modulation_network(args.num_filters, args.filters_offset, args.lmbda[i]) 215 | 216 | with tf.variable_scope("analysis", reuse=(i>0)): # Reuse variables when i>0 for sharing 217 | _y = modulated_analysis_transform(x, cond_an, total_filters_num) 218 | y.append(_y) 219 | 220 | entropy_bottlenecks.append(tfc.EntropyBottleneck()) 221 | _y_tilde, _likelihoods = entropy_bottlenecks[i](_y, training=True) 222 | y_tilde.append(_y_tilde) 223 | likelihoods.append(_likelihoods) 224 | 225 | with tf.variable_scope("synthesis", reuse=(i > 0)): # Reuse variable when i>0 for sharing 226 | _x_tilde = demodulated_synthesis_transform(y_tilde[i], cond_syn, total_filters_num) 227 | x_tilde.append(_x_tilde) 228 | 229 | # Total number of bits divided by number of pixels. 230 | train_bpp.append(tf.reduce_sum(tf.log(likelihoods[i])) / (-np.log(2) * num_pixels)) 231 | 232 | # Mean squared error across pixels. 233 | train_mse.append(tf.reduce_mean(tf.squared_difference(x, x_tilde[i]))) 234 | 235 | # The rate-distortion cost. 236 | train_loss.append(_lmbda * train_mse[i] + train_bpp[i]) 237 | 238 | total_train_loss = tf.add_n(train_loss) 239 | 240 | step = tf.train.create_global_step() 241 | # learning_rate_placeholder_cnn = tf.placeholder(tf.float32, [], name='learning_rate_cnn') 242 | # learning_rate_placeholder_rate = tf.placeholder(tf.float32, [], name='learning_rate_rate') 243 | # main_optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_placeholder_cnn) 244 | main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) 245 | main_step = main_optimizer.minimize(total_train_loss, global_step=step) 246 | 247 | aux_optimizers = list() 248 | list_ops = [main_step] 249 | for i, entropy_bottleneck in enumerate(entropy_bottlenecks): 250 | aux_optimizers.append(tf.train.AdamOptimizer(learning_rate=1e-3)) 251 | # aux_optimizers.append(tf.train.AdamOptimizer(learning_rate=learning_rate_placeholder_rate)) 252 | list_ops.append(aux_optimizers[i].minimize(entropy_bottleneck.losses[0])) 253 | list_ops.append(entropy_bottleneck.updates[0]) 254 | train_op = tf.group(list_ops) 255 | 256 | # Summaries 257 | for i, _lmbda in enumerate(args.lmbda): 258 | tf.summary.scalar("loss_%d" % i, train_loss[i]) 259 | tf.summary.scalar("bpp_%d" % i, train_bpp[i]) 260 | tf.summary.scalar("mse_%d" % i, train_mse[i]* 255 ** 2) # Rescaled 261 | # tf.summary.histogram("hist_layer_a0_%d" % i, features_an[i][0]) 262 | # tf.summary.histogram("hist_layer_a1_%d" % i, features_an[i][1]) 263 | tf.summary.histogram("hist_y_%d" % i, y[i]) 264 | # tf.summary.image("reconstruction_%d" % i, quantize_image(x_tilde[i])) 265 | 266 | tf.summary.scalar("total_loss", total_train_loss) 267 | 268 | hooks = [ 269 | tf.train.StopAtStepHook(last_step=args.last_step), 270 | tf.train.NanTensorHook(total_train_loss), 271 | ] 272 | 273 | with tf.train.MonitoredTrainingSession( 274 | hooks=hooks, checkpoint_dir=args.checkpoint_dir, 275 | save_checkpoint_secs=900, save_summaries_secs=600) as sess: 276 | while not sess.should_stop(): 277 | # learning_rate_cnn = 4e-4 if step < 400000 else 2e-4 278 | # learning_rate_rate = 2e-3 if step < 400000 else 1e-3 279 | # pdb.set_trace() 280 | sess.run(step) 281 | sess.run(train_op) 282 | # sess.run(train_op, feed_dict={learning_rate_placeholder_cnn:learning_rate_cnn, learning_rate_placeholder_rate: learning_rate_rate}) 283 | 284 | #----------------- evaluate -----------------# 285 | def evaluate(): 286 | """Evaluate the model for test dataset""" 287 | # process all the images in input_path 288 | imagesList = listdir(args.inputPath) 289 | # Initialize metric scores 290 | bpp_actual_total, bpp_estimate_total, mse_total, psnr_total, msssim_total, msssim_db_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 291 | # get all the entropy models or one entropy model 292 | if args.model_ID is not None: 293 | entropy_bottlenecks = list() 294 | for i, _lmbda in enumerate(args.lmbda): 295 | entropy_bottlenecks.append(tfc.EntropyBottleneck()) 296 | entropy_bottleneck = entropy_bottlenecks[args.model_ID] 297 | with tf.variable_scope("modulation_network", reuse=tf.AUTO_REUSE): 298 | cond_an, cond_syn, total_filters_num = modulation_network(args.num_filters, args.filters_offset, args.condition) 299 | else: 300 | print('error: model_ID is necessary for one specific entropy model') 301 | 302 | for image in imagesList: 303 | x = load_image(args.inputPath + image) 304 | x = tf.expand_dims(x, 0) 305 | x.set_shape([1, None, None, 3]) 306 | 307 | with tf.variable_scope("analysis", reuse=tf.AUTO_REUSE): # Reuse variable when i>0 for sharing 308 | y = modulated_analysis_transform(x, cond_an, total_filters_num) 309 | 310 | string = entropy_bottleneck.compress(y) 311 | string = tf.squeeze(string, axis=0) 312 | y_hat, likelihoods = entropy_bottleneck(y, training=False) 313 | 314 | with tf.variable_scope("synthesis", reuse=tf.AUTO_REUSE): 315 | x_hat_first = demodulated_synthesis_transform(y_hat, cond_syn, total_filters_num) 316 | 317 | num_pixels = tf.to_float(tf.reduce_prod(tf.shape(x)[:-1])) 318 | # Total number of bits divided by number of pixels. 319 | eval_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels) 320 | # Bring both images back to 0..255 range. 321 | x *= 255 322 | x_hat = tf.clip_by_value(x_hat_first, 0, 1) 323 | x_hat = tf.round(x_hat * 255) 324 | x_hat = tf.slice(x_hat, [0, 0, 0, 0], [1,tf.shape(x)[1], tf.shape(x)[2], 3]) 325 | 326 | mse = tf.reduce_mean(tf.squared_difference(x, x_hat)) 327 | psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255)) 328 | msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255)) 329 | 330 | if args.save_reconstruction: 331 | x_shape = tf.shape(x) 332 | x_hat_first = x_hat_first[0, :x_shape[1], :x_shape[2], :] 333 | if os.path.isdir(args.outputPath): 334 | print(args.outputPath + ':exists.') 335 | else: 336 | os.makedirs(args.outputPath) 337 | print(args.outputPath + ':created.') 338 | op = save_image(args.outputPath + image, x_hat_first) 339 | 340 | with tf.Session() as sess: 341 | # Load the latest model checkpoint, get the evaluation results. 342 | latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) 343 | tf.train.Saver().restore(sess, save_path=latest) 344 | 345 | string, eval_bpp, mse, psnr, msssim, num_pixels = sess.run( 346 | [string, eval_bpp, mse, psnr, msssim, num_pixels]) 347 | 348 | sess.run(op) 349 | 350 | # The actual bits per pixel including overhead 351 | bpp = (8 + len(string)) * 8 / num_pixels 352 | 353 | print("Mean squared error: {:0.4f}".format(mse)) 354 | print("PSNR (dB): {:0.2f}".format(psnr)) 355 | print("Multiscale SSIM: {:0.4f}".format(msssim)) 356 | print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssim))) 357 | print("Information content in bpp: {:0.4f}".format(eval_bpp)) 358 | print("Actual bits per pixel: {:0.4f}".format(bpp)) 359 | 360 | with open (args.outputPath + image[:-4] + '.txt', 'w') as f: 361 | f.write('Avg_bpp_actual: '+str(bpp)+'\n') 362 | f.write('Avg_bpp_estimate: '+str(eval_bpp)+'\n') 363 | f.write('Avg_mse: '+str(mse)+'\n') 364 | f.write('Avg_psnr: '+str(psnr)+'\n') 365 | f.write('Avg_msssim: '+str(msssim)+'\n') 366 | f.write('Avg_msssim_db: '+str(-10 * np.log10(1 - msssim))+'\n') 367 | else: 368 | with tf.Session() as sess: 369 | # Load the latest model checkpoint, get the compressed string and the tensor 370 | # shapes. 371 | latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir) 372 | tf.train.Saver().restore(sess, save_path=latest) 373 | 374 | string, eval_bpp, mse, psnr, msssim, num_pixels = sess.run( 375 | [string, eval_bpp, mse, psnr, msssim, num_pixels]) 376 | 377 | # The actual bits per pixel including overhead 378 | bpp = (8 + len(string)) * 8 / num_pixels 379 | 380 | print("Mean squared error: {:0.4f}".format(mse)) 381 | print("PSNR (dB): {:0.2f}".format(psnr)) 382 | print("Multiscale SSIM: {:0.4f}".format(msssim)) 383 | print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssim))) 384 | print("Information content in bpp: {:0.4f}".format(eval_bpp)) 385 | print("Actual bits per pixel: {:0.4f}".format(bpp)) 386 | 387 | bpp_actual_total += bpp 388 | bpp_estimate_total += eval_bpp 389 | mse_total += mse 390 | psnr_total += psnr 391 | msssim_total += msssim 392 | msssim_db_total += (-10 * np.log10(1 - msssim)) 393 | 394 | if args.evaluation_name is not None: 395 | Avg_bpp_actual, Avg_bpp_estimate = bpp_actual_total / len(imagesList), bpp_estimate_total / len(imagesList) 396 | Avg_mse, Avg_psnr = mse_total / len(imagesList), psnr_total / len(imagesList) 397 | Avg_msssim, Avg_msssim_db = msssim_total / len(imagesList), msssim_db_total / len(imagesList) 398 | with open (args.evaluation_name + '.txt', 'w') as f: 399 | f.write('Avg_bpp_actual: '+str(Avg_bpp_actual)+'\n') 400 | f.write('Avg_bpp_estimate: '+str(Avg_bpp_estimate)+'\n') 401 | f.write('Avg_mse: '+str(Avg_mse)+'\n') 402 | f.write('Avg_psnr: '+str(Avg_psnr)+'\n') 403 | f.write('Avg_msssim: '+str(Avg_msssim)+'\n') 404 | f.write('Avg_msssim_db: '+str(Avg_msssim_db)+'\n') 405 | 406 | if __name__ == "__main__": 407 | parser = argparse.ArgumentParser( 408 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 409 | 410 | parser.add_argument( 411 | "command", choices=["train", "evaluate"], 412 | help="'train' loads training data and trains (or continues " 413 | "to train) a new model. 'evaluate' can get the RD curves from" 414 | "the pretrained model.") 415 | 416 | parser.add_argument( 417 | "--verbose", "-v", action="store_true", 418 | help="Report bitrate and distortion when training or compressing.") 419 | parser.add_argument( 420 | "--num_filters", nargs="+", type=int, default=[128], 421 | help="Number of filters per layer (per R-D tradeoff point).") 422 | parser.add_argument( 423 | "--checkpoint_dir", default="train", 424 | help="Directory where to save/load model checkpoints.") 425 | parser.add_argument( 426 | "--train_glob", default="images/*.png", 427 | help="Glob pattern identifying training data. This pattern must expand " 428 | "to a list of RGB images in PNG format.") 429 | parser.add_argument( 430 | "--batchsize", type=int, default=8, 431 | help="Batch size for training.") 432 | parser.add_argument( 433 | "--patchsize", type=int, default=256, 434 | help="Size of image patches for training.") 435 | parser.add_argument( 436 | "--lambda", nargs="+", type=float, default=[512], dest="lmbda", 437 | help="Lambdas for rate-distortion tradeoff points.") 438 | parser.add_argument( 439 | "--last_step", type=int, default=1000000, 440 | help="Train up to this number of steps.") 441 | parser.add_argument( 442 | "--preprocess_threads", type=int, default=6, 443 | help="Number of CPU threads to use for parallel decoding of training " 444 | "images.") 445 | parser.add_argument( 446 | "--modulation_init", action="store_true", 447 | help="Initialize the modulation network by using the default vectors.") 448 | parser.add_argument( 449 | "--filters_offset", nargs="+", type=int, default=[0], 450 | help="Offset filters (per R-D tradeoff point)") 451 | parser.add_argument( 452 | "--save_reconstruction", action="store_true", 453 | help="save reconstructed image while evaluation") 454 | parser.add_argument( 455 | "--model_ID", type=int, default=[0], 456 | help="Align the model which you want to use for compression/decompression.") 457 | parser.add_argument( 458 | "--condition", type=int, default=None, 459 | help="condition for different RD trade-off.") 460 | parser.add_argument( 461 | "--condition_norm", type=float, default=None, 462 | help="Normalization of condition values.") 463 | parser.add_argument( 464 | "--evaluation_name", type=str, default='results', 465 | help="the name of evaluation results txt file.") 466 | parser.add_argument( 467 | "--inputPath", type=str, default=None, 468 | help="Directory where to evaluation dataset.") 469 | parser.add_argument( 470 | "--outputPath", type=str, default=None, 471 | help="Directory where to save reconstructed images.") 472 | parser.add_argument( 473 | "--regularizer", type=str, default=None, 474 | help="regularizer of modulation network.") 475 | 476 | args = parser.parse_args() 477 | 478 | if args.command == "train": 479 | # Check consistency between lambda, num_filters, filters_offset 480 | if len(args.lmbda) != len(args.num_filters): 481 | raise ValueError("The length of lambda and num_filters should be the same.") 482 | if len(args.num_filters) != len(args.filters_offset): 483 | raise ValueError("The length num_filters and filters_offset should be the same.") 484 | train() 485 | elif args.command == "evaluate": 486 | if args.inputPath is None: 487 | raise ValueError("Need input path for evaluation.") 488 | evaluate() 489 | --------------------------------------------------------------------------------