├── README.md ├── data_gen └── data_set.py ├── main.py ├── nets ├── auto_encoder.py ├── auto_encoder_for_vis.py └── meta_net.py ├── run ├── realistic │ ├── vanilla │ │ ├── fixed_init.sh │ │ ├── joint.sh │ │ └── meta.sh │ └── with_RTN │ │ ├── fixed_init.sh │ │ ├── joint.sh │ │ └── meta.sh └── toy │ ├── vanilla │ ├── fixed_init.sh │ ├── joint.sh │ └── meta.sh │ └── with_RTN │ ├── fixed_init.sh │ ├── joint.sh │ └── meta.sh ├── saved_data ├── realistic │ ├── channels │ │ ├── meta_training_channels │ │ │ └── training_channels.pckl │ │ └── new_channels │ │ │ └── test_channels.pckl │ └── nets │ │ ├── vanilla │ │ ├── joint │ │ │ └── init_net │ │ └── meta │ │ │ └── init_net │ │ └── with_RTN │ │ ├── joint │ │ └── init_net │ │ └── meta │ │ └── init_net └── toy │ ├── after_adapted_to_new_channels_nets │ ├── vanilla │ │ └── meta │ │ │ └── after_adapt │ │ │ └── 1_adapt_steps │ │ │ ├── 0th_adapted_net │ │ │ └── 1th_adapted_net │ └── with_RTN │ │ └── meta │ │ └── after_adapt │ │ └── 1_adapt_steps │ │ ├── 0th_adapted_net │ │ └── 1th_adapted_net │ └── nets │ ├── vanilla │ ├── joint │ │ └── init_net │ └── meta │ │ └── init_net │ └── with_RTN │ ├── joint │ └── init_net │ └── meta │ └── init_net ├── toy_visualization.ipynb ├── training ├── meta_train.py ├── test.py └── train.py └── utils └── funcs.py /README.md: -------------------------------------------------------------------------------- 1 | ## Meta-Autoencoder 2 | 3 | This repository contains code for "[Meta-Learning to Communicate: Fast End-to-End Training for Fading Channels](https://arxiv.org/abs/1910.09945)" - 4 | Sangwoo Park, Osvaldo Simeone, and Joonhyuk Kang. 5 | 6 | ### Dependencies 7 | 8 | This program is written in python 3.7 and uses PyTorch 1.2.0 and scipy. 9 | Tensorboard for pytorch is used for visualization (e.g., https://pytorch.org/docs/stable/tensorboard.html). 10 | - pip install tensorboard and pip install scipy might be useful. 11 | 12 | ### Basic Usage 13 | 14 | - Train and test a model: 15 | 16 | To train the autoencoder with default settings, execute 17 | ``` 18 | python main.py 19 | ``` 20 | For the default settings and other argument options, see top of `main.py` 21 | 22 | Once training is done, test will be started automatically based on the trained model. 23 | 24 | 25 | 26 | ### Toy Example 27 | 28 | - In the 'run/toy' folder, _meta-learning_, _joint training,_ and _fixed initialization_ schemes can be tested based on the pretrained two autoencoder architectures (vanilla autoencoder, autoencoder with RTN) 29 | 30 | In order to train from scratch, remove '--path_for_meta_trained_net ' part. 31 | 32 | In the 'saved_data/toy/nets' folder, trained models used to generate Fig. 2 can be found (proper paths are given in the shell script). 33 | 34 | In order to regenerate Fig. 3, 'toy_visualization.ipynb' may be useful. 35 | 36 | ### A More Realistic Scenario 37 | 38 | - In the 'run/realistic' folder, _meta-learning_, _joint training,_ and _fixed initialization_ schemes can be tested based on the pretrained two autoencoder architectures (vanilla autoencoder, autoencoder with RTN) 39 | 40 | In order to train from scratch, remove '--path_for_meta_trained_net ' part. 41 | 42 | In order to train and/or test with new channels, remove '--path_for_meta_training_channels' and/or '--path_for_test_channels' part. 43 | 44 | In the 'saved_data/realistic/nets' folder, trained models used to generate Fig. 4 can be found (proper paths are given in the shell script). 45 | 46 | 47 | -------------------------------------------------------------------------------- /data_gen/data_set.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def message_gen(k, mb_size): 5 | tot_message_num = pow(2,k) 6 | m = torch.zeros(mb_size, tot_message_num) 7 | label = torch.zeros(mb_size) 8 | for ind_mb in range(mb_size): 9 | if ind_mb % tot_message_num == 0: 10 | rand_lst = torch.randperm(tot_message_num) 11 | ind_one_rand_lst = ind_mb % tot_message_num 12 | ind_one = rand_lst[ind_one_rand_lst] 13 | m[ind_mb, ind_one] = 1 14 | label[ind_mb] = ind_one 15 | return m, label 16 | 17 | def channel_set_gen(num_channels, tap_num, if_toy): 18 | channel_list = [] 19 | for ind_channels in range(num_channels): 20 | if if_toy: 21 | assert tap_num == 1 22 | if ind_channels % 2 == 0: 23 | h_toy = torch.zeros(2 * tap_num) 24 | h_toy[0] = 1 * np.cos(np.pi/4) 25 | h_toy[1] = 1 * np.sin(np.pi/4) 26 | else: 27 | h_toy = torch.zeros(2 * tap_num) 28 | h_toy[0] = 1 * np.cos((3*np.pi) / 4) 29 | h_toy[1] = 1 * np.sin((3*np.pi) / 4) 30 | channel_list.append(h_toy) 31 | else: 32 | chan_var = 1 / (2 * tap_num) # since we are generating real and im. part indep. so 1/2 and we are considering complex, -> 2L generated 33 | Chan = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2 * tap_num), 34 | chan_var * torch.eye(2 * tap_num)) 35 | h = Chan.sample() 36 | channel_list.append(h) 37 | return channel_list 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from data_gen.data_set import channel_set_gen 4 | from training.train import test_training 5 | from training.test import test_per_channel_per_snr 6 | from nets.auto_encoder import dnn 7 | from torch.utils.tensorboard import SummaryWriter 8 | from training.meta_train import multi_task_learning 9 | import pickle 10 | import scipy.io as sio 11 | import datetime 12 | import os 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='end_to_end-meta') 16 | 17 | # bit num (k), channel uses (n), tap number (L), number of pilots (P), Eb/N0 18 | parser.add_argument('--bit_num', type=int, default=4, help='number of bits') 19 | parser.add_argument('--channel_num', type=int, default=4, help='number of channel uses') 20 | parser.add_argument('--tap_num', type=int, default=3, help='..') 21 | parser.add_argument('--mb_size', type=int, default=16, help='minibatch size') 22 | parser.add_argument('--mb_size_meta_train', type=int, default=16, 23 | help='minibatch size during meta-training (this can be useful for decreasing pilots)') 24 | parser.add_argument('--mb_size_meta_test', type=int, default=16, 25 | help='minibatch size for query set (this can be useful for decreasing pilots)') 26 | parser.add_argument('--Eb_over_N_db', type=float, default=15, 27 | help='energy per bit to noise power spectral density ratio') 28 | 29 | # paths 30 | parser.add_argument('--path_for_common_dir', dest='path_for_common_dir', 31 | default='default_folder/default_subfolder/', type=str) 32 | parser.add_argument('--path_for_meta_training_channels', dest='path_for_meta_training_channels', default=None, 33 | type=str) 34 | parser.add_argument('--path_for_test_channels', dest='path_for_test_channels', default=None, type=str) 35 | parser.add_argument('--path_for_meta_trained_net', dest='path_for_meta_trained_net', default=None, type=str) 36 | 37 | # neural network architecture (number of neurons for hidden layer) 38 | parser.add_argument('--num_neurons_encoder', type=int, default=None, help='number of neuron in hidden layer in encoder') 39 | parser.add_argument('--num_neurons_decoder', type=int, default=None, help='number of neuron in hidden layer in decoder') 40 | # whether to use bias and relu (if not relu: tanh) 41 | parser.add_argument('--if_not_bias', dest='if_bias', action='store_false', default=True) 42 | parser.add_argument('--if_not_relu', dest='if_relu', action='store_false', default=True) 43 | # RTN 44 | parser.add_argument('--if_RTN', dest='if_RTN', action='store_true', default=False) 45 | # in case of running on gpu, index for cuda device 46 | parser.add_argument('--cuda_ind', type=int, default=0, help='index for cuda device') 47 | 48 | # experiment details (hyperparameters, number of data for calculating performance and for meta-training 49 | parser.add_argument('--lr_testtraining', type=float, default=0.001, help='lr for adaptation to new channel') 50 | parser.add_argument('--lr_meta_update', type=float, default=0.01, help='lr during meta-training: outer loop (update initialization) lr') 51 | parser.add_argument('--lr_meta_inner', type=float, default=0.1, help='lr during meta-training: inner loop (local adaptation) lr') 52 | parser.add_argument('--test_size', type=int, default=1000000, help='number of messages to calculate BLER for test (new channel)') 53 | parser.add_argument('--num_channels_meta', type=int, default=100, help='number of meta-training channels (K)') 54 | parser.add_argument('--num_channels_test', type=int, default=20, help='number of new channels for test (to get average over BLER)') 55 | parser.add_argument('--tasks_per_metaupdate', type=int, default=20, help='number of meta-training channels considered in one meta-update') 56 | parser.add_argument('--num_meta_local_updates', type=int, default=1, help='number of local adaptation in meta-training') 57 | parser.add_argument('--num_epochs_meta_train', type=int, default=10000, 58 | help='number epochs for meta-training') 59 | 60 | # if run for joint training, if false: meta-learning 61 | parser.add_argument('--if_joint_training', dest='if_joint_training', action='store_true', default=False) # else: meta-learning for multi-task learning 62 | # whether to use Adam optimizer to adapt to a new channel 63 | parser.add_argument('--if_not_test_training_adam', dest='if_test_training_adam', action='store_false', 64 | default=True) 65 | # if run on toy example (Fig. 2 and 3) 66 | parser.add_argument('--if_toy', dest='if_toy', action='store_true', 67 | default=False) 68 | # to run on a more realistic example (Fig. 4) 69 | parser.add_argument('--if_RBF', dest='if_RBF', action='store_true', 70 | default=False) 71 | parser.add_argument('--test_per_adapt_fixed_Eb_over_N_value', type=int, default=15, 72 | help='Eb/N0 in db for test') 73 | # desinged for maml: sgd during args.num_meta_local_updates with args.lr_meta_inner and then follow Adam optimizer with args.lr_testtraining 74 | parser.add_argument('--if_adam_after_sgd', dest='if_adam_after_sgd', action='store_true', 75 | default=False) 76 | 77 | args = parser.parse_args() 78 | 79 | args.device = torch.device("cuda:" + str(args.cuda_ind) if torch.cuda.is_available() else "cpu") 80 | if args.num_neurons_encoder == None: # unless specified, set number of hidden neurons to be same as the number of possible messages 81 | args.num_neurons_encoder = pow(2,args.bit_num) 82 | if args.num_neurons_decoder == None: 83 | args.num_neurons_decoder = pow(2, args.bit_num) 84 | 85 | if args.if_test_training_adam == False: 86 | args.if_adam_after_sgd = False 87 | 88 | if args.if_toy == True: 89 | print('running for toy scenario') 90 | args.bit_num = 2 91 | args.channel_num = 1 92 | args.tap_num = 1 93 | args.mb_size = 4 94 | args.mb_size_meta_train = 4 95 | args.mb_size_meta_test = 4 96 | args.num_channels_meta = 20 97 | args.num_neurons_encoder = 4 98 | args.num_neurons_decoder = 4 99 | elif args.if_RBF == True: 100 | print('running for a more realistic scenario') 101 | args.bit_num = 4 102 | args.channel_num = 4 103 | args.tap_num = 3 104 | args.mb_size = 16 105 | args.mb_size_meta_train = 16 106 | args.mb_size_meta_test = 16 107 | args.num_channels_meta = 100 108 | args.num_neurons_encoder = 16 109 | args.num_neurons_decoder = 16 110 | else: 111 | print('running on custom environment') 112 | print('Running on device: {}'.format(args.device)) 113 | return args 114 | 115 | if __name__ == '__main__': 116 | args = parse_args() 117 | print('Called with args:') 118 | print(args) 119 | 120 | 121 | curr_time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 122 | common_dir = './' + args.path_for_common_dir + curr_time + '/' 123 | 124 | PATH_before_adapt = common_dir + 'saved_model/' + 'before_adapt/' + 'init_net' 125 | PATH_meta_intermediate = common_dir + 'saved_model/' + 'during_meta_training/' + 'epochs/' 126 | 127 | os.makedirs(common_dir + 'saved_model/' + 'before_adapt/') 128 | os.makedirs(common_dir + 'saved_model/' + 'after_adapt/') 129 | os.makedirs(PATH_meta_intermediate) 130 | 131 | os.makedirs(common_dir + 'meta_training_channels/') 132 | os.makedirs(common_dir + 'test_channels/') 133 | os.makedirs(common_dir + 'test_result/') 134 | 135 | dir_meta_training = common_dir + 'TB/' + 'meta_training' 136 | writer_meta_training = SummaryWriter(dir_meta_training) 137 | dir_during_adapt = common_dir + 'TB/' + 'during_adapt/' 138 | 139 | test_Eb_over_N_range = [args.test_per_adapt_fixed_Eb_over_N_value] 140 | test_adapt_range = [0, 1, 2, 5, 10, 100, 200, 1000, 10000] 141 | 142 | if len(test_Eb_over_N_range) > 1: 143 | assert len(test_adapt_range) == 1 144 | if len(test_adapt_range) > 1: 145 | assert len(test_Eb_over_N_range) == 1 146 | 147 | test_result_all_PATH = common_dir + 'test_result/' + 'test_result.mat' 148 | save_test_result_dict = {} 149 | 150 | actual_channel_num = args.channel_num * 2 151 | 152 | net = dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, n_inv_filter = args.tap_num, 153 | num_neurons_decoder=args.num_neurons_decoder, if_bias=args.if_bias, if_relu=args.if_relu, if_RTN=args.if_RTN) 154 | if torch.cuda.is_available(): 155 | net = net.to(args.device) 156 | net_for_testtraining = dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, n_inv_filter = args.tap_num, 157 | num_neurons_decoder=args.num_neurons_decoder, if_bias=args.if_bias, if_relu=args.if_relu, if_RTN=args.if_RTN) 158 | if torch.cuda.is_available(): 159 | net_for_testtraining = net_for_testtraining.to(args.device) 160 | 161 | Eb_over_N = pow(10, (args.Eb_over_N_db/10)) 162 | R = args.bit_num/args.channel_num 163 | noise_var = 1 / (2 * R * Eb_over_N) 164 | Noise = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), noise_var * torch.eye(actual_channel_num)) 165 | if args.path_for_meta_training_channels is None: 166 | print('generate meta-training channels') 167 | h_list_meta = channel_set_gen(args.num_channels_meta, args.tap_num, args.if_toy) 168 | h_list_meta_path = common_dir + 'meta_training_channels/' + 'training_channels.pckl' 169 | f_meta_channels = open(h_list_meta_path, 'wb') 170 | pickle.dump(h_list_meta, f_meta_channels) 171 | f_meta_channels.close() 172 | else: 173 | print('load previously generated channels') 174 | h_list_meta_path = args.path_for_meta_training_channels + '/' + 'training_channels.pckl' 175 | f_meta_channels = open(h_list_meta_path, 'rb') 176 | h_list_meta = pickle.load(f_meta_channels) 177 | f_meta_channels.close() 178 | 179 | if args.path_for_meta_trained_net is None: 180 | if args.if_joint_training: 181 | print('start joint training') 182 | else: 183 | print('start meta-training') 184 | multi_task_learning(args, net, h_list_meta, writer_meta_training, Noise) 185 | torch.save(net.state_dict(), PATH_before_adapt) 186 | else: 187 | print('load previously saved autoencoder') 188 | PATH_before_adapt = args.path_for_meta_trained_net 189 | 190 | if args.path_for_test_channels is None: 191 | print('generate test channels') 192 | h_list_test = channel_set_gen(args.num_channels_test, args.tap_num, args.if_toy) 193 | h_list_test_path = common_dir + 'test_channels/' + 'test_channels.pckl' 194 | f_test_channels = open(h_list_test_path, 'wb') 195 | pickle.dump(h_list_test, f_test_channels) 196 | f_test_channels.close() 197 | else: 198 | print('load previously generated channels') 199 | h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl' 200 | f_test_channels = open(h_list_test_path, 'rb') 201 | h_list_test = pickle.load(f_test_channels) 202 | f_test_channels.close() 203 | 204 | if len(h_list_test) > args.num_channels_test: 205 | h_list_test = h_list_test[:args.num_channels_test] 206 | print('used test channels', h_list_test) 207 | 208 | dir_test = common_dir + 'TB/' + 'test' 209 | writer_test = SummaryWriter(dir_test) 210 | 211 | print('start adaptation with test set') 212 | if_val = False 213 | total_block_error_rate = torch.zeros(args.num_channels_test, len(test_Eb_over_N_range), len(test_adapt_range)) 214 | ind_adapt_steps = 0 215 | for adapt_steps in test_adapt_range: 216 | print('curr adaptation: ', adapt_steps) 217 | os.mkdir(common_dir + 'saved_model/' + 'after_adapt/' + str(adapt_steps) + '_adapt_steps/') 218 | os.mkdir(common_dir + 'test_result/' + str(adapt_steps) + '_adapt_steps/') 219 | test_result_per_adapt_steps = common_dir + 'test_result/' + str(adapt_steps) + '_adapt_steps/' + 'test_result.mat' 220 | save_test_result_dict_per_adapt_steps = {} 221 | 222 | block_error_rate = torch.zeros(args.num_channels_test, len(test_Eb_over_N_range)) 223 | ind_h = 0 224 | for h in h_list_test: 225 | print('current channel ind', ind_h) 226 | PATH_after_adapt = common_dir + 'saved_model/' + 'after_adapt/' + str(adapt_steps) + '_adapt_steps/'+ str(ind_h) + 'th_adapted_net' 227 | writer_per_test_channel = [] 228 | test_training(args, h, net_for_testtraining, Noise, PATH_before_adapt, PATH_after_adapt, adapt_steps) 229 | # test 230 | ind_snr = 0 231 | for test_snr in test_Eb_over_N_range: 232 | block_error_rate_per_snr_per_channel = test_per_channel_per_snr(args, h, net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt, if_val) 233 | block_error_rate[ind_h, ind_snr] = block_error_rate_per_snr_per_channel 234 | total_block_error_rate[ind_h, ind_snr, ind_adapt_steps] = block_error_rate_per_snr_per_channel 235 | ind_snr += 1 236 | ind_h += 1 237 | ind_snr = 0 238 | save_test_result_dict_per_adapt_steps['block_error_rate'] = block_error_rate.data.numpy() 239 | sio.savemat(test_result_per_adapt_steps, save_test_result_dict_per_adapt_steps) 240 | writer_test.add_scalar('average (h) block error rate per adaptation steps', torch.mean(block_error_rate[:, :]), adapt_steps) 241 | ind_adapt_steps += 1 242 | 243 | save_test_result_dict['block_error_rate_total'] = total_block_error_rate.data.numpy() 244 | sio.savemat(test_result_all_PATH, save_test_result_dict) -------------------------------------------------------------------------------- /nets/auto_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | from utils.funcs import complex_mul_taps, complex_conv_transpose 5 | 6 | 7 | class basic_DNN(nn.Module): 8 | def __init__(self, M, num_neurons_encoder, n, n_inv_filter, num_neurons_decoder, if_bias, if_relu, if_RTN): 9 | super(basic_DNN, self).__init__() 10 | self.enc_fc1 = nn.Linear(M, num_neurons_encoder, bias=if_bias) 11 | self.enc_fc2 = nn.Linear(num_neurons_encoder, n, bias=if_bias) 12 | 13 | ### norm, nothing to train 14 | ### channel, nothing to train 15 | 16 | num_inv_filter = 2 * n_inv_filter 17 | if if_RTN: 18 | self.rtn_1 = nn.Linear(n, n, bias=if_bias) 19 | self.rtn_2 = nn.Linear(n, n, bias=if_bias) 20 | self.rtn_3 = nn.Linear(n, num_inv_filter, bias=if_bias) 21 | else: 22 | pass 23 | 24 | self.dec_fc1 = nn.Linear(n, num_neurons_decoder, bias=if_bias) 25 | self.dec_fc2 = nn.Linear(num_neurons_decoder, M, bias=if_bias) 26 | if if_relu: 27 | self.activ = nn.ReLU() 28 | else: 29 | self.activ = nn.Tanh() 30 | self.tanh = nn.Tanh() 31 | def forward(self, x, h, noise_dist, device, if_RTN): 32 | x = self.enc_fc1(x) 33 | x = self.activ(x) 34 | x = self.enc_fc2(x) 35 | # normalize 36 | x_norm = torch.norm(x, dim=1) 37 | x_norm = x_norm.unsqueeze(1) 38 | x = pow(x.shape[1], 0.5) * pow(0.5, 0.5) * x / x_norm # since each has ^2 norm as 0.5 -> complex 1 39 | # channel 40 | x = complex_mul_taps(h, x) 41 | x = x.to(device) 42 | # noise 43 | n = torch.zeros(x.shape[0], x.shape[1]) 44 | for noise_batch_ind in range(x.shape[0]): 45 | n[noise_batch_ind] = noise_dist.sample() 46 | n = n.type(torch.FloatTensor).to(device) 47 | x = x + n # noise insertion 48 | 49 | # RTN 50 | if if_RTN: 51 | h_inv = self.rtn_1(x) 52 | h_inv = self.tanh(h_inv) 53 | h_inv = self.rtn_2(h_inv) 54 | h_inv = self.tanh(h_inv) 55 | h_inv = self.rtn_3(h_inv) # no activation for the final rtn (linear activation without weights) 56 | x = complex_conv_transpose(h_inv, x) 57 | x = x.to(device) 58 | else: 59 | pass 60 | x = self.dec_fc1(x) 61 | x = self.activ(x) 62 | x = self.dec_fc2(x) # softmax taken at loss function 63 | return x 64 | 65 | def dnn(**kwargs): 66 | net = basic_DNN(**kwargs) 67 | return net 68 | -------------------------------------------------------------------------------- /nets/auto_encoder_for_vis.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | from utils.funcs import complex_mul_taps, complex_conv_transpose 5 | 6 | # code for visualizing toy example (Fig. 3) 7 | 8 | class basic_DNN(nn.Module): 9 | def __init__(self, M, num_neurons_encoder, n, n_inv_filter, num_neurons_decoder, if_bias, if_relu, if_RTN): 10 | super(basic_DNN, self).__init__() 11 | self.enc_fc1 = nn.Linear(M, num_neurons_encoder, bias=if_bias) 12 | self.enc_fc2 = nn.Linear(num_neurons_encoder, n, bias=if_bias) 13 | 14 | ### norm, nothing to train 15 | ### channel, nothing to train 16 | num_inv_filter = 2 * n_inv_filter 17 | if if_RTN: 18 | self.rtn_1 = nn.Linear(n, n, bias=if_bias) 19 | self.rtn_2 = nn.Linear(n, n, bias=if_bias) 20 | self.rtn_3 = nn.Linear(n, num_inv_filter, bias=if_bias) 21 | else: 22 | pass 23 | self.dec_fc1 = nn.Linear(n, num_neurons_decoder, bias=if_bias) 24 | self.dec_fc2 = nn.Linear(num_neurons_decoder, M, bias=if_bias) 25 | if if_relu: 26 | self.activ = nn.ReLU() 27 | else: 28 | self.activ = nn.Tanh() 29 | self.tanh = nn.Tanh() 30 | def forward(self, x, h, noise_dist, device, if_RTN, artificial_recieved_signal): 31 | x = self.enc_fc1(x) 32 | x = self.activ(x) 33 | x = self.enc_fc2(x) 34 | # normalize 35 | x_norm = torch.norm(x, dim=1) 36 | x_norm = x_norm.unsqueeze(1) 37 | x = pow(x.shape[1], 0.5) * pow(0.5, 0.5) * x / x_norm # since each has ^2 norm as 0.5 -> complex 1 38 | modulated_symbol = x 39 | # channel 40 | x = complex_mul_taps(h, x) 41 | x = x.to(device) 42 | # noise 43 | n = torch.zeros(x.shape[0], x.shape[1]) 44 | for noise_batch_ind in range(x.shape[0]): 45 | n[noise_batch_ind] = noise_dist.sample() 46 | n = n.type(torch.FloatTensor).to(device) 47 | x = x + n # noise insertion 48 | 49 | received_signal = x 50 | if artificial_recieved_signal is not None: 51 | x = artificial_recieved_signal 52 | modulated_symbol = None 53 | received_signal = None 54 | #### RTN 55 | if if_RTN: 56 | h_inv = self.rtn_1(x) 57 | h_inv = self.tanh(h_inv) 58 | h_inv = self.rtn_2(h_inv) 59 | h_inv = self.tanh(h_inv) 60 | h_inv = self.rtn_3(h_inv) # no activation for the final rtn (linear activation without weights) 61 | x = complex_conv_transpose(h_inv, x) 62 | x = x.to(device) 63 | else: 64 | pass 65 | 66 | x = self.dec_fc1(x) 67 | x = self.activ(x) 68 | x = self.dec_fc2(x) # softmax taken at loss function 69 | 70 | return x, modulated_symbol, received_signal 71 | 72 | def dnn_vis(**kwargs): 73 | net = basic_DNN(**kwargs) 74 | return net 75 | -------------------------------------------------------------------------------- /nets/meta_net.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn import functional as F 5 | from utils.funcs import complex_mul_taps, complex_conv_transpose 6 | 7 | class meta_Net_DNN(nn.Module): 8 | def __init__(self, if_relu): # it only gets paramters from other network's parameters 9 | super(meta_Net_DNN, self).__init__() 10 | if if_relu: 11 | self.activ = nn.ReLU() 12 | else: 13 | self.activ = nn.Tanh() 14 | self.tanh = nn.Tanh() 15 | 16 | def forward(self, x, var, if_bias, h, device, noise_dist, if_RTN): 17 | idx_init = 0 18 | if if_bias: 19 | gap = 2 20 | else: 21 | gap = 1 22 | idx = idx_init 23 | while idx < len(var): 24 | if idx > idx_init: # no activation from the beginning 25 | if idx == gap * 2+idx_init: # after last layer of encoder 26 | pass 27 | else: 28 | x = self.activ(x) 29 | if idx == idx_init: 30 | if if_bias: 31 | w1, b1 = var[idx], var[idx + 1] # weight and bias 32 | x = F.linear(x, w1, b1) 33 | idx += 2 34 | else: 35 | w1 = var[idx] # weight 36 | x = F.linear(x, w1) 37 | idx += 1 38 | elif idx == gap * 1+idx_init: 39 | if if_bias: 40 | w2, b2 = var[idx], var[idx + 1] # weight and bias 41 | x = F.linear(x, w2, b2) 42 | idx += 2 43 | else: 44 | w2 = var[idx] # weight and bias 45 | x = F.linear(x, w2) 46 | idx += 1 47 | elif idx == gap * 2+idx_init: 48 | #### now we need to normalize and then pass the channel 49 | x_norm = torch.norm(x, dim=1) 50 | x_norm = x_norm.unsqueeze(1) 51 | x = pow(x.shape[1], 0.5) * pow(0.5, 0.5) * x / x_norm 52 | x = complex_mul_taps(h, x) 53 | x = x.to(device) 54 | # noise 55 | n = torch.zeros(x.shape[0], x.shape[1]) 56 | for noise_batch_ind in range(x.shape[0]): 57 | n[noise_batch_ind] = noise_dist.sample() 58 | n = n.type(torch.FloatTensor).to(device) 59 | x = x + n 60 | 61 | if if_RTN: 62 | if if_bias: 63 | w_rtn_1, b_rtn_1 = var[idx], var[idx+1] 64 | h_inv = F.linear(x, w_rtn_1, b_rtn_1) 65 | h_inv = self.tanh(h_inv) 66 | w_rtn_2, b_rtn_2 = var[idx+2], var[idx + 3] 67 | h_inv = F.linear(h_inv, w_rtn_2, b_rtn_2) 68 | h_inv = self.tanh(h_inv) 69 | w_rtn_3, b_rtn_3 = var[idx + 4], var[idx + 5] 70 | h_inv = F.linear(h_inv, w_rtn_3, b_rtn_3) 71 | rtn_gap = 6 72 | else: 73 | w_rtn_1 = var[idx] 74 | h_inv = F.linear(x, w_rtn_1) 75 | h_inv = self.tanh(h_inv) 76 | w_rtn_2 = var[idx+1] 77 | h_inv = F.linear(h_inv, w_rtn_2) 78 | h_inv = self.tanh(h_inv) 79 | w_rtn_3 = var[idx+2] 80 | h_inv = F.linear(h_inv, w_rtn_3) 81 | rtn_gap = 3 82 | x = complex_conv_transpose(h_inv, x) 83 | x = x.to(device) 84 | else: 85 | rtn_gap = 0 86 | ############## from now, demodulator 87 | if if_bias: 88 | w3, b3 = var[idx+ rtn_gap], var[idx + rtn_gap + 1] # weight and bias 89 | x = F.linear(x, w3, b3) 90 | idx += (2 + rtn_gap) 91 | else: 92 | w3 = var[idx + rtn_gap] # weight 93 | x = F.linear(x, w3) 94 | idx += (1 + rtn_gap) 95 | elif idx == gap * 3+rtn_gap+idx_init: 96 | if if_bias: 97 | w4, b4 = var[idx], var[idx + 1] # weight and bias 98 | x = F.linear(x, w4, b4) 99 | idx += 2 100 | else: 101 | w4 = var[idx] # weight 102 | x = F.linear(x, w4) 103 | idx += 1 104 | return x 105 | 106 | def meta_dnn(**kwargs): 107 | net = meta_Net_DNN(**kwargs) 108 | return net 109 | -------------------------------------------------------------------------------- /run/realistic/vanilla/fixed_init.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_RBF --num_epochs_meta_train 0 --path_for_common_dir 'realistic/vanilla/fixed_init/' --path_for_meta_training_channels '../../../saved_data/realistic/channels/meta_training_channels' --path_for_test_channels '../../../saved_data/realistic/channels/new_channels' -------------------------------------------------------------------------------- /run/realistic/vanilla/joint.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_RBF --num_epochs_meta_train 20000 --if_joint_training --path_for_common_dir 'realistic/vanilla/joint/' --path_for_meta_training_channels '../../../saved_data/realistic/channels/meta_training_channels' --path_for_test_channels '../../../saved_data/realistic/channels/new_channels' --path_for_meta_trained_net '../../../saved_data/realistic/nets/vanilla/joint/init_net' -------------------------------------------------------------------------------- /run/realistic/vanilla/meta.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_RBF --num_epochs_meta_train 10000 --if_adam_after_sgd --path_for_common_dir 'realistic/vanilla/meta/' --path_for_meta_training_channels '../../../saved_data/realistic/channels/meta_training_channels' --path_for_test_channels '../../../saved_data/realistic/channels/new_channels' --path_for_meta_trained_net '../../../saved_data/realistic/nets/vanilla/meta/init_net' -------------------------------------------------------------------------------- /run/realistic/with_RTN/fixed_init.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_RBF --num_epochs_meta_train 0 --if_RTN --path_for_common_dir 'realistic/with_RTN/fixed_init/' --path_for_meta_training_channels '../../../saved_data/realistic/channels/meta_training_channels' --path_for_test_channels '../../../saved_data/realistic/channels/new_channels' -------------------------------------------------------------------------------- /run/realistic/with_RTN/joint.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_RBF --num_epochs_meta_train 20000 --if_joint_training --if_RTN --path_for_common_dir 'realistic/with_RTN/joint/' --path_for_meta_training_channels '../../../saved_data/realistic/channels/meta_training_channels' --path_for_test_channels '../../../saved_data/realistic/channels/new_channels' --path_for_meta_trained_net '../../../saved_data/realistic/nets/with_RTN/joint/init_net' -------------------------------------------------------------------------------- /run/realistic/with_RTN/meta.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_RBF --num_epochs_meta_train 10000 --if_RTN --if_adam_after_sgd --path_for_common_dir 'realistic/with_RTN/meta/' --path_for_meta_training_channels '../../../saved_data/realistic/channels/meta_training_channels' --path_for_test_channels '../../../saved_data/realistic/channels/new_channels' --path_for_meta_trained_net '../../../saved_data/realistic/nets/with_RTN/meta/init_net' -------------------------------------------------------------------------------- /run/toy/vanilla/fixed_init.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_toy --num_epochs_meta_train 0 --path_for_common_dir 'toy/vanilla/fixed_init/' -------------------------------------------------------------------------------- /run/toy/vanilla/joint.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_toy --num_epochs_meta_train 20000 --if_joint_training --path_for_common_dir 'toy/vanilla/joint/' --path_for_meta_trained_net '../../../saved_data/toy/nets/vanilla/joint/init_net' -------------------------------------------------------------------------------- /run/toy/vanilla/meta.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_toy --num_epochs_meta_train 10000 --if_adam_after_sgd --path_for_common_dir 'toy/vanilla/meta/' --path_for_meta_trained_net '../../../saved_data/toy/nets/vanilla/meta/init_net' -------------------------------------------------------------------------------- /run/toy/with_RTN/fixed_init.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_toy --num_epochs_meta_train 0 --if_RTN --path_for_common_dir 'toy/with_RTN/fixed_init/' -------------------------------------------------------------------------------- /run/toy/with_RTN/joint.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_toy --num_epochs_meta_train 20000 --if_joint_training --if_RTN --path_for_common_dir 'toy/with_RTN/joint/' --path_for_meta_trained_net '../../../saved_data/toy/nets/with_RTN/joint/init_net' -------------------------------------------------------------------------------- /run/toy/with_RTN/meta.sh: -------------------------------------------------------------------------------- 1 | python ../../../main.py --if_toy --num_epochs_meta_train 10000 --if_RTN --if_adam_after_sgd --path_for_common_dir 'toy/with_RTN/meta/' --path_for_meta_trained_net '../../../saved_data/toy/nets/with_RTN/meta/init_net' -------------------------------------------------------------------------------- /saved_data/realistic/channels/meta_training_channels/training_channels.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/realistic/channels/meta_training_channels/training_channels.pckl -------------------------------------------------------------------------------- /saved_data/realistic/channels/new_channels/test_channels.pckl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/realistic/channels/new_channels/test_channels.pckl -------------------------------------------------------------------------------- /saved_data/realistic/nets/vanilla/joint/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/realistic/nets/vanilla/joint/init_net -------------------------------------------------------------------------------- /saved_data/realistic/nets/vanilla/meta/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/realistic/nets/vanilla/meta/init_net -------------------------------------------------------------------------------- /saved_data/realistic/nets/with_RTN/joint/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/realistic/nets/with_RTN/joint/init_net -------------------------------------------------------------------------------- /saved_data/realistic/nets/with_RTN/meta/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/realistic/nets/with_RTN/meta/init_net -------------------------------------------------------------------------------- /saved_data/toy/after_adapted_to_new_channels_nets/vanilla/meta/after_adapt/1_adapt_steps/0th_adapted_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/after_adapted_to_new_channels_nets/vanilla/meta/after_adapt/1_adapt_steps/0th_adapted_net -------------------------------------------------------------------------------- /saved_data/toy/after_adapted_to_new_channels_nets/vanilla/meta/after_adapt/1_adapt_steps/1th_adapted_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/after_adapted_to_new_channels_nets/vanilla/meta/after_adapt/1_adapt_steps/1th_adapted_net -------------------------------------------------------------------------------- /saved_data/toy/after_adapted_to_new_channels_nets/with_RTN/meta/after_adapt/1_adapt_steps/0th_adapted_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/after_adapted_to_new_channels_nets/with_RTN/meta/after_adapt/1_adapt_steps/0th_adapted_net -------------------------------------------------------------------------------- /saved_data/toy/after_adapted_to_new_channels_nets/with_RTN/meta/after_adapt/1_adapt_steps/1th_adapted_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/after_adapted_to_new_channels_nets/with_RTN/meta/after_adapt/1_adapt_steps/1th_adapted_net -------------------------------------------------------------------------------- /saved_data/toy/nets/vanilla/joint/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/nets/vanilla/joint/init_net -------------------------------------------------------------------------------- /saved_data/toy/nets/vanilla/meta/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/nets/vanilla/meta/init_net -------------------------------------------------------------------------------- /saved_data/toy/nets/with_RTN/joint/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/nets/with_RTN/joint/init_net -------------------------------------------------------------------------------- /saved_data/toy/nets/with_RTN/meta/init_net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kclip/meta-autoencoder/f0a47ae67c9c7938bbd4007a3a09d013b10e05c7/saved_data/toy/nets/with_RTN/meta/init_net -------------------------------------------------------------------------------- /toy_visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 37, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import argparse\n", 11 | "from data_gen.data_set import message_gen, channel_set_gen\n", 12 | "from training.train import test_training\n", 13 | "from training.test import test_per_channel_per_snr\n", 14 | "from nets.auto_encoder import dnn\n", 15 | "from torch.utils.tensorboard import SummaryWriter\n", 16 | "from nets.auto_encoder_for_vis import dnn_vis\n", 17 | "from training.meta_train import multi_task_learning\n", 18 | "import pickle\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import scipy.io as sio\n", 21 | "import datetime\n", 22 | "import numpy\n", 23 | "import os" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 38, 29 | "metadata": {}, 30 | "outputs": [ 31 | { 32 | "data": { 33 | "text/plain": [ 34 | "_StoreTrueAction(option_strings=['--if_adam_after_sgd'], dest='if_adam_after_sgd', nargs=0, const=True, default=False, type=None, choices=None, help=None, metavar=None)" 35 | ] 36 | }, 37 | "execution_count": 38, 38 | "metadata": {}, 39 | "output_type": "execute_result" 40 | } 41 | ], 42 | "source": [ 43 | "parser = argparse.ArgumentParser(description='end_to_end-meta')\n", 44 | "\n", 45 | "# bit num (k), channel uses (n), tap number (L), number of pilots (P), Eb/N0\n", 46 | "parser.add_argument('--bit_num', type=int, default=4, help='number of bits')\n", 47 | "parser.add_argument('--channel_num', type=int, default=4, help='number of channel uses')\n", 48 | "parser.add_argument('--tap_num', type=int, default=3, help='..')\n", 49 | "parser.add_argument('--mb_size', type=int, default=16, help='minibatch size')\n", 50 | "parser.add_argument('--mb_size_meta_train', type=int, default=16,\n", 51 | " help='minibatch size during meta-training (this can be useful for decreasing pilots)')\n", 52 | "parser.add_argument('--mb_size_meta_test', type=int, default=16,\n", 53 | " help='minibatch size for query set (this can be useful for decreasing pilots)')\n", 54 | "parser.add_argument('--Eb_over_N_db', type=float, default=15,\n", 55 | " help='energy per bit to noise power spectral density ratio')\n", 56 | "\n", 57 | "# paths\n", 58 | "parser.add_argument('--path_for_common_dir', dest='path_for_common_dir',\n", 59 | " default='default_folder/default_subfolder/', type=str)\n", 60 | "parser.add_argument('--path_for_meta_training_channels', dest='path_for_meta_training_channels', default=None,\n", 61 | " type=str)\n", 62 | "parser.add_argument('--path_for_test_channels', dest='path_for_test_channels', default=None, type=str)\n", 63 | "parser.add_argument('--path_for_meta_trained_net', dest='path_for_meta_trained_net', default=None, type=str)\n", 64 | "\n", 65 | "# neural network architecture (number of neurons for hidden layer)\n", 66 | "parser.add_argument('--num_neurons_encoder', type=int, default=None, help='number of neuron in hidden layer in encoder')\n", 67 | "parser.add_argument('--num_neurons_decoder', type=int, default=None, help='number of neuron in hidden layer in decoder')\n", 68 | "# whether to use bias and relu (if not relu: tanh)\n", 69 | "parser.add_argument('--if_not_bias', dest='if_bias', action='store_false', default=True)\n", 70 | "parser.add_argument('--if_not_relu', dest='if_relu', action='store_false', default=True)\n", 71 | "# RTN\n", 72 | "parser.add_argument('--if_RTN', dest='if_RTN', action='store_true', default=False)\n", 73 | "# in case of running on gpu, index for cuda device\n", 74 | "parser.add_argument('--cuda_ind', type=int, default=0, help='index for cuda device')\n", 75 | "\n", 76 | "# experiment details (hyperparameters, number of data for calculating performance and for meta-training\n", 77 | "parser.add_argument('--lr_testtraining', type=float, default=0.001, help='lr for adaptation to new channel')\n", 78 | "parser.add_argument('--lr_meta_update', type=float, default=0.01, help='lr during meta-training: outer loop (update initialization) lr')\n", 79 | "parser.add_argument('--lr_meta_inner', type=float, default=0.1, help='lr during meta-training: inner loop (local adaptation) lr')\n", 80 | "parser.add_argument('--test_size', type=int, default=1000000, help='number of messages to calculate BLER for test (new channel)')\n", 81 | "parser.add_argument('--num_channels_meta', type=int, default=100, help='number of meta-training channels (K)')\n", 82 | "parser.add_argument('--num_channels_test', type=int, default=20, help='number of new channels for test (to get average over BLER)')\n", 83 | "parser.add_argument('--tasks_per_metaupdate', type=int, default=20, help='number of meta-training channels considered in one meta-update')\n", 84 | "parser.add_argument('--num_meta_local_updates', type=int, default=1, help='number of local adaptation in meta-training')\n", 85 | "parser.add_argument('--num_epochs_meta_train', type=int, default=10000,\n", 86 | " help='number epochs for meta-training')\n", 87 | "\n", 88 | "# if run for joint training, if false: meta-learning\n", 89 | "parser.add_argument('--if_joint_training', dest='if_joint_training', action='store_true', default=False) # else: meta-learning for multi-task learning\n", 90 | "# whether to use Adam optimizer to adapt to a new channel\n", 91 | "parser.add_argument('--if_not_test_training_adam', dest='if_test_training_adam', action='store_false',\n", 92 | " default=True)\n", 93 | "# if run on toy example (Fig. 2 and 3)\n", 94 | "parser.add_argument('--if_toy', dest='if_toy', action='store_true',\n", 95 | " default=False)\n", 96 | "# to run on a more realistic example (Fig. 4)\n", 97 | "parser.add_argument('--if_RBF', dest='if_RBF', action='store_true',\n", 98 | " default=False)\n", 99 | "parser.add_argument('--test_per_adapt_fixed_Eb_over_N_value', type=int, default=15,\n", 100 | " help='Eb/N0 in db for test')\n", 101 | "# desinged for maml: sgd during args.num_meta_local_updates with args.lr_meta_inner and then follow Adam optimizer with args.lr_testtraining\n", 102 | "parser.add_argument('--if_adam_after_sgd', dest='if_adam_after_sgd', action='store_true',\n", 103 | " default=False)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 39, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "# common setting for toy\n", 113 | "run_script = \"--if_toy --num_epochs_meta_train 0 --path_for_common_dir toy/jupyter/\"\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 40, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "running for toy scenario\n", 126 | "Running on device: cuda:0\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "if __name__ == '__main__':\n", 132 | " args = parser.parse_args(run_script.split())\n", 133 | " args.device = torch.device(\"cuda:\" + str(args.cuda_ind) if torch.cuda.is_available() else \"cpu\")\n", 134 | " if args.num_neurons_encoder == None: # unless specified, set number of hidden neurons to be same as the number of possible messages\n", 135 | " args.num_neurons_encoder = pow(2,args.bit_num)\n", 136 | " if args.num_neurons_decoder == None:\n", 137 | " args.num_neurons_decoder = pow(2, args.bit_num)\n", 138 | "\n", 139 | " if args.if_test_training_adam == False:\n", 140 | " args.if_adam_after_sgd = False\n", 141 | "\n", 142 | " if args.if_toy == True:\n", 143 | " print('running for toy scenario')\n", 144 | " args.bit_num = 2\n", 145 | " args.channel_num = 1\n", 146 | " args.tap_num = 1\n", 147 | " args.mb_size = 4\n", 148 | " args.mb_size_meta_train = 4\n", 149 | " args.mb_size_meta_test = 4\n", 150 | " args.num_channels_meta = 20\n", 151 | " args.num_neurons_encoder = 4\n", 152 | " args.num_neurons_decoder = 4\n", 153 | " elif args.if_RBF == True:\n", 154 | " print('running for a more realistic scenario')\n", 155 | " args.bit_num = 4\n", 156 | " args.channel_num = 4\n", 157 | " args.tap_num = 3\n", 158 | " args.mb_size = 16\n", 159 | " args.mb_size_meta_train = 16\n", 160 | " args.mb_size_meta_test = 16\n", 161 | " args.num_channels_meta = 100\n", 162 | " args.num_neurons_encoder = 16\n", 163 | " args.num_neurons_decoder = 16\n", 164 | " else:\n", 165 | " print('running on custom environment')\n", 166 | " print('Running on device: {}'.format(args.device))" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 41, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "generate meta-training channels\n", 179 | "start meta-training\n", 180 | "generate test channels\n", 181 | "used test channels [tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071]), tensor([0.7071, 0.7071]), tensor([-0.7071, 0.7071])]\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | " curr_time = datetime.datetime.now().strftime(\"%Y_%m_%d_%H_%M_%S\")\n", 187 | " common_dir = './' + args.path_for_common_dir + curr_time + '/'\n", 188 | "\n", 189 | " PATH_before_adapt = common_dir + 'saved_model/' + 'before_adapt/' + 'init_net'\n", 190 | " PATH_meta_intermediate = common_dir + 'saved_model/' + 'during_meta_training/' + 'epochs/'\n", 191 | "\n", 192 | " os.makedirs(common_dir + 'saved_model/' + 'before_adapt/')\n", 193 | " os.makedirs(common_dir + 'saved_model/' + 'after_adapt/')\n", 194 | " os.makedirs(PATH_meta_intermediate)\n", 195 | "\n", 196 | " os.makedirs(common_dir + 'meta_training_channels/')\n", 197 | " os.makedirs(common_dir + 'test_channels/')\n", 198 | " os.makedirs(common_dir + 'test_result/')\n", 199 | "\n", 200 | " dir_meta_training = common_dir + 'TB/' + 'meta_training'\n", 201 | " writer_meta_training = SummaryWriter(dir_meta_training)\n", 202 | " dir_during_adapt = common_dir + 'TB/' + 'during_adapt/'\n", 203 | "\n", 204 | " test_Eb_over_N_range = [args.test_per_adapt_fixed_Eb_over_N_value]\n", 205 | " test_adapt_range = [0, 1, 2, 5, 10, 100, 200, 1000, 10000]\n", 206 | "\n", 207 | " if len(test_Eb_over_N_range) > 1:\n", 208 | " assert len(test_adapt_range) == 1\n", 209 | " if len(test_adapt_range) > 1:\n", 210 | " assert len(test_Eb_over_N_range) == 1\n", 211 | "\n", 212 | " test_result_all_PATH = common_dir + 'test_result/' + 'test_result.mat'\n", 213 | " save_test_result_dict = {}\n", 214 | "\n", 215 | " actual_channel_num = args.channel_num * 2\n", 216 | "\n", 217 | " net = dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, n_inv_filter = args.tap_num,\n", 218 | " num_neurons_decoder=args.num_neurons_decoder, if_bias=args.if_bias, if_relu=args.if_relu, if_RTN=args.if_RTN)\n", 219 | " if torch.cuda.is_available():\n", 220 | " net = net.to(args.device)\n", 221 | " net_for_testtraining = dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, n_inv_filter = args.tap_num,\n", 222 | " num_neurons_decoder=args.num_neurons_decoder, if_bias=args.if_bias, if_relu=args.if_relu, if_RTN=args.if_RTN)\n", 223 | " if torch.cuda.is_available():\n", 224 | " net_for_testtraining = net_for_testtraining.to(args.device)\n", 225 | " net_for_vis = dnn(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num,\n", 226 | " n_inv_filter = args.tap_num,\n", 227 | " num_neurons_decoder=args.num_neurons_decoder, if_bias=args.if_bias,\n", 228 | " if_relu=args.if_relu, if_RTN=args.if_RTN)\n", 229 | " if torch.cuda.is_available():\n", 230 | " net_for_vis = net_for_vis.to(args.device)\n", 231 | "\n", 232 | " Eb_over_N = pow(10, (args.Eb_over_N_db/10))\n", 233 | " R = args.bit_num/args.channel_num\n", 234 | " noise_var = 1 / (2 * R * Eb_over_N)\n", 235 | " Noise = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), noise_var * torch.eye(actual_channel_num))\n", 236 | " if args.path_for_meta_training_channels is None:\n", 237 | " print('generate meta-training channels')\n", 238 | " h_list_meta = channel_set_gen(args.num_channels_meta, args.tap_num, args.if_toy)\n", 239 | " h_list_meta_path = common_dir + 'meta_training_channels/' + 'training_channels.pckl'\n", 240 | " f_meta_channels = open(h_list_meta_path, 'wb')\n", 241 | " pickle.dump(h_list_meta, f_meta_channels)\n", 242 | " f_meta_channels.close()\n", 243 | " else:\n", 244 | " print('load previously generated channels')\n", 245 | " h_list_meta_path = args.path_for_meta_training_channels + '/' + 'training_channels.pckl'\n", 246 | " f_meta_channels = open(h_list_meta_path, 'rb')\n", 247 | " h_list_meta = pickle.load(f_meta_channels)\n", 248 | " f_meta_channels.close()\n", 249 | "\n", 250 | " if args.path_for_meta_trained_net is None:\n", 251 | " if args.if_joint_training:\n", 252 | " print('start joint training')\n", 253 | " else:\n", 254 | " print('start meta-training')\n", 255 | " multi_task_learning(args, net, h_list_meta, writer_meta_training, Noise)\n", 256 | " torch.save(net.state_dict(), PATH_before_adapt)\n", 257 | " else:\n", 258 | " PATH_before_adapt = args.path_for_meta_trained_net\n", 259 | "\n", 260 | " if args.path_for_test_channels is None:\n", 261 | " print('generate test channels')\n", 262 | " h_list_test = channel_set_gen(args.num_channels_test, args.tap_num, args.if_toy)\n", 263 | " h_list_test_path = common_dir + 'test_channels/' + 'test_channels.pckl'\n", 264 | " f_test_channels = open(h_list_test_path, 'wb')\n", 265 | " pickle.dump(h_list_test, f_test_channels)\n", 266 | " f_test_channels.close()\n", 267 | " else:\n", 268 | " print('load previously generated channels')\n", 269 | " h_list_test_path = args.path_for_test_channels + '/' + 'test_channels.pckl'\n", 270 | " f_test_channels = open(h_list_test_path, 'rb')\n", 271 | " h_list_test = pickle.load(f_test_channels)\n", 272 | " f_test_channels.close()\n", 273 | "\n", 274 | " if len(h_list_test) > args.num_channels_test:\n", 275 | " h_list_test = h_list_test[:args.num_channels_test]\n", 276 | " print('used test channels', h_list_test)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 152, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "num_adapt = 1 # training iteration number\n", 286 | "ind_channel = 0 # index of channel (wheter phase pi/4 (even idx) or 3pi/4 (odd idx))\n", 287 | "if_RTN = False # need to change this in order to switch between vanilla (False) and autoencoder with RTN (True)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 153, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "### vanilla\n", 297 | "# meta\n", 298 | "PATH_before_adapt_meta_vanilla = 'saved_data/toy/nets/vanilla/meta/init_net'\n", 299 | "PATH_after_adapt_meta_vanilla = 'saved_data/toy/after_adapted_to_new_channels_nets/vanilla/meta/after_adapt/' + str(num_adapt) + '_adapt_steps/' + str(ind_channel) + 'th_adapted_net' \n", 300 | "\n", 301 | "### with rtn\n", 302 | "# meta\n", 303 | "PATH_before_adapt_meta_rtn = 'saved_data/toy/nets/with_RTN/meta/init_net'\n", 304 | "PATH_after_adapt_meta_rtn = 'saved_data/toy/after_adapted_to_new_channels_nets/with_RTN/meta/after_adapt/' + str(num_adapt) + '_adapt_steps/' + str(ind_channel) + 'th_adapted_net' \n" 305 | ] 306 | }, 307 | { 308 | "cell_type": "code", 309 | "execution_count": 154, 310 | "metadata": {}, 311 | "outputs": [], 312 | "source": [ 313 | "# choose appropriate path for visualization\n", 314 | "current_PATH = PATH_before_adapt_meta_vanilla\n", 315 | "#current_PATH = PATH_after_adapt_meta_vanilla\n", 316 | "#current_PATH = PATH_before_adapt_meta_rtn\n", 317 | "#current_PATH = PATH_after_adapt_meta_rtn\n", 318 | " \n", 319 | "net_for_vis = dnn_vis(M=pow(2, args.bit_num), num_neurons_encoder=args.num_neurons_encoder, n=actual_channel_num, n_inv_filter = args.tap_num,\n", 320 | " num_neurons_decoder=args.num_neurons_decoder, if_bias=args.if_bias, if_relu=args.if_relu, if_RTN=if_RTN)\n", 321 | "if torch.cuda.is_available():\n", 322 | " net_for_vis = net_for_vis.to(args.device)\n" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 155, 328 | "metadata": {}, 329 | "outputs": [ 330 | { 331 | "data": { 332 | "text/plain": [ 333 | "" 334 | ] 335 | }, 336 | "execution_count": 155, 337 | "metadata": {}, 338 | "output_type": "execute_result" 339 | } 340 | ], 341 | "source": [ 342 | "# load saved autoencoder\n", 343 | "net_for_vis.load_state_dict(torch.load(current_PATH))\n" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 156, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "def message_gen_tmp(k, mb_size):\n", 353 | " tot_message_num = pow(2,k)\n", 354 | " m = torch.zeros(mb_size, tot_message_num)\n", 355 | " label = torch.zeros(mb_size)\n", 356 | " for ind_mb in range(mb_size):\n", 357 | " if ind_mb % tot_message_num == 0:\n", 358 | " pass\n", 359 | " #rand_lst = torch.randperm(tot_message_num) # remove randomness for simpleness.\n", 360 | " ind_one_rand_lst = ind_mb % tot_message_num\n", 361 | " ind_one = ind_one_rand_lst\n", 362 | " m[ind_mb, ind_one] = 1\n", 363 | " label[ind_mb] = ind_one\n", 364 | " return m, label" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 157, 370 | "metadata": {}, 371 | "outputs": [], 372 | "source": [ 373 | "# visualizing constellation points\n", 374 | "\n", 375 | "num_datapoints_transmitted = 1000\n", 376 | "h = h_list_test[ind_channel]\n", 377 | "artificial_input = None # for decision region, now should be None\n", 378 | "m_test, label_test = message_gen_tmp(args.bit_num, num_datapoints_transmitted)\n", 379 | "m_test = m_test.type(torch.FloatTensor).to(args.device)\n", 380 | "label_test = label_test.type(torch.LongTensor).to(args.device)\n", 381 | "net_for_vis.zero_grad()\n", 382 | "out_test, modulated_symbol, received_signal = net_for_vis(m_test, h, Noise, args.device,if_RTN, artificial_input)\n" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": 158, 388 | "metadata": {}, 389 | "outputs": [ 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "h tensor([0.7071, 0.7071])\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "print('h', h)\n", 400 | "modulated_symbol = modulated_symbol.data.cpu().numpy()\n", 401 | "m_test = m_test.data.cpu().numpy()\n", 402 | "received_signal = received_signal.data.cpu().numpy()" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 159, 415 | "metadata": {}, 416 | "outputs": [ 417 | { 418 | "data": { 419 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAF1CAYAAADoc51vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAYpElEQVR4nO3de5CldX3n8fcHZgai43pjIggoEsmsJFsxQqHExGrEbCGVhRh1g1W7gUQzmsRsVRKrFktXFks3miXJ6mrWzBpWdDdc4gad6BgUtYsYVwVSEEEcHAmGQQJyEdOroWfku3+cX7vHtrunp89zbt3vV9WpeW7n+f2+z5l+Pue5nHNSVUiSdNi4OyBJmgwGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0EbTJJK8sw2/L4kbxlgXXNJTuyud4ObxD5pehgImlpJTmg7+E0jaGs2yav6p1XV1qq6Y9htH4pD6VN/OEpgIEiSGgNBnUpyfJI/T/KNJA8keVebfliSNyb5WpL7krw/yePbvIV3+ucn+fsk9yd5Q986T0tyQ5JvJbk3yR+0Wde1f7/ZTpWc3pb/lSS3JXkoyTVJnr6Kfj8xyUdavx9qw8e1eW8FfgZ4V2tnoab+00+PbzV9o9X4xiSHtXkXJPlMkkvauv8uyYtX6MudSV6f5Ett+f+R5Mi++b+aZG+SB5PsSvLUvnmLT4m9O8lHk/xjks8n+ZE2b2Hb3dxq+sUkR7W6v9nW/VcLNWiDqCofPjp5AIcDNwN/CDwWOBL46TbvV4C9wInAVuDPgQ+0eScABfx34IeAnwAeAZ7V5v8f4N+24a3A8xY9b1NfH85t7TwL2AS8Efhs3/wCntmG3we8pQ0/GXgp8BjgccCfAR/qe94s8KpF9fav6/3Ah9tzTwBuB17Z5l0A7Ad+tW2jXwO+DmSZ7XgncAtwPPAk4K/7+vlC4H7gOcARwH8FrluhvgeA09q2+F/AFUst28Z/F3gPsLk9fma5PvpYn4+xd8DH+nkApwPf6N9B9837JPDrfePb205yU9+O/bi++V8AzmvD1wEXA0ctWudSgfCxhR1xGz8M+Dbw9Da+ZCAs0d9nAw/1jS8bCG0nPw+c3Dfv1cBsG74A2Ns37zHtuUcv0/adwGv6xs8GvtqG/wT4vb55W9t2PGGZ+t67aD1fXtz/vvE30wu1Zy7VLx/r/+HhoLp0PPC1qjqwxLynAl/rG/8avTB4St+0f+gb/ja9nR3AK4EfBb6c5PokP7dCH54OvKOd9vgm8CAQ4NiVOp7kMUn+uJ3u+Ra9EHpCksNXel5zFL131Ivr62/ze7VV1bfb4FaWd9eidS2cFvq+7VhVc/SOAparb7ltupT/TO/o6uNJ7khy4QrLah0yENSlu4CnLXPXz9fp7awXPA04ANx7sJVW1Veq6hXADwNvBz6Y5LH03uEu1YdXV9UT+h4/VFWfPUgzv0PvqOW5VfXPgBe06VnoxgrPvZ/eu/TF9d19kDZXcvyidX29DX/fdmzb4ckDtgVAVf1jVf1OVZ0InAP8dpIzB12vpoeBoC59AbgHeFuSxyY5Msnz27zLgd9K8owkW4H/BFy5zNHE90nyb5Jsq6pHgW+2yY/SOz31KL3rEgveA7w+yY+15z4+yctX0ffHAd+hd4H6ScBFi+bfu6id76mq7wJXAW9N8rh2Efu3gf+5inaX8xtJjmt9eQNwZZt+OfDLSZ6d5Ah62/HzVXXnGtr4vpqS/FySZyYJ8DDwXXrbVxuEgaDOtB3jv6J3Xv3vgX3AL7bZlwIfoHcq5u+AfwJ+c5WrPgu4Nckc8A561xa+0069vBX463aK6HlVdTW9o4gr2qmfW4Bl7+jp81/oXdC+H/gc8JeL5r8DeFm76+edSzz/N4H/C9wBfAb401bzWv0p8PG2vq8CbwGoqmuB/wD8b3rh+yPAeWts4z8Cl7Vt96+Bk4BrgTl6F/L/qKo+PUANmjKp8gdypEmS5E56F7CvHXdftLF4hCBJAjoKhCSXtg8b3bLM/JkkDye5qT3e1EW7kqTudPUdMO8D3kXvwznL+auqWul2QUlAVZ0w7j5oY+rkCKGqrqN3v7ckaUqN8hrC6UluTvKxhVsCJUmTY+hfG9z8Db2vDphLcjbwIXq3uP2AJDuAHQBHHnnkKU972tNG1MXRevTRRznssPV7Td/6ppv1Ta/bb7/9/qratpbndnbbaZITgI9U1Y+vYtk7gVOr6v6Vltu+fXvt2bOnk/5NmtnZWWZmZsbdjaGxvulmfdMryY1VdepanjuSiExydPv0I0lOa+0+MIq2JUmr08kpoySXAzPAUUn20fvY/2aAqnoP8DLg15IcoPf1AOeVn4iTpInSSSC0Lx5baf676N2WKkmaUOvzqook6ZAZCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUbBp3B7TBJatftmp4/Vgn3JwahIGg8TiUPdfi57gn+wFuTnXBQNBorWXPtdw63JO5OdUpryFoNLZs6Wbv1S/prXcDcnNqGAwEDd+WLbB//3DWvX//htuLuTk1LJ0EQpJLk9yX5JZl5ifJO5PsTfK3SZ7TRbuaAsPcey0Y9voniJtTw9TVEcL7gLNWmP9i4KT22AH8t47a1aQb1d6l6/MnE2pUm/PGG0fTjiZLJ4FQVdcBD66wyLnA+6vnc8ATkhzTRduaYKPeSa/zUHBzathSHd1akOQE4CNV9eNLzPsI8Laq+kwb/yTw76vqhiWW3UHvKIJt27adctVVV3XSv0kzNzfH1q1bx92NoZmbm2Prnj2jb/iUU0bSzDhev1G+az/uuDn27ds6qs05cuv57++MM864sapOXctzJ+6206raCewE2L59e83MzIy3Q0MyOzvLeq0NYPb3f5+Z171uPI2P4P7JUb9+o363fskls7zudTPA+rwddb3//a3VqO4yuhs4vm/8uDZNkjQhRhUIu4BfancbPQ94uKruGVHbGrVxn3wed/sdG3c5425fo9PJKaMklwMzwFFJ9gEXAZsBquo9wG7gbGAv8G3gl7toV5LUnU4CoapecZD5BfxGF21JkobDTypLkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIKh74/42tHG337FxlzPu9jU6BoIkCTAQNCzj+iL9dfp2dlxlrdPNqWUYCJIkwEDQMI367eU6fzvr5tSwGQgars2bR9POBtl7jWpzrtefztTKDAQN1/z88Pdio9pLTgA3p4bJQNDwDXMvVtVb/wbi5tSwGAgajfn57k/rbJDTREtxc2oYOvnFNGnVFvY6g/xQr3uu73FzqksGgsZjLXsy91zLcnOqCwaCxsu9UqfcnBqE1xAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSUBHgZDkrCR7kuxNcuES8y9I8o0kN7XHq7poV5LUnU2DriDJ4cC7gZ8F9gHXJ9lVVV9atOiVVfXaQduTJA1HF0cIpwF7q+qOqpoHrgDO7WC9kqQRGvgIATgWuKtvfB/w3CWWe2mSFwC3A79VVXctsQxJdgA7ALZt28bs7GwHXZw8c3Nz67Y2sL5pZ30bUxeBsBp/AVxeVY8keTVwGfDCpRasqp3AToDt27fXzMzMiLo4WrOzs6zX2sD6pp31bUxdnDK6Gzi+b/y4Nu17quqBqnqkjb4XOKWDdiVJHeoiEK4HTkryjCRbgPOAXf0LJDmmb/Qc4LYO2pUkdWjgU0ZVdSDJa4FrgMOBS6vq1iRvBm6oql3Av0tyDnAAeBC4YNB214Mkq162qobYE0nq6BpCVe0Gdi+a9qa+4dcDr++irfUgCZdccskhPwcMBknDM6qLyuLQjggOtg6DQVLX/OqKEdiyZUsnYdAvCVu2bOl0nZI2NgNhyLZs2cL+/fuHsu79+/cbCpI6YyAM0TDDYMGw1y9p4zAQhmhUO+uuT0dJ2pgMhCEZ9U7aUJA0KANBkgQYCEMxrnfrHiVIGoSBIEkCDITOjftd+rjblzS9DARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDoWPj/q3jcbcvaXoZCJIkwEAYinG9S/foQNIgDARJEmAgDM2o3617dCBpUAbCEG3evHkk7RgGkrpgIAzR/Pz80ENhVKEjaf0zEIZsmKFQVczPzw9l3ZI2HgNhBObn5zs/reNpIkld2zTuDmwkCzvxQX732CCQNCweIYxBVXHKKacc8nMMA0nD5BHCGLmDlzRJPEKQJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkS0FEgJDkryZ4ke5NcuMT8I5Jc2eZ/PskJXbQrSerOwIGQ5HDg3cCLgZOBVyQ5edFirwQeqqpnAn8IvH3QdiVJ3eriCOE0YG9V3VFV88AVwLmLljkXuKwNfxA4M4P8bJgkqXNdBMKxwF194/vatCWXqaoDwMPAkztoW5LUkYn7xbQkO4AdANu2bWN2dna8HRqSubm5dVsbWN+0s76NqYtAuBs4vm/8uDZtqWX2JdkEPB54YKmVVdVOYCfA9u3ba2ZmpoMuTp7Z2VnWa21gfdPO+jamLk4ZXQ+clOQZSbYA5wG7Fi2zCzi/Db8M+FT5g8KSNFEGPkKoqgNJXgtcAxwOXFpVtyZ5M3BDVe0C/gT4QJK9wIP0QkOSNEE6uYZQVbuB3Yumvalv+J+Al3fRliRpOPyksiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQMFAhJnpTkE0m+0v594jLLfTfJTe2xa5A2JUnDMegRwoXAJ6vqJOCTbXwp36mqZ7fHOQO2KUkagkED4VzgsjZ8GfDzA65PkjQmqaq1Pzn5ZlU9oQ0HeGhhfNFyB4CbgAPA26rqQyuscwewA2Dbtm2nXHXVVWvu3ySbm5tj69at4+7G0FjfdLO+6XXGGWfcWFWnruW5Bw2EJNcCRy8x6w3AZf0BkOShqvqB6whJjq2qu5OcCHwKOLOqvnqwzm3fvr327NlzsMWm0uzsLDMzM+PuxtBY33SzvumVZM2BsOlgC1TVi1Zo+N4kx1TVPUmOAe5bZh13t3/vSDIL/CRw0ECQJI3OoNcQdgHnt+HzgQ8vXiDJE5Mc0YaPAp4PfGnAdiVJHRs0EN4G/GySrwAvauMkOTXJe9syzwJuSHIz8Gl61xAMBEmaMAc9ZbSSqnoAOHOJ6TcAr2rDnwX+xSDtSJKGz08qS5IAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAmATePugLQWuTirXrYuqiH2RFo/DARNlUMJgsXPMRiklRkImgprCYLl1mEwSEszEDTRtly8hf3s73SduThsZjPzF813ul5p2nlRWRNrGGGwYD/72XLxlqGsW5pWBoIm0jDDYMGw1y9NGwNBE2lUO+surk1I64WBoIkz6p20oSD1GAiSJMBA0IQZ17t1jxIkA0GS1BgImhjjfpc+7valcTMQJEmAgSBJagwESRJgIEiSmoECIcnLk9ya5NEkp66w3FlJ9iTZm+TCQdqUJA3HoEcItwC/AFy33AJJDgfeDbwYOBl4RZKTB2xXktSxgb7+uqpuA0hWvF3vNGBvVd3Rlr0COBf40iBtS5K6NYrfQzgWuKtvfB/w3OUWTrID2AGwbds2Zmdnh9q5cZmbm1u3tcHa6rvkRy8ZTmcOwWr77Os33dZ7fWt10EBIci1w9BKz3lBVH+66Q1W1E9gJsH379pqZmem6iYkwOzvLeq0N1lbfGRefMZzOHIJ6xep+Tc3Xb7qt9/rW6qCBUFUvGrCNu4Hj+8aPa9MkSRNkFLedXg+clOQZSbYA5wG7RtCupsy4f+t43O1L4zbobacvSbIPOB34aJJr2vSnJtkNUFUHgNcC1wC3AVdV1a2DdVuS1LVB7zK6Grh6ielfB87uG98N7B6kLW0MdVGN5UvmPDqQ/KSyJKkxEDRxRv1u3aMDqcdA0ETazOaRtGMYSP+fgaCJNH/R/NBDYVShI00LA0ETa5ihUBcV8xfND2Xd0rQyEDTR5i+a7/y0jqeJpKWN4ruMpIEt7MQHuSXVIJBWZiBoqqwlGAwCaXUMBE0ld/JS97yGIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUDBQISV6e5NYkjyY5dYXl7kzyxSQ3JblhkDYlScOxacDn3wL8AvDHq1j2jKq6f8D2JElDMlAgVNVtAEm66Y0kaWxGdQ2hgI8nuTHJjhG1KUk6BAc9QkhyLXD0ErPeUFUfXmU7P11Vdyf5YeATSb5cVdct094OYCE0HklyyyrbmDZHAev5FJr1TTfrm17b1/rEgwZCVb1orSvvW8fd7d/7klwNnAYsGQhVtRPYCZDkhqpa9mL1NFvPtYH1TTvrm16D3Lgz9FNGSR6b5HELw8C/pHcxWpI0QQa97fQlSfYBpwMfTXJNm/7UJLvbYk8BPpPkZuALwEer6i8HaVeS1L1B7zK6Grh6ielfB85uw3cAP7HGJnauvXcTbz3XBtY37axveq25tlRVlx2RJE0pv7pCkgRMUCCs96/BOIT6zkqyJ8neJBeOso+DSPKkJJ9I8pX27xOXWe677bW7KcmuUffzUB3s9UhyRJIr2/zPJzlh9L1cm1XUdkGSb/S9Xq8aRz/XKsmlSe5b7tb19Lyz1f+3SZ4z6j6u1Spqm0nycN9r96ZVrbiqJuIBPIve/bOzwKkrLHcncNS4+zuM+oDDga8CJwJbgJuBk8fd91XW93vAhW34QuDtyyw3N+6+HkJNB309gF8H3tOGzwOuHHe/O6ztAuBd4+7rADW+AHgOcMsy888GPgYEeB7w+XH3ucPaZoCPHOp6J+YIoapuq6o94+7HsKyyvtOAvVV1R1XNA1cA5w6/d504F7isDV8G/PwY+9KV1bwe/XV/EDgz0/FdLtP8f21Vqvfh1wdXWORc4P3V8zngCUmOGU3vBrOK2tZkYgLhEKznr8E4Frirb3xfmzYNnlJV97Thf6B3u/FSjkxyQ5LPJZn00FjN6/G9ZarqAPAw8OSR9G4wq/2/9tJ2OuWDSY4fTddGZpr/3lbj9CQ3J/lYkh9bzRMG/bbTQzLqr8EYtY7qm1gr1dc/UlWVZLnb157eXr8TgU8l+WJVfbXrvqoTfwFcXlWPJHk1vSOhF465T1qdv6H3tzaX5GzgQ8BJB3vSSAOhRvw1GKPWQX13A/3vwo5r0ybCSvUluTfJMVV1Tzvsvm+ZdSy8fnckmQV+kt657Em0mtdjYZl9STYBjwceGE33BnLQ2qqqv4730rtOtJ5M9N/bIKrqW33Du5P8UZKj6iA/QTBVp4w2wNdgXA+clOQZSbbQu0g58XfiNLuA89vw+cAPHBEleWKSI9rwUcDzgS+NrIeHbjWvR3/dLwM+Ve2q3oQ7aG2LzqefA9w2wv6Nwi7gl9rdRs8DHu477TnVkhy9cC0ryWn09vUHf6My7qvlfVfFX0LvHN4jwL3ANW36U4HdbfhEendD3AzcSu9UzNj73lV9bfxs4HZ675qnqb4nA58EvgJcCzypTT8VeG8b/ingi+31+yLwynH3exV1/cDrAbwZOKcNHwn8GbCX3leznDjuPndY2++2v7ObgU8D/3zcfT7E+i4H7gH2t7+9VwKvAV7T5gd4d6v/i6xwd+OkPVZR22v7XrvPAT+1mvX6SWVJEjBlp4wkScNjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkC4P8BSvn3yBxv0MwAAAAASUVORK5CYII=\n", 420 | "text/plain": [ 421 | "
" 422 | ] 423 | }, 424 | "metadata": { 425 | "needs_background": "light" 426 | }, 427 | "output_type": "display_data" 428 | } 429 | ], 430 | "source": [ 431 | "\n", 432 | "fig, ax = plt.subplots(figsize=(6,6))\n", 433 | "xmin = -1.5\n", 434 | "xmax = 1.5\n", 435 | "ymin = -1.5\n", 436 | "ymax = 1.5\n", 437 | "ax.set_xlim([xmin,xmax])\n", 438 | "ax.set_ylim([ymin,ymax])\n", 439 | "color_list = ['r', 'g', 'b', 'k']\n", 440 | "for j in range(num_datapoints_transmitted):\n", 441 | " c_ind = j%4\n", 442 | " ax.scatter(modulated_symbol[j,0], modulated_symbol[j,1], c=color_list[c_ind], label = m_test[c_ind], s = 1000)\n", 443 | "ax.grid(True)\n", 444 | "plt.title('constellation points')\n", 445 | "plt.show()\n" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 160, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAF1CAYAAADoc51vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dfXRcd33n8fdXihUsHBykuDzY0Si0TtqUFEpSIIXTdTBtEy0kbRY40EnqBFjVyXJqdsu2aWcLC622dDdnF3MgMd40qcEqxMtjCKKUOHHpEy1OS3AgDZjgcWygYCm4UexGsvTdP+699mh07zxo7szcK31e58zRPNy593cl+37v7/f9PZi7IyIi0tPtAoiISDYoIIiICKCAICIiIQUEEREBFBBERCSkgCAiIoACgqwgZrbDzH6/Dfv972a2ewnf+7qZbUq7PDHH2Wdmb233cST/zup2AUQ6xd23drsMldz9p7tdBpFKqiFIppiZblJEukQBQbrOzA6Z2e+Y2deAp8zsLDN7vpl9wsx+aGbfMbPfrNi+18x+z8y+bWZPmtmDZnZ++NlPmtkXzWzKzB41szdUfO9PzewPw+ePmNlrKj47KzzWS8LXLzezvzWzH5nZQ5VNO2Z2gZn9ZXjsLwLn1Ti388zs3nA/U2b2V2bWU3Herw6frzazXWb2RFi23zazI1W/o3eY2dfM7LiZ3W1mzwg/e3Z4jB+G37/XzDa0+GeRFUgBQbLiTcC/B84F5oHPAg8B64HNwNvN7JfDbf9LuP0I8CzgzcAJM3sm8EXgz4AfA94I3GZmF8cc76PhPiK/DBxz9380s/XA54A/BAaAdwCfMLN14bZ/BjxIEAj+ANhS47x+CzgCrAOeA/weEDdfzLuAYeAFwC8C18Vs8wbgSuAC4GeAG8L3e4C7gAIwBJwEPlCjTCKxFBAkK97v7o+7+0ng54B17v4ed59x98eA/0twgQd4K/Df3P1RDzzk7pPAa4BD7n6Xu59y938CPgG8PuZ4fwZcbWb94etfIwgSEFyMJ9x9wt3n3f2LwH5gxMyGwvL9vrs/7e5fIgheSWaB5wEFd59197/y+AnE3gD8D3d/wt2PAO9P+B19192nwmO+GMDdJ939E+5+wt2fBMaAf1ejTCKxFBAkKx6veF4Anh82s/zIzH5EcGf9nPDz84Fvx+yjALys6ntF4LnVG7r7QeAR4LVhULiaIEhE+3l91X5eSXBhfz7whLs/VbG7co3z+l/AQeAvzOwxM7slYbvnV/0OHo/Z5vsVz08AawDMrN/MPmRmZTP7V+BLwLlm1lujXCKLKIEnWVF51/w48B1335iw7ePAjwMPx7z/l+7+iw0eM2o26gG+EQaJaD8fcff/WP0FMysAzzazZ1YEhSHim4EI79h/C/gtM3shcL+ZfcXd91Zt+j1gA/CN8PX5DZ4D4f4vAl7m7t83sxcD/wRYE/sQUQ1BMukfgCfDRPPqMIn8QjP7ufDzO4A/MLONFvgZMxsE7gUuNLPrzWxV+Pg5M/uphON8DPgl4CbO1A4AdhPUHH45PPYzzGyTmW1w9zJB89G7zazPzF4JvDbpRMzsNWb2E2ZmwHFgjiBHUm0P8Lthgng98LZGf1nAOQR5gx+Z2QBBPkKkaQoIkjnuPkeQD3gx8B3gGEEQWBtu8r8JLqB/Afwr8CfA6vBu/JcIcg3fJWhi+WPg7ITjfA/4O+Dngbsr3n8cuIagmeqHBDWG/8qZ/y+/BrwMmCK4+H64xulsBO4DpsNj3ebuD8Rs9x6C5PN3wu0/DjxdY7+V3gesJvg9fRn48wa/J7KAaYEckewxs5uAN7q7ksPSMaohiGSAmT3PzF5hZj1mdhFBXuBT3S6XrCypBAQzu9PMfmBm1Um+6PNN4WCar4aPd6ZxXJFlpA/4EPAkcD/wGeC2rpZIVpxUmozM7BcI2kg/7O4vjPl8E/AOd39N9WciIpINqdQQwsE5U2nsS0REuqOTOYTLwzlhPm9mmuVRRCRjOjUw7R8Jhu5Pm9kI8GmC7niLmNkoMArwjGc849KhoaEOFbGz5ufn6elZvjl9nV++6fzy65vf/OYxd19Xf8vFUut2ambDwL1xOYSYbQ8Bl7n7sVrbXXTRRf7oo4+mUr6s2bdvH5s2bep2MdpG55dvOr/8MrMH3f2ypXy3IyHSzJ4bjtTEzF4aHneyE8cWEZHGpNJkZGYfBTYB54VzuL8LWAXg7juA1wE3mdkpgiH2b0yY8VFERLoklYDg7m+q8/kH0PzsIiKZtjyzKiIi0jQFBBERARQQREQkpIAgIiKAAoKIiIQUEEREBFBAEBGRkAKCiIgACggiIhJSQBAREUABQUREQgoIIiICKCCIiEhIAUFERAAFBBERCSkgiIgIoIAgIiIhBQQREQEUEEREJKSAICIigAKCiIiEFBBERARQQBARkZACgoiIAAoI0kHj4zA8DD09wc/x8W6XSEQqndXtAsjKMD4ON94Is7PB63I5eA1QLHavXCJyhmoI0hHbtp0JBpHZ2eB9EckGBQTpiMnJ5t4Xkc5TQJDcUA5CpL0UEKQjBgcbfz/uwj8+DqOjQe7BPfg5OqqgIJImBQTpiO3boa9v4Xt9fcH7lZIu/Nu2wYkTC7c9cQJKpfaWW2QlUUCQjigW4c47oVAAs+DnnXcu7mFUKsVf+JNyDYcPqylJJC3qdiodUyzW72J6+HBz+xwYCGoQURCJahTR8USkcaohSKYMDcW/PzgI/f0L34teqylJJB0KCJIpY2PxF/7t22HnzoVNTjt3wtRU/H6arWmIiAKCdFl1+z/EX/ij5qZDh2B+PvhZLCbXKJLeF5FkCgjSNUk9imDxhT9JUo1ibKxdpRZZvhQQpGuSehRt23am1nDeecGjekxC9HmpBFu2xNcoRKQ5CgjSNeVy/PuTk2dqDZOTwSOqQdxwQzApXmWtYteuoEbQSI0ioq6qIospIEjbJY08Nmt+X6dOLZ4kr16vourj33yzRj2LxFFAkFRUX3Sj3j/j4/DmNy+8+L75zUGzkHt6xy+XFzctRcevvvjv2KGuqiJxNDBNWhZddCsHh5XLwZ347bcv3n5mpj2znEb7rExOx+UpkgKRuqrKSqcagjQlrvkn7qI7Px8fDDoluuNv5iKvrqqy0qmGIA2LqwlUvs6aw4eDi3xc8tpsYU1BXVVFVEOQJiR1E+3t7U556hkaSh6nsHWruqqKVFMNQRqW1PwyNxdcZLNUU4ju+KOLfNR8FAUJXfxFFkulhmBmd5rZD8zs4YTPzczeb2YHzexrZvaSNI4rnZXUxr5mzdK6kKatJ/zXXH3HHzflhYgsllaT0Z8CV9b4/CpgY/gYBbqYbpSlGhtbvMgNwPQ0PPVU58tTbX5+cc1ARBqXSkBw9y8BCfNOAnAN8GEPfBk418yel8axpXOKRTjnnG6XojaNJxBZOvOURgeZ2TBwr7u/MOaze4H3uvtfh6/3Ar/j7vtjth0lqEWwbt26S/fs2ZNK+bJmenqaNWvWdLsYTXvwwca227BhmiNH2nd+Z50VjFpO0tcXjHfo64P164OFdNKU179fo3R++XXFFVc86O6XLenL7p7KAxgGHk747F7glRWv9wKX1dvnhRde6MvVAw880O0iLEmh4B502Kz9uPXWBxrarhOP/n733bvT/T3k9e/XKJ1ffgH7fYnX8U51Oz0KnF/xekP4nuTM2BisWtXtUjRHzUgijelUQLgH+PWwt9HLgePu/r0OHVtSVCzCs57V7VI0T9NSiNSXyjgEM/sosAk4z8yOAO8CVgG4+w5gAhgBDgIngBvTOK50R9KylVmmaSlE6kslILj7m+p87sB/SuNY0n1J00FklaalEGmMpq6Qpo2NZWMgWiM0LYVI4xQQpGnFYjAXUNaDQm9vkDsolbT4jUgjFBBkSV7xivT79qdtbk4rook0QwFBmhZNg92ORW7a5cQJ2LJFayiL1KLZTqVpcdNg58HcXPCzckU15RZEzlANQRpSuVJannoYJdFgNZHFVEOQuqpXSlsuNFhNZCHVEKSuvDYR1aPBaiILKSBIXbWaiKIlKOPWSciS6i6yGqwmspgCgtSVtGZyb2+wKM3YWDDVdJa5aw1lkXqUQ5C6ot45Se/nITlbKATLZ4pIMtUQpK5Cofb7WU/OqnlIpDEKCFLX2FhwUa1UeZHNcnJWzUMijVNAkLqKxeCimtQGPzLS3fLF6e+H3buDZiIFA5HGKIcgDSkWky+sExOdLUuczZvh4MGg+WpoKKi9KBCINEcBQVqWhRzCwYNKGou0Sk1G0rIs5BCyEJRE8k4BQVoWl3SOJI1hSFsWgpJI3ikgSMviks7r1gU/k8YwpEndSkXSoYAgqSgWgzb8aOTy5GT9WVF7e4OeQM985tKPOziobqUiaVFAkNSVSkFgqGd+PriQNzNxXk/PmVrI7t1w7JiCgUha1MtIUtdogjdq9x8aanyNhWc/OwgCIpI+1RAkdY0keM2CIDA8HAxsS0pKV5uaWrhYj5bDFEmPAoKkbmwsuFhX6usL2vshCAbuwfNyGXbsgMsvX5iUjratNjAQLNZTLgf7iJbDVFAQaZ0CgqSuWAwu6pUX+DvvDJp6CoUzwSDiDnv3wvQ0fOQjQXJ6+/b4+ZNgcc5By2GKpEM5BGmLgYH4kcO18guTk8HdPpxJFJdKC6ejuP76+O9qYJpI61RDkI6ql1+ovNuv7MoaTVKX9H0NTBNpnQKCdNTY2OLlLKvVutuvNxW3iCydAoJ0VLEIW7fWDgq17vbrTcUtIkungCAdd9ttQfI4ridRI3f7cU1JItI6BQTpimIx6HW0e7fu9kWyQr2MpKtqLbwjIp2lGoKIiAAKCCIiElJAyANN3iMiHaCA0AmtXNDHxzV5j4h0hAJCuyVd0KemGvt+qaTJe0SkIxQQ2i3pgn70aGPfT1ooQJP3iEjKFBDaLenCPTNT/7vj48lDeluZvEc5CRGJoYDQbkkX7r6++t8tlRbPFQ1BkFjq5D3KSYhIAgWEdkuajW39+vrfTapduC8czVXvjr/y8y1blJMQkVgKCO2WNBvbwED97z7zmfXfr3fHX/353Fz8PpWTEFnxFBA6oVgMagpDQ8GFt1RKXhw4es8sWEIszlNPnbngJyWtt21L/jyOFhQQWfE0l1EnRHfp0YW5XA6m6bzlFpidPfPeddc1vs9SKQg0SXf2k5PBcRu589eCAiKCagidEXeX7n4mGCxFdKGvdWe/bVvy5729mmJURBZIJSCY2ZVm9qiZHTSzW2I+v8HMfmhmXw0fb03juLnRjvb56EJf685+chJGRuKT2rt2aUEBEVmg5YBgZr3AB4GrgIuBN5nZxTGb3u3uLw4fd7R63Eyrzg00kkBuRtTEMz5ev3fQHXcEPYu06ICI1JFGDuGlwEF3fwzAzD4GXAN8I4V9509cvmDVKjjrLDh1Kp1jrF4Nf/M3wV1+vYTx7Czs2ROsRiMiUoN53MCnZnZg9jrgSnd/a/j6euBl7v62im1uAP4I+CHwTeA/u/vjCfsbBUYB1q1bd+mePXtaKl9bTU0FU1DMzAQDzdavP/O6jukNG1hz5EgHChm69NLOHQuYnp5mzZo1HT1mJ+n88m05n98VV1zxoLtftqQvu3tLD+B1wB0Vr68HPlC1zSBwdvj8N4D7G9n3hRde6Jm1e7d7f797kB4OHtWvazweuPXWhrdN5VEouJsFP3fvbvuv54EHHmj7MbpJ55dvy/n8gP2+xOt5Gknlo8D5Fa83hO9VBp1Jd386fHkH0Nnb1XZI6v+fVZqqQkTqSCMgfAXYaGYXmFkf8EbgnsoNzOx5FS+vBh5J4bjdleeRvZqqQkRitJxUdvdTZvY24AtAL3Cnu3/dzN5DUHW5B/hNM7saOAVMATe0etyuGxpKnpo6D/Ic0ESkLVIZh+DuE+5+obv/uLuPhe+9MwwGuPvvuvtPu/uL3P0Kd//nNI7bVXGT1uVJT4+ajURkAU1dsVRRP/4tW5InjMuyuTl485uD5xqTICJo6orWFIvBWIC81hRmZs5MgiciK54CQquqp7fOm8nJ4KdWURNZ8RQQ0lAsBnMCfeQj3S7J0mgVNRFBASE90UU1bwYHk8dUqGuqyIqigJCWbduyPTAtzqpVsH17chdUdU0VWVEUENIwPn6mLT7LNm9eOOvpXXcFzV1JayZoFTWRFUUBIQ15aFrZvBluvDH+s7gxFVpFTWTF0TiENGS5aaW3N+gaC4un5Y5yHtE4hFIpOJehoSAYaHyCyIqigJCGLE9jMToaXNiHh5MTx8XimYeIrFhqMmpF1He/XM7uGITbbz9TxjhZrt2ISEcpICxVZd99CPrvR0FhcLB75YpTK2ApcSwiIQWERsSN4o3ru+8e9N7J4kpMlQErosSxiFRQDqGeuDWSK19Xy3ITTBSwlDgWkRgKCPUkjeLt7Y2f5TRqgslikrlQCKbYEBGJoYBQT9Id/9xc0ORSGSwqm2BuvBFmZ9tfvkapeUhE6lAOoZ6kpGuhsHCW0+h11H3zrruyk1yuLJuISAIFhHpqjeKNZjmdnz/TFBMln0ulYJ6g3bu7u17C4GBQNgUDEalDTUb1NDqKNy75fN11wQX58sth797OljuShzmWRCQTFBAa0cgo3rjkMwQX5Pvvb0+5RERSpCajtNTqbureuXJUy0oeQ0QyTwEhLVkd8bt9e7dLICI5oYCQlrjkcz19fe0pSyUlk0WkQQoIaSkWg66djTbRmMFb3hJ0CW1VT8KfMY19i8iKoYCQpmIRjh0Lupr29tbe1h0mJpZWs6g2P68FbkSkZQoI7VAsBhfpeg4fTu6d1Ixag+RERBqkbqft0siiOUND6UyGF42LUAAQkRaohtAu9ZqCenqCbVrtnTQ4qEAgIqlQDaFdqkc4DwwEr6emgiBQKMC11wbvVU+nbRY/dqH6/f5+dSsVkdSohtBOlXMdHTsWPKJ5j6IAEfVOqmz/37p1ce3CDF71KuUJRKRtVEPIgqT2/x07ztQI3OHv/k5BQETaRjWErJqYWNxsdOJE0AQlItIGCghZldT7KO79uDWfRUSapICQVUm9j6rfj6bdLpeDGkW05rOCgog0SQEhq2otzFMpac1nNS2JSJMUELIqrvdRXEK5maYlEZEa1MsoyxoZfZw0Ijqr03GLSGaphpB3jTYtiYjUoYCQd402LYmI1KEmo+VAE9uJSApUQxAREUABQUREQgoIIiICKCCIiEhIAUFERICUAoKZXWlmj5rZQTO7Jebzs83s7vDzvzez4TSOKyIi6Wk5IJhZL/BB4CrgYuBNZnZx1WZvAZ5w958A/g/wx60eV0RE0pVGDeGlwEF3f8zdZ4CPAddUbXMNsCt8/nFgs5lZCscWEZGUpBEQ1gOPV7w+Er4Xu427nwKOA4MpHFtERFKSuZHKZjYKjAKsW7eOffv2dbdAbTI9Pb1szw10fnmn81uZ0ggIR4HzK15vCN+L2+aImZ0FrAUm43bm7juBnQAXXXSRb9q0KYUiZs++fftYrucGOr+80/mtTGk0GX0F2GhmF5hZH/BG4J6qbe4BtoTPXwfc7169YLCIiHRTyzUEdz9lZm8DvgD0Ane6+9fN7D3Afne/B/gT4CNmdhCYIggaIiKSIankENx9Apioeu+dFc//DXh9GscSEZH20EhlEREBFBBERCSkgCAiIoACgoiIhBQQREQEUEAQEZGQAoKIiAAKCCIiElJAEBERQAFBRERCCgg5NT4+zvDwMD09PQwPDzM+Pt7tIolIzmVuPQSpb3x8nNHRUU6cOAFAuVxmdHQUgGKx2M2iiUiOqYaQQ6VS6XQwiJw4cYJSqdSlEonIcqCAkEPlcjn2/cOHD3e4JCKynCggdEh1m//U1NSS92NmsZ8NDQ21UkQRWeEUEDogavMvl8u4O+VymXK5vKREcKlUIm6xOTNjbGwsjeKKyAqlgNABcW3+8/PzS2rzT2oWcncllEWkJQoIHZB0EV9Km39Ss1ChUGh6XyIilRQQOiDpIt7T09P0OIKRkZGa72t8gogslQJCB4yNjdHf37/o/bm5udM5heuvv56bb7657r4mJiYS34/LVYyOjiooiEhDFBA6oFgssnPnTgqFAmZGb2/vom3cnR07dtS9eNdqftL4BBFphQJChxSLRQ4dOsT8/Dzz8/Ox27g727ZtS2zyGR8fp6cn/k8W1QjiaHyCiDRCAaELao0XmJycjG3yiZqD5ubmUj0eKO8gIgHNZdQFY2NjfP/7329o28omn+rmoEb09/fXHJ+geZFEJKIaQhcUi0XOOeechrc/fPhw080+ZkahUGDnzp01L+zKO4hIRDWELhgfH+epp55qePuBgQHWrFmTmCOoVigUOHToUEPbpjlGQkTyTTWEDqhuo9+2bVtiYjnOE088wcjISOIcRtXK5XLDuYCk/ILmRRJZeRQQ2ixubMDk5GRT+5ifn2fPnj1s3bq1qaDQyBiEuDES9fIOIrI8KSC0WVwb/VJMTk4yMTERO7FdkkZyAdVjJBrJO4jI8qSA0GZptsU3mkOoPn69bqWVYyQOHTqkYCCyQikgtEHlBThpIFmn9Pf3azoLEWmIAkLKqnMGSxlIlqannnpK3UpFpCEKCClLK2fQbupWKiLVFBBStpR2/m4YGhrSlBUisoAGpqUoWu+4mZ5A3dDf38/IyIimrBCRBVRDSFHSesdZMjg4yM6dO5mYmFBuQUQWUEBoUdTsYma5aC568skn2bZtm6bKFpFFFBCaVNnuft5553HjjTfmIhBEZmZmao6U1pQVIiuXcghNqJ4qutkpKLJOU1aIrGyqITQhL11Kl0JTVoiIaghNWK7t681Mly0iy5dqCE1YDu3rvb29C16rmUhEIgoITYibKjpO9UU3S84991wKhQIQlDPqaqpBaSKigNCEyqmikxQKhaYWv+m0qamp04EtmmdJE96JCCggNC2aKnr37t2JC8tkuWlpYGBA6yiLSCwFhCWqtbDMyMhIt4uX6Mknn9SgNBGJ1VIvIzMbAO4GhoFDwBvc/YmY7eaAA+HLw+5+dSvHzYpisbiom+b4+Di7du3qUonqm5mZobe3N3Za7izXbESk/VqtIdwC7HX3jcDe8HWck+7+4vCxLIJBkjyMVZibm9M6yiKySKsB4Roguh3eBfxKi/vLvTw0u0TNW9XNXYCmwxZZwayV2TnN7Efufm743IAnotdV250CvgqcAt7r7p+usc9RYBRg3bp1l+7Zs2fJ5euGAwcOMDMzE/tZ5dTYGzZs4MiRI50s2ukyDA8PMzAwsOD9qakpyuXygh5SPT09FAqFRds2Ynp6mjVr1rRc3qzS+eXbcj6/K6644kF3v2xJX3b3mg/gPuDhmMc1wI+qtn0iYR/rw58vIMg1/Hi947o7F154oefN7t27vb+/34EFDzPzm266yXfv3u2Dg4N+6623LtqmE4/BwcHYchcKhdjtC4XCkn4PDzzwwNJ/iTmg88u35Xx+wH5v4Poa96jbZOTur3b3F8Y8PgP8i5k9DyD8+YOEfRwNfz4G7AN+tt5x86RyBtRSqcSWLVsYHBxcsI27c/vtt3PXXXdx7Ngx+vr6ulLWpAn5kpq68tAEJiLpaDWHcA+wJXy+BfhM9QZm9mwzOzt8fh7wCuAbLR43M6IZUMvlMu5OuVyu2cto79693Hzzzaxfv76DpVwoLkeQ1MNIPY9EVo5WA8J7gV80s28Brw5fY2aXmdkd4TY/Bew3s4eABwhyCMsmICQN8qo1NXaUwO2WKHBVjk6Om5ZDPY9EVpaWAoK7T7r7ZnffGDYtTYXv73f3t4bP/9bdL3H3F4U//ySNgmfFUppU5ubmOHr0aBtK05zK0cm1BtqJyMqg6a9bNDQ0FDvyd3BwMLGW0Nvbm9gTqdMqA1rcQDsRWTk0dUWLkppatm/fzubNm2O/s3r1anp66v/qzzqr/fFaOQIRiSggtKhWU8t9993HTTfdRDBE44zp6emGZkQ9depUQ4Gjev+NbqccgYhUUkBIQTQD6vz8PIcOHVrQ7HLbbbe1dBfeSOAYGBhY1M212qpVq9i6datyBCKSSDmEDmh3X/5aPZoiz3rWs7jtttvaWg7JvvED45T2ljh8/DBDa4cY2zxG8RLdFEhANYQOWGoNob+/v+HmoHqmpqZS2Y/k1/iBcUY/O0r5eBnHKR8vM/rZUcYPaM4qCSggdECjS29Wipp0vIW5piopeSylvSVOzFaNmZk9wbbPb2P4fcP0vLuH4fcNK0CsYAoIHVCdeK7X3m9mp3MRtZbrbJSSxwJw+Hh80+XkyclFtYapk6pRrkQKCB1SmXg+duwYF1xwAb29vbHbVt7NL6V2MTg4qOSxLDK0trFa4onZExx9svsDJ6XzFBAaUDl5XVrrBAwMDLBr166600VU1y6Sgkjl97dv357Y60lWrpGNjS/tOjM3o6ajFUgBoY64yesq5wBqRaPTRVTWLuKCSGRwcFC1AUk08a2JprZXwnnlUUCoI2nyumgOoFbVGsOQtP3OnTtj8xAnT55MpUySP+MHxusmhpNyCElOzJ6gtDedf+eSDwoIdWRxnYBisRi72lOagUryo9HupI3mECqVj5fV82gFUUCoI611AqrzEK2OC8hioJLuSOpOWn13P7Z5jP5VzXVQgCAoXP/J67n5czc39b1Gai2SLQoIdaSxTkBcHqJcLreUh9CCNhJJagqK7u6jCzLAlhdtwWh+sKPj7Ni/o+GLugbB5ZMCQh1prBMQl4eYn59vqXlHC9pIJKkpyLAFF+TrPnkdH3rwQzhLG+zo+OlaR727/0ZrLZItCggNaDbxW60dzTta0EYiSU1BcRf+ea8/WWIt5eNlbv7czdz46RsXBJsbP33jgqCQVGtpNrEtnaWA0AHtat5pNVDJ8lC8pLjkpqCluH3/7czOzy54b3Z+lm2f33b6dVKtZSmJbekcBYQOiGve6enpUfOONKVWM83EtyaW3BSUlsmTway74wfGmZ6ZXvR5/6p+xjbr33yWafrrDoju3EulEocPH2ZoaIhCocC1117b5ZJJXkRJ2qhdPkrSQlBDyEpTTHU5I4OrB9l+1XZNtZ1xqiF0SHXzzsDAQLeLJDlSb6bSbtcOIBiX3h0AAAvUSURBVLjox5UTYE1fMG5G3VCzTQFBJAfqzVRaS3Qxbqe+3j62X7U9sSxRjUbdULNNAUEkB1pJxj596ukUSxLvzmvuBEhMbPdar7qh5oACgkgONDNTabXZ+Vl6rfYsua0wjOs/eT1bPrUltunKMOZ8Lva7Wcl9SEBJZZEcaHam0mpJF+Q0REEg6Ri18hvqhpotqiGI5MByvJNWN9TsUUAQyYGkO+nB1YMU1hYwrCPJ47QU1hbY+dqd6oaaMQoIIjkQNz1F/6p+tl+1nUNvP8T8u+Y5OZuP9TAM49DbDykYZJACgkgOFC8psvO1O0/XBuLusNuZJ0iT8gbZpaSySE4ULynWvKvutd7MBwXDlDfIMNUQRHIumuMoD8HgVRe8itLekkYrZ5RqCCI5ljR3UNYU1hYY2TjCrod2Jc7HJN2nGoJIjiXNHZQlURJ54lsTGq2ccQoIIjmWh/EJUe8oLZqTfQoIIjmW1GOnx7LzX/up2acYPzCuRXNyIDv/akSkaUnLZ7a6VGbaSntLiWMp1OsoOxQQRHKsenxCOyexa8Xh44cXlBXOzIBa2ltSb6OMUEAQybniJcXTo5WzVjOIDKwOFoQqXlI8XVOIuslqbYTsUEAQWUby0B6ftPqbeht1n8YhiOTU+IFxSntLHD5+mKG1Q4xsHIld3D4Lpk5OnX6u3kbZpYAgkkPVA9LKx8vcvv/2LpcqWY/10PPuHobWDjGweoDJk5OLtslD7Wa5U0AQyaE8DEirVJkvWNWzir7ePmbmZk5/rt5G2aAcgkgOJS1mnwez87Oc03dOzZlbpTtUQxDJmfED4xhWc2nKrJs8Ocmx3z7W7WJIFdUQRHKmtLeU62AAZHa8xErXUkAws9eb2dfNbN7MLqux3ZVm9qiZHTSzW1o5pshKl9XeOIOrB1nVs6qhbbM+VfdK1WoN4WHgWuBLSRuYWS/wQeAq4GLgTWZ2cYvHFVmx2tEbxzAGVw/Gflbvbr5/VT+7r93Nsd8+xl2/cteC3EDSPqPRypItLQUEd3/E3R+ts9lLgYPu/pi7zwAfA65p5bgiK1nS/EVL1b+qn62XbeXfTv3bos8Mq3k332u9CxLClaOmk3oNGcbIxpF0Ci+pMvfW2yLNbB/wDnffH/PZ64Ar3f2t4evrgZe5+9sS9jUKjAKsW7fu0j179rRcviyanp5mzZo13S5G2+j82mvq5BRHnzy6oOvmUvVYD+6+IC+x4ewNHHn6SEPf7+vtY/05609PTxGVr3y8nDiVRo/1UFhbWPCdTur236+drrjiigfdPbEJv5a6vYzM7D7guTEfldz9M0s5aC3uvhPYCXDRRRf5pk2b0j5EJuzbt4/lem6g8+uU4fcNN90FdVXPKsysZjC59cJbecc339HwPvtX9S+oKTRSrsLaAofefqjhY6QpK3+/rKkbENz91S0e4yhwfsXrDeF7ItKiWgnmwtpC7EV5dn429XJEcxFFAaGRxHdWk+MrWSe6nX4F2GhmF5hZH/BG4J4OHFdk2UtKMEd334Z1rCyVF/hGEt+aqiJ7Wu12+qtmdgS4HPicmX0hfP/5ZjYB4O6ngLcBXwAeAfa4+9dbK7aIQHyCuXIaiE5edAdWDzD8vmF63t3D9Mx0zS6omqoim1oaqezunwI+FfP+d4GRitcTwEQrxxKRxaImmupZT0t7S1z3yetil9Jc1bNqyc1GvdbLnM8tGim9qmcVT848eXrSusmTk/T19jG4epCpk1Onk8dTJ6cYWjvE2OYxTVWRQZq6QiTnipcUT19cq2dBre7lM7h6kO1XbWfb57fFzjgaXfDjVCaOq6fenp6ZXrS/mbkZ1vSt0RQVOaKpK0SWkXqzoK7pW0PxkiLbr9oe29Q0eulobFPP4OrBxPEGh95+aMF6B5WUOM4XBQSRZaTeBTj6vHot5mjG0VcMvQKzhYnovt4+tl+1vWYTT1KuQonjfFGTkcgyMrR2qGb//8oLdGVTU2T4fcOLxifMzM0s6FIaZ2zz2IKmKlDiOI9UQxBZRmpNa9HIBXqpy1sm1TiUOM4X1RBElpHKXkfl4+XTSeLC2kJDPXuSahiNNP3E1TgkXxQQRJaZVi7MUdNPJTX9rBwKCCJyWhRIph6ZwjCNGVhhFBBEZIHiJUX2Te5j/g3xM5XK8qWksoiIAAoIIiISUkAQERFAAUFEREIKCCIiAiggiIhISAFBREQABQQREQkpIIiICKCAICIiIQUEEREBFBBERCSkgCAiIoACgoiIhBQQREQEUEAQEZGQAoKIiAAKCCIiElJAEBERQAFBRERCCggiIgIoIIiISEgBQUREAAUEEREJKSCIiAiggCAiIiEFBBERARQQREQkpIAgIiKAAoKIiIQUEEREBFBAEBGRkAKCiIgACggiIhJSQBAREUABQUREQi0FBDN7vZl93czmzeyyGtsdMrMDZvZVM9vfyjFFRKQ9zmrx+w8D1wIfamDbK9z9WIvHExGRNmkpILj7IwBmlk5pRESkazqVQ3DgL8zsQTMb7dAxRUSkCXVrCGZ2H/DcmI9K7v6ZBo/zSnc/amY/BnzRzP7Z3b+UcLxRIAoaT5vZww0eI2/OA5ZzE5rOL990fvl10VK/WDcguPurl7rzin0cDX/+wMw+BbwUiA0I7r4T2AlgZvvdPTFZnWfL+dxA55d3Or/8aqXjTtubjMzsmWZ2TvQc+CWCZLSIiGRIq91Of9XMjgCXA58zsy+E7z/fzCbCzZ4D/LWZPQT8A/A5d//zVo4rIiLpa7WX0aeAT8W8/11gJHz+GPCiJR5i59JLl3nL+dxA55d3Or/8WvK5mbunWRAREckpTV0hIiJAhgLCcp8Go4nzu9LMHjWzg2Z2SyfL2AozGzCzL5rZt8Kfz07Ybi78233VzO7pdDmbVe/vYWZnm9nd4ed/b2bDnS/l0jRwbjeY2Q8r/l5v7UY5l8rM7jSzHyR1XbfA+8Pz/5qZvaTTZVyqBs5tk5kdr/jbvbOhHbt7Jh7ATxH0n90HXFZju0PAed0ubzvOD+gFvg28AOgDHgIu7nbZGzy//wncEj6/BfjjhO2mu13WJs6p7t8DuBnYET5/I3B3t8ud4rndAHyg22Vt4Rx/AXgJ8HDC5yPA5wEDXg78fbfLnOK5bQLubXa/makhuPsj7v5ot8vRLg2e30uBg+7+mLvPAB8Drml/6VJxDbArfL4L+JUuliUtjfw9Ks/748Bmy8dcLnn+t9YQDwa/TtXY5Brgwx74MnCumT2vM6VrTQPntiSZCQhNWM7TYKwHHq94fSR8Lw+e4+7fC59/n6C7cZxnmNl+M/uymWU9aDTy9zi9jbufAo4Dgx0pXWsa/bf2H8LmlI+b2fmdKVrH5Pn/WyMuN7OHzOzzZvbTjXyh1dlOm9LpaTA6LaXzy6xa51f5wt3dzJK6rxXCv98LgPvN7IC7fzvtskoqPgt81N2fNrPfIKgJvarLZZLG/CPB/7VpMxsBPg1srPeljgYE7/A0GJ2WwvkdBSrvwjaE72VCrfMzs38xs+e5+/fCavcPEvYR/f0eM7N9wM8StGVnUSN/j2ibI2Z2FrAWmOxM8VpS99zcvfI87iDIEy0nmf7/1gp3/9eK5xNmdpuZned1liDIVZPRCpgG4yvARjO7wMz6CJKUme+JE7oH2BI+3wIsqhGZ2bPN7Ozw+XnAK4BvdKyEzWvk71F53q8D7vcwq5dxdc+tqj39auCRDpavE+4Bfj3sbfRy4HhFs2eumdlzo1yWmb2U4Fpf/0al29nyiqz4rxK04T0N/AvwhfD95wMT4fMXEPSGeAj4OkFTTNfLntb5ha9HgG8S3DXn6fwGgb3At4D7gIHw/cuAO8LnPw8cCP9+B4C3dLvcDZzXor8H8B7g6vD5M4D/BxwkmJrlBd0uc4rn9kfh/7OHgAeAn+x2mZs8v48C3wNmw/97bwG2AlvDzw34YHj+B6jRuzFrjwbO7W0Vf7svAz/fyH41UllERICcNRmJiEj7KCCIiAiggCAiIiEFBBERARQQREQkpIAgIiKAAoKIiIQUEEREBID/D1T2Lf+XvrvCAAAAAElFTkSuQmCC\n", 456 | "text/plain": [ 457 | "
" 458 | ] 459 | }, 460 | "metadata": { 461 | "needs_background": "light" 462 | }, 463 | "output_type": "display_data" 464 | } 465 | ], 466 | "source": [ 467 | "fig, ax = plt.subplots(figsize=(6,6))\n", 468 | "xmin = -1.5\n", 469 | "xmax = 1.5\n", 470 | "ymin = -1.5\n", 471 | "ymax = 1.5\n", 472 | "ax.set_xlim([xmin,xmax])\n", 473 | "ax.set_ylim([ymin,ymax])\n", 474 | "color_list = ['r', 'g', 'b', 'k']\n", 475 | "for j in range(num_datapoints_transmitted):\n", 476 | " c_ind = j%4\n", 477 | " ax.scatter(received_signal[j,0], received_signal[j,1], c=color_list[c_ind], label = m_test[c_ind])\n", 478 | "ax.grid(True)\n", 479 | "plt.title('received signal')\n", 480 | "plt.show()\n" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 161, 493 | "metadata": {}, 494 | "outputs": [ 495 | { 496 | "data": { 497 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAF1CAYAAADoc51vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAXDElEQVR4nO3de7SldX3f8ffHGblorCCD3C9OIFXM2AjIxRhDFQ1QCt6wkKaCkYykoaZdM80idRWNNVW6QkwTSXG0VNQKqAkyylACorKMgUAtZLgUGVikM8NVEAwBgQPf/vH8xmyO5zazn3PZM+/XWmed5/Lbz+/33XvO/uzntidVhSRJL5jvAUiSFgYDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgaBZluT8JP9xYP43kzyQ5PEku7TfS6fZxr6t3aIZ9HdUkg19jH0YSa5IcuosbPezST66BY+b9nnuQ5J7khw92/1odiye7wFo65HkNOD0qnrDpmVVdcbA+hcCfwgcUVU3t8U/M912q+r/zaTdQlJVx873GAZV1Ug9f5of7iGoF0lm8uFiN2AH4NZZHs6MzHDM0jbDQNCkkpyV5K4kf5fktiRvH1h3WpK/TPKJJA8DlwDnA0e2wxOPtnafTfLRJD8H3NEe/miSa9r6SnJAm94xyblJ/jbJY0m+05bt39otbu3em+T2Nq67k7x/M2qqJL+V5E7gzrbslUmuSvJIkjuSvHug/YRjauuOSPLdJI8muTnJUQOP+1aS05Ns39b//MC6XZM8meTlbf74JDe1dt9N8pqBtq9N8r1W6yV0gTpZbQck+XYb5w9a+8G6Nz3PuyT5WpIfJbmhvT7fGdf2jCR3tjGdlyRt3c8muSbJw62P/5lkp5k+/1rYDARN5S7gl4CXAr8HfCHJHgPrDwfupvvk/2vAGcBfVdXPVNXz3iSq6vvAq9vsTlX1pgn6+wPgEOD1wMuA3wGem6Ddg8DxwD8C3gt8IsnBm1HX29rYD0ryYuAq4IvAy4GTgT9NctBUY0qyF3A58NG2fCXwZ0l2HVf3U8CfA6cMLH438O2qejDJa4ELgPcDuwCfAla3INkO+Crw+dbHl4F3TlHXfwL+AtgZ2Bv4k0nanQf8PbA7cGr7Ge944HXAa9p4f6UtD/AxYE/gVcA+wIenGJNGiIGgSVXVl6vq3qp6rqouoftEfdhAk3ur6k+qaqyqnhymryQvAH4d+O2q2lhVz1bVd9sb6vhxXV5Vd1Xn23Rvgr+0Gd19rKoeaWM+Hrinqv5Hq+P/AH8GnDTNmH4NWFNVa9rzcxVwI3DcBP19kS5oNvnVtgxgOfCpqrq+bf9C4CngiPbzQuCPquqZqvoKcMMUdT0D7AfsWVU/rqrvjG/QTsy/E/hQVT1RVbcBF06wrY9X1aPt/M03gV8AqKp1VXVVVT1VVQ/RnRP65SnGpBFiIGhSSd4zcCjjUeDngSUDTdb32N0SusMhd81gXMcmua4d4nmU7k14yXSPGzA47v2AwzfV2Lb3L+k+PU81pv3oQmPwcW8A9pig7TeBFyU5PMn+dG+ulw5sZ8W47exD9wl8T2BjPf8bKP92irp+h+4T/F8nuTXJr0/QZle6i0kGn4OJXsf7B6afoJ3UT7JbkouTbEzyI+ALbN5zrwXMk2qaUJL9gE8Db6Y7DPRskpvo3nA2Gf9VucN8de4PgB8DPwvcPFmjJNvTfYJ/D3BZVT2T5KvjxjWdwXGupzt885YJ+nrBFGNaD3y+qn5j2s665+5LdIeNHgC+XlV/N7Cd36+q35+g/18G9kqSgVDYl0lCs6ruB36jPfYNwNVJrq2qdQPNHgLG6A4pfb8t22e6Ggb8Z7rnb1lVPZLkbcAnN+PxWsDcQ9BkXkz3h/8QdCdy6fYQpvIAsHc79r1Zquo5umPpf5hkzySLkhzZAmDQdsD2bVxjSY4F3rq5/Q34OvBzSf5Vkhe2n9cledU0Y/oC8M+T/EpbvkO6eyD2nqSfLwL/gm7v44sDyz8NnNH2HpLkxUn+WZKXAH9F9+b9gTaud/D8Q3bPk+Skgf5/SPf6Pe8cTFU9S3dO48NJXpTklXThOlMvAR4HHmvnUf79ZjxWC5yBoAm1Y8vn0r0pPQAsA/5ymoddQ3dJ6f1JfrAF3a4E1tIdJ38EOIdx/0bbJ+sPAF+ie9P7VWD1FvQ1uL230h3jv5fuUMk5dKEz6Ziqaj1wIvAf6MJpPd2b44R/U1V1Pd2J3D2BKwaW30j3qf6TrZ51wGlt3dPAO9r8I3SB8udTlPM64Pokj9M9J79dVXdP0O5MugsF7qc7YX0R3XmLmfg94GDgMbqT6lONRyMm/gc50rYtyTnA7lXV+53VGi3uIUjbmHT3XbymHaI6DHgf/3CSW9uwXgIhyQVJHkxyyyTrj2o3y9zUfs7uo19JW+QldId6/p7uhsJzgcvmdURaEHo5ZJTkjXQnmj5XVT914jHdHZwrq+r4oTuTJM2KXvYQqupaupNekqQRNZfnEI5M930vVyR59fTNJUlzaa5uTPsesF9VPZ7kOLrvZzlwooZJltPdzs/22+9wyMv33GuOhji3FgfGtuILvKxvtFnf6Np4/0M8+8Rjm3Oj5k/0dtlpuyX/6xOdQ5ig7T3AoVU15bXq+y49oF7w7v/ay/gWmhXLxjh37dZ7o7j1jTbrG133Xfhveeq+O7coEObkkFGS3Qe+Pvew1u/Dc9G3JGlmeonIJBcBRwFL0v33hR+i+5ZGqup84F3AbyYZA54ETi7viJOkBaWXQKiqU6ZZ/0n8AixJWtC8U1mSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSYCBIElqegmEJBckeTDJLZOsT5I/TrIuyd8kObiPfiVJ/elrD+GzwDFTrD8WOLD9LAf+W0/9SpJ60ksgVNW1wCNTNDkR+Fx1rgN2SrJHH31LkvqxeI762QtYPzC/oS27b3zDJMvp9iJYsmRXzl42NicDnGu77QgrttLawPpGnfWNrpVDPHauAmHGqmoVsApg36UH1LlrF9wQe7Fi2Rhba21gfaPO+rZNc3WV0UZgn4H5vdsySdICMVeBsBp4T7va6Ajgsar6qcNFkqT508s+U5KLgKOAJUk2AB8CXghQVecDa4DjgHXAE8B7++hXktSfXgKhqk6ZZn0Bv9VHX5Kk2eGdypIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAnoKhCTHJLkjybokZ02w/rQkDyW5qf2c3ke/kqT+LB52A0kWAecBbwE2ADckWV1Vt41reklVnTlsf5Kk2dHHHsJhwLqquruqngYuBk7sYbuSpDk09B4CsBewfmB+A3D4BO3emeSNwPeBf1dV6ydoQ5LlwHKAJUt25exlYz0MceHZbUdYsZXWBtY36qxvdK0c4rF9BMJMfA24qKqeSvJ+4ELgTRM1rKpVwCqAfZceUOeunashzq0Vy8bYWmsD6xt11rdt6uOQ0UZgn4H5vduyn6iqh6vqqTb7GeCQHvqVJPWoj0C4ATgwySuSbAecDKwebJBkj4HZE4Dbe+hXktSjofeZqmosyZnAlcAi4IKqujXJR4Abq2o18IEkJwBjwCPAacP2K0nqVy8H0apqDbBm3LKzB6Z/F/jdPvqSJM0O71SWJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRLQUyAkOSbJHUnWJTlrgvXbJ7mkrb8+yf599CtJ6s/QgZBkEXAecCxwEHBKkoPGNXsf8MOqOgD4BHDOsP1KkvrVxx7CYcC6qrq7qp4GLgZOHNfmRODCNv0V4M1J0kPfkqSe9BEIewHrB+Y3tGUTtqmqMeAxYJce+pYk9WTxfA9gvCTLgeUAS5bsytnLxuZ5RLNjtx1hxVZaG1jfqLO+0bVyiMf2EQgbgX0G5vduyyZqsyHJYuClwMMTbayqVgGrAPZdekCdu3bBZVYvViwbY2utDaxv1FnftqmPQ0Y3AAcmeUWS7YCTgdXj2qwGTm3T7wKuqarqoW9JUk+GjsiqGktyJnAlsAi4oKpuTfIR4MaqWg38d+DzSdYBj9CFhiRpAelln6mq1gBrxi07e2D6x8BJffQlSZod3qksSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNQaCJAkwECRJjYEgSQIMBElSYyBIkgADQZLUGAiSJMBAkCQ1BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkwECQJDUGgiQJMBAkSY2BIEkCDARJUmMgSJIAA0GS1BgIkiTAQJAkNUMFQpKXJbkqyZ3t986TtHs2yU3tZ/UwfUqSZsewewhnAd+oqgOBb7T5iTxZVb/Qfk4Ysk9J0iwYNhBOBC5s0xcCbxtye5KkeZKq2vIHJ49W1U5tOsAPN82PazcG3ASMAR+vqq9Osc3lwHKAJUt2PeTsP/r0Fo9vIdttR3jgyfkexeyxvtFmfaNr5cqVPHXfndmSxy6erkGSq4HdJ1j1wcGZqqokk6XLflW1MclS4Joka6vqrokaVtUqYBXAvksPqHPXTjvEkbRi2Rhba21gfaPO+rZN0z4jVXX0ZOuSPJBkj6q6L8kewIOTbGNj+313km8BrwUmDARJ0vwY9hzCauDUNn0qcNn4Bkl2TrJ9m14C/CJw25D9SpJ6NmwgfBx4S5I7gaPbPEkOTfKZ1uZVwI1Jbga+SXcOwUCQpAVmqINoVfUw8OYJlt8InN6mvwssG6YfSdLs805lSRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBQwZCkpOS3JrkuSSHTtHumCR3JFmX5Kxh+pQkzY5h9xBuAd4BXDtZgySLgPOAY4GDgFOSHDRkv5Kkni0e5sFVdTtAkqmaHQasq6q7W9uLgROB24bpW5LUr6ECYYb2AtYPzG8ADp+scZLlwHKAJUt25exlY7M7unmy246wYiutDaxv1Fnf6Fo5xGOnDYQkVwO7T7Dqg1V12RB9T6iqVgGrAPZdekCdu3YuMmvurVg2xtZaG1jfqLO+bdO0z0hVHT1kHxuBfQbm927LJEkLyFxcdnoDcGCSVyTZDjgZWD0H/UqSNsOwl52+PckG4Ejg8iRXtuV7JlkDUFVjwJnAlcDtwJeq6tbhhi1J6tuwVxldClw6wfJ7geMG5tcAa4bpS5I0u7xTWZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJgIEgSWoMBEkSYCBIkhoDQZIEGAiSpMZAkCQBBoIkqTEQJEmAgSBJagwESRJgIEiSGgNBkgQYCJKkxkCQJAEGgiSpMRAkSYCBIElqDARJEmAgSJIaA0GSBBgIkqTGQJAkAQaCJKkxECRJwJCBkOSkJLcmeS7JoVO0uyfJ2iQ3JblxmD4lSbNj8ZCPvwV4B/CpGbT9p1X1gyH7kyTNkqECoapuB0jSz2gkSfNmrs4hFPAXSf53kuVz1KckaTOkqqZukFwN7D7Bqg9W1WWtzbeAlVU14fmBJHtV1cYkLweuAv5NVV07SdvlQBcaixYfst2u+8+skhHz7BOPsehFL53vYcwa6xtt1je6nnl4A889/eQWHbaZNhBmtJFpAmFc2w8Dj1fVH8yg7Y1VNenJ6lG2NdcG1jfqrG90DVPbrB8ySvLiJC/ZNA28le5ktCRpARn2stO3J9kAHAlcnuTKtnzPJGtas92A7yS5Gfhr4PKq+l/D9CtJ6t+wVxldClw6wfJ7gePa9N3AP9nCLlZt+egWvK25NrC+UWd9o2uLa+vlHIIkafT51RWSJGABBcLW/jUYm1HfMUnuSLIuyVlzOcZhJHlZkquS3Nl+7zxJu2fba3dTktVzPc7NNd3rkWT7JJe09dcn2X/uR7llZlDbaUkeGni9Tp+PcW6pJBckeTDJhBexpPPHrf6/SXLwXI9xS82gtqOSPDbw2p09ow1X1YL4AV4F/GPgW8ChU7S7B1gy3+OdjfqARcBdwFJgO+Bm4KD5HvsM6/svwFlt+izgnEnaPT7fY92MmqZ9PYB/DZzfpk8GLpnvcfdY22nAJ+d7rEPU+EbgYOCWSdYfB1wBBDgCuH6+x9xjbUcBX9/c7S6YPYSqur2q7pjvccyWGdZ3GLCuqu6uqqeBi4ETZ390vTgRuLBNXwi8bR7H0peZvB6DdX8FeHNG47tcRvnf2oxUd/PrI1M0ORH4XHWuA3ZKssfcjG44M6htiyyYQNgMW/PXYOwFrB+Y39CWjYLdquq+Nn0/3eXGE9khyY1Jrkuy0ENjJq/HT9pU1RjwGLDLnIxuODP9t/bOdjjlK0n2mZuhzZlR/nubiSOT3JzkiiSvnskDhv22080yk6/BmIE31MDXYCT5vzXJ12DMtZ7qW7Cmqm9wpqoqyWSXr+3XXr+lwDVJ1lbVXX2PVb34GnBRVT2V5P10e0JvmucxaWa+R/e39niS44CvAgdO96A5DYSqOrqHbWxsvx9Mcindru+CCIQe6tsIDH4K27stWxCmqi/JA0n2qKr72m73g5NsY9Prd3f7ypPX0h3LXohm8npsarMhyWLgpcDDczO8oUxbW1UN1vEZuvNEW5MF/fc2jKr60cD0miR/mmRJTfNfEIzUIaNt4GswbgAOTPKKJNvRnaRc8FfiNKuBU9v0qcBP7REl2TnJ9m16CfCLwG1zNsLNN5PXY7DudwHXVDurt8BNW9u44+knALfP4fjmwmrgPe1qoyOAxwYOe460JLtvOpeV5DC69/rpP6jM99nygbPib6c7hvcU8ABwZVu+J7CmTS+luxriZuBWukMx8z72vupr88cB36f71DxK9e0CfAO4E7gaeFlbfijwmTb9emBte/3WAu+b73HPoK6fej2AjwAntOkdgC8D6+i+mmXpfI+5x9o+1v7Obga+Cbxyvse8mfVdBNwHPNP+9t4HnAGc0dYHOK/Vv5Yprm5caD8zqO3MgdfuOuD1M9mudypLkoARO2QkSZo9BoIkCTAQJEmNgSBJAgwESVJjIEiSAANBktQYCJIkAP4/l0kglPWFAu0AAAAASUVORK5CYII=\n", 498 | "text/plain": [ 499 | "
" 500 | ] 501 | }, 502 | "metadata": { 503 | "needs_background": "light" 504 | }, 505 | "output_type": "display_data" 506 | } 507 | ], 508 | "source": [ 509 | "# generating artificial recieved signal for visualizing decision region\n", 510 | "num_data_points = 100\n", 511 | "artificial_input = torch.zeros(num_data_points*num_data_points,2)\n", 512 | "\n", 513 | "real_min = -1.5\n", 514 | "real_max = 1.5\n", 515 | "im_min = -1.5\n", 516 | "im_max = 1.5\n", 517 | "\n", 518 | "for real_input in range(num_data_points):\n", 519 | " for im_input in range(num_data_points):\n", 520 | " artificial_input[real_input*num_data_points+im_input,0] = ((real_max-real_min)/(num_data_points-1)) * real_input + real_min\n", 521 | " artificial_input[real_input*num_data_points+im_input,1] = ((im_max-im_min)/(num_data_points-1)) * im_input + im_min\n", 522 | "fig, ax = plt.subplots(figsize=(6,6))\n", 523 | "xmin = -1.5\n", 524 | "xmax = 1.5\n", 525 | "ymin = -1.5\n", 526 | "ymax = 1.5\n", 527 | "ax.set_xlim([xmin,xmax])\n", 528 | "ax.set_ylim([ymin,ymax])\n", 529 | "ax.scatter(artificial_input[:,0], artificial_input[:,1])\n", 530 | "plt.title('artificial received signal')\n", 531 | "ax.grid(True)\n", 532 | "plt.show()" 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": 162, 538 | "metadata": {}, 539 | "outputs": [ 540 | { 541 | "data": { 542 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAF1CAYAAADoc51vAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3df7BkZZ3f8fcXGEYRBYchCtzhhyuwsi5RYcEfKUPEbJAVRl00WPiDLWE0CWu54VoZxyhmKyKb3KrdmHEjrLHEbHYEcXUHxaiglEW5/kBLHJCAIxmGO6AIAjozCgzzzR99Zujb9L3d9/bpPud0v19VXZy+59w+zzOXuZ853/M8z4nMRJKkfapugCSpHgwESRJgIEiSCgaCJAkwECRJBQNBkgQYCKqhiPhURPznAT/jvIj4ah/HfTwiPjDIuUYhItZFxCeqbofGWzgPQXUTEZ8CZjPzP1bdFmmSeIUgjUBE7Fd1G6ReDARVLiJeHBE/iIhfR8RVwNM69r82In4YEQ9HxLci4sS2fasi4u8j4hcR8WBErC++fn5E3FRsR0T8ZUTcHxG/iohNEfHCYt+c8lREXBgRmyPilxGxMSIOb9uXEfGuiPhJ0ZaPRUTM06cPRcQ1EfG3EfEr4PyI2Cci1kbET4u2Xh0RK9q+520RcXex7wMRsSUiXt32eX/bduzZEXFb0Y4bI+IFbfu2RMR0RPwoIh6JiKsiYs6fqdSNgaBKRcT+wBeA/wWsAD4L/HHb/hcDnwTeCRwCXA5sjIjlEbEv8EXgbuBo4AjgM11O84fAK4HjgIOANwEPdmnLq4CPFPsPKz638/NeC/wBcGJx3L9aoHurgWuAg4H/Dfwp8DrgnwOHAw8BHyvOfQLw18B5xbkPKvrzFBFxHLABeA9wKHAdcG3xZ7nHm4AzgGOKtp6/QDslwEBQ9V4KLAP+KjMfz8xrgO+17V8DXJ6Z38nMJzLzSuDR4vtOofWL9b2ZuSMzf5uZN3U5x+PAM4HfpXXf7PbMvK/LcecBn8zMH2Tmo8D7gJdFxNFtx1yWmQ9n5lbgG8CLFujbP2bmFzJzd2b+BngX8P7MnC0+/0PAOUU56Rzg2sy8KTMfAz4IzHeD718DX8rMr2Xm48AM8HTg5W3HfDQz783MXwLX9minBBgIqt7hwLacO7rh7rbto4CLi9LIwxHxMLCq+L5VwN2ZuWuhE2Tm14H1tP41fn9EXBERz5qnLXe3fd92WlcS7f9S/1nb9k7gwAVOfU/H+6OAz7f143bgCeA5xbn3Hp+ZO+lyFTNPO3cX37vUdkqAgaDq3Qcc0VGLP7Jt+x7gw5l5cNvrgMzcUOw7sp8btpn50cw8CTiBVunovV0Ou5fWL20AIuIZtMpU2xbdq+K0He/vAV7T0ZenZeY2Wn8OU23nfnpx7m462xm0wnGp7ZQAA0HV+0dgF/DuiFgWEW+gVQra42+Ad0XEqcXN4WdExB9FxDOB79L6RXpZ8fWnRcQrOk8QEX9QfP8yYAfwW2B3l7ZsAP4kIl4UEcuBS4HvZOaWkvr6ceDDEXFU0a5DI2J1se8a4KyIeHlxL+BDQNcb1sDVwB9FxOlFny6mVUb7Vknt1IQyEFSpol7+Blo3PX9Jqz7+9237bwYupFXyeQjYXBxLZj4BnAU8H9gKzBbf3+lZtILlIVqllgeB/9qlLdcDHwA+Rytofgc4d9A+tvlvwEbgqxHxa+DbwKnFuW+jddP5M8W5twP30/pF39nOO4C3AP8deIDWn8FZxZ+ltGROTJNqKCIOBB4Gjs3M/1d1ezQZvEKQaiIizoqIA4p7FzPAJmBLta3SJCklECLik8Wkn1vn2X9aMUHmh8Xrg2WcVxozq2ndML4XOBY4N72E1wiVUjKKiFfSqnl+OjNf2GX/acB0Zr524JNJkoailCuEzPwmrRuCkqSGGuU9hJdFxC0R8eWI+L0RnleS1IdRrcD4A+CozNweEWfSWrvm2G4HRsQaWssVsHz500469NAjux3WeMuW7ebxx8f3nr79q9CynYN/xD7LeHz34yU0pp7GuX+zP5sld+Z8c1gWVNqw02K9ly92u4fQ5dgtwMmZ+cBCx61adXzOzt5RSvvqZmbmRqanT6u6GUNj/yr03pXwjPlWvejPzHEzTN85XVKD6mes+3c55L1LC4SR/BMnIp67Z2mCiDilOO9g/8dKkkpVSskoIjYApwErI2IWuITWCpZk5sdpreT4byJiF/AbHE4nDc8Bju/Q0pQSCJn55h7719NaekDSMCzbDo8XC5ruXDFwyUiTqaZ3xSQtyunrYNmOqluhhvM5r9I4OHU9HPAA3HCpJSMtmYEgjYNHjoQTN7Re0hJZMpLGwR1nzv/ATalPBoI0Do6/bv7H6Uh9MhCkcXDQ1qpboDFgIEjjYOeKqlugMWAgSJIAA0EaDw41VQkMBGkcWDJSCQwESRJgIEjjwZKRSmAgSOPAkpFKYCBIkgADQRoPloxUAgNBGgePjOezxzVaBoI0DlzcTiUwEKRx4OJ2KoGBII0DF7dTCQwEqYmWbZ/73mGnKoGBIDWRz1DWEBgIUhOduh7OuhAO2gLsdtipSuEzlaUm2rnCZyirdF4hSJIAA0FqJktEGgIDQWoiZyZrCAwEqYmcmawhMBCkJnJmsobAQJCayJnJGgIDQWoiZyZrCAwESRJgIEjN5LBTDYGBIDWRJSMNgYEgSQIMBKmZLBlpCAwEqYksGWkIDARJEmAgSM1kyUhDYCBITeTidhoCA0FqIhe30xAYCFITubidhsBAkJrIxe00BAaC1BTLtj+57bBTDYGBIDXF6etg2Y6qW6Extl/VDZDUp1PXwwEPwA2XOuxUQ2EgSE2xcwWcuKH1koaglJJRRHwyIu6PiFvn2R8R8dGI2BwRP4qIl5RxXklSecq6h/Ap4IwF9r8GOLZ4rQH+R0nnlSaHZSINWSmBkJnfBBb6v3U18Ols+TZwcEQcVsa5pYnh7GQNWWSWM90xIo4GvpiZL+yy74vAZZl5U/H+BuA/ZObNXY5dQ+sqgpUrDz1p7dqrS2lf3UxNbWd29sCqmzE09m8IDtoKz/jFSE41tXyK2UdnR3KuKoxz/6YvnibvzSVNW6zdTeXMvAK4AmDVquNzevq0ahs0JDMzNzKufQP7NxTvORoOvnskp5o5bobpO6dHcq4qjHv/lmpU8xC2Aava3k8VX5PUL2cna8hGFQgbgbcVo41eCjySmfeN6NzSeHB2soaslJJRRGwATgNWRsQscAmwDCAzPw5cB5wJbAZ2An9SxnklSeUpJRAy88099ifw78o4lzSxHHaqIXMtI6kpLBlpyAwESRJgIEjNYclIQ2YgSE3hTGUNmYEgNYXPUdaQGQhSU/gcZQ2ZgSA1hTOVNWQGgtQUDjvVkBkIkiTAQJCaw2GnGjIDQaqzZduf3LZkpCEzEKQ6O30dLNtRdSs0IWr3gBxJbU5dDwc8ADdcaslIQ2cgSHW2cwWcuKH1kobMkpEkCTAQpHqzTKQRMhCkOnNBO42QgSDVmQvaaYQMBKnOXNBOI2QgSHXmgnYaIQNBqjNnJ2uEDARJEmAgSPXmsFONkIEg1ZklI42QgSBJAgwEqd4sGWmEDASpzpyprBEyEKQ6c6ayRshAkOrMmcoaIQNBqjNnKmuEDASpTtqfoQwOO9VIGQhSnfgMZVXIQJDq5NT1cNaFcNAWYLfDTjVSPlNZqhOfoawKeYUgSQIMBKleLBGpQgaCVCeOKlKFDARJEmAgSPViyUgVMhCkOnExO1XIQJDqxMXsVCEDQaoTF7NThQwEqU5czE4VMhCkOnHYqSpkIEiSgJICISLOiIg7ImJzRKztsv/8iPhFRPyweF1QxnmlseOwU1Vo4MXtImJf4GPAvwRmge9FxMbM/HHHoVdl5kWDnk8aaztXwDMerLoVmlBlXCGcAmzOzLsy8zHgM8DqEj5XkjRCkTnYoOeIOAc4IzMvKN6/FTi1/WogIs4HPgL8ArgT+LPMvGeez1sDrAFYufLQk9auvXqg9tXV1NR2ZmcPrLoZQ2P/lujw75f/mUswtXyK2Udnq27G0Ixz/6YvnibvzSUNXh7V8xCuBTZk5qMR8U7gSuBV3Q7MzCuAKwBWrTo+p6dPG1ETR2tm5kbGtW9g/5bsPefDwXeX/7mLNHPcDNN3TlfdjKEZ9/4tVRklo23Aqrb3U8XX9srMBzPz0eLtJ4CTSjivNH6cqawKlREI3wOOjYhjImJ/4FxgY/sBEXFY29uzgdtLOK80fpyprAoNXDLKzF0RcRHwFWBf4JOZeVtE/Dlwc2ZuBN4dEWcDu4BfAucPel5pLDlTWRUq5R5CZl4HXNfxtQ+2bb8PeF8Z55LGzrLt8Hhxg9php6qQM5Wlqp2+DpbtqLoV0shGGUmaz6nr4YAH4IZLnamsShkIUtV2roATN7ReUoUsGUmSAANBqp5lItWEgSBVzecoqyYMBKlqzk5WTRgIUtWcnayaMBCkqjk7WTVhIEhV8znKqgkDQZIEGAhS9Rx2qpowEKSqWTJSTRgIkiTAQJCqZ8lINWEgSFWzZKSaMBAkSYCBIFXPkpFqwkCQqubidqoJA0GqmovbqSYMBKlqLm6nmjAQpKq5uJ1qwkCQqrBs+5PbDjtVTRgIUhVOXwfLdlTdCmmO/apugDSRTl0PBzwAN1zqsFPVhoEgVWHnCjhxQ+sl1YQlI0kSYCBI1bBMpBoyEKQqODtZNWQgSFVwdrJqyECQquDsZNWQgSBVwdnJqiEDQaqCs5NVQwaCJAkwEKRqOOxUNWQgSFWwZKQaMhAkSYCBIFXDkpFqyECQqmDJSDVkIEiSAANBqoYlI9WQgSBVwcXtVEMGglQFF7dTDRkIUhVc3E41ZCBIVXBxO9VQKYEQEWdExB0RsTki1nbZvzwirir2fyciji7jvFJjOexUNTRwIETEvsDHgNcAJwBvjogTOg57B/BQZj4f+EvgLwY9rySpXGVcIZwCbM7MuzLzMeAzwOqOY1YDVxbb1wCnR4QVVE0uh52qhsoIhCOAe9rezxZf63pMZu4CHgEOKeHcY+cAtnfdXmhfv8eV8Rmdx6lPyzr+3CwZqYYic7CxbxFxDnBGZl5QvH8rcGpmXtR2zK3FMbPF+58WxzzQ5fPWAGsAVq489KS1a68eqH11NTW1ndnZA5/y9VXcwzaOYDf7zNleaF+/x5XxGZ3HLbZ/42LR/XvWPfDrIyCLP7fn3gL77BpO40owtXyK2Udnq27G0Ixz/6YvnibvzSVVYPYr4fzbgFVt76eKr3U7ZjYi9gMOAh7s9mGZeQVwBcCqVcfn9PRpJTSxfmZmbqRb355gHzZwLu/nUu7iVXu3t3Iku9iv675+jyvjMzqPS4Ju4yfn69+4WHT/LtkHNp0LN1zampR2yasg6jsRYea4GabvnK66GUMz7v1bqjJKRt8Djo2IYyJif+BcYGPHMRuBtxfb5wBfz0EvTcbUVo7kPDawhWPYh9y7vZt957yfb3uh48r4jM7jjuLuqv/ImuGRI+HEDfBnx8CH9q11GGhyDRwIxT2Bi4CvALcDV2fmbRHx5xFxdnHY/wQOiYjNwL8HnjI0tU5exk08nR1d3/e7r9dx+7C763Ff5MxGTWD9MOs4YJF/NhPJmclqgFLmIWTmdZl5XGb+TmZ+uPjaBzNzY7H928x8Y2Y+PzNPycy7yjjvsPwdb+FvuJCj2EKwm6s5d+/7v+O8vvb1Ou4o7u563OvY2KgJrOexgSu69BmY989mIn8zOjNZDTDwTeVhWrXq+JydvWPk532CfdhnyL+0bpyZ4bTp8a1hLtS/fXiCbPgk+SXdQ2hQmWjca+xj3b/LWfJN5cb8rRzlEMwHcUjgMB3J3GUbJmKIq8NM1QC1DoT9eWxvfbq9Vt35fr7tQfZpeJb6s5Q0XLUOhP3Ytbc+/W7W790Ods95P9/2QscttO8QnEU6TO33HRbzs2z0vQdnJqsBah0IQGVDMDVcgw5xbdzMaktGaoDaB4K0x1LKf5adpP4ZCGqM9lLTUkuIlZWdLBmpAQwENUpjZ1b7DGU1gIGgiTLamdU3wZ7PuOFMeKxjt7eqVDMGgibKfDOrF5p1vvTf3G8BLgS2wKYvwbXAw8XHPYEzl1U7Zax2KjXKeWzgPDYAcDRb927Pt681s3opv723AnfDns/fVLwALllq66Xh8QpB6qHfmdU8ZYjrAkNNH2nb7iwlSRUxEKQeug1x3bNa7dx966Df+w438GQQXI+hoFowEKQeus2s3rNabfuwVljP3nsG7IaFZrxv4sl7Ct9l7v0FqSIGgtSHziGuK/jlU4a1HkLSul9wDLAvPX+7bwL+qmP7kY5jvHLQCBkIUp20l5LAcpJGylFGUklKmYu8ZxTS6bSePP5d4Ddt7x2qqiEyEKSSrAAeLOOD2oentr9/D3Bw8bXHgP3LOJn0JEtGUlM4MklD5hWCVJKhL1/XXk7qLCWB5SQNzECQSnIkDH/pvPZyUvt2ezlJWiJLRlJJzqzy5O3lpK3MLSd1vpfmYSBIJbmuypO3T3T7HHMnun0WJ76pL5aMpJJs7X3IcM1XTmp/76J6WoBXCNIA2peza8RTk11UTwswEKQBLGI5u3pw6KoWYMlIGsB64AHgUkYw7LQMDl3VAgwEaQAraC1nt6HXgXXSz9BVbz5PJEtGklray0m/wnLSBDIQpAE0okzUr/ahqzuYO1TVK4aJYMlIGsBIZieP0p4S0gzOhJ5AXiFIA6h0dvIodT6nQWPJKwRpAJXOTh6lzuc0OBppLHmFIA2g8tnJo9T+yE+NJQNBGkAjZieXrfO5zxobBoKkxfF+wtjyHoI0gLEadtov7yeMLa8QpAFMZMkInryf0Fk+8sqh0QwESUvXWT5ywbxGs2QkDWAiS0btOstHnQvmWU5qFANBGsAK4MGqG1G1+R7G4+zmxrFkJGk4HI3UOF4hSAOY+JLRQhyN1DheIUgDOLLqBtTdfKORVEsGgrRINwG7i+2JWdxuUN3KRy6pXTsGgrRIb6G15PUW4EvVNqU52p+1kMATWEKqoYHuIUTECuAq4Ghafz/elJkPdTnuCZ6sKG7NzLMHOa9Upa207h0cU3VDmqZ9NNIlVTZE8xn0CmEtcENmHkvronDtPMf9JjNfVLwMAzXaxM5OLpP3FGpp0EBYDVxZbF8JvG7Az5M0CRySWkuRufQ7OxHxcGYeXGwH8NCe9x3H7QJ+COwCLsvMLyzwmWuANQCHrlx50tVr57voaLbtU1McODtbdTOGZpz7931gamqK2THtH4yof08HngXsO9zTdDO1fIrZR8fz5zd98TR5by7pDk3PQIiI64Hndtn1fuDK9gCIiIcy89ldPuOIzNwWEc8Dvg6cnpk/7dW441etyjvG9C/djTMznDY9XXUzhmac+7cSeN/MDNNj2j+AmVH2r3NG82PA/sM95cxxM0zfOaY/v8tZciD0vKmcma+eb19E/DwiDsvM+yLiMOD+eT5jW/HfuyLiRuDFQM9AkDQBbgDO4skQuB54NUMPBT3VoPcQNgJvL7bfDvxD5wER8eyIWF5srwReAfx4wPNKlXF2csk6h6R+t+O9RmbQpSsuA66OiHfQGpr9JoCIOBl4V2ZeALwAuDwidtMKoMsy00BQYzk7eQhcIK8WBgqEzHyQ1kolnV+/Gbig2P4W8PuDnEeqE2cnj1BnOUlD5UxlaZGuq7oBk6SznKShMhCkRdpadQMmzZ4F8jR0BoK0SM5Uroizm4fOQJDUDM5uHjofkCMtksNOK9L5wB1wxdSSGQjSIlkyqlD78NT2IakjmN08CSwZSWqm9hLS9VhOKoFXCNIiWTKqifYS0neB32A5aUAGgrRIzlSukfYS0nzlJPXNkpG0SM5UbgBHJC2JgSAtkjOVG8AZzktiIEiL5EzlhnCG86IZCFIP2zveO+y0YZzh3DcDQephHbCj6kZo6byf0DcDQephPXAhsAXYjcNOG8f7CX0zEKQeVgAbgGNoPQ/e3ykNtOd+Qmf5yCuHOQwESZOjs3zkDOc5nJgm9WCJaIzsmbj2Ap58fnP7DOcJn93sFYLUg6OKxswm4OfAf2p7362cNIEMBEkCRyNhyUjqyZLRhOh83sIElo+8QpB6cDG7CTLh5SMDQerBxewmULfy0QSMN7ZkJPXgYnYTqLN8tJvWJJQxZyBIPbiY3YRqf77CJVU2ZHQsGUk9OOxUc+4pjPFIJANBknqZkOc3WzKSenDYqSbl+c0GgtTDCuDBqhuh6k3A85stGUnSIMZohrNXCFIPloy0oDGa4ewVgtSDM5XV05jMcDYQpB6cqay+Nbx8ZMlI6sGZyupbw8tHXiFIPThTWYuyp3zUQAaC1MX2tm1nKmtJGvj8ZgNB6mIdsKPqRqjZGvj8Zu8hSF2sBx4ALsVhp1qizvsJDXh+s4EgdbEC2FC8pCVrn9Hc/r6ms5stGUnSqNV0eKpXCFIXlok0VDUdnuoVgtSFI4s0dDWc3WwgSFKValQ+smQkdWHJSCPTWT6CykpIBoLUxZHA3VU3QpOjJs9XGKhkFBFvjIjbImJ3RJy8wHFnRMQdEbE5ItYOck5pFFzQTpVpLyFt5anlpBzeqQe9h3Ar8Abgm/MdEBH7Ah8DXgOcALw5Ik4Y8LzSULmgnSqzCbgWeBj4XNt2Ak8w1HLSQCWjzLwdIGLBFp4CbM7Mu4pjPwOsBn48yLmlYXJBO1Vqvsd1XjLc00bm4NcfEXEjMJ2ZN3fZdw5wRmZeULx/K3BqZl40z2etAdYAHLpy5UlXrx3PCtP2qSkOnJ2tuhlD0/T+3QLsWmD/1NQUsw3uXy/2r6aeA+y78CHTF0+T9+bSriMyc8EXrSWZbu3yWt12zI3AyfN8/znAJ9revxVY3+u8mclxU1OZMJavb8zMVN4G+zf/6xBIFnjNzMwsuL/pL/tX09fvk6wj+dACr8PIfn6/dnv1LBll5qt7HdPDNmBV2/up4mtSbTnsVLU05BnOo5iY9j3g2Ig4JiL2B84FNo7gvNKSOVNZtTXEGc6DDjt9fUTMAi8DvhQRXym+fnhEXAeQmbuAi4CvALcDV2fmbYM1W5Im3BBmOA86yujzwOe7fP1e2oZyZ+Z1OJJPDWLJSLU3hPKRaxlJXRxZdQOkfpRcPjIQpC6cqaxGKal85FpGUhfWN9Uo7eWjARgIUhfOVFbjdD6ucwksGUmF7W3bDjvVJDIQpMI6YEfVjZAqZCBIhfXAhcAWHHaqyWQgSIUVwAbgGFoLx0iTxkCQJAEGgrSXZSJNOgNBKjg7WZPOQJAKzk7WpDMQpIKzkzXpDASp4OxkTToDQSo4O1mTzkCQJAEGgrSXw0416QwEqWDJSJPOQJAkAQaCtJclI006A0EqWDLSpDMQJEmAgSDtZclIk85AkAoubqdJZyBIBRe306QzEKSCi9tp0hkIUsHF7TTpDARNrO0d7x12qklnIGhirQN2VN0IqUYMBE2s9cCFwBZgNw47lQwETawVwAbgGGBfIKttjlQ5A0GSBBgImmCWiKS5DARNLGcmS3MZCJpYzkyW5jIQNLGcmSzNZSBoYjkzWZrLQNBEaZ+d7MxkaS4DQRPF2cnS/AwETZT22ckOO5XmMhA0UdpnJzszWZrLQJAkAQaCJoxlIml+BoImirOTpfkZCBprNzF3VJGzk6X5DRQIEfHGiLgtInZHxMkLHLclIjZFxA8j4uZBziktxluY+8yDjZW2Rqq3/Qb8/luBNwCX93Hsv8jMBwY8n7QoW4G7aY0skrSwgQIhM28HiIhyWiOVbAXwYNWNkBpiVPcQEvhqRHw/ItaM6JySpEWIzIWn50TE9cBzu+x6f2b+Q3HMjcB0Zna9PxARR2Tmtoj4J8DXgD/NzG/Oc+waYA3Acjjphf32pGF+ARxadSOGqC79+37VDZAqkJlLKtv0DIS+PqRHIHQc+yFge2bO9HHszZk5783qJhvnvoH9azr711yD9G3oJaOIeEZEPHPPNvCHtG5GS5JqZNBhp6+PiFngZcCXIuIrxdcPj4g9zx95DnBTRNwCfBf4Umb+n0HOK0kq36CjjD4PfL7L1++lmAOUmXcB/3SJp7hi6a2rvXHuG9i/prN/zbXkvpVyD0GS1HwuXSFJAmoUCOO+DMYi+ndGRNwREZsjYu0o2ziIiFgREV+LiJ8U/332PMc9UfzsfhgRtV9JotfPIyKWR8RVxf7vRMTRo2/l0vTRt/Mj4hdtP68LqmjnUkXEJyPi/ojoOoglWj5a9P9HEfGSUbdxqfro22kR8Ujbz+6DfX1wZtbiBbwAOB64ETh5geO2ACurbu8w+gfsC/wUeB6wP3ALcELVbe+zf/8FWFtsrwX+Yp7jtlfd1kX0qefPA/i3wMeL7XOBq6pud4l9Ox9YX3VbB+jjK4GXALfOs/9M4MtAAC8FvlN1m0vs22nAFxf7ubW5QsjM2zPzjqrbMSx99u8UYHNm3pWZjwGfAVYPv3WlWA1cWWxfCbyuwraUpZ+fR3u/rwFOj2as5dLk/9f6kq3Jrws9AmM18Ols+TZwcEQcNprWDaaPvi1JbQJhEcZ5GYwjgHva3s8WX2uC52TmfcX2z2gNN+7maRFxc0R8OyLqHhr9/Dz2HpOZu4BHgENG0rrB9Pv/2h8X5ZRrImLVaJo2Mk3++9aPl0XELRHx5Yj4vX6+YdDVTheln2Uw+vDPsm0ZjIj4vznPMhijVlL/amuh/rW/ycyMiPmGrx1V/PyeB3w9IjZl5k/LbqtKcS2wITMfjYh30roSelXFbVJ/fkDr79r2iDgT+AJwbK9vGmkgZOarS/iMbcV/74+Iz9O69K1FIJTQv21A+7/Cpoqv1cJC/YuIn0fEYZl5X3HZff88n7Hn53dXseTJi2nVsuuon5/HnmNmI2I/4CCascBqz75lZns/PkHrPtE4qfXft0Fk5q/atq+LiL+OiJXZ4xEEjSoZTcAyGN8Djo2IYyJif1o3KWs/EqewEXh7sf124ClXRBHx7IhYXmyvBF4B/HhkLVy8fn4e7f0+B/h6Fnf1aq5n3zrq6WcDt4+wfaOwEXhbMdropcAjbWXPRouI5+65lxURp9D6Xd/7HypV3y1vuyv+elo1vEeBnwNfKb5+ONjGeqoAAAC/SURBVHBdsf08WqMhbgFuo1WKqbztZfWveH8mcCetfzU3qX+HADcAPwGuB1YUXz8Z+ESx/XJgU/Hz2wS8o+p299Gvp/w8gD8Hzi62nwZ8FthMa2mW51Xd5hL79pHi79ktwDeA3626zYvs3wbgPuDx4u/eO4B3Ae8q9gfwsaL/m1hgdGPdXn307aK2n923gZf387nOVJYkAQ0rGUmShsdAkCQBBoIkqWAgSJIAA0GSVDAQJEmAgSBJKhgIkiQA/j9gWK2VdEHe5wAAAABJRU5ErkJggg==\n", 543 | "text/plain": [ 544 | "
" 545 | ] 546 | }, 547 | "metadata": { 548 | "needs_background": "light" 549 | }, 550 | "output_type": "display_data" 551 | } 552 | ], 553 | "source": [ 554 | "fig, ax = plt.subplots(figsize=(6,6))\n", 555 | "xmin = -1.5\n", 556 | "xmax = 1.5\n", 557 | "ymin = -1.5\n", 558 | "ymax = 1.5\n", 559 | "ax.set_xlim([xmin,xmax])\n", 560 | "ax.set_ylim([ymin,ymax])\n", 561 | "\n", 562 | "artificial_input = artificial_input.type(torch.FloatTensor).to(args.device)\n", 563 | "\n", 564 | "m_test, label_test = message_gen_tmp(args.bit_num, num_data_points*num_data_points)\n", 565 | "m_test = m_test.type(torch.FloatTensor).to(args.device)\n", 566 | "label_test = label_test.type(torch.LongTensor).to(args.device)\n", 567 | "out_test, modulated_symbol, received_signal = net_for_vis(m_test, h, Noise, args.device, if_RTN, artificial_input)\n", 568 | "artificial_input = artificial_input.data.cpu().numpy()\n", 569 | "color_list = ['r', 'g', 'b', 'k']\n", 570 | "for ind_datapoints in range(num_data_points*num_data_points):\n", 571 | " ind_max = torch.argmax(out_test[ind_datapoints]).data.cpu().numpy()\n", 572 | " ax.scatter(artificial_input[ind_datapoints, 0], artificial_input[ind_datapoints, 1], c=color_list[ind_max], label = m_ref[ind_max])\n", 573 | "plt.title('decision region')\n", 574 | "ax.grid(True)\n", 575 | "plt.show()" 576 | ] 577 | }, 578 | { 579 | "cell_type": "code", 580 | "execution_count": null, 581 | "metadata": {}, 582 | "outputs": [], 583 | "source": [] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": null, 588 | "metadata": {}, 589 | "outputs": [], 590 | "source": [] 591 | }, 592 | { 593 | "cell_type": "code", 594 | "execution_count": null, 595 | "metadata": {}, 596 | "outputs": [], 597 | "source": [] 598 | }, 599 | { 600 | "cell_type": "code", 601 | "execution_count": null, 602 | "metadata": {}, 603 | "outputs": [], 604 | "source": [] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": null, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [] 612 | } 613 | ], 614 | "metadata": { 615 | "kernelspec": { 616 | "display_name": "Python 3", 617 | "language": "python", 618 | "name": "python3" 619 | }, 620 | "language_info": { 621 | "codemirror_mode": { 622 | "name": "ipython", 623 | "version": 3 624 | }, 625 | "file_extension": ".py", 626 | "mimetype": "text/x-python", 627 | "name": "python", 628 | "nbconvert_exporter": "python", 629 | "pygments_lexer": "ipython3", 630 | "version": "3.7.4" 631 | } 632 | }, 633 | "nbformat": 4, 634 | "nbformat_minor": 2 635 | } 636 | -------------------------------------------------------------------------------- /training/meta_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data_gen.data_set import message_gen 3 | from nets.meta_net import meta_dnn 4 | 5 | def multi_task_learning(args, net, h_list_meta, writer_meta_training, Noise): 6 | meta_optimiser = torch.optim.Adam(net.parameters(), args.lr_meta_update) 7 | h_list_train = h_list_meta[:args.num_channels_meta] 8 | 9 | for epochs in range(args.num_epochs_meta_train): 10 | first_loss = 0 11 | second_loss = 0 12 | iter_in_sampled_device = 0 # for averaging meta-devices 13 | for ind_meta_dev in range(args.tasks_per_metaupdate): 14 | # during this, meta-gradients are accumulated 15 | channel_list_total = torch.randperm(len(h_list_train)) # sampling with replacement 16 | current_channel_ind = channel_list_total[ind_meta_dev] 17 | current_channel = h_list_train[current_channel_ind] 18 | if args.if_joint_training: 19 | iter_in_sampled_device, first_loss_curr, second_loss_curr = joint_training(args, iter_in_sampled_device, 20 | net, current_channel, Noise) 21 | else: # maml 22 | iter_in_sampled_device, first_loss_curr, second_loss_curr = maml(args, iter_in_sampled_device, 23 | net, current_channel, Noise) 24 | first_loss = first_loss + first_loss_curr 25 | second_loss = second_loss + second_loss_curr 26 | first_loss = first_loss / args.tasks_per_metaupdate 27 | second_loss = second_loss / args.tasks_per_metaupdate 28 | writer_meta_training.add_scalar('first loss', first_loss, epochs) 29 | writer_meta_training.add_scalar('second loss', second_loss, epochs) 30 | # meta-update 31 | meta_optimiser.zero_grad() 32 | for f in net.parameters(): 33 | f.grad = f.total_grad.clone() / args.tasks_per_metaupdate 34 | meta_optimiser.step() # Adam 35 | 36 | def maml(args, iter_in_sampled_device, net, current_channel, Noise): 37 | net.zero_grad() 38 | para_list_from_net = list(map(lambda p: p[0], zip(net.parameters()))) 39 | net_meta_intermediate = meta_dnn(if_relu = args.if_relu) 40 | 41 | for inner_loop in range(args.num_meta_local_updates): 42 | if inner_loop == 0: 43 | m, label = message_gen(args.bit_num, args.mb_size_meta_train) 44 | m = m.type(torch.FloatTensor).to(args.device) 45 | label = label.type(torch.LongTensor).to(args.device) 46 | out = net_meta_intermediate(m, para_list_from_net, args.if_bias, current_channel, args.device, Noise, args.if_RTN) 47 | loss = torch.nn.functional.cross_entropy(out, label) 48 | first_loss_curr = float(loss) 49 | grad = torch.autograd.grad(loss, para_list_from_net, create_graph=True) 50 | intermediate_updated_para_list = list(map(lambda p: p[1] - args.lr_meta_inner * p[0], zip(grad, para_list_from_net))) 51 | else: 52 | m, label = message_gen(args.bit_num, args.mb_size_meta_train) 53 | m = m.type(torch.FloatTensor).to(args.device) 54 | label = label.type(torch.LongTensor).to(args.device) 55 | out = net_meta_intermediate(m, intermediate_updated_para_list, args.if_bias, current_channel, 56 | args.device, Noise, args.if_RTN) 57 | loss = torch.nn.functional.cross_entropy(out, label) 58 | grad = torch.autograd.grad(loss, intermediate_updated_para_list, create_graph=True) 59 | intermediate_updated_para_list = list(map(lambda p: p[1] - args.lr_meta_inner * p[0], zip(grad, intermediate_updated_para_list))) 60 | ########### 61 | #### meta-update 62 | m, label = message_gen(args.bit_num, args.mb_size_meta_test) 63 | m = m.type(torch.FloatTensor).to(args.device) 64 | label = label.type(torch.LongTensor).to(args.device) 65 | out = net_meta_intermediate(m, intermediate_updated_para_list, args.if_bias, current_channel, 66 | args.device, Noise, args.if_RTN) 67 | loss = torch.nn.functional.cross_entropy(out, label) 68 | second_loss_curr = float(loss) 69 | para_list_grad = torch.autograd.grad(loss, para_list_from_net, create_graph=False) 70 | ind_f_para_list = 0 71 | for f in net.parameters(): 72 | if iter_in_sampled_device == 0: 73 | f.total_grad = para_list_grad[ind_f_para_list].data.clone() 74 | else: 75 | f.total_grad = f.total_grad + para_list_grad[ind_f_para_list].data.clone() 76 | ind_f_para_list += 1 77 | iter_in_sampled_device = iter_in_sampled_device + 1 78 | return iter_in_sampled_device, first_loss_curr, second_loss_curr 79 | 80 | 81 | def joint_training(args, iter_in_sampled_device, net, current_channel, Noise): 82 | net.zero_grad() 83 | para_list_from_net = list(map(lambda p: p[0], zip(net.parameters()))) 84 | net_meta_intermediate = meta_dnn(if_relu = args.if_relu) 85 | 86 | m, label = message_gen(args.bit_num, args.mb_size_meta_test) 87 | m = m.type(torch.FloatTensor).to(args.device) 88 | label = label.type(torch.LongTensor).to(args.device) 89 | out = net_meta_intermediate(m, para_list_from_net, args.if_bias, current_channel, args.device, Noise, args.if_RTN) 90 | loss = torch.nn.functional.cross_entropy(out, label) 91 | first_loss_curr = float(loss) 92 | grad = torch.autograd.grad(loss, para_list_from_net, create_graph=False) 93 | second_loss_curr = first_loss_curr 94 | 95 | ind_f_para_list = 0 96 | for f in net.parameters(): 97 | if iter_in_sampled_device == 0: 98 | f.total_grad = grad[ind_f_para_list].data.clone() 99 | else: 100 | f.total_grad = f.total_grad + grad[ind_f_para_list].data.clone() 101 | ind_f_para_list += 1 102 | iter_in_sampled_device = iter_in_sampled_device + 1 103 | return iter_in_sampled_device, first_loss_curr, second_loss_curr -------------------------------------------------------------------------------- /training/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data_gen.data_set import message_gen 3 | 4 | def test_per_channel_per_snr(args, h, net_for_testtraining, test_snr, actual_channel_num, PATH_after_adapt, if_val): 5 | if torch.cuda.is_available(): 6 | net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt)) 7 | else: 8 | net_for_testtraining.load_state_dict(torch.load(PATH_after_adapt, map_location = torch.device('cpu'))) 9 | 10 | 11 | batch_size = args.test_size 12 | success_test = 0 13 | Eb_over_N_test = pow(10, (test_snr / 10)) 14 | R = args.bit_num / args.channel_num 15 | noise_var_test = 1 / (2 * R * Eb_over_N_test) 16 | 17 | Noise_test = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(actual_channel_num), 18 | noise_var_test * torch.eye( 19 | actual_channel_num)) 20 | m_test, label_test = message_gen(args.bit_num, batch_size) 21 | m_test = m_test.type(torch.FloatTensor).to(args.device) 22 | label_test = label_test.type(torch.LongTensor).to(args.device) 23 | 24 | out_test = net_for_testtraining(m_test, h, Noise_test, args.device, args.if_RTN) 25 | for ind_mb in range(label_test.shape[0]): 26 | assert label_test.shape[0] == batch_size 27 | if torch.argmax(out_test[ind_mb]) == label_test[ind_mb]: # means correct classification 28 | success_test += 1 29 | else: 30 | pass 31 | accuracy = success_test / label_test.shape[0] 32 | if not if_val: 33 | print('for snr: ', test_snr, 'bler: ', 1 - accuracy) 34 | 35 | return 1 - accuracy 36 | 37 | 38 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data_gen.data_set import message_gen 3 | 4 | def test_training(args, h, net_for_testtraining, Noise, PATH_before_adapt, PATH_after_adapt, adapt_steps): #PATH_before_adapt can be meta-learneds 5 | # initialize network (net_for_testtraining) (net is for meta-training) 6 | if torch.cuda.is_available(): 7 | net_for_testtraining.load_state_dict(torch.load(PATH_before_adapt)) 8 | else: 9 | net_for_testtraining.load_state_dict(torch.load(PATH_before_adapt, map_location = torch.device('cpu'))) 10 | if args.if_test_training_adam and not args.if_adam_after_sgd: 11 | testtraining_optimiser = torch.optim.Adam(net_for_testtraining.parameters(), args.lr_testtraining) 12 | else: 13 | pass 14 | 15 | num_adapt = adapt_steps 16 | for epochs in range(num_adapt): 17 | m, label = message_gen(args.bit_num, args.mb_size) 18 | m = m.type(torch.FloatTensor).to(args.device) 19 | label = label.type(torch.LongTensor).to(args.device) 20 | for f in net_for_testtraining.parameters(): 21 | if f.grad is not None: 22 | f.grad.detach() 23 | f.grad.zero_() 24 | 25 | out = net_for_testtraining(m, h, Noise, args.device, args.if_RTN) 26 | loss = torch.nn.functional.cross_entropy(out, label) 27 | # grad calculation 28 | loss.backward() 29 | ### adapt (update) parameter 30 | if args.if_test_training_adam: 31 | if args.if_adam_after_sgd: 32 | if epochs < args.num_meta_local_updates: 33 | for f in net_for_testtraining.parameters(): 34 | if f.grad is not None: 35 | f.data.sub_(f.grad.data * args.lr_meta_inner) 36 | elif epochs == args.num_meta_local_updates: 37 | testtraining_optimiser = torch.optim.Adam(net_for_testtraining.parameters(), 38 | args.lr_testtraining) 39 | testtraining_optimiser.step() 40 | else: 41 | testtraining_optimiser.step() 42 | else: 43 | testtraining_optimiser.step() 44 | else: 45 | for f in net_for_testtraining.parameters(): 46 | if f.grad is not None: 47 | f.data.sub_(f.grad.data * args.lr_testtraining) 48 | # saved adapted network for calculate BLER 49 | torch.save(net_for_testtraining.state_dict(), PATH_after_adapt) 50 | 51 | 52 | -------------------------------------------------------------------------------- /utils/funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def complex_mul(h, x): # h fixed on batch, x has multiple batch 4 | if len(h.shape) == 1: 5 | # h is same over all messages (if estimated h, it is averaged) 6 | y = torch.zeros(x.shape[0], 2, dtype=torch.float) 7 | y[:, 0] = x[:, 0] * h[0] - x[:, 1] * h[1] 8 | y[:, 1] = x[:, 0] * h[1] + x[:, 1] * h[0] 9 | elif len(h.shape) == 2: 10 | # h_estimated is not averaged 11 | assert x.shape[0] == h.shape[0] 12 | y = torch.zeros(x.shape[0], 2, dtype=torch.float) 13 | y[:, 0] = x[:, 0] * h[:, 0] - x[:, 1] * h[:, 1] 14 | y[:, 1] = x[:, 0] * h[:, 1] + x[:, 1] * h[:, 0] 15 | else: 16 | print('h shape length need to be either 1 or 2') 17 | raise NotImplementedError 18 | return y 19 | 20 | 21 | def complex_mul_taps(h, x_tensor): 22 | if len(h.shape) == 1: 23 | L = h.shape[0] // 2 # length/2 of channel vector means number of taps 24 | elif len(h.shape) == 2: 25 | L = h.shape[1] // 2 # length/2 of channel vector means number of taps 26 | else: 27 | print('h shape length need to be either 1 or 2') 28 | raise NotImplementedError 29 | y = torch.zeros(x_tensor.shape[0], x_tensor.shape[1], dtype=torch.float) 30 | assert x_tensor.shape[1] % 2 == 0 31 | for ind_channel_use in range(x_tensor.shape[1]//2): 32 | for ind_conv in range(min(L, ind_channel_use+1)): 33 | if len(h.shape) == 1: 34 | y[:, (ind_channel_use) * 2:(ind_channel_use + 1) * 2] += complex_mul(h[2*ind_conv:2*(ind_conv+1)], x_tensor[:, (ind_channel_use-ind_conv)*2:(ind_channel_use-ind_conv+1)*2]) 35 | else: 36 | y[:, (ind_channel_use) * 2:(ind_channel_use + 1) * 2] += complex_mul( 37 | h[:, 2 * ind_conv:2 * (ind_conv + 1)], 38 | x_tensor[:, (ind_channel_use - ind_conv) * 2:(ind_channel_use - ind_conv + 1) * 2]) 39 | 40 | return y 41 | 42 | def complex_conv_transpose(h_trans, y_tensor): # takes the role of inverse filtering 43 | assert len(y_tensor.shape) == 2 # batch 44 | assert y_tensor.shape[1] % 2 == 0 45 | assert h_trans.shape[0] % 2 == 0 46 | if len(h_trans.shape) == 1: 47 | L = h_trans.shape[0] // 2 48 | elif len(h_trans.shape) == 2: 49 | L = h_trans.shape[1] // 2 50 | else: 51 | print('h shape length need to be either 1 or 2') 52 | 53 | deconv_y = torch.zeros(y_tensor.shape[0], y_tensor.shape[1] + 2*(L-1), dtype=torch.float) 54 | for ind_y in range(y_tensor.shape[1]//2): 55 | ind_y_deconv = ind_y + (L-1) 56 | for ind_conv in range(L): 57 | if len(h_trans.shape) == 1: 58 | deconv_y[:, 2*(ind_y_deconv - ind_conv):2*(ind_y_deconv - ind_conv+1)] += complex_mul(h_trans[2*ind_conv:2*(ind_conv+1)] , y_tensor[:,2*ind_y:2*(ind_y+1)]) 59 | else: 60 | deconv_y[:, 2 * (ind_y_deconv - ind_conv):2 * (ind_y_deconv - ind_conv + 1)] += complex_mul( 61 | h_trans[:, 2 * ind_conv:2 * (ind_conv + 1)], y_tensor[:, 2 * ind_y:2 * (ind_y + 1)]) 62 | return deconv_y[:, 2*(L-1):] 63 | 64 | 65 | --------------------------------------------------------------------------------