├── ADJSCC_test.py ├── ADJSCC_train.py ├── DeepJSCC_V_test.py ├── DeepJSCC_V_train_CIFAR10.py ├── DeepJSCC_V_train_ImageNet.py ├── GDN.py ├── OracleNet.py ├── OracleNet_test_Kodak.py ├── OracleNet_test_Kodak_image_level.py ├── Oracle_test.py ├── Oracle_train_CIFAR10.py ├── Oracle_train_ImageNet.py ├── README.md ├── data_loader.py ├── models.py └── utils.py /ADJSCC_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import os 6 | 7 | from utils import * 8 | from models import * 9 | 10 | 11 | BATCH_SIZE = 256 12 | EPOCHS = 150 13 | LEARNING_RATE = 1e-3 14 | PRINT_RREQ = 250 15 | 16 | CHANNEL = 'Fading' # Choose AWGN or Fading 17 | # if CHANNEL == 'AWGN': 18 | # # CR_INDEX = torch.Tensor([ 6, 3, 2, 1]).int() 19 | # # CR_INDEX = torch.Tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]).int() 20 | # CR_INDEX = [6, 3, 2] 21 | # elif CHANNEL == 'Fading': 22 | # # CR_INDEX = [3, 3/2] 23 | # # CR_INDEX = torch.Tensor([3, 3/2]).int() 24 | # CR_INDEX = torch.Tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]).int() 25 | 26 | IMG_SIZE = [3, 32, 32] 27 | N_channels = 256 28 | kernel_sz = 5 29 | 30 | enc_shape = [32, 8, 8] 31 | CR = 96//enc_shape[0] # The real compression ration R = 1/CR 32 | 33 | _, x_test = Load_cifar100_data() 34 | test_dataset = DatasetFolder(x_test) 35 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True) 36 | 37 | KSZ = '_'+str(kernel_sz)+'x'+str(kernel_sz)+'_' 38 | PSNR_ave = np.zeros((10, 10)) 39 | if __name__ == '__main__': 40 | for m in range(0, 10): 41 | # enc_shape = [96//CR_INDEX[m], 8, 8] 42 | DeepJSCC = ADJSCC(enc_shape, kernel_sz, N_channels).cuda() 43 | # DeepJSCC = nn.DataParallel(DeepJSCC) 44 | 45 | DeepJSCC.load_state_dict(torch.load('./JSCC_models/DeepJSCC'+KSZ+CHANNEL+'_'+str(CR)+'_'+str(N_channels)+'.pth.tar')['state_dict']) 46 | 47 | for k in range(0, 10): 48 | print('Evaluating DeepJSCC with CR = '+str(CR)+' and SNR = '+str(3*k-3)+'dB') 49 | total_psnr = 0 50 | DeepJSCC.eval() 51 | with torch.no_grad(): 52 | for i, test_input in enumerate(test_loader): 53 | SNR = 3*(k-1)*torch.ones((test_input.shape[0], 1)) 54 | test_input = test_input.cuda() 55 | 56 | test_rec = DeepJSCC(test_input, SNR, CHANNEL) 57 | 58 | test_input = Img_transform(test_input) 59 | test_rec = Img_transform(test_rec) 60 | psnr_ave = Compute_batch_PSNR(test_input, test_rec) 61 | total_psnr += psnr_ave 62 | averagePSNR = total_psnr / i 63 | print('PSNR = ' + str(averagePSNR)) 64 | 65 | PSNR_ave[m, k] = averagePSNR 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /ADJSCC_train.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | 7 | from utils import * 8 | from models import * 9 | 10 | 11 | BATCH_SIZE = 128 12 | EPOCHS = 200 13 | LEARNING_RATE = 1e-4 14 | PRINT_RREQ = 250 15 | 16 | 17 | CHANNEL = 'AWGN' # Choose AWGN or Fading 18 | IMG_SIZE = [3, 32, 32] 19 | N_channels = 256 20 | Kernel_sz = 5 21 | 22 | # Parameter enc_out_shape[0] specifies the compresison ratio 23 | enc_out_shape = [32, IMG_SIZE[1]//4, IMG_SIZE[2]//4] 24 | 25 | CR = 96//enc_out_shape[0] 26 | 27 | x_train, x_test = Load_cifar10_data() 28 | train_dataset = DatasetFolder(x_train) 29 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True) 30 | test_dataset = DatasetFolder(x_test) 31 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True) 32 | 33 | current_epoch = 0 34 | CONTINUE_TRAINING = False 35 | 36 | KSZ = str(Kernel_sz)+'x'+str(Kernel_sz)+'_' 37 | if __name__ == '__main__': 38 | 39 | DeepJSCC = ADJSCC(enc_out_shape, Kernel_sz, N_channels).cuda() 40 | # DeepJSCC = nn.DataParallel(DeepJSCC) 41 | 42 | criterion = nn.MSELoss().cuda() 43 | optimizer = torch.optim.Adam(DeepJSCC.parameters(), lr=LEARNING_RATE) 44 | 45 | bestLoss = 1e3 46 | if CONTINUE_TRAINING == True: 47 | DeepJSCC.load_state_dict(torch.load('./JSCC_models/DeepJSCC_'+KSZ+CHANNEL+'_'+str(CR)+'_'+str(N_channels)+'.pth.tar')['state_dict']) 48 | current_epoch = 0 49 | bestLoss = 1 50 | 51 | 52 | for epoch in range(current_epoch, EPOCHS): 53 | DeepJSCC.train() 54 | print('========================') 55 | print('lr:%.4e'%optimizer.param_groups[0]['lr']) 56 | 57 | # Model training 58 | for i, x_input in enumerate(train_loader): 59 | x_input = x_input.cuda() 60 | 61 | SNR_TRAIN = torch.randint(0, 28, (x_input.shape[0], 1)).cuda() 62 | x_rec = DeepJSCC(x_input, SNR_TRAIN, CHANNEL) 63 | loss = criterion(x_input, x_rec) 64 | loss = loss.mean() 65 | 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | if i % PRINT_RREQ == 0: 70 | print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}\t'.format(epoch, i, len(train_loader), loss=loss.item())) 71 | 72 | # Model Evaluation 73 | DeepJSCC.eval() 74 | totalLoss = 0 75 | with torch.no_grad(): 76 | for i, test_input in enumerate(test_loader): 77 | test_input = test_input.cuda() 78 | SNR_TEST = torch.randint(0, 28, (test_input.shape[0], 1)).cuda() 79 | test_rec = DeepJSCC(test_input, SNR_TEST, CHANNEL) 80 | totalLoss += criterion(test_rec, test_input).item() * test_input.size(0) 81 | averageLoss = totalLoss / (len(test_dataset)) 82 | print('averageLoss=', averageLoss) 83 | if averageLoss < bestLoss: 84 | # Model saving 85 | if not os.path.exists('./JSCC_models'): 86 | os.makedirs('./JSCC_models') 87 | torch.save({'state_dict': DeepJSCC.state_dict(), }, './JSCC_models/DeepJSCC_'+KSZ+CHANNEL+'_'+str(CR)+'_'+str(N_channels)+'.pth.tar') 88 | print('Model saved') 89 | bestLoss = averageLoss 90 | 91 | print('Training for DeepJSCC_'+str(CR)+' is finished!') 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /DeepJSCC_V_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import os 6 | 7 | from utils import * 8 | from models import * 9 | 10 | from data_loader import train_data_loader, test_data_loader 11 | 12 | BATCH_SIZE = 256 13 | EPOCHS = 150 14 | LEARNING_RATE = 1e-3 15 | PRINT_RREQ = 250 16 | 17 | CHANNEL = 'Fading' # Choose AWGN or Fading 18 | if CHANNEL == 'AWGN': 19 | # CR_INDEX = torch.Tensor([ 6, 3, 2, 1]).int() 20 | # CR_INDEX = torch.Tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]).int() 21 | CR_INDEX = torch.Tensor([6, 3, 2]).int() 22 | elif CHANNEL == 'Fading': 23 | # CR_INDEX = [3, 3/2] 24 | # CR_INDEX = torch.Tensor([3, 3/2]).int() 25 | CR_INDEX = torch.Tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]).int() 26 | # CR_INDEX = torch.Tensor([2, 1]).int() 27 | 28 | IMG_SIZE = [32, 32, 32] 29 | N_channels = 256 30 | kernel_sz = 5 31 | KSZ = str(kernel_sz)+'x'+str(kernel_sz)+'_' 32 | 33 | 34 | _, x_test = Load_cifar100_data() 35 | test_dataset = DatasetFolder(x_test) 36 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True) 37 | 38 | 39 | IMGZ = 32 40 | c_max = 48 41 | enc_shape = [c_max, IMGZ//4, IMGZ//4] 42 | 43 | PSNR_ave = np.zeros((10, 10)) 44 | if __name__ == '__main__': 45 | for m in range(0, 10): 46 | # cr = 1/CR_INDEX[m] 47 | cr = 2/3 48 | DeepJSCC_V = ADJSCC_V(enc_shape, kernel_sz, N_channels).cuda() 49 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 50 | 51 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_cifar10.pth.tar')['state_dict']) 52 | 53 | for k in range(0, 10): 54 | print('Evaluating DeepJSCC_VLC with CR = '+str(2*CR_INDEX[m].item())+' and SNR = '+str(3*k-3)+'dB') 55 | total_psnr = 0 56 | DeepJSCC_V.eval() 57 | with torch.no_grad(): 58 | # for i, (test_input,_) in enumerate(test_loader): 59 | for i, test_input in enumerate(test_loader): 60 | SNR = 3*(k-1)*torch.ones((test_input.shape[0], 1)) 61 | CR = cr*torch.ones((test_input.shape[0], 1)) 62 | SNR = SNR.cuda() 63 | CR = CR.cuda() 64 | test_input = test_input.cuda() 65 | 66 | test_rec = DeepJSCC_V(test_input, SNR, CR, CHANNEL) 67 | 68 | test_input = Img_transform(test_input) 69 | test_rec = Img_transform(test_rec) 70 | psnr_ave = Compute_batch_PSNR(test_input, test_rec) 71 | total_psnr += psnr_ave 72 | averagePSNR = total_psnr / i 73 | print('PSNR = ' + str(averagePSNR)) 74 | 75 | PSNR_ave[m, k] = averagePSNR 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /DeepJSCC_V_train_CIFAR10.py: -------------------------------------------------------------------------------- 1 | 2 | # import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | 7 | from utils import Load_cifar10_data, DatasetFolder 8 | from models import ADJSCC_V 9 | 10 | 11 | BATCH_SIZE = 128 12 | EPOCHS = 400 13 | LEARNING_RATE = 1e-4 14 | PRINT_RREQ = 150 15 | 16 | CHANNEL = 'AWGN' # Choose AWGN or Fading 17 | IMG_SIZE = [3, 32, 32] 18 | N_channels = 256 19 | Kernel_sz = 5 20 | 21 | x_train, x_test = Load_cifar10_data() 22 | 23 | train_dataset = DatasetFolder(x_train) 24 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True) 25 | test_dataset = DatasetFolder(x_test) 26 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True) 27 | 28 | current_epoch = 0 29 | CONTINUE_TRAINING = False 30 | 31 | enc_out_shape = [48, IMG_SIZE[1]//4, IMG_SIZE[2]//4] 32 | KSZ = str(Kernel_sz)+'x'+str(Kernel_sz)+'_' 33 | if __name__ == '__main__': 34 | 35 | DeepJSCC_V = ADJSCC_V(enc_out_shape, Kernel_sz, N_channels).cuda() 36 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 37 | 38 | criterion = nn.MSELoss().cuda() 39 | optimizer = torch.optim.Adam(DeepJSCC_V.parameters(), lr=LEARNING_RATE) 40 | 41 | bestLoss = 1e3 42 | if CONTINUE_TRAINING == True: 43 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20.pth.tar')['state_dict']) 44 | current_epoch = 204 45 | 46 | # bestLoss = 1e3 47 | for epoch in range(current_epoch, EPOCHS): 48 | DeepJSCC_V.train() 49 | print('========================') 50 | print('lr:%.4e'%optimizer.param_groups[0]['lr']) 51 | 52 | # Model training 53 | for i, x_input in enumerate(train_loader): 54 | x_input = x_input.cuda() 55 | 56 | SNR_TRAIN = torch.randint(0, 28, (x_input.shape[0], 1)).cuda() 57 | CR = 0.1+0.9*torch.rand(x_input.shape[0], 1).cuda() 58 | x_rec = DeepJSCC_V(x_input, SNR_TRAIN, CR, CHANNEL) 59 | 60 | loss = criterion(x_input, x_rec) 61 | loss = loss.mean() 62 | 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() 66 | 67 | if i % PRINT_RREQ == 0: 68 | print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}\t'.format(epoch, i, len(train_loader), loss=loss.item())) 69 | 70 | 71 | # Model Evaluation 72 | DeepJSCC_V.eval() 73 | totalLoss = 0 74 | with torch.no_grad(): 75 | for i, test_input in enumerate(test_loader): 76 | test_input = test_input.cuda() 77 | SNR_TEST = torch.randint(0, 28, (test_input.shape[0], 1)).cuda() 78 | CR = 0.1+0.9*torch.rand(test_input.shape[0], 1).cuda() 79 | test_rec = DeepJSCC_V(test_input, SNR_TEST, CR, CHANNEL) 80 | 81 | totalLoss += criterion(test_input, test_rec).item() * test_input.size(0) 82 | averageLoss = totalLoss / (len(test_dataset)) 83 | print('averageLoss=', averageLoss) 84 | if averageLoss < bestLoss: 85 | # Model saving 86 | if not os.path.exists('./JSCC_models'): 87 | os.makedirs('./JSCC_models') 88 | torch.save({'state_dict': DeepJSCC_V.state_dict(), }, './JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20.pth.tar') 89 | print('Model saved') 90 | bestLoss = averageLoss 91 | 92 | print('Training for DeepJSCC_V is finished!') 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /DeepJSCC_V_train_ImageNet.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | 7 | from models import ADJSCC_V 8 | 9 | from data_loader import train_data_loader, test_data_loader 10 | 11 | 12 | BATCH_SIZE = 32 13 | EPOCHS = 50 14 | LEARNING_RATE = 1e-4 15 | PRINT_RREQ = 50 16 | SAVE_RREQ = 500 17 | 18 | CHANNEL = 'AWGN' # Choose AWGN or Fading 19 | N_channels = 256 20 | Kernel_sz = 5 21 | 22 | IMGZ = 128 23 | train_loader = train_data_loader(batch_size = BATCH_SIZE, imgz = IMGZ, workers = 2) 24 | test_loader = test_data_loader(batch_size = BATCH_SIZE, imgz = IMGZ, workers = 2) 25 | 26 | current_epoch = 0 27 | CONTINUE_TRAINING = True 28 | LOAD_PRETRAIN = False 29 | 30 | enc_out_shape = [48, IMGZ//4, IMGZ//4] 31 | KSZ = str(Kernel_sz)+'x'+str(Kernel_sz)+'_' 32 | if __name__ == '__main__': 33 | 34 | DeepJSCC_V = ADJSCC_V(enc_out_shape, Kernel_sz, N_channels).cuda() 35 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 36 | 37 | criterion = nn.MSELoss().cuda() 38 | optimizer = torch.optim.Adam(DeepJSCC_V.parameters(), lr=LEARNING_RATE) 39 | 40 | if LOAD_PRETRAIN == True: 41 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_cifar10.pth.tar')['state_dict']) 42 | 43 | bestLoss = 1e3 44 | if CONTINUE_TRAINING == True: 45 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_ImageNet.pth.tar')['state_dict']) 46 | current_epoch = 0 47 | # LEARNING_RATE = 0.5*1e-4 48 | 49 | # bestLoss = 1e3 50 | for epoch in range(current_epoch, EPOCHS): 51 | DeepJSCC_V.train() 52 | print('========================') 53 | print('lr:%.4e'%optimizer.param_groups[0]['lr']) 54 | 55 | 56 | # if epoch == 40: 57 | # optimizer.param_groups[0]['lr'] = 0.5*1e-4 58 | 59 | # Model training 60 | for i, (x_input, _) in enumerate(train_loader): 61 | # print(i)% 62 | x_input = x_input.cuda() 63 | 64 | SNR_TRAIN = torch.randint(0, 28, (x_input.shape[0], 1)).cuda() 65 | CR = 0.1+0.9*torch.rand(x_input.shape[0], 1).cuda() 66 | x_rec = DeepJSCC_V(x_input, SNR_TRAIN, CR, CHANNEL) 67 | 68 | loss = criterion(x_input, x_rec) 69 | loss = loss.mean() 70 | 71 | optimizer.zero_grad() 72 | loss.backward() 73 | optimizer.step() 74 | 75 | if i % PRINT_RREQ == 0: 76 | print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}\t'.format(epoch, i, len(train_loader), loss=loss.item())) 77 | 78 | # if i % SAVE_RREQ == 0: 79 | # if not os.path.exists('./JSCC_models'): 80 | # os.makedirs('./JSCC_models') 81 | # torch.save({'state_dict': DeepJSCC_V.state_dict(), }, './JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_ImageNet.pth.tar') 82 | # print('Model saved') 83 | 84 | # Model Evaluation 85 | DeepJSCC_V.eval() 86 | totalLoss = 0 87 | with torch.no_grad(): 88 | for i, (test_input, _) in enumerate(test_loader): 89 | test_input = test_input.cuda() 90 | SNR_TEST = torch.randint(0, 28, (test_input.shape[0], 1)).cuda() 91 | CR = 0.1+0.9*torch.rand(test_input.shape[0], 1).cuda() 92 | test_rec = DeepJSCC_V(test_input, SNR_TEST, CR, CHANNEL) 93 | 94 | totalLoss += criterion(test_input, test_rec).item() * test_input.size(0) 95 | averageLoss = totalLoss / 5000 96 | print('averageLoss=', averageLoss) 97 | if averageLoss < bestLoss: 98 | # Model saving 99 | if not os.path.exists('./JSCC_models'): 100 | os.makedirs('./JSCC_models') 101 | torch.save({'state_dict': DeepJSCC_V.state_dict(), }, './JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_ImageNet.pth.tar') 102 | print('Model saved') 103 | bestLoss = averageLoss 104 | 105 | # print('Training for DeepJSCC_V is finished!') 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /GDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch import nn, optim 4 | from torch.nn import functional as F 5 | # from torchvision import datasets, transforms 6 | # from torchvision.utils import save_image 7 | from torch.autograd import Function 8 | 9 | 10 | class LowerBound(Function): 11 | @staticmethod 12 | def forward(ctx, inputs, bound): 13 | b = torch.ones_like(inputs) * bound 14 | ctx.save_for_backward(inputs, b) 15 | return torch.max(inputs, b) 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | inputs, b = ctx.saved_tensors 20 | pass_through_1 = inputs >= b 21 | pass_through_2 = grad_output < 0 22 | 23 | pass_through = pass_through_1 | pass_through_2 24 | return pass_through.type(grad_output.dtype) * grad_output, None 25 | 26 | 27 | class GDN(nn.Module): 28 | """Generalized divisive normalization layer. 29 | y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j])) 30 | """ 31 | 32 | def __init__(self, 33 | ch, 34 | inverse=False, 35 | beta_min=1e-6, 36 | gamma_init=0.1, 37 | reparam_offset=2**-18): 38 | super(GDN, self).__init__() 39 | self.inverse = inverse 40 | self.beta_min = beta_min 41 | self.gamma_init = gamma_init 42 | self.reparam_offset = reparam_offset 43 | 44 | self.build(ch) 45 | 46 | def build(self, ch): 47 | self.pedestal = self.reparam_offset**2 48 | self.beta_bound = ((self.beta_min + self.reparam_offset**2)**0.5) 49 | self.gamma_bound = self.reparam_offset 50 | 51 | # Create beta param 52 | beta = torch.sqrt(torch.ones(ch)+self.pedestal) 53 | self.beta = nn.Parameter(beta) 54 | 55 | # Create gamma param 56 | eye = torch.eye(ch) 57 | g = self.gamma_init*eye 58 | g = g + self.pedestal 59 | gamma = torch.sqrt(g) 60 | 61 | self.gamma = nn.Parameter(gamma) 62 | self.pedestal = self.pedestal 63 | 64 | def forward(self, inputs): 65 | unfold = False 66 | if inputs.dim() == 5: 67 | unfold = True 68 | bs, ch, d, w, h = inputs.size() 69 | inputs = inputs.view(bs, ch, d*w, h) 70 | 71 | _, ch, _, _ = inputs.size() 72 | 73 | # Beta bound and reparam 74 | beta = LowerBound.apply(self.beta, self.beta_bound) 75 | beta = beta**2 - self.pedestal 76 | 77 | # Gamma bound and reparam 78 | gamma = LowerBound.apply(self.gamma, self.gamma_bound) 79 | gamma = gamma**2 - self.pedestal 80 | gamma = gamma.view(ch, ch, 1, 1) 81 | 82 | # Norm pool calc 83 | norm_ = nn.functional.conv2d(inputs**2, gamma, beta) 84 | norm_ = torch.sqrt(norm_) 85 | 86 | # Apply norm 87 | if self.inverse: 88 | outputs = inputs * norm_ 89 | else: 90 | outputs = inputs / norm_ 91 | 92 | if unfold: 93 | outputs = outputs.view(bs, ch, d, w, h) 94 | return outputs 95 | -------------------------------------------------------------------------------- /OracleNet.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | 8 | class fc_ResBlock(nn.Module): 9 | def __init__(self, Nin, Nout): 10 | super(fc_ResBlock, self).__init__() 11 | Nh = Nin*2 12 | self.use_fc3 = False 13 | self.fc1 = nn.Linear(Nin, Nh) 14 | self.fc2 = nn.Linear(Nh, Nout) 15 | self.relu = nn.ReLU() 16 | if Nin != Nout: 17 | self.use_fc3 = True 18 | self.fc3 = nn.Linear(Nin, Nout) 19 | def forward(self, x): 20 | out = self.fc1(x) 21 | out = self.relu(out) 22 | out = self.fc2(out) 23 | if self.use_fc3 == True: 24 | x = self.fc3(x) 25 | out = out+x 26 | out = self.relu(out) 27 | return out 28 | 29 | # The oracle network for predicting the PSNR of the transmitted images 30 | class OracleNet(nn.Module): 31 | def __init__(self, Nc_max): 32 | super(OracleNet, self).__init__() 33 | self.fc1 = fc_ResBlock(Nc_max*2+2, Nc_max) 34 | self.fc2 = fc_ResBlock(Nc_max, Nc_max) 35 | self.fc3 = nn.Linear(Nc_max, 1) 36 | self.relu = nn.ReLU() 37 | def forward(self, x, snr, cr): 38 | N_out = torch.round(48*cr).int() 39 | if snr.shape[0]==1: 40 | snr = snr.unsqueeze(1) 41 | N_out = N_out.unsqueeze(1) 42 | std_feat = torch.std(x, (2, 3)) 43 | mean_feat = torch.mean(x, (2, 3)) 44 | out = torch.cat((mean_feat, std_feat, snr, N_out), 1) 45 | out = self.fc1(out) 46 | out = self.fc2(out) 47 | out = self.fc3(out) 48 | out = self.relu(out) 49 | return out 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /OracleNet_test_Kodak.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import matplotlib.pyplot as plt 4 | # import matplotlib.image as mpimg 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import torch 9 | from utils import * 10 | from models import * 11 | from OracleNet import OracleNet 12 | 13 | from skimage.metrics import peak_signal_noise_ratio as compute_pnsr 14 | 15 | CR = 1/10 16 | 17 | CHANNEL = 'AWGN' # Choose AWGN or Fading 18 | Rep_N = 1 19 | 20 | 21 | # def kodak_test(CR): 22 | PSNR_Kodak = np.zeros((24, 9)) 23 | PSNR_Kodak_pred = np.zeros((24, 9)) 24 | for k in range(0, 24): 25 | print('Image ' + str(k)) 26 | if k<9: 27 | img_id = '0'+str(k+1) 28 | else: 29 | img_id = str(k+1) 30 | img_file = './data/Kodak24/kodim'+img_id+'.png' 31 | img = Image.open(img_file) 32 | 33 | img = np.transpose(img, (2, 0, 1)) 34 | img = img.astype('float32') / 255 35 | 36 | img = torch.Tensor(img).cuda() 37 | img = img.unsqueeze(0) 38 | 39 | 40 | input_shape = np.shape(img)[1:4] 41 | kernel_sz = 5 42 | N_channels = 256 43 | KSZ = str(kernel_sz)+'x'+str(kernel_sz)+'_' 44 | enc_shape = [48, input_shape[1]//4, input_shape[2]//4] 45 | 46 | DeepJSCC_V = ADJSCC_V(enc_shape, kernel_sz, N_channels).cuda() 47 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 48 | 49 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_ImageNet.pth.tar')['state_dict']) 50 | DeepJSCC_V.eval() 51 | 52 | OraNet = OracleNet(enc_shape[0]).cuda() 53 | OraNet.load_state_dict(torch.load('./JSCC_models/OracleNet_'+CHANNEL+'_ImageNet.pth.tar')['state_dict']) 54 | OraNet.eval() 55 | 56 | for snr_index in range(0, 9): 57 | PSNR_REP = np.zeros((Rep_N, 1)) 58 | with torch.no_grad(): 59 | for i in range(0, Rep_N): 60 | # print(i) 61 | # snr = torch.Tensor([snr_index*3, ]).cuda() 62 | # cr = torch.Tensor([CR, ]).cuda() 63 | 64 | snr = snr_index*3*torch.ones((5, 1)).cuda() 65 | cr = CR*torch.ones((5, 1)).cuda() 66 | 67 | img_i = img.tile(5, 1, 1, 1) 68 | img_rec = DeepJSCC_V(img_i, snr, cr, CHANNEL) 69 | 70 | img0 = Img_transform(img_i) 71 | img_rec = Img_transform(img_rec) 72 | PSNR = Compute_batch_PSNR(img0, img_rec) 73 | 74 | PSNR_REP[i, 0] = PSNR 75 | 76 | snr1 = torch.Tensor([snr_index*3, ]).cuda() 77 | cr1 = torch.Tensor([CR, ]).cuda() 78 | # z = DeepJSCC_V.module.encoder(img, snr1) 79 | z = DeepJSCC_V.encoder(img, snr1) 80 | z = z.view(-1, enc_shape[0], input_shape[1]//4, input_shape[2]//4) 81 | PSNR_pred = OraNet(z, snr1, cr1) 82 | PSNR_pred = PSNR_pred[0].cpu().detach().numpy() 83 | 84 | PSNR_Kodak[k, snr_index] = np.mean(PSNR) 85 | PSNR_Kodak_pred[k, snr_index] = PSNR_pred 86 | 87 | print(np.mean(PSNR)) 88 | print(PSNR_pred) 89 | 90 | ave1 = np.mean(PSNR_Kodak, 0) 91 | ave2 = np.mean(PSNR_Kodak_pred, 0) 92 | 93 | 94 | # return ave1, ave2 95 | 96 | 97 | # CR_all = torch.Tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 3/2]).int() 98 | 99 | # # CR = 1/8 100 | 101 | # CHANNEL = 'Fading' 102 | # Rep_N = 2 103 | 104 | # psnr_all = np.zeros((10, 9)) 105 | # psnr_pred_all = np.zeros((10, 9)) 106 | # for i in range(0, 10): 107 | # CR = 1/CR_all[i] 108 | # ave1, ave2 = kodak_test(CR) 109 | # psnr_all[i,:] = ave1 110 | # psnr_pred_all[i,:] = ave2 111 | 112 | # print('Evaluation ' + str(CR) + 'finished...') 113 | 114 | 115 | 116 | # plt.imshow(img) 117 | # plt.axis('off') 118 | # plt.show() 119 | 120 | # plt.imshow(img_rec[0, :]) 121 | # plt.axis('off') 122 | # plt.show() 123 | 124 | 125 | -------------------------------------------------------------------------------- /OracleNet_test_Kodak_image_level.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import matplotlib.pyplot as plt # plt 用于显示图片 4 | # import matplotlib.image as mpimg # mpimg 用于读取图片 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import torch 9 | from utils import * 10 | from models import * 11 | from OracleNet import OracleNet 12 | 13 | from skimage.metrics import peak_signal_noise_ratio as compute_pnsr 14 | 15 | 16 | 17 | CHANNEL = 'AWGN' 18 | Rep_N = 1 19 | 20 | # CR_LEVELS = [20,18,16,14,12,10,8,6,4,3] 21 | CR_LEVELS = [20,19,18,17,16,15,14,13,12,11,10,9,8,7,6,5,4,3] 22 | 23 | # select the target Kodak image 24 | k = 13 25 | 26 | 27 | if k<10: 28 | img_id = '0'+str(k) 29 | else: 30 | img_id = str(k) 31 | img_file = './data/Kodak24/kodim'+img_id+'.png' 32 | img = Image.open(img_file) 33 | 34 | plt.imshow(img) 35 | plt.axis('off') 36 | plt.show() 37 | 38 | img = np.transpose(img, (2, 0, 1)) 39 | img = img.astype('float32') / 255 40 | 41 | img = torch.Tensor(img).cuda() 42 | img = img.unsqueeze(0) 43 | 44 | 45 | input_shape = np.shape(img)[1:4] 46 | kernel_sz = 5 47 | N_channels = 256 48 | KSZ = str(kernel_sz)+'x'+str(kernel_sz)+'_' 49 | enc_shape = [48, input_shape[1]//4, input_shape[2]//4] 50 | 51 | DeepJSCC_V = ADJSCC_V(enc_shape, kernel_sz, N_channels).cuda() 52 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 53 | 54 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_ImageNet.pth.tar')['state_dict']) 55 | DeepJSCC_V.eval() 56 | 57 | OraNet = OracleNet(enc_shape[0]).cuda() 58 | OraNet.load_state_dict(torch.load('./JSCC_models/OracleNet_'+CHANNEL+'_ImageNet.pth.tar')['state_dict']) 59 | OraNet.eval() 60 | 61 | PSNR_ALL = np.zeros((18, 9)) 62 | PSNR_PRED_ALL = np.zeros((18, 9)) 63 | 64 | N_batch = 32 65 | for cr_index in range(0, 18): 66 | cr0 = 2/CR_LEVELS[cr_index] 67 | cr = cr0*torch.ones((N_batch, 1)).cuda() 68 | 69 | for snr_index in range(0, 9): 70 | snr = snr_index*3*torch.ones((N_batch, 1)).cuda() 71 | 72 | PSNR_REP = np.zeros((Rep_N, 1)) 73 | with torch.no_grad(): 74 | for i in range(0, Rep_N): 75 | # print(i) 76 | img_i = img.tile(N_batch, 1, 1, 1) 77 | img_rec = DeepJSCC_V(img_i, snr, cr, CHANNEL) 78 | 79 | img0 = Img_transform(img_i) 80 | img_rec = Img_transform(img_rec) 81 | PSNR = Compute_batch_PSNR(img0, img_rec) 82 | 83 | PSNR_REP[i, 0] = PSNR 84 | 85 | snr1 = torch.Tensor([snr_index*3, ]).cuda() 86 | cr1 = torch.Tensor([cr0, ]).cuda() 87 | # z = DeepJSCC_V.module.encoder(img, snr1) 88 | z = DeepJSCC_V.encoder(img, snr1) 89 | 90 | z = z.view(-1, enc_shape[0], input_shape[1]//4, input_shape[2]//4) 91 | PSNR_pred = OraNet(z, snr1, cr1) 92 | PSNR_pred = PSNR_pred[0].cpu().detach().numpy() 93 | 94 | PSNR_ALL[cr_index, snr_index] = np.mean(PSNR) 95 | PSNR_PRED_ALL[cr_index, snr_index] = PSNR_pred 96 | 97 | print(np.mean(PSNR)) 98 | print(PSNR_pred) 99 | 100 | 101 | 102 | # ave1 = np.mean(PSNR_Kodak, 0) 103 | # ave2 = np.mean(PSNR_Kodak_pred, 0) 104 | 105 | 106 | # plt.imshow(img) 107 | # plt.axis('off') 108 | # plt.show() 109 | 110 | # plt.imshow(img_rec[0, :]) 111 | # plt.axis('off') 112 | # plt.show() 113 | 114 | 115 | -------------------------------------------------------------------------------- /Oracle_test.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import os 6 | 7 | from utils import * 8 | from models import * 9 | from OracleNet import OracleNet 10 | 11 | 12 | BATCH_SIZE = 128 13 | EPOCHS = 150 14 | LEARNING_RATE = 1e-4 15 | PRINT_RREQ = 150 16 | 17 | 18 | _, x_test = Load_cifar100_data() 19 | # train_dataset = DatasetFolder(x_train) 20 | # train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) 21 | test_dataset = DatasetFolder(x_test) 22 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True) 23 | 24 | 25 | CHANNEL = 'Fading' # Choose AWGN or Fading 26 | CR_INDEX = torch.Tensor([10, 9, 8, 7, 6, 5, 4, 3, 2, 1]).int() 27 | 28 | IMG_SIZE = [3, 32, 32] 29 | N_channels = 256 30 | kernel_sz = 5 31 | enc_shape = [48, 8, 8] 32 | KSZ = str(kernel_sz)+'x'+str(kernel_sz)+'_' 33 | 34 | DeepJSCC_V = ADJSCC_V(enc_shape, kernel_sz, N_channels).cuda() 35 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 36 | 37 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_cifar10.pth.tar')['state_dict']) 38 | DeepJSCC_V.eval() 39 | 40 | OraNet = OracleNet(enc_shape[0]).cuda() 41 | OraNet.load_state_dict(torch.load('./JSCC_models/OracleNet_'+CHANNEL+'_Res.pth.tar')['state_dict']) 42 | OraNet.eval() 43 | 44 | criterion = nn.MSELoss().cuda() 45 | MSE_pred = np.zeros((10, 10)) 46 | if __name__ == '__main__': 47 | # Model Evaluation 48 | for m in range(0, 10): 49 | cr = 1/CR_INDEX[m] 50 | for k in range(0, 10): 51 | totalLoss = 0 52 | with torch.no_grad(): 53 | for i, test_input in enumerate(test_loader): 54 | SNR_TEST = 3*(k-1)*torch.ones((test_input.shape[0], 1)).cuda() 55 | CR = cr*torch.ones((test_input.shape[0], 1)).cuda() 56 | 57 | test_input = torch.Tensor(test_input).cuda() 58 | test_rec = DeepJSCC_V(test_input, SNR_TEST, CR, CHANNEL) 59 | z = DeepJSCC_V.module.encoder(test_input, SNR_TEST) 60 | 61 | test_input = Img_transform(test_input) 62 | test_rec = Img_transform(test_rec) 63 | psnr_batch = Compute_IMG_PSNR(test_input, test_rec) 64 | psnr_batch = torch.Tensor(psnr_batch).cuda() 65 | 66 | z = z.view(-1, enc_shape[0], 8, 8) 67 | psnr_pred = OraNet(z, SNR_TEST, CR) 68 | 69 | totalLoss += criterion(psnr_batch, psnr_pred).item() * psnr_batch.size(0) 70 | averageLoss = totalLoss / (len(test_dataset)) 71 | print('CR = '+str(cr.item())+ ', SNR = '+ str(3*(k-1)) +', MSE =', averageLoss) 72 | 73 | MSE_pred[m, k] = averageLoss 74 | 75 | 76 | # a = psnr_batch.cpu().numpy() 77 | # b = psnr_pred.cpu().numpy() 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /Oracle_train_CIFAR10.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | 6 | from utils import * 7 | from models import * 8 | from OracleNet import OracleNet 9 | 10 | 11 | BATCH_SIZE = 128 12 | EPOCHS = 150 13 | LEARNING_RATE = 1e-4 14 | PRINT_RREQ = 150 15 | 16 | 17 | x_train, x_test = Load_cifar10_data() 18 | train_dataset = DatasetFolder(x_train) 19 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True) 20 | test_dataset = DatasetFolder(x_test) 21 | test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True) 22 | 23 | 24 | CHANNEL = 'Fading' # Choose AWGN or Fading 25 | IMG_SIZE = [3, 32, 32] 26 | N_channels = 256 27 | kernel_sz = 5 28 | enc_shape = [48, 8, 8] 29 | KSZ = str(kernel_sz)+'x'+str(kernel_sz)+'_' 30 | 31 | DeepJSCC_V = ADJSCC_V(enc_shape, kernel_sz, N_channels).cuda() 32 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 33 | 34 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_cifar10.pth.tar')['state_dict']) 35 | DeepJSCC_V.eval() 36 | 37 | OraNet = OracleNet(enc_shape[0]).cuda() 38 | # OraNet = nn.DataParallel(OraNet) 39 | criterion = nn.MSELoss().cuda() 40 | optimizer = torch.optim.Adam(OraNet.parameters(), lr=LEARNING_RATE) 41 | 42 | if __name__ == '__main__': 43 | bestLoss = 1e3 44 | for epoch in range(EPOCHS): 45 | OraNet.train() 46 | for i, x_input in enumerate(train_loader): 47 | SNR = torch.randint(0, 28, (x_input.shape[0], 1)).cuda() 48 | CR = 0.1+0.9*torch.rand(x_input.shape[0], 1).cuda() 49 | 50 | x_input = torch.Tensor(x_input).cuda() 51 | x_rec = DeepJSCC_V(x_input, SNR, CR, CHANNEL) 52 | # z = DeepJSCC_V.module.encoder(x_input, SNR) 53 | z = DeepJSCC_V.encoder(x_input, SNR) 54 | 55 | x_input = Img_transform(x_input) 56 | x_rec = Img_transform(x_rec) 57 | psnr_batch = Compute_IMG_PSNR(x_input, x_rec) 58 | psnr_batch = torch.Tensor(psnr_batch).cuda() 59 | 60 | z = z.view(-1, enc_shape[0], 8, 8) 61 | psnr_pred = OraNet(z, SNR, CR) 62 | 63 | loss = criterion(psnr_batch, psnr_pred) 64 | loss = loss.mean() 65 | 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | 70 | if i % PRINT_RREQ == 0: 71 | print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}\t'.format(epoch, i, len(train_loader), loss=loss.item())) 72 | 73 | 74 | # Model Evaluation 75 | OraNet.eval() 76 | totalLoss = 0 77 | with torch.no_grad(): 78 | for i, test_input in enumerate(test_loader): 79 | SNR_TEST = torch.randint(0, 28, (test_input.shape[0], 1)).cuda() 80 | CR = 0.1+0.9*torch.rand(test_input.shape[0], 1).cuda() 81 | 82 | test_input = torch.Tensor(test_input).cuda() 83 | test_rec = DeepJSCC_V(test_input, SNR_TEST, CR, CHANNEL) 84 | # z = DeepJSCC_V.module.encoder(test_input, SNR_TEST) 85 | z = DeepJSCC_V.encoder(test_input, SNR_TEST) 86 | 87 | test_input = Img_transform(test_input) 88 | test_rec = Img_transform(test_rec) 89 | psnr_batch = Compute_IMG_PSNR(test_input, test_rec) 90 | psnr_batch = torch.Tensor(psnr_batch).cuda() 91 | 92 | z = z.view(-1, enc_shape[0], 8, 8) 93 | psnr_pred = OraNet(z, SNR_TEST, CR) 94 | 95 | totalLoss += criterion(psnr_batch, psnr_pred).item() * psnr_batch.size(0) 96 | averageLoss = totalLoss / (len(test_dataset)) 97 | print('averageLoss=', averageLoss) 98 | if averageLoss < bestLoss: 99 | # Model saving 100 | if not os.path.exists('./JSCC_models'): 101 | os.makedirs('./JSCC_models') 102 | torch.save({'state_dict': OraNet.state_dict(), }, './JSCC_models/OracleNet_'+CHANNEL+'_Res.pth.tar') 103 | print('Model saved') 104 | bestLoss = averageLoss 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /Oracle_train_ImageNet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | 6 | from utils import * 7 | from models import * 8 | from OracleNet import OracleNet 9 | 10 | 11 | from data_loader import train_data_loader, test_data_loader 12 | 13 | BATCH_SIZE = 32 14 | EPOCHS = 50 15 | LEARNING_RATE = 1e-4 16 | PRINT_RREQ = 150 17 | 18 | 19 | IMGZ = 128 20 | train_loader = train_data_loader(batch_size = BATCH_SIZE, imgz = IMGZ, workers = 2) 21 | test_loader = test_data_loader(batch_size = BATCH_SIZE, imgz = IMGZ, workers = 2) 22 | 23 | 24 | CHANNEL = 'AWGN' # Choose AWGN or Fading 25 | IMG_SIZE = [3, IMGZ, IMGZ] 26 | N_channels = 256 27 | kernel_sz = 5 28 | enc_shape = [48, IMGZ//4, IMGZ//4] 29 | KSZ = str(kernel_sz)+'x'+str(kernel_sz)+'_' 30 | 31 | DeepJSCC_V = ADJSCC_V(enc_shape, kernel_sz, N_channels).cuda() 32 | # DeepJSCC_V = nn.DataParallel(DeepJSCC_V) 33 | DeepJSCC_V.load_state_dict(torch.load('./JSCC_models/DeepJSCC_VLC_'+KSZ+CHANNEL+'_'+str(N_channels)+'_20_ImageNet.pth.tar')['state_dict']) 34 | DeepJSCC_V.eval() 35 | 36 | OraNet = OracleNet(enc_shape[0]).cuda() 37 | # OraNet = nn.DataParallel(OraNet) 38 | criterion = nn.MSELoss().cuda() 39 | optimizer = torch.optim.Adam(OraNet.parameters(), lr=LEARNING_RATE) 40 | 41 | if __name__ == '__main__': 42 | bestLoss = 1e3 43 | for epoch in range(EPOCHS): 44 | OraNet.train() 45 | for i, (x_input, _) in enumerate(train_loader): 46 | SNR = torch.randint(0, 28, (x_input.shape[0], 1)).cuda() 47 | CR = 0.1+0.9*torch.rand(x_input.shape[0], 1).cuda() 48 | 49 | x_input = torch.Tensor(x_input).cuda() 50 | x_rec = DeepJSCC_V(x_input, SNR, CR, CHANNEL) 51 | # z = DeepJSCC_V.module.encoder(x_input, SNR) 52 | 53 | z = DeepJSCC_V.encoder(x_input, SNR) 54 | 55 | x_input = Img_transform(x_input) 56 | x_rec = Img_transform(x_rec) 57 | psnr_batch = Compute_IMG_PSNR(x_input, x_rec) 58 | psnr_batch = torch.Tensor(psnr_batch).cuda() 59 | 60 | z = z.view(-1, enc_shape[0], IMGZ//4, IMGZ//4) 61 | psnr_pred = OraNet(z, SNR, CR) 62 | 63 | loss = criterion(psnr_batch, psnr_pred) 64 | loss = loss.mean() 65 | 66 | optimizer.zero_grad() 67 | loss.backward() 68 | optimizer.step() 69 | 70 | if i % PRINT_RREQ == 0: 71 | print('Epoch: [{0}][{1}/{2}]\t' 'Loss {loss:.4f}\t'.format(epoch, i, len(train_loader), loss=loss.item())) 72 | 73 | 74 | # Model Evaluation 75 | OraNet.eval() 76 | totalLoss = 0 77 | with torch.no_grad(): 78 | for i, (test_input, _) in enumerate(test_loader): 79 | SNR_TEST = torch.randint(0, 28, (test_input.shape[0], 1)).cuda() 80 | CR = 0.1+0.9*torch.rand(test_input.shape[0], 1).cuda() 81 | 82 | test_input = torch.Tensor(test_input).cuda() 83 | test_rec = DeepJSCC_V(test_input, SNR_TEST, CR, CHANNEL) 84 | # z = DeepJSCC_V.module.encoder(test_input, SNR_TEST) 85 | z = DeepJSCC_V.encoder(test_input, SNR_TEST) 86 | 87 | test_input = Img_transform(test_input) 88 | test_rec = Img_transform(test_rec) 89 | psnr_batch = Compute_IMG_PSNR(test_input, test_rec) 90 | psnr_batch = torch.Tensor(psnr_batch).cuda() 91 | 92 | z = z.view(-1, enc_shape[0], IMGZ//4, IMGZ//4) 93 | psnr_pred = OraNet(z, SNR_TEST, CR) 94 | 95 | totalLoss += criterion(psnr_batch, psnr_pred).item() * psnr_batch.size(0) 96 | averageLoss = totalLoss / 5000 97 | print('averageLoss=', averageLoss) 98 | if averageLoss < bestLoss: 99 | # Model saving 100 | if not os.path.exists('./JSCC_models'): 101 | os.makedirs('./JSCC_models') 102 | torch.save({'state_dict': OraNet.state_dict(), }, './JSCC_models/OracleNet_'+CHANNEL+'_ImageNet.pth.tar') 103 | print('Model saved') 104 | bestLoss = averageLoss 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PADC: Predictive-and-Adaptive-Deep-Coding-for-Wireless-Image-Transmission-in-Semantic-Communication 2 | Pytorch code for IEEE TWC paper "Predictive and Adaptive Deep Coding for Wireless Image Transmission in Semantic Communication" 3 | 4 | For more details, please read to the following paper: 5 | Zhang W, Zhang H, Ma H, et al. Predictive and Adaptive Deep Coding for Wireless Image Transmission in Semantic Communication. IEEE Transactions on Wireless Communications, 2023. 6 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torchvision.datasets as datasets 6 | 7 | 8 | def train_data_loader(batch_size = 64, imgz = 128, workers = 0, pin_memory = True): 9 | # Please specify your own dataset path here 10 | data_dir = os.path.join('E:/Code_Wenyu/Dataset/ILSVRC2012_img_val') 11 | dataset = datasets.ImageFolder( 12 | data_dir, 13 | transforms.Compose([ 14 | transforms.CenterCrop(imgz), 15 | transforms.RandomHorizontalFlip(p=0.5), 16 | transforms.RandomVerticalFlip(p=0.5), 17 | transforms.ToTensor(), 18 | ]) 19 | ) 20 | train_data_loader = torch.utils.data.DataLoader( 21 | dataset, 22 | batch_size = batch_size, 23 | shuffle = True, 24 | num_workers = workers, 25 | pin_memory = pin_memory 26 | ) 27 | return train_data_loader 28 | 29 | def test_data_loader(batch_size = 64, imgz = 128, workers = 0, pin_memory = True): 30 | # Please specify your own dataset path here 31 | data_dir = os.path.join('E:/Code_Wenyu/Dataset/test_data') 32 | dataset = datasets.ImageFolder( 33 | data_dir, 34 | transforms.Compose([ 35 | transforms.CenterCrop(imgz), 36 | transforms.ToTensor(), 37 | ]) 38 | ) 39 | test_data_loader = torch.utils.data.DataLoader( 40 | dataset, 41 | batch_size = batch_size, 42 | shuffle = False, 43 | num_workers = workers, 44 | pin_memory = pin_memory 45 | ) 46 | return test_data_loader 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch 6 | 7 | from GDN import GDN 8 | 9 | def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1): 10 | return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 11 | 12 | def deconv(in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding = 0): 13 | return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding = output_padding,bias=False) 14 | 15 | 16 | class conv_block(nn.Module): 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 18 | super(conv_block, self).__init__() 19 | self.conv = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 20 | self.gdn = nn.GDN(out_channels) 21 | self.prelu = nn.PReLU() 22 | def forward(self, x): 23 | out = self.conv(x) 24 | out = self.gdn(out) 25 | out = self.prelu(out) 26 | return out 27 | 28 | class deconv_block(nn.Module): 29 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, output_padding = 0): 30 | super(deconv_block, self).__init__() 31 | self.deconv = deconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding = output_padding) 32 | self.gdn = nn.GDN(out_channels) 33 | self.prelu = nn.PReLU() 34 | self.sigmoid = nn.Sigmoid() 35 | def forward(self, x, activate_func='prelu'): 36 | out = self.deconv(x) 37 | out = self.gdn(out) 38 | if activate_func=='prelu': 39 | out = self.prelu(out) 40 | elif activate_func=='sigmoid': 41 | out = self.sigmoid(out) 42 | return out 43 | 44 | class AF_block(nn.Module): 45 | def __init__(self, Nin, Nh, No): 46 | super(AF_block, self).__init__() 47 | self.fc1 = nn.Linear(Nin+1, Nh) 48 | self.fc2 = nn.Linear(Nh, No) 49 | self.relu = nn.ReLU() 50 | self.sigmoid = nn.Sigmoid() 51 | def forward(self, x, snr): 52 | # out = F.adaptive_avg_pool2d(x, (1,1)) 53 | # out = torch.squeeze(out) 54 | # out = torch.cat((out, snr), 1) 55 | if snr.shape[0]>1: 56 | snr = snr.squeeze() 57 | snr = snr.unsqueeze(1) 58 | mu = torch.mean(x, (2, 3)) 59 | out = torch.cat((mu, snr), 1) 60 | out = self.fc1(out) 61 | out = self.relu(out) 62 | out = self.fc2(out) 63 | out = self.sigmoid(out) 64 | out = out.unsqueeze(2) 65 | out = out.unsqueeze(3) 66 | out = out*x 67 | return out 68 | 69 | 70 | class conv_ResBlock(nn.Module): 71 | def __init__(self, in_channels, out_channels, use_conv1x1=False, kernel_size=3, stride=1, padding=1): 72 | super(conv_ResBlock, self).__init__() 73 | self.conv1 = conv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 74 | self.conv2 = conv(out_channels, out_channels, kernel_size=1, stride = 1, padding=0) 75 | self.gdn1 = GDN(out_channels) 76 | self.gdn2 = GDN(out_channels) 77 | self.prelu = nn.PReLU() 78 | self.use_conv1x1 = use_conv1x1 79 | if use_conv1x1 == True: 80 | self.conv3 = conv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0) 81 | def forward(self, x): 82 | out = self.conv1(x) 83 | out = self.gdn1(out) 84 | out = self.prelu(out) 85 | out = self.conv2(out) 86 | out = self.gdn2(out) 87 | if self.use_conv1x1 == True: 88 | x = self.conv3(x) 89 | out = out+x 90 | out = self.prelu(out) 91 | return out 92 | 93 | 94 | class deconv_ResBlock(nn.Module): 95 | def __init__(self, in_channels, out_channels, use_deconv1x1=False, kernel_size=3, stride=1, padding=1, output_padding=0): 96 | super(deconv_ResBlock, self).__init__() 97 | self.deconv1 = deconv(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding) 98 | self.deconv2 = deconv(out_channels, out_channels, kernel_size=1, stride = 1, padding=0, output_padding=0) 99 | self.gdn1 = GDN(out_channels) 100 | self.gdn2 = GDN(out_channels) 101 | self.prelu = nn.PReLU() 102 | self.sigmoid = nn.Sigmoid() 103 | self.use_deconv1x1 = use_deconv1x1 104 | if use_deconv1x1 == True: 105 | self.deconv3 = deconv(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, output_padding=output_padding) 106 | def forward(self, x, activate_func='prelu'): 107 | out = self.deconv1(x) 108 | out = self.gdn1(out) 109 | out = self.prelu(out) 110 | out = self.deconv2(out) 111 | out = self.gdn2(out) 112 | if self.use_deconv1x1 == True: 113 | x = self.deconv3(x) 114 | out = out+x 115 | if activate_func=='prelu': 116 | out = self.prelu(out) 117 | elif activate_func=='sigmoid': 118 | out = self.sigmoid(out) 119 | return out 120 | 121 | # The Encoder model with attention feature blocks 122 | class Encoder(nn.Module): 123 | def __init__(self, enc_shape, kernel_sz, Nc_conv): 124 | super(Encoder, self).__init__() 125 | enc_N = enc_shape[0] 126 | Nh_AF = Nc_conv//2 127 | padding_L = (kernel_sz-1)//2 128 | self.conv1 = conv_ResBlock(3, Nc_conv, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L) 129 | self.conv2 = conv_ResBlock(Nc_conv, Nc_conv, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L) 130 | self.conv3 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L) 131 | self.conv4 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L) 132 | self.conv5 = conv_ResBlock(Nc_conv, enc_N, use_conv1x1=True, kernel_size = kernel_sz, stride = 1, padding=padding_L) 133 | self.AF1 = AF_block(Nc_conv, Nh_AF, Nc_conv) 134 | self.AF2 = AF_block(Nc_conv, Nh_AF, Nc_conv) 135 | self.AF3 = AF_block(Nc_conv, Nh_AF, Nc_conv) 136 | self.AF4 = AF_block(Nc_conv, Nh_AF, Nc_conv) 137 | self.AF5 = AF_block(enc_N, enc_N//2, enc_N) 138 | self.flatten = nn.Flatten() 139 | def forward(self, x, snr): 140 | out = self.conv1(x) 141 | out = self.AF1(out, snr) 142 | out = self.conv2(out) 143 | out = self.AF2(out, snr) 144 | out = self.conv3(out) 145 | out = self.AF3(out, snr) 146 | out = self.conv4(out) 147 | out = self.AF4(out, snr) 148 | out = self.conv5(out) 149 | out = self.AF5(out, snr) 150 | out = self.flatten(out) 151 | return out 152 | 153 | # The Decoder model with attention feature blocks 154 | class Decoder(nn.Module): 155 | def __init__(self, enc_shape, kernel_sz, Nc_deconv): 156 | super(Decoder, self).__init__() 157 | self.enc_shape = enc_shape 158 | Nh_AF1 = enc_shape[0]//2 159 | Nh_AF = Nc_deconv//2 160 | padding_L = (kernel_sz-1)//2 161 | self.deconv1 = deconv_ResBlock(self.enc_shape[0], Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L, output_padding = 1) 162 | self.deconv2 = deconv_ResBlock(Nc_deconv, Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L, output_padding = 1) 163 | self.deconv3 = deconv_ResBlock(Nc_deconv, Nc_deconv, kernel_size=kernel_sz, stride=1, padding=padding_L) 164 | self.deconv4 = deconv_ResBlock(Nc_deconv, Nc_deconv, kernel_size=kernel_sz, stride=1, padding=padding_L) 165 | self.deconv5 = deconv_ResBlock(Nc_deconv, 3, use_deconv1x1=True, kernel_size=kernel_sz, stride=1, padding=padding_L) 166 | self.AF1 = AF_block(self.enc_shape[0], Nh_AF1, self.enc_shape[0]) 167 | self.AF2 = AF_block(Nc_deconv, Nh_AF, Nc_deconv) 168 | self.AF3 = AF_block(Nc_deconv, Nh_AF, Nc_deconv) 169 | self.AF4 = AF_block(Nc_deconv, Nh_AF, Nc_deconv) 170 | self.AF5 = AF_block(Nc_deconv, Nh_AF, Nc_deconv) 171 | def forward(self, x, snr): 172 | out = x.view(-1, self.enc_shape[0], self.enc_shape[1], self.enc_shape[2]) 173 | out = self.AF1(out, snr) 174 | out = self.deconv1(out) 175 | out = self.AF2(out, snr) 176 | out = self.deconv2(out) 177 | out = self.AF3(out, snr) 178 | out = self.deconv3(out) 179 | out = self.AF4(out, snr) 180 | out = self.deconv4(out) 181 | out = self.AF5(out, snr) 182 | out = self.deconv5(out, 'sigmoid') 183 | return out 184 | 185 | 186 | # # The complexities of the following Encoder and Decoder models are smaller 187 | # class Encoder(nn.Module): 188 | # def __init__(self, enc_shape, kernel_sz, Nc_conv): 189 | # super(Encoder, self).__init__() 190 | # enc_N = enc_shape[0] 191 | # Nh_AF = Nc_conv//2 192 | # padding_L = (kernel_sz-1)//2 193 | # self.conv1 = conv_ResBlock(3, Nc_conv//2, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L) 194 | # self.conv2 = conv_ResBlock(Nc_conv//2, Nc_conv, use_conv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L) 195 | # self.conv3 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L) 196 | # self.conv4 = conv_ResBlock(Nc_conv, Nc_conv, kernel_size = kernel_sz, stride = 1, padding=padding_L) 197 | # self.conv5 = conv_ResBlock(Nc_conv, enc_N, use_conv1x1=True, kernel_size = kernel_sz, stride = 1, padding=padding_L) 198 | # self.AF1 = AF_block(Nc_conv//2, Nh_AF//2, Nc_conv//2) 199 | # self.AF2 = AF_block(Nc_conv, Nh_AF, Nc_conv) 200 | # self.AF3 = AF_block(Nc_conv, Nh_AF, Nc_conv) 201 | # self.AF4 = AF_block(Nc_conv, Nh_AF, Nc_conv) 202 | # self.AF5 = AF_block(enc_N, enc_N//2, enc_N) 203 | # self.flatten = nn.Flatten() 204 | # def forward(self, x, snr): 205 | # out = self.conv1(x) 206 | # out = self.AF1(out, snr) 207 | # out = self.conv2(out) 208 | # out = self.AF2(out, snr) 209 | # out = self.conv3(out) 210 | # out = self.AF3(out, snr) 211 | # out = self.conv4(out) 212 | # out = self.AF4(out, snr) 213 | # out = self.conv5(out) 214 | # out = self.AF5(out, snr) 215 | # out = self.flatten(out) 216 | # return out 217 | 218 | # class Decoder(nn.Module): 219 | # def __init__(self, enc_shape, kernel_sz, Nc_deconv): 220 | # super(Decoder, self).__init__() 221 | # self.enc_shape = enc_shape 222 | # Nh_AF1 = enc_shape[0]//2 223 | # Nh_AF = Nc_deconv//2 224 | # padding_L = (kernel_sz-1)//2 225 | # self.deconv1 = deconv_ResBlock(self.enc_shape[0], Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L, output_padding = 1) 226 | # self.deconv2 = deconv_ResBlock(Nc_deconv, Nc_deconv, use_deconv1x1=True, kernel_size = kernel_sz, stride = 2, padding=padding_L, output_padding = 1) 227 | # self.deconv3 = deconv_ResBlock(Nc_deconv, Nc_deconv, kernel_size=kernel_sz, stride=1, padding=padding_L) 228 | # self.deconv4 = deconv_ResBlock(Nc_deconv, Nc_deconv//2, use_deconv1x1=True, kernel_size=kernel_sz, stride=1, padding=padding_L) 229 | # self.deconv5 = deconv_ResBlock(Nc_deconv//2, 3, use_deconv1x1=True, kernel_size=kernel_sz, stride=1, padding=padding_L) 230 | # self.AF1 = AF_block(self.enc_shape[0], Nh_AF1, self.enc_shape[0]) 231 | # self.AF2 = AF_block(Nc_deconv, Nh_AF, Nc_deconv) 232 | # self.AF3 = AF_block(Nc_deconv, Nh_AF, Nc_deconv) 233 | # self.AF4 = AF_block(Nc_deconv, Nh_AF, Nc_deconv) 234 | # self.AF5 = AF_block(Nc_deconv//2, Nh_AF//2, Nc_deconv//2) 235 | # def forward(self, x, snr): 236 | # out = x.view(-1, self.enc_shape[0], self.enc_shape[1], self.enc_shape[2]) 237 | # out = self.AF1(out, snr) 238 | # out = self.deconv1(out) 239 | # out = self.AF2(out, snr) 240 | # out = self.deconv2(out) 241 | # out = self.AF3(out, snr) 242 | # out = self.deconv3(out) 243 | # out = self.AF4(out, snr) 244 | # out = self.deconv4(out) 245 | # out = self.AF5(out, snr) 246 | # out = self.deconv5(out, 'sigmoid') 247 | # return out 248 | 249 | 250 | 251 | # Power normalization before transmission 252 | # Note: if P = 1, the symbol power is 2 253 | # If you want to set the average power as 1, please change P as P=1/np.sqrt(2) 254 | def Power_norm(z, P = 1): 255 | batch_size, z_dim = z.shape 256 | z_power = torch.sqrt(torch.sum(z**2, 1)) 257 | z_M = z_power.repeat(z_dim, 1) 258 | return np.sqrt(P*z_dim)*z/z_M.t() 259 | 260 | def Power_norm_complex(z, P = 1): 261 | batch_size, z_dim = z.shape 262 | z_com = torch.complex(z[:, 0:z_dim:2], z[:, 1:z_dim:2]) 263 | z_com_conj = torch.complex(z[:, 0:z_dim:2], -z[:, 1:z_dim:2]) 264 | z_power = torch.sum(z_com*z_com_conj, 1).real 265 | z_M = z_power.repeat(z_dim//2, 1) 266 | z_nlz = np.sqrt(P*z_dim)*z_com/torch.sqrt(z_M.t()) 267 | z_out = torch.zeros(batch_size, z_dim).cuda() 268 | z_out[:, 0:z_dim:2] = z_nlz.real 269 | z_out[:, 1:z_dim:2] = z_nlz.imag 270 | return z_out 271 | 272 | # The (real) AWGN channel 273 | def AWGN_channel(x, snr, P = 1): 274 | batch_size, length = x.shape 275 | gamma = 10 ** (snr / 10.0) 276 | noise = torch.sqrt(P/gamma)*torch.randn(batch_size, length).cuda() 277 | y = x+noise 278 | return y 279 | 280 | def AWGN_complex(x, snr, Ps = 1): 281 | batch_size, length = x.shape 282 | gamma = 10 ** (snr / 10.0) 283 | n_I = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda() 284 | n_R = torch.sqrt(Ps/gamma)*torch.randn(batch_size, length).cuda() 285 | noise = torch.complex(n_I, n_R) 286 | y = x + noise 287 | return y 288 | 289 | # Please set the symbol power if it is not a default value 290 | def Fading_channel(x, snr, P = 1): 291 | gamma = 10 ** (snr / 10.0) 292 | [batch_size, feature_length] = x.shape 293 | K = feature_length//2 294 | 295 | h_I = torch.randn(batch_size, K).cuda() 296 | h_R = torch.randn(batch_size, K).cuda() 297 | h_com = torch.complex(h_I, h_R) 298 | x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2]) 299 | y_com = h_com*x_com 300 | 301 | n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda() 302 | n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda() 303 | noise = torch.complex(n_I, n_R) 304 | 305 | y_add = y_com + noise 306 | y = y_add/h_com 307 | 308 | y_out = torch.zeros(batch_size, feature_length).cuda() 309 | y_out[:, 0:feature_length:2] = y.real 310 | y_out[:, 1:feature_length:2] = y.imag 311 | return y_out 312 | 313 | 314 | 315 | # Note: if P = 1, the symbol power is 2 316 | # If you want to set the average power as 1, please change P as P=1/np.sqrt(2) 317 | def Power_norm_VLC(z, cr, P = 1): 318 | batch_size, z_dim = z.shape 319 | Kv = torch.ceil(z_dim*cr).int() 320 | z_power = torch.sqrt(torch.sum(z**2, 1)) 321 | z_M = z_power.repeat(z_dim, 1).cuda() 322 | return torch.sqrt(Kv*P)*z/z_M.t() 323 | 324 | 325 | def AWGN_channel_VLC(x, snr, cr, P = 1): 326 | batch_size, length = x.shape 327 | gamma = 10 ** (snr / 10.0) 328 | mask = mask_gen(length, cr).cuda() 329 | noise = torch.sqrt(P/gamma)*torch.randn(1, length).cuda() 330 | noise = noise*mask 331 | y = x+noise 332 | return y 333 | 334 | 335 | def Fading_channel_VLC(x, snr, cr, P = 1): 336 | gamma = 10 ** (snr / 10.0) 337 | [batch_size, feature_length] = x.shape 338 | K = feature_length//2 339 | 340 | mask = mask_gen(K, cr).cuda() 341 | h_I = torch.randn(batch_size, K).cuda() 342 | h_R = torch.randn(batch_size, K).cuda() 343 | h_com = torch.complex(h_I, h_R) 344 | x_com = torch.complex(x[:, 0:feature_length:2], x[:, 1:feature_length:2]) 345 | y_com = h_com*x_com 346 | 347 | n_I = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda() 348 | n_R = torch.sqrt(P/gamma)*torch.randn(batch_size, K).cuda() 349 | noise = torch.complex(n_I, n_R)*mask 350 | 351 | y_add = y_com + noise 352 | y = y_add/h_com 353 | 354 | y_out = torch.zeros(batch_size, feature_length).cuda() 355 | y_out[:, 0:feature_length:2] = y.real 356 | y_out[:, 1:feature_length:2] = y.imag 357 | return y_out 358 | 359 | 360 | def Channel(z, snr, channel_type = 'AWGN'): 361 | z = Power_norm(z) 362 | if channel_type == 'AWGN': 363 | z = AWGN_channel(z, snr) 364 | elif channel_type == 'Fading': 365 | z = Fading_channel(z, snr) 366 | return z 367 | 368 | 369 | def Channel_VLC(z, snr, cr, channel_type = 'AWGN'): 370 | z = Power_norm_VLC(z, cr) 371 | if channel_type == 'AWGN': 372 | z = AWGN_channel_VLC(z, snr, cr) 373 | elif channel_type == 'Fading': 374 | z = Fading_channel_VLC(z, snr, cr) 375 | return z 376 | 377 | 378 | def mask_gen(N, cr, ch_max = 48): 379 | MASK = torch.zeros(cr.shape[0], N).int() 380 | nc = N//ch_max 381 | for i in range(0, cr.shape[0]): 382 | L_i = nc*torch.round(ch_max*cr[i]).int() 383 | MASK[i, 0:L_i] = 1 384 | return MASK 385 | 386 | 387 | class ADJSCC(nn.Module): 388 | def __init__(self, enc_shape, Kernel_sz, Nc): 389 | super(ADJSCC, self).__init__() 390 | self.encoder = Encoder(enc_shape, Kernel_sz, Nc) 391 | self.decoder = Decoder(enc_shape, Kernel_sz, Nc) 392 | def forward(self, x, snr, channel_type = 'AWGN'): 393 | z = self.encoder(x, snr) 394 | z = Channel(z, snr, channel_type) 395 | out = self.decoder(z, snr) 396 | return out 397 | 398 | # The DeepJSCC_V model, also called ADJSCC_V 399 | class ADJSCC_V(nn.Module): 400 | def __init__(self, enc_shape, Kernel_sz, Nc): 401 | super(ADJSCC_V, self).__init__() 402 | self.encoder = Encoder(enc_shape, Kernel_sz, Nc) 403 | self.decoder = Decoder(enc_shape, Kernel_sz, Nc) 404 | def forward(self, x, snr, cr, channel_type = 'AWGN'): 405 | z = self.encoder(x, snr) 406 | z = z*mask_gen(z.shape[1], cr).cuda() 407 | z = Channel_VLC(z, snr, cr, channel_type) 408 | out = self.decoder(z, snr) 409 | return out 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | # import torch 4 | # from torchvision import datasets, transforms 5 | from torch.utils.data import Dataset 6 | 7 | # import matplotlib.pyplot as plt 8 | # from skimage.metrics import structural_similarity as compute_ssim 9 | from skimage.metrics import peak_signal_noise_ratio as compute_pnsr 10 | 11 | # from models import * 12 | 13 | 14 | # Note that the original data is downloaded from keras.datasets, not from torch.utils.data 15 | def Load_cifar10_data(): 16 | x_train = np.load('data/CIFAR10_raw/x_train.npy') 17 | x_test = np.load('data/CIFAR10_raw/x_test.npy') 18 | # from keras.datasets import cifar10 19 | # (x_train, y_train_), (x_test, y_test_) = cifar10.load_data() 20 | x_train = np.transpose(x_train, (0, 3, 1, 2)) 21 | x_test = np.transpose(x_test, (0, 3, 1, 2)) 22 | x_train = x_train.astype('float32') / 255 23 | x_test = x_test.astype('float32') / 255 24 | return x_train, x_test 25 | 26 | 27 | # Note that the original data is downloaded from keras.datasets, not from torch.utils.data 28 | def Load_cifar100_data(): 29 | x_train = np.load('data/CIFAR100_raw/x_train.npy') 30 | x_test = np.load('data/CIFAR100_raw/x_test.npy') 31 | # from keras.datasets import cifar10 32 | # (x_train, y_train_), (x_test, y_test_) = cifar10.load_data() 33 | x_train = np.transpose(x_train, (0, 3, 1, 2)) 34 | x_test = np.transpose(x_test, (0, 3, 1, 2)) 35 | x_train = x_train.astype('float32') / 255 36 | x_test = x_test.astype('float32') / 255 37 | return x_train, x_test 38 | 39 | 40 | # def Plot_CIFAR10_img(x): 41 | # digit_size = 32 42 | # n = 5 43 | # figure = np.zeros((digit_size*n, digit_size * n, 3)) 44 | # for i in range (n): 45 | # x_i = x[i * n: (i + 1) * n, :] 46 | # for j in range(n): 47 | # digit = x_i[j].reshape(digit_size, digit_size, 3) 48 | # figure[i * digit_size: (i + 1) * digit_size, 49 | # j * digit_size: (j + 1) * digit_size, :] = digit 50 | 51 | # plt.figure(figsize=(10, 10)) 52 | # plt.imshow(figure, cmap='Greys_r') 53 | # plt.axis('off') 54 | # plt.show() 55 | 56 | 57 | def Img_transform(test_rec): 58 | test_rec = test_rec.permute(0, 2, 3, 1) 59 | test_rec = test_rec.cpu().detach().numpy() 60 | test_rec = test_rec*255 61 | test_rec = test_rec.astype(np.uint8) 62 | return test_rec 63 | 64 | def Compute_batch_PSNR(test_input, test_rec): 65 | psnr_i1 = np.zeros((test_input.shape[0])) 66 | for j in range(0, test_input.shape[0]): 67 | psnr_i1[j] = compute_pnsr(test_input[j, :], test_rec[j, :]) 68 | psnr_ave = np.mean(psnr_i1) 69 | return psnr_ave 70 | 71 | 72 | def Compute_IMG_PSNR(test_input, test_rec): 73 | psnr_i1 = np.zeros((test_input.shape[0], 1)) 74 | for j in range(0, test_input.shape[0]): 75 | psnr_i1[j] = compute_pnsr(test_input[j, :], test_rec[j, :]) 76 | return psnr_i1 77 | 78 | # Data Loader 79 | class DatasetFolder(Dataset): 80 | def __init__(self, matData): 81 | self.matdata = matData 82 | def __getitem__(self, index): 83 | return self.matdata[index] 84 | def __len__(self): 85 | return self.matdata.shape[0] 86 | 87 | # Use the following learning schedulars maybe helpful for improving the training quality 88 | def lr_schedular(cur_epoch, warmup_epoch, epochs, lr_max): 89 | lr_min = 1e-6 90 | kappa = (lr_max-lr_min)/warmup_epoch 91 | if cur_epoch < warmup_epoch: 92 | lr = lr_min + kappa*cur_epoch 93 | else: 94 | lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(np.pi * (cur_epoch-warmup_epoch) / epochs)) 95 | return lr 96 | 97 | def lr_schedular_step(epoch, warmup_epoch, EPOCHS, lr_max): 98 | lr_min = 1e-6 99 | kappa = (lr_max-lr_min)/warmup_epoch 100 | if epoch < warmup_epoch: 101 | lr = lr_min + kappa*epoch 102 | else: 103 | eta = EPOCHS/100 104 | if epoch<=25*eta: 105 | lr = lr_max 106 | elif epoch>25*eta and epoch<=50*eta: 107 | lr = lr_max/2 108 | elif epoch>50*eta and epoch<=80*eta: 109 | lr = lr_max/4 110 | elif epoch>80*eta and epoch<=95*eta: 111 | lr = lr_max/8 112 | elif epoch>95*eta: 113 | lr = lr_max/16 114 | return lr 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | --------------------------------------------------------------------------------