├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── eval.py ├── loader.py ├── models ├── autoencoders.py └── bfcnn.py ├── train_bf_cnn.py ├── train_deep_jscc.py ├── utils.py └── visualization └── sample.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | # input data, saved log, checkpoints 2 | data/* 3 | **/*.pyc 4 | .vscode/ 5 | __MACOSX/ 6 | **/.DS_Store 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Changwoo Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Joint Source-Channel Coding with Iterative Source Error Correction 2 | 3 | An official Pytorch implementation of [Deep Joint Source-Channel Coding with Iterative Source Error Correction (AISTATS23)](https://arxiv.org/abs/2302.09174) 4 | 5 | ## Summary 6 | 7 | ![sample](visualization/sample.jpg) 8 | 9 | Deep joint source-channel coding (Deep JSCC) transmits a source through a noisy channel using deep neural network encoders and decoders. In our [AISTATS paper](https://arxiv.org/abs/2302.09174), we introduce an *iterative source error correction (ISEC)* algorithm which corrects the source error in the *codeword (latent) space* of the Deep JSCC encoder and decoder. 10 | This is achieved by maximizing the *modified maximum a posteriori (MAP)* probability, which comprises both the likelihood and the prior probability, using gradient ascent. 11 | While obtaining the likelihood is simple, estimating the prior probability of the codeword space of Deep JSCC is highly challenging. 12 | To address this, we use a *bias-free CNN denoiser* to predict the gradient of the log prior probability. 13 | Empirical evidence indicates that the decoded results from ISEC contain less distortion and more details than those from the one-shot decoding, especially when training and testing environments of the channel are mismatched. 14 | 15 | ## Getting Started 16 | 17 | ### Dependencies 18 | 19 | Python3.8+, Pytorch 1.9+, matplotlib, tensorboard, argparse, [pytorch-msssim](https://github.com/VainF/pytorch-msssim), [lpips](https://github.com/richzhang/PerceptualSimilarity), [pytorch-fid](https://github.com/mseitzer/pytorch-fid) 20 | 21 | 22 | ## Data Preparation 23 | 24 | ### Kodak image 25 | 26 | Kodak images can be downloaded from [http://r0k.us/graphics/kodak/](http://r0k.us/graphics/kodak/) 27 | 28 | ### Openimages 29 | 30 | 1. Install [awscli](https://aws.amazon.com/cli/) 31 | 2. Download the first two training data file and the validation set. 32 | 33 | ```bash 34 | aws s3 --no-sign-request cp s3://open-images-dataset/tar/train_0.tar.gz [target_dir/train] 35 | aws s3 --no-sign-request cp s3://open-images-dataset/tar/train_1.tar.gz [target_dir/train] 36 | aws s3 --no-sign-request sync s3://open-images-dataset/validation [target_dir/val] 37 | ``` 38 | 39 | 3. Unzip the files 40 | 41 | ```bash 42 | tar -xzf [target_dir/train]/train_0.tar.gz -C PATH_TO_DATA_DIR/train/ 43 | tar -xzf [target_dir/train]/train_1.tar.gz -C PATH_TO_DATA_DIR/train/ 44 | tar -xzf [target_dir/val]/validation.tar.gz -C PATH_TO_DATA_DIR/val/ 45 | ``` 46 | 47 | ### CIFAR-10 48 | 49 | CIFAR-10 images will be automatically downloaded if needed. 50 | 51 | 52 | ## Pretrained Models 53 | 54 | Pretrained models on CIFAR-10 and Openimages are available 55 | [here](https://drive.google.com/drive/folders/1nFAosEok5TBPY8DjudHLvs79WUCLzYiQ?usp=sharing). 56 | 57 | #### List of available models 58 | 59 | * Channel-per-pixel (CPP) = `0.5 * len(codeword) / len(input_image)`. 60 | 61 | | Dataset | Channel-per-pixel (CPP) [`-nc` option] | Signal-to-noise ratio (SNR) | 62 | |------------|---------------------------------|-----------------------------| 63 | | CIFAR-10 | 1/6 [`-nc=16`], 1/12 [`-nc=8`] | 0, 5, 10 | 64 | | Openimages | 1/6 [`-nc=16`], 1/16 [`-nc=6`] | 1, 7, 13 | 65 | 66 | 67 | 68 | ## Evaluation 69 | 70 | 71 | 72 | ### CIFAR-10 73 | 74 | ```bash 75 | # trained SNR=5dB, test SNR=0,5,10dB, CPP=1/6 76 | python eval.py -s 0 -st 5 -lr 0.002 --alpha=1.0 --delta=1.0 -jmp "saved_models/cifar/cpp_1_6/snr_5/cifar_16.pb" -bmp "saved_models/cifar/cpp_1_6/snr_5/cifar_16_bfcnn.pb" -nc 16 --print_freq=10 -bs 256 --num_conv_blocks=2 --num_residual_blocks=2 -dset "cifar" -ni 50 -mb 100 -ne 1 --gpu=0 --save_images 77 | python eval.py -s 5 -st 5 -lr 0.002 --alpha=1.0 --delta=1.0 -jmp "saved_models/cifar/cpp_1_6/snr_5/cifar_16.pb" -bmp "saved_models/cifar/cpp_1_6/snr_5/cifar_16_bfcnn.pb" -nc 16 --print_freq=10 -bs 256 --num_conv_blocks=2 --num_residual_blocks=2 -dset "cifar" -ni 50 -mb 100 -ne 1 --gpu=0 --save_images 78 | python eval.py -s 10 -st 5 -lr 0.002 --alpha=1.0 --delta=1.0 -jmp "saved_models/cifar/cpp_1_6/snr_5/cifar_16.pb" -bmp "saved_models/cifar/cpp_1_6/snr_5/cifar_16_bfcnn.pb" -nc 16 --print_freq=10 -bs 256 --num_conv_blocks=2 --num_residual_blocks=2 -dset "cifar" -ni 50 -mb 100 -ne 1 --gpu=0 --save_images 79 | ``` 80 | 81 | ### Kodak 82 | 83 | ```bash 84 | # trained SNR=7dB, test SNR=1,7,13dB, CPP=1/6 85 | python eval.py -s 1 -st 7 -lr 0.001 --alpha=4.0 --delta=1.0 -jmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16.pb" -bmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16_bfcnn.pb" -nc 16 --print_freq=10 -bs 1 --num_conv_blocks=2 --num_residual_blocks=3 -dset "kodak" --data_dir="./data/kodak" --num_hidden=128 -ni 50 -mb 100 -ne 10 --gpu=0 --save_images 86 | python eval.py -s 7 -st 7 -lr 0.001 --alpha=4.0 --delta=1.0 -jmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16.pb" -bmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16_bfcnn.pb" -nc 16 --print_freq=10 -bs 1 --num_conv_blocks=2 --num_residual_blocks=3 -dset "kodak" --data_dir="./data/kodak" --num_hidden=128 -ni 50 -mb 100 -ne 10 --gpu=0 --save_images 87 | python eval.py -s 13 -st 7 -lr 0.001 --alpha=4.0 --delta=1.0 -jmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16.pb" -bmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16_bfcnn.pb" -nc 16 --print_freq=10 -bs 1 --num_conv_blocks=2 --num_residual_blocks=3 -dset "kodak" --data_dir="./data/kodak" --num_hidden=128 -ni 50 -mb 100 -ne 10 --gpu=0 --save_images 88 | ``` 89 | 90 | ### Kodak-with additive white Laplace noise 91 | 92 | ```bash 93 | python eval.py -dist "Laplace" -s 7 -st 7 -lr 0.001 --alpha=4.0 --delta=1.0 -jmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16.pb" -bmp "saved_models/openimages/cpp_1_6/snr_7/openimages_16_bfcnn.pb" -nc 16 --print_freq=10 -bs 1 --num_conv_blocks=2 --num_residual_blocks=3 -dset "kodak" --data_dir="./data/kodak" --num_hidden=128 -ni 50 -mb 100 -ne 10 --gpu=0 94 | ``` 95 | 96 | 97 | ## Training 98 | 99 | ### Deep JSCC Models on CIFAR-10 100 | 101 | ```bash 102 | # SNR=5, CPP=1/6 103 | python train_deep_jscc.py --gpu=0 --model_path=PATH_TO_MODEL_DIR -nc 16 -e 300 -bs 64 -lr 2e-4 --num_conv_blocks=2 --num_residual_blocks=2 --snr 5 --print_freq=100 -dset "cifar" 104 | ``` 105 | 106 | ### Bias-free Codeword Denoisers on CIFAR-10 107 | 108 | ```bash 109 | # SNR=5, CPP=1/6 110 | python train_bf_cnn.py --gpu=0 --model_path=PATH_TO_MODEL_DIR -nc 16 -e 300 -bs 64 -lr 2e-4 --num_conv_blocks=2 --num_residual_blocks=2 --snr=5 --print_freq=100 -dset "cifar" --pretrained_model_path="./saved_models/cifar/cpp_1_6/snr_5/cifar_16.pb" 111 | ``` 112 | 113 | 114 | 115 | ### Deep JSCC Models on OpenImages 116 | 117 | ```bash 118 | # SNR=7, CPP=1/16 119 | python train_deep_jscc.py --gpu=0 --model_path=PATH_TO_MODEL_DIR --data_dir=PATH_TO_DATA_DIR -nc 6 -e 45 -bs 128 -lr 1e-3 --num_conv_blocks=2 --num_residual_blocks=3 --snr 7 --print_freq=100 -dset "openimages" --image_size=128 --display_freq=1 --num_hidden=128 --save_freq=6 --test_freq=3 120 | ``` 121 | 122 | 123 | ### Bias-free Codeword Denoisers on OpenImages 124 | 125 | ```bash 126 | # SNR=7, CPP=1/16 127 | python train_bf_cnn.py --gpu=0 --model_path=PATH_TO_MODEL_DIR --data_dir=PATH_TO_DATA_DIR -nc 6 -e 25 -bs 64 -lr 2e-4 --num_conv_blocks=2 --num_residual_blocks=3 --snr 7 --print_freq=100 -dset "openimages" --image_size=128 --display_freq=1 --num_hidden=128 --save_freq=6 --test_freq=3 --pretrained_model_path="./saved_models/openimages/cpp_1_16/snr_7/openimages_6.pb" 128 | ``` 129 | 130 | ## Citation 131 | 132 | If you use this code or find our work valuable, please cite: 133 | ``` 134 | @inproceedings{lee2023deep, 135 | title={Deep Joint Source-Channel Coding with Iterative Source Error Correction}, 136 | author={Lee, Changwoo and Hu, Xiao and Kim, Hun-Seok}, 137 | booktitle={International Conference on Artificial Intelligence and Statistics}, 138 | year={2023} 139 | } 140 | ``` 141 | 142 | 143 | 144 | ## License 145 | 146 | This project is licensed under the MIT License - see the LICENSE.md file for details 147 | 148 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_common_parser(parser=None): 4 | if parser is None: 5 | parser = argparse.ArgumentParser() 6 | 7 | ### Common Arguments ### 8 | parser.add_argument('--num_channels', '-nc', type=int, default=6, help='number of channels') 9 | parser.add_argument('--gpu', type=int, default=-1, help='gpu id. -1: cpu') 10 | parser.add_argument('--seed', type=int, default=1234, help='random seed') 11 | parser.add_argument('--batch_size', '-bs', type=int, default=128, help='batch_size') 12 | parser.add_argument('--model_name',type=str, default='openimages', help='model name') 13 | parser.add_argument('--data_dir', '-dd', type=str, default='./data', help='data directory') 14 | parser.add_argument('--display_freq', '-df', type=int, default=10, help='save_freq') 15 | parser.add_argument('--print_freq', '-pf', type=int, default=50, help='print freq') 16 | parser.add_argument('--image_size', '-is', type=int, default=128, help='image_size') 17 | 18 | parser.add_argument('--snr', '-s', type=float, default=0., help='SNR for Training') 19 | parser.add_argument('--lr', '-lr', type=float, default=0.0005, help='Learning rate') 20 | 21 | parser.add_argument('--dataset', '-dset', type=str, default="openimages", help='Specify Dataset. "cifar", "openimages", "kodak"') 22 | 23 | ### Model Configuration ### 24 | parser.add_argument('--num_hidden',type=int, default=32, help='Number of hidden nodes.') 25 | parser.add_argument('--power_norm',type=str, default='hard', help='Power normalization type. "hard|soft|none", default: "hard".') 26 | parser.add_argument('--num_conv_blocks', type=int, default=2, help='Number of hidden nodes.') 27 | parser.add_argument('--num_residual_blocks', type=int, default=3, help='Number of hidden nodes.') 28 | 29 | parser.add_argument('--debug', action='store_true', help='Debug mode') 30 | 31 | return parser 32 | 33 | 34 | def get_train_parser(parser=None): 35 | if parser is None: 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--epochs', '-e', type=int, default=150, help='number of epochs') 38 | parser.add_argument('--test_batch_size', type=int, default=128, help='batch_size') 39 | parser.add_argument('--save_freq', type=int, default=25, help='save_freq') 40 | parser.add_argument('--test_freq', type=int, default=5, help='test freq') 41 | parser.add_argument('--show_outputs', action='store_true', help='Show test output if ture.') 42 | 43 | parser.add_argument('--model_path',type=str, default='saved_models/imagenet/', help='model path') 44 | 45 | parser.add_argument('--eval', action='store_true', help='Eval mode') 46 | parser.add_argument('--pretrained_model_path',type=str, default=None, help='model path') 47 | parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay parameter') 48 | 49 | parser.add_argument('--train_image_size', type=int, default=128, help='image size') 50 | 51 | return parser 52 | 53 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from utils import tensor2im, PSNR, save_image_collections, save_to_json 7 | 8 | from pytorch_msssim import ms_ssim, ssim 9 | from pytorch_fid.fid_score import calculate_fid_given_paths 10 | import lpips 11 | 12 | 13 | import loader 14 | import config 15 | import models.autoencoders as ae 16 | from models.bfcnn import BF_CNN 17 | 18 | parser = config.get_common_parser() 19 | 20 | parser.add_argument('--jscc_model_path', '-jmp', type=str, default=None, help='model path') 21 | parser.add_argument('--bfcnn_model_path', '-bmp', type=str, default=None, help='model path') 22 | parser.add_argument('--loss_type', type=str, default='l2', help='l2|l1 default=l2') 23 | 24 | 25 | parser.add_argument('--num_iter', '-ni', type=int, default=100, help="Number of SEC iterations.") 26 | parser.add_argument('--save_images', action='store_true', help='Save output images') 27 | parser.add_argument('--max_batch', '-mb', type=int, default=100, help='Number of maximum batch') 28 | parser.add_argument('--img_prefix',type=str, default="", help='Saved images prefix') 29 | 30 | parser.add_argument('--save_json', action='store_true', help='Save JSON file') 31 | parser.add_argument('--json_file_path',type=str, default="", help='path for JSON file') 32 | parser.add_argument('--snr_train', '-st', type=int, default=0, help="Trained SNR") 33 | 34 | parser.add_argument('--output_dir', '-od', type=str, default='./outputs', help='output directory') 35 | parser.add_argument('--no_denoiser', action='store_true', help='Do not use denoiser') 36 | parser.add_argument('--alpha', '-al', type=float, default=0.0, help='modified MAP parameter') 37 | parser.add_argument('--delta', '-de', type=float, default=1.5, help='delta') 38 | parser.add_argument('--stop_ratio', '-sr', type=float, default=0.0, help='Stopping Criterion') 39 | parser.add_argument('--num_experiment', '-ne', type=int, default=10, help="Number of experiment") 40 | parser.add_argument('--distribution', '-dist', type=str, default='Gaussian', help='Noise distribution. |Gaussian (default)|Laplace|') 41 | 42 | 43 | args = parser.parse_args() 44 | if args.debug: 45 | args.print_freq = 1 46 | args.display_freq = 1 47 | 48 | dev = "cuda:{}".format(args.gpu) if args.gpu>=0 else "cpu" 49 | 50 | device = torch.device(dev) 51 | np.random.seed(args.seed) 52 | torch.manual_seed(args.seed) 53 | 54 | test_dataloader = loader.get_test_dataloader(args) 55 | image_range=(-1, 1) 56 | print(len(test_dataloader)) 57 | 58 | if args.loss_type == 'l2': 59 | criterion = nn.MSELoss(reduction='sum') 60 | elif args.loss_type == 'l1': 61 | criterion = nn.L1Loss(reduction='sum') 62 | else: 63 | raise NotImplementedError() 64 | 65 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) 66 | 67 | def print_update(i, 68 | t, 69 | scaled_denoiser_sqnorm, 70 | obj, 71 | B, 72 | outputs, 73 | inputs, 74 | psnr_orig, 75 | msssim_orig, 76 | avg_psnr, 77 | avg_msssim, 78 | lpips_orig=None, 79 | avg_lpips=None, 80 | last_iter=False): 81 | with torch.no_grad(): 82 | lpips = loss_fn_vgg(outputs, inputs).mean() 83 | inputs_255 = tensor2im(inputs, *image_range, 'torch') 84 | psnr = PSNR(reduction='sum')(outputs, inputs, *image_range, offset=0) 85 | try: 86 | msssim = ms_ssim(tensor2im(outputs, *image_range,'torch'), 87 | inputs_255, data_range=255, size_average=True) 88 | label = "MS-SSIM" 89 | except AssertionError: 90 | msssim = ssim(tensor2im(outputs, *image_range, 'torch'), 91 | inputs_255, data_range=255, size_average=True) 92 | label = "SSIM" 93 | 94 | if last_iter: 95 | avg_psnr += psnr 96 | avg_msssim += msssim 97 | avg_lpips += lpips 98 | message = "[{:4d}, {:4d}] sigma_t^2: {:.4f} Obj: {:.4f} PSNR: {:.2f} PSNR Orig: {:.2f} {}: {:.4f} {} Orig: {:.4f}".format( 99 | i+1, t+1, scaled_denoiser_sqnorm.item(), obj.item(), psnr / B, psnr_orig / B, label, msssim, label, msssim_orig) 100 | message += " LPIPS: {:.4f} LPIPS Orig: {:.4f}".format(lpips, lpips_orig) 101 | print(message) 102 | return avg_psnr, avg_msssim, avg_lpips 103 | 104 | 105 | def print_avg(i, 106 | count, 107 | avg_psnr, 108 | avg_psnr_orig, 109 | avg_msssim, 110 | avg_msssim_orig, 111 | avg_lpips, 112 | avg_lpips_orig): 113 | 114 | stats = {'PSNR': avg_psnr.item()/count, 115 | 'PSNR Orig': avg_psnr_orig.item()/count, 116 | 'PSNR Gain': (avg_psnr - avg_psnr_orig).item()/count, 117 | 'MS SSIM': avg_msssim.item()/i, 118 | 'MS SSIM Orig': avg_msssim_orig.item()/i, 119 | 'MS SSIM Gain': (avg_msssim - avg_msssim_orig).item()/i, 120 | 'lpips': avg_lpips.item()/i, 121 | 'lpips Orig': avg_lpips_orig.item()/i, 122 | 'lpips Gain': (-avg_lpips + avg_lpips_orig).item()/i, 123 | } 124 | message = '[{:4d}] Average'.format(i) 125 | for name, val in stats.items(): 126 | message += " - {}: {:.4f}".format(name, val) 127 | print(message) 128 | return stats 129 | 130 | 131 | def test_latent(net, stddev=0., saved_dir=None, writer=None, epoch=0): 132 | net.eval() 133 | avg_psnr_orig = 0. 134 | avg_psnr = 0. 135 | avg_msssim_orig = 0. 136 | avg_msssim = 0. 137 | avg_lpips_orig = 0. 138 | avg_lpips = 0. 139 | avg_fid = 0. 140 | avg_fid_orig = 0. 141 | count = 0. 142 | 143 | base_var = 10**(-0.1*args.snr_train) 144 | print(args.distribution) 145 | 146 | if args.distribution == 'Gaussian' or args.distribution == 'Fading': 147 | dist = torch.distributions.normal.Normal(0.0, stddev) 148 | elif args.distribution == 'Laplace': 149 | dist = torch.distributions.laplace.Laplace(0.0, stddev/np.sqrt(2)) 150 | else: 151 | raise NotImplementedError() 152 | 153 | 154 | decoder = net.decoder 155 | 156 | L = args.num_experiment 157 | psnr_vals = np.zeros(L) 158 | psnr_orig_vals = np.zeros(L) 159 | ssim_vals = np.zeros(L) 160 | ssim_orig_vals = np.zeros(L) 161 | lpips_vals = np.zeros(L) 162 | lpips_orig_vals = np.zeros(L) 163 | fid_vals = np.zeros(L) 164 | fid_orig_vals = np.zeros(L) 165 | for e in range(L): 166 | for i, data in enumerate(test_dataloader): 167 | if i==args.max_batch: 168 | break 169 | with torch.no_grad(): 170 | # Sample Test Image 171 | inputs = data[0].to(device) 172 | B,C,H,W = inputs.size() 173 | # Encode 174 | codeword = net.encoder(inputs) 175 | _,zC,zH,zW = codeword.size() 176 | # Corrupted codeword 177 | noise = dist.sample(codeword.size()).to(device) 178 | if args.distribution == 'Fading': 179 | h = torch.randn(B, 1, 1, 1) 180 | codeword = h * codeword 181 | y = codeword + noise 182 | # One-shot Decoding 183 | outputs = decoder(y) 184 | 185 | 186 | psnr_orig = PSNR(reduction='sum')(outputs, inputs, *image_range, offset=0) 187 | inputs_255 = tensor2im(inputs, *image_range, 'torch') 188 | try: 189 | msssim_orig = ms_ssim(tensor2im(outputs, *image_range, 'torch'), 190 | inputs_255, data_range=255, size_average=True) 191 | label = "MS-SSIM" 192 | except AssertionError: 193 | msssim_orig = ssim(tensor2im(outputs, *image_range, 'torch'), 194 | inputs_255, data_range=255, size_average=True) 195 | label = "SSIM" 196 | 197 | lpips_orig = loss_fn_vgg(outputs, inputs).sum() 198 | 199 | avg_psnr_orig += psnr_orig 200 | avg_msssim_orig += msssim_orig 201 | avg_lpips_orig += lpips_orig / B 202 | count += B 203 | outputs_orig = outputs.clone() 204 | 205 | # Initialize zt with y 206 | with torch.no_grad(): 207 | init_p = y 208 | zt = init_p.detach().clone().requires_grad_() 209 | 210 | with torch.no_grad(): 211 | # Logging Purpose, sqnorm of denoiser output 212 | dt = net.denoiser(zt) 213 | vart = torch.sum(dt**2, dim=(1,2,3), keepdim=True)/(zC*zH*zW) 214 | var_scale = stddev**2 / vart.mean() 215 | vart *= var_scale 216 | varL = args.stop_ratio * vart.mean().item() 217 | 218 | # Compute delta 219 | delta = args.delta if (stddev**2/base_var) > 1 else 1.0 220 | for t in range(args.num_iter): 221 | with torch.no_grad(): 222 | # Logging Purpose, sqnorm of denoiser output 223 | dt = net.denoiser(zt) 224 | vart = torch.sum(dt**2, dim=(1,2,3), keepdim=True)/(zC*zH*zW) 225 | vart *= var_scale 226 | scaled_denoiser_sqnorm = vart.mean() 227 | 228 | # Evaluate NLL 229 | z_p = net.encoder(decoder(zt)) 230 | obj = 1/(2*(stddev**2)) * criterion(z_p, y) 231 | obj.backward() 232 | 233 | with torch.no_grad(): 234 | # Gradient of NLL 235 | zt_grad = -zt.grad 236 | # Scale the output of the denoiser (approximate gradient of the log prior) 237 | zt_grad += args.alpha * max(0.1, (stddev**2/base_var)**2) * dt 238 | lr = args.lr / max(0.1, (stddev**2/base_var)**delta) 239 | zt.data = zt + lr * zt_grad 240 | zt.grad.zero_() 241 | 242 | if t % args.print_freq == args.print_freq - 1 or t == args.num_iter - 1 or t==0: 243 | outputs = decoder(zt) 244 | avg_psnr, avg_msssim, avg_lpips = print_update(i, 245 | t, scaled_denoiser_sqnorm, obj, B, outputs, inputs, 246 | psnr_orig, msssim_orig, 247 | avg_psnr, avg_msssim, 248 | lpips_orig/B, avg_lpips, last_iter=t==args.num_iter-1) 249 | 250 | 251 | stats = print_avg(e*min(len(test_dataloader), args.max_batch) + i+1, count, avg_psnr, avg_psnr_orig, avg_msssim, avg_msssim_orig, avg_lpips, avg_lpips_orig) 252 | psnr_vals[e] = avg_psnr/count 253 | psnr_orig_vals[e] = avg_psnr_orig/count 254 | ssim_vals[e] = avg_msssim/(e+1) 255 | ssim_orig_vals[e] = avg_msssim_orig/(e+1) 256 | lpips_vals[e] = avg_lpips/(e+1) 257 | lpips_orig_vals[e] = avg_lpips_orig/(e+1) 258 | if args.save_images: 259 | output_dir = args.output_dir 260 | subdirs = ['targets', 'orig', 'updated'] 261 | for sd in subdirs: 262 | if not os.path.exists(os.path.join(output_dir, args.img_prefix, sd, "files")): 263 | os.makedirs(os.path.join(output_dir, args.img_prefix, sd, "files")) 264 | targets = tensor2im(inputs, *image_range) 265 | orig = tensor2im(outputs_orig, *image_range) 266 | updated = tensor2im(outputs, *image_range) 267 | for b in range(B): 268 | plt.imsave(os.path.join(output_dir, 269 | args.img_prefix, "targets", "files", "targets{:04d}.png".format(i*args.batch_size + b)), targets[b,:,:,:]) 270 | plt.imsave(os.path.join(output_dir, args.img_prefix, "orig", "files", "orig{:04d}.png".format(i*args.batch_size + b)), orig[b,:,:,:]) 271 | plt.imsave(os.path.join(output_dir, args.img_prefix, "updated", "files", "updated{:04d}.png".format(i*args.batch_size + b)), updated[b,:,:,:]) 272 | 273 | if args.save_images and args.image_size > 0: 274 | save_image_collections(args.img_prefix, np.minimum(36, i*args.batch_size), output_dir, nrow=6) 275 | 276 | try: 277 | fid_orig = calculate_fid_given_paths( 278 | ["{}/{}/orig/files/".format(output_dir, args.img_prefix), "{}/{}/targets/files/".format(output_dir, args.img_prefix)], 279 | min(args.batch_size, 16), 280 | device, 281 | 2048, 282 | 3) 283 | fid_updated = calculate_fid_given_paths( 284 | ["{}/{}/updated/files/".format(output_dir, args.img_prefix), "{}/{}/targets/files/".format(output_dir, args.img_prefix)], 285 | min(args.batch_size, 16), 286 | device, 287 | 2048, 288 | 3) 289 | 290 | avg_fid += fid_updated 291 | avg_fid_orig += fid_orig 292 | #stats["FID Gain"] += fid_updated - fid_orig 293 | print("FID: {:.2f}, FID Orig: {:.2f}, FID Gain: {:.2f}".format(fid_updated, fid_orig, fid_orig - fid_updated)) 294 | 295 | fid_vals[e] = fid_updated 296 | fid_orig_vals[e] = fid_orig 297 | 298 | except RuntimeError as e: 299 | print("FID unavailable") 300 | print(e) 301 | pass 302 | 303 | try: 304 | stats["FID"] = avg_fid / L 305 | stats["FID Orig"] = avg_fid_orig / L 306 | stats["FID Gain"] = -stats["FID"] + stats["FID Orig"] 307 | stats['PSNR Numpy'] = psnr_vals.tolist() 308 | stats['PSNR Orig Numpy'] = psnr_orig_vals.tolist() 309 | stats['SSIM Numpy'] = ssim_vals.tolist() 310 | stats['SSIM Orig Numpy'] = ssim_orig_vals.tolist() 311 | stats['LPIPS Numpy'] = lpips_vals.tolist() 312 | stats['LPIPS Orig Numpy'] = lpips_orig_vals.tolist() 313 | stats['FID Numpy'] = fid_vals.tolist() 314 | stats['FID Orig Numpy'] = fid_orig_vals.tolist() 315 | except KeyError: 316 | pass 317 | 318 | print(stats) 319 | return stats 320 | 321 | 322 | def main(): 323 | 324 | if 'cifar' in args.dataset: 325 | Enc = ae.Encoder_CIFAR 326 | Dec = ae.Decoder_CIFAR 327 | else: 328 | Enc = ae.Encoder 329 | Dec = ae.Decoder 330 | encoder = Enc(num_out=args.num_channels, 331 | num_hidden=args.num_hidden, 332 | num_conv_blocks=args.num_conv_blocks, 333 | num_residual_blocks=args.num_residual_blocks, 334 | normalization=nn.BatchNorm2d, 335 | activation=nn.PReLU, 336 | power_norm=args.power_norm) 337 | 338 | decoder = Dec(num_in=args.num_channels, 339 | num_hidden=args.num_hidden, 340 | num_conv_blocks=args.num_conv_blocks, 341 | num_residual_blocks=args.num_residual_blocks, 342 | normalization=nn.BatchNorm2d, 343 | activation=nn.PReLU, 344 | no_tanh=False) 345 | 346 | net = ae.Generator(encoder, decoder) 347 | print(args) 348 | 349 | try: 350 | filepath = args.jscc_model_path 351 | print("Try loading "+filepath) 352 | net.load_state_dict(torch.load(filepath, map_location=dev)) 353 | except Exception as e: 354 | print(e) 355 | print("Loading Failed...") 356 | exit() 357 | 358 | bfcnn = BF_CNN(1, 64, 3, 20, args.num_channels) 359 | try: 360 | filepath = args.bfcnn_model_path 361 | print("Try loading "+filepath) 362 | bfcnn.load_state_dict(torch.load(filepath, map_location=dev)) 363 | except Exception as e: 364 | print(e) 365 | print("Loading Failed...") 366 | exit() 367 | 368 | bfcnn.to(device) 369 | net.denoiser = lambda z_: -bfcnn(z_) 370 | net.to(device) 371 | 372 | if args.save_json: 373 | args.img_prefix = os.path.join(args.dataset, 374 | "{}_snr_train_{:.1f}_snr_{:.1f}_nc_{}".format(args.img_prefix, 375 | args.snr_train, 376 | args.snr, 377 | args.num_channels)) 378 | 379 | stats = test_latent(net, 10**(-0.05*args.snr)) 380 | 381 | if args.save_json: 382 | save_to_json(stats, args) 383 | 384 | 385 | if __name__=='__main__': 386 | main() 387 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | 6 | 7 | def get_train_dataloader(args): 8 | if 'cifar' in args.dataset: 9 | args.train_image_size = 32 10 | transform_train = transforms.Compose( 11 | [transforms.RandomCrop(args.train_image_size), 12 | transforms.RandomHorizontalFlip(), 13 | transforms.ToTensor(), 14 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]) 15 | 16 | if 'cifar' in args.dataset: 17 | trainset = torchvision.datasets.CIFAR10(args.data_dir, train=True, transform=transform_train, download=True) 18 | args.image_size=32 19 | elif 'openimages' in args.dataset: 20 | trainset = torchvision.datasets.ImageFolder(os.path.join(args.data_dir, 'train'), transform=transform_train) 21 | else: 22 | raise NotImplementedError() 23 | 24 | train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, 25 | shuffle=True, num_workers=3, 26 | drop_last=True, 27 | pin_memory=True, ) 28 | 29 | return train_dataloader 30 | 31 | 32 | def get_test_dataloader(args): 33 | if 'cifar' in args.dataset: 34 | args.image_size = 32 35 | transform_list = [] 36 | if args.image_size > -1: 37 | transform_list += [transforms.Resize(args.image_size), transforms.CenterCrop(args.image_size)] 38 | transform_list += [transforms.ToTensor()] 39 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 40 | if 'cifar' in args.dataset: 41 | testset = torchvision.datasets.CIFAR10(args.data_dir, train=False, transform=transforms.Compose(transform_list), download=True) 42 | elif 'openimages' in args.dataset: 43 | testset = torchvision.datasets.ImageFolder(os.path.join(args.data_dir, 'val'), transform=transforms.Compose(transform_list)) 44 | elif 'kodak' in args.dataset: 45 | testset = torchvision.datasets.ImageFolder(args.data_dir, transform=transforms.Compose(transform_list)) 46 | else: 47 | raise NotImplementedError() 48 | test_dataloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, 49 | shuffle=False, num_workers=1) 50 | return test_dataloader 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /models/autoencoders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class UnitNorm(nn.Module): 6 | def __init__(self, max_power=1, eps=1e-5): 7 | super(UnitNorm, self).__init__() 8 | self.max_power = max_power 9 | 10 | def forward(self, x): 11 | x_shape = x.size() 12 | x = x.reshape(x_shape[0], -1) 13 | multiplier = self.max_power * np.sqrt(x.size(1)) 14 | proj_idx = torch.norm(x, p=2, dim=1) > multiplier 15 | x[proj_idx] = multiplier * x[proj_idx] / torch.norm(x[proj_idx], p=2, dim=1, keepdim=True) 16 | return x.reshape(*x_shape) 17 | 18 | 19 | def unitnorm(x, max_power=1.0): 20 | x_shape = x.size() 21 | x = x.reshape(x_shape[0], -1) 22 | multiplier = max_power * np.sqrt(x.size(1)) 23 | proj_idx = torch.norm(x, p=2, dim=1) > multiplier 24 | x[proj_idx] = multiplier * x[proj_idx] / torch.norm(x[proj_idx], p=2, dim=1, keepdim=True) 25 | return x.reshape(*x_shape) 26 | 27 | 28 | class ResBlock(nn.Module): 29 | def __init__(self, dim1, dim2, normalization=nn.BatchNorm2d, activation=nn.ReLU, r=2, conv=None, kernel_size=5): 30 | super(ResBlock, self).__init__() 31 | 32 | if conv is None: 33 | conv = nn.Conv2d 34 | self.block1 = nn.Sequential(conv(dim1, dim1//r, kernel_size, stride=1, padding=(kernel_size-1)//2), normalization(dim1//r), activation()) 35 | self.block2 = nn.Sequential(conv(dim1//r, dim2, kernel_size, stride=1, padding=(kernel_size-1)//2), normalization(dim2)) 36 | if dim1 != dim2: 37 | self.shortcut = conv(dim1, dim2, 1, bias=False) 38 | else: 39 | self.shortcut = nn.Identity() 40 | def forward(self, x): 41 | x_short = self.shortcut(x) 42 | x = self.block2(self.block1(x)) 43 | return x+x_short 44 | 45 | 46 | 47 | 48 | class Encoder(nn.Sequential): 49 | """Encoder Module""" 50 | def __init__(self, 51 | num_out, 52 | num_hidden, 53 | num_conv_blocks=2, 54 | num_residual_blocks=3, 55 | normalization=nn.BatchNorm2d, 56 | activation=nn.PReLU, 57 | power_norm="hard", 58 | primary_latent_power=1.0, 59 | r=2, 60 | conv_kernel_size=5, 61 | residual=True): 62 | conv = nn.Conv2d 63 | layers = [conv(3, num_hidden, 7, stride=2, padding=3), 64 | normalization(num_hidden), 65 | activation()] 66 | 67 | channels = num_hidden 68 | for _ in range(num_conv_blocks-1): 69 | layers += [conv(channels, channels, conv_kernel_size, stride=2, padding=(conv_kernel_size-1)//2)] 70 | layers += [normalization(channels), 71 | activation()] 72 | if residual: 73 | layers += [ResBlock(channels, channels, normalization, activation, conv=conv, r=r, kernel_size=conv_kernel_size), activation()] 74 | 75 | for _ in range(num_residual_blocks-1): 76 | if residual: 77 | layers += [ResBlock(channels, channels, normalization, activation, conv=conv, r=r, kernel_size=conv_kernel_size), activation()] 78 | else: 79 | layers += [conv(channels, channels, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), normalization(channels), activation()] 80 | 81 | layers += [conv(channels, num_out, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), 82 | normalization(num_out)] 83 | if power_norm == "hard": 84 | layers += [UnitNorm()] 85 | elif power_norm == "soft": 86 | layers += [nn.BatchNorm2d(num_out, affine=False)] 87 | else: 88 | raise NotImplementedError() 89 | 90 | super(Encoder, self).__init__(*layers) 91 | 92 | 93 | class Encoder_CIFAR(nn.Sequential): 94 | """Encoder Module""" 95 | def __init__(self, 96 | num_out, 97 | num_hidden, 98 | num_conv_blocks=2, 99 | num_residual_blocks=2, 100 | conv = nn.Conv2d, 101 | normalization=nn.BatchNorm2d, 102 | activation=nn.PReLU, 103 | power_norm="hard", 104 | primary_latent_power=1.0, 105 | r=2, 106 | max_channels=512, 107 | first_conv_size=7, 108 | conv_kernel_size=3, 109 | residual=True, 110 | **kwargs): 111 | 112 | bias = normalization==nn.Identity 113 | layers = [conv(3, num_hidden, first_conv_size, stride=1, padding=(first_conv_size-1)//2, bias=bias), 114 | normalization(num_hidden), 115 | activation()] 116 | 117 | channels = num_hidden 118 | for _ in range(num_conv_blocks): 119 | channels *= 2 120 | layers += [conv(np.minimum(channels//2, max_channels), np.minimum(channels, max_channels), 3, stride=2, padding=1, bias=bias)] 121 | layers += [normalization(np.minimum(channels, max_channels)), 122 | activation()] 123 | if residual: 124 | layers += [ResBlock(np.minimum(channels, max_channels), np.minimum(channels, max_channels), normalization, activation, conv=conv, r=r, kernel_size=3), activation()] 125 | else: 126 | layers += [conv(np.minimum(channels, max_channels), np.minimum(channels, max_channels), 3, stride=1, padding=1), normalization(np.minimum(channels, max_channels)), activation()] 127 | 128 | for _ in range(num_residual_blocks): 129 | if residual: 130 | layers += [ResBlock(np.minimum(channels, max_channels), np.minimum(channels, max_channels), normalization, activation, conv=conv, r=r, kernel_size=3), activation()] 131 | else: 132 | layers += [conv(np.minimum(channels, max_channels),np.minimum(channels, max_channels), conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), 133 | normalization(np.minimum(channels, max_channels)), 134 | activation()] 135 | 136 | if residual: 137 | layers += [ResBlock(np.minimum(channels, max_channels), num_out, normalization, activation, conv=conv, r=r, kernel_size=3)] 138 | else: 139 | layers += [conv(np.minimum(channels, max_channels), num_out, 3, stride=1, padding=1)] 140 | if power_norm == "hard": 141 | layers += [UnitNorm()] 142 | elif power_norm == "soft": 143 | layers += [nn.BatchNorm2d(num_out, affine=False)] 144 | elif power_norm == "none": 145 | pass 146 | else: 147 | raise NotImplementedError() 148 | 149 | super(Encoder_CIFAR, self).__init__(*layers) 150 | 151 | 152 | class Decoder(nn.Sequential): 153 | def __init__(self, 154 | num_in, 155 | num_hidden, 156 | num_conv_blocks=2, 157 | num_residual_blocks=3, 158 | normalization=nn.BatchNorm2d, 159 | activation=nn.PReLU, 160 | no_tanh=False, 161 | r=2, 162 | conv_kernel_size=5, 163 | residual=True): 164 | 165 | channels = num_hidden 166 | 167 | layers = [nn.Conv2d(num_in, channels, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), 168 | normalization(channels), 169 | activation()] 170 | 171 | for _ in range(num_residual_blocks-1): 172 | if residual: 173 | layers += [ResBlock(channels, channels, normalization, activation, r=r, kernel_size=conv_kernel_size), activation()] 174 | else: 175 | layers += [nn.Conv2d(channels, channels, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), normalization(channels), activation()] 176 | 177 | for _ in range(num_conv_blocks-1): 178 | layers += [nn.Upsample(scale_factor=(2,2), mode='bilinear')] 179 | layers += [nn.Conv2d(channels, channels, conv_kernel_size, 1, padding=(conv_kernel_size-1)//2)] 180 | layers += [normalization(channels), activation()] 181 | if residual: 182 | layers += [ResBlock(channels, channels, normalization, activation, r=r, kernel_size=conv_kernel_size), activation()] 183 | 184 | layers += [nn.Upsample(scale_factor=(2,2), mode='bilinear'), 185 | nn.Conv2d(num_hidden, 3, 7, stride=1, padding=3)] 186 | layers += [normalization(3)] 187 | if not no_tanh: 188 | layers += [nn.Tanh()] 189 | super(Decoder, self).__init__(*layers) 190 | 191 | 192 | class Decoder_CIFAR(nn.Sequential): 193 | def __init__(self, 194 | num_in, 195 | num_hidden, 196 | num_conv_blocks=2, 197 | num_residual_blocks=2, 198 | normalization=nn.BatchNorm2d, 199 | activation=nn.PReLU, 200 | no_tanh=False, 201 | bias_free=False, 202 | r=2, 203 | residual=True, 204 | max_channels=512, 205 | last_conv_size=5, 206 | normalize_first=False, 207 | conv_kernel_size=3, 208 | **kwargs): 209 | 210 | channels = num_hidden * (2**num_conv_blocks) 211 | 212 | layers = [nn.Conv2d(num_in, min(max_channels, channels), 3, stride=1, padding=1, bias=False), 213 | normalization(channels), 214 | activation()] 215 | 216 | for _ in range(num_residual_blocks): 217 | if residual: 218 | layers += [ResBlock(min(max_channels, channels), min(max_channels, channels), normalization, activation, r=r, kernel_size=conv_kernel_size), activation()] 219 | else: 220 | layers += [nn.Conv2d(min(max_channels, channels), min(max_channels, channels), conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), normalization(channels), activation()] 221 | 222 | for _ in range(num_conv_blocks): 223 | channels = channels // 2 224 | layers += [nn.Upsample(scale_factor=(2,2), mode='bilinear'), 225 | nn.Conv2d(min(max_channels, channels*2), min(max_channels, channels), 3, 1, 1, bias=False), 226 | normalization(min(max_channels, channels)), 227 | activation()] 228 | if residual: 229 | layers += [ResBlock(min(max_channels, channels), min(max_channels, channels), normalization, activation, r=r, kernel_size=3), activation()] 230 | else: 231 | layers += [nn.Conv2d(min(max_channels, channels), min(max_channels, channels), conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), normalization(channels), activation()] 232 | 233 | layers += [nn.Conv2d(num_hidden, 3, last_conv_size, stride=1, padding=(last_conv_size-1)//2, bias=False)] 234 | 235 | if not normalize_first: 236 | layers += [normalization(3)] 237 | if not no_tanh: 238 | layers += [nn.Tanh()] 239 | 240 | super(Decoder_CIFAR, self).__init__(*layers) 241 | 242 | 243 | class Generator(nn.Module): 244 | def __init__(self, encoder, decoder): 245 | super(Generator, self).__init__() 246 | self.encoder = encoder 247 | self.decoder = decoder 248 | 249 | def forward(self, inp, stddev, return_latent=False): 250 | code = self.encoder(inp) 251 | chan_noise = torch.randn_like(code) * stddev 252 | y = code + chan_noise 253 | reconst = self.decoder(y) 254 | 255 | if return_latent: 256 | return reconst, (y, code) 257 | else: 258 | return reconst 259 | 260 | 261 | -------------------------------------------------------------------------------- /models/bfcnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code from https://github.com/LabForComputationalVision/universal_inverse_problem/blob/master/code/network.py 3 | """ 4 | 5 | import torch.nn as nn 6 | import torch 7 | 8 | 9 | 10 | ################################################# network class ################################################# 11 | 12 | 13 | class BF_CNN(nn.Module): 14 | 15 | def __init__(self, padding, num_kernels, kernel_size, num_layers, num_channels): 16 | super(BF_CNN, self).__init__() 17 | 18 | self.padding = padding 19 | self.num_kernels = num_kernels 20 | self.kernel_size = kernel_size 21 | self.num_layers = num_layers 22 | self.num_channels = num_channels 23 | 24 | self.conv_layers = nn.ModuleList([]) 25 | self.running_sd = nn.ParameterList([]) 26 | self.gammas = nn.ParameterList([]) 27 | 28 | 29 | self.conv_layers.append(nn.Conv2d(self.num_channels,self.num_kernels, self.kernel_size, padding=self.padding , bias=False)) 30 | 31 | for l in range(1,self.num_layers-1): 32 | self.conv_layers.append(nn.Conv2d(self.num_kernels ,self.num_kernels, self.kernel_size, padding=self.padding , bias=False)) 33 | self.running_sd.append( nn.Parameter(torch.ones(1,self.num_kernels,1,1), requires_grad=False) ) 34 | g = (torch.randn( (1,self.num_kernels,1,1) )*(2./9./64.)).clamp_(-0.025,0.025) 35 | self.gammas.append(nn.Parameter(g, requires_grad=True) ) 36 | 37 | self.conv_layers.append(nn.Conv2d(self.num_kernels,self.num_channels, self.kernel_size, padding=self.padding , bias=False)) 38 | 39 | 40 | def forward(self, x): 41 | relu = nn.ReLU(inplace=True) 42 | x = relu(self.conv_layers[0](x)) 43 | for l in range(1,self.num_layers-1): 44 | x = self.conv_layers[l](x) 45 | # BF_BatchNorm 46 | sd_x = torch.sqrt(x.var(dim=(0,2,3) ,keepdim = True, unbiased=False)+ 1e-05) 47 | 48 | if self.conv_layers[l].training: 49 | x = x / sd_x.expand_as(x) 50 | self.running_sd[l-1].data = (1-.1) * self.running_sd[l-1].data + .1 * sd_x 51 | x = x * self.gammas[l-1].expand_as(x) 52 | 53 | else: 54 | x = x / self.running_sd[l-1].expand_as(x) 55 | x = x * self.gammas[l-1].expand_as(x) 56 | 57 | x = relu(x) 58 | 59 | x = self.conv_layers[-1](x) 60 | 61 | return x 62 | 63 | 64 | -------------------------------------------------------------------------------- /train_bf_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import time 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import json 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.optim as optim 11 | import argparse 12 | 13 | import config 14 | from loader import get_train_dataloader, get_test_dataloader 15 | from models.autoencoders import Generator, Encoder_CIFAR, Decoder_CIFAR, Encoder, Decoder 16 | from models.bfcnn import BF_CNN 17 | from utils import batch2im, PSNR 18 | 19 | parser = config.get_common_parser() 20 | parser = config.get_train_parser(parser) 21 | 22 | args = parser.parse_args() 23 | dev = "cuda:{}".format(args.gpu) if args.gpu>=0 else "cpu" 24 | device = torch.device(dev) 25 | 26 | 27 | np.random.seed(args.seed) 28 | torch.manual_seed(args.seed) 29 | 30 | train_dataloader = get_train_dataloader(args) 31 | test_dataloader = get_test_dataloader(args) 32 | 33 | print(len(train_dataloader)) 34 | print(len(test_dataloader)) 35 | 36 | criterion = nn.MSELoss() 37 | 38 | def test(net, bfcnn, stddev=0., saved_dir=None, writer=None, epoch=0): 39 | net.eval() 40 | bfcnn.eval() 41 | avg_psnr = 0. 42 | avg_loss = 0. 43 | 44 | count = 0. 45 | batch_count = 0. 46 | with torch.no_grad(): 47 | for i, data in enumerate(test_dataloader): 48 | inputs = data[0].to(device) 49 | B = inputs.size(0) 50 | 51 | code = net.encoder(inputs) 52 | noise = torch.randn_like(code)*stddev 53 | residual = bfcnn(code + noise) 54 | loss = criterion(residual, noise) 55 | outputs = net.decoder(code+noise-residual) 56 | 57 | avg_psnr += PSNR(reduction='sum')(outputs, inputs, -1, 1, offset=0) 58 | avg_loss += loss.item() 59 | count += inputs.size(0) 60 | batch_count += 1 61 | if args.show_outputs: 62 | plt.imsave('test.png', batch2im(outputs, 8, 8, 63 | im_height=args.image_size, im_width=args.image_size)) 64 | plt.imsave('test_target.png', batch2im(inputs, 8, 8, 65 | im_height=args.image_size, im_width=args.image_size)) 66 | break 67 | 68 | print('Average PSNR: {:.4f}, Average loss: {:.4f}'.format(avg_psnr/count, avg_loss / batch_count) ) 69 | if writer is not None: 70 | writer.add_scalar('PSNR/test', avg_psnr/count, epoch+1) 71 | 72 | 73 | def train(net, bfcnn, optimizer, num_epoch, stddev=0., 74 | saved_dir=None, model_name=None, which_epoch=0, writer=None, clip_val=5): 75 | 76 | if not os.path.exists(saved_dir): 77 | os.makedirs(saved_dir) 78 | 79 | 80 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs // 2, gamma=0.1) 81 | for epoch in range(which_epoch, num_epoch): # loop over the dataset multiple times 82 | start_time = time.time() 83 | net.eval() 84 | bfcnn.train() 85 | running_loss = 0. 86 | 87 | for i, data in enumerate(train_dataloader): 88 | if args.debug and i==10: 89 | break 90 | optimizer.zero_grad() 91 | 92 | inputs = data[0].to(device) 93 | B = inputs.size(0) 94 | with torch.no_grad(): 95 | code = net.encoder(inputs) 96 | noise = torch.randn_like(code)*stddev 97 | 98 | residual = bfcnn(code + noise) 99 | loss_codeword = criterion(residual, noise) 100 | loss = loss_codeword 101 | loss.backward() 102 | optimizer.step() 103 | 104 | 105 | # print statistics 106 | running_loss += loss.item() 107 | 108 | if i % args.print_freq == args.print_freq-1 or i == len(train_dataloader)-1: 109 | with torch.no_grad(): 110 | code = code[:B,:,:,:] 111 | noise = noise[:B,:,:,:] 112 | residual = residual[:B,:,:,:] 113 | mse_y = criterion(net.decoder(code + noise), inputs) 114 | mse_z_star = criterion(net.decoder(code), inputs) 115 | mse_dn = criterion(net.decoder(code+noise-residual), inputs) 116 | 117 | log_message = "[{:4d}, {:5d}] loss: {:.5f}, MSE y: {:.5f}, MSE z*: {:.5f}, MSE denoised: {:.5f}, ".format(epoch+1, i+1, 118 | running_loss / (i+1), 119 | mse_y.item(), 120 | mse_z_star.item(), 121 | mse_dn.item() 122 | ) 123 | print(log_message) 124 | 125 | scheduler.step() 126 | if writer is not None: 127 | writer.add_scalar("Loss/train", running_loss / (i+1), epoch+1) 128 | if epoch % args.display_freq == args.display_freq-1: 129 | with torch.no_grad(): 130 | outputs_y = net.decoder(code+noise) 131 | outputs_z_star = net.decoder(code) 132 | outputs_dn = net.decoder(code+noise-residual) 133 | 134 | targets = Image.fromarray(batch2im(inputs, 2, 2, 135 | im_height=args.train_image_size, im_width=args.train_image_size)) 136 | targets.save(os.path.join(saved_dir, "e{:03d}_targets.png".format(epoch+1))) 137 | outputs = Image.fromarray(batch2im(outputs_y, 2, 2, 138 | im_height=args.train_image_size, im_width=args.train_image_size)) 139 | outputs.save(os.path.join(saved_dir, "e{:03d}_outputs_y.png".format(epoch+1))) 140 | outputs = Image.fromarray(batch2im(outputs_z_star, 2, 2, 141 | im_height=args.train_image_size, im_width=args.train_image_size)) 142 | outputs.save(os.path.join(saved_dir, "e{:03d}_outputs_z_star.png".format(epoch+1))) 143 | outputs = Image.fromarray(batch2im(outputs_dn, 2, 2, 144 | im_height=args.train_image_size, im_width=args.train_image_size)) 145 | outputs.save(os.path.join(saved_dir, "e{:03d}_outputs_dn.png".format(epoch+1))) 146 | 147 | if epoch % args.save_freq == args.save_freq-1: 148 | torch.save(net.state_dict(), 149 | os.path.join(saved_dir,"{}_{}_e{:03d}.pb".format(model_name, 150 | args.num_channels, 151 | epoch+1))) 152 | torch.save(bfcnn.state_dict(), 153 | os.path.join(saved_dir,"{}_{}_e{:03d}_bfcnn.pb".format(model_name, 154 | args.num_channels, 155 | epoch+1))) 156 | 157 | if epoch % args.test_freq == args.test_freq - 1: 158 | test(net, bfcnn, stddev, saved_dir=saved_dir, writer=writer, epoch=epoch) 159 | 160 | print('Time Taken: %d sec' % (time.time() - start_time)) 161 | 162 | print('Finished Training') 163 | test(net, bfcnn, stddev, saved_dir=saved_dir) 164 | 165 | 166 | def train_model(net, bfcnn, epoch=30, stddev=0., wd=0., model_name="", saved_dir=None, writer=None): 167 | params = bfcnn.parameters() 168 | optimizer = optim.Adam(params, lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay) 169 | train(net, bfcnn, optimizer, 170 | epoch, 171 | stddev=stddev, 172 | saved_dir=saved_dir, 173 | model_name=model_name, 174 | writer=writer) 175 | 176 | 177 | def main(): 178 | if 'cifar' in args.dataset: 179 | Enc = Encoder_CIFAR 180 | Dec = Decoder_CIFAR 181 | else: 182 | Enc = Encoder 183 | Dec = Decoder 184 | encoder = Enc(num_out=args.num_channels, 185 | num_hidden=args.num_hidden, 186 | num_conv_blocks=args.num_conv_blocks, 187 | num_residual_blocks=args.num_residual_blocks, 188 | normalization=nn.BatchNorm2d, 189 | activation=nn.PReLU, 190 | power_norm=args.power_norm) 191 | 192 | decoder = Dec(num_in=args.num_channels, 193 | num_hidden=args.num_hidden, 194 | num_conv_blocks=args.num_conv_blocks, 195 | num_residual_blocks=args.num_residual_blocks, 196 | normalization=nn.BatchNorm2d, 197 | activation=nn.PReLU, 198 | no_tanh=False) 199 | 200 | 201 | net = Generator(encoder, decoder) 202 | 203 | bfcnn = BF_CNN(1, 64, 3, 20, args.num_channels) 204 | 205 | print(args) 206 | 207 | if args.pretrained_model_path is not None: 208 | try: 209 | filepath = args.pretrained_model_path 210 | print("Try loading "+filepath) 211 | net.load_state_dict(torch.load(filepath, map_location=dev)) 212 | except Exception as e: 213 | print(e) 214 | print("Loading Failed. Initializing Networks...") 215 | exit() 216 | 217 | net.to(device) 218 | bfcnn.to(device) 219 | 220 | if args.eval: 221 | test(net, bfcnn, 10**(-0.05*args.snr)) 222 | exit() 223 | 224 | saved_dir = args.model_path 225 | if not os.path.exists(saved_dir): 226 | os.makedirs(saved_dir) 227 | 228 | writer = SummaryWriter(saved_dir) 229 | 230 | with open(os.path.join(saved_dir, 'args.txt'), 'w') as f: 231 | json.dump(vars(args), f, indent=4) 232 | 233 | train_model(net, 234 | bfcnn, 235 | epoch=args.epochs, 236 | stddev=10**(-0.05*args.snr), 237 | model_name=args.model_name, 238 | saved_dir=saved_dir, 239 | writer=writer) 240 | 241 | 242 | torch.save(net.state_dict(), 243 | os.path.join(saved_dir,args.model_name+"_{}.pb".format(args.num_channels))) 244 | torch.save(bfcnn.state_dict(), 245 | os.path.join(saved_dir,args.model_name+"_{}_bfcnn.pb".format(args.num_channels))) 246 | 247 | 248 | 249 | if __name__=='__main__': 250 | main() 251 | -------------------------------------------------------------------------------- /train_deep_jscc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import time 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import json 9 | from torch.utils.tensorboard import SummaryWriter 10 | import torch.optim as optim 11 | 12 | import loader 13 | import models.autoencoders as ae 14 | import utils 15 | import config 16 | 17 | parser = config.get_common_parser() 18 | parser = config.get_train_parser(parser) 19 | 20 | 21 | args = parser.parse_args() 22 | 23 | dev = "cuda:{}".format(args.gpu) if args.gpu>=0 else "cpu" 24 | 25 | device = torch.device(dev) 26 | np.random.seed(args.seed) 27 | torch.manual_seed(args.seed) 28 | 29 | train_dataloader = loader.get_train_dataloader(args) 30 | test_dataloader = loader.get_test_dataloader(args) 31 | 32 | print(len(train_dataloader)) 33 | print(len(test_dataloader)) 34 | 35 | criterion = nn.MSELoss() 36 | 37 | def test(net, stddev=0., saved_dir=None, writer=None, epoch=0): 38 | net.eval() 39 | avg_psnr = 0. 40 | count = 0. 41 | start_time = time.time() 42 | with torch.no_grad(): 43 | for i, data in enumerate(test_dataloader): 44 | inputs = data[0].to(device) 45 | outputs = net(inputs, stddev) 46 | avg_psnr += utils.PSNR(reduction='sum')(outputs, inputs, -1, 1, cuda=True) 47 | count += inputs.size(0) 48 | if args.show_outputs: 49 | plt.imsave('test.png', utils.batch2im(outputs, 3, 3, -1, 1, im_height=args.image_size, im_width=args.image_size)) 50 | break 51 | 52 | if i%10==9: 53 | print('[{:4d} / {:4d}] - Average PSNR: {}, time taken: {:.2f} sec'.format(i+1, 54 | len(test_dataloader), avg_psnr/count, time.time()-start_time)) 55 | 56 | print('[{:4d} / {:4d}] - Average PSNR: {}, time taken: {:.2f} sec'.format(i+1, 57 | len(test_dataloader), avg_psnr/count, time.time()-start_time)) 58 | if writer is not None: 59 | writer.add_scalar('PSNR/test', avg_psnr/count, epoch+1) 60 | 61 | 62 | def train(net, 63 | optimizer_G, 64 | num_epoch, 65 | stddev=0., 66 | saved_dir=None, 67 | model_name=None, 68 | which_epoch=0, 69 | writer=None): 70 | 71 | if not os.path.exists(saved_dir): 72 | os.makedirs(saved_dir) 73 | 74 | if 'cifar' in args.dataset: 75 | scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=args.epochs // 2, gamma=0.1) 76 | elif 'openimages' in args.dataset: 77 | scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=30, gamma=0.1) 78 | else: 79 | raise NotImplementedError() 80 | 81 | for epoch in range(which_epoch, num_epoch): # loop over the dataset multiple times 82 | start_time = time.time() 83 | #train_iter = iter(train_dataloader) 84 | running_loss_mse = 0.0 85 | running_loss = 0.0 86 | net.train() 87 | 88 | for i, data in enumerate(train_dataloader): 89 | if args.debug and i==10: 90 | break 91 | inputs= data[0].to(device) 92 | 93 | # zero the parameter gradients 94 | optimizer_G.zero_grad() 95 | # forward 96 | # stddev: max(0, 𝜎±eps) where eps~N(0, 0.01*I). 97 | 98 | outputs, z = net(inputs, stddev, return_latent=True) 99 | mse_loss = criterion(outputs, inputs) 100 | 101 | loss = mse_loss 102 | loss.backward() 103 | optimizer_G.step() 104 | 105 | # print statistics 106 | running_loss_mse += mse_loss.item() 107 | running_loss += loss.item() 108 | if i % args.print_freq == args.print_freq-1 or i == len(train_dataloader)-1: 109 | print('[%d, %5d] loss: %.4f, mse: %.4f' % 110 | (epoch + 1, i + 1, running_loss / (i+1), running_loss_mse/(i+1))) 111 | 112 | if writer is not None: 113 | writer.add_scalar("Loss/train", running_loss / (i+1), epoch+1) 114 | writer.add_scalar("MSE/train", running_loss_mse / (i+1), epoch+1) 115 | if epoch % args.display_freq == args.display_freq-1: 116 | with torch.no_grad(): 117 | targets = Image.fromarray(utils.batch2im(inputs, 2, 2, -1, 1, 118 | im_height=args.train_image_size, im_width=args.train_image_size)) 119 | targets.save(os.path.join(saved_dir, "e{:03d}_targets.png".format(epoch+1))) 120 | outputs = Image.fromarray(utils.batch2im(outputs, 2, 2, -1, 1, 121 | im_height=args.train_image_size, im_width=args.train_image_size)) 122 | outputs.save(os.path.join(saved_dir, "e{:03d}_outputs.png".format(epoch+1))) 123 | 124 | scheduler_G.step() 125 | print('Time Taken: %d sec' % (time.time() - start_time)) 126 | 127 | if epoch % args.save_freq == args.save_freq-1: 128 | torch.save(net.state_dict(), 129 | os.path.join(saved_dir,"{}_{}_e{:03d}.pb".format(model_name, 130 | args.num_channels, 131 | epoch+1))) 132 | 133 | if epoch % args.test_freq == args.test_freq - 1: 134 | test(net, stddev, saved_dir=saved_dir, writer=writer, epoch=epoch) 135 | 136 | print('Finished Training') 137 | 138 | test(net, stddev, saved_dir=saved_dir) 139 | 140 | def train_model(net, epoch=30, stddev=0., wd=0., model_name="", saved_dir=None, writer=None): 141 | optimizer_G = optim.Adam(net.parameters(), 142 | lr=args.lr, 143 | betas=(0.0, 0.9), 144 | weight_decay=args.weight_decay) 145 | train(net, 146 | optimizer_G, 147 | epoch, 148 | stddev=stddev, 149 | saved_dir=saved_dir, 150 | model_name=model_name, 151 | writer=writer) 152 | 153 | 154 | 155 | 156 | def main(): 157 | if 'cifar' in args.dataset: 158 | Enc = ae.Encoder_CIFAR 159 | Dec = ae.Decoder_CIFAR 160 | else: 161 | Enc = ae.Encoder 162 | Dec = ae.Decoder 163 | encoder = Enc(num_out=args.num_channels, 164 | num_hidden=args.num_hidden, 165 | num_conv_blocks=args.num_conv_blocks, 166 | num_residual_blocks=args.num_residual_blocks, 167 | normalization=nn.BatchNorm2d, 168 | activation=nn.PReLU, 169 | power_norm=args.power_norm) 170 | 171 | decoder = Dec(num_in=args.num_channels, 172 | num_hidden=args.num_hidden, 173 | num_conv_blocks=args.num_conv_blocks, 174 | num_residual_blocks=args.num_residual_blocks, 175 | normalization=nn.BatchNorm2d, 176 | activation=nn.PReLU, 177 | no_tanh=False) 178 | 179 | print(encoder) 180 | print(decoder) 181 | print(args) 182 | net = ae.Generator(encoder, decoder) 183 | 184 | if args.pretrained_model_path is not None: 185 | try: 186 | filepath = args.pretrained_model_path 187 | print("Try loading "+filepath) 188 | net.load_state_dict(torch.load(filepath, map_location=dev)) 189 | except Exception as e: 190 | print(e) 191 | print("Loading Failed. Initializing Networks...") 192 | pass 193 | 194 | net.to(device) 195 | 196 | if args.eval: 197 | test(net, 10**(-0.05*args.snr)) 198 | exit() 199 | 200 | saved_dir = args.model_path 201 | if not os.path.exists(saved_dir): 202 | os.makedirs(saved_dir) 203 | 204 | writer = SummaryWriter(saved_dir) 205 | 206 | with open(os.path.join(saved_dir, 'args.txt'), 'w') as f: 207 | json.dump(vars(args), f, indent=4) 208 | 209 | train_model(net, 210 | epoch=args.epochs, 211 | stddev=10**(-0.05*args.snr), 212 | model_name=args.model_name, 213 | saved_dir=saved_dir, 214 | writer=writer) 215 | torch.save(net.state_dict(), 216 | os.path.join(saved_dir,args.model_name+"_{}.pb".format(args.num_channels))) 217 | 218 | 219 | 220 | if __name__=='__main__': 221 | main() 222 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torchvision.io import read_image, ImageReadMode 5 | from torchvision.utils import make_grid 6 | import matplotlib.pyplot as plt 7 | import torch.autograd as autograd 8 | 9 | import json 10 | 11 | 12 | def save_to_json(stats, args): 13 | params = { 14 | 'nc': args.num_channels, 15 | 'lr': args.lr, 16 | 'stats': stats 17 | } 18 | 19 | if os.path.exists(args.json_file_path): 20 | with open(args.json_file_path) as f: 21 | data = json.load(f) 22 | else: 23 | data = {} 24 | 25 | ind = "snr_train_{:.1f}_snr_{:.1f}_nc_{}".format(args.snr_train, args.snr, args.num_channels) 26 | data.update({ind: params}) 27 | with open(args.json_file_path, "w") as f: 28 | json.dump(data, f, indent=4) 29 | 30 | 31 | 32 | def load_images(files): 33 | images = [] 34 | for f in sorted(files): 35 | images.append(read_image(f, ImageReadMode.RGB)) 36 | return images 37 | 38 | def save_image_collections(img_prefix, num_images, output_dir, nrow=4): 39 | print("Saving image collections...") 40 | target_files = [os.path.join(output_dir, img_prefix, "targets","files", "targets{:04d}.png".format(f_ind)) for f_ind in range(num_images)] 41 | orig_files = [os.path.join(output_dir, img_prefix, "orig","files", "orig{:04d}.png".format(f_ind)) for f_ind in range(num_images)] 42 | updated_files = [os.path.join(output_dir, img_prefix, "updated","files", "updated{:04d}.png".format(f_ind)) for f_ind in range(num_images)] 43 | target_ims = make_grid(load_images(target_files), nrow=nrow).permute(1,2,0).numpy() 44 | orig_ims = make_grid(load_images(orig_files), nrow=nrow).permute(1,2,0).numpy() 45 | updated_ims = make_grid(load_images(updated_files), nrow=nrow).permute(1,2,0).numpy() 46 | plt.imsave(os.path.join(output_dir, img_prefix, "targets_collection.png"), target_ims) 47 | plt.imsave(os.path.join(output_dir, img_prefix, "orig_collection.png"), orig_ims) 48 | plt.imsave(os.path.join(output_dir, img_prefix, "updated_collection.png"), updated_ims) 49 | 50 | 51 | 52 | def tensor2im(x, min_val=0., max_val=1., return_type='numpy'): 53 | with torch.no_grad(): 54 | #x = torch.clamp(x, min_val, max_val) 55 | x = (x - min_val) / (max_val - min_val) * 255. 56 | if return_type == 'numpy': 57 | return x.permute(0,2,3,1).clone().detach().cpu().numpy().round().astype('uint8') 58 | elif return_type == 'torch': 59 | return x.round() 60 | 61 | def batch2im(x, n_row, n_col, min_val=-1., max_val=1., im_height=32, im_width=32): 62 | x = tensor2im(x, min_val, max_val) 63 | 64 | img = np.zeros((im_height*n_row, im_width*n_col, x.shape[-1]), dtype=np.uint8) 65 | 66 | for r in range(n_row): 67 | for c in range(n_col): 68 | img[r*im_height:(r+1)*im_height, c*im_width:(c+1)*im_width, :] = x[r*n_col + c, :, :, :] 69 | return img.squeeze() 70 | 71 | 72 | class PSNR: 73 | """Peak Signal to Noise Ratio 74 | img1 and img2 have range [0, 255]""" 75 | 76 | def __init__(self, reduction='mean'): 77 | self.name = "PSNR" 78 | self.reduction = reduction 79 | 80 | #@staticmethod 81 | def __call__(self, img1, img2, min_val=0., max_val=1., mean_weight=1., cuda=False, offset=0): 82 | if cuda: 83 | with torch.no_grad(): 84 | img1 = torch.clamp(img1, min_val, max_val) 85 | img1 = (img1 - min_val) / (max_val - min_val) * 255. 86 | img2 = torch.clamp(img2, min_val, max_val) 87 | img2 = (img2 - min_val) / (max_val - min_val) * 255. 88 | img1, img2 = img1.round(), img2.round() 89 | mse = torch.mean((img1 - img2)**2, dim=(1,2,3)) 90 | psnr = 20 * torch.log10(255.0 / torch.sqrt(mse)) 91 | psnr = psnr.detach().cpu().numpy() 92 | else: 93 | img1, img2 = tensor2im(img1, min_val, max_val).astype('float64'), tensor2im(img2, min_val, max_val).astype('float64') 94 | if offset > 0: 95 | img1 = img1[:,offset:-offset,offset:-offset,:] 96 | img2 = img2[:,offset:-offset,offset:-offset,:] 97 | mse = np.mean((img1 - img2)**2, axis=(1,2,3)) 98 | psnr = 10 * np.log10(255.0**2 / mse)#np.sqrt(mse)) 99 | if self.reduction == 'mean': 100 | return np.mean(psnr)*mean_weight 101 | elif self.reduction == 'sum': 102 | return np.sum(psnr) 103 | elif self.reduction == 'none': 104 | return psnr 105 | else: 106 | raise NotImplementedError() 107 | 108 | -------------------------------------------------------------------------------- /visualization/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/changwoolee/isec-deep-jscc/43d1a27beaeb89b7057f02bd2fa2dd05c46093e1/visualization/sample.jpg --------------------------------------------------------------------------------