├── tmp ├── README.md ├── Pilot_16 ├── lr_scheduler.py ├── Pilot_64 ├── utils.py └── train_multigpu.ipynb ├── Model_train.py ├── Model_define_pytorch.py └── Model_define_pytorch_score_09945.py /tmp/README.md: -------------------------------------------------------------------------------- 1 | 目录结构 2 | 3 | - src (this folder) 4 | - input (fold of data) 5 | - Y_1.csv 6 | - Y_2.csv 7 | - H.bin 8 | - H_val.bin 9 | 10 | run train_multigpu.ipynb 11 | 12 | 13 | !!! [place holder] 复赛结束后开源 14 | -------------------------------------------------------------------------------- /tmp/Pilot_16: -------------------------------------------------------------------------------- 1 | 1.000000000000000000e+00 2 | 1.000000000000000000e+00 3 | 0.000000000000000000e+00 4 | 1.000000000000000000e+00 5 | 1.000000000000000000e+00 6 | 0.000000000000000000e+00 7 | 0.000000000000000000e+00 8 | 0.000000000000000000e+00 9 | 0.000000000000000000e+00 10 | 1.000000000000000000e+00 11 | 1.000000000000000000e+00 12 | 0.000000000000000000e+00 13 | 0.000000000000000000e+00 14 | 0.000000000000000000e+00 15 | 1.000000000000000000e+00 16 | 1.000000000000000000e+00 17 | 0.000000000000000000e+00 18 | 1.000000000000000000e+00 19 | 0.000000000000000000e+00 20 | 0.000000000000000000e+00 21 | 1.000000000000000000e+00 22 | 1.000000000000000000e+00 23 | 1.000000000000000000e+00 24 | 0.000000000000000000e+00 25 | 1.000000000000000000e+00 26 | 1.000000000000000000e+00 27 | 1.000000000000000000e+00 28 | 1.000000000000000000e+00 29 | 0.000000000000000000e+00 30 | 1.000000000000000000e+00 31 | 1.000000000000000000e+00 32 | 1.000000000000000000e+00 33 | -------------------------------------------------------------------------------- /tmp/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | class CosineAnnealingWarmUpRestarts(_LRScheduler): 5 | def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1): 6 | if T_0 <= 0 or not isinstance(T_0, int): 7 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 8 | if T_mult < 1 or not isinstance(T_mult, int): 9 | raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) 10 | if T_up < 0 or not isinstance(T_up, int): 11 | raise ValueError("Expected positive integer T_up, but got {}".format(T_up)) 12 | self.T_0 = T_0 13 | self.T_mult = T_mult 14 | self.base_eta_max = eta_max 15 | self.eta_max = eta_max 16 | self.T_up = T_up 17 | self.T_i = T_0 18 | self.gamma = gamma 19 | self.cycle = 0 20 | self.T_cur = last_epoch 21 | super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch) 22 | 23 | def get_lr(self): 24 | if self.T_cur == -1: 25 | return self.base_lrs 26 | elif self.T_cur < self.T_up: 27 | return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs] 28 | else: 29 | return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2 30 | for base_lr in self.base_lrs] 31 | 32 | def step(self, epoch=None): 33 | if epoch is None: 34 | epoch = self.last_epoch + 1 35 | self.T_cur = self.T_cur + 1 36 | if self.T_cur >= self.T_i: 37 | self.cycle += 1 38 | self.T_cur = self.T_cur - self.T_i 39 | self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up 40 | else: 41 | if epoch >= self.T_0: 42 | if self.T_mult == 1: 43 | self.T_cur = epoch % self.T_0 44 | self.cycle = epoch // self.T_0 45 | else: 46 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 47 | self.cycle = n 48 | self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) 49 | self.T_i = self.T_0 * self.T_mult ** (n) 50 | else: 51 | self.T_i = self.T_0 52 | self.T_cur = epoch 53 | 54 | self.eta_max = self.base_eta_max * (self.gamma**self.cycle) 55 | self.last_epoch = math.floor(epoch) 56 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 57 | param_group['lr'] = lr -------------------------------------------------------------------------------- /tmp/Pilot_64: -------------------------------------------------------------------------------- 1 | 0.000000000000000000e+00 2 | 0.000000000000000000e+00 3 | 1.000000000000000000e+00 4 | 1.000000000000000000e+00 5 | 1.000000000000000000e+00 6 | 0.000000000000000000e+00 7 | 0.000000000000000000e+00 8 | 1.000000000000000000e+00 9 | 0.000000000000000000e+00 10 | 1.000000000000000000e+00 11 | 1.000000000000000000e+00 12 | 0.000000000000000000e+00 13 | 0.000000000000000000e+00 14 | 0.000000000000000000e+00 15 | 1.000000000000000000e+00 16 | 0.000000000000000000e+00 17 | 1.000000000000000000e+00 18 | 1.000000000000000000e+00 19 | 0.000000000000000000e+00 20 | 1.000000000000000000e+00 21 | 1.000000000000000000e+00 22 | 1.000000000000000000e+00 23 | 0.000000000000000000e+00 24 | 0.000000000000000000e+00 25 | 1.000000000000000000e+00 26 | 0.000000000000000000e+00 27 | 0.000000000000000000e+00 28 | 1.000000000000000000e+00 29 | 1.000000000000000000e+00 30 | 0.000000000000000000e+00 31 | 0.000000000000000000e+00 32 | 1.000000000000000000e+00 33 | 1.000000000000000000e+00 34 | 0.000000000000000000e+00 35 | 0.000000000000000000e+00 36 | 1.000000000000000000e+00 37 | 1.000000000000000000e+00 38 | 1.000000000000000000e+00 39 | 0.000000000000000000e+00 40 | 0.000000000000000000e+00 41 | 1.000000000000000000e+00 42 | 0.000000000000000000e+00 43 | 0.000000000000000000e+00 44 | 1.000000000000000000e+00 45 | 1.000000000000000000e+00 46 | 0.000000000000000000e+00 47 | 0.000000000000000000e+00 48 | 1.000000000000000000e+00 49 | 0.000000000000000000e+00 50 | 1.000000000000000000e+00 51 | 1.000000000000000000e+00 52 | 0.000000000000000000e+00 53 | 1.000000000000000000e+00 54 | 0.000000000000000000e+00 55 | 1.000000000000000000e+00 56 | 0.000000000000000000e+00 57 | 0.000000000000000000e+00 58 | 1.000000000000000000e+00 59 | 1.000000000000000000e+00 60 | 1.000000000000000000e+00 61 | 0.000000000000000000e+00 62 | 0.000000000000000000e+00 63 | 1.000000000000000000e+00 64 | 1.000000000000000000e+00 65 | 0.000000000000000000e+00 66 | 0.000000000000000000e+00 67 | 1.000000000000000000e+00 68 | 1.000000000000000000e+00 69 | 1.000000000000000000e+00 70 | 1.000000000000000000e+00 71 | 0.000000000000000000e+00 72 | 0.000000000000000000e+00 73 | 1.000000000000000000e+00 74 | 0.000000000000000000e+00 75 | 0.000000000000000000e+00 76 | 0.000000000000000000e+00 77 | 1.000000000000000000e+00 78 | 0.000000000000000000e+00 79 | 0.000000000000000000e+00 80 | 0.000000000000000000e+00 81 | 1.000000000000000000e+00 82 | 0.000000000000000000e+00 83 | 1.000000000000000000e+00 84 | 0.000000000000000000e+00 85 | 1.000000000000000000e+00 86 | 1.000000000000000000e+00 87 | 1.000000000000000000e+00 88 | 1.000000000000000000e+00 89 | 0.000000000000000000e+00 90 | 1.000000000000000000e+00 91 | 1.000000000000000000e+00 92 | 1.000000000000000000e+00 93 | 0.000000000000000000e+00 94 | 1.000000000000000000e+00 95 | 1.000000000000000000e+00 96 | 0.000000000000000000e+00 97 | 1.000000000000000000e+00 98 | 1.000000000000000000e+00 99 | 1.000000000000000000e+00 100 | 1.000000000000000000e+00 101 | 1.000000000000000000e+00 102 | 1.000000000000000000e+00 103 | 1.000000000000000000e+00 104 | 1.000000000000000000e+00 105 | 1.000000000000000000e+00 106 | 1.000000000000000000e+00 107 | 1.000000000000000000e+00 108 | 0.000000000000000000e+00 109 | 0.000000000000000000e+00 110 | 0.000000000000000000e+00 111 | 0.000000000000000000e+00 112 | 0.000000000000000000e+00 113 | 0.000000000000000000e+00 114 | 1.000000000000000000e+00 115 | 1.000000000000000000e+00 116 | 1.000000000000000000e+00 117 | 0.000000000000000000e+00 118 | 1.000000000000000000e+00 119 | 1.000000000000000000e+00 120 | 0.000000000000000000e+00 121 | 0.000000000000000000e+00 122 | 1.000000000000000000e+00 123 | 0.000000000000000000e+00 124 | 0.000000000000000000e+00 125 | 1.000000000000000000e+00 126 | 0.000000000000000000e+00 127 | 0.000000000000000000e+00 128 | 1.000000000000000000e+00 129 | -------------------------------------------------------------------------------- /Model_train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | seefun . Aug 2020. 3 | github.com/seefun | kaggle.com/seefun 4 | ''' 5 | 6 | import numpy as np 7 | import h5py 8 | import torch 9 | import os 10 | import torch.nn as nn 11 | import random 12 | 13 | from Model_define_pytorch import AutoEncoder, DatasetFolder, NMSE_cuda, NMSELoss 14 | 15 | # Parameters for training 16 | gpu_list = '0' 17 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list 18 | 19 | def seed_everything(seed=42): 20 | random.seed(seed) 21 | os.environ['PYHTONHASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | SEED = 42 28 | seed_everything(SEED) 29 | 30 | batch_size = 256 31 | epochs = 100 32 | learning_rate = 2e-3 # bigger to train faster 33 | num_workers = 4 34 | print_freq = 500 35 | train_test_ratio = 0.8 36 | # parameters for data 37 | feedback_bits = 128 38 | img_height = 16 39 | img_width = 32 40 | img_channels = 2 41 | 42 | 43 | # Model construction 44 | model = AutoEncoder(feedback_bits) 45 | 46 | model.encoder.quantization = False 47 | model.decoder.quantization = False 48 | 49 | if len(gpu_list.split(',')) > 1: 50 | model = torch.nn.DataParallel(model).cuda() # model.module 51 | else: 52 | model = model.cuda() 53 | 54 | criterion = NMSELoss(reduction='mean') #nn.MSELoss() 55 | criterion_test = NMSELoss(reduction='sum') 56 | 57 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 58 | 59 | 60 | # Data loading 61 | data_load_address = './data' 62 | mat = h5py.File(data_load_address + '/H_train.mat', 'r') 63 | data = np.transpose(mat['H_train']) # shape=(320000, 1024) 64 | data = data.astype('float32') 65 | data = np.reshape(data, [len(data), img_channels, img_height, img_width]) 66 | # split data for training(80%) and validation(20%) 67 | np.random.shuffle(data) 68 | start = int(data.shape[0] * train_test_ratio) 69 | x_train, x_test = data[:start], data[start:] 70 | 71 | # dataLoader for training 72 | train_dataset = DatasetFolder(x_train) 73 | train_loader = torch.utils.data.DataLoader( 74 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True) 75 | 76 | # dataLoader for training 77 | test_dataset = DatasetFolder(x_test) 78 | test_loader = torch.utils.data.DataLoader( 79 | test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 80 | 81 | 82 | 83 | best_loss = 100 84 | for epoch in range(epochs): 85 | print('========================') 86 | print('lr:%.4e'%optimizer.param_groups[0]['lr']) 87 | # model training 88 | model.train() 89 | if epoch < epochs//10: 90 | try: 91 | model.encoder.quantization = False 92 | model.decoder.quantization = False 93 | except: 94 | model.module.encoder.quantization = False 95 | model.module.decoder.quantization = False 96 | else: 97 | try: 98 | model.encoder.quantization = True 99 | model.decoder.quantization = True 100 | except: 101 | model.module.encoder.quantization = True 102 | model.module.decoder.quantization = True 103 | 104 | if epoch == epochs//4 * 3: 105 | optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] * 0.25 106 | 107 | for i, input in enumerate(train_loader): 108 | 109 | input = input.cuda() 110 | output = model(input) 111 | 112 | loss = criterion(output, input) 113 | loss.backward() 114 | optimizer.step() 115 | optimizer.zero_grad() 116 | 117 | if i % print_freq == 0: 118 | print('Epoch: [{0}][{1}/{2}]\t' 119 | 'Loss {loss:.4f}\t'.format( 120 | epoch, i, len(train_loader), loss=loss.item())) 121 | model.eval() 122 | try: 123 | model.encoder.quantization = True 124 | model.decoder.quantization = True 125 | except: 126 | model.module.encoder.quantization = True 127 | model.module.decoder.quantization = True 128 | total_loss = 0 129 | with torch.no_grad(): 130 | for i, input in enumerate(test_loader): 131 | # convert numpy to Tensor 132 | input = input.cuda() 133 | output = model(input) 134 | total_loss += criterion_test(output, input).item() 135 | average_loss = total_loss / len(test_dataset) 136 | print('NMSE %.4f'%average_loss) 137 | if average_loss < best_loss: 138 | # model save 139 | # save encoder 140 | modelSave1 = './Modelsave/encoder.pth.tar' 141 | try: 142 | torch.save({'state_dict': model.encoder.state_dict(), }, modelSave1) 143 | except: 144 | torch.save({'state_dict': model.module.encoder.state_dict(), }, modelSave1) 145 | # save decoder 146 | modelSave2 = './Modelsave/decoder.pth.tar' 147 | try: 148 | torch.save({'state_dict': model.decoder.state_dict(), }, modelSave2) 149 | except: 150 | torch.save({'state_dict': model.module.decoder.state_dict(), }, modelSave2) 151 | print('Model saved!') 152 | best_loss = average_loss 153 | 154 | -------------------------------------------------------------------------------- /Model_define_pytorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """An Implement of an autoencoder with pytorch. 3 | This is the template code for 2020 NIAC https://naic.pcl.ac.cn/. 4 | The code is based on the sample code with tensorflow for 2020 NIAC and it can only run with GPUS. 5 | If you have any questions, please contact me with https://github.com/xufana7/AutoEncoder-with-pytorch 6 | Author, Fan xu Aug 2020 7 | 8 | changed by seefun Aug 2020 9 | github.com/seefun | kaggle.com/seefun 10 | """ 11 | import numpy as np 12 | import torch.nn as nn 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import Dataset 16 | from collections import OrderedDict 17 | 18 | 19 | # This part implement the quantization and dequantization operations. 20 | # The output of the encoder must be the bitstream. 21 | def Num2Bit(Num, B): 22 | Num_ = Num.type(torch.uint8) 23 | 24 | def integer2bit(integer, num_bits=B * 2): 25 | dtype = integer.type() 26 | exponent_bits = -torch.arange(-(num_bits - 1), 1).type(dtype) 27 | exponent_bits = exponent_bits.repeat(integer.shape + (1,)) 28 | out = integer.unsqueeze(-1) // 2 ** exponent_bits 29 | return (out - (out % 1)) % 2 30 | 31 | bit = integer2bit(Num_) 32 | bit = (bit[:, :, B:]).reshape(-1, Num_.shape[1] * B) 33 | return bit.type(torch.float32) 34 | 35 | 36 | def Bit2Num(Bit, B): 37 | Bit_ = Bit.type(torch.float32) 38 | Bit_ = torch.reshape(Bit_, [-1, int(Bit_.shape[1] / B), B]) 39 | num = torch.zeros(Bit_[:, :, 1].shape).cuda() 40 | for i in range(B): 41 | num = num + Bit_[:, :, i] * 2 ** (B - 1 - i) 42 | return num 43 | 44 | 45 | class Quantization(torch.autograd.Function): 46 | @staticmethod 47 | def forward(ctx, x, B): 48 | ctx.constant = B 49 | step = 2 ** B 50 | out = torch.round(x * step - 0.5) 51 | out = Num2Bit(out, B) 52 | return out 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | # return as many input gradients as there were arguments. 57 | # Gradients of constant arguments to forward must be None. 58 | # Gradient of a number is the sum of its four bits. 59 | b, _ = grad_output.shape 60 | grad_num = torch.sum(grad_output.reshape(b, -1, ctx.constant), dim=2) 61 | return grad_num, None 62 | 63 | 64 | class Dequantization(torch.autograd.Function): 65 | @staticmethod 66 | def forward(ctx, x, B): 67 | ctx.constant = B 68 | step = 2 ** B 69 | out = Bit2Num(x, B) 70 | out = (out + 0.5) / step 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | # return as many input gradients as there were arguments. 76 | # Gradients of non-Tensor arguments to forward must be None. 77 | # repeat the gradient of a Num for four time. 78 | #b, c = grad_output.shape 79 | #grad_bit = grad_output.repeat(1, 1, ctx.constant) 80 | #return torch.reshape(grad_bit, (-1, c * ctx.constant)), None 81 | grad_bit = grad_output.repeat_interleave(ctx.constant, dim=1) 82 | return grad_bit, None 83 | 84 | 85 | class QuantizationLayer(nn.Module): 86 | 87 | def __init__(self, B): 88 | super(QuantizationLayer, self).__init__() 89 | self.B = B 90 | 91 | def forward(self, x): 92 | out = Quantization.apply(x, self.B) 93 | return out 94 | 95 | 96 | class DequantizationLayer(nn.Module): 97 | 98 | def __init__(self, B): 99 | super(DequantizationLayer, self).__init__() 100 | self.B = B 101 | 102 | def forward(self, x): 103 | out = Dequantization.apply(x, self.B) 104 | return out 105 | 106 | 107 | def conv3x3(in_planes, out_planes, stride=1): 108 | """3x3 convolution with padding""" 109 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 110 | padding=1, bias=True) 111 | 112 | 113 | class ConvBN(nn.Sequential): 114 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1): 115 | if not isinstance(kernel_size, int): 116 | padding = [(i - 1) // 2 for i in kernel_size] 117 | else: 118 | padding = (kernel_size - 1) // 2 119 | super(ConvBN, self).__init__(OrderedDict([ 120 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, 121 | padding=padding, groups=groups, bias=False)), 122 | ('bn', nn.BatchNorm2d(out_planes)) 123 | ])) 124 | 125 | 126 | class CRBlock(nn.Module): 127 | def __init__(self): 128 | super(CRBlock, self).__init__() 129 | self.path1 = nn.Sequential(OrderedDict([ 130 | ('conv3x3', ConvBN(32, 32, 3)), 131 | ('relu1', nn.LeakyReLU(negative_slope=0.3, inplace=True)), 132 | ('conv1x9', ConvBN(32, 32, [1, 9])), 133 | ('relu2', nn.LeakyReLU(negative_slope=0.3, inplace=True)), 134 | ('conv9x1', ConvBN(32, 32, [9, 1])), 135 | ])) 136 | self.path2 = nn.Sequential(OrderedDict([ 137 | ('conv1x5', ConvBN(32, 32, [1, 5])), 138 | ('relu', nn.LeakyReLU(negative_slope=0.3, inplace=True)), 139 | ('conv5x1', ConvBN(32, 32, [5, 1])), 140 | ])) 141 | self.conv1x1 = ConvBN(32 * 2, 32, 1) 142 | self.identity = nn.Identity() 143 | self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True) 144 | 145 | def forward(self, x): 146 | identity = self.identity(x) 147 | 148 | out1 = self.path1(x) 149 | out2 = self.path2(x) 150 | out = torch.cat((out1, out2), dim=1) 151 | out = self.relu(out) 152 | out = self.conv1x1(out) 153 | 154 | out = self.relu(out + identity) 155 | return out 156 | 157 | 158 | class Encoder(nn.Module): 159 | B = 4 160 | 161 | def __init__(self, feedback_bits, quantization=True): 162 | super(Encoder, self).__init__() 163 | self.encoder1 = nn.Sequential(OrderedDict([ 164 | ("conv3x3_bn", ConvBN(2, 32, 3)), 165 | ("relu1", nn.LeakyReLU(negative_slope=0.3, inplace=True)), 166 | ("conv1x9_bn", ConvBN(32, 32, [1, 9])), 167 | ("relu2", nn.LeakyReLU(negative_slope=0.3, inplace=True)), 168 | ("conv9x1_bn", ConvBN(32, 32, [9, 1])), 169 | ])) 170 | self.encoder2 = ConvBN(2, 32, 3) 171 | self.encoder_conv = nn.Sequential(OrderedDict([ 172 | ("relu1", nn.LeakyReLU(negative_slope=0.3, inplace=True)), 173 | ("conv1x1_bn", ConvBN(32*2, 2, 1)), 174 | ("relu2", nn.LeakyReLU(negative_slope=0.3, inplace=True)), 175 | ])) 176 | 177 | self.fc = nn.Linear(1024, int(feedback_bits / self.B)) 178 | self.sig = nn.Sigmoid() 179 | self.quantize = QuantizationLayer(self.B) 180 | self.quantization = quantization 181 | 182 | def forward(self, x): 183 | encode1 = self.encoder1(x) 184 | encode2 = self.encoder2(x) 185 | out = torch.cat((encode1, encode2), dim=1) 186 | out = self.encoder_conv(out) 187 | out = out.view(-1, 1024) 188 | out = self.fc(out) 189 | out = self.sig(out) 190 | if self.quantization: 191 | out = self.quantize(out) 192 | else: 193 | out = out 194 | return out 195 | 196 | 197 | class Decoder(nn.Module): 198 | B = 4 199 | 200 | def __init__(self, feedback_bits, quantization=True): 201 | super(Decoder, self).__init__() 202 | self.feedback_bits = feedback_bits 203 | self.dequantize = DequantizationLayer(self.B) 204 | self.fc = nn.Linear(int(feedback_bits / self.B), 1024) 205 | decoder = OrderedDict([ 206 | ("conv5x5_bn", ConvBN(2, 32, 5)), 207 | ("relu", nn.LeakyReLU(negative_slope=0.3, inplace=True)), 208 | ("CRBlock1", CRBlock()), 209 | ("CRBlock2", CRBlock()), 210 | ]) 211 | self.decoder_feature = nn.Sequential(decoder) 212 | self.out_cov = conv3x3(32, 2) 213 | self.sig = nn.Sigmoid() 214 | self.quantization = quantization 215 | 216 | def forward(self, x): 217 | if self.quantization: 218 | out = self.dequantize(x) 219 | else: 220 | out = x 221 | out = out.view(-1, int(self.feedback_bits / self.B)) 222 | out = self.fc(out) 223 | out = out.view(-1, 2, 16, 32) 224 | out = self.decoder_feature(out) 225 | out = self.out_cov(out) 226 | out = self.sig(out) 227 | return out 228 | 229 | 230 | # Note: Do not modify following class and keep it in your submission. 231 | # feedback_bits is 128 by default. 232 | class AutoEncoder(nn.Module): 233 | 234 | def __init__(self, feedback_bits): 235 | super(AutoEncoder, self).__init__() 236 | self.encoder = Encoder(feedback_bits) 237 | self.decoder = Decoder(feedback_bits) 238 | 239 | def forward(self, x): 240 | feature = self.encoder(x) 241 | out = self.decoder(feature) 242 | return out 243 | 244 | 245 | def NMSE(x, x_hat): 246 | x_real = np.reshape(x[:, :, :, 0], (len(x), -1)) 247 | x_imag = np.reshape(x[:, :, :, 1], (len(x), -1)) 248 | x_hat_real = np.reshape(x_hat[:, :, :, 0], (len(x_hat), -1)) 249 | x_hat_imag = np.reshape(x_hat[:, :, :, 1], (len(x_hat), -1)) 250 | x_C = x_real - 0.5 + 1j * (x_imag - 0.5) 251 | x_hat_C = x_hat_real - 0.5 + 1j * (x_hat_imag - 0.5) 252 | power = np.sum(abs(x_C) ** 2, axis=1) 253 | mse = np.sum(abs(x_C - x_hat_C) ** 2, axis=1) 254 | nmse = np.mean(mse / power) 255 | return nmse 256 | 257 | def NMSE_cuda(x, x_hat): 258 | x_real = x[:, 0, :, :].view(len(x),-1) - 0.5 259 | x_imag = x[:, 1, :, :].view(len(x),-1) - 0.5 260 | x_hat_real = x_hat[:, 0, :, :].view(len(x_hat), -1) - 0.5 261 | x_hat_imag = x_hat[:, 1, :, :].view(len(x_hat), -1) - 0.5 262 | power = torch.sum(x_real**2 + x_imag**2, axis=1) 263 | mse = torch.sum((x_real-x_hat_real)**2 + (x_imag-x_hat_imag)**2, axis=1) 264 | nmse = mse/power 265 | return nmse 266 | 267 | class NMSELoss(nn.Module): 268 | def __init__(self, reduction='sum'): 269 | super(NMSELoss, self).__init__() 270 | self.reduction = reduction 271 | 272 | def forward(self, x_hat, x): 273 | nmse = NMSE_cuda(x, x_hat) 274 | if self.reduction == 'mean': 275 | nmse = torch.mean(nmse) 276 | else: 277 | nmse = torch.sum(nmse) 278 | return nmse 279 | 280 | def Score(NMSE): 281 | score = 1 - NMSE 282 | return score 283 | 284 | 285 | # dataLoader 286 | class DatasetFolder(Dataset): 287 | 288 | def __init__(self, matData): 289 | self.matdata = matData 290 | 291 | def __len__(self): 292 | return self.matdata.shape[0] 293 | 294 | def __getitem__(self, index): 295 | return self.matdata[index] #, self.matdata[index] 296 | -------------------------------------------------------------------------------- /tmp/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | # import scipy.interpolate 4 | # import tensorflow as tf 5 | import math 6 | import os 7 | 8 | mu = 2 9 | K = 256 10 | CP = 32 11 | 12 | 13 | def print_something(): 14 | print('utils.py has been loaded perfectly') 15 | 16 | 17 | def Clipping(x, CL): 18 | sigma = np.sqrt(np.mean(np.square(np.abs(x)))) 19 | CL = CL * sigma 20 | x_clipped = x 21 | clipped_idx = abs(x_clipped) > CL 22 | x_clipped[clipped_idx] = np.divide((x_clipped[clipped_idx] * CL), abs(x_clipped[clipped_idx])) 23 | return x_clipped 24 | 25 | 26 | def PAPR(x): 27 | Power = np.abs(x) ** 2 28 | PeakP = np.max(Power) 29 | AvgP = np.mean(Power) 30 | PAPR_dB = 10 * np.log10(PeakP / AvgP) 31 | return PAPR_dB 32 | 33 | 34 | def Modulation(bits, mu): 35 | bit_r = bits.reshape((int(len(bits) / mu), mu)) 36 | return 0.7071 * (2 * bit_r[:, 0] - 1) + 0.7071j * (2 * bit_r[:, 1] - 1) # This is just for QAM modulation 37 | 38 | 39 | def deModulation(Q): 40 | Qr=np.real(Q) 41 | Qi=np.imag(Q) 42 | bits=np.zeros([64,2]) 43 | bits[:,0]=Qr>0 44 | bits[:,1]=Qi>0 45 | return bits.reshape([-1]) # This is just for QAM modulation 46 | 47 | def Modulation1(bits, mu): 48 | bit_r = bits.reshape((int(len(bits) / mu), mu)) 49 | return (bit_r[:, 0]) + 1j * (bit_r[:, 1]) 50 | 51 | 52 | def IDFT(OFDM_data): 53 | return np.fft.ifft(OFDM_data) 54 | 55 | 56 | def addCP(OFDM_time, CP, CP_flag, mu, K): 57 | if CP_flag == False: 58 | # add noise CP 59 | bits_noise = np.random.binomial(n=1, p=0.5, size=(K * mu,)) 60 | codeword_noise = Modulation(bits_noise, mu) 61 | OFDM_data_nosie = codeword_noise 62 | OFDM_time_noise = np.fft.ifft(OFDM_data_nosie) 63 | cp = OFDM_time_noise[-CP:] 64 | else: 65 | cp = OFDM_time[-CP:] # take the last CP samples ... 66 | # cp = OFDM_time[-CP:] 67 | return np.hstack([cp, OFDM_time]) # ... and add them to the beginning 68 | 69 | 70 | def channel(signal, channelResponse, SNRdb): 71 | 72 | convolved = np.convolve(signal, channelResponse) 73 | 74 | sigma2 = 0.35 * 10 ** (-SNRdb / 10) 75 | noise = np.sqrt(sigma2 / 2) * (np.random.randn(*convolved.shape) + 1j * np.random.randn(*convolved.shape)) 76 | return convolved + noise 77 | 78 | 79 | def removeCP(signal, CP, K): 80 | return signal[CP:(CP + K)] 81 | 82 | 83 | def DFT(OFDM_RX): 84 | return np.fft.fft(OFDM_RX) 85 | 86 | 87 | def equalize(OFDM_demod, Hest): 88 | return OFDM_demod / Hest 89 | 90 | 91 | def get_payload(equalized): 92 | return equalized[dataCarriers] 93 | 94 | 95 | def PS(bits): 96 | return bits.reshape((-1,)) 97 | 98 | 99 | 100 | def ofdm_simulate(codeword, channelResponse, SNRdb, mu, CP_flag, K, P, CP, pilotValue, pilotCarriers, dataCarriers, 101 | Clipping_Flag): 102 | 103 | # --- training inputs ---- 104 | 105 | CR=1 106 | OFDM_data = np.zeros(K, dtype=complex) 107 | OFDM_data[pilotCarriers] = pilotValue # allocate the pilot subcarriers 108 | 109 | OFDM_time = IDFT(OFDM_data) 110 | OFDM_withCP = addCP(OFDM_time, CP, CP_flag, mu, 2 * K) 111 | # OFDM_withCP = addCP(OFDM_time) 112 | OFDM_TX = OFDM_withCP 113 | if Clipping_Flag: 114 | OFDM_TX = Clipping(OFDM_TX, CR) # add clipping 115 | OFDM_RX = channel(OFDM_TX, channelResponse, SNRdb) 116 | OFDM_RX_noCP = removeCP(OFDM_RX, CP, K) 117 | OFDM_RX_noCP = np.fft.fft(OFDM_RX_noCP) 118 | # OFDM_RX_noCP = removeCP(OFDM_RX) 119 | # ----- target inputs --- 120 | symbol = np.zeros(K, dtype=complex) 121 | codeword_qam = Modulation(codeword, mu) 122 | if len(codeword_qam) != K: 123 | print('length of code word is not equal to K, error !!') 124 | symbol = codeword_qam 125 | OFDM_data_codeword = symbol 126 | OFDM_time_codeword = np.fft.ifft(OFDM_data_codeword) 127 | OFDM_withCP_cordword = addCP(OFDM_time_codeword, CP, CP_flag, mu, K) 128 | # OFDM_withCP_cordword = addCP(OFDM_time_codeword) 129 | if Clipping_Flag: 130 | OFDM_withCP_cordword = Clipping(OFDM_withCP_cordword, CR) # add clipping 131 | OFDM_RX_codeword = channel(OFDM_withCP_cordword, channelResponse, SNRdb) 132 | OFDM_RX_noCP_codeword = removeCP(OFDM_RX_codeword, CP, K) 133 | OFDM_RX_noCP_codeword = np.fft.fft(OFDM_RX_noCP_codeword) 134 | AA = np.concatenate((np.real(OFDM_RX_noCP), np.imag(OFDM_RX_noCP))) 135 | 136 | 137 | CC=OFDM_RX_noCP/np.max(AA) 138 | BB = np.concatenate((np.real(OFDM_RX_noCP_codeword), np.imag(OFDM_RX_noCP_codeword))) 139 | 140 | return np.concatenate((AA, BB)), CC # sparse_mask 141 | 142 | 143 | def MIMO(X, HMIMO, SNRdb,flag,P): 144 | P = P * 2 145 | Pilot_file_name = 'Pilot_' + str(P) 146 | if os.path.isfile(Pilot_file_name): 147 | bits = np.loadtxt(Pilot_file_name, delimiter=',') 148 | else: 149 | bits = np.random.binomial(n=1, p=0.5, size=(P * mu,)) 150 | np.savetxt(Pilot_file_name, bits, delimiter=',') 151 | pilotValue = Modulation(bits, mu) 152 | 153 | 154 | if flag==1: 155 | cpflag, CR = 0, 0 156 | elif flag==2: 157 | cpflag, CR = 0, 1 158 | else: 159 | cpflag, CR = 1, 0 160 | allCarriers = np.arange(K) 161 | pilotCarriers = np.arange(0, K, K // P) 162 | dataCarriers = [val for val in allCarriers if not (val in pilotCarriers)] 163 | 164 | 165 | 166 | bits0=X[0] 167 | bits1=X[1] 168 | pilotCarriers1 = pilotCarriers[0:P:2] 169 | pilotCarriers2 = pilotCarriers[1:P:2] 170 | signal_output00, para = ofdm_simulate(bits0, HMIMO[0,:], SNRdb, mu, cpflag, K, P, CP, pilotValue[0:P:2], 171 | pilotCarriers1, dataCarriers, CR) 172 | signal_output01, para = ofdm_simulate(bits0, HMIMO[1, :], SNRdb, mu, cpflag, K, P, CP, pilotValue[0:P:2], 173 | pilotCarriers1, dataCarriers, CR) 174 | signal_output10, para = ofdm_simulate(bits1, HMIMO[2, :], SNRdb, mu, cpflag, K, P, CP, pilotValue[1:P:2], 175 | pilotCarriers2, dataCarriers, CR) 176 | signal_output11, para = ofdm_simulate(bits1, HMIMO[3, :], SNRdb, mu, cpflag, K, P, CP, pilotValue[1:P:2], 177 | pilotCarriers2, dataCarriers, CR) 178 | 179 | signal_output0=signal_output00+signal_output10 180 | signal_output1=signal_output01+signal_output11 181 | output=np.concatenate((signal_output0, signal_output1)) 182 | output=np.transpose(np.reshape(output,[8,-1]),[1,0]) 183 | 184 | #print(np.shape(signal_output00)) 185 | return np.reshape(output,[-1]) 186 | 187 | # def ofdm_simulateMIMO22(codeword, channelResponse, SNRdb, mu, CP_flag, K, P, CP, pilotValue, pilotCarriers1,pilotCarriers2, 188 | # Clipping_Flag): 189 | # # --- training inputs ---- 190 | # 191 | # 192 | # OFDM_data = np.zeros([2,K], dtype=complex) 193 | # OFDM_data[0,pilotCarriers1] = pilotValue[0:2:P] # allocate the pilot subcarriers 194 | # OFDM_data[1, pilotCarriers2] = pilotValue[1:2:P] 195 | # 196 | # OFDM_time = IDFT(OFDM_data) 197 | # OFDM_withCP = addCP(OFDM_time, CP, CP_flag, mu, 2 * K) 198 | # # OFDM_withCP = addCP(OFDM_time) 199 | # OFDM_TX = OFDM_withCP 200 | # if Clipping_Flag: 201 | # OFDM_TX = Clipping(OFDM_TX, CR) # add clipping 202 | # OFDM_RX = channel(OFDM_TX, channelResponse, SNRdb) 203 | # OFDM_RX_noCP = removeCP(OFDM_RX, CP, K) 204 | # OFDM_RX_noCP = np.fft.fft(OFDM_RX_noCP) 205 | # # OFDM_RX_noCP = removeCP(OFDM_RX) 206 | # # ----- target inputs --- 207 | # symbol = np.zeros(K, dtype=complex) 208 | # codeword_qam = Modulation(codeword, mu) 209 | # if len(codeword_qam) != K: 210 | # print('length of code word is not equal to K, error !!') 211 | # symbol = codeword_qam 212 | # OFDM_data_codeword = symbol 213 | # OFDM_time_codeword = np.fft.ifft(OFDM_data_codeword) 214 | # OFDM_withCP_cordword = addCP(OFDM_time_codeword, CP, CP_flag, mu, K) 215 | # # OFDM_withCP_cordword = addCP(OFDM_time_codeword) 216 | # if Clipping_Flag: 217 | # OFDM_withCP_cordword = Clipping(OFDM_withCP_cordword, CR) # add clipping 218 | # OFDM_RX_codeword = channel(OFDM_withCP_cordword, channelResponse, SNRdb) 219 | # OFDM_RX_noCP_codeword = removeCP(OFDM_RX_codeword, CP, K) 220 | # OFDM_RX_noCP_codeword = np.fft.fft(OFDM_RX_noCP_codeword) 221 | # AA = np.concatenate((np.real(OFDM_RX_noCP), np.imag(OFDM_RX_noCP))) 222 | # 223 | # #AA = AA / np.max(AA) 224 | # CC=OFDM_RX_noCP/np.max(AA) 225 | # BB = np.concatenate((np.real(OFDM_RX_noCP_codeword), np.imag(OFDM_RX_noCP_codeword))) 226 | # #BB = BB / np.max(BB) 227 | # # OFDM_RX_noCP_codeword = removeCP(OFDM_RX_codeword) 228 | # return np.concatenate((AA, BB)), CC # sparse_mask 229 | 230 | 231 | ''' 232 | 233 | def ofdm_simulate(codeword, channelResponse,SNRdb,mu, CP_flag, K, P, CP, pilotValue,pilotCarriers, dataCarriers,Clipping_Flag): 234 | payloadBits_per_OFDM = mu*len(dataCarriers) 235 | 236 | # --- training inputs ---- 237 | if P < K: 238 | bits = np.random.binomial(n=1, p=0.5, size=(payloadBits_per_OFDM, )) 239 | QAM = Modulation(bits,mu) 240 | OFDM_data = np.zeros(K, dtype=complex) 241 | OFDM_data[pilotCarriers] = pilotValue # allocate the pilot subcarriers 242 | OFDM_data[dataCarriers] = QAM 243 | else: 244 | OFDM_data = pilotValue 245 | OFDM_time = IDFT(OFDM_data) 246 | OFDM_withCP = addCP(OFDM_time) 247 | OFDM_TX = OFDM_withCP 248 | OFDM_RX = channel(OFDM_TX, channelResponse,SNRdb) 249 | OFDM_RX_noCP = removeCP(OFDM_RX) 250 | 251 | # ----- target inputs --- 252 | symbol = np.zeros(K, dtype=complex) 253 | codeword_qam = Modulation(codeword,mu) 254 | symbol[np.arange(K)] = codeword_qam 255 | OFDM_data_codeword = symbol 256 | OFDM_time_codeword = np.fft.ifft(OFDM_data_codeword) 257 | OFDM_withCP_cordword = addCP(OFDM_time_codeword) 258 | OFDM_RX_codeword = channel(OFDM_withCP_cordword, channelResponse,SNRdb) 259 | OFDM_RX_noCP_codeword = removeCP(OFDM_RX_codeword) 260 | return np.concatenate((np.concatenate((np.real(OFDM_RX_noCP),np.imag(OFDM_RX_noCP))), np.concatenate((np.real(OFDM_RX_noCP_codeword),np.imag(OFDM_RX_noCP_codeword))))), abs(channelResponse) 261 | ''' 262 | 263 | ''' 264 | 265 | 266 | def ofdm_simulate(codeword, channelResponse,SNRdb, mu, CP_flag, K, P, CP, pilotValue, pilotCarriers, dataCarriers,Clipping_Flag): 267 | OFDM_data = np.zeros(K, dtype=complex) 268 | allCarriers = np.arange(K) 269 | OFDM_data[allCarriers] = pilotValue 270 | OFDM_time = IDFT(OFDM_data) 271 | OFDM_withCP = addCP(OFDM_time, CP_flag, mu, K) 272 | OFDM_TX = OFDM_withCP 273 | OFDM_RX = channel(OFDM_TX, channelResponse,SNRdb) 274 | OFDM_RX_noCP = removeCP(OFDM_RX,CP,K) 275 | 276 | # ----- target inputs --- 277 | symbol = np.zeros(K, dtype=complex) 278 | codeword_qam = Modulation(codeword,mu) 279 | symbol[np.arange(K)] = codeword_qam 280 | OFDM_data_codeword = symbol 281 | OFDM_time_codeword = np.fft.ifft(OFDM_data_codeword) 282 | OFDM_withCP_cordword = addCP(OFDM_time_codeword, CP_flag, mu, K) 283 | OFDM_RX_codeword = channel(OFDM_withCP_cordword, channelResponse,SNRdb) 284 | OFDM_RX_noCP_codeword = removeCP(OFDM_RX_codeword,CP,K) 285 | return np.concatenate((np.concatenate((np.real(OFDM_RX_noCP),np.imag(OFDM_RX_noCP))), np.concatenate((np.real(OFDM_RX_noCP_codeword),np.imag(OFDM_RX_noCP_codeword))))), abs(channelResponse) 286 | 287 | 288 | ''' 289 | -------------------------------------------------------------------------------- /Model_define_pytorch_score_09945.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """An Implement of an autoencoder with pytorch. 3 | This is the template code for 2020 NIAC https://naic.pcl.ac.cn/. 4 | The code is based on the sample code with tensorflow for 2020 NIAC and it can only run with GPUS. 5 | If you have any questions, please contact me with https://github.com/xufana7/AutoEncoder-with-pytorch 6 | Author, Fan xu Aug 2020 7 | """ 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import Dataset 13 | from collections import OrderedDict 14 | 15 | # from torchvision.models import resnet18 16 | 17 | channel_num = 768 18 | 19 | # This part implement the quantization and dequantization operations. 20 | # The output of the encoder must be the bitstream. 21 | def Num2Bit(Num, B): 22 | Num_ = Num.type(torch.uint8) 23 | 24 | def integer2bit(integer, num_bits=B * 2): 25 | dtype = integer.type() 26 | exponent_bits = -torch.arange(-(num_bits - 1), 1).type(dtype) 27 | exponent_bits = exponent_bits.repeat(integer.shape + (1,)) 28 | out = integer.unsqueeze(-1) // 2 ** exponent_bits 29 | return (out - (out % 1)) % 2 30 | 31 | bit = integer2bit(Num_) 32 | bit = (bit[:, :, B:]).reshape(-1, Num_.shape[1] * B) 33 | return bit.type(torch.float32) 34 | 35 | 36 | def Bit2Num(Bit, B): 37 | Bit_ = Bit.type(torch.float32) 38 | Bit_ = torch.reshape(Bit_, [-1, int(Bit_.shape[1] / B), B]) 39 | num = torch.zeros(Bit_[:, :, 1].shape).cuda() 40 | for i in range(B): 41 | num = num + Bit_[:, :, i] * 2 ** (B - 1 - i) 42 | return num 43 | 44 | 45 | class Quantization(torch.autograd.Function): 46 | @staticmethod 47 | def forward(ctx, x, B): 48 | ctx.constant = B 49 | step = 2 ** B 50 | out = torch.round(x * step - 0.5) 51 | out = Num2Bit(out, B) 52 | return out 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | # return as many input gradients as there were arguments. 57 | # Gradients of constant arguments to forward must be None. 58 | # Gradient of a number is the sum of its four bits. 59 | b, _ = grad_output.shape 60 | grad_num = torch.sum(grad_output.reshape(b, -1, ctx.constant), dim=2) 61 | return grad_num, None 62 | 63 | 64 | class Dequantization(torch.autograd.Function): 65 | @staticmethod 66 | def forward(ctx, x, B): 67 | ctx.constant = B 68 | step = 2 ** B 69 | out = Bit2Num(x, B) 70 | out = (out + 0.5) / step 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | # return as many input gradients as there were arguments. 76 | # Gradients of non-Tensor arguments to forward must be None. 77 | # repeat the gradient of a Num for four time. 78 | #b, c = grad_output.shape 79 | #grad_bit = grad_output.repeat(1, 1, ctx.constant) 80 | #return torch.reshape(grad_bit, (-1, c * ctx.constant)), None 81 | grad_bit = grad_output.repeat_interleave(ctx.constant, dim=1) 82 | return grad_bit, None 83 | 84 | 85 | class QuantizationLayer(nn.Module): 86 | 87 | def __init__(self, B): 88 | super(QuantizationLayer, self).__init__() 89 | self.B = B 90 | 91 | def forward(self, x): 92 | out = Quantization.apply(x, self.B) 93 | return out 94 | 95 | 96 | class DequantizationLayer(nn.Module): 97 | 98 | def __init__(self, B): 99 | super(DequantizationLayer, self).__init__() 100 | self.B = B 101 | 102 | def forward(self, x): 103 | out = Dequantization.apply(x, self.B) 104 | return out 105 | 106 | 107 | class FakeQuantOp(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, x, num_bits=4): 110 | step = 2 ** num_bits 111 | out = torch.round(x * step - 0.5) 112 | x = (out + 0.5) / step 113 | return x 114 | 115 | @staticmethod 116 | def backward(ctx, grad_output): 117 | # straight through estimator 118 | return grad_output, None 119 | 120 | 121 | class FakeQuantLayer(nn.Module): 122 | 123 | def __init__(self, B): 124 | super(FakeQuantLayer, self).__init__() 125 | self.B = B 126 | 127 | def forward(self, x): 128 | out = FakeQuantOp.apply(x, self.B) 129 | return out 130 | 131 | 132 | def conv3x3(in_planes, out_planes, stride=1): 133 | """3x3 convolution with padding""" 134 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 135 | padding=1, bias=True) 136 | 137 | class SEBlock(nn.Module): 138 | def __init__(self, in_ch, r=8): 139 | super(SEBlock, self).__init__() 140 | 141 | self.linear_1 = nn.Linear(in_ch, in_ch//r) 142 | self.linear_2 = nn.Linear(in_ch//r, in_ch) 143 | 144 | def forward(self, x): 145 | input_x = x 146 | x = F.relu(self.linear_1(x), inplace=True) 147 | x = self.linear_2(x) 148 | x = torch.sigmoid(x) 149 | x = input_x * x 150 | return x 151 | 152 | 153 | class Mish(nn.Module): 154 | def __init__(self): 155 | super().__init__() 156 | 157 | def forward(self, x): 158 | return x *( torch.tanh(F.softplus(x))) 159 | 160 | 161 | class Swish(nn.Module): 162 | def __init__(self, inplace=False): 163 | super().__init__() 164 | self.inplace = inplace 165 | 166 | def forward(self, x): 167 | if self.inplace: 168 | x.mul_(torch.sigmoid(x)) 169 | return x 170 | else: 171 | return x * torch.sigmoid(x) 172 | 173 | 174 | class ActBNConv(nn.Sequential): 175 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1): 176 | if not isinstance(kernel_size, int): 177 | padding = [(i - 1) // 2 for i in kernel_size] 178 | else: 179 | padding = (kernel_size - 1) // 2 180 | super(ActBNConv, self).__init__(OrderedDict([ 181 | ('act', Swish()), 182 | ('bn', nn.BatchNorm2d(in_planes)), 183 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, 184 | padding=padding, groups=groups, bias=False)) 185 | ])) 186 | 187 | 188 | class SCSEModule(nn.Module): 189 | def __init__(self, in_channels, reduction=8): 190 | super().__init__() 191 | self.cSE = nn.Sequential( 192 | nn.AdaptiveAvgPool2d(1), 193 | nn.Conv2d(in_channels, in_channels // reduction, 1), 194 | nn.ReLU(inplace=True), 195 | nn.Conv2d(in_channels // reduction, in_channels, 1), 196 | nn.Sigmoid(), 197 | ) 198 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 199 | 200 | def forward(self, x): 201 | return x * self.cSE(x) + x * self.sSE(x) 202 | 203 | 204 | 205 | class BasicBlock(nn.Module): 206 | def __init__(self, in_channels): 207 | super().__init__() 208 | self.conv1 = ActBNConv(in_channels,in_channels,3) 209 | self.conv2 = ActBNConv(in_channels,in_channels,3) 210 | self.attention = SCSEModule(in_channels) 211 | self.identity = nn.Identity() 212 | 213 | def forward(self, x): 214 | identity = self.identity(x) 215 | x = self.conv1(x) 216 | x = self.conv2(x) 217 | x = self.attention(x) 218 | x = x + identity 219 | return x 220 | 221 | class BasicBlock_2(nn.Module): 222 | def __init__(self, in_channels, out_channels): 223 | super().__init__() 224 | self.conv1 = ActBNConv(in_channels,out_channels,3) 225 | self.conv2 = ActBNConv(out_channels,out_channels,3) 226 | self.attention = SCSEModule(out_channels) 227 | self.identity = nn.Identity() 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.conv2(x) 232 | x = self.attention(x) 233 | return x 234 | 235 | 236 | # class Encoder(nn.Module): 237 | # B = 4 238 | 239 | # def __init__(self, feedback_bits, quantization=True): 240 | # super(Encoder, self).__init__() 241 | 242 | # self.model = resnet18(pretrained=True) 243 | # weight_rgb = self.model.conv1.weight 244 | # weight_grey = weight_rgb[:,:2,:,:] + weight_rgb[:,-1,:,:].unsqueeze(1) / 2.0 245 | # self.model.conv1 = nn.Conv2d(2, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False) 246 | # self.model.conv1.weight = torch.nn.Parameter(weight_grey) 247 | # self.model.fc = nn.Linear(512, 32) 248 | 249 | # self.sig = nn.Sigmoid() 250 | # self.quantize = QuantizationLayer(self.B) 251 | # self.quantization = quantization 252 | # self.fake_quantize = FakeQuantLayer(self.B) 253 | 254 | # def forward(self, x): 255 | # out = self.model(x) 256 | # out = self.sig(out) 257 | # if self.quantization: 258 | # out = self.quantize(out) 259 | # else: 260 | # out = self.fake_quantize(out) 261 | # return out 262 | 263 | 264 | class Encoder(nn.Module): 265 | B = 4 266 | 267 | def __init__(self, feedback_bits, quantization=True): 268 | super(Encoder, self).__init__() 269 | self.conv1 = nn.Conv2d(2, channel_num, kernel_size=3, stride=1, padding=1, bias=True) 270 | self.encoder1 = nn.Sequential(OrderedDict([ 271 | ("conv3x3_bn", ActBNConv(channel_num, channel_num, 3)), 272 | ("conv1x9_bn", ActBNConv(channel_num, channel_num, [1, 9])), 273 | ("conv9x1_bn", ActBNConv(channel_num, channel_num, [9, 1])), 274 | ])) 275 | self.encoder2 = ActBNConv(channel_num, channel_num//4*3, 3) 276 | self.encoder3 = ActBNConv(channel_num, channel_num//4, 5) 277 | self.encoder_conv = nn.Sequential(OrderedDict([ 278 | ("conv1x1_bn_1", ActBNConv(channel_num*3, channel_num, 1)), 279 | ("EncoderBlock", BasicBlock_2(channel_num, channel_num)), 280 | ("conv1x1_bn_2", ActBNConv(channel_num, 2, 1)), 281 | ])) 282 | self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True) 283 | self.fc = nn.Linear(1024, int(feedback_bits / self.B)) 284 | self.sig = nn.Sigmoid() 285 | self.quantize = QuantizationLayer(self.B) 286 | self.quantization = quantization 287 | self.fake_quantize = FakeQuantLayer(self.B) 288 | 289 | def forward(self, x): 290 | x = self.conv1(x) 291 | encode1 = self.encoder1(x) 292 | encode2 = self.encoder2(x) 293 | encode3 = self.encoder3(x) 294 | out = torch.cat([encode1, encode2, encode3, x], dim=1) 295 | out = self.encoder_conv(out) 296 | out = self.relu(out) 297 | out = out.view(-1, 1024) 298 | out = self.fc(out) 299 | out = self.sig(out) 300 | if self.quantization == 'check': 301 | out = out 302 | elif self.quantization: 303 | out = self.quantize(out) 304 | else: 305 | out = self.fake_quantize(out) 306 | return out 307 | 308 | 309 | class Decoder(nn.Module): 310 | B = 4 311 | 312 | def __init__(self, feedback_bits, quantization=True): 313 | super(Decoder, self).__init__() 314 | self.feedback_bits = feedback_bits 315 | self.dequantize = DequantizationLayer(self.B) 316 | self.offset = nn.Sequential( 317 | nn.Linear(int(feedback_bits / self.B), 128), nn.LeakyReLU(), nn.BatchNorm1d(128), 318 | nn.Linear(128, int(feedback_bits / self.B), nn.Sigmoid()), 319 | ) 320 | self.fc = nn.Linear(int(feedback_bits / self.B), 1024) 321 | self.se = SEBlock(1024) 322 | self.bn = nn.BatchNorm1d(1024) 323 | self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True) 324 | decoder = OrderedDict([ 325 | ("conv5x5", nn.Conv2d(2, channel_num, kernel_size=5, stride=1, padding=2, bias=True)), 326 | ("DecoderBlock1", BasicBlock(channel_num)), 327 | ("DecoderBlock2", BasicBlock(channel_num)), 328 | ("DecoderBlock3", BasicBlock(channel_num)), 329 | ("DecoderBlock4", BasicBlock(channel_num)), 330 | ("DecoderBlock5", BasicBlock_2(channel_num, channel_num//4)), 331 | ("DecoderBlock6", BasicBlock(channel_num//4)), 332 | ("DecoderBlock7", BasicBlock(channel_num//4)), 333 | ("DecoderBlock8", BasicBlock(channel_num//4)), 334 | ("DecoderBlock9", BasicBlock_2(channel_num//4, channel_num//8)), 335 | ("DecoderBlock10", BasicBlock(channel_num//8)), 336 | ("DecoderBlock11", BasicBlock(channel_num//8)), 337 | ]) 338 | self.decoder_feature = nn.Sequential(decoder) 339 | self.out_cov = ActBNConv(channel_num//8, 2, 3) 340 | self.sig = nn.Sigmoid() 341 | self.quantization = quantization 342 | 343 | def forward(self, x): 344 | if self.quantization == 'check': 345 | out = x 346 | elif self.quantization: 347 | out = self.dequantize(x) 348 | else: 349 | out = x 350 | out = out.view(-1, int(self.feedback_bits / self.B)) 351 | if self.quantization: 352 | out = out + self.offset(out)/(2**(self.B)) 353 | out = self.fc(out) 354 | out = self.relu(out) 355 | out = self.se(out) 356 | out = self.bn(out) 357 | out = out.view(-1, 2, 16, 32) 358 | out = self.decoder_feature(out) 359 | out = self.out_cov(out) 360 | out = self.sig(out) 361 | return out 362 | 363 | 364 | # Note: Do not modify following class and keep it in your submission. 365 | # feedback_bits is 128 by default. 366 | class AutoEncoder(nn.Module): 367 | 368 | def __init__(self, feedback_bits): 369 | super(AutoEncoder, self).__init__() 370 | self.encoder = Encoder(feedback_bits) 371 | self.decoder = Decoder(feedback_bits) 372 | 373 | def forward(self, x): 374 | feature = self.encoder(x) 375 | out = self.decoder(feature) 376 | return out 377 | 378 | 379 | def NMSE(x, x_hat): 380 | x_real = np.reshape(x[:, :, :, 0], (len(x), -1)) 381 | x_imag = np.reshape(x[:, :, :, 1], (len(x), -1)) 382 | x_hat_real = np.reshape(x_hat[:, :, :, 0], (len(x_hat), -1)) 383 | x_hat_imag = np.reshape(x_hat[:, :, :, 1], (len(x_hat), -1)) 384 | x_C = x_real - 0.5 + 1j * (x_imag - 0.5) 385 | x_hat_C = x_hat_real - 0.5 + 1j * (x_hat_imag - 0.5) 386 | power = np.sum(abs(x_C) ** 2, axis=1) 387 | mse = np.sum(abs(x_C - x_hat_C) ** 2, axis=1) 388 | nmse = np.mean(mse / power) 389 | return nmse 390 | 391 | def NMSE_cuda(x, x_hat): 392 | x_real = x[:, 0, :, :].view(len(x),-1) - 0.5 393 | x_imag = x[:, 1, :, :].view(len(x),-1) - 0.5 394 | x_hat_real = x_hat[:, 0, :, :].view(len(x_hat), -1) - 0.5 395 | x_hat_imag = x_hat[:, 1, :, :].view(len(x_hat), -1) - 0.5 396 | power = torch.sum(x_real**2 + x_imag**2, axis=1) 397 | mse = torch.sum((x_real-x_hat_real)**2 + (x_imag-x_hat_imag)**2, axis=1) 398 | nmse = mse/power 399 | return nmse 400 | 401 | 402 | class NMSELoss(nn.Module): 403 | def __init__(self, reduction='sum'): 404 | super(NMSELoss, self).__init__() 405 | self.reduction = reduction 406 | 407 | def forward(self, x_hat, x): 408 | nmse = NMSE_cuda(x, x_hat) 409 | if self.reduction == 'mean': 410 | nmse = torch.mean(nmse) 411 | else: 412 | nmse = torch.sum(nmse) 413 | return nmse 414 | 415 | 416 | class MSE_NMSELoss(nn.Module): 417 | def __init__(self, alpha = 100., reduction='sum'): 418 | super(MSE_NMSELoss, self).__init__() 419 | self.reduction = reduction 420 | self.MSELoss = nn.MSELoss(reduction=self.reduction) 421 | self.alpha = alpha 422 | 423 | def forward(self, x_hat, x): 424 | nmse = NMSE_cuda(x, x_hat) 425 | if self.reduction == 'mean': 426 | nmse = torch.mean(nmse) 427 | else: 428 | nmse = torch.sum(nmse) 429 | loss = nmse + self.alpha * self.MSELoss(x, x_hat) 430 | return loss 431 | 432 | 433 | def Score(NMSE): 434 | score = 1 - NMSE 435 | return score 436 | 437 | 438 | # dataLoader 439 | class DatasetFolder(Dataset): 440 | 441 | def __init__(self, matData): 442 | self.matdata = matData 443 | 444 | def __len__(self): 445 | return self.matdata.shape[0] 446 | 447 | def __getitem__(self, index): 448 | return self.matdata[index] #, self.matdata[index] 449 | -------------------------------------------------------------------------------- /tmp/train_multigpu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Libs & Settings" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import numpy as np\n", 17 | "import pandas as pd\n", 18 | "import torch\n", 19 | "import os\n", 20 | "import torch.nn as nn\n", 21 | "from torch.nn import functional as F\n", 22 | "from torch.utils.data import DataLoader, Dataset\n", 23 | "\n", 24 | "import struct\n", 25 | "from utils import *\n", 26 | "import random\n", 27 | "import math\n", 28 | "\n", 29 | "import matplotlib.pyplot as plt\n", 30 | "\n", 31 | "from apex import amp\n", 32 | "import time\n", 33 | "\n", 34 | "import scipy.special\n", 35 | "sigmoid = lambda x: scipy.special.expit(x)\n", 36 | "\n", 37 | "# Parameters for training\n", 38 | "gpu_list = '0,1'\n", 39 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = gpu_list\n", 40 | "\n", 41 | "SEED = 42" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "def seed_everything(seed=SEED):\n", 51 | " random.seed(seed)\n", 52 | " os.environ['PYHTONHASHSEED'] = str(seed)\n", 53 | " np.random.seed(seed)\n", 54 | " torch.manual_seed(seed)\n", 55 | " torch.cuda.manual_seed(seed)\n", 56 | " torch.backends.cudnn.deterministic = True\n", 57 | "\n", 58 | "seed_everything(SEED)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# parameters for training\n", 68 | "batch_size = 512\n", 69 | "epochs = 200\n", 70 | "warmup_epoch = 2\n", 71 | "learning_rate = 1e-3 \n", 72 | "lr_div_factor = 20.0\n", 73 | "num_workers = 32\n", 74 | "val_times = 10\n", 75 | "\n", 76 | "# parameters for data\n", 77 | "# mode [0,1,2]\n", 78 | "# SNRdb (8to12)\n", 79 | "Pilotnum = 32 # Y1:32 Y2:8 \n", 80 | "NO_NOISE = False #False # SNR=100 else (8to12) \n", 81 | "FIX_MODE = False # mode=MODE else random choose from [0,1,2]\n", 82 | "MODE = 0\n", 83 | "\n", 84 | "# param of training data and criterion setting\n", 85 | "RESHAPE = True # false: fc true: conv\n", 86 | "# NUM_SAMPLE = 20000\n", 87 | "MSE_or_BCE = 0 # 0 MSE 1 BCE" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## Dataset" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "data_train=open('../input/H.bin','rb')\n", 104 | "H1=struct.unpack('f'*2*2*2*32*320000,data_train.read(4*2*2*2*32*320000))\n", 105 | "H1=np.reshape(H1,[320000,2,4,32])\n", 106 | "H=H1[:,1,:,:]+1j*H1[:,0,:,:]\n", 107 | "H_train = H\n", 108 | "# random.shuffle(H)\n", 109 | "\n", 110 | "# H_train = H[:30000] \n", 111 | "# H_val = H[300000:]\n", 112 | "\n", 113 | "# H_train = H_train[:30000] # for debug\n", 114 | "# H_val = H_val[:2000] # for debug\n", 115 | "\n", 116 | "data_val=open('../input/H_val.bin','rb')\n", 117 | "H1=struct.unpack('f'*2*2*2*32*2000,data_val.read(4*2*2*2*32*2000))\n", 118 | "H1=np.reshape(H1,[2000,2,4,32])\n", 119 | "H_val=H1[:,1,:,:]+1j*H1[:,0,:,:]" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# def ifft(y):#256子载波 * mimo2 * 2(导频/数据) * 2(实部/虚部)\n", 129 | "# y_antenna0 = 20*y[:,0,:,:] #256子载波 * 2(导频/数据) * 2(实部/虚部)\n", 130 | "# y_antenna1 = 20*y[:,1,:,:] #256子载波 * 2(导频/数据) * 2(实部/虚部)\n", 131 | " \n", 132 | "# y_pilot0 = y_antenna0[:,0,0] + y_antenna0[:,0,1]*1j\n", 133 | "# y_data0 = y_antenna0[:,1,0] + y_antenna0[:,1,1]*1j\n", 134 | "# y_pilot0 = np.fft.ifft(y_pilot0)\n", 135 | "# y_data0 = np.fft.ifft(y_data0)\n", 136 | " \n", 137 | "# y_pilot1 = y_antenna1[:,0,0] + y_antenna1[:,0,1]*1j\n", 138 | "# y_data1 = y_antenna1[:,1,0] + y_antenna1[:,1,1]*1j\n", 139 | "# y_pilot1 = np.fft.ifft(y_pilot1)\n", 140 | "# y_data1 = np.fft.ifft(y_data1)\n", 141 | "\n", 142 | "# y[:,0,1,0] = y_data0.real\n", 143 | "# y[:,0,1,1] = y_data0.imag\n", 144 | "# y[:,0,0,0] = y_pilot0.real \n", 145 | "# y[:,0,0,1] = y_pilot0.imag\n", 146 | "# y[:,1,1,0] = y_data1.real\n", 147 | "# y[:,1,1,1] = y_data1.imag\n", 148 | "# y[:,1,0,0] = y_pilot1.real\n", 149 | "# y[:,1,0,1] = y_pilot1.imag\n", 150 | "# return y" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "class OFDMDataset(Dataset):\n", 160 | " def __init__(self, H, Pilotnum=32, train=True): # Pilotnum: 32(Y1) or 8(Y2)\n", 161 | " self.H = H\n", 162 | " self.Pilotnum = Pilotnum\n", 163 | " self.train = train\n", 164 | "\n", 165 | " def __len__(self):\n", 166 | " return len(self.H) #NUM_SAMPLE\n", 167 | "\n", 168 | " def __getitem__(self, idx):\n", 169 | " binomial = torch.distributions.binomial.Binomial(total_count=1, probs=0.5*torch.ones(128*4))\n", 170 | " self.bits0 = binomial.sample().numpy()\n", 171 | " self.bits1 = binomial.sample().numpy()\n", 172 | " X = [self.bits0, self.bits1]\n", 173 | " HH = self.H[idx] #self.H[torch.randint(0,len(self.H),size=(1,))]\n", 174 | " if NO_NOISE:\n", 175 | " SNRdb = 100\n", 176 | " elif self.train:\n", 177 | " SNRdb = torch.Tensor(1,).uniform_(7.95, 12.1).numpy() # 8 to 12\n", 178 | " else:\n", 179 | " SNRdb = torch.Tensor(1,).uniform_(8.0,12.0).numpy() # 8 to 12\n", 180 | " if not FIX_MODE:\n", 181 | " mode = torch.randint(0,3,size=(1,))\n", 182 | " else:\n", 183 | " mode = MODE \n", 184 | " YY = MIMO(X, HH, SNRdb, mode, self.Pilotnum) / 20 ###\n", 185 | " XX = np.concatenate((self.bits0, self.bits1), 0)\n", 186 | " if RESHAPE:\n", 187 | " YY = YY.reshape(256, 2, 2, 2) # 256子载波 * mimo2 * 2(导频/数据) * 2(实部/虚部)\n", 188 | " #YY = YY.transpose(2,1,3,0).reshape(2,4,256) #2(导频/数据) x 4 x 256\n", 189 | " return XX, YY" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": null, 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "dataset_train = OFDMDataset(H_train, Pilotnum, True)\n", 199 | "dataset_val = OFDMDataset(H_val, Pilotnum, False) # *= val_times\n", 200 | "\n", 201 | "train_loader = torch.utils.data.DataLoader(\n", 202 | " dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)\n", 203 | "\n", 204 | "val_loader = torch.utils.data.DataLoader(\n", 205 | " dataset_val, batch_size=batch_size, shuffle=False, num_workers=num_workers)" 206 | ] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": {}, 211 | "source": [ 212 | "## Model" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "class Mish(nn.Module):\n", 222 | " def __init__(self):\n", 223 | " super().__init__()\n", 224 | "\n", 225 | " def forward(self, x):\n", 226 | " return x *( torch.tanh(F.softplus(x)))\n", 227 | "\n", 228 | "\n", 229 | "class Swish(nn.Module):\n", 230 | " def __init__(self, inplace=False):\n", 231 | " super().__init__()\n", 232 | " self.inplace = inplace\n", 233 | "\n", 234 | " def forward(self, x):\n", 235 | " if self.inplace:\n", 236 | " x.mul_(torch.sigmoid(x))\n", 237 | " return x\n", 238 | " else:\n", 239 | " return x * torch.sigmoid(x)" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": {}, 246 | "outputs": [], 247 | "source": [ 248 | "class ResConv(nn.Module):\n", 249 | " def __init__(self, dim=512, k=(1,1), groups=1, dropout=0.25):\n", 250 | " super().__init__()\n", 251 | " self.dense = nn.Sequential(\n", 252 | " nn.BatchNorm2d(dim),\n", 253 | " nn.ReLU(inplace=True),\n", 254 | " nn.Conv2d(dim, dim, k, bias=False, groups=groups),\n", 255 | " nn.Dropout(dropout),\n", 256 | " nn.BatchNorm2d(dim),\n", 257 | " nn.ReLU(inplace=True),\n", 258 | " nn.Conv2d(dim, dim, k, bias=False, groups=groups),\n", 259 | " )\n", 260 | " def forward(self, x):\n", 261 | " x = x + self.dense(x)\n", 262 | " return x\n", 263 | "\n", 264 | "\n", 265 | "multiple = 16\n", 266 | "class Net(nn.Module):\n", 267 | " def __init__(self):\n", 268 | " super().__init__()\n", 269 | "\n", 270 | " self.conv_pilot = nn.Sequential(\n", 271 | " nn.Conv2d(256, 1024, (2,2), bias=False),\n", 272 | " ResConv(dim=1024, k=(1,1)),\n", 273 | " ResConv(dim=1024, k=(1,1)),\n", 274 | " nn.Conv2d(1024, 1024*multiple, (1,1), bias=True),\n", 275 | " )\n", 276 | " \n", 277 | " self.conv_data = nn.Sequential(\n", 278 | " nn.Conv2d(256, 1024*multiple, (2,2), bias=False, groups=256),\n", 279 | " ResConv(dim=1024*multiple, k=(1,1), groups=256),\n", 280 | " ResConv(dim=1024*multiple, k=(1,1), groups=256),\n", 281 | " ResConv(dim=1024*multiple, k=(1,1), groups=256), \n", 282 | " nn.Conv2d(1024*multiple, 1024*multiple, (1,1), bias=True, groups=256),\n", 283 | " )\n", 284 | " \n", 285 | " self.conv = nn.Sequential(\n", 286 | " nn.Conv2d(1024*multiple, 1024*multiple, (2,1), bias=False, groups=256),\n", 287 | " ResConv(dim=1024*multiple, k=(1,1), groups=256),\n", 288 | " ResConv(dim=1024*multiple, k=(1,1), groups=256), \n", 289 | " ResConv(dim=1024*multiple, k=(1,1), groups=256), \n", 290 | " ResConv(dim=1024*multiple, k=(1,1), groups=256), \n", 291 | " ResConv(dim=1024*multiple, k=(1,1), groups=256), \n", 292 | " nn.Conv2d(1024*multiple, 1024, (1,1), bias=True, groups=256),\n", 293 | " )\n", 294 | " \n", 295 | " self.sigmoid = nn.Sigmoid()\n", 296 | " \n", 297 | " def forward(self, x):\n", 298 | " # input: # (bs,256,2,2,2)\n", 299 | " x_pilot = x[:,:,:,0,:] # (bs,256,2,2)\n", 300 | " x_data = x[:,:,:,1,:] # (bs,256,2,2)\n", 301 | " \n", 302 | " x_pilot = self.conv_pilot(x_pilot) # (bs,1024x,1,1)\n", 303 | " x_data = self.conv_data(x_data) # (bs,1024x,1,1)\n", 304 | " \n", 305 | " x = torch.cat([x_data,x_pilot], axis=2) # (bs,1024x,2,1)\n", 306 | " \n", 307 | " x = self.conv(x) # (bs,1024,1,1)\n", 308 | " \n", 309 | " x = torch.squeeze(x) # (bs,1024)\n", 310 | " x1 = x[:,0::2]\n", 311 | " x2 = x[:,1::2]\n", 312 | " x = torch.cat([x1,x2], axis=1) # (bs,1024)\n", 313 | " \n", 314 | " x = self.sigmoid(x)\n", 315 | " \n", 316 | " return x" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "metadata": { 323 | "scrolled": true 324 | }, 325 | "outputs": [], 326 | "source": [ 327 | "model = Net()\n", 328 | "model.cuda()" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "from torch.nn.modules.loss import _WeightedLoss\n", 338 | "\n", 339 | "class SmoothBCEwLogits(_WeightedLoss):\n", 340 | " def __init__(self, weight=None, reduction='mean', smoothing=0.0):\n", 341 | " super().__init__(weight=weight, reduction=reduction)\n", 342 | " self.smoothing = smoothing\n", 343 | " self.weight = weight\n", 344 | " self.reduction = reduction\n", 345 | " \n", 346 | " @staticmethod\n", 347 | " def _smooth(targets:torch.Tensor, n_labels:int, smoothing=0.0):\n", 348 | " assert 0<=smoothing<1\n", 349 | " with torch.no_grad():\n", 350 | " targets = targets * (1.0-smoothing) + 0.5*smoothing\n", 351 | " return targets\n", 352 | " \n", 353 | " def forward(self, inputs, targets):\n", 354 | " targets = SmoothBCEwLogits._smooth(targets, inputs.size(-1), self.smoothing)\n", 355 | " loss = F.binary_cross_entropy_with_logits(inputs, targets, self.weight)\n", 356 | " \n", 357 | " if self.reduction == 'sum':\n", 358 | " loss = loss.sum()\n", 359 | " elif self.reduction == 'mean':\n", 360 | " loss = loss.mean()\n", 361 | " else:\n", 362 | " loss = loss.mean()\n", 363 | " \n", 364 | " return loss" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": null, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "if not MSE_or_BCE:\n", 374 | " criterion = nn.MSELoss()\n", 375 | " criterion_test = nn.MSELoss()\n", 376 | " thre = 0.5\n", 377 | "else:\n", 378 | " criterion = SmoothBCEwLogits(smoothing=0.01) #nn.BCEWithLogitsLoss()\n", 379 | " criterion_test = nn.BCEWithLogitsLoss()\n", 380 | " thre = 0.0 " 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "from lr_scheduler import CosineAnnealingWarmUpRestarts\n", 390 | "\n", 391 | "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate/lr_div_factor, weight_decay=2e-5)\n", 392 | "\n", 393 | "# model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=0)\n", 394 | "\n", 395 | "if len(gpu_list.split(',')) > 1:\n", 396 | " model = torch.nn.DataParallel(model)\n", 397 | "\n", 398 | "T = len(train_loader) * epochs \n", 399 | "T_up = len(train_loader) * warmup_epoch\n", 400 | "scheduler = CosineAnnealingWarmUpRestarts(optimizer, T_0=T, T_mult=2, eta_max=learning_rate, T_up=T_up, gamma=0.1)" 401 | ] 402 | }, 403 | { 404 | "cell_type": "markdown", 405 | "metadata": {}, 406 | "source": [ 407 | "## Train" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "best_ber = 1.0" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "# if len(gpu_list.split(',')) > 1:\n", 426 | "# model.module.load_state_dict(torch.load('model%d.pth'%Pilotnum))\n", 427 | "# else:\n", 428 | "# model.load_state_dict(torch.load('model%d.pth'%Pilotnum))" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": { 435 | "scrolled": true 436 | }, 437 | "outputs": [], 438 | "source": [ 439 | "for epoch in range(epochs):\n", 440 | " start_time = time.time()\n", 441 | " print('========================')\n", 442 | " print('lr:%.4e'%optimizer.param_groups[0]['lr']) \n", 443 | " \n", 444 | " # model training\n", 445 | " model.train()\n", 446 | " losses = []\n", 447 | " for i, (X,Y) in enumerate(train_loader):\n", 448 | " X,Y = X.float().cuda(), Y.float().cuda()\n", 449 | " output = model(Y)\n", 450 | " loss = criterion(output, X)\n", 451 | " loss.backward()\n", 452 | "# with amp.scale_loss(loss, optimizer) as scaled_loss:\n", 453 | "# scaled_loss.backward()\n", 454 | " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)\n", 455 | " optimizer.step()\n", 456 | " optimizer.zero_grad()\n", 457 | " scheduler.step()\n", 458 | " losses.append(loss.item())\n", 459 | " \n", 460 | " avg_tr_loss = sum(losses)/len(losses)\n", 461 | " print('Epoch: [{0}]\\t'\n", 462 | " 'Loss {loss:.4f}\\t'.format(\n", 463 | " epoch, loss=avg_tr_loss))\n", 464 | " \n", 465 | " model.eval()\n", 466 | " eval_losses = []\n", 467 | " ber_list = []\n", 468 | " for _ in range(val_times):\n", 469 | " with torch.no_grad():\n", 470 | " for i, (X,Y) in enumerate(val_loader):\n", 471 | " X,Y = X.float().cuda(), Y.float().cuda()\n", 472 | " output = model(Y)\n", 473 | " eval_losses.append(criterion_test(output, X).item())\n", 474 | " ber = ((output.detach() > thre) == X.bool()).cpu().numpy().mean() #mse: 0.5 bce: 0\n", 475 | " ber_list.append(ber)\n", 476 | " avg_eval_loss = sum(eval_losses)/len(eval_losses)\n", 477 | " avg_eval_ber = 1 - sum(ber_list)/len(ber_list)\n", 478 | " print('Val Loss: %.4f | Val BER: %.5f'%(avg_eval_loss,avg_eval_ber))\n", 479 | " \n", 480 | " if avg_eval_ber < best_ber:\n", 481 | " \n", 482 | " if len(gpu_list.split(',')) > 1:\n", 483 | " torch.save(model.module.state_dict(), 'model%d.pth'%Pilotnum)\n", 484 | " else:\n", 485 | " torch.save(model.state_dict(), 'model%d.pth'%Pilotnum)\n", 486 | " best_ber = avg_eval_ber\n", 487 | " print('Model saved!')\n", 488 | " \n", 489 | " end_time = time.time()\n", 490 | " print('Time cost:%ds'%(round(end_time-start_time)))" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "# change hyper-parameters\n", 500 | "# transformer or other networks\n", 501 | "# add EMA\n", 502 | "# Ensemble\n", 503 | "\n", 504 | "print(\"BEST BER: %.4f\"%best_ber)" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "metadata": {}, 510 | "source": [ 511 | "## Inference" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": null, 517 | "metadata": { 518 | "scrolled": true 519 | }, 520 | "outputs": [], 521 | "source": [ 522 | "if len(gpu_list.split(',')) > 1:\n", 523 | " model = model.module\n", 524 | " \n", 525 | "model.load_state_dict(torch.load('model%d.pth'%Pilotnum))\n", 526 | " \n", 527 | "model.eval()" 528 | ] 529 | }, 530 | { 531 | "cell_type": "code", 532 | "execution_count": null, 533 | "metadata": {}, 534 | "outputs": [], 535 | "source": [ 536 | "if Pilotnum==32:\n", 537 | " name = '../input/Y_1.csv'\n", 538 | "elif Pilotnum==8:\n", 539 | " name = '../input/Y_2.csv'\n", 540 | "with open(name) as f:\n", 541 | " Y = f.readlines()\n", 542 | " \n", 543 | "for idx, line in enumerate(Y): \n", 544 | " Y[idx] = list(map(float, line.split(',')))\n", 545 | " \n", 546 | "Y = np.array(Y)\n", 547 | "if RESHAPE:\n", 548 | " Y = Y.reshape(Y.shape[0],256,2,2,2)" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "metadata": {}, 555 | "outputs": [], 556 | "source": [ 557 | "test_batch = 100\n", 558 | "result = []\n", 559 | "save = []\n", 560 | "\n", 561 | "with torch.no_grad():\n", 562 | " for i in range(int(len(Y)/test_batch)):\n", 563 | " start = test_batch*i\n", 564 | " end = min(test_batch*(i+1), len(Y))\n", 565 | " input_Y = torch.from_numpy(Y[start:end].astype(np.float32)).cuda()\n", 566 | " output = model(input_Y)\n", 567 | " output = output.detach().cpu().numpy() \n", 568 | " if MSE_or_BCE: # bce\n", 569 | " save.append(sigmoid(output))\n", 570 | " else:\n", 571 | " save.append(output)\n", 572 | " output = output > thre\n", 573 | " result.append(output)\n", 574 | "\n", 575 | "result = np.concatenate(result, axis=0)\n", 576 | "save = np.concatenate(save, axis=0)" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": null, 582 | "metadata": {}, 583 | "outputs": [], 584 | "source": [ 585 | "if Pilotnum==32:\n", 586 | " np.save('X_pre_1_%.4f.npy'%best_ber, save)\n", 587 | "elif Pilotnum==8:\n", 588 | " np.save('X_pre_2_%.4f.npy'%best_ber, save)\n", 589 | "else:\n", 590 | " print('Check the Pilotnum param!!!')\n", 591 | "\n", 592 | "if Pilotnum==32:\n", 593 | " result.tofile('X_pre_1.bin')\n", 594 | "elif Pilotnum==8:\n", 595 | " result.tofile('X_pre_2.bin')\n", 596 | "else:\n", 597 | " print('Check the Pilotnum param!!!')" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": null, 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [] 606 | } 607 | ], 608 | "metadata": { 609 | "kernelspec": { 610 | "display_name": "Python 3", 611 | "language": "python", 612 | "name": "python3" 613 | }, 614 | "language_info": { 615 | "codemirror_mode": { 616 | "name": "ipython", 617 | "version": 3 618 | }, 619 | "file_extension": ".py", 620 | "mimetype": "text/x-python", 621 | "name": "python", 622 | "nbconvert_exporter": "python", 623 | "pygments_lexer": "ipython3", 624 | "version": "3.7.4" 625 | } 626 | }, 627 | "nbformat": 4, 628 | "nbformat_minor": 2 629 | } 630 | --------------------------------------------------------------------------------