├── .gitignore ├── README.md ├── bpg_ldpc.py ├── classification.py ├── original_implement.py ├── resnet.py ├── resources ├── experiments.png ├── performance_awgn.png ├── performance_fading.png ├── training_valid.png ├── validation_awgn_snr00_c0.04_e1000.png ├── validation_awgn_snr00_c0.09_e1000.png ├── validation_awgn_snr00_c0.17_e1000.png ├── validation_awgn_snr00_c0.25_e1000.png ├── validation_awgn_snr00_c0.33_e1000.png ├── validation_awgn_snr00_c0.42_e1000.png ├── validation_awgn_snr00_c0.49_e1000.png ├── validation_awgn_snr10_c0.04_e1000.png ├── validation_awgn_snr10_c0.09_e1000.png ├── validation_awgn_snr10_c0.17_e1000.png ├── validation_awgn_snr10_c0.25_e1000.png ├── validation_awgn_snr10_c0.33_e1000.png ├── validation_awgn_snr10_c0.42_e1000.png ├── validation_awgn_snr10_c0.49_e1000.png ├── validation_awgn_snr20_c0.04_e1000.png ├── validation_awgn_snr20_c0.09_e1000.png ├── validation_awgn_snr20_c0.17_e1000.png ├── validation_awgn_snr20_c0.25_e1000.png ├── validation_awgn_snr20_c0.33_e1000.png ├── validation_awgn_snr20_c0.42_e1000.png └── validation_awgn_snr20_c0.49_e1000.png ├── torch_impl.py ├── visualization.md └── wideresnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .git/ 2 | .vscode/ 3 | checkpoints/ 4 | data/ 5 | validation_imgs/ 6 | train_logs/ 7 | validation.mp4 8 | generate_videos.py 9 | test.py 10 | push2github.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 14 | 20 | 21 | 22 | 23 | 24 | # Launch Records 25 | 26 | ## Introduction 27 | 28 | Reimplement [Deep Joint Source-Channel Coding for Wireless Image Transmission](https://arxiv.org/abs/1809.01733) in Pytorch. 29 | 30 | ![awgn_performance](resources/performance_awgn.png) 31 | ![slowfading_performance](resources/performance_fading.png) 32 | 33 | Thanks to [irdanish11's implemantation](https://github.com/irdanish11/DJSCC-for-Wireless-Image-Transmission) and [Ahmedest61's implemantation](https://github.com/Ahmedest61/D-JSCC). 34 | 35 | ## Technical Solution 36 | 37 | Using an `AutoEncoder`to compress image from `[b, 3, H, W]` to feature maps with shape of`[b, c, h, w]`, feed into channels `[AWGN, Slow Fading Channel]` after power constraint and recover. 38 | 39 | 40 | ## Experimental setup 41 | 42 | Use `Adam` optimizer,`batch size` set to `64`, `learning rate` set to `1e-3`, and update to `1e-4` after `the 640-th epoch`. Train `1000 epochs` in total. 43 | 44 | Train with `SNR` and `compression rate`, where`SNR`varies in `[0, 10, 20]`,`compression rate` varies in `[0.04, 0.09, 0.17, 0.25, 0.33, 0.42, 0.49]`, namely `channel width` varies in `[2, 4, 8, 12, 16, 20, 24]`. 45 | 46 | 47 | 48 | 49 | ## Model Metric 50 | 51 | - Loss Function:`MSE Loss` 52 | 53 | - Performance Metric:`PSNR` 54 | 55 | - Computational Cost:`20s * 1000 epochs / 3600 ~= 5.6h` with single `4090Ti` 56 | 57 | ## Experimental results 58 | 59 | Validation loss when training. 60 | 61 | ![Training](resources/training_valid.png) 62 | 63 | 64 | Pre-fix "EXP" means the experimental results of this reimplement, "REP" means the performance reported in the lecture. 65 | 66 | ![exp_performance](resources/experiments.png) 67 | 68 | 69 | See [Visualization](visualization.md) for details. 70 | 71 | # BPG-LDPC simulation 72 | 73 | ``` 74 | SNR=0, bw=0.083333, k=3072, n=6144, m=02, PSNR=19.61, SSIM=0.59 75 | SNR=0, bw=0.083333, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 76 | SNR=0, bw=0.083333, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 77 | SNR=0, bw=0.083333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 78 | SNR=0, bw=0.083333, k=3072, n=4608, m=02, PSNR=7.79, SSIM=0.10 79 | SNR=0, bw=0.083333, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 80 | SNR=0, bw=0.083333, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 81 | SNR=0, bw=0.083333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 82 | SNR=0, bw=0.083333, k=1536, n=4608, m=02, PSNR=19.61, SSIM=0.59 83 | SNR=0, bw=0.083333, k=1536, n=4608, m=04, PSNR=20.09, SSIM=0.62 84 | SNR=0, bw=0.083333, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 85 | SNR=0, bw=0.083333, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 86 | SNR=0, bw=0.166667, k=3072, n=6144, m=02, PSNR=21.56, SSIM=0.71 87 | SNR=0, bw=0.166667, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 88 | SNR=0, bw=0.166667, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 89 | SNR=0, bw=0.166667, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 90 | SNR=0, bw=0.166667, k=3072, n=4608, m=02, PSNR=7.64, SSIM=0.09 91 | SNR=0, bw=0.166667, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 92 | SNR=0, bw=0.166667, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 93 | SNR=0, bw=0.166667, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 94 | SNR=0, bw=0.166667, k=1536, n=4608, m=02, PSNR=20.09, SSIM=0.62 95 | SNR=0, bw=0.166667, k=1536, n=4608, m=04, PSNR=22.40, SSIM=0.75 96 | SNR=0, bw=0.166667, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 97 | SNR=0, bw=0.166667, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 98 | SNR=0, bw=0.250000, k=3072, n=6144, m=02, PSNR=22.90, SSIM=0.77 99 | SNR=0, bw=0.250000, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 100 | SNR=0, bw=0.250000, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 101 | SNR=0, bw=0.250000, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 102 | SNR=0, bw=0.250000, k=3072, n=4608, m=02, PSNR=7.60, SSIM=0.08 103 | SNR=0, bw=0.250000, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 104 | SNR=0, bw=0.250000, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 105 | SNR=0, bw=0.250000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 106 | SNR=0, bw=0.250000, k=1536, n=4608, m=02, PSNR=21.56, SSIM=0.71 107 | SNR=0, bw=0.250000, k=1536, n=4608, m=04, PSNR=24.17, SSIM=0.82 108 | SNR=0, bw=0.250000, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 109 | SNR=0, bw=0.250000, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 110 | SNR=0, bw=0.333333, k=3072, n=6144, m=02, PSNR=24.17, SSIM=0.82 111 | SNR=0, bw=0.333333, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 112 | SNR=0, bw=0.333333, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 113 | SNR=0, bw=0.333333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 114 | SNR=0, bw=0.333333, k=3072, n=4608, m=02, PSNR=7.54, SSIM=0.08 115 | SNR=0, bw=0.333333, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 116 | SNR=0, bw=0.333333, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 117 | SNR=0, bw=0.333333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 118 | SNR=0, bw=0.333333, k=1536, n=4608, m=02, PSNR=22.40, SSIM=0.75 119 | SNR=0, bw=0.333333, k=1536, n=4608, m=04, PSNR=25.49, SSIM=0.86 120 | SNR=0, bw=0.333333, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 121 | SNR=0, bw=0.333333, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 122 | SNR=0, bw=0.500000, k=3072, n=6144, m=02, PSNR=26.19, SSIM=0.88 123 | SNR=0, bw=0.500000, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 124 | SNR=0, bw=0.500000, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 125 | SNR=0, bw=0.500000, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 126 | SNR=0, bw=0.500000, k=3072, n=4608, m=02, PSNR=7.50, SSIM=0.07 127 | SNR=0, bw=0.500000, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 128 | SNR=0, bw=0.500000, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 129 | SNR=0, bw=0.500000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 130 | SNR=0, bw=0.500000, k=1536, n=4608, m=02, PSNR=24.17, SSIM=0.82 131 | SNR=0, bw=0.500000, k=1536, n=4608, m=04, PSNR=27.94, SSIM=0.91 132 | SNR=0, bw=0.500000, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 133 | SNR=0, bw=0.500000, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 134 | SNR=10, bw=0.083333, k=3072, n=6144, m=02, PSNR=19.61, SSIM=0.59 135 | SNR=10, bw=0.083333, k=3072, n=6144, m=04, PSNR=21.56, SSIM=0.71 136 | SNR=10, bw=0.083333, k=3072, n=6144, m=16, PSNR=24.17, SSIM=0.82 137 | SNR=10, bw=0.083333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 138 | SNR=10, bw=0.083333, k=3072, n=4608, m=02, PSNR=20.09, SSIM=0.62 139 | SNR=10, bw=0.083333, k=3072, n=4608, m=04, PSNR=22.40, SSIM=0.75 140 | SNR=10, bw=0.083333, k=3072, n=4608, m=16, PSNR=25.12, SSIM=0.85 141 | SNR=10, bw=0.083333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 142 | SNR=10, bw=0.083333, k=1536, n=4608, m=02, PSNR=19.61, SSIM=0.59 143 | SNR=10, bw=0.083333, k=1536, n=4608, m=04, PSNR=20.09, SSIM=0.62 144 | SNR=10, bw=0.083333, k=1536, n=4608, m=16, PSNR=22.40, SSIM=0.75 145 | SNR=10, bw=0.083333, k=1536, n=4608, m=64, PSNR=24.17, SSIM=0.82 146 | SNR=10, bw=0.166667, k=3072, n=6144, m=02, PSNR=21.56, SSIM=0.71 147 | SNR=10, bw=0.166667, k=3072, n=6144, m=04, PSNR=24.17, SSIM=0.82 148 | SNR=10, bw=0.166667, k=3072, n=6144, m=16, PSNR=27.94, SSIM=0.91 149 | SNR=10, bw=0.166667, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 150 | SNR=10, bw=0.166667, k=3072, n=4608, m=02, PSNR=22.40, SSIM=0.75 151 | SNR=10, bw=0.166667, k=3072, n=4608, m=04, PSNR=25.49, SSIM=0.86 152 | SNR=10, bw=0.166667, k=3072, n=4608, m=16, PSNR=29.75, SSIM=0.94 153 | SNR=10, bw=0.166667, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 154 | SNR=10, bw=0.166667, k=1536, n=4608, m=02, PSNR=20.09, SSIM=0.62 155 | SNR=10, bw=0.166667, k=1536, n=4608, m=04, PSNR=22.40, SSIM=0.75 156 | SNR=10, bw=0.166667, k=1536, n=4608, m=16, PSNR=25.49, SSIM=0.86 157 | SNR=10, bw=0.166667, k=1536, n=4608, m=64, PSNR=27.94, SSIM=0.91 158 | SNR=10, bw=0.250000, k=3072, n=6144, m=02, PSNR=22.90, SSIM=0.77 159 | SNR=10, bw=0.250000, k=3072, n=6144, m=04, PSNR=26.19, SSIM=0.88 160 | SNR=10, bw=0.250000, k=3072, n=6144, m=16, PSNR=30.55, SSIM=0.95 161 | SNR=10, bw=0.250000, k=3072, n=6144, m=64, PSNR=6.59, SSIM=0.12 162 | SNR=10, bw=0.250000, k=3072, n=4608, m=02, PSNR=24.17, SSIM=0.82 163 | SNR=10, bw=0.250000, k=3072, n=4608, m=04, PSNR=27.94, SSIM=0.91 164 | SNR=10, bw=0.250000, k=3072, n=4608, m=16, PSNR=32.85, SSIM=0.97 165 | SNR=10, bw=0.250000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 166 | SNR=10, bw=0.250000, k=1536, n=4608, m=02, PSNR=21.56, SSIM=0.71 167 | SNR=10, bw=0.250000, k=1536, n=4608, m=04, PSNR=24.17, SSIM=0.82 168 | SNR=10, bw=0.250000, k=1536, n=4608, m=16, PSNR=27.94, SSIM=0.91 169 | SNR=10, bw=0.250000, k=1536, n=4608, m=64, PSNR=30.55, SSIM=0.95 170 | SNR=10, bw=0.333333, k=3072, n=6144, m=02, PSNR=24.17, SSIM=0.82 171 | SNR=10, bw=0.333333, k=3072, n=6144, m=04, PSNR=27.94, SSIM=0.91 172 | SNR=10, bw=0.333333, k=3072, n=6144, m=16, PSNR=32.85, SSIM=0.97 173 | SNR=10, bw=0.333333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 174 | SNR=10, bw=0.333333, k=3072, n=4608, m=02, PSNR=25.49, SSIM=0.86 175 | SNR=10, bw=0.333333, k=3072, n=4608, m=04, PSNR=29.75, SSIM=0.94 176 | SNR=10, bw=0.333333, k=3072, n=4608, m=16, PSNR=35.25, SSIM=0.98 177 | SNR=10, bw=0.333333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 178 | SNR=10, bw=0.333333, k=1536, n=4608, m=02, PSNR=22.40, SSIM=0.75 179 | SNR=10, bw=0.333333, k=1536, n=4608, m=04, PSNR=25.49, SSIM=0.86 180 | SNR=10, bw=0.333333, k=1536, n=4608, m=16, PSNR=29.75, SSIM=0.94 181 | SNR=10, bw=0.333333, k=1536, n=4608, m=64, PSNR=32.85, SSIM=0.97 182 | SNR=10, bw=0.500000, k=3072, n=6144, m=02, PSNR=26.19, SSIM=0.88 183 | SNR=10, bw=0.500000, k=3072, n=6144, m=04, PSNR=30.55, SSIM=0.95 184 | SNR=10, bw=0.500000, k=3072, n=6144, m=16, PSNR=36.48, SSIM=0.98 185 | SNR=10, bw=0.500000, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 186 | SNR=10, bw=0.500000, k=3072, n=4608, m=02, PSNR=27.94, SSIM=0.91 187 | SNR=10, bw=0.500000, k=3072, n=4608, m=04, PSNR=32.85, SSIM=0.97 188 | SNR=10, bw=0.500000, k=3072, n=4608, m=16, PSNR=38.58, SSIM=0.97 189 | SNR=10, bw=0.500000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 190 | SNR=10, bw=0.500000, k=1536, n=4608, m=02, PSNR=24.17, SSIM=0.82 191 | SNR=10, bw=0.500000, k=1536, n=4608, m=04, PSNR=27.94, SSIM=0.91 192 | SNR=10, bw=0.500000, k=1536, n=4608, m=16, PSNR=32.85, SSIM=0.97 193 | SNR=10, bw=0.500000, k=1536, n=4608, m=64, PSNR=36.48, SSIM=0.98 194 | ``` 195 | 196 | 198 | 199 | 220 | -------------------------------------------------------------------------------- /bpg_ldpc.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | import os 3 | import math 4 | import numpy as np 5 | 6 | from PIL import Image 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | 10 | from sionna.mapping import Constellation, Mapper, Demapper 11 | from sionna.fec.ldpc import LDPC5GEncoder, LDPC5GDecoder 12 | from sionna.utils import ebnodb2no 13 | from sionna.channel import AWGN, FlatFadingChannel 14 | 15 | 16 | def imBatchtoImage(batch_images): 17 | ''' 18 | turns b, 32, 32, 3 images into single sqrt(b) * 32, sqrt(b) * 32, 3 image. 19 | ''' 20 | batch, h, w, c = batch_images.shape 21 | b = int(batch ** 0.5) 22 | 23 | divisor = b 24 | while batch % divisor != 0: 25 | divisor -= 1 26 | 27 | image = tf.reshape(batch_images, (-1, batch//divisor, h, w, c)) 28 | image = tf.transpose(image, [0, 2, 1, 3, 4]) 29 | image = tf.reshape(image, (-1, batch//divisor*w, c)) 30 | return image 31 | 32 | 33 | class BPGEncoder(): 34 | def __init__(self, working_directory='./analysis/temp'): 35 | ''' 36 | working_directory: directory to save temp files 37 | do not include '/' in the end 38 | ''' 39 | self.working_directory = working_directory 40 | 41 | def run_bpgenc(self, qp, input_dir, output_dir='temp.bpg'): 42 | if os.path.exists(output_dir): 43 | os.remove(output_dir) 44 | os.system(f'bpgenc {input_dir} -q {qp} -o {output_dir} -f 444') 45 | 46 | if os.path.exists(output_dir): 47 | return os.path.getsize(output_dir) 48 | else: 49 | return -1 50 | 51 | def get_qp(self, input_dir, byte_threshold, output_dir='temp.bpg'): 52 | ''' 53 | iteratively finds quality parameter that maximizes quality given the byte_threshold constraint 54 | ''' 55 | # rate-match algorithm 56 | quality_max = 51 57 | quality_min = 0 58 | quality = (quality_max - quality_min) // 2 59 | 60 | while True: 61 | qp = 51 - quality 62 | bytes = self.run_bpgenc(qp, input_dir, output_dir) 63 | if quality == 0 or quality == quality_min or quality == quality_max: 64 | break 65 | elif bytes > byte_threshold and quality_min != quality - 1: 66 | quality_max = quality 67 | quality -= (quality - quality_min) // 2 68 | elif bytes > byte_threshold and quality_min == quality - 1: 69 | quality_max = quality 70 | quality -= 1 71 | elif bytes < byte_threshold and quality_max > quality: 72 | quality_min = quality 73 | quality += (quality_max - quality) // 2 74 | else: 75 | break 76 | 77 | return qp 78 | 79 | def encode(self, image_array, max_bytes, header_bytes=22): 80 | ''' 81 | image_array: uint8 numpy array with shape (b, h, w, c) 82 | max_bytes: int, maximum bytes of the encoded image file (exlcuding header bytes) 83 | header_bytes: the size of BPG header bytes (to be excluded in image file size calculation) 84 | ''' 85 | 86 | input_dir = f'{self.working_directory}/temp_enc.png' 87 | output_dir = f'{self.working_directory}/temp_enc.bpg' 88 | 89 | im = Image.fromarray(image_array, 'RGB') 90 | im.save(input_dir) 91 | 92 | qp = self.get_qp(input_dir, max_bytes + header_bytes, output_dir) 93 | 94 | if self.run_bpgenc(qp, input_dir, output_dir) < 0: 95 | raise RuntimeError("BPG encoding failed") 96 | 97 | # read binary and convert it to numpy binary array with float dtype 98 | return np.unpackbits(np.fromfile(output_dir, dtype=np.uint8)).astype(np.float32) 99 | 100 | 101 | class LDPCTransmitter(): 102 | ''' 103 | Transmits given bits (float array of '0' and '1') with LDPC. 104 | ''' 105 | 106 | def __init__(self, k, n, m, esno_db, channel='AWGN'): 107 | ''' 108 | k: data bits per codeword (in LDPC) 109 | n: total codeword bits (in LDPC) 110 | m: modulation order (in m-QAM) 111 | esno_db: channel SNR 112 | channel: 'AWGN' or 'Rayleigh' 113 | ''' 114 | self.k = k 115 | self.n = n 116 | self.num_bits_per_symbol = round(math.log2(m)) 117 | 118 | constellation_type = 'qam' if m != 2 else 'pam' 119 | self.constellation = Constellation( 120 | constellation_type, num_bits_per_symbol=self.num_bits_per_symbol) 121 | self.mapper = Mapper(constellation=self.constellation) 122 | self.demapper = Demapper('app', constellation=self.constellation) 123 | self.channel = AWGN() if channel == 'AWGN' else FlatFadingChannel 124 | self.encoder = LDPC5GEncoder(k=self.k, n=self.n) 125 | self.decoder = LDPC5GDecoder(self.encoder, num_iter=20) 126 | self.esno_db = esno_db 127 | 128 | def send(self, source_bits): 129 | ''' 130 | source_bits: float np array of '0' and '1', whose total # of bits is divisible with k 131 | ''' 132 | lcm = np.lcm(self.k, self.num_bits_per_symbol) 133 | source_bits_pad = tf.pad( 134 | source_bits, [[0, math.ceil(len(source_bits)/lcm)*lcm - len(source_bits)]]) 135 | u = np.reshape(source_bits_pad, (-1, self.k)) 136 | 137 | no = ebnodb2no(self.esno_db, num_bits_per_symbol=1, coderate=1) 138 | c = self.encoder(u) 139 | x = self.mapper(c) 140 | y = self.channel([x, no]) 141 | llr_ch = self.demapper([y, no]) 142 | u_hat = self.decoder(llr_ch) 143 | 144 | return tf.reshape(u_hat, (-1))[:len(source_bits)] 145 | 146 | 147 | class BPGDecoder(): 148 | def __init__(self, working_directory='./analysis/temp'): 149 | ''' 150 | working_directory: directory to save temp files 151 | do not include '/' in the end 152 | ''' 153 | self.working_directory = working_directory 154 | 155 | def run_bpgdec(self, input_dir, output_dir='temp.png'): 156 | if os.path.exists(output_dir): 157 | os.remove(output_dir) 158 | os.system(f'bpgdec {input_dir} -o {output_dir}') 159 | 160 | if os.path.exists(output_dir): 161 | return os.path.getsize(output_dir) 162 | else: 163 | return -1 164 | 165 | def decode(self, bit_array, image_shape): 166 | ''' 167 | returns decoded result of given bit_array. 168 | if bit_array is not decodable, then returns the mean CIFAR-10 pixel values. 169 | 170 | byte_array: float array of '0' and '1' 171 | image_shape: used to generate image with mean pixel values if the given byte_array is not decodable 172 | ''' 173 | input_dir = f'{self.working_directory}/temp_dec.bpg' 174 | output_dir = f'{self.working_directory}/temp_dec.png' 175 | 176 | byte_array = np.packbits(bit_array.astype(np.uint8)) 177 | with open(input_dir, "wb") as binary_file: 178 | binary_file.write(byte_array.tobytes()) 179 | 180 | cifar_mean = np.array( 181 | [0.4913997551666284, 0.48215855929893703, 0.4465309133731618]) * 255 182 | cifar_mean = np.reshape( 183 | cifar_mean, [1] * (len(image_shape) - 1) + [3]).astype(np.uint8) 184 | 185 | if self.run_bpgdec(input_dir, output_dir) < 0: 186 | # print('warning: Decode failed. Returning mean pixel value') 187 | return 0 * np.ones(image_shape) + cifar_mean 188 | else: 189 | x = np.array(Image.open(output_dir).convert('RGB')) 190 | if x.shape != image_shape: 191 | return 0 * np.ones(image_shape) + cifar_mean 192 | return x 193 | 194 | 195 | (train_images, train_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data() 196 | 197 | # BPG + LDPC 198 | bpgencoder = BPGEncoder() 199 | bpgdecoder = BPGDecoder() 200 | 201 | bw_ratio = [1/12, 1/6, 1/4, 1/3, 1/2] 202 | snrs = [0, 10] 203 | mcs = [(k, n, m) for k, n in [(3072, 6144), (3072, 4608), (1536, 4608)] 204 | for m in (2, 4, 16, 64)] 205 | batchsize = 256 206 | ''' 207 | (3072, 6144), (3072, 4608), (1536, 4608) 208 | BPSK, 4-QAM, 16-QAM, 64-QAM 209 | ''' 210 | 211 | 212 | for esno_db in snrs: 213 | for bw in bw_ratio: 214 | for k, n, m in mcs: 215 | i = 0 216 | psnr = 0 217 | ssim = 0 218 | total_images = 0 219 | ldpctransmitter = LDPCTransmitter(k, n, m, esno_db, 'AWGN') 220 | # for image, _ in tqdm(trainloader): 221 | # for image, _ in [(train_images, train_labels)]: 222 | for start in range(len(train_images)//batchsize): 223 | end = (start+1)*batchsize 224 | start = start*batchsize 225 | image = train_images[start:end] 226 | b, _, _, _ = image.shape 227 | image = tf.cast(imBatchtoImage(image), tf.uint8) 228 | max_bytes = b * 32 * 32 * 3 * bw * math.log2(m) * k / n / 8 229 | src_bits = bpgencoder.encode(image.numpy(), max_bytes) 230 | rcv_bits = ldpctransmitter.send(src_bits) 231 | 232 | decoded_image = bpgdecoder.decode( 233 | rcv_bits.numpy(), image.shape) 234 | total_images += b 235 | psnr = (total_images - b) / (total_images) * psnr + float(b * 236 | tf.image.psnr(decoded_image, image, max_val=255)) / (total_images) 237 | ssim = (total_images - b) / (total_images) * ssim + float(b * tf.image.ssim(tf.cast( 238 | decoded_image, dtype=tf.float32), tf.cast(image, dtype=tf.float32), max_val=255)) / (total_images) 239 | 240 | print( 241 | f'[res] SNR={esno_db}, bw={bw:.6f}, k={k}, n={n}, m={m}, PSNR={psnr:.2f}, SSIM={ssim:.2f}') 242 | 243 | 244 | 245 | 246 | sss = """ 247 | SNR=00, bw=0.083333, k=3072, n=6144, m=02, PSNR=19.61, SSIM=0.59 248 | SNR=00, bw=0.083333, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 249 | SNR=00, bw=0.083333, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 250 | SNR=00, bw=0.083333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 251 | SNR=00, bw=0.083333, k=3072, n=4608, m=02, PSNR=7.79, SSIM=0.10 252 | SNR=00, bw=0.083333, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 253 | SNR=00, bw=0.083333, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 254 | SNR=00, bw=0.083333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 255 | SNR=00, bw=0.083333, k=1536, n=4608, m=02, PSNR=19.61, SSIM=0.59 256 | SNR=00, bw=0.083333, k=1536, n=4608, m=04, PSNR=20.09, SSIM=0.62 257 | SNR=00, bw=0.083333, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 258 | SNR=00, bw=0.083333, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 259 | SNR=00, bw=0.166667, k=3072, n=6144, m=02, PSNR=21.56, SSIM=0.71 260 | SNR=00, bw=0.166667, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 261 | SNR=00, bw=0.166667, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 262 | SNR=00, bw=0.166667, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 263 | SNR=00, bw=0.166667, k=3072, n=4608, m=02, PSNR=7.64, SSIM=0.09 264 | SNR=00, bw=0.166667, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 265 | SNR=00, bw=0.166667, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 266 | SNR=00, bw=0.166667, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 267 | SNR=00, bw=0.166667, k=1536, n=4608, m=02, PSNR=20.09, SSIM=0.62 268 | SNR=00, bw=0.166667, k=1536, n=4608, m=04, PSNR=22.40, SSIM=0.75 269 | SNR=00, bw=0.166667, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 270 | SNR=00, bw=0.166667, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 271 | SNR=00, bw=0.250000, k=3072, n=6144, m=02, PSNR=22.90, SSIM=0.77 272 | SNR=00, bw=0.250000, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 273 | SNR=00, bw=0.250000, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 274 | SNR=00, bw=0.250000, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 275 | SNR=00, bw=0.250000, k=3072, n=4608, m=02, PSNR=7.60, SSIM=0.08 276 | SNR=00, bw=0.250000, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 277 | SNR=00, bw=0.250000, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 278 | SNR=00, bw=0.250000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 279 | SNR=00, bw=0.250000, k=1536, n=4608, m=02, PSNR=21.56, SSIM=0.71 280 | SNR=00, bw=0.250000, k=1536, n=4608, m=04, PSNR=24.17, SSIM=0.82 281 | SNR=00, bw=0.250000, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 282 | SNR=00, bw=0.250000, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 283 | SNR=00, bw=0.333333, k=3072, n=6144, m=02, PSNR=24.17, SSIM=0.82 284 | SNR=00, bw=0.333333, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 285 | SNR=00, bw=0.333333, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 286 | SNR=00, bw=0.333333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 287 | SNR=00, bw=0.333333, k=3072, n=4608, m=02, PSNR=7.54, SSIM=0.08 288 | SNR=00, bw=0.333333, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 289 | SNR=00, bw=0.333333, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 290 | SNR=00, bw=0.333333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 291 | SNR=00, bw=0.333333, k=1536, n=4608, m=02, PSNR=22.40, SSIM=0.75 292 | SNR=00, bw=0.333333, k=1536, n=4608, m=04, PSNR=25.49, SSIM=0.86 293 | SNR=00, bw=0.333333, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 294 | SNR=00, bw=0.333333, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 295 | SNR=00, bw=0.500000, k=3072, n=6144, m=02, PSNR=26.19, SSIM=0.88 296 | SNR=00, bw=0.500000, k=3072, n=6144, m=04, PSNR=6.57, SSIM=0.12 297 | SNR=00, bw=0.500000, k=3072, n=6144, m=16, PSNR=6.57, SSIM=0.12 298 | SNR=00, bw=0.500000, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 299 | SNR=00, bw=0.500000, k=3072, n=4608, m=02, PSNR=7.50, SSIM=0.07 300 | SNR=00, bw=0.500000, k=3072, n=4608, m=04, PSNR=6.57, SSIM=0.12 301 | SNR=00, bw=0.500000, k=3072, n=4608, m=16, PSNR=6.57, SSIM=0.12 302 | SNR=00, bw=0.500000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 303 | SNR=00, bw=0.500000, k=1536, n=4608, m=02, PSNR=24.17, SSIM=0.82 304 | SNR=00, bw=0.500000, k=1536, n=4608, m=04, PSNR=27.94, SSIM=0.91 305 | SNR=00, bw=0.500000, k=1536, n=4608, m=16, PSNR=6.57, SSIM=0.12 306 | SNR=00, bw=0.500000, k=1536, n=4608, m=64, PSNR=6.57, SSIM=0.12 307 | SNR=10, bw=0.083333, k=3072, n=6144, m=02, PSNR=19.61, SSIM=0.59 308 | SNR=10, bw=0.083333, k=3072, n=6144, m=04, PSNR=21.56, SSIM=0.71 309 | SNR=10, bw=0.083333, k=3072, n=6144, m=16, PSNR=24.17, SSIM=0.82 310 | SNR=10, bw=0.083333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 311 | SNR=10, bw=0.083333, k=3072, n=4608, m=02, PSNR=20.09, SSIM=0.62 312 | SNR=10, bw=0.083333, k=3072, n=4608, m=04, PSNR=22.40, SSIM=0.75 313 | SNR=10, bw=0.083333, k=3072, n=4608, m=16, PSNR=25.12, SSIM=0.85 314 | SNR=10, bw=0.083333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 315 | SNR=10, bw=0.083333, k=1536, n=4608, m=02, PSNR=19.61, SSIM=0.59 316 | SNR=10, bw=0.083333, k=1536, n=4608, m=04, PSNR=20.09, SSIM=0.62 317 | SNR=10, bw=0.083333, k=1536, n=4608, m=16, PSNR=22.40, SSIM=0.75 318 | SNR=10, bw=0.083333, k=1536, n=4608, m=64, PSNR=24.17, SSIM=0.82 319 | SNR=10, bw=0.166667, k=3072, n=6144, m=02, PSNR=21.56, SSIM=0.71 320 | SNR=10, bw=0.166667, k=3072, n=6144, m=04, PSNR=24.17, SSIM=0.82 321 | SNR=10, bw=0.166667, k=3072, n=6144, m=16, PSNR=27.94, SSIM=0.91 322 | SNR=10, bw=0.166667, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 323 | SNR=10, bw=0.166667, k=3072, n=4608, m=02, PSNR=22.40, SSIM=0.75 324 | SNR=10, bw=0.166667, k=3072, n=4608, m=04, PSNR=25.49, SSIM=0.86 325 | SNR=10, bw=0.166667, k=3072, n=4608, m=16, PSNR=29.75, SSIM=0.94 326 | SNR=10, bw=0.166667, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 327 | SNR=10, bw=0.166667, k=1536, n=4608, m=02, PSNR=20.09, SSIM=0.62 328 | SNR=10, bw=0.166667, k=1536, n=4608, m=04, PSNR=22.40, SSIM=0.75 329 | SNR=10, bw=0.166667, k=1536, n=4608, m=16, PSNR=25.49, SSIM=0.86 330 | SNR=10, bw=0.166667, k=1536, n=4608, m=64, PSNR=27.94, SSIM=0.91 331 | SNR=10, bw=0.250000, k=3072, n=6144, m=02, PSNR=22.90, SSIM=0.77 332 | SNR=10, bw=0.250000, k=3072, n=6144, m=04, PSNR=26.19, SSIM=0.88 333 | SNR=10, bw=0.250000, k=3072, n=6144, m=16, PSNR=30.55, SSIM=0.95 334 | SNR=10, bw=0.250000, k=3072, n=6144, m=64, PSNR=6.59, SSIM=0.12 335 | SNR=10, bw=0.250000, k=3072, n=4608, m=02, PSNR=24.17, SSIM=0.82 336 | SNR=10, bw=0.250000, k=3072, n=4608, m=04, PSNR=27.94, SSIM=0.91 337 | SNR=10, bw=0.250000, k=3072, n=4608, m=16, PSNR=32.85, SSIM=0.97 338 | SNR=10, bw=0.250000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 339 | SNR=10, bw=0.250000, k=1536, n=4608, m=02, PSNR=21.56, SSIM=0.71 340 | SNR=10, bw=0.250000, k=1536, n=4608, m=04, PSNR=24.17, SSIM=0.82 341 | SNR=10, bw=0.250000, k=1536, n=4608, m=16, PSNR=27.94, SSIM=0.91 342 | SNR=10, bw=0.250000, k=1536, n=4608, m=64, PSNR=30.55, SSIM=0.95 343 | SNR=10, bw=0.333333, k=3072, n=6144, m=02, PSNR=24.17, SSIM=0.82 344 | SNR=10, bw=0.333333, k=3072, n=6144, m=04, PSNR=27.94, SSIM=0.91 345 | SNR=10, bw=0.333333, k=3072, n=6144, m=16, PSNR=32.85, SSIM=0.97 346 | SNR=10, bw=0.333333, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 347 | SNR=10, bw=0.333333, k=3072, n=4608, m=02, PSNR=25.49, SSIM=0.86 348 | SNR=10, bw=0.333333, k=3072, n=4608, m=04, PSNR=29.75, SSIM=0.94 349 | SNR=10, bw=0.333333, k=3072, n=4608, m=16, PSNR=35.25, SSIM=0.98 350 | SNR=10, bw=0.333333, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 351 | SNR=10, bw=0.333333, k=1536, n=4608, m=02, PSNR=22.40, SSIM=0.75 352 | SNR=10, bw=0.333333, k=1536, n=4608, m=04, PSNR=25.49, SSIM=0.86 353 | SNR=10, bw=0.333333, k=1536, n=4608, m=16, PSNR=29.75, SSIM=0.94 354 | SNR=10, bw=0.333333, k=1536, n=4608, m=64, PSNR=32.85, SSIM=0.97 355 | SNR=10, bw=0.500000, k=3072, n=6144, m=02, PSNR=26.19, SSIM=0.88 356 | SNR=10, bw=0.500000, k=3072, n=6144, m=04, PSNR=30.55, SSIM=0.95 357 | SNR=10, bw=0.500000, k=3072, n=6144, m=16, PSNR=36.48, SSIM=0.98 358 | SNR=10, bw=0.500000, k=3072, n=6144, m=64, PSNR=6.57, SSIM=0.12 359 | SNR=10, bw=0.500000, k=3072, n=4608, m=02, PSNR=27.94, SSIM=0.91 360 | SNR=10, bw=0.500000, k=3072, n=4608, m=04, PSNR=32.85, SSIM=0.97 361 | SNR=10, bw=0.500000, k=3072, n=4608, m=16, PSNR=38.58, SSIM=0.97 362 | SNR=10, bw=0.500000, k=3072, n=4608, m=64, PSNR=6.57, SSIM=0.12 363 | SNR=10, bw=0.500000, k=1536, n=4608, m=02, PSNR=24.17, SSIM=0.82 364 | SNR=10, bw=0.500000, k=1536, n=4608, m=04, PSNR=27.94, SSIM=0.91 365 | SNR=10, bw=0.500000, k=1536, n=4608, m=16, PSNR=32.85, SSIM=0.97 366 | SNR=10, bw=0.500000, k=1536, n=4608, m=64, PSNR=36.48, SSIM=0.98 367 | """ -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import time 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torchvision 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from tiny_imagenet import TinyImageNet 12 | from resnet import * 13 | from wideresnet import * 14 | import logging 15 | # from bpg_ldpc import LDPCTransmitter, BPGEncoder, BPGDecoder 16 | from torch_impl import JSCC, Calculate_filters 17 | 18 | os.environ["CUDA_VISIBLE_DEVICES"] = '0, 1, 2' 19 | 20 | """ 21 | 22 | nohup python classification.py --lr 0.1 --wd 2e-4 --epochs 100 --data CIFAR-10 --arch ResNet18 --aug --seed 1 > ./logs/output.log 2>&1 & 23 | tail -f ./logs/output.log 24 | 25 | """ 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch CIFAR TRADES Adversarial Training') 28 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 29 | help='input batch size for training (default: 128)') 30 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', 31 | help='input batch size for testing (default: 128)') 32 | parser.add_argument('--data', default='Tiny-ImageNet', choices=['Tiny-ImageNet', 'CIFAR-10', 'CIFAR-100', 'ImageNet'], help='data') 33 | parser.add_argument('--arch', default='ResNet18', choices=['ResNet18', 'WideResNet34'], help='model') 34 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 35 | help='number of epochs to train') 36 | parser.add_argument('--weight-decay', '--wd', default=2e-4, #5e-4 37 | type=float, metavar='W') 38 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', 39 | help='learning rate') 40 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 41 | help='SGD momentum') 42 | parser.add_argument('--no-cuda', action='store_true', default=False, 43 | help='disables CUDA training') 44 | parser.add_argument('--aug', action='store_true', default=False, 45 | help='data augumentation') 46 | parser.add_argument('--seed', type=int, default=1, metavar='S', 47 | help='random seed (default: 1)') 48 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 49 | help='how many batches to wait before logging training status') 50 | parser.add_argument('--model-dir', default='./checkpoint/baseline/', 51 | help='directory of model for saving checkpoint') 52 | parser.add_argument('--save-freq', '-s', default=5, type=int, metavar='N', 53 | help='save frequency') 54 | 55 | args = parser.parse_args() 56 | 57 | # settings 58 | model_dir = args.model_dir + time.strftime('%Y-%m-%d-%H-%M-%S-', time.localtime()) + args.arch + '-Standard-' + args.data + '-aug-' + str(args.aug) 59 | if not os.path.exists(model_dir): 60 | os.makedirs(model_dir) 61 | 62 | logger = logging.getLogger(__name__) 63 | logging.basicConfig( 64 | format='[%(asctime)s] - %(message)s', 65 | datefmt='%Y/%m/%d %H:%M:%S', 66 | level=logging.INFO, 67 | filename=os.path.join(model_dir, 'train.log')) 68 | logger.info(args) 69 | 70 | 71 | use_cuda = not args.no_cuda and torch.cuda.is_available() 72 | torch.manual_seed(args.seed) 73 | kwargs = {'num_workers': 4, 'pin_memory': False} if use_cuda else {} 74 | # BPG-LDPC 75 | # SNR=10, bw=0.500000, k=3072, n=4608, m=16, PSNR=38.58, SSIM=0.97 76 | # bw = 1 / 2 77 | # esno_db = 10 78 | # k, n, m = 3072, 4608, 16 79 | # b = 256 80 | # max_bytes = b * 32 * 32 * 3 * bw * math.log2(m) * k / n / 8 81 | # ldpctransmitter = LDPCTransmitter(k, n, m, esno_db, 'AWGN') 82 | # bpgencoder = BPGEncoder() 83 | # bpgdecoder = BPGDecoder() 84 | 85 | # JSCC 86 | SNR = 20 87 | CHANNEL_TYPE = "awgn" 88 | COMPRESSION_RATIO = 0.04 89 | EPOCHS = 1000 90 | NUM_WORKERS = 4 91 | LEARNING_RATE = 0.001 92 | CHANNEL_SNR_TRAIN = 10 93 | TRAIN_IMAGE_NUM = 50000 94 | TEST_IMAGE_NUM = 10000 95 | TRAIN_BS = 64 96 | TEST_BS = 4096 97 | K = Calculate_filters(COMPRESSION_RATIO) 98 | net = JSCC(K, snr_db=SNR).cuda() 99 | net.load_state_dict(torch.load("/media/bohnsix/djscc/checkpoints/jscc_model_17")) 100 | 101 | 102 | 103 | 104 | class Normalize(nn.Module): 105 | def __init__(self, mean, std): 106 | super(Normalize, self).__init__() 107 | self.register_buffer('mean', torch.Tensor(mean).to("cuda")) 108 | self.register_buffer('std', torch.Tensor(std).to("cuda")) 109 | 110 | def forward(self, input): 111 | # Broadcasting 112 | mean = self.mean.reshape(1, 3, 1, 1) 113 | std = self.std.reshape(1, 3, 1, 1) 114 | return (input - mean) / std 115 | 116 | 117 | 118 | 119 | 120 | 121 | # setup data loader 122 | if args.data == 'ImageNet': 123 | transform_train = transforms.Compose([ 124 | transforms.Resize(256), 125 | transforms.RandomResizedCrop(224), 126 | transforms.RandomHorizontalFlip(), 127 | transforms.ToTensor(), 128 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 129 | ]) 130 | transform_test = transforms.Compose([ 131 | transforms.Resize(256), 132 | transforms.CenterCrop(224), 133 | transforms.ToTensor(), 134 | # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 135 | ]) 136 | trainset = datasets.ImageFolder('/data/ZNY/data/ImageNet/train', transform_train) 137 | testset = datasets.ImageFolder("/data/ZNY/data/ImageNet/val", transform_test) 138 | norm_layer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 139 | class_number = 1000 140 | size = 224 141 | elif args.data == 'Tiny-ImageNet': 142 | transform_train = transforms.Compose([ 143 | transforms.RandomCrop(64, padding=8, padding_mode="reflect"), 144 | transforms.RandomHorizontalFlip(), 145 | transforms.ToTensor(), 146 | #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 147 | ]) 148 | transform_test = transforms.Compose([ 149 | transforms.ToTensor(), 150 | #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 151 | ]) 152 | trainset = TinyImageNet('../data/tiny-imagenet-200', train=True, transform=transform_train) 153 | testset = TinyImageNet('../data/tiny-imagenet-200', train=False, transform=transform_test) 154 | norm_layer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 155 | class_number = 200 156 | size = 64 157 | elif args.data == 'CIFAR-10': 158 | transform_train = transforms.Compose([ 159 | transforms.RandomCrop(32, padding=4), 160 | transforms.RandomHorizontalFlip(), 161 | transforms.ToTensor(), 162 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), 163 | ]) 164 | transform_test = transforms.Compose([ 165 | transforms.ToTensor(), 166 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), 167 | ]) 168 | trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform_train) 169 | testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform_test) 170 | norm_layer = Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2471, 0.2435, 0.2616]) 171 | class_number = 10 172 | size = 32 173 | elif args.data == 'CIFAR-100': 174 | transform_train = transforms.Compose([ 175 | transforms.RandomCrop(32, padding=4), 176 | transforms.RandomHorizontalFlip(), 177 | transforms.ToTensor(), 178 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), 179 | ]) 180 | transform_test = transforms.Compose([ 181 | transforms.ToTensor(), 182 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616)), 183 | ]) 184 | trainset = torchvision.datasets.CIFAR100(root='../data', train=True, download=True, transform=transform_train) 185 | testset = torchvision.datasets.CIFAR100(root='../data', train=False, download=True, transform=transform_test) 186 | norm_layer = Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2471, 0.2435, 0.2616]) 187 | class_number = 100 188 | size = 32 189 | 190 | 191 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, **kwargs) 192 | test_loader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, **kwargs) 193 | 194 | 195 | 196 | 197 | 198 | 199 | def train(args, model, train_loader, optimizer, epoch): 200 | model.train() 201 | for batch_idx, (data, target) in enumerate(train_loader): 202 | # BPG-LDPC 203 | # image = imBatchtoImage(data) 204 | # src_bits = bpgencoder.encode(image.numpy(), max_bytes) 205 | # rcv_bits = ldpctransmitter.send(src_bits) 206 | # decoded_image = bpgdecoder.decode(rcv_bits.numpy(), image.shape) 207 | # data = decoded_image 208 | 209 | #JSCC 210 | decoded_img, chn_out = net(data.cuda()) 211 | data = decoded_img 212 | 213 | data, target = data.cuda(), target.cuda() 214 | optimizer.zero_grad() 215 | loss = F.cross_entropy(model(data), target) 216 | loss.backward() 217 | optimizer.step() 218 | 219 | # print progress 220 | if batch_idx % args.log_interval == 0: 221 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 222 | epoch, batch_idx * len(data), len(train_loader.dataset), 223 | 100. * batch_idx / len(train_loader), loss.item())) 224 | logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 225 | epoch, batch_idx * len(data), len(train_loader.dataset), 226 | 100. * batch_idx / len(train_loader), loss.item())) 227 | 228 | 229 | def eval_train(model, train_loader): 230 | model.eval() 231 | train_loss = 0 232 | correct = 0 233 | with torch.no_grad(): 234 | for data, target in train_loader: 235 | # BPG-LDPC 236 | # image = imBatchtoImage(data) 237 | # src_bits = bpgencoder.encode(image.numpy(), max_bytes) 238 | # rcv_bits = ldpctransmitter.send(src_bits) 239 | # decoded_image = bpgdecoder.decode(rcv_bits.numpy(), image.shape) 240 | # data = decoded_image 241 | 242 | # JSCC 243 | decoded_img, chn_out = net(data.cuda()) 244 | data = decoded_img 245 | 246 | data, target = data.cuda(), target.cuda() 247 | output = model(data) 248 | train_loss += F.cross_entropy(output, target, reduction='mean').item() 249 | pred = output.max(1, keepdim=True)[1] 250 | correct += pred.eq(target.view_as(pred)).sum().item() 251 | train_loss /= len(train_loader.dataset) 252 | print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 253 | train_loss, correct, len(train_loader.dataset), 254 | 100. * correct / len(train_loader.dataset))) 255 | logger.info('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 256 | train_loss, correct, len(train_loader.dataset), 257 | 100. * correct / len(train_loader.dataset))) 258 | training_accuracy = correct / len(train_loader.dataset) 259 | return train_loss, training_accuracy 260 | 261 | 262 | def eval_test(model, test_loader): 263 | model.eval() 264 | test_loss = 0 265 | correct = 0 266 | with torch.no_grad(): 267 | for data, target in test_loader: 268 | # BPG-LDPC 269 | # image = imBatchtoImage(data) 270 | # src_bits = bpgencoder.encode(image.numpy(), max_bytes) 271 | # rcv_bits = ldpctransmitter.send(src_bits) 272 | # decoded_image = bpgdecoder.decode(rcv_bits.numpy(), image.shape) 273 | # data = decoded_image 274 | 275 | # JSCC 276 | decoded_img, chn_out = net(data.cuda()) 277 | data = decoded_img 278 | 279 | data, target = data.cuda(), target.cuda() 280 | output = model(data) 281 | test_loss += F.cross_entropy(output, target, reduction='mean').item() 282 | pred = output.max(1, keepdim=True)[1] 283 | correct += pred.eq(target.view_as(pred)).sum().item() 284 | test_loss /= len(test_loader.dataset) 285 | print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 286 | test_loss, correct, len(test_loader.dataset), 287 | 100. * correct / len(test_loader.dataset))) 288 | logger.info('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 289 | test_loss, correct, len(test_loader.dataset), 290 | 100. * correct / len(test_loader.dataset))) 291 | test_accuracy = correct / len(test_loader.dataset) 292 | return test_loss, test_accuracy 293 | 294 | 295 | 296 | 297 | 298 | 299 | def adjust_learning_rate(optimizer, epoch): 300 | """decrease the learning rate""" 301 | lr = args.lr 302 | if epoch >= 75: 303 | lr = args.lr * 0.1 304 | if epoch >= 90: 305 | lr = args.lr * 0.01 306 | if epoch >= 100: 307 | lr = args.lr * 0.001 308 | for param_group in optimizer.param_groups: 309 | param_group['lr'] = lr 310 | 311 | 312 | def imBatchtoImage(batch_images): 313 | ''' 314 | turns b, 32, 32, 3 images into single sqrt(b) * 32, sqrt(b) * 32, 3 image. 315 | ''' 316 | batch, h, w, c = batch_images.shape 317 | b = int(batch ** 0.5) 318 | 319 | divisor = b 320 | while batch % divisor != 0: 321 | divisor -= 1 322 | 323 | image = batch_images.reshape(-1, batch//divisor, h, w, c) 324 | image = image.transpose(0, 2, 1, 3, 4) 325 | image = image.reshape(-1, batch//divisor*w, c) 326 | 327 | return torch.round(image) 328 | 329 | 330 | def main(): 331 | if args.arch == 'ResNet18': 332 | model = ResNet18(size, class_number).cuda() 333 | elif args.arch == 'WideResNet34': 334 | model = WideResNet(image_size=size, depth=34, num_classes=class_number, widen_factor=10).cuda() 335 | 336 | if args.aug: 337 | model = nn.Sequential(norm_layer, model).cuda() 338 | model = torch.nn.DataParallel(model).cuda() 339 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 340 | 341 | for epoch in range(1, args.epochs + 1): 342 | adjust_learning_rate(optimizer, epoch) 343 | start_time = time.time() 344 | train(args, model, train_loader, optimizer, epoch) 345 | print('using time:', time.time() - start_time) 346 | logger.info('using time: {}'.format(time.time() - start_time)) 347 | 348 | _, _ = eval_train(model, train_loader) 349 | _, _ = eval_test(model, test_loader) 350 | 351 | logger.info('================================================================') 352 | 353 | torch.save(optimizer.state_dict(), 354 | os.path.join(model_dir, 'opt-last.tar')) 355 | torch.save(model.module.state_dict(), 356 | os.path.join(model_dir, 'model-last.pth')) 357 | 358 | # save checkpoint 359 | if epoch % args.save_freq == 0: 360 | torch.save(model.module.state_dict(), 361 | os.path.join(model_dir, 'model-{}.pth'.format(epoch))) 362 | 363 | if __name__ == '__main__': 364 | main() 365 | 366 | -------------------------------------------------------------------------------- /original_implement.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 4 | import glob 5 | import time 6 | from datetime import datetime 7 | import tensorflow as tf 8 | import numpy as np 9 | import configargparse 10 | from tensorflow.keras import layers 11 | from tensorflow.keras import datasets 12 | import tensorflow_compression as tfc 13 | 14 | 15 | def psnr_metric(x_in, x_out): 16 | if type(x_in) is list: 17 | img_in = x_in[0] 18 | else: 19 | img_in = x_in 20 | return tf.image.psnr(img_in, x_out, max_val=1.0) 21 | 22 | 23 | class Encoder(layers.Layer): 24 | """Build encoder arch""" 25 | 26 | def __init__(self, conv_depth, name="encoder", **kwargs): 27 | super(Encoder, self).__init__(name=name, **kwargs) 28 | self.data_format = "channels_last" 29 | self.sublayers = [ 30 | tfc.SignalConv2D( 31 | 16, 32 | (5, 5), 33 | name="conv_1", 34 | corr=True, 35 | strides_down=2, 36 | padding="same_zeros", 37 | use_bias=True, 38 | ), 39 | layers.PReLU(shared_axes=[1, 2]), 40 | tfc.SignalConv2D( 41 | 32, 42 | (5, 5), 43 | name="conv_2", 44 | corr=True, 45 | strides_down=2, 46 | padding="same_zeros", 47 | use_bias=True, 48 | ), 49 | layers.PReLU(shared_axes=[1, 2]), 50 | tfc.SignalConv2D( 51 | 32, 52 | (5, 5), 53 | name="conv_3", 54 | corr=True, 55 | strides_down=1, 56 | padding="same_zeros", 57 | use_bias=True, 58 | ), 59 | layers.PReLU(shared_axes=[1, 2]), 60 | tfc.SignalConv2D( 61 | 32, 62 | (5, 5), 63 | name="conv_4", 64 | corr=True, 65 | strides_down=1, 66 | padding="same_zeros", 67 | use_bias=True, 68 | ), 69 | layers.PReLU(shared_axes=[1, 2]), 70 | tfc.SignalConv2D( 71 | conv_depth, 72 | (5, 5), 73 | name="conv_5", 74 | corr=True, 75 | strides_down=1, 76 | padding="same_zeros", 77 | use_bias=True, 78 | activation=None, 79 | ), 80 | ] 81 | 82 | def call(self, x): 83 | for sublayer in self.sublayers: 84 | x = sublayer(x) 85 | return x 86 | 87 | 88 | class Decoder(layers.Layer): 89 | """Build decoder arch""" 90 | 91 | def __init__(self, n_channels, name="decoder", **kwargs): 92 | super(Decoder, self).__init__(name=name, **kwargs) 93 | self.data_format = "channels_last" 94 | self.sublayers = [ 95 | tfc.SignalConv2D( 96 | 32, 97 | (5, 5), 98 | name="conv_1", 99 | corr=False, 100 | strides_up=1, 101 | padding="same_zeros", 102 | use_bias=True, 103 | ), 104 | layers.PReLU(shared_axes=[1, 2]), 105 | tfc.SignalConv2D( 106 | 32, 107 | (5, 5), 108 | name="conv_2", 109 | corr=False, 110 | strides_up=1, 111 | padding="same_zeros", 112 | use_bias=True, 113 | ), 114 | layers.PReLU(shared_axes=[1, 2]), 115 | tfc.SignalConv2D( 116 | 32, 117 | (5, 5), 118 | name="conv_3", 119 | corr=False, 120 | strides_up=1, 121 | padding="same_zeros", 122 | use_bias=True, 123 | ), 124 | layers.PReLU(shared_axes=[1, 2]), 125 | tfc.SignalConv2D( 126 | 16, 127 | (5, 5), 128 | name="conv_4", 129 | corr=False, 130 | strides_up=2, 131 | padding="same_zeros", 132 | use_bias=True, 133 | ), 134 | layers.PReLU(shared_axes=[1, 2]), 135 | tfc.SignalConv2D( 136 | n_channels, 137 | (5, 5), 138 | name="conv_5", 139 | corr=False, 140 | strides_up=2, 141 | padding="same_zeros", 142 | use_bias=True, 143 | activation=tf.nn.sigmoid, 144 | ), 145 | ] 146 | 147 | def call(self, x): 148 | for sublayer in self.sublayers: 149 | x = sublayer(x) 150 | return x 151 | 152 | 153 | def max_Rate(k, n, snr): 154 | """Implements the maximum rate R (banwidth of the channel). 155 | Args: 156 | k: channel bandwidth 157 | n: image dimension (source bandwidth) 158 | snr: channel signal-to-noise rate 159 | Returns: 160 | Rmax: Max bit rate 161 | """ 162 | Rmax = np.divide(k,n) * math.log2(1+(10**(snr/10))) 163 | 164 | return Rmax 165 | 166 | 167 | def real_awgn(x, stddev): 168 | """Implements the real additive white gaussian noise channel. 169 | Args: 170 | x: channel input symbols 171 | stddev: standard deviation of noise 172 | Returns: 173 | y: noisy channel output symbols 174 | """ 175 | # additive white gaussian noise 176 | awgn = tf.random.normal(tf.shape(x), 0, stddev, dtype=tf.float32) 177 | y = x + awgn 178 | 179 | return y 180 | 181 | 182 | def fading(x, stddev, h=None): 183 | """Implements the fading channel with multiplicative fading and 184 | additive white gaussian noise. 185 | Args: 186 | x: channel input symbols 187 | stddev: standard deviation of noise 188 | Returns: 189 | y: noisy channel output symbols 190 | """ 191 | # channel gain 192 | if h is None: 193 | h = tf.complex( 194 | tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2)), 195 | tf.random.normal([tf.shape(x)[0], 1], 0, 1 / np.sqrt(2)), 196 | ) 197 | 198 | # additive white gaussian noise 199 | awgn = tf.complex( 200 | tf.random.normal(tf.shape(x), 0, 1 / np.sqrt(2)), 201 | tf.random.normal(tf.shape(x), 0, 1 / np.sqrt(2)), 202 | ) 203 | 204 | return (h * x + stddev * awgn), h 205 | 206 | 207 | class Channel(layers.Layer): 208 | def __init__(self, channel_type, channel_snr, name="channel", **kwargs): 209 | super(Channel, self).__init__(name=name, **kwargs) 210 | self.channel_type = channel_type 211 | self.channel_snr = channel_snr 212 | 213 | def call(self, inputs): 214 | (encoded_img, prev_h) = inputs 215 | inter_shape = tf.shape(encoded_img) 216 | # reshape array to [-1, dim_z] 217 | z = layers.Flatten()(encoded_img) 218 | # convert from snr to std 219 | print("channel_snr: {}".format(self.channel_snr)) 220 | noise_stddev = np.sqrt(10 ** (-self.channel_snr / 10)) 221 | 222 | # Add channel noise 223 | if self.channel_type == "awgn": 224 | dim_z = tf.shape(z)[1] 225 | # normalize latent vector so that the average power is 1 226 | z_in = tf.sqrt(tf.cast(dim_z, dtype=tf.float32)) * tf.nn.l2_normalize( 227 | z, axis=1) 228 | z_out = real_awgn(z_in, noise_stddev) 229 | h = tf.ones_like(z_in) # h just makes sense on fading channels 230 | 231 | elif self.channel_type == "fading": 232 | dim_z = tf.shape(z)[1] // 2 233 | # convert z to complex representation 234 | z_in = tf.complex(z[:, :dim_z], z[:, dim_z:]) 235 | # normalize the latent vector so that the average power is 1 236 | z_norm = tf.reduce_sum( 237 | tf.math.real(z_in * tf.math.conj(z_in)), axis=1, keepdims=True 238 | ) 239 | z_in = z_in * tf.complex( 240 | tf.sqrt(tf.cast(dim_z, dtype=tf.float32) / z_norm), 0.0 241 | ) 242 | z_out, h = fading(z_in, noise_stddev, prev_h) 243 | # convert back to real 244 | z_out = tf.concat([tf.math.real(z_out), tf.math.imag(z_out)], 1) 245 | 246 | # convert signal back to intermediate shape 247 | z_out = tf.reshape(z_out, inter_shape) 248 | 249 | return z_out, h 250 | 251 | class D_JSCC(layers.Layer): 252 | """Build D-JSCC arch""" 253 | def __init__( 254 | self, 255 | channel_snr, 256 | conv_depth, 257 | channel_type, 258 | name="deep_jscc", 259 | **kwargs 260 | ): 261 | super(D_JSCC, self).__init__(name=name, **kwargs) 262 | 263 | n_channels = 3 # For RGB, change this if working with BW images 264 | self.encoder = Encoder(conv_depth) 265 | self.decoder = Decoder(n_channels, name="decoder_output") 266 | self.channel = Channel(channel_type, channel_snr, name="channel_output") 267 | 268 | def call(self, inputs): 269 | 270 | # inputs is just the original image 271 | img_in = img = inputs 272 | prev_chn_gain = None 273 | 274 | chn_in = self.encoder(img_in) 275 | chn_out, chn_gain = self.channel((chn_in, prev_chn_gain)) 276 | 277 | decoded_img = self.decoder(chn_out) 278 | 279 | # keep track of some metrics 280 | self.add_metric( 281 | tf.image.psnr(img, decoded_img, max_val=1.0), 282 | aggregation="mean", 283 | name="psnr", 284 | ) 285 | 286 | self.add_metric( 287 | tf.reduce_mean(tf.math.square(img - decoded_img)), 288 | aggregation="mean", 289 | name="mse", 290 | ) 291 | 292 | return (decoded_img, chn_out, chn_gain) 293 | 294 | def change_channel_snr(self, channel_snr): 295 | self.channel.channel_snr = channel_snr 296 | 297 | def change_feedback_snr(self, feedback_snr): 298 | self.feedback_snr = feedback_snr 299 | 300 | 301 | def main(args): 302 | 303 | # get train and test CIFAR dataset 304 | x_train, x_test = get_dataset(args.number_of_train_image,args.number_of_test_image) 305 | 306 | if args.delete_previous_model and tf.io.gfile.exists(args.model_dir): 307 | print("Deleting previous model files at {}".format(args.model_dir)) 308 | tf.io.gfile.rmtree(args.model_dir) 309 | tf.io.gfile.makedirs(args.model_dir) 310 | else: 311 | print("Starting new model at {}".format(args.model_dir)) 312 | tf.io.gfile.makedirs(args.model_dir) 313 | 314 | # load model 315 | prev_layer_out = None 316 | # add input placeholder to please keras 317 | img = tf.keras.Input(shape=(None, None, 3)) 318 | 319 | channel_snr = args.channel_snr_train 320 | 321 | # Max R (bit rate/bandwidth) of the AWGN Channel given CIFAR dataset 322 | image_dim = 32 * 32 * 3 323 | channel_Rmax = max_Rate(args.conv_depth, image_dim, channel_snr) 324 | 325 | # checkpoint 326 | ckpt_file = os.path.join(args.model_dir, "ckpt") 327 | 328 | # D-JSCC model object 329 | ae_layer = D_JSCC( 330 | channel_snr, 331 | int(args.conv_depth), 332 | args.channel, 333 | ) 334 | 335 | layer_output = ae_layer(img) 336 | 337 | ( 338 | decoded_img, 339 | _chn_out, 340 | _chn_gain, 341 | ) = layer_output 342 | 343 | model = tf.keras.Model(inputs=img, outputs=decoded_img) 344 | 345 | model_metrics = [ 346 | tf.keras.metrics.MeanSquaredError(), 347 | psnr_metric, 348 | ] 349 | 350 | model.compile( 351 | optimizer=tf.keras.optimizers.Adam(learning_rate=args.learn_rate), 352 | loss="mse", 353 | metrics=model_metrics, 354 | ) 355 | 356 | print(model.summary()) 357 | 358 | checkpoint_path = os.path.join(args.model_dir, "ckpt") 359 | checkpoint_dir = os.path.dirname(checkpoint_path) 360 | # Create a callback that saves the model's weights 361 | cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,\ 362 | save_weights_only=True, verbose=1) 363 | model.load_weights("/media/bohnsix/D-JSCC/train_logs/checkpoint") 364 | 365 | model.fit( 366 | x_train, 367 | x_train, 368 | # epochs=args.train_epochs, 369 | epochs=1000, 370 | callbacks=[cp_callback], 371 | verbose=2, 372 | batch_size=args.batch_size_train, 373 | ) 374 | 375 | model.trainable = False 376 | 377 | 378 | print("<----------EVALUATION--------->") 379 | # eval the model 380 | out_eval = model.evaluate(x_test,x_test, verbose=2,batch_size=args.batch_size_test) 381 | for m, v in zip(model.metrics_names, out_eval): 382 | met_name = "_".join(["eval", m]) 383 | print("{}={}".format(met_name, v), end=" ") 384 | print("\n") 385 | 386 | 387 | def get_dataset(no_of_train_images,no_of_test_images): 388 | 389 | # load train and test images of CIFAR-10 dataset 390 | (train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data() 391 | train_labels, test_labels = train_labels, test_labels 392 | # Normalize pixel values to be between 0 and 1 for all the images 393 | x_train, x_tst = train_images[:no_of_train_images] / 255.0, test_images[:no_of_test_images] / 255.0 394 | 395 | return x_train, x_tst 396 | 397 | 398 | if __name__ == "__main__": 399 | # parse args 400 | p = configargparse.ArgParser() 401 | p.add( 402 | "-c", 403 | "--my-config", 404 | required=False, 405 | is_config_file=True, 406 | help="config file path", 407 | ) 408 | p.add( 409 | "--conv_depth", 410 | type=float, 411 | default=8, 412 | help=( 413 | "Number of channels of last conv layer, used to define the " 414 | "compression rate: k/n=c_out/(16*3)" 415 | ), 416 | 417 | ) 418 | p.add( 419 | "--channel", 420 | type=str, 421 | default="fading", 422 | choices=["awgn", "fading"], 423 | help="Model of channel used (awgn, fading)", 424 | ) 425 | p.add( 426 | "--model_dir", 427 | type=str, 428 | default="./train_logs", 429 | help=("The location of the model checkpoint files."), 430 | ) 431 | p.add( 432 | "--delete_previous_model", 433 | action="store_true", 434 | default=False, 435 | help=("If model_dir has checkpoints, delete it before" "starting new run"), 436 | ) 437 | p.add( 438 | "--channel_snr_train", 439 | type=float, 440 | default=10, 441 | help="target SNR of channel during training (dB)", 442 | ) 443 | p.add( 444 | "--number_of_train_image", 445 | type=int, 446 | default=5000, 447 | help="Number of training images during training ", 448 | ) 449 | p.add( 450 | "--number_of_test_image", 451 | type=int, 452 | default=1000, 453 | help="Number of test images during testing ", 454 | ) 455 | p.add( 456 | "--learn_rate", 457 | type=float, 458 | default=0.001, 459 | help="Learning rate for Adam optimizer", 460 | ) 461 | p.add( 462 | "--train_epochs", 463 | type=int, 464 | default=2500, 465 | help=( 466 | "The number of epochs used to train (each epoch goes over the whole dataset)" 467 | ), 468 | ) 469 | p.add("--batch_size_train", type=int, default=64, help="Batch size for training") 470 | p.add("--batch_size_test", type=int, default=64, help="Batch size for testing") 471 | 472 | args = p.parse_args() 473 | 474 | print("##############D-JSCC#########################") 475 | for arg, value in sorted(vars(args).items()): 476 | print("{}: {}".format(arg, value)) 477 | print("#############################################") 478 | main(args) 479 | 480 | 481 | 482 | """ 483 | conda activate deepjscc_bohnsix 484 | 485 | python -u main.py --channel awgn --batch_size_train 512 --batch_size_test 512 > log2000.log 486 | """ 487 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | '''ResNet in PyTorch. 5 | 6 | For Pre-activation ResNet, see 'preact_resnet.py'. 7 | 8 | Reference: 9 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 10 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 11 | ''' 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class BasicBlockRC(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride=1, dropout_rate=0.0): 21 | super(BasicBlockRC, self).__init__() 22 | self.conv1 = nn.Conv2d( 23 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.dropout = nn.Dropout(p=dropout_rate) 26 | 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 28 | stride=1, padding=1, bias=False) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | 31 | self.shortcut = nn.Sequential() 32 | if stride != 1 or in_planes != self.expansion * planes: 33 | self.shortcut = nn.Sequential( 34 | nn.Conv2d(in_planes, self.expansion * planes, 35 | kernel_size=1, stride=stride, bias=False), 36 | nn.BatchNorm2d(self.expansion * planes) 37 | ) 38 | 39 | def forward(self, x): 40 | out = self.dropout(F.relu(self.bn1(self.conv1(x)))) 41 | out = self.bn2(self.conv2(out)) 42 | out = F.relu(out) 43 | return out 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | expansion = 1 48 | 49 | def __init__(self, in_planes, planes, stride=1, dropout_rate=0.0): 50 | super(BasicBlock, self).__init__() 51 | self.conv1 = nn.Conv2d( 52 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | # self.dropout = nn.Dropout(p=dropout_rate) 55 | # print("Using Dropout") 56 | 57 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 58 | stride=1, padding=1, bias=False) 59 | self.bn2 = nn.BatchNorm2d(planes) 60 | 61 | self.shortcut = nn.Sequential() 62 | if stride != 1 or in_planes != self.expansion * planes: 63 | self.shortcut = nn.Sequential( 64 | nn.Conv2d(in_planes, self.expansion * planes, 65 | kernel_size=1, stride=stride, bias=False), 66 | nn.BatchNorm2d(self.expansion * planes) 67 | ) 68 | 69 | def forward(self, x): 70 | # out = self.dropout(F.relu(self.bn1(self.conv1(x)))) 71 | 72 | out = F.relu(self.bn1(self.conv1(x))) 73 | 74 | out = self.bn2(self.conv2(out)) 75 | out += self.shortcut(x) 76 | out = F.relu(out) 77 | return out 78 | 79 | 80 | class Bottleneck(nn.Module): 81 | expansion = 4 82 | 83 | def __init__(self, in_planes, planes, stride=1, dropout_rate=0.0): 84 | super(Bottleneck, self).__init__() 85 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 86 | self.bn1 = nn.BatchNorm2d(planes) 87 | self.dropout1 = nn.Dropout(p=dropout_rate) 88 | 89 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 90 | stride=stride, padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.dropout2 = nn.Dropout(p=dropout_rate) 93 | 94 | self.conv3 = nn.Conv2d(planes, self.expansion * 95 | planes, kernel_size=1, bias=False) 96 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 97 | self.dropout3 = nn.Dropout(p=dropout_rate) 98 | 99 | self.shortcut = nn.Sequential() 100 | if stride != 1 or in_planes != self.expansion * planes: 101 | self.shortcut = nn.Sequential( 102 | nn.Conv2d(in_planes, self.expansion * planes, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(self.expansion * planes) 105 | ) 106 | 107 | def forward(self, x): 108 | out = self.dropout1(F.relu(self.bn1(self.conv1(x)))) 109 | out = self.dropout2(F.relu(self.bn2(self.conv2(out)))) 110 | out = self.bn3(self.conv3(out)) 111 | out += self.shortcut(x) 112 | out = self.dropout3(F.relu(out)) 113 | return out 114 | 115 | 116 | class ResNet(nn.Module): 117 | def __init__(self, input_size, block, num_blocks, num_classes=10, dropout=0.0): 118 | super(ResNet, self).__init__() 119 | self.in_planes = 64 120 | self.dropout = dropout 121 | 122 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 125 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 126 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 127 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 128 | self.linear = nn.Linear(512 * block.expansion * ((input_size // 32) ** 2), num_classes) 129 | 130 | 131 | # print(512*block.expansion*((input_size//32)**2)) 132 | 133 | def _make_layer(self, block, planes, num_blocks, stride): 134 | strides = [stride] + [1] * (num_blocks - 1) 135 | layers = [] 136 | for stride in strides: 137 | layers.append(block(self.in_planes, planes, stride, self.dropout)) 138 | self.in_planes = planes * block.expansion 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x, prejection=False): 142 | # print(x.shape) 143 | out = F.relu(self.bn1(self.conv1(x))) 144 | # print(out.shape) 145 | out = self.layer1(out) 146 | # print(out.shape) 147 | out = self.layer2(out) 148 | # print(out.shape) 149 | out = self.layer3(out) 150 | # print(out.shape) 151 | out = self.layer4(out) 152 | # print(out.shape) 153 | out = F.avg_pool2d(out, 4) 154 | # print(out.shape) 155 | # out = out.view(out.size(0), -1) 156 | out = out.reshape(out.size(0), -1) 157 | # print(out.shape) 158 | # out = self.linear(out) 159 | # print(out.shape) 160 | # exit() 161 | # return out 162 | if prejection == True: 163 | return self.linear(out), out 164 | else: 165 | return self.linear(out) 166 | 167 | 168 | 169 | 170 | 171 | 172 | class ResNet_FS(nn.Module): 173 | def __init__(self, input_size, block, num_blocks, num_classes=10, dropout=0.0): 174 | super(ResNet_FS, self).__init__() 175 | self.in_planes = 64 176 | self.dropout = dropout 177 | 178 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 179 | self.bn1 = nn.BatchNorm2d(64) 180 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 181 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 182 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 183 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 184 | size = 512 * block.expansion * ((input_size // 32) ** 2) 185 | self.linear = nn.Linear(size, num_classes) 186 | self.linear1 = nn.Linear(size, size, bias=False) 187 | self.linear2 = nn.Linear(size, size, bias=False) 188 | self.linear1.weight = torch.nn.parameter.Parameter(torch.eye(size)) 189 | self.linear2.weight = torch.nn.parameter.Parameter(torch.eye(size)) 190 | 191 | # print(512*block.expansion*((input_size//32)**2)) 192 | 193 | def _make_layer(self, block, planes, num_blocks, stride): 194 | strides = [stride] + [1] * (num_blocks - 1) 195 | layers = [] 196 | for stride in strides: 197 | layers.append(block(self.in_planes, planes, stride, self.dropout)) 198 | self.in_planes = planes * block.expansion 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x, prejection=False, logit=False): 202 | # print(x.shape) 203 | out = F.relu(self.bn1(self.conv1(x))) 204 | # print(out.shape) 205 | out = self.layer1(out) 206 | # print(out.shape) 207 | out = self.layer2(out) 208 | # print(out.shape) 209 | out = self.layer3(out) 210 | # print(out.shape) 211 | out = self.layer4(out) 212 | # print(out.shape) 213 | out = F.avg_pool2d(out, 4) 214 | # print(out.shape) 215 | # # out = out.view(out.size(0), -1) 216 | # out = out.reshape(out.size(0), -1) 217 | # # print(out.shape) 218 | # out = self.linear(out) 219 | # # print(out.shape) 220 | # # exit() 221 | # return out 222 | out_1 = out.view(out.size(0), -1) 223 | out = self.linear1(out_1) 224 | if self.training or logit: 225 | logit_1 = self.linear(out) 226 | out2 = self.linear2(out_1.detach())# 227 | logit_2 = self.linear(out2) 228 | out_2 = (logit_1, logit_2) 229 | out = (out, out2) 230 | else: 231 | out_2 = self.linear(out) 232 | if prejection == True: 233 | return out_2, out 234 | else: 235 | return out_2 236 | 237 | 238 | 239 | 240 | def ResNet18(input_size, num_classes): 241 | return ResNet(input_size, BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 242 | 243 | 244 | def ResNet18_FS(input_size, num_classes): 245 | return ResNet_FS(input_size, BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 246 | 247 | 248 | def ResNet34(input_size, num_classes): 249 | return ResNet(input_size, BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 250 | 251 | 252 | def ResNet50(input_size, num_classes): 253 | return ResNet(input_size, Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 254 | 255 | def ResNet50_FS(input_size, num_classes): 256 | return ResNet_FS(input_size, Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 257 | 258 | 259 | def ResNet101(): 260 | return ResNet(Bottleneck, [3, 4, 23, 3]) 261 | 262 | 263 | def ResNet152(input_size, num_classes): 264 | return ResNet(input_size, Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 265 | 266 | 267 | def test(): 268 | net = ResNet18() 269 | y = net(torch.randn(1, 3, 32, 32)) 270 | print(y.size()) 271 | 272 | # '''ResNet in PyTorch. 273 | 274 | # For Pre-activation ResNet, see 'preact_resnet.py'. 275 | 276 | # Reference: 277 | # [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 278 | # Deep Residual Learning for Image Recognition. arXiv:1512.03385 279 | # ''' 280 | # import torch 281 | # import torch.nn as nn 282 | # import torch.nn.functional as F 283 | 284 | 285 | # class BasicBlock(nn.Module): 286 | # expansion = 1 287 | 288 | # def __init__(self, in_planes, planes, stride=1): 289 | # super(BasicBlock, self).__init__() 290 | # self.conv1 = nn.Conv2d( 291 | # in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 292 | # # self.bn1 = nn.BatchNorm2d(planes) 293 | # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 294 | # stride=1, padding=1, bias=False) 295 | # # self.bn2 = nn.BatchNorm2d(planes) 296 | 297 | # self.shortcut = nn.Sequential() 298 | # if stride != 1 or in_planes != self.expansion*planes: 299 | # self.shortcut = nn.Sequential( 300 | # nn.Conv2d(in_planes, self.expansion*planes, 301 | # kernel_size=1, stride=stride, bias=False), 302 | # nn.BatchNorm2d(self.expansion*planes) 303 | # ) 304 | 305 | # def forward(self, x): 306 | # out = F.relu(self.conv1(x)) 307 | # out = self.conv2(out) 308 | # out += self.shortcut(x) 309 | # out = F.relu(out) 310 | # return out 311 | 312 | 313 | # class Bottleneck(nn.Module): 314 | # expansion = 4 315 | 316 | # def __init__(self, in_planes, planes, stride=1): 317 | # super(Bottleneck, self).__init__() 318 | # self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 319 | # self.bn1 = nn.BatchNorm2d(planes) 320 | # self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 321 | # stride=stride, padding=1, bias=False) 322 | # self.bn2 = nn.BatchNorm2d(planes) 323 | # self.conv3 = nn.Conv2d(planes, self.expansion * 324 | # planes, kernel_size=1, bias=False) 325 | # self.bn3 = nn.BatchNorm2d(self.expansion*planes) 326 | 327 | # self.shortcut = nn.Sequential() 328 | # if stride != 1 or in_planes != self.expansion*planes: 329 | # self.shortcut = nn.Sequential( 330 | # nn.Conv2d(in_planes, self.expansion*planes, 331 | # kernel_size=1, stride=stride, bias=False), 332 | # nn.BatchNorm2d(self.expansion*planes) 333 | # ) 334 | 335 | # def forward(self, x): 336 | # out = F.relu(self.bn1(self.conv1(x))) 337 | # out = F.relu(self.bn2(self.conv2(out))) 338 | # out = self.bn3(self.conv3(out)) 339 | # out += self.shortcut(x) 340 | # out = F.relu(out) 341 | # return out 342 | 343 | 344 | # class ResNet(nn.Module): 345 | # def __init__(self, input_size, block, num_blocks, num_classes=10): 346 | # super(ResNet, self).__init__() 347 | # self.in_planes = 64 348 | 349 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 350 | # # self.bn1 = nn.BatchNorm2d(64) 351 | # self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 352 | # self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 353 | # self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 354 | # self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 355 | # self.linear = nn.Linear(512*block.expansion*((input_size//32)**2), num_classes) 356 | # # print(512*block.expansion*((input_size//32)**2)) 357 | 358 | # def _make_layer(self, block, planes, num_blocks, stride): 359 | # strides = [stride] + [1]*(num_blocks-1) 360 | # layers = [] 361 | # for stride in strides: 362 | # layers.append(block(self.in_planes, planes, stride)) 363 | # self.in_planes = planes * block.expansion 364 | # return nn.Sequential(*layers) 365 | 366 | # def forward(self, x): 367 | # # print(x.shape) 368 | # out = F.relu(self.conv1(x)) 369 | # # print(out.shape) 370 | # out = self.layer1(out) 371 | # # print(out.shape) 372 | # out = self.layer2(out) 373 | # # print(out.shape) 374 | # out = self.layer3(out) 375 | # # print(out.shape) 376 | # out = self.layer4(out) 377 | # # print(out.shape) 378 | # out = F.avg_pool2d(out, 4) 379 | # # print(out.shape) 380 | # out = out.view(out.size(0), -1) 381 | # # print(out.shape) 382 | # out = self.linear(out) 383 | # return out 384 | 385 | 386 | # def ResNet18(input_size, num_classes): 387 | # return ResNet(input_size, BasicBlock, [2, 2, 2, 2], num_classes=num_classes) 388 | 389 | 390 | # def ResNet34(input_size, num_classes): 391 | # return ResNet(input_size, BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 392 | 393 | 394 | # def ResNet50(input_size, num_classes): 395 | # return ResNet(input_size, Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 396 | 397 | 398 | # def ResNet101(): 399 | # return ResNet(Bottleneck, [3, 4, 23, 3]) 400 | 401 | 402 | # def ResNet152(): 403 | # return ResNet(Bottleneck, [3, 8, 36, 3]) 404 | 405 | 406 | # def test(): 407 | # net = ResNet18() 408 | # y = net(torch.randn(1, 3, 32, 32)) 409 | # print(y.size()) 410 | 411 | # # test() -------------------------------------------------------------------------------- /resources/experiments.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/experiments.png -------------------------------------------------------------------------------- /resources/performance_awgn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/performance_awgn.png -------------------------------------------------------------------------------- /resources/performance_fading.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/performance_fading.png -------------------------------------------------------------------------------- /resources/training_valid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/training_valid.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr00_c0.04_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr00_c0.04_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr00_c0.09_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr00_c0.09_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr00_c0.17_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr00_c0.17_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr00_c0.25_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr00_c0.25_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr00_c0.33_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr00_c0.33_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr00_c0.42_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr00_c0.42_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr00_c0.49_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr00_c0.49_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr10_c0.04_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr10_c0.04_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr10_c0.09_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr10_c0.09_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr10_c0.17_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr10_c0.17_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr10_c0.25_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr10_c0.25_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr10_c0.33_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr10_c0.33_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr10_c0.42_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr10_c0.42_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr10_c0.49_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr10_c0.49_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr20_c0.04_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr20_c0.04_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr20_c0.09_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr20_c0.09_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr20_c0.17_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr20_c0.17_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr20_c0.25_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr20_c0.25_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr20_c0.33_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr20_c0.33_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr20_c0.42_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr20_c0.42_e1000.png -------------------------------------------------------------------------------- /resources/validation_awgn_snr20_c0.49_e1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BohnSix/deepjscc/8c460a71670d8ce4bd2255b9d3494b8476bc2406/resources/validation_awgn_snr20_c0.49_e1000.png -------------------------------------------------------------------------------- /torch_impl.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | import cv2 4 | import json 5 | import time 6 | import torch 7 | import pickle 8 | import datetime 9 | import matplotlib 10 | import torchvision 11 | import numpy as np 12 | from torch import nn 13 | from glob import glob 14 | from tqdm import tqdm 15 | from pprint import pprint 16 | from matplotlib import pyplot as plt 17 | import torchvision.transforms as transforms 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | 21 | class Encoder(nn.Module): 22 | def __init__(self, conv_depth): 23 | super().__init__() 24 | self.conv_depth = conv_depth 25 | self.sublayers = nn.ModuleList([ 26 | nn.Conv2d(3, 16, 5, 2, 2), 27 | nn.PReLU(), 28 | nn.Conv2d(16, 32, 5, 2, 2), 29 | nn.PReLU(), 30 | nn.Conv2d(32, 32, 5, 1, 2), 31 | nn.PReLU(), 32 | nn.Conv2d(32, 32, 5, 1, 2), 33 | nn.PReLU(), 34 | nn.Conv2d(32, conv_depth*2, 5, 1, 2), 35 | nn.PReLU(), 36 | ]) 37 | 38 | def forward(self, x): 39 | for layer in self.sublayers: 40 | x = layer(x) 41 | # return x.type(torch.complex64) 42 | return torch.complex(x[: , :self.conv_depth], x[: , self.conv_depth:]) 43 | 44 | class Decoder(nn.Module): 45 | def __init__(self, conv_depth): 46 | super().__init__() 47 | self.sublayers = nn.ModuleList([ 48 | nn.ConvTranspose2d(conv_depth*2, 32, 5, 1, 2), 49 | nn.PReLU(), 50 | nn.ConvTranspose2d(32, 32, 5, 1, 2), 51 | nn.PReLU(), 52 | nn.ConvTranspose2d(32, 32, 5, 1, 2), 53 | nn.PReLU(), 54 | nn.ConvTranspose2d(32, 16, 5, 2, 2, output_padding=1), 55 | nn.PReLU(), 56 | nn.ConvTranspose2d(16, 3, 5, 2, 2, output_padding=1), 57 | nn.Sigmoid(), 58 | ]) 59 | 60 | def forward(self, x): 61 | x = torch.concat([torch.real(x), torch.imag(x)], 1).float() 62 | for layer in self.sublayers: 63 | x = layer(x) 64 | return x 65 | 66 | class Channel(nn.Module): 67 | def __init__(self, channel_type, channel_snr): 68 | super().__init__() 69 | self.channel_type = channel_type 70 | self.channel_snr = channel_snr 71 | self.snr = 10**(self.channel_snr/10.0) 72 | 73 | def awgn(self, channel_input, stddev): 74 | cmplx_dist = np.random.normal(loc=0, scale=np.sqrt(2)/2, size=(2*len(channel_input))).view(np.complex128) 75 | cmplx_dist = torch.from_numpy(cmplx_dist).cuda() 76 | noise = cmplx_dist * stddev 77 | return channel_input + noise, torch.ones_like(channel_input) 78 | 79 | def fading(self, x, stddev, h=None): 80 | z = torch.real(x) 81 | z_dim = len(z) // 2 82 | z_in = torch.complex(z[:z_dim], z[z_dim:]) 83 | 84 | if h is None: 85 | h = torch.complex(torch.from_numpy(np.random.normal(0, np.sqrt(2)/2, z_in.shape)), 86 | torch.from_numpy(np.random.normal(0, np.sqrt(2)/2, z_in.shape))) 87 | noise = torch.complex(torch.from_numpy(np.random.normal(0, np.sqrt(2)/2, z_in.shape)), 88 | torch.from_numpy(np.random.normal(0, np.sqrt(2)/2, z_in.shape))) 89 | h, noise = h.cuda(), noise.cuda() 90 | z_out = h * z_in + noise * stddev 91 | 92 | z_out = torch.concat([torch.real(z_out), torch.imag(z_out)], 0) 93 | 94 | return z_out, h 95 | 96 | def forward(self, channel_input): 97 | # print("channel_snr: {}".format(self.channel_snr)) 98 | 99 | signl_pwr = torch.mean(torch.square(torch.abs(channel_input))) 100 | noise_pwr = signl_pwr / self.snr 101 | noise_stddev = torch.sqrt(noise_pwr) 102 | 103 | if self.channel_type == "awgn": 104 | channal_output, H = self.awgn(channel_input, noise_stddev) 105 | elif self.channel_type == "fading": 106 | channal_output, H = self.fading(channel_input, noise_stddev) 107 | 108 | return channal_output, H 109 | 110 | class JSCC(nn.Module): 111 | def __init__(self, conv_depth, snr_db=10): 112 | super().__init__() 113 | self.encoder = Encoder(conv_depth) 114 | self.channel = Channel("awgn", snr_db) 115 | self.decoder = Decoder(conv_depth) 116 | 117 | def powerConstraint(self, channel_input, P): 118 | # norm by total power instead of average power 119 | enery = torch.sum(torch.square(torch.abs(channel_input))) 120 | normalization_factor = np.sqrt(len(channel_input)*P) / torch.sqrt(enery) 121 | channel_input = channel_input * normalization_factor 122 | # the average power of output should be about P 123 | 124 | return channel_input 125 | 126 | def forward(self, inputs, snr_db=10, P=1): 127 | prev_chn_gain = None 128 | chn_in = self.encoder(inputs) 129 | lst = list(chn_in.shape) 130 | 131 | chn_in = chn_in.flatten() 132 | chn_in = self.powerConstraint(chn_in, P) 133 | # print(torch.mean(torch.square(torch.abs(chn_in)))) 134 | 135 | chn_out, h = self.channel(chn_in) 136 | 137 | chn_out = chn_out.reshape(lst) 138 | 139 | decoded_img = self.decoder(chn_out) 140 | 141 | return decoded_img, chn_out 142 | 143 | def Calculate_filters(comp_ratio, F=8, n=3072): 144 | K = (comp_ratio*n)/F**2 145 | return round(K) 146 | 147 | # ############################################################### 148 | # compression_ratios = [0.04, 0.09, 0.17, 0.25, 0.33, 0.42, 0.49] 149 | # filter_size = [] 150 | # for comp_ratio in compression_ratios: 151 | # K = Calculate_filters(comp_ratio) 152 | # filter_size.append(K) 153 | 154 | # print(filter_size) # [2, 4, 8, 12, 16, 20, 24] 155 | # ############################################################### 156 | 157 | SNR = 20 158 | CHANNEL_TYPE = "awgn" 159 | COMPRESSION_RATIO = 0.04 160 | 161 | """ 162 | rm checkpoints/* 163 | rm -r train_logs/* 164 | rm validation_imgs/* 165 | 166 | nohup python -u torch_impl.py > train_logs/awgn_snr20_c04.log 2>&1 & 167 | """ 168 | 169 | EPOCHS = 1000 170 | NUM_WORKERS = 4 171 | LEARNING_RATE = 0.001 172 | CHANNEL_SNR_TRAIN = 10 173 | TRAIN_IMAGE_NUM = 50000 174 | TEST_IMAGE_NUM = 10000 175 | TRAIN_BS = 64 176 | TEST_BS = 4096 177 | K = Calculate_filters(COMPRESSION_RATIO) 178 | 179 | transform = transforms.Compose([transforms.ToTensor()]) 180 | 181 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) 182 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) 183 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BS, shuffle=True, num_workers=NUM_WORKERS) 184 | testloader = torch.utils.data.DataLoader(testset, batch_size=TEST_BS, shuffle=False, num_workers=NUM_WORKERS) 185 | 186 | model = JSCC(K, snr_db=SNR).cuda() 187 | 188 | # model.load_state_dict(torch.load("/media/bohnsix/djscc/checkpoints/jscc_model_17")) 189 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 190 | loss_fn = nn.MSELoss() 191 | 192 | def train_one_epoch(epoch_index): 193 | running_loss = 0. 194 | 195 | for i, data in enumerate(trainloader): 196 | inputs, labels = data 197 | b, _, _, _ = inputs.shape 198 | inputs = inputs.cuda() 199 | optimizer.zero_grad() 200 | decoded_img, chn_out = model(inputs) 201 | 202 | loss = loss_fn(decoded_img, inputs) 203 | loss.backward() 204 | optimizer.step() 205 | 206 | running_loss += loss.item() * b 207 | 208 | 209 | return running_loss / TRAIN_IMAGE_NUM 210 | 211 | writer = SummaryWriter(f'train_logs/deepjscc_{CHANNEL_TYPE}_snr{SNR:02d}_c{COMPRESSION_RATIO}') 212 | 213 | best_vloss = 1. 214 | change_lr_flag = True 215 | 216 | print(f"""Training on CHANNEL {CHANNEL_TYPE}, Compression Ratio {COMPRESSION_RATIO} and SNR {SNR:02d} dB at {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}.\n\n\n""") 217 | 218 | for epoch in range(1, EPOCHS+1): 219 | if epoch > 640 and change_lr_flag: 220 | LEARNING_RATE = LEARNING_RATE / 10 221 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 222 | change_lr_flag = False 223 | print("Update LR to {LEARNING_RATE}\n") 224 | 225 | cur = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 226 | print(f'EPOCH {epoch:03d} starts at {cur}') 227 | 228 | model.train(True) 229 | avg_loss = train_one_epoch(epoch) 230 | 231 | running_vloss = 0.0 232 | model.eval() 233 | 234 | with torch.no_grad(): 235 | if epoch > 640: 236 | val_times = 10 237 | else: 238 | val_times = 1 239 | for _ in range(val_times): 240 | for i, vdata in enumerate(testloader): 241 | vinputs, vlabels = vdata 242 | b, _, _, _ = vinputs.shape 243 | 244 | vinputs = vinputs.cuda() 245 | decoded_img, chn_out = model(vinputs) 246 | vloss = loss_fn(decoded_img, vinputs) 247 | 248 | running_vloss += vloss * b 249 | 250 | a = vinputs[:128].detach().cpu().numpy().reshape(16, 8, 3, 32, 32).transpose(0, 1, 3, 4, 2) 251 | b = decoded_img[:128].detach().cpu().numpy().reshape(16, 8, 3, 32, 32).transpose(0, 1, 3, 4, 2) 252 | c = (np.hstack(np.hstack(np.concatenate([a, b], 3)))[..., ::-1] * 255).astype(np.uint8) 253 | cv2.imwrite(f"validation_imgs/validation_{CHANNEL_TYPE}_snr{SNR:02d}_c{COMPRESSION_RATIO}_e{epoch:04d}.png", c) 254 | 255 | avg_vloss = running_vloss / TEST_IMAGE_NUM / val_times 256 | print(f'LOSS train {avg_loss:.8f} valid {avg_vloss:.8f}') 257 | print(f'LOSS valid PSNR {10 * np.log10(1/avg_vloss.item()):.2f} dB. \n') 258 | 259 | writer.add_scalars('Training vs. Validation Loss', 260 | { 'Training' : avg_loss, 'Validation' : avg_vloss }, 261 | epoch) 262 | writer.flush() 263 | 264 | # Track best performance, and save the model's state 265 | if avg_vloss < best_vloss: 266 | best_vloss = avg_vloss 267 | model_path = f'checkpoints/deepjscc_{CHANNEL_TYPE}_snr{SNR:02d}_c{COMPRESSION_RATIO}_e{epoch:03d}.ckpt' 268 | torch.save(model.state_dict(), model_path) 269 | -------------------------------------------------------------------------------- /visualization.md: -------------------------------------------------------------------------------- 1 | AWGN 00dB 0.04 15.07dB 2 | ![验证结果](resources/validation_awgn_snr00_c0.04_e1000.png) 3 | 4 | AWGN 00dB 0.09 16.99dB 5 | ![验证结果](resources/validation_awgn_snr00_c0.09_e1000.png) 6 | 7 | AWGN 00dB 0.17 19.0dB 8 | ![验证结果](resources/validation_awgn_snr00_c0.17_e1000.png) 9 | 10 | AWGN 00dB 0.25 20.18dB 11 | ![验证结果](resources/validation_awgn_snr00_c0.25_e1000.png) 12 | 13 | AWGN 00dB 0.33 21.08dB 14 | ![验证结果](resources/validation_awgn_snr00_c0.33_e1000.png) 15 | 16 | AWGN 00dB 0.42 21.74dB 17 | ![验证结果](resources/validation_awgn_snr00_c0.42_e1000.png) 18 | 19 | AWGN 00dB 0.49 22.29dB 20 | ![验证结果](resources/validation_awgn_snr00_c0.49_e1000.png) 21 | 22 | AWGN 10dB 0.04 19.36dB 23 | ![验证结果](resources/validation_awgn_snr10_c0.04_e1000.png) 24 | 25 | AWGN 10dB 0.09 22.29dB 26 | ![验证结果](resources/validation_awgn_snr10_c0.09_e1000.png) 27 | 28 | AWGN 10dB 0.17 25.23dB 29 | ![验证结果](resources/validation_awgn_snr10_c0.17_e1000.png) 30 | 31 | AWGN 10dB 0.25 26.99dB 32 | ![验证结果](resources/validation_awgn_snr10_c0.25_e1000.png) 33 | 34 | AWGN 10dB 0.33 27.21dB 35 | ![验证结果](resources/validation_awgn_snr10_c0.33_e1000.png) 36 | 37 | AWGN 10dB 0.42 28.24dB 38 | ![验证结果](resources/validation_awgn_snr10_c0.42_e1000.png) 39 | 40 | AWGN 10dB 0.49 28.54dB 41 | ![验证结果](resources/validation_awgn_snr10_c0.49_e1000.png) 42 | 43 | AWGN 20dB 0.04 21.43dB 44 | ![验证结果](resources/validation_awgn_snr20_c0.04_e1000.png) 45 | 46 | AWGN 20dB 0.09 25.09dB 47 | 48 | ![验证结果](resources/validation_awgn_snr20_c0.09_e1000.png) 49 | 50 | AWGN 20dB 0.17 29.59dB 51 | ![验证结果](resources/validation_awgn_snr20_c0.17_e1000.png) 52 | 53 | AWGN 20dB 0.25 32.22dB 54 | ![验证结果](resources/validation_awgn_snr20_c0.25_e1000.png) 55 | 56 | AWGN 20dB 0.33 33.01dB 57 | ![验证结果](resources/validation_awgn_snr20_c0.33_e1000.png) 58 | 59 | AWGN 20dB 0.42 33.01dB 60 | ![验证结果](resources/validation_awgn_snr20_c0.42_e1000.png) 61 | 62 | AWGN 20dB 0.49 33.01dB 63 | ![验证结果](resources/validation_awgn_snr20_c0.49_e1000.png) -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | class BasicBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activation='ReLU', softplus_beta=1): 10 | super(BasicBlock, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 16 | padding=1, bias=False) 17 | if activation == 'ReLU': 18 | self.relu1 = nn.ReLU(inplace=True) 19 | self.relu2 = nn.ReLU(inplace=True) 20 | print('R') 21 | elif activation == 'Softplus': 22 | self.relu1 = nn.Softplus(beta=softplus_beta, threshold=20) 23 | self.relu2 = nn.Softplus(beta=softplus_beta, threshold=20) 24 | print('S') 25 | elif activation == 'GELU': 26 | self.relu1 = nn.GELU() 27 | self.relu2 = nn.GELU() 28 | print('G') 29 | elif activation == 'ELU': 30 | self.relu1 = nn.ELU(alpha=1.0, inplace=True) 31 | self.relu2 = nn.ELU(alpha=1.0, inplace=True) 32 | print('E') 33 | 34 | self.droprate = dropRate 35 | self.equalInOut = (in_planes == out_planes) 36 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 37 | padding=0, bias=False) or None 38 | 39 | def forward(self, x): 40 | if not self.equalInOut: 41 | x = self.relu1(self.bn1(x)) 42 | else: 43 | out = self.relu1(self.bn1(x)) 44 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 45 | if self.droprate > 0: 46 | out = F.dropout(out, p=self.droprate, training=self.training) 47 | out = self.conv2(out) 48 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 49 | 50 | 51 | class NetworkBlock(nn.Module): 52 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activation='ReLU', softplus_beta=1): 53 | super(NetworkBlock, self).__init__() 54 | self.activation = activation 55 | self.softplus_beta = softplus_beta 56 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 57 | 58 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 59 | layers = [] 60 | for i in range(int(nb_layers)): 61 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, 62 | self.activation, self.softplus_beta)) 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | return self.layer(x) 67 | 68 | 69 | class WideResNet(nn.Module): 70 | def __init__(self, image_size=32, depth=34, num_classes=10, widen_factor=10, dropRate=0.0, normalize=False, activation='ReLU', softplus_beta=1): 71 | super(WideResNet, self).__init__() 72 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 73 | assert ((depth - 4) % 6 == 0) 74 | n = (depth - 4) / 6 75 | block = BasicBlock 76 | self.normalize = normalize 77 | #self.scale = scale 78 | # 1st conv before any network block 79 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 80 | padding=1, bias=False) 81 | # 1st block 82 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activation=activation, softplus_beta=softplus_beta) 83 | # 1st sub-block 84 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activation=activation, softplus_beta=softplus_beta) 85 | # 2nd block 86 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, activation=activation, softplus_beta=softplus_beta) 87 | # 3rd block 88 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, activation=activation, softplus_beta=softplus_beta) 89 | # global average pooling and classifier 90 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 91 | 92 | if activation == 'ReLU': 93 | self.relu = nn.ReLU(inplace=True) 94 | elif activation == 'Softplus': 95 | self.relu = nn.Softplus(beta=softplus_beta, threshold=20) 96 | elif activation == 'GELU': 97 | self.relu = nn.GELU() 98 | elif activation == 'ELU': 99 | self.relu = nn.ELU(alpha=1.0, inplace=True) 100 | print('Use activation of ' + activation) 101 | 102 | if self.normalize: 103 | self.fc = nn.Linear(nChannels[3] * (image_size // 32) ** 2, num_classes, bias = False) 104 | else: 105 | self.fc = nn.Linear(nChannels[3] * (image_size // 32) ** 2, num_classes) 106 | self.nChannels = nChannels[3] * (image_size // 32) ** 2 107 | 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear) and not self.normalize: 117 | m.bias.data.zero_() 118 | 119 | def forward(self, x, prejection=False): 120 | out = self.conv1(x) 121 | out = self.block1(out) 122 | out = self.block2(out) 123 | out = self.block3(out) 124 | out = self.relu(self.bn1(out)) 125 | out = F.avg_pool2d(out, 8) 126 | out = out.view(-1, self.nChannels) 127 | if self.normalize: 128 | out = F.normalize(out, p=2, dim=1) 129 | for _, module in self.fc.named_modules(): 130 | if isinstance(module, nn.Linear): 131 | module.weight.data = F.normalize(module.weight, p=2, dim=1) 132 | if prejection == True: 133 | return self.fc(out), out 134 | else: 135 | return self.fc(out) 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | class WideResNet_FS(nn.Module): 145 | def __init__(self, image_size=32, depth=34, num_classes=10, widen_factor=10, dropRate=0.0, normalize=False, activation='ReLU', softplus_beta=1): 146 | super(WideResNet_FS, self).__init__() 147 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 148 | assert ((depth - 4) % 6 == 0) 149 | n = (depth - 4) / 6 150 | block = BasicBlock 151 | self.normalize = normalize 152 | #self.scale = scale 153 | # 1st conv before any network block 154 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 155 | padding=1, bias=False) 156 | # 1st block 157 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activation=activation, softplus_beta=softplus_beta) 158 | # 1st sub-block 159 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activation=activation, softplus_beta=softplus_beta) 160 | # 2nd block 161 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, activation=activation, softplus_beta=softplus_beta) 162 | # 3rd block 163 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, activation=activation, softplus_beta=softplus_beta) 164 | # global average pooling and classifier 165 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 166 | 167 | if activation == 'ReLU': 168 | self.relu = nn.ReLU(inplace=True) 169 | elif activation == 'Softplus': 170 | self.relu = nn.Softplus(beta=softplus_beta, threshold=20) 171 | elif activation == 'GELU': 172 | self.relu = nn.GELU() 173 | elif activation == 'ELU': 174 | self.relu = nn.ELU(alpha=1.0, inplace=True) 175 | print('Use activation of ' + activation) 176 | 177 | if self.normalize: 178 | self.fc = nn.Linear(nChannels[3] * (image_size // 32) ** 2, num_classes, bias = False) 179 | else: 180 | self.fc = nn.Linear(nChannels[3] * (image_size // 32) ** 2, num_classes) 181 | self.nChannels = nChannels[3] * (image_size // 32) ** 2 182 | 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 186 | m.weight.data.normal_(0, math.sqrt(2. / n)) 187 | elif isinstance(m, nn.BatchNorm2d): 188 | m.weight.data.fill_(1) 189 | m.bias.data.zero_() 190 | elif isinstance(m, nn.Linear) and not self.normalize: 191 | m.bias.data.zero_() 192 | 193 | self.fc1 = nn.Linear(self.nChannels, self.nChannels, bias=False) 194 | self.fc2 = nn.Linear(self.nChannels, self.nChannels, bias=False) 195 | self.fc1.weight = torch.nn.parameter.Parameter(torch.eye(self.nChannels)) 196 | self.fc2.weight = torch.nn.parameter.Parameter(torch.eye(self.nChannels)) 197 | 198 | def forward(self, x, logit=False, prejection=False): 199 | out = self.conv1(x) 200 | out = self.block1(out) 201 | out = self.block2(out) 202 | out = self.block3(out) 203 | out = self.relu(self.bn1(out)) 204 | out = F.avg_pool2d(out, 8) 205 | # out = out.view(-1, self.nChannels) 206 | # if self.normalize: 207 | # out = F.normalize(out, p=2, dim=1) 208 | # for _, module in self.fc.named_modules(): 209 | # if isinstance(module, nn.Linear): 210 | # module.weight.data = F.normalize(module.weight, p=2, dim=1) 211 | # return self.fc(out) 212 | out_1 = out.view(-1, self.nChannels) 213 | out = self.fc1(out_1) 214 | if self.training or logit: 215 | logit_1 = self.fc(out) 216 | out2 = self.fc2(out_1.detach()) # 217 | logit_2 = self.fc(out2) 218 | out_2 = (logit_1, logit_2) 219 | out = (out, out2) 220 | else: 221 | out_2 = self.fc(out) 222 | if prejection == True: 223 | return out_2, out 224 | else: 225 | return out_2 226 | 227 | 228 | 229 | 230 | --------------------------------------------------------------------------------