├── LICENSE ├── README.md ├── Results ├── Polar(64,7) │ └── NN_EncFull_Skip+Dec_SC │ │ └── Enc_snr_-0.5_Dec_snr-2.5 │ │ └── Batch_20000 │ │ └── Models │ │ ├── Decoder_NN_4300.pt │ │ └── Encoder_NN_4300.pt ├── RM(6,1) │ └── NN_EncFull_Skip+Dec_Dumer │ │ └── Enc_snr_-2.0_Dec_snr-4.0 │ │ └── Batch_32768 │ │ └── Models │ │ ├── Decoder_NN_7400.pt │ │ └── Encoder_NN_7400.pt ├── RM(6,1)_softmap │ └── NN_EncFull_Skip+Dec_Dumer │ │ └── Enc_snr_-2.0 │ │ └── Batch_32768 │ │ └── Models │ │ └── Encoder_NN_11100.pt ├── RM(8,2) │ └── fullNN_Enc+fullNN_Dec │ │ └── Enc_snr_0.0_Dec_snr-2.0 │ │ └── Batch_100000 │ │ └── Models │ │ ├── Decoder_NN.pt │ │ └── Encoder_NN.pt └── RM(9,2) │ └── fullNN_Enc+fullNN_Dec │ └── Enc_snr_-2.0_Dec_snr-4.0 │ └── Batch_40000 │ └── Models │ ├── Decoder_NN.pt │ └── Encoder_NN.pt ├── data ├── 1 │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt ├── 2 │ ├── CodebookIndex_this_matrix_One_MinusOne.pt │ ├── CodebookIndex_this_matrix_Zero_PlusOne.pt │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt ├── 3 │ ├── CodebookIndex_this_matrix_One_MinusOne.pt │ ├── CodebookIndex_this_matrix_Zero_PlusOne.pt │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt ├── 4 │ ├── CodebookIndex_this_matrix_One_MinusOne.pt │ ├── CodebookIndex_this_matrix_Zero_PlusOne.pt │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt ├── 5 │ ├── CodebookIndex_this_matrix_One_MinusOne.pt │ ├── CodebookIndex_this_matrix_Zero_PlusOne.pt │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt ├── 6 │ ├── CodebookIndex_this_matrix_One_MinusOne.pt │ ├── CodebookIndex_this_matrix_Zero_PlusOne.pt │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt ├── 7 │ ├── CodebookIndex_this_matrix_One_MinusOne.pt │ ├── CodebookIndex_this_matrix_Zero_PlusOne.pt │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt └── 8 │ ├── CodebookIndex_this_matrix_One_MinusOne.pt │ ├── CodebookIndex_this_matrix_Zero_PlusOne.pt │ ├── Mul_this_matrix_Ind_One.pt │ └── Mul_this_matrix_Ind_Zero.pt ├── data_loader.py ├── reed_muller_modules ├── all_functions.py ├── comm_utils.py ├── hadamard.py ├── logging_utils.py └── reedmuller_codebook.py ├── test_KO_m1_dumer.py ├── test_KO_m1_map.py ├── test_KO_m2.py ├── test_Polar_m6k7.py ├── train_KO_m1_dumer.py ├── train_KO_m1_map.py ├── train_KO_m2.py └── train_Polar_m6k7.py /LICENSE: -------------------------------------------------------------------------------- 1 | Academic License Agreement 2 | The (KOcode) software ("Software") has been developed by researchers at the University of Washington, the University of Michigan, and the University of Illinois ("Developers") and made available through the University of Washington ("UW") for your internal, non-profit research use. 3 | UW and the Developers allow researchers at your institution, non-exclusively and at no cost, to run, display internally, copy and modify Software on the following conditions: 4 | 1. You are faculty member of an institution of higher education or a non-profit research institute. 5 | 2. The Software remains at your institution and is not published, distributed, or otherwise transferred or made available to other than institution employees and students involved in research under your supervision. 6 | 2. You agree to make results generated using Software available to other academic researchers for non-profit research purposes. If you wish to obtain Software for any commercial purposes, including fee-based service projects, you will need to execute a separate licensing agreement with the University of Washington and pay a fee. In that case please contact: license@uw.edu. 7 | 3. You retain in Software and any modifications to Software, the copyright, trademark, or other notices pertaining to Software as provided by UW and Developers. 8 | 4. You provide the Developers with feedback on the use of the Software in your research, and that the Developers and UW are permitted to use any information you provide in making changes to the Software. All bug reports and technical questions shall be sent to the email address: (xiyangl@cs.washington.edu). 9 | 5. You acknowledge that the Developers, UW and its licensees may develop modifications to Software that may be substantially similar to your modifications of Software, and that the Developers, UW and its licensees shall not be constrained in any way by you in Developer's, UW's or its licensees' use or management of such modifications. You acknowledge the right of the Developers and UW to prepare and publish modifications to Software that may be substantially similar or functionally equivalent to your modifications and improvements, and if you obtain patent protection for any modification or improvement to Software you agree not to allege or enjoin infringement of your patent by the Developers, UW or by any of UW's licensees obtaining modifications or improvements to Software from the UW or the Developers. 10 | 6. You agree to acknowledge the contribution Developers and Software make to your research, and cite appropriate references about the Software in your publications. 11 | 7. Any risk associated with using the Software at your institution is with you and your institution. Software is experimental in nature and is made available as a research courtesy "AS IS," without obligation by UW or the University of Illinois to provide accompanying services or support. 12 | 8. UW AND THE UNIVERSITY OF ILLINOIS, AND THE DEVELOPERS, EXPRESSLY DISCLAIM ANY AND ALL WARRANTIES REGARDING THE SOFTWARE, WHETHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO WARRANTIES PERTAINING TO NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. 13 | 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Dependency 4 | - numpy (1.14.1) 5 | - pytorch (1.4) 6 | 7 | 8 | 9 | ## KO(8,2) 10 | 11 | - To train KO(8,2) run default setting by: 12 | $ python train_KO_m2.py --gpu 0 --m 8 --enc_train_snr 0 --dec_train_snr -2 --batch_size 100000 13 | 14 | 15 | - To test KO(8,2) : 16 | python test_KO_m2.py --gpu 0 --m 8 --enc_train_snr 0 --dec_train_snr -2 --batch_size 100000 17 | 18 | ## KO(9,2) 19 | 20 | - To train KO(9,2) run default setting by: 21 | $ python train_KO_m2.py --gpu 0 --m 9 --enc_train_snr -2 --dec_train_snr -4 --batch_size 50000 22 | 23 | 24 | - To test KO(9,2) : 25 | python test_KO_m2.py --gpu 0 --m 9 --enc_train_snr -2 --dec_train_snr -4 --batch_size 50000 26 | 27 | 28 | ## KO(6,1) 29 | 30 | - To train KO(6,1) run default setting by: 31 | $ python train_KO_m1_dumer.py --gpu 0 --m 6 32 | 33 | 34 | - To test KO(6,1) : 35 | python test_KO_m1_dumer.py --gpu 0 --m 6 36 | 37 | 38 | ## KO(6,1) with MAP decoding 39 | 40 | - To train KO(6,1) with MAP decoding run default setting by: 41 | $ python train_KO_m1_map.py --gpu 0 --m 6 42 | 43 | 44 | - To test KO(6,1) with MAP decoding: 45 | python test_KO_m1_map.py --gpu 0 --m 6 46 | 47 | 48 | ## Polar(64,7) 49 | 50 | - To train Polar(64,7) run default setting by: 51 | $ python train_Polar_m6k7.py --gpu 0 52 | 53 | 54 | - To test Polar(64,7): 55 | python test_Polar_m6k7.py --gpu 0 -------------------------------------------------------------------------------- /Results/Polar(64,7)/NN_EncFull_Skip+Dec_SC/Enc_snr_-0.5_Dec_snr-2.5/Batch_20000/Models/Decoder_NN_4300.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/Polar(64,7)/NN_EncFull_Skip+Dec_SC/Enc_snr_-0.5_Dec_snr-2.5/Batch_20000/Models/Decoder_NN_4300.pt -------------------------------------------------------------------------------- /Results/Polar(64,7)/NN_EncFull_Skip+Dec_SC/Enc_snr_-0.5_Dec_snr-2.5/Batch_20000/Models/Encoder_NN_4300.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/Polar(64,7)/NN_EncFull_Skip+Dec_SC/Enc_snr_-0.5_Dec_snr-2.5/Batch_20000/Models/Encoder_NN_4300.pt -------------------------------------------------------------------------------- /Results/RM(6,1)/NN_EncFull_Skip+Dec_Dumer/Enc_snr_-2.0_Dec_snr-4.0/Batch_32768/Models/Decoder_NN_7400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/RM(6,1)/NN_EncFull_Skip+Dec_Dumer/Enc_snr_-2.0_Dec_snr-4.0/Batch_32768/Models/Decoder_NN_7400.pt -------------------------------------------------------------------------------- /Results/RM(6,1)/NN_EncFull_Skip+Dec_Dumer/Enc_snr_-2.0_Dec_snr-4.0/Batch_32768/Models/Encoder_NN_7400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/RM(6,1)/NN_EncFull_Skip+Dec_Dumer/Enc_snr_-2.0_Dec_snr-4.0/Batch_32768/Models/Encoder_NN_7400.pt -------------------------------------------------------------------------------- /Results/RM(6,1)_softmap/NN_EncFull_Skip+Dec_Dumer/Enc_snr_-2.0/Batch_32768/Models/Encoder_NN_11100.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/RM(6,1)_softmap/NN_EncFull_Skip+Dec_Dumer/Enc_snr_-2.0/Batch_32768/Models/Encoder_NN_11100.pt -------------------------------------------------------------------------------- /Results/RM(8,2)/fullNN_Enc+fullNN_Dec/Enc_snr_0.0_Dec_snr-2.0/Batch_100000/Models/Decoder_NN.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/RM(8,2)/fullNN_Enc+fullNN_Dec/Enc_snr_0.0_Dec_snr-2.0/Batch_100000/Models/Decoder_NN.pt -------------------------------------------------------------------------------- /Results/RM(8,2)/fullNN_Enc+fullNN_Dec/Enc_snr_0.0_Dec_snr-2.0/Batch_100000/Models/Encoder_NN.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/RM(8,2)/fullNN_Enc+fullNN_Dec/Enc_snr_0.0_Dec_snr-2.0/Batch_100000/Models/Encoder_NN.pt -------------------------------------------------------------------------------- /Results/RM(9,2)/fullNN_Enc+fullNN_Dec/Enc_snr_-2.0_Dec_snr-4.0/Batch_40000/Models/Decoder_NN.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/RM(9,2)/fullNN_Enc+fullNN_Dec/Enc_snr_-2.0_Dec_snr-4.0/Batch_40000/Models/Decoder_NN.pt -------------------------------------------------------------------------------- /Results/RM(9,2)/fullNN_Enc+fullNN_Dec/Enc_snr_-2.0_Dec_snr-4.0/Batch_40000/Models/Encoder_NN.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/Results/RM(9,2)/fullNN_Enc+fullNN_Dec/Enc_snr_-2.0_Dec_snr-4.0/Batch_40000/Models/Encoder_NN.pt -------------------------------------------------------------------------------- /data/1/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/1/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/1/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/1/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data/2/CodebookIndex_this_matrix_One_MinusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/2/CodebookIndex_this_matrix_One_MinusOne.pt -------------------------------------------------------------------------------- /data/2/CodebookIndex_this_matrix_Zero_PlusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/2/CodebookIndex_this_matrix_Zero_PlusOne.pt -------------------------------------------------------------------------------- /data/2/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/2/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/2/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/2/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data/3/CodebookIndex_this_matrix_One_MinusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/3/CodebookIndex_this_matrix_One_MinusOne.pt -------------------------------------------------------------------------------- /data/3/CodebookIndex_this_matrix_Zero_PlusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/3/CodebookIndex_this_matrix_Zero_PlusOne.pt -------------------------------------------------------------------------------- /data/3/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/3/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/3/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/3/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data/4/CodebookIndex_this_matrix_One_MinusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/4/CodebookIndex_this_matrix_One_MinusOne.pt -------------------------------------------------------------------------------- /data/4/CodebookIndex_this_matrix_Zero_PlusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/4/CodebookIndex_this_matrix_Zero_PlusOne.pt -------------------------------------------------------------------------------- /data/4/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/4/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/4/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/4/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data/5/CodebookIndex_this_matrix_One_MinusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/5/CodebookIndex_this_matrix_One_MinusOne.pt -------------------------------------------------------------------------------- /data/5/CodebookIndex_this_matrix_Zero_PlusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/5/CodebookIndex_this_matrix_Zero_PlusOne.pt -------------------------------------------------------------------------------- /data/5/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/5/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/5/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/5/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data/6/CodebookIndex_this_matrix_One_MinusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/6/CodebookIndex_this_matrix_One_MinusOne.pt -------------------------------------------------------------------------------- /data/6/CodebookIndex_this_matrix_Zero_PlusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/6/CodebookIndex_this_matrix_Zero_PlusOne.pt -------------------------------------------------------------------------------- /data/6/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/6/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/6/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/6/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data/7/CodebookIndex_this_matrix_One_MinusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/7/CodebookIndex_this_matrix_One_MinusOne.pt -------------------------------------------------------------------------------- /data/7/CodebookIndex_this_matrix_Zero_PlusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/7/CodebookIndex_this_matrix_Zero_PlusOne.pt -------------------------------------------------------------------------------- /data/7/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/7/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/7/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/7/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data/8/CodebookIndex_this_matrix_One_MinusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/8/CodebookIndex_this_matrix_One_MinusOne.pt -------------------------------------------------------------------------------- /data/8/CodebookIndex_this_matrix_Zero_PlusOne.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/8/CodebookIndex_this_matrix_Zero_PlusOne.pt -------------------------------------------------------------------------------- /data/8/Mul_this_matrix_Ind_One.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/8/Mul_this_matrix_Ind_One.pt -------------------------------------------------------------------------------- /data/8/Mul_this_matrix_Ind_Zero.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepcomm/KOcodes/c412e35b94879e6ad69ff2773d6f65a2b62b18a0/data/8/Mul_this_matrix_Ind_Zero.pt -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.utils.data as data 4 | 5 | from torchvision import datasets 6 | from torchvision import transforms 7 | from torch.distributions.multivariate_normal import MultivariateNormal 8 | 9 | def get_loader(config): 10 | tf = transforms.Compose([transforms.Resize(28), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 13 | mnist = datasets.MNIST(root=config.mnist_path, train=True, download=True, transform=tf) 14 | mnist_loader = data.DataLoader(dataset=mnist, batch_size=config.batch_size, 15 | shuffle=True, 16 | num_workers=4) 17 | r_loader = RealDataGenerator(mnist_loader) 18 | mu, cov = compute_mnist_stats(mnist) 19 | z_loader = MVGaussianGenerator(config.batch_size, mu, cov) 20 | return r_loader, z_loader 21 | 22 | class DataGenerator(object): 23 | "superclass of all data. WARNING: doesn't raise StopIteration so it loops forever!" 24 | 25 | def __iter__(self): 26 | return self 27 | 28 | def __next__(self): 29 | return self.get_batch() 30 | 31 | def get_batch(self): 32 | raise NotImplementedError() 33 | 34 | def float_tensor(self, batch): 35 | return torch.from_numpy(batch).type(torch.FloatTensor) 36 | 37 | class StandardGaussianGenerator(DataGenerator): 38 | """samples from a multivariate gaussian""" 39 | def __init__(self, batch_size, mu, cov, lambda_identity=1.0): 40 | self.batch_size = batch_size 41 | cov = cov + lambda_identity * torch.eye(cov.size(0)) * 1e-1 42 | self.generator = MultivariateNormal(mu, cov) 43 | 44 | def get_batch(self): 45 | return self.generator.sample((self.batch_size,)).view(self.batch_size, -1) 46 | 47 | class RealDataGeneratorDummy(DataGenerator): 48 | """samples from real data""" 49 | def __init__(self, loader): 50 | self.loader = loader 51 | self.generator = iter(self.loader) 52 | self.data_len = len(self.loader) 53 | self.count = 0 54 | 55 | def get_batch(self): 56 | if (((self.count + 1) % self.data_len) == 0): 57 | del self.generator 58 | self.generator = iter(self.loader) 59 | self.count += 1 60 | return next(self.generator) 61 | 62 | 63 | class MVGaussianGenerator(DataGenerator): 64 | """samples from a multivariate gaussian""" 65 | def __init__(self, batch_size, mu, cov, lambda_identity=1.0): 66 | self.batch_size = batch_size 67 | self.image_size = 28 68 | cov = cov + lambda_identity * torch.eye(cov.size(0)) * 1e-1 69 | self.generator = MultivariateNormal(mu, cov) 70 | 71 | def get_batch(self): 72 | return self.generator.sample((self.batch_size,)).view(self.batch_size, 1, self.image_size, self.image_size) 73 | 74 | class RealDataGenerator(DataGenerator): 75 | """samples from real data""" 76 | def __init__(self, loader): 77 | self.loader = loader 78 | self.generator = iter(self.loader) 79 | self.data_len = len(self.loader) 80 | self.count = 0 81 | 82 | def get_batch(self): 83 | if (((self.count + 1) % self.data_len) == 0): 84 | del self.generator 85 | self.generator = iter(self.loader) 86 | self.count += 1 87 | return next(self.generator)[0] 88 | 89 | def compute_mnist_stats(mnist_dataset): 90 | loader = data.DataLoader(dataset=mnist_dataset, batch_size=60000, num_workers=8) 91 | mnist = next(iter(loader))[0] 92 | mnist = mnist.view(60000, -1).t().numpy() 93 | mnist_mean = np.mean(mnist, axis=1) 94 | mnist_cov = np.cov(mnist) 95 | return torch.from_numpy(mnist_mean).type(torch.FloatTensor), \ 96 | torch.from_numpy(mnist_cov).type(torch.FloatTensor) -------------------------------------------------------------------------------- /reed_muller_modules/all_functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | from torch.autograd import Variable 5 | import torch.utils.data 6 | 7 | 8 | from opt_einsum import contract # This is for faster torch.einsum 9 | 10 | import math 11 | import numpy as np 12 | 13 | import itertools 14 | from itertools import combinations 15 | 16 | import os 17 | 18 | 19 | ######################### 20 | ### These are imported from comm_utils 21 | ######################### 22 | 23 | # note there are a few definitions of SNR. In our result, we stick to the following SNR setup. 24 | def snr_db2sigma(train_snr): 25 | return 10**(-train_snr*1.0/20) 26 | 27 | 28 | def snr_sigma2db(train_sigma): 29 | try: 30 | return -20.0 * math.log(train_sigma, 10) 31 | except: 32 | return -20.0 * torch.log10(train_sigma) 33 | 34 | 35 | ################################## 36 | 37 | 38 | def to_var(x, can_I_use_cuda, requires_grad=False): 39 | 40 | """Converts torch tensor to variable.""" 41 | 42 | if can_I_use_cuda: 43 | x = x.cuda() 44 | 45 | return Variable(x, requires_grad=requires_grad) 46 | 47 | 48 | def to_data(x): 49 | """Converts variable to numpy.""" 50 | if torch.cuda.is_available(): 51 | x = x.cpu() 52 | return x.data.numpy() 53 | 54 | 55 | def numpy_to_torch(tensor): 56 | 57 | return torch.from_numpy(tensor).float() 58 | 59 | 60 | def Return_Index(M, S, b, m): 61 | 62 | # All of them are numpy arrays. Returns the indices a \in {0,1}^m such that a_S_bar = b 63 | S_bar = np.setdiff1d(np.arange(m), S) 64 | 65 | return np.flatnonzero(((M[:, S_bar] + b) % 2).sum(1) == 0 ) 66 | 67 | 68 | def integer_to_binary(integer, binary_base): 69 | 70 | ans = format(integer, '0{0}b'.format(binary_base)) 71 | 72 | ans = numpy_to_torch(np.array([int(each_bit) for each_bit in ans])) 73 | 74 | return ans 75 | 76 | 77 | def all_integers_to_binary(m): 78 | 79 | M = torch.zeros(2**m, m) 80 | 81 | for i in range(M.shape[0]): 82 | M[i, :] = integer_to_binary(i, m) 83 | 84 | return M 85 | 86 | 87 | def all_binary_binary_dot_products(m, can_I_use_cuda): 88 | 89 | all_dot_products_matrix = to_var(all_integers_to_binary(m), can_I_use_cuda) 90 | 91 | return all_dot_products_matrix.mm(all_dot_products_matrix.t()) % 2 92 | 93 | 94 | def binary_to_integer(binary_string): 95 | 96 | return int(binary_string, 2) 97 | 98 | 99 | def fixed_vector_permutation_of_indices(m, integer_corresp_permutation): 100 | 101 | new_indices = [i^integer_corresp_permutation for i in range(2**m)] 102 | 103 | return torch.eye(2**m)[new_indices, :] 104 | 105 | 106 | def all_vectors_permutation_of_indices(m): 107 | 108 | # First one is dummy. Hence it's really (2**m -1 , 2**m, 2**(m-1)) 109 | all_permutations_tensor = torch.zeros(2**m, 2**m, 2**m) 110 | 111 | for integer_corresp_permutation in range(1, 2**m): 112 | 113 | all_permutations_tensor[integer_corresp_permutation, :, :] = fixed_vector_permutation_of_indices(m, integer_corresp_permutation) 114 | 115 | return all_permutations_tensor 116 | 117 | 118 | def fixed_vector_coset_projection_matrix(m, integer_corresp_projection): 119 | 120 | M = torch.zeros(2**m, 2**(m-1)) 121 | 122 | current_column = 0 123 | 124 | for integer_corresp_coset in range(2**m): 125 | if sum(M[integer_corresp_coset, :]) == 0: 126 | M[integer_corresp_coset, current_column] = 1 127 | 128 | remaining_element_in_coset = integer_corresp_coset^integer_corresp_projection 129 | 130 | M[remaining_element_in_coset, current_column] = 1 131 | 132 | current_column += 1 133 | 134 | return M 135 | 136 | 137 | def all_vectors_coset_projection_tensor(m): 138 | 139 | # First one is dummy. Hence it's really (2**m -1 , 2**m, 2**(m-1)) 140 | all_projections_tensor = torch.zeros(2**m, 2**m, 2**(m-1)) 141 | 142 | for integer_corresp_projection in range(1, 2**m): 143 | 144 | all_projections_tensor[integer_corresp_projection, :, :] = fixed_vector_coset_projection_matrix(m, integer_corresp_projection) 145 | 146 | return all_projections_tensor 147 | 148 | ############################# 149 | ## s-dim stuff 150 | ############################# 151 | 152 | def do_xor(list_elems, shift_elem): 153 | 154 | #candidate_elem = 0 is the default. 155 | 156 | return [elem^shift_elem for elem in list_elems] 157 | 158 | 159 | def find_xor(B, binary_string): 160 | 161 | assert(len(B) == len(binary_string)) 162 | 163 | ans = 0 164 | 165 | for (i, basis) in enumerate(B): 166 | if binary_string[i] == 1: 167 | ans = ans^basis 168 | 169 | return ans 170 | 171 | 172 | def get_basis_for_all_s_dim_subspaces(m, s, shift_elem): 173 | 174 | basis_all_s_dim_subspaces = [do_xor(list(i), shift_elem) for i in combinations(set([2**(i) for i in range(m)]),s)] 175 | 176 | return basis_all_s_dim_subspaces 177 | 178 | 179 | def get_subspace_given_basis(B): 180 | 181 | # B is of size s. 182 | s = len(B) 183 | 184 | all_binary_strings_of_length_s = all_integers_to_binary(s).long() 185 | 186 | subspace = [] 187 | 188 | for i in range(all_binary_strings_of_length_s.shape[0]): 189 | subspace.append(find_xor(B, all_binary_strings_of_length_s[i, :])) 190 | 191 | return subspace 192 | 193 | 194 | def fixed_s_subspace_coset_projection_matrix(m, s, s_dim_subspace): 195 | 196 | # Finds the coset projection matrix for the given subspace. 197 | 198 | # print(s_dim_subspace) 199 | M = torch.zeros(2**m, 2**(m-s)) 200 | 201 | current_column = 0 202 | 203 | for integer_corresp_coset in range(2**m): 204 | 205 | if sum(M[integer_corresp_coset, :]) == 0: 206 | 207 | all_elements_in_the_coset = [integer_corresp_coset^subspace_vector for subspace_vector in s_dim_subspace] 208 | 209 | M[all_elements_in_the_coset, current_column] = 1 210 | 211 | current_column += 1 212 | 213 | return M 214 | 215 | 216 | def all_s_subspace_coset_projection_tensor(m, s, shift_elem=0): 217 | 218 | # s = r-1 for RM direct projection to order-1 codes. 219 | assert(s <= m) 220 | 221 | basis_all_s_dim_subspaces = get_basis_for_all_s_dim_subspaces(m, s, shift_elem) 222 | 223 | m_choose_s = len(basis_all_s_dim_subspaces) 224 | 225 | # We only have (m s) projections. Input LLR is of shape 2**m. Output LLR is of shape 2**(m-s) 226 | all_cordinate_s_dim_projections_tensor = torch.zeros(m_choose_s, 2**m, 2**(m-s)) 227 | 228 | for num_projections in range(m_choose_s): 229 | 230 | all_cordinate_s_dim_projections_tensor[num_projections, :, :] = fixed_s_subspace_coset_projection_matrix(m, s, get_subspace_given_basis(basis_all_s_dim_subspaces[num_projections])) 231 | 232 | return all_cordinate_s_dim_projections_tensor 233 | 234 | 235 | def find_code_projection_s_dim_subspace_coset_indices(M, s): 236 | 237 | # Input (m choose s, 2**m, 2**(m-s)). Output (m choose s, 2**s, 2**(m-s)) 238 | # M is of shape (70, 256, 16). We want a tensor of shape (70, 16, 16) that stores the indices of non-zero elements along each column 239 | 240 | two_power_s = 2**s 241 | idx = torch.zeros(M.shape[0], two_power_s, M.shape[2]).long() 242 | 243 | for i in range(idx.shape[0]): 244 | for j in range(idx.shape[2]): 245 | idx[i, :, j] = (M[i, :, j] == 1.).nonzero().reshape(two_power_s) 246 | 247 | return idx 248 | 249 | 250 | def find_s_dim_coset_friends_message_passing_and_backprojection(Code_idx): 251 | 252 | # Code_idx is of shape (m choose s, 2**s, 2**(m-s)). So we want an output of shape (m choose s, 2**m, 2**s-1), where we store the other coset friends for each of the 256 main indices 253 | 254 | m_choose_s = Code_idx.shape[0] 255 | two_power_m = Code_idx.shape[1] * Code_idx.shape[2] 256 | two_power_s = Code_idx.shape[1] 257 | # s = int(math.log(two_power_s, 2)) 258 | 259 | LLR_idx = torch.zeros(m_choose_s, two_power_m, two_power_s - 1).long() 260 | Coset_idx_for_bit_indices = torch.zeros(m_choose_s, two_power_m).long() 261 | 262 | for i in range(m_choose_s): 263 | mat = Code_idx[i,:, :] 264 | for j in range(two_power_m): 265 | that_column = (mat == j).nonzero().reshape(2)[1].item() 266 | all_coset_friends = mat[:, that_column] 267 | # print(all_coset_friends) 268 | LLR_idx[i, j, :] = all_coset_friends[all_coset_friends != j] 269 | Coset_idx_for_bit_indices[i, j] = that_column 270 | 271 | return LLR_idx, Coset_idx_for_bit_indices 272 | 273 | 274 | def llr_coset_projection_s_dim(llr, Code_idx, even_comb, odd_comb): 275 | 276 | # LLR is of shape (1, 256) 277 | # Code_idx is of shape (70, 16, 16) 278 | # output is of shape (1, 70, 16) 279 | # even_comb is the set of even-sized combinations of 2**s 280 | # odd_comb is the set of odd-sized combinations of 2**s 281 | 282 | Big_LLR = llr[:, Code_idx] # Of shape (1, 70, 16, 16) 283 | 284 | numerator = 1 285 | 286 | for each_comb in even_comb: 287 | 288 | numerator += Big_LLR[:, :, each_comb, :].sum(2).exp() # (1, 70, 16) 289 | 290 | # print(numerator[0].min(), numerator[0].max()) 291 | 292 | denominator = 0 293 | 294 | for each_comb in odd_comb: 295 | 296 | denominator += Big_LLR[:, :, each_comb, :].sum(2).exp() 297 | 298 | # print(denominator[0].min(), denominator[0].max()) 299 | 300 | 301 | return torch.log(numerator/denominator) # (1, 70, 16) 302 | 303 | 304 | def llr_message_passing_aggregation_s_dim(llr, LLR_idx, odd_comb, even_comb): 305 | 306 | # llr is of shape (1, 256) 307 | # LLR_idx is of shape (70, 256, 15) 308 | # odd_comb is the set of odd-sized combinations of 2**s - 1 309 | # even_comb is the set of even-sized combinations of 2**s - 1 310 | 311 | Big_LLR = llr[:, LLR_idx] # shape (1, 70, 256, 15) 312 | 313 | numerator = 0 314 | 315 | for each_comb in odd_comb: 316 | 317 | numerator += Big_LLR[:, :, :, each_comb].sum(3).exp() 318 | 319 | denominator = 1 320 | 321 | for each_comb in even_comb: 322 | 323 | denominator += Big_LLR[:, :, :, each_comb].sum(3).exp() 324 | 325 | return torch.log(numerator/denominator) # (1, 70, 256) 326 | 327 | 328 | def log_sum_avoid_NaN(x, y): 329 | 330 | a = torch.max(x, y) 331 | b = torch.min(x, y) 332 | 333 | log_sum_standard = torch.log(1 + (x+y).exp()) - x - torch.log(1 + (y-x).exp() ) 334 | 335 | # print("Original one:", log_sum_standard) 336 | 337 | ## Check for NaN or infty or -infty once here. 338 | if (torch.isnan(log_sum_standard).sum() > 0) | ((log_sum_standard == float('-inf')).sum() > 0 )| ( (log_sum_standard == float('inf')).sum() > 0) : 339 | 340 | # print("Had to avoid NaNs!") 341 | # 80 for float32 and 707 for float64. 342 | big_threshold = 80. if log_sum_standard.dtype == torch.float32 else 700. 343 | 344 | idx_1 = (x + y > big_threshold) 345 | subset_1 = idx_1 & ((x-y).abs() < big_threshold) 346 | 347 | idx_2 = (x + y < -big_threshold) 348 | subset_2 = idx_2 & ((x-y).abs() < big_threshold) 349 | 350 | idx_3 = ((x - y).abs() > big_threshold) & ( (x+y).abs() < big_threshold ) 351 | 352 | # Can be fastened 353 | if idx_1.sum() > 0 : 354 | 355 | if subset_1.sum() > 0: 356 | log_sum_standard[subset_1] = y[subset_1]- torch.log(1 + (y[subset_1] - x[subset_1]).exp() ) 357 | # print("After 11 modification", log_sum_standard) 358 | 359 | if (idx_1 - subset_1).sum() > 0: 360 | log_sum_standard[idx_1 - subset_1] = b[idx_1 - subset_1] 361 | # print("After 12 modification", log_sum_standard) 362 | 363 | if idx_2.sum() > 0: 364 | 365 | if subset_2.sum() > 0: 366 | log_sum_standard[subset_2] = -x[subset_2]- torch.log(1 + (y[subset_2] - x[subset_2]).exp() ) 367 | # print("After 21 modification", log_sum_standard) 368 | 369 | if (idx_2 - subset_2).sum() > 0: 370 | log_sum_standard[idx_2 - subset_2] = -a[idx_2 - subset_2] 371 | # print("After 22 modification", log_sum_standard) 372 | 373 | if idx_3.sum() > 0: 374 | 375 | log_sum_standard[idx_3] = torch.log(1 + (x[idx_3]+ y[idx_3]).exp() ) - a[idx_3] 376 | # print("After 3 modification", log_sum_standard) 377 | 378 | return log_sum_standard 379 | 380 | 381 | def log_sum_avoid_zero_NaN(x, y): 382 | 383 | avoided_NaN = log_sum_avoid_NaN(x,y) 384 | 385 | zero_idx = (avoided_NaN == 0.) 386 | 387 | data_type = x.dtype 388 | 389 | if zero_idx.sum() > 0: 390 | 391 | # print("Had to avoid zeros!") 392 | 393 | x_subzero = x[zero_idx] 394 | y_subzero = y[zero_idx] 395 | 396 | nume = torch.relu(x_subzero + y_subzero) 397 | denom = torch.max(x_subzero , y_subzero) 398 | delta = 1e-7 if data_type == torch.float32 else 1e-16 399 | 400 | term_1 = 0.5 *( (-nume).exp() + (x_subzero + y_subzero - nume).exp() ) 401 | term_2 = 0.5 * ( (x_subzero - denom).exp() + (y_subzero - denom).exp() ) 402 | 403 | close_1 = torch.tensor( (term_1 - 1).abs() < delta, dtype= data_type).cuda() 404 | T_1 = (term_1 - 1.) * close_1 + torch.log(term_1) * (1-close_1) 405 | 406 | close_2 = torch.tensor( (term_2 - 1).abs() < delta, dtype= data_type).cuda() 407 | T_2 = (term_2 - 1.) * close_2 + torch.log(term_2) * (1-close_2) 408 | 409 | corrected_ans = nume - denom + T_1 - T_2 410 | 411 | further_zero = (corrected_ans == 0.) 412 | 413 | if further_zero.sum() > 0: 414 | 415 | x_sub_subzero = x_subzero[further_zero] 416 | y_sub_subzero = y_subzero[further_zero] 417 | 418 | positive_idx = ( x_sub_subzero + y_sub_subzero > 0.) 419 | 420 | spoiled_brat = torch.min(- x_sub_subzero, - y_sub_subzero) 421 | 422 | spoiled_brat[positive_idx] = torch.min(x_sub_subzero[positive_idx], y_sub_subzero[positive_idx]) 423 | 424 | corrected_ans[further_zero] = spoiled_brat 425 | 426 | avoided_NaN[zero_idx] = corrected_ans 427 | 428 | return avoided_NaN 429 | 430 | 431 | def recursive_llr_coset_projection_s_dim(Big_LLR): 432 | 433 | # if first_half.shape[2] == 1: 434 | # numerator = 1 + first_half.add(second_half).exp() 435 | # denominator = first_half.exp().add(second_half.exp()) 436 | # return torch.log(numerator/denominator) 437 | 438 | # Big_LLR shape: (batch, 35, 8, 16) 439 | # Output shape: (batch, 25, 1, 16) 440 | 441 | if Big_LLR.shape[2] == 2: 442 | # numerator = 1 + Big_LLR.sum(2, keepdim=True).exp() 443 | # denominator = Big_LLR.exp().sum(2, keepdim=True) 444 | return log_sum_avoid_zero_NaN(Big_LLR[:, :, 0:1, :], Big_LLR[:, :, 1:2, :]) 445 | 446 | else: 447 | current_coset_length_half = Big_LLR.shape[2] // 2 448 | first_half = recursive_llr_coset_projection_s_dim(Big_LLR[:, :, :current_coset_length_half, :]) 449 | second_half = recursive_llr_coset_projection_s_dim(Big_LLR[:, :, current_coset_length_half:, :]) 450 | 451 | # numerator = 1 + first_half.add(second_half).exp() 452 | # denominator = first_half.exp().add(second_half.exp()) 453 | 454 | return log_sum_avoid_zero_NaN(first_half, second_half) #(batch, 35, 1, 16) # remember to squeeze the 2nd dimension 455 | 456 | 457 | def recursive_llr_message_passing_aggregation_s_dim(Big_LLR): 458 | 459 | # Big_LLR shape: (1, 35, 128, 8). -\infty is attached to the first column for each index z. 460 | # Output shape: (1, 35, 128, 1) 461 | 462 | ############## 463 | ### New implementation using tree. 464 | ############## 465 | 466 | return -recursive_llr_coset_projection_s_dim(Big_LLR.permute(0, 1, 3, 2)).permute(0, 1, 3, 2) 467 | 468 | ''' 469 | ####### 470 | ## Old implementation using a chain. This is correct too. 471 | ####### 472 | 473 | # if Big_LLR.shape[3] == 2: 474 | # numerator = 1 + Big_LLR.sum(3, keepdim=True).exp() 475 | # denominator = Big_LLR.exp().sum(3, keepdim=True) 476 | # return torch.log(numerator) - torch.log(denominator) 477 | 478 | # else: 479 | # # print(Big_LLR.shape) 480 | # next_friend = Big_LLR[:, :, :, :1] 481 | # remaining_Set = recursive_llr_message_passing_aggregation_s_dim(Big_LLR[:, :, :, 1:]) 482 | 483 | # numerator = 1 + next_friend.add(remaining_Set).exp() 484 | # denominator = next_friend.exp().add(remaining_Set.exp()) 485 | # return torch.log(numerator) - torch.log(denominator) # remember to squeeze the 3nd dimension 486 | ''' 487 | 488 | 489 | def coset_to_codeword_back_projection_s_dim(decoded_cosets, Coset_idx_unsqueezed): 490 | 491 | # decoded_cosets is of shape (Batch, 28, 64) 492 | # Coset_idx_unsqueezed is of shape (Batch, 28, 256). 493 | # Output is of shape (Batch, 28, 256) 494 | 495 | return decoded_cosets.gather(2, Coset_idx_unsqueezed) 496 | 497 | 498 | 499 | 500 | 501 | ####### 2-dim stuff 502 | def fixed_2_subspace_coset_projection_matrix(m, first_non_zero_subspace, second_non_zero_subspace, third_non_zero_subspace): 503 | 504 | M = torch.zeros(2**m, 2**(m-2)) 505 | 506 | current_column = 0 507 | 508 | for integer_corresp_coset in range(2**m): 509 | 510 | if sum(M[integer_corresp_coset, :]) == 0: 511 | 512 | M[integer_corresp_coset, current_column] = 1 513 | 514 | remaining_element1_in_coset = integer_corresp_coset^first_non_zero_subspace 515 | remaining_element2_in_coset = integer_corresp_coset^second_non_zero_subspace 516 | remaining_element3_in_coset = integer_corresp_coset^third_non_zero_subspace 517 | 518 | M[remaining_element1_in_coset, current_column] = 1 519 | M[remaining_element2_in_coset, current_column] = 1 520 | M[remaining_element3_in_coset, current_column] = 1 521 | 522 | current_column += 1 523 | 524 | return M 525 | 526 | 527 | def all_pair_cordinate_directions_2_subspace_coset_projection_tensor(m): 528 | 529 | # We only have (m 2) projections. Input LLR is of shape 2**m. Output LLR is of shape 2**(m-2) 530 | all_cordinate_2_projections_tensor = torch.zeros(int(m*(m-1)/2), 2**m, 2**(m-2)) 531 | 532 | count = 0 533 | 534 | for j in range(1, m): 535 | for i in range(j): 536 | 537 | all_cordinate_2_projections_tensor[count, :, :] = fixed_2_subspace_coset_projection_matrix(m, 2**i, 2**j, 2**i + 2**j) 538 | 539 | count += 1 540 | 541 | return all_cordinate_2_projections_tensor 542 | 543 | 544 | def find_code_projection_2_subspace_coset_indices(M): 545 | 546 | # M is of shape (28, 256, 64). We want a tensor of shape (28, 4, 64) that stores the indices of non-zero elements along each column 547 | 548 | idx = torch.zeros(M.shape[0], 4, M.shape[2]).long() 549 | 550 | for i in range(idx.shape[0]): 551 | for j in range(idx.shape[2]): 552 | idx[i, :, j] = (M[i, :, j] == 1.).nonzero().reshape(4) 553 | 554 | return idx 555 | 556 | 557 | def find_coset_friends_message_passing_and_backprojection(Code_idx): 558 | 559 | # Code_idx is of shape (28, 4, 64). So we want an output of shape (28, 256, 3), where we store the other 3 coset friends for each of the 256 main indices 560 | 561 | LLR_idx = torch.zeros(Code_idx.shape[0], 4*Code_idx.shape[2], 3).long() 562 | Coset_idx_for_bit_indices = torch.zeros(Code_idx.shape[0], 4*Code_idx.shape[2]).long() 563 | 564 | for i in range(LLR_idx.shape[0]): 565 | mat = Code_idx[i,:, :] 566 | for j in range(LLR_idx.shape[1]): 567 | that_column = (mat == j).nonzero().reshape(2)[1].item() 568 | all_coset_friends = mat[:, that_column] 569 | # print(all_coset_friends) 570 | LLR_idx[i, j, :] = all_coset_friends[all_coset_friends != j] 571 | Coset_idx_for_bit_indices[i, j] = that_column 572 | 573 | return LLR_idx, Coset_idx_for_bit_indices 574 | 575 | 576 | def llr_coset_projection_2_subspace_batch(llr, Code_idx): #req_Code_Projection_tensor, numer_req_Code_Projection_tensors, denom_req_Code_Projection_tensors 577 | 578 | # LLR is of shape (1, 256) 579 | # Code_idx is of shape (28, 4, 64) 580 | # output is of shape (1, 28, 64) 581 | Big_LLR = llr[:, Code_idx] # Of shape (1, 28, 4, 64) 582 | 583 | numerator = Big_LLR.sum(2).exp() + Big_LLR[:, :, [0,1], :].sum(2).exp() + Big_LLR[:, :, [0, 2], :].sum(2).exp() + Big_LLR[:, :, [0, 3], :].sum(2).exp() +\ 584 | Big_LLR[:, :, [1, 2], :].sum(2).exp() + Big_LLR[:, :, [1, 3], :].sum(2).exp() + Big_LLR[:, :, [2, 3], :].sum(2).exp() 585 | 586 | denominator = Big_LLR.exp().sum(2) + Big_LLR[:, :, [0, 1, 2], :].sum(2).exp() + Big_LLR[:, :, [0, 1, 3], :].sum(2).exp() + \ 587 | Big_LLR[:, :, [0, 2, 3], :].sum(2).exp() + Big_LLR[:, :, [1, 2, 3], :].sum(2).exp() 588 | 589 | return torch.log(1 + numerator) - torch.log(denominator) # (1, 28, 64) 590 | 591 | 592 | def llr_message_passing_aggregation(llr, LLR_idx): 593 | 594 | # llr is of shape (1, 256) 595 | # LLR_idx is of shape (28, 256, 3) 596 | 597 | Big_LLR = llr[:, LLR_idx] # shape (1, 28, 256, 3) 598 | 599 | numerator = Big_LLR.sum(3).exp() + Big_LLR.exp().sum(3) 600 | 601 | denominator = Big_LLR[:, :, :, [0, 1]].sum(3).exp() + Big_LLR[:, :, :, [0, 2]].sum(3).exp() + Big_LLR[:, :, :, [1, 2]].sum(3).exp() 602 | 603 | return torch.log( numerator ) - torch.log( 1 + denominator) # (1, 28, 256) 604 | 605 | 606 | def coset_to_codeword_back_projection_2_subspace_batch(decoded_cosets, Coset_idx_unsqueezed): 607 | 608 | # decoded_cosets is of shape (Batch, 28, 64) 609 | # Coset_idx_unsqueezed is of shape (Batch, 28, 256). 610 | # Output is of shape (Batch, 28, 256) 611 | 612 | return decoded_cosets.gather(2, Coset_idx_unsqueezed) 613 | 614 | 615 | 616 | 617 | ####################################################################################################### 618 | 619 | def codeword_to_all_coset_projection_and_back(codeword, BatchMul_Code_Projection_tensor): 620 | 621 | # codeword is of shape (1, 256). 622 | # Code_Projection_tensor is of shape (255, 256, 256). 623 | # output is of shape (256, 255) 624 | 625 | return contract('ij, kjm -> ikm', codeword, BatchMul_Code_Projection_tensor).reshape(BatchMul_Code_Projection_tensor.shape[0], BatchMul_Code_Projection_tensor.shape[2]).t() % 2 626 | 627 | 628 | def codeword_to_all_coset_projection(codeword, Code_Projection_tensor): 629 | 630 | # codeword is of shape (1, 256). 631 | # Code_Projection_tensor is of shape (255, 256, 128). 632 | # output is of shape (256, 255) 633 | 634 | return contract('ij, kjm -> ikm', codeword, Code_Projection_tensor).reshape(Code_Projection_tensor.shape[0], Code_Projection_tensor.shape[2]).t() % 2 635 | 636 | 637 | def coset_to_codeword_back_projection(decoded_cosets, Code_Projection_tensor): 638 | 639 | # decoded_cosets is of shape (255, 128). 640 | # Code_Projection_tensor is of shape (256, 256, 128). Hence we need to ignore the first component and transpose the last two components 641 | # output is of shape (256, 255) 642 | 643 | 644 | return contract('ij, ijm -> im', decoded_cosets, Code_Projection_tensor[1:, :, :].permute(0, 2, 1)).t() % 2 645 | 646 | 647 | def coset_to_codeword_back_projection_batch(decoded_cosets, req_Code_Projection_tensor, projection_choice='no_sparse', proj_indices=None): 648 | 649 | if projection_choice is 'no_sparse': 650 | 651 | # decoded_cosets is of shape (batch, 255, 128). 652 | # Code_Projection_tensor is of shape (255, 256, 128). Hence we need to ignore the first component and transpose the last two components 653 | # output is of shape (batch, 256, 255) 654 | 655 | return contract('kij, ijm -> kim', decoded_cosets, req_Code_Projection_tensor.permute(0, 2, 1)).permute(0, 2, 1) % 2 656 | 657 | elif projection_choice is "static_sparse_proj_batch_wise": 658 | 659 | return contract('kij, kijm -> kim', decoded_cosets, req_Code_Projection_tensor[proj_indices, :, :].permute(0, 1, 3, 2)).permute(0, 2, 1) % 2 660 | 661 | 662 | def cosetLLR_to_codewordLLR_back_projection_batch(decoded_cosets, req_Code_Projection_tensor, projection_choice='no_sparse', proj_indices=None): 663 | 664 | if projection_choice is 'no_sparse': 665 | 666 | # decoded_cosets is of shape (batch, 255, 128). 667 | # Code_Projection_tensor is of shape (255, 256, 128). Hence we need to ignore the first component and transpose the last two components 668 | # output is of shape (batch, 256, 255) 669 | 670 | return contract('kij, ijm -> kim', decoded_cosets, req_Code_Projection_tensor.permute(0, 2, 1)).permute(0, 2, 1) 671 | 672 | elif projection_choice is "static_sparse_proj_batch_wise": 673 | 674 | return contract('kij, kijm -> kim', decoded_cosets, req_Code_Projection_tensor[proj_indices, :, :].permute(0, 1, 3, 2)).permute(0, 2, 1) 675 | 676 | 677 | def NNoutput_to_codeword_back_projection_batch(decoded_cosets, Code_Projection_tensor): 678 | 679 | # decoded_cosets is of shape (255, 128). 680 | # Code_Projection_tensor is of shape (256, 256, 128). Hence we need to ignore the first component and transpose the last two components 681 | # output is of shape (256, 255) 682 | 683 | return contract('kijl, ijm -> kiml', decoded_cosets, Code_Projection_tensor[1:, :, :].permute(0, 2, 1)).permute(0, 2,1, 3) 684 | 685 | 686 | def llr_all_coset_projection(llr, req_Code_Projection_tensor): 687 | 688 | # LLR is of shape (1, 256). 689 | # Code_Projection_tensor is of shape (255, 256, 128). 690 | # output is of shape (255, 128) 691 | 692 | exp_llr = torch.exp(llr) 693 | 694 | proj_llr = contract('ij, kjm ->ikm', llr, req_Code_Projection_tensor).reshape(req_Code_Projection_tensor.shape[0],\ 695 | req_Code_Projection_tensor.shape[2]).t() 696 | 697 | proj_exp_llr = contract('ij, kjm ->ikm', exp_llr, req_Code_Projection_tensor).reshape(req_Code_Projection_tensor.shape[0],\ 698 | req_Code_Projection_tensor.shape[2]).t() 699 | 700 | return torch.log(1 + torch.exp(proj_llr)) - torch.log(proj_exp_llr) 701 | 702 | 703 | def llr_all_coset_projection_batch(llr, req_Code_Projection_tensor, projection_choice='no_sparse', proj_indices=None): 704 | 705 | # LLR is of shape (1, 256). 706 | # Code_Projection_tensor is of shape (255, 256, 128). 707 | # output is of shape (255, 128) 708 | 709 | if projection_choice is "no_sparse": 710 | 711 | exp_llr = torch.exp(llr) 712 | 713 | proj_llr = contract('ij, kjm ->ikm', llr, req_Code_Projection_tensor).reshape(llr.shape[0],req_Code_Projection_tensor.shape[0],\ 714 | req_Code_Projection_tensor.shape[2]).permute(0,2,1) 715 | 716 | proj_exp_llr = contract('ij, kjm ->ikm', exp_llr, req_Code_Projection_tensor).reshape(llr.shape[0], req_Code_Projection_tensor.shape[0],\ 717 | req_Code_Projection_tensor.shape[2]).permute(0,2,1) 718 | 719 | return torch.log(1 + torch.exp(proj_llr)) - torch.log(proj_exp_llr + 1e-8) 720 | 721 | elif projection_choice is "static_sparse_proj_batch_wise": 722 | 723 | # proj_indices is of shape (batch_size, proj_indices_for_each_batch). For example, (25, 8) 724 | 725 | modified_Code_Projection_tensor = req_Code_Projection_tensor[proj_indices, :, :] # (25, 8, 256, 128) 726 | 727 | exp_llr = torch.exp(llr) 728 | 729 | proj_llr = contract('ij, ikjm ->ikm', llr, modified_Code_Projection_tensor).reshape(llr.shape[0], modified_Code_Projection_tensor.shape[1],\ 730 | modified_Code_Projection_tensor.shape[3]).permute(0,2,1) 731 | 732 | proj_exp_llr = contract('ij, ikjm ->ikm', exp_llr, modified_Code_Projection_tensor).reshape(llr.shape[0], modified_Code_Projection_tensor.shape[1],\ 733 | modified_Code_Projection_tensor.shape[3]).permute(0,2,1) 734 | 735 | return torch.log(1 + torch.exp(proj_llr)) - torch.log(proj_exp_llr) 736 | 737 | 738 | def permute_a_given_llr(llr, LLR_Permutation_tensor): 739 | 740 | # llr is of shape (1, 256) 741 | # Permutation_tensor is of shape (255, 256, 256). 742 | # Output is of shape (256, 255) 743 | 744 | 745 | return contract('ij, kjm -> ikm', llr, LLR_Permutation_tensor).reshape(LLR_Permutation_tensor.shape[0], LLR_Permutation_tensor.shape[2]).t() 746 | 747 | 748 | def permute_a_given_llr_batch(llr, LLR_Permutation_tensor): 749 | 750 | # llr is of shape (1, 256) 751 | # Permutation_tensor is of shape (255, 256, 256). 752 | # Output is of shape (256, 255) 753 | 754 | 755 | return contract('ij, kjm -> ikm', llr, LLR_Permutation_tensor).reshape(llr.shape[0],LLR_Permutation_tensor.shape[0], LLR_Permutation_tensor.shape[2]).permute(0, 2, 1) 756 | 757 | 758 | def reed_muller_batch_encoding(batch_messages, Generator_Matrix): 759 | 760 | return batch_messages.mm(Generator_Matrix) % 2 761 | 762 | 763 | def awgn_channel(batch_codes, snr): 764 | 765 | noise_sigma = snr_db2sigma(snr) 766 | 767 | return (1-2*batch_codes) + noise_sigma*torch.randn(batch_codes.shape[0], batch_codes.shape[1], dtype=batch_codes.dtype) 768 | 769 | 770 | def simple_awgn_channel(batch_codes, snr, can_I_use_cuda): 771 | 772 | noise_sigma = snr_db2sigma(snr) 773 | 774 | standard_Gaussian = to_var(torch.randn(batch_codes.shape[0], batch_codes.shape[1], dtype=batch_codes.dtype), can_I_use_cuda) 775 | 776 | return batch_codes + noise_sigma*standard_Gaussian 777 | 778 | 779 | def llr_awgn_channel_affine_Plotkin(corrupted_codewords, snr, a, k_a): 780 | 781 | noise_sigma = snr_db2sigma(snr) 782 | 783 | first_column = (2./noise_sigma**2) * k_a * corrupted_codewords[:, 0:1] 784 | 785 | second_column = (2./noise_sigma**2) * k_a * a * corrupted_codewords[:, 1:2] 786 | 787 | return torch.cat([first_column, second_column], dim=1) 788 | 789 | def llr_awgn_channel_bpsk(corrupted_codes, snr): 790 | 791 | noise_sigma = snr_db2sigma(snr) 792 | 793 | return (2./noise_sigma**2) * corrupted_codes 794 | 795 | 796 | def get_codeword_llr_nn1(code_generator, llr_generator): 797 | code_data = next(code_generator) 798 | llr_data = next(llr_generator) 799 | if code_data.size() != llr_data.size(): 800 | code_data = next(code_generator) 801 | llr_data = next(llr_generator) 802 | return code_data, llr_data 803 | -------------------------------------------------------------------------------- /reed_muller_modules/comm_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | 5 | def errors_ber(y_true, y_pred): 6 | y_true = y_true.view(y_true.shape[0], -1, 1) 7 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 8 | 9 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 10 | res = sum(sum(myOtherTensor))/(myOtherTensor.shape[0]*myOtherTensor.shape[1]) 11 | return res 12 | 13 | 14 | def errors_ber_list(y_true, y_pred): 15 | block_len = y_true.shape[1] 16 | y_true = y_true.view(y_true.shape[0], -1) 17 | y_pred = y_pred.view(y_pred.shape[0], -1) 18 | 19 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)) 20 | res_list_tensor = torch.sum(myOtherTensor, dim = 1).type(torch.FloatTensor)/block_len 21 | 22 | return res_list_tensor 23 | 24 | 25 | def errors_ber_pos(y_true, y_pred): 26 | y_true = y_true.view(y_true.shape[0], -1, 1) 27 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 28 | 29 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 30 | 31 | tmp = myOtherTensor.sum(0)/myOtherTensor.shape[0] 32 | res = tmp.squeeze(1) 33 | return res 34 | 35 | def code_power(the_codes): 36 | the_codes = the_codes.cpu().numpy() 37 | the_codes = np.abs(the_codes)**2 38 | the_codes = the_codes.sum(2)/the_codes.shape[2] 39 | tmp = the_codes.sum(0)/the_codes.shape[0] 40 | res = tmp 41 | return res 42 | 43 | def errors_bler(y_true, y_pred): 44 | y_true = y_true.view(y_true.shape[0], -1, 1) 45 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 46 | 47 | decoded_bits = torch.round(y_pred) 48 | X_test = torch.round(y_true) 49 | tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) 50 | tp0 = tp0.cpu().numpy() 51 | bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) 52 | return bler_err_rate 53 | 54 | # note there are a few definitions of SNR. In our result, we stick to the following SNR setup. 55 | def snr_db2sigma(train_snr): 56 | return 10**(-train_snr*1.0/20) 57 | 58 | def snr_sigma2db(train_sigma): 59 | try: 60 | return -20.0 * math.log(train_sigma, 10) 61 | except: 62 | return -20.0 * torch.log10(train_sigma) -------------------------------------------------------------------------------- /reed_muller_modules/hadamard.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | use_hadamard_transform_cuda = True 5 | try: 6 | import hadamard_cuda 7 | # import torch.utils.cpp_extension 8 | # hadamard_cuda = torch.utils.cpp_extension.load( 9 | # name='hadamard_cuda', 10 | # sources=[ 11 | # 'hadamard_cuda/hadamard_cuda.cpp', 12 | # 'hadamard_cuda/hadamard_cuda_kernel.cu', 13 | # ], 14 | # extra_cuda_cflags=['-O2'], 15 | # verbose=False 16 | # ) 17 | except (ImportError, RuntimeError) as e: 18 | print("CUDA version of Hadamard transform isn't installed. Will use Pytorch's version, which is much slower.") 19 | use_hadamard_transform_cuda = False 20 | 21 | from scipy.linalg import hadamard 22 | 23 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | 25 | 26 | def hadamard_transform_torch(u, normalize=False): 27 | """Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n. 28 | n must be a power of 2. 29 | Parameters: 30 | u: Tensor of shape (..., n) 31 | normalize: if True, divide the result by 2^{m/2} where m = log_2(n). 32 | Returns: 33 | product: Tensor of shape (..., n) 34 | """ 35 | batch_size, n = u.shape 36 | m = int(np.log2(n)) 37 | assert n == 1 << m, 'n must be a power of 2' 38 | x = u[..., np.newaxis] 39 | for d in range(m)[::-1]: 40 | x = torch.cat((x[..., ::2, :] + x[..., 1::2, :], x[..., ::2, :] - x[..., 1::2, :]), dim=-1) 41 | return x.squeeze(-2) / 2**(m / 2) if normalize else x.squeeze(-2) 42 | 43 | 44 | class HadamardTransformCuda(torch.autograd.Function): 45 | '''The unnormalized Hadamard transform (i.e. without dividing by sqrt(2)) 46 | ''' 47 | @staticmethod 48 | def forward(ctx, u): 49 | return hadamard_cuda.hadamard_transform(u) 50 | 51 | @staticmethod 52 | def backward(ctx, grad): 53 | return HadamardTransformCuda.apply(grad) 54 | 55 | 56 | def hadamard_transform_cuda(u, normalize=False): 57 | """Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n. 58 | n must be a power of 2. 59 | Parameters: 60 | u: Tensor of shape (..., n) 61 | normalize: if True, divide the result by 2^{m/2} where m = log_2(n). 62 | Returns: 63 | product: Tensor of shape (..., n) 64 | """ 65 | _, n = u.shape 66 | m = int(np.log2(n)) 67 | assert n == 1 << m, 'n must be a power of 2' 68 | output = HadamardTransformCuda.apply(u) 69 | return output / 2**(m / 2) if normalize else output 70 | 71 | 72 | def test_hadamard_transform(): 73 | m = 15 74 | n = 1 << m 75 | batch_size = 50 76 | u = torch.rand((batch_size, n), requires_grad=True, device=device) 77 | result_cuda = hadamard_transform_cuda(u) 78 | grad_cuda, = torch.autograd.grad(result_cuda.sum(), u, retain_graph=True) 79 | result_torch = hadamard_transform_torch(u) 80 | grad_torch, = torch.autograd.grad(result_torch.sum(), u, retain_graph=True) 81 | # Explicit construction from scipy 82 | H = torch.tensor(hadamard(n), dtype=torch.float, device=device) 83 | result_explicit = u @ H.t() 84 | print((result_cuda - result_explicit).abs().max().item()) 85 | print((result_cuda - result_explicit).abs().mean().item()) 86 | print((result_torch - result_explicit).abs().max().item()) 87 | print((result_torch - result_explicit).abs().mean().item()) 88 | print((grad_cuda - grad_torch).abs().max().item()) 89 | print((grad_cuda - grad_torch).abs().mean().item()) 90 | 91 | 92 | hadamard_transform = hadamard_transform_cuda if use_hadamard_transform_cuda else hadamard_transform_torch 93 | 94 | if __name__ == '__main__': 95 | test_hadamard_transform() 96 | -------------------------------------------------------------------------------- /reed_muller_modules/logging_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging.config 4 | import shutil 5 | import pandas as pd 6 | from bokeh.io import output_file, save, show 7 | from bokeh.plotting import figure 8 | from bokeh.layouts import column 9 | 10 | 11 | 12 | 13 | def setup_logging(log_file='log.txt'): 14 | """Setup logging configuration 15 | """ 16 | logging.basicConfig(level=logging.DEBUG, 17 | format="%(asctime)s - %(levelname)s - %(message)s", 18 | datefmt="%Y-%m-%d %H:%M:%S", 19 | filename=log_file, 20 | filemode='w') 21 | console = logging.StreamHandler() 22 | console.setLevel(logging.INFO) 23 | formatter = logging.Formatter('%(message)s') 24 | console.setFormatter(formatter) 25 | logging.getLogger('').addHandler(console) 26 | 27 | 28 | class ResultsLog(object): 29 | 30 | def __init__(self, path='results.csv', plot_path=None): 31 | self.path = path 32 | self.plot_path = plot_path or (self.path + '.html') 33 | self.figures = [] 34 | self.results = None 35 | 36 | def add(self, **kwargs): 37 | df = pd.DataFrame([kwargs.values()], columns=kwargs.keys()) 38 | if self.results is None: 39 | self.results = df 40 | else: 41 | self.results = self.results.append(df, ignore_index=True) 42 | 43 | def save(self, title='Training Results'): 44 | if len(self.figures) > 0: 45 | if os.path.isfile(self.plot_path): 46 | os.remove(self.plot_path) 47 | output_file(self.plot_path, title=title) 48 | plot = column(*self.figures) 49 | save(plot) 50 | self.figures = [] 51 | self.results.to_csv(self.path, index=False, index_label=False) 52 | 53 | def load(self, path=None): 54 | path = path or self.path 55 | if os.path.isfile(path): 56 | self.results.read_csv(path) 57 | 58 | def show(self): 59 | if len(self.figures) > 0: 60 | plot = column(*self.figures) 61 | show(plot) 62 | 63 | 64 | def image(self, *kargs, **kwargs): 65 | fig = figure() 66 | fig.image(*kargs, **kwargs) 67 | self.figures.append(fig) 68 | 69 | 70 | def save_checkpoint(state, is_best, path='.', filename='checkpoint.pth.tar', save_all=False): 71 | filename = os.path.join(path, filename) 72 | torch.save(state, filename) 73 | if is_best: 74 | shutil.copyfile(filename, os.path.join(path, 'model_best.pth.tar')) 75 | if save_all: 76 | shutil.copyfile(filename, os.path.join( 77 | path, 'checkpoint_epoch_%s.pth.tar' % state['epoch'])) 78 | 79 | 80 | class AverageMeter(object): 81 | """Computes and stores the average and current value""" 82 | 83 | def __init__(self): 84 | self.reset() 85 | 86 | def reset(self): 87 | self.val = 0 88 | self.avg = 0 89 | self.sum = 0 90 | self.count = 0 91 | 92 | def update(self, val, n=1): 93 | self.val = val 94 | self.sum += val * n 95 | self.count += n 96 | self.avg = self.sum / self.count 97 | 98 | __optimizers = { 99 | 'SGD': torch.optim.SGD, 100 | 'ASGD': torch.optim.ASGD, 101 | 'Adam': torch.optim.Adam, 102 | 'Adamax': torch.optim.Adamax, 103 | 'Adagrad': torch.optim.Adagrad, 104 | 'Adadelta': torch.optim.Adadelta, 105 | 'Rprop': torch.optim.Rprop, 106 | 'RMSprop': torch.optim.RMSprop 107 | } 108 | 109 | 110 | def adjust_optimizer(optimizer, epoch, config): 111 | """Reconfigures the optimizer according to epoch and config dict""" 112 | def modify_optimizer(optimizer, setting): 113 | if 'optimizer' in setting: 114 | optimizer = __optimizers[setting['optimizer']]( 115 | optimizer.param_groups) 116 | logging.debug('OPTIMIZER - setting method = %s' % 117 | setting['optimizer']) 118 | for param_group in optimizer.param_groups: 119 | for key in param_group.keys(): 120 | if key in setting: 121 | logging.debug('OPTIMIZER - setting %s = %s' % 122 | (key, setting[key])) 123 | param_group[key] = setting[key] 124 | return optimizer 125 | 126 | if callable(config): 127 | optimizer = modify_optimizer(optimizer, config(epoch)) 128 | else: 129 | for e in range(epoch + 1): # run over all epochs - sticky setting 130 | if e in config: 131 | optimizer = modify_optimizer(optimizer, config[e]) 132 | 133 | return optimizer 134 | 135 | 136 | def accuracy(output, target, topk=(1,)): 137 | """Computes the precision@k for the specified values of k""" 138 | maxk = max(topk) 139 | batch_size = target.size(0) 140 | 141 | _, pred = output.float().topk(maxk, 1, True, True) 142 | pred = pred.t() 143 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 144 | 145 | res = [] 146 | for k in topk: 147 | correct_k = correct[:k].view(-1).float().sum(0) 148 | res.append(correct_k.mul_(100.0 / batch_size)) 149 | return res 150 | 151 | -------------------------------------------------------------------------------- /reed_muller_modules/reedmuller_codebook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """reedmuller.py 3 | 4 | Implementation of Reed-Muller codes for Python. 5 | See the class ReedMuller for the documentation.""" 6 | 7 | import operator 8 | import itertools 9 | from functools import reduce 10 | import numpy as np 11 | 12 | 13 | def _binom(n, k): 14 | """Binomial coefficienct (n-k)!/k!.""" 15 | return reduce(operator.mul, range(n - k + 1, n + 1)) // reduce(operator.mul, range(1, k + 1)) 16 | 17 | 18 | def _construct_vector(m, i): 19 | """Construct the vector for x_i of length 2^m, which has form: 20 | A string of 2^{m-i-1} 1s followed by 2^{m-i-1} 0s, repeated 21 | 2^m / (2*2^{m-i-1}) = 2^{m-1}/2^{m-i-1} = 2^i times. 22 | NOTE: we must have 0 <= i < m.""" 23 | return ([1] * (2 ** (m - i - 1)) + [0] * (2 ** (m - i - 1))) * (2 ** i) 24 | 25 | 26 | def _vector_mult(*vecs): 27 | """For any number of length-n vectors, pairwise multiply the entries, e.g. for 28 | x = (x_0, ..., x_{n-1}), y = (y_0, ..., y_{n-1}), 29 | xy = (x_0y_0, x_1y_1, ..., x_{n-1}y{n-1}).""" 30 | assert (len(set(map(len, vecs))) == 1) 31 | return list(map(lambda a: reduce(operator.mul, a, 1), list(zip(*vecs)))) 32 | 33 | 34 | def _vector_add(*vecs): 35 | """For any number of length-n vectors, pairwise add the entries, e.g. for 36 | x = (x_0, ..., x_{n-1}), y = (y_0, ..., y_{n-1}), 37 | xy = (x_0+y_0, x_1+y_1, ..., x_{n-1}+y{n-1}).""" 38 | assert (len(set(map(len, vecs))) == 1) 39 | return list(map(lambda a: reduce(operator.add, a, 0), list(zip(*vecs)))) 40 | 41 | 42 | def _vector_neg(x): 43 | """Take the negation of a vector over Z_2, i.e. swap 1 and 0.""" 44 | return list(map(lambda a: 1 - a, x)) 45 | 46 | 47 | def _vector_reduce(x, modulo): 48 | """Reduce each entry of a vector modulo the value supplied.""" 49 | return list(map(lambda a: a % modulo, x)) 50 | 51 | 52 | def _dot_product(x, y): 53 | """Calculate the dot product of two vectors.""" 54 | assert (len(x) == len(y)) 55 | return sum(_vector_mult(x, y)) 56 | 57 | 58 | def _generate_all_rows(m, S): 59 | """Generate all rows over the monomials in S, e.g. if S = {0,2}, we want to generate 60 | a list of four rows, namely: 61 | phi(x_0) * phi(x_2) 62 | phi(x_0) * !phi(x_2) 63 | !phi(x_0) * phi(x_2) 64 | !phi(x_0) * !phi(x_2). 65 | 66 | We do this using recursion on S.""" 67 | 68 | if not S: 69 | return [[1] * (2 ** m)] 70 | 71 | i, Srest = S[0], S[1:] 72 | 73 | # Find all the rows over Srest. 74 | Srest_rows = _generate_all_rows(m, Srest) 75 | 76 | # Now, for both the representation of x_i and !x_i, return the rows multiplied by these. 77 | xi_row = _construct_vector(m, i) 78 | not_xi_row = _vector_neg(xi_row) 79 | return [_vector_mult(xi_row, row) for row in Srest_rows] + [_vector_mult(not_xi_row, row) for row in Srest_rows] 80 | 81 | 82 | class ReedMuller: 83 | """A class representing a Reed-Muller code RM(r,m), which encodes words of length: 84 | k = C(m,0) + C(m,1) + ... + C(m,r) 85 | to words of length n = 2^m. 86 | Note that C(m,0) + ... + C(m,m) = 2^m, so k <= n in all cases, as expected. 87 | The code RM(r,m) has weight 2^{m-r}, and thus, can correct up to 2^{m-r-1}-1 errors.""" 88 | 89 | def __init__(self, r, m): 90 | """Create a Reed-Muller coder / decoder for RM(r,m).""" 91 | self.r, self.m = (r, m) 92 | self._construct_matrix() 93 | self.k = len(self.M[0]) 94 | self.n = 2 ** m 95 | 96 | def strength(self): 97 | """Return the strength of the code, i.e. the number of errors we can correct.""" 98 | return 2 ** (self.m - self.r - 1) - 1 99 | 100 | def message_length(self): 101 | """The length of a message to be encoded.""" 102 | return self.k 103 | 104 | def block_length(self): 105 | """The length of a coded message.""" 106 | return self.n 107 | 108 | def _construct_matrix(self): 109 | # Construct all of the x_i rows. 110 | x_rows = [_construct_vector(self.m, i) for i in range(self.m)] 111 | 112 | # For every s-set S for all 0 <= s <= r, create the row that is the product of the x_j vectors for j in S. 113 | self.matrix_by_row = [reduce(_vector_mult, [x_rows[i] for i in S], [1] * (2 ** self.m)) 114 | for s in range(self.r + 1) 115 | for S in itertools.combinations(range(self.m), s)] 116 | 117 | # To decode, for each row of the matrix, what we need is a list of all vectors consisting of the representations 118 | # of all monomials not in the row. These are the rows that are used in voting to determine if there is a 0 or 1 119 | # in the position corresponding to the row. 120 | self.voting_rows = [_generate_all_rows(self.m, [i for i in range(self.m) if i not in S]) 121 | for s in range(self.r + 1) 122 | for S in itertools.combinations(range(self.m), s)] 123 | 124 | # Now the only thing we need are a list of the indices of the rows corresponding to monomials of degree i. 125 | self.row_indices_by_degree = [0] 126 | for degree in range(1, self.r + 1): 127 | self.row_indices_by_degree.append(self.row_indices_by_degree[degree - 1] + _binom(self.m, degree)) 128 | 129 | # Now we want the transpose for the code matrix, to facilitate multiplying vectors on the right by the matrix. 130 | self.M = list(zip(*self.matrix_by_row)) 131 | 132 | self.Generator_Matrix = np.array(self.M) 133 | self.Generator_Matrix = self.Generator_Matrix.reshape(len(self.Generator_Matrix), len(self.Generator_Matrix[0])).transpose() 134 | 135 | def encode(self, word): 136 | """Encode a length-k vector to a length-n vector.""" 137 | assert (len(word) == self.k) 138 | return [_dot_product(word, col) % 2 for col in self.M] 139 | 140 | 141 | 142 | 143 | def decode(self, eword): 144 | """Decode a length-n vector back to its original length-k vector using majority logic.""" 145 | # We want to iterate over each row r of the matrix and determine if a 0 or 1 appears in 146 | # position r of the original word w using majority logic. 147 | 148 | row = self.k - 1 149 | word = [-1] * self.k 150 | 151 | for degree in range(self.r, -1, -1): 152 | # We calculate the entries for the degree. We need the range of rows of the code matrix 153 | # corresponding to degree r. 154 | upper_r = self.row_indices_by_degree[degree] 155 | lower_r = 0 if degree == 0 else self.row_indices_by_degree[degree - 1] + 1 156 | 157 | # Now iterate over these rows to determine the value of word for positions lower_r 158 | # through upper_r inclusive. 159 | for pos in range(lower_r, upper_r + 1): 160 | # We vote for the value of this position based on the vectors in voting_rows. 161 | votes = [_dot_product(eword, vrow) % 2 for vrow in self.voting_rows[pos]] 162 | 163 | # If there is a tie, there is nothing we can do. 164 | if votes.count(0) == votes.count(1): 165 | return None 166 | 167 | # Otherwise, we set the position to the winner. 168 | word[pos] = 0 if votes.count(0) > votes.count(1) else 1 169 | 170 | # Now we need to modify the word. We want to calculate the product of what we just 171 | # voted on with the rows of the matrix. 172 | # QUESTION: do we JUST do this with what we've calculated (word[lower_r] to word[upper_r]), 173 | # or do we do it with word[lower_r] to word[k-1]? 174 | s = [_dot_product(word[lower_r:upper_r + 1], column[lower_r:upper_r + 1]) % 2 for column in self.M] 175 | eword = _vector_reduce(_vector_add(eword, s), 2) 176 | 177 | # We have now decoded. 178 | return word 179 | 180 | def __repr__(self): 181 | return '' % (self.r, self.m, self.strength()) 182 | 183 | 184 | def _generate_all_vectors(n): 185 | """Generator to yield all possible length-n vectors in Z_2.""" 186 | v = [0] * n 187 | while True: 188 | yield v 189 | 190 | # Generate the next vector by adding 1 to the end. 191 | # Then keep modding by 2 and moving any excess back up the vector. 192 | v[n - 1] = v[n - 1] + 1 193 | pos = n - 1 194 | while pos >= 0 and v[pos] == 2: 195 | v[pos] = 0 196 | pos = pos - 1 197 | if pos >= 0: 198 | v[pos] += 1 199 | 200 | # Terminate if we reach the all-0 vector again. 201 | if v == [0] * n: 202 | break 203 | 204 | 205 | def _characteristic_vector(n, S): 206 | """Return the characteristic vector of the subset S of an n-set.""" 207 | return [0 if i not in S else 1 for i in range(n)] 208 | 209 | 210 | if __name__ == '__main__': 211 | # Check for correct command-line arguments and if not present, print informative message. 212 | import sys 213 | 214 | if len(sys.argv) != 3: 215 | sys.stderr.write('Usage: %s r m\n' % (sys.argv[0],)) 216 | sys.exit(1) 217 | r, m = map(int, sys.argv[1:]) 218 | if (m <= r): 219 | sys.stderr.write('We require r > m.\n') 220 | sys.exit(2) 221 | 222 | # Create the code. 223 | rm = ReedMuller(r, m) 224 | strength = rm.strength() 225 | message_length = rm.message_length() 226 | block_length = rm.block_length() 227 | 228 | # Create a list of all possible errors up to the maximum strength. 229 | error_vectors = [_characteristic_vector(block_length, S) 230 | for numerrors in range(strength + 1) 231 | for S in itertools.combinations(range(block_length), numerrors)] 232 | 233 | # Encode every possible message of message_length. 234 | success = True 235 | for word in _generate_all_vectors(message_length): 236 | codeword = rm.encode(word) 237 | 238 | # Now produce all correctable errors and make sure we still decode to the right word. 239 | for error in error_vectors: 240 | error_codeword = _vector_reduce(_vector_add(codeword, error), 2) 241 | error_word = rm.decode(error_codeword) 242 | if error_word != word: 243 | print('ERROR: encode(%s) => %s, decode(%s+%s=%s) => %s' % (word, codeword, codeword, 244 | error, error_codeword, error_word)) 245 | success = False 246 | 247 | if success: 248 | print('RM(%s,%s): success.' % (r, m)) 249 | -------------------------------------------------------------------------------- /test_KO_m1_dumer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | from torchvision.utils import save_image 13 | from torchvision.utils import make_grid 14 | import torch.utils.data 15 | from data_loader import * 16 | # from IPython import display 17 | 18 | import pickle 19 | import glob 20 | import os 21 | import logging 22 | import time 23 | from datetime import datetime 24 | from ast import literal_eval 25 | import matplotlib 26 | matplotlib.use('AGG') 27 | import matplotlib.pyplot as plt 28 | import matplotlib.animation as animation 29 | from PIL import Image 30 | 31 | import reed_muller_modules 32 | from reed_muller_modules.logging_utils import * 33 | 34 | from opt_einsum import contract # This is for faster torch.einsum 35 | from reed_muller_modules.reedmuller_codebook import * 36 | from reed_muller_modules.hadamard import * 37 | from reed_muller_modules.comm_utils import * 38 | from reed_muller_modules.logging_utils import * 39 | from reed_muller_modules.all_functions import * 40 | import reed_muller_modules.reedmuller_codebook as reedmuller_codebook 41 | 42 | import pandas as pd 43 | import numpy as np 44 | from scipy.stats import norm 45 | from tqdm import tqdm 46 | from itertools import combinations 47 | 48 | 49 | parser = argparse.ArgumentParser(description='(m,1) dumer') 50 | 51 | parser.add_argument('--m', type=int, default=8, help='reed muller code parameter m') 52 | 53 | parser.add_argument('--batch_size', type=int, default=20000, help='size of the batches') 54 | parser.add_argument('--hidden_size', type=int, default=64, help='neural network size') 55 | 56 | parser.add_argument('--full_iterations', type=int, default=10000, help='full iterations') 57 | parser.add_argument('--enc_train_iters', type=int, default=50, help='encoder iterations') 58 | parser.add_argument('--dec_train_iters', type=int, default=500, help='decoder iterations') 59 | 60 | parser.add_argument('--enc_train_snr', type=float, default=-4., help='snr at enc are trained') 61 | parser.add_argument('--dec_train_snr', type=float, default=-7., help='snr at dec are trained') 62 | 63 | 64 | parser.add_argument('--power_constraint_type', type=str, default='hard_power_block', help='typer of power constraint') 65 | parser.add_argument('--loss_type', type=str, default='BCE', choices=['MSE', 'BCE'], help='loss function') 66 | parser.add_argument('--model_iters', type=int, default=0, help='model Iters') 67 | 68 | parser.add_argument('--gpu', type=int, default=0, help='gpus used for training - e.g 0,1,3') 69 | 70 | args = parser.parse_args() 71 | 72 | device = torch.device("cuda:{0}".format(args.gpu)) 73 | kwargs = {'num_workers': 4, 'pin_memory': False} 74 | 75 | def repetition_code_matrices(device, m=8): 76 | 77 | M_dict = {} 78 | 79 | for i in range(1, m): 80 | M_dict[i] = torch.ones(1, 2**i).to(device) 81 | 82 | return M_dict 83 | 84 | repetition_M_dict = repetition_code_matrices(device, args.m) 85 | 86 | print("Matrices required for repition code are defined!") 87 | 88 | ###### 89 | ## Functions 90 | ###### 91 | 92 | def snr_db2sigma(train_snr): 93 | return 10**(-train_snr*1.0/20) 94 | 95 | 96 | def log_sum_exp(LLR_vector): 97 | 98 | sum_vector = LLR_vector.sum(dim=1, keepdim=True) 99 | sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) 100 | 101 | return torch.logsumexp(sum_concat, dim=1)- torch.logsumexp(LLR_vector, dim=1) 102 | 103 | 104 | def errors_ber(y_true, y_pred): 105 | y_true = y_true.view(y_true.shape[0], -1, 1) 106 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 107 | 108 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 109 | res = sum(sum(myOtherTensor))/(myOtherTensor.shape[0]*myOtherTensor.shape[1]) 110 | return res 111 | 112 | 113 | def errors_bler(y_true, y_pred): 114 | y_true = y_true.view(y_true.shape[0], -1, 1) 115 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 116 | 117 | decoded_bits = torch.round(y_pred).cpu() 118 | X_test = torch.round(y_true).cpu() 119 | tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) 120 | tp0 = tp0.detach().cpu().numpy() 121 | bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) 122 | return bler_err_rate 123 | 124 | 125 | 126 | 127 | 128 | class g_identity(nn.Module): 129 | def __init__(self): 130 | super(g_vector, self).__init__() 131 | self.fc = nn.Linear(1, 1, bias=False) 132 | 133 | def forward(self, y): 134 | 135 | return y 136 | 137 | class g_vector(nn.Module): 138 | def __init__(self): 139 | super(g_vector, self).__init__() 140 | self.fc = nn.Linear(16, 1, bias=True) 141 | 142 | def forward(self, y): 143 | 144 | return self.fc(y) 145 | 146 | 147 | 148 | class g_Full(nn.Module): 149 | def __init__(self, input_size, hidden_size, output_size): 150 | super(g_Full, self).__init__() 151 | 152 | self.input_size = input_size 153 | 154 | self.half_input_size = int(input_size/2) 155 | 156 | self.hidden_size = hidden_size 157 | self.output_size = output_size 158 | 159 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 160 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 161 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 162 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 163 | 164 | # self.skip = nn.Linear(3*self.half_input_size, self.hidden_size, bias=False) 165 | 166 | def forward(self, y): 167 | x = F.selu(self.fc1(y)) 168 | x = F.selu(self.fc2(x)) 169 | 170 | x = F.selu(self.fc3(x)) 171 | x = self.fc4(x)+y[:, :self.half_input_size]*y[:, self.half_input_size:] 172 | return x 173 | 174 | class f_Full(nn.Module): 175 | def __init__(self, input_size, hidden_size, output_size): 176 | super(f_Full, self).__init__() 177 | self.input_size = input_size 178 | self.hidden_size = hidden_size 179 | self.output_size = output_size 180 | 181 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 182 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 183 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 184 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 185 | 186 | def forward(self, y): 187 | x = F.selu(self.fc1(y)) 188 | x = F.selu(self.fc2(x)) 189 | 190 | x = F.selu(self.fc3(x)) 191 | x = self.fc4(x) 192 | return x 193 | 194 | 195 | def power_constraint(codewords, gnet_top, power_constraint_type, training_mode): 196 | 197 | 198 | if power_constraint_type in ['soft_power_block','soft_power_bit']: 199 | 200 | this_mean = codewords.mean(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.mean() 201 | this_std = codewords.std(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.std() 202 | 203 | if training_mode == 'train': # Training 204 | power_constrained_codewords = (codewords - this_mean)*1.0 / this_std 205 | 206 | gnet_top.update_normstats_for_test(this_mean, this_std) 207 | 208 | elif training_mode == 'test': # For inference 209 | power_constrained_codewords = (codewords - gnet_top.mean_scalar)*1.0/gnet_top.std_scalar 210 | 211 | # else: # When updating the stat parameters of g2net. Just don't do anything 212 | # power_constrained_codewords = _ 213 | 214 | return power_constrained_codewords 215 | 216 | 217 | elif power_constraint_type == 'hard_power_block': 218 | 219 | return F.normalize(codewords, p=2, dim=1)*np.sqrt(2**args.m) 220 | 221 | 222 | else: # 'hard_power_bit' 223 | 224 | return codewords/codewords.abs() 225 | 226 | # Plotkin stuff 227 | def encoder_Plotkin(msg_bits): 228 | 229 | #msg_bits is of shape (batch, m+1) 230 | 231 | u_level0 = msg_bits[:, 0:1] 232 | v_level0 = msg_bits[:, 1:2] 233 | 234 | for i in range(2, args.m+1): 235 | 236 | u_level0 = torch.cat([ u_level0, u_level0 * v_level0], dim=1) 237 | v_level0 = msg_bits[:, i:i+1].mm(repetition_M_dict[i-1]) 238 | 239 | u_levelm = torch.cat([u_level0, u_level0 * v_level0], dim=1) 240 | 241 | return u_levelm 242 | 243 | 244 | 245 | def encoder_full(msg_bits, gnet_dict, power_constraint_type='hard_power_block', training_mode='train'): #g_avector, g_bvector, 246 | 247 | u_level0 = msg_bits[:, 0:1] 248 | v_level0 = msg_bits[:, 1:2] 249 | 250 | for i in range(2, args.m+1): 251 | 252 | u_level0 = torch.cat([ u_level0, gnet_dict[i-1](torch.cat([u_level0, v_level0], dim=1)) ], dim=1) 253 | v_level0 = msg_bits[:, i:i+1].mm(repetition_M_dict[i-1]) 254 | 255 | u_levelm = torch.cat([u_level0, gnet_dict[args.m](torch.cat([u_level0, v_level0], dim=1))], dim=1) 256 | 257 | return power_constraint(u_levelm, gnet_dict[args.m], power_constraint_type, training_mode) 258 | 259 | 260 | 261 | def awgn_channel(codewords, snr): 262 | noise_sigma = snr_db2sigma(snr) 263 | standard_Gaussian = torch.randn_like(codewords) 264 | corrupted_codewords = codewords+noise_sigma * standard_Gaussian 265 | return corrupted_codewords 266 | 267 | def decoder_dumer(corrupted_codewords, snr): 268 | 269 | noise_sigma = snr_db2sigma(snr) 270 | 271 | llrs = (2/noise_sigma**2)*corrupted_codewords 272 | Lu = llrs 273 | 274 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 275 | 276 | for i in range(args.m-1, -1, -1): 277 | 278 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 279 | 280 | v_hat = torch.sign(Lv) 281 | 282 | decoded_bits[:, i+1] = v_hat.squeeze(1) 283 | 284 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 285 | 286 | 287 | u_1_hat = torch.sign(Lu) 288 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 289 | 290 | return decoded_bits 291 | 292 | 293 | def decoder_dumer_soft(corrupted_codewords, snr): 294 | 295 | noise_sigma = snr_db2sigma(snr) 296 | 297 | llrs = (2/noise_sigma**2)*corrupted_codewords 298 | Lu = llrs 299 | 300 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 301 | 302 | for i in range(args.m-1, -1, -1): 303 | 304 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 305 | 306 | v_hat = torch.tanh(Lv/2) 307 | 308 | decoded_bits[:, i+1] = v_hat.squeeze(1) 309 | 310 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 311 | 312 | 313 | u_1_hat = torch.tanh(Lu/2) 314 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 315 | 316 | return decoded_bits 317 | 318 | 319 | def decoder_nn_full(corrupted_codewords, fnet_dict): 320 | 321 | Lu = corrupted_codewords 322 | 323 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 324 | 325 | for i in range(args.m-1, -1 , -1): 326 | 327 | Lv = fnet_dict[2*(args.m-i)-1](Lu)+log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 328 | 329 | v_hat = torch.tanh(Lv/2) 330 | 331 | decoded_llrs[:, i+1] = v_hat.squeeze(1) 332 | 333 | Lu = fnet_dict[2*(args.m-i)](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2), v_hat.unsqueeze(1).repeat(1, 2**i, 1)], dim=2)).squeeze(2)+Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 334 | 335 | u_1_hat = torch.tanh(Lu/2) 336 | 337 | decoded_llrs[:, 0] = u_1_hat.squeeze(1) 338 | 339 | 340 | return decoded_llrs 341 | 342 | 343 | def get_msg_bits_batch(data_generator): 344 | msg_bits_batch = next(data_generator) 345 | return msg_bits_batch 346 | 347 | def moving_average(a, n=3) : 348 | ret = np.cumsum(a, dtype=float) 349 | ret[n:] = ret[n:] - ret[:-n] 350 | return ret[n - 1:] / n 351 | 352 | # msg_bits = 2 * (torch.rand(args.full_iterations * args.batch_size, args.m+1) < 0.5).float() - 1 353 | # Data_Generator = torch.utils.data.DataLoader(msg_bits, batch_size=args.batch_size , shuffle=True, **kwargs) 354 | 355 | print("Data loading stuff is completed! \n") 356 | 357 | gnet_dict = {} 358 | 359 | for i in range(1, args.m+1): 360 | gnet_dict[i] = g_Full(2*2**(i-1), args.hidden_size, 2**(i-1)) 361 | 362 | 363 | fnet_dict = {} 364 | 365 | for i in range(1, args.m+1): 366 | fnet_dict[2*i-1] = f_Full(2**(args.m-i+1), args.hidden_size, 1) 367 | fnet_dict[2*i] = f_Full(1+ 1+ 1, args.hidden_size, 1) 368 | 369 | 370 | ######### 371 | ### Loading the models 372 | ######### 373 | 374 | results_load_path = './Results/RM({0},1)/NN_EncFull_Skip+Dec_Dumer/Enc_snr_{1}_Dec_snr{2}/Batch_{3}'\ 375 | .format(args.m, args.enc_train_snr,args.dec_train_snr, args.batch_size) 376 | 377 | checkpoint1 = torch.load(results_load_path +'/Models/Encoder_NN_{0}.pt'.format(args.model_iters), map_location=lambda storage, loc: storage) 378 | print(checkpoint1) 379 | for i in range(1, args.m+1): 380 | gnet_dict[i].load_state_dict(checkpoint1['g{0}'.format(i)]) 381 | 382 | checkpoint2 = torch.load(results_load_path +'/Models/Decoder_NN_{0}.pt'.format(args.model_iters), map_location=lambda storage, loc: storage) 383 | 384 | for i in range(1,args.m+1): 385 | fnet_dict[2*i-1].load_state_dict(checkpoint2['f{0}'.format(2*i-1)]) 386 | fnet_dict[2*i].load_state_dict(checkpoint2['f{0}'.format(2*i)]) 387 | 388 | # Now load them onto devices 389 | for i in range(1, args.m+1): 390 | gnet_dict[i].to(device) 391 | 392 | 393 | for i in range(1, args.m+1): 394 | fnet_dict[2*i-1].to(device) 395 | fnet_dict[2*i].to(device) 396 | print("Models are loaded!") 397 | 398 | 399 | 400 | ###### 401 | ## Pairwise distances 402 | ###### 403 | 404 | def bin_array(num, m): 405 | """Convert a positive integer num into an m-bit bit vector""" 406 | return np.array(list(np.binary_repr(num).zfill(m))).astype(np.float32).reshape(-1) 407 | 408 | all_msg_bits = [] 409 | 410 | for i in range(2**(args.m+1)): 411 | all_msg_bits.append(bin_array(i,args.m+1)*2-1) 412 | 413 | 414 | all_msg_bits = torch.tensor(np.array(all_msg_bits)).to(device) 415 | 416 | print(all_msg_bits) 417 | 418 | 419 | def pairwise_distances(codebook): 420 | dists = [] 421 | for row1, row2 in combinations(codebook, 2): 422 | distance = (row1-row2).pow(2).sum() 423 | dists.append(np.sqrt(distance.item())) 424 | return dists, np.min(dists) 425 | 426 | 427 | codebook_reusable_NN = encoder_full(all_msg_bits, gnet_dict, args.power_constraint_type, training_mode='test') # Just testing 428 | pairwise_dist_neural, d_min_reusable_NN = pairwise_distances(codebook_reusable_NN.data.cpu()) 429 | 430 | # codebook_neural_PlusOne = codebook_reusable_NN[PlusOneIdx] 431 | # codebook_neural_MinusOne = codebook_reusable_NN[MinusOneIdx] 432 | 433 | 434 | codebook_quantized_reuse_NN = codebook_reusable_NN.sign() 435 | _, d_min_quantized = pairwise_distances(codebook_quantized_reuse_NN.data.cpu()) 436 | 437 | 438 | codebook_plotkin = encoder_Plotkin(all_msg_bits) 439 | pairwise_dist_plotkin, d_min_plotkin = pairwise_distances(codebook_plotkin.data.cpu()) 440 | 441 | print("Neural Codebook with d_min: {0: .4f} is \n {1}".format(d_min_reusable_NN, codebook_reusable_NN.data.cpu().numpy())) 442 | 443 | print("Quantized Neural Codebook with d_min: {0: .4f} is \n {1}".format(d_min_quantized, codebook_quantized_reuse_NN.data.cpu().numpy())) 444 | 445 | print("Plotkin Codebook with d_min: {0: .4f} is \n {1}".format(d_min_plotkin, codebook_plotkin)) 446 | 447 | 448 | Gaussian_codebook = F.normalize(torch.randn(2**(args.m+1), 2**args.m), p=2, dim=1)*np.sqrt(2**args.m) 449 | pairwise_dist_Gaussian, d_min_Gaussian = pairwise_distances(Gaussian_codebook) 450 | print(Gaussian_codebook[1:3,:].pow(2).sum(1)) 451 | 452 | 453 | 454 | ### 455 | 456 | 457 | all_msg_bits_large = all_msg_bits.t().unsqueeze(0).repeat(1000, 1, 1).to(device) 458 | 459 | 460 | def encoder_codebook(msg_bits, codebook ): 461 | msg_bits_large = msg_bits.unsqueeze(2).repeat(1, 1, 2**(args.m+1)).to(device) 462 | diff = (msg_bits_large - all_msg_bits_large).pow(2).sum(dim=1) 463 | idx = diff.argmin(dim=1, keepdim=False) 464 | return codebook[idx,:] 465 | 466 | ########## 467 | ### Histogram stuff 468 | ######### 469 | 470 | total_pairwise_dist = len(pairwise_dist_neural) 471 | 472 | print(total_pairwise_dist) 473 | # print(total_pairwise_dist) 474 | 475 | # if m == 6: 476 | # range_histogram = (8,16) 477 | # elif m ==8: 478 | # range_histogram = (10, 25) 479 | 480 | min_stuff = np.min([np.min(pairwise_dist_neural), np.min(pairwise_dist_Gaussian)]) 481 | max_stuff = np.max([np.max(pairwise_dist_neural), np.max(pairwise_dist_Gaussian)]) 482 | 483 | bins = np.linspace(min_stuff, max_stuff, 1000) 484 | # bins = np.arange(np.floor(np.min(pairwise_dist_neural)),np.ceil(np.max(pairwise_dist_neural))) 485 | 486 | n_neural, bins_neural = np.histogram(pairwise_dist_neural, bins=bins, density=True)#, density = False, bins=100, label='Neural Code: d_min={0:.2f}'.format(d_min_reusable_NN)) 487 | 488 | # print("Neural", n_neural, "\n Bins:",bins_neural) 489 | print(n_neural.sum()) 490 | # print(np.all(np.diff(bins_neural)==1)) 491 | 492 | n_RM, bins_RM = np.histogram(pairwise_dist_plotkin, bins=bins, density=False) 493 | 494 | n_Gaussian, bins_Gaussian = np.histogram(pairwise_dist_Gaussian, bins= bins,\ 495 | density=True) 496 | # print("Gaussian", n_Gaussian, "\n Bins:",bins_Gaussian) 497 | print(n_Gaussian.sum()) 498 | # n_neural = n_neural / n_neural.sum() 499 | n_RM = (1/total_pairwise_dist)*n_RM*510/511 500 | # n_Gaussian = n_Gaussian / n_Gaussian.sum() 501 | 502 | 503 | # print(n_RM) 504 | 505 | from scipy.signal import savgol_filter 506 | n_Gaussian = savgol_filter(n_Gaussian, 101, 5) 507 | n_neural = savgol_filter(n_neural, 101, 5) 508 | # n_RM = savgol_filter(n_RM, 101, 5) 509 | 510 | fig, ax = plt.subplots(figsize= (10, 7)) 511 | 512 | 513 | 514 | 515 | # ax.annotate('Min dist. of RM={0:.2f}'.format(d_min_plotkin), xy=(d_min_plotkin, 0.))#, xytext=(25.,0.05), arrowprops=dict(facecolor='black', shrink=0.05)) 516 | plt.plot(bins_RM[:-1], n_RM, label='RM: Min dist = {0:.2f}'.format(d_min_plotkin), linewidth=2.0) 517 | plt.plot(bins_neural[:-1], n_neural, label='Neural RM: Min dist = {0:.2f}'.format(d_min_reusable_NN), linewidth=2.0) 518 | plt.plot(bins_Gaussian[:-1], n_Gaussian, label='Random Gaussian: Min dist = {0:.2f}'.format(d_min_Gaussian), linewidth=2.0) 519 | plt.xlabel("Pairwise distances", fontsize=16) 520 | plt.ylabel("Probability density/mass", fontsize=16) 521 | plt.legend(loc='upper right', prop={'size': 15}) 522 | plt.title("Histogram of pairwise distances", fontsize=16) 523 | plt.savefig(results_load_path+'/Histogram_.pdf') 524 | 525 | 526 | 527 | 528 | 529 | ######## 530 | ### Testing stuff 531 | ######## 532 | 533 | batch_inflated_neural_codebook = codebook_reusable_NN.t().unsqueeze(0).repeat(1000, 1, 1) 534 | 535 | batch_inflated_Plotkin_codebook = codebook_plotkin.t().unsqueeze(0).repeat(1000, 1, 1) 536 | batch_inflated_Gaussian_codebook = Gaussian_codebook.t().unsqueeze(0).repeat(1000, 1, 1).to(device) 537 | 538 | print(batch_inflated_neural_codebook.shape, batch_inflated_Plotkin_codebook.shape, batch_inflated_Gaussian_codebook.shape) 539 | 540 | def decoder_MAP(corrupted_codewords, batch_inflated_codebook): 541 | 542 | corrupted_codewords_inflated = corrupted_codewords.unsqueeze(2).repeat(1, 1, 2**(args.m+1)) #Both are of shape (batch, 256, 512) 543 | 544 | diff = (corrupted_codewords_inflated - batch_inflated_codebook).pow(2).sum(dim=1) 545 | 546 | idx = diff.argmin(dim=1, keepdim=False) #(batch) 547 | 548 | decoded_bits = all_msg_bits[idx, :] 549 | 550 | return decoded_bits 551 | 552 | def test_MAP(msg_bits, snr): 553 | 554 | # codewords_old_NN = encoder_nn_old(msg_bits, g1net, g2net) 555 | 556 | codewords_reuse_NN = encoder_full(msg_bits, gnet_dict, args.power_constraint_type, training_mode='test') 557 | codewords_Plotkin = encoder_Plotkin(msg_bits) 558 | 559 | noise_sigma = snr_db2sigma(snr) 560 | standard_Gaussian = torch.randn_like(codewords_reuse_NN) 561 | 562 | corrupted_codewords_reuse_NN = codewords_reuse_NN + noise_sigma * standard_Gaussian 563 | corrupted_codewords_Plotkin = codewords_Plotkin + noise_sigma * standard_Gaussian 564 | 565 | dumer_decoded_bits = decoder_MAP(corrupted_codewords_Plotkin, batch_inflated_Plotkin_codebook) 566 | 567 | nn_decoded_bits = decoder_MAP(corrupted_codewords_reuse_NN, batch_inflated_neural_codebook) 568 | 569 | 570 | ber_dumer = errors_ber(msg_bits, dumer_decoded_bits).item() 571 | 572 | 573 | ber_nn = errors_ber(msg_bits, nn_decoded_bits).item() 574 | 575 | return ber_dumer, ber_nn 576 | 577 | 578 | def test_all(msg_bits, snr): 579 | 580 | # codewords_old_NN = encoder_nn_old(msg_bits, g1net, g2net) 581 | 582 | codewords_reuse_NN = encoder_full(msg_bits, gnet_dict, args.power_constraint_type, training_mode='test') 583 | codewords_Plotkin = encoder_Plotkin(msg_bits) 584 | 585 | noise_sigma = snr_db2sigma(snr) 586 | standard_Gaussian = torch.randn_like(codewords_reuse_NN) 587 | 588 | corrupted_codewords_reuse_NN = codewords_reuse_NN + noise_sigma * standard_Gaussian 589 | corrupted_codewords_Plotkin = codewords_Plotkin + noise_sigma * standard_Gaussian 590 | 591 | dumer_decoded_bits = decoder_soft_FHT(corrupted_codewords_Plotkin, snr, m)[:, tree_bits_order_from_standard].sign() #_dumer 592 | nn_decoded_bits = decoder_soft_FHT(corrupted_codewords_reuse_NN, snr, m)[:, tree_bits_order_from_standard].sign() 593 | 594 | 595 | ber_dumer = errors_ber(msg_bits, dumer_decoded_bits).item() 596 | ber_nn = errors_ber(msg_bits, nn_decoded_bits).item() 597 | 598 | return ber_dumer, ber_nn 599 | 600 | 601 | def test_MAP_and_all(msg_bits, snr): 602 | 603 | ## Common stuff 604 | 605 | noise_sigma = snr_db2sigma(snr) 606 | 607 | 608 | codewords_Plotkin = encoder_Plotkin(msg_bits) 609 | codewords_reuse_NN = encoder_full(msg_bits, gnet_dict, args.power_constraint_type, training_mode='test') 610 | 611 | codewords_Gaussian = encoder_codebook(msg_bits, Gaussian_codebook).to(device) 612 | 613 | standard_Gaussian = torch.randn_like(codewords_reuse_NN) 614 | 615 | corrupted_codewords_Plotkin = codewords_Plotkin + noise_sigma * standard_Gaussian 616 | corrupted_codewords_reuse_NN = codewords_reuse_NN + noise_sigma * standard_Gaussian 617 | corrupted_codewords_Gaussian = codewords_Gaussian+noise_sigma * standard_Gaussian 618 | 619 | 620 | 621 | ### MAP stuff 622 | dumer_decoded_bits = decoder_MAP(corrupted_codewords_Plotkin, batch_inflated_Plotkin_codebook) 623 | nn_decoded_bits = decoder_MAP(corrupted_codewords_reuse_NN, batch_inflated_neural_codebook) 624 | 625 | Gaussian_decoded_bits = decoder_MAP(corrupted_codewords_Gaussian, batch_inflated_Gaussian_codebook) 626 | 627 | ber_dumer_map = errors_ber(msg_bits, dumer_decoded_bits).item() 628 | ber_nn_map = errors_ber(msg_bits, nn_decoded_bits).item() 629 | 630 | ber_Gaussian_map = errors_ber(msg_bits, Gaussian_decoded_bits).item() 631 | 632 | 633 | bler_msg_dumer_map = errors_bler(msg_bits, dumer_decoded_bits).item() 634 | bler_msg_nn_map = errors_bler(msg_bits, nn_decoded_bits).item() 635 | 636 | ### Existing decoding algorithms' stuff 637 | 638 | # dumer_decoded_bits = decoder_soft_FHT(corrupted_codewords_Plotkin, snr, m)[:, tree_bits_order_from_standard].sign() #_dumer 639 | # nn_decoded_bits = first_principle_soft_MAP(corrupted_codewords_reuse_NN, \ 640 | # codebook_neural_PlusOne, codebook_neural_MinusOne).sign() 641 | 642 | 643 | dumer_decoded_bits = decoder_dumer(corrupted_codewords_Plotkin, snr) 644 | 645 | nn_decoded_bits = decoder_nn_full(corrupted_codewords_reuse_NN, fnet_dict).sign() 646 | 647 | 648 | ber_dumer = errors_ber(msg_bits, dumer_decoded_bits).item() 649 | ber_nn = errors_ber(msg_bits, nn_decoded_bits).item() 650 | 651 | bler_dumer = errors_bler(msg_bits, dumer_decoded_bits).item() 652 | bler_nn = errors_bler(msg_bits, nn_decoded_bits).item() 653 | 654 | 655 | # bler_msg_dumer_map = errors_bler(msg_bits, dumer_decoded_bits).item() 656 | # bler_msg_nn_map = errors_bler(msg_bits, nn_decoded_bits).item() 657 | bler_msg_gaussian_map = errors_bler(msg_bits, Gaussian_decoded_bits).item() 658 | 659 | return ber_dumer_map, ber_nn_map, ber_dumer, ber_nn, bler_dumer, bler_nn, ber_Gaussian_map, bler_msg_dumer_map, bler_msg_nn_map, bler_msg_gaussian_map 660 | 661 | 662 | #### Final testing stuff 663 | 664 | snr_range = np.linspace(-6, 12 , 19) if args.m<=4 else np.linspace(-10., 0., 11) #9, 16)# 6, 13) 665 | test_size = 100000 666 | 667 | bers_dumer_test = [] 668 | bers_nn_test = [] 669 | 670 | blers_dumer_test = [] 671 | blers_nn_test = [] 672 | 673 | bers_dumer_test_map = [] 674 | bers_nn_test_map = [] 675 | 676 | blers_msg_dumer_map_test = [] 677 | blers_msg_nn_map_test = [] 678 | 679 | bers_Gaussian_map_test = [] 680 | blers_msg_gaussian_map_test = [] 681 | 682 | 683 | 684 | bersu1_nn_test = [] 685 | bersv1_nn_test = [] 686 | bersv2_nn_test = [] 687 | 688 | 689 | 690 | 691 | os.makedirs(results_load_path, exist_ok=True) 692 | results_file = os.path.join(results_load_path +'/ber_results.%s') 693 | results = ResultsLog(results_file % 'csv', results_file % 'html') 694 | 695 | ##### 696 | # Test Data 697 | ##### 698 | Test_msg_bits = 2 * (torch.rand(test_size, args.m+1) < 0.5).float() - 1 699 | Test_Data_Generator = torch.utils.data.DataLoader(Test_msg_bits, batch_size=1000 , shuffle=False, **kwargs) 700 | 701 | num_test_batches = len(Test_Data_Generator) 702 | 703 | for test_snr in tqdm(snr_range): 704 | 705 | bers_dumer, bers_nn = 0., 0. 706 | bers_dumer_map, bers_nn_map = 0., 0. 707 | 708 | blers_dumer, blers_nn = 0., 0. 709 | bers_Gaussian_map, blers_msg_dumer_map, blers_msg_nn_map, blers_msg_gaussian_map = 0.,0.,0.,0. 710 | 711 | 712 | for (k, msg_bits) in enumerate(Test_Data_Generator): 713 | 714 | msg_bits = msg_bits.to(device) 715 | 716 | ber_dumer_map, ber_nn_map, ber_dumer, ber_nn, bler_dumer, bler_nn, ber_Gaussian_map, bler_msg_dumer_map, bler_msg_nn_map, bler_msg_gaussian_map = test_MAP_and_all(msg_bits, snr=test_snr) 717 | 718 | bers_dumer_map += ber_dumer_map 719 | bers_nn_map += ber_nn_map 720 | bers_dumer += ber_dumer 721 | bers_nn += ber_nn 722 | blers_dumer += bler_dumer 723 | blers_nn += bler_nn 724 | 725 | bers_Gaussian_map +=ber_Gaussian_map 726 | blers_msg_dumer_map +=bler_msg_dumer_map 727 | blers_msg_nn_map +=bler_msg_nn_map 728 | blers_msg_gaussian_map +=bler_msg_gaussian_map 729 | 730 | 731 | bers_dumer_map /= num_test_batches 732 | bers_nn_map /= num_test_batches 733 | bers_dumer /= num_test_batches 734 | bers_nn /= num_test_batches 735 | 736 | blers_dumer /= num_test_batches 737 | blers_nn /= num_test_batches 738 | 739 | bers_Gaussian_map /= num_test_batches 740 | blers_msg_dumer_map /= num_test_batches 741 | blers_msg_nn_map /= num_test_batches 742 | blers_msg_gaussian_map /= num_test_batches 743 | 744 | bers_dumer_test.append(bers_dumer) 745 | bers_nn_test.append(bers_nn) 746 | 747 | bers_dumer_test_map.append(bers_dumer_map) 748 | bers_nn_test_map.append(bers_nn_map) 749 | 750 | 751 | blers_dumer_test.append(blers_dumer) 752 | blers_nn_test.append(blers_nn) 753 | 754 | 755 | blers_msg_dumer_map_test.append(blers_msg_dumer_map) 756 | blers_msg_nn_map_test.append(blers_msg_nn_map) 757 | 758 | bers_Gaussian_map_test.append(bers_Gaussian_map) 759 | blers_msg_gaussian_map_test.append(blers_msg_gaussian_map) 760 | 761 | results.add(Test_SNR = test_snr, NN_BER = bers_nn, Plotkin_BER = bers_dumer, NN_BLER = blers_nn , Plotkin_BLER = blers_dumer , NN_BER_MAP = bers_nn_map, Plotkin_BER_MAP = bers_dumer_map, RandGauss_BER_MAP = bers_Gaussian_map, RandGauss_BLER_MAP = blers_msg_gaussian_map) 762 | 763 | results.save() 764 | 765 | 766 | ### Plotting stuff 767 | 768 | ## BER 769 | plt.figure(figsize = (12,8)) 770 | 771 | ok = 1 772 | plt.semilogy(snr_range[:-ok], bers_dumer_test[:-ok], label="RM + Dumer", marker='o', linewidth=1.5) 773 | plt.semilogy(snr_range[:-ok], bers_nn_test[:-ok], label="Neural RM + Neural Dumer", marker='^', linewidth=1.5) 774 | plt.semilogy(snr_range[:-ok], bers_Gaussian_map_test[:-ok], label="Random Gaussian + MAP", marker='^', linewidth=1.5) 775 | plt.semilogy(snr_range[:-ok], bers_dumer_test_map[:-ok], label="RM + MAP (Inference)", marker='o', linewidth=1.5) 776 | plt.semilogy(snr_range[:-ok], bers_nn_test_map[:-ok], label="Neural RM + MAP (Inference)", marker='^', linewidth=1.5) 777 | 778 | plt.grid() 779 | plt.xlabel("SNR (dB)", fontsize=16) 780 | plt.ylabel("Bit Error Rate", fontsize=16) 781 | # plt.title("Trained at Enc_SNR = {0} dB and Dec_SNR = {1} dB".format(enc_train_snr, dec_train_snr)) 782 | plt.title("BER plot of Neural RM({0},1) codes: Trained at Enc:{1}dB, Dec:{2}dB".format(args.m, args.enc_train_snr, args.dec_train_snr)) 783 | plt.legend(prop={'size': 15}) 784 | plt.savefig(results_load_path + "/{0}_BER_at_Test_SNRs.pdf".format(args.m)) 785 | 786 | 787 | ### BLER 788 | plt.figure(figsize = (12,8)) 789 | 790 | ok = 1 791 | plt.semilogy(snr_range[:-ok], blers_dumer_test[:-ok], label="RM + Dumer", marker='o', linewidth=1.5) 792 | plt.semilogy(snr_range[:-ok], blers_nn_test[:-ok], label="Neural RM + Neural Dumer", marker='^', linewidth=1.5) 793 | # plt.semilogy(snr_range[:-ok], blers_Gaussian_map_test[:-ok], label="Random Gaussian + MAP", marker='^', linewidth=1.5) 794 | plt.semilogy(snr_range[:-ok], blers_msg_dumer_map_test[:-ok], label="RM + MAP (Inference)", marker='o', linewidth=1.5) 795 | plt.semilogy(snr_range[:-ok], blers_msg_nn_map_test[:-ok], label="Neural RM + MAP (Inference)", marker='^', linewidth=1.5) 796 | 797 | plt.grid() 798 | plt.xlabel("SNR (dB)", fontsize=16) 799 | plt.ylabel("Message bits-Block Error Rate", fontsize=16) 800 | # plt.title("Trained at Enc_SNR = {0} dB and Dec_SNR = {1} dB".format(enc_train_snr, dec_train_snr)) 801 | plt.title("BLER plot of Neural RM({0},1) codes: Trained at Enc:{1}dB, Dec:{2}dB".format(args.m, args.enc_train_snr, args.dec_train_snr)) 802 | plt.legend(prop={'size': 15}) 803 | plt.savefig(results_load_path + "/{0}_BLER_at_Test_SNRs.pdf".format(args.m)) 804 | 805 | 806 | 807 | 808 | 809 | -------------------------------------------------------------------------------- /test_Polar_m6k7.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | from torchvision.utils import save_image 13 | from torchvision.utils import make_grid 14 | import torch.utils.data 15 | from data_loader import * 16 | # from IPython import display 17 | 18 | import pickle 19 | import glob 20 | import os 21 | import logging 22 | import time 23 | from datetime import datetime 24 | from ast import literal_eval 25 | import matplotlib 26 | 27 | matplotlib.use('AGG') 28 | import matplotlib.pyplot as plt 29 | import matplotlib.animation as animation 30 | from PIL import Image 31 | 32 | import reed_muller_modules 33 | from reed_muller_modules.logging_utils import * 34 | 35 | from opt_einsum import contract # This is for faster torch.einsum 36 | 37 | import pandas as pd 38 | import numpy as np 39 | from scipy.stats import norm 40 | from tqdm import tqdm 41 | from itertools import combinations 42 | 43 | 44 | parser = argparse.ArgumentParser(description='(m,k) Polar') 45 | 46 | parser.add_argument('--m', type=int, default=6, help='number of layers in a polar code m') 47 | 48 | parser.add_argument('--batch_size', type=int, default=20000, help='size of the batches') 49 | parser.add_argument('--hidden_size', type=int, default=64, help='neural network size') 50 | 51 | parser.add_argument('--full_iterations', type=int, default=20000, help='full iterations') 52 | parser.add_argument('--enc_train_iters', type=int, default=50, help='encoder iterations') 53 | parser.add_argument('--dec_train_iters', type=int, default=500, help='decoder iterations') 54 | 55 | parser.add_argument('--enc_train_snr', type=float, default=-0.5., help='snr at enc are trained') 56 | parser.add_argument('--dec_train_snr', type=float, default=-2.5., help='snr at dec are trained') 57 | 58 | parser.add_argument('--power_constraint_type', type=str, default='hard_power_block', help='typer of power constraint') 59 | parser.add_argument('--loss_type', type=str, default='BCE', choices=['MSE', 'BCE'], help='loss function') 60 | 61 | parser.add_argument('--gpu', type=int, default=0, help='gpus used for training - e.g 0,1,3') 62 | 63 | args = parser.parse_args() 64 | 65 | 66 | device = torch.device("cuda:{0}".format(args.gpu)) 67 | # device = torch.device("cpu") 68 | kwargs = {'num_workers': 4, 'pin_memory': False} 69 | 70 | 71 | def repetition_code_matrices(device, m=8): 72 | M_dict = {} 73 | 74 | for i in range(1, m): 75 | M_dict[i] = torch.ones(1, 2 ** i).to(device) 76 | 77 | return M_dict 78 | 79 | 80 | repetition_M_dict = repetition_code_matrices(device, args.m) 81 | 82 | 83 | print("Matrices required for repition code are defined!") 84 | 85 | 86 | ###### 87 | ## Functions 88 | ###### 89 | 90 | def snr_db2sigma(train_snr): 91 | return 10 ** (-train_snr * 1.0 / 20) 92 | 93 | 94 | def log_sum_exp(LLR_vector): 95 | sum_vector = LLR_vector.sum(dim=1, keepdim=True) 96 | sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) 97 | 98 | return torch.logsumexp(sum_concat, dim=1) - torch.logsumexp(LLR_vector, dim=1) 99 | 100 | 101 | def errors_ber(y_true, y_pred): 102 | y_true = y_true.view(y_true.shape[0], -1, 1) 103 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 104 | 105 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 106 | res = sum(sum(myOtherTensor)) / (myOtherTensor.shape[0] * myOtherTensor.shape[1]) 107 | return res 108 | 109 | 110 | def errors_bler(y_true, y_pred): 111 | y_true = y_true.view(y_true.shape[0], -1, 1) 112 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 113 | 114 | decoded_bits = torch.round(y_pred).cpu() 115 | X_test = torch.round(y_true).cpu() 116 | tp0 = (abs(decoded_bits - X_test)).view([X_test.shape[0], X_test.shape[1]]) 117 | tp0 = tp0.detach().cpu().numpy() 118 | bler_err_rate = sum(np.sum(tp0, axis=1) > 0) * 1.0 / (X_test.shape[0]) 119 | return bler_err_rate 120 | 121 | 122 | class g_identity(nn.Module): 123 | def __init__(self): 124 | super(g_vector, self).__init__() 125 | self.fc = nn.Linear(1, 1, bias=False) 126 | 127 | def forward(self, y): 128 | return y 129 | 130 | 131 | class g_vector(nn.Module): 132 | def __init__(self): 133 | super(g_vector, self).__init__() 134 | self.fc = nn.Linear(16, 1, bias=True) 135 | 136 | def forward(self, y): 137 | return self.fc(y) 138 | 139 | 140 | class g_Full(nn.Module): 141 | def __init__(self, input_size, hidden_size, output_size): 142 | super(g_Full, self).__init__() 143 | 144 | self.input_size = input_size 145 | 146 | self.half_input_size = int(input_size / 2) 147 | 148 | self.hidden_size = hidden_size 149 | self.output_size = output_size 150 | 151 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 152 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 153 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 154 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 155 | 156 | self.skip = nn.Linear(3 * self.half_input_size, self.hidden_size, bias=False) 157 | 158 | def forward(self, y): 159 | x = F.selu(self.fc1(y)) 160 | x = F.selu(self.fc2(x)) + self.skip( 161 | torch.cat([y, y[:, :self.half_input_size] * y[:, self.half_input_size:]], dim=1)) 162 | 163 | x = F.selu(self.fc3(x)) 164 | x = self.fc4(x) 165 | return x 166 | 167 | 168 | class f_Full(nn.Module): 169 | def __init__(self, input_size, hidden_size, output_size): 170 | super(f_Full, self).__init__() 171 | self.input_size = input_size 172 | self.hidden_size = hidden_size 173 | self.output_size = output_size 174 | 175 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 176 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 177 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 178 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 179 | 180 | def forward(self, y): 181 | x = F.selu(self.fc1(y)) 182 | x = F.selu(self.fc2(x)) 183 | 184 | x = F.selu(self.fc3(x)) 185 | x = self.fc4(x) 186 | return x 187 | 188 | 189 | def power_constraint(codewords, gnet_top, power_constraint_type, training_mode): 190 | if power_constraint_type in ['soft_power_block', 'soft_power_bit']: 191 | 192 | this_mean = codewords.mean(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.mean() 193 | this_std = codewords.std(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.std() 194 | 195 | if training_mode == 'train': # Training 196 | power_constrained_codewords = (codewords - this_mean) * 1.0 / this_std 197 | 198 | gnet_top.update_normstats_for_test(this_mean, this_std) 199 | 200 | elif training_mode == 'test': # For inference 201 | power_constrained_codewords = (codewords - gnet_top.mean_scalar) * 1.0 / gnet_top.std_scalar 202 | 203 | 204 | return power_constrained_codewords 205 | 206 | 207 | elif power_constraint_type == 'hard_power_block': 208 | 209 | return F.normalize(codewords, p=2, dim=1) * np.sqrt(2 ** args.m) 210 | 211 | 212 | else: 213 | 214 | return codewords / codewords.abs() 215 | 216 | 217 | def encoder_Polar_Plotkin(msg_bits): 218 | 219 | u_level1 = torch.cat([msg_bits[:, 6:7], msg_bits[:, 6:7] * msg_bits[:, 5:6]], dim=1) 220 | v_level1 = torch.cat([msg_bits[:, 4:5], msg_bits[:, 4:5] * msg_bits[:, 3:4]], dim=1) 221 | 222 | for i in range(2, args.m - 1): 223 | u_level1 = torch.cat([u_level1, u_level1 * v_level1], dim=1) 224 | v_level1 = msg_bits[:, 4-i:5-i].mm(repetition_M_dict[i]) 225 | 226 | u_level5 = torch.cat([u_level1, u_level1 * v_level1], dim=1) 227 | 228 | u_level6 = torch.cat([u_level5, u_level5], dim=1) 229 | 230 | return u_level6 231 | 232 | 233 | def encoder_Polar_full(msg_bits, gnet_dict, power_constraint_type='hard_power_block', 234 | training_mode='train'): # g_avector, g_bvector, 235 | 236 | u_level1 = torch.cat([msg_bits[:, 6:7], gnet_dict[1, 'right'](torch.cat([msg_bits[:, 6:7], msg_bits[:, 5:6]], dim=1)) ], dim=1) 237 | v_level1 = torch.cat([msg_bits[:, 4:5], gnet_dict[1, 'left'](torch.cat([msg_bits[:, 4:5], msg_bits[:, 3:4]], dim=1))], dim=1) 238 | 239 | for i in range(2, args.m - 1): 240 | u_level1 = torch.cat([u_level1, gnet_dict[i](torch.cat([u_level1, v_level1], dim=1)) ], dim=1) 241 | v_level1 = msg_bits[:, 4-i:5-i].mm(repetition_M_dict[i]) 242 | 243 | u_level5 = torch.cat([u_level1, gnet_dict[args.m-1](torch.cat([u_level1, v_level1], dim=1)) ], dim=1) 244 | 245 | u_level6 = torch.cat([u_level5, u_level5], dim=1) 246 | 247 | return power_constraint(u_level6, gnet_dict[args.m], power_constraint_type, training_mode) 248 | 249 | 250 | def awgn_channel(codewords, snr): 251 | noise_sigma = snr_db2sigma(snr) 252 | standard_Gaussian = torch.randn_like(codewords) 253 | corrupted_codewords = codewords + noise_sigma * standard_Gaussian 254 | return corrupted_codewords 255 | 256 | def decoder_Polar_SC(corrupted_codewords, snr): 257 | noise_sigma = snr_db2sigma(snr) 258 | 259 | llrs = (2 / noise_sigma ** 2) * corrupted_codewords 260 | Lu = llrs 261 | Lu = Lu[:, 32:] + Lu[:, :32] 262 | 263 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m + 1).to(device) 264 | 265 | for i in range(args.m - 2, 1, -1): 266 | Lv = log_sum_exp(torch.cat([Lu[:, :2 ** i].unsqueeze(2), Lu[:, 2 ** i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 267 | v_hat = torch.sign(Lv) 268 | decoded_bits[:, 4 - i] = v_hat.squeeze(1) 269 | Lu = Lu[:, :2 ** i] + v_hat * Lu[:, 2 ** i:] 270 | 271 | 272 | Lu2 = Lu 273 | Lv1 = log_sum_exp(torch.cat([Lu2[:, 0:2].unsqueeze(2), Lu2[:, 2:4].unsqueeze(2)], dim=2).permute(0, 2, 1)) 274 | L_u3 = log_sum_exp(torch.cat([Lv1[:, 0:1].unsqueeze(2), Lv1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 275 | u3_hat = torch.sign(L_u3) 276 | decoded_bits[:, 3] = u3_hat.squeeze(1) 277 | 278 | L_u4 = Lv1[:, 0:1] + u3_hat * Lv1[:, 1:2] 279 | u4_hat = torch.sign(L_u4) 280 | decoded_bits[:, 4] = u4_hat.squeeze(1) 281 | 282 | v1_hat = torch.cat([decoded_bits[:, 4:5], decoded_bits[:, 4:5] * decoded_bits[:, 3:4]], dim=1) 283 | Lu1 = Lu2[:, 0:2] + v1_hat * Lu2[:, 2:4] 284 | L_u5 = log_sum_exp(torch.cat([Lu1[:, 0:1].unsqueeze(2), Lu1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 285 | u5_hat = torch.sign(L_u5) 286 | decoded_bits[:, 5] = u5_hat.squeeze(1) 287 | 288 | L_u6 = Lu1[:, 0:1] + u5_hat * Lu1[:, 1:2] 289 | u6_hat = torch.sign(L_u6) 290 | decoded_bits[:, 6] = u6_hat.squeeze(1) 291 | 292 | return decoded_bits 293 | 294 | 295 | def decoder_Polar_SC_soft(corrupted_codewords, snr): 296 | noise_sigma = snr_db2sigma(snr) 297 | 298 | llrs = (2 / noise_sigma ** 2) * corrupted_codewords 299 | Lu = llrs 300 | Lu = Lu[:, 32:] + Lu[:, :32] 301 | 302 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m + 1).to(device) 303 | 304 | for i in range(args.m - 2, 1, -1): 305 | Lv = log_sum_exp(torch.cat([Lu[:, :2 ** i].unsqueeze(2), Lu[:, 2 ** i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 306 | v_hat = torch.tanh(Lv/2) 307 | decoded_bits[:, 4 - i] = v_hat.squeeze(1) 308 | Lu = Lu[:, :2 ** i] + v_hat * Lu[:, 2 ** i:] 309 | 310 | 311 | Lu2 = Lu 312 | Lv1 = log_sum_exp(torch.cat([Lu2[:, 0:2].unsqueeze(2), Lu2[:, 2:4].unsqueeze(2)], dim=2).permute(0, 2, 1)) 313 | L_u3 = log_sum_exp(torch.cat([Lv1[:, 0:1].unsqueeze(2), Lv1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 314 | u3_hat = torch.tanh(L_u3/2) 315 | decoded_bits[:, 3] = u3_hat.squeeze(1) 316 | 317 | L_u4 = Lv1[:, 0:1] + u3_hat * Lv1[:, 1:2] 318 | u4_hat = torch.tanh(L_u4/2) 319 | decoded_bits[:, 4] = u4_hat.squeeze(1) 320 | 321 | v1_hat = torch.cat([decoded_bits[:, 4:5], decoded_bits[:, 4:5] * decoded_bits[:, 3:4]], dim=1) 322 | Lu1 = Lu2[:, 0:2] + v1_hat * Lu2[:, 2:4] 323 | L_u5 = log_sum_exp(torch.cat([Lu1[:, 0:1].unsqueeze(2), Lu1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 324 | u5_hat = torch.tanh(L_u5/2) 325 | decoded_bits[:, 5] = u5_hat.squeeze(1) 326 | 327 | L_u6 = Lu1[:, 0:1] + u5_hat * Lu1[:, 1:2] 328 | u6_hat = torch.tanh(L_u6/2) 329 | decoded_bits[:, 6] = u6_hat.squeeze(1) 330 | 331 | return decoded_bits 332 | 333 | 334 | def decoder_Polar_nn_full(corrupted_codewords, fnet_dict): 335 | 336 | Lu = corrupted_codewords 337 | Lu = Lu[:, 32:] + Lu[:, :32] 338 | 339 | 340 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], args.m + 1).to(device) 341 | 342 | for i in range(args.m - 2, 1, -1): 343 | Lv = fnet_dict[i+1, 'left'](Lu) 344 | decoded_llrs[:, 4 - i] = Lv.squeeze(1) 345 | v_hat = torch.tanh(Lv/2) 346 | Lu = fnet_dict[i+1, 'right'](torch.cat([Lu[:, :2 ** i].unsqueeze(2), Lu[:, 2 ** i:].unsqueeze(2), v_hat.unsqueeze(1).repeat(1, 2 ** i, 1)],dim=2)).squeeze(2) 347 | 348 | 349 | Lu2 = Lu 350 | Lv1 = fnet_dict[2, 'left'](Lu2) 351 | L_u3 = fnet_dict[1, 'left', 'left'](Lv1) 352 | decoded_llrs[:, 3] = L_u3.squeeze(1) 353 | u3_hat = torch.tanh(0.5 * L_u3) 354 | 355 | L_u4 = fnet_dict[1, 'left', 'right'](torch.cat([Lv1[:, 0:1].unsqueeze(2), Lv1[:, 1:2].unsqueeze(2), u3_hat.unsqueeze(1).repeat(1, 1, 1)],dim=2)).squeeze(2) 356 | decoded_llrs[:, 4] = L_u4.squeeze(1) 357 | u4_hat = torch.tanh(0.5 * L_u4) 358 | 359 | v1_hat = torch.cat([u4_hat, gnet_dict[1, 'left'](torch.cat([torch.sign(L_u4), torch.sign(L_u3)], dim=1)) ], dim=1) 360 | Lu1 = fnet_dict[2, 'right'](torch.cat([Lu2[:, :2].unsqueeze(2), Lu2[:, 2:].unsqueeze(2), v1_hat.unsqueeze(2)],dim=2)).squeeze(2) 361 | L_u5 = fnet_dict[1, 'right', 'left'](Lu1) 362 | decoded_llrs[:, 5] = L_u5.squeeze(1) 363 | u5_hat = torch.tanh(0.5 * L_u5) 364 | 365 | L_u6 = fnet_dict[1, 'right', 'right'](torch.cat([Lu1[:, 0:1].unsqueeze(2), Lu1[:, 1:2].unsqueeze(2), u5_hat.unsqueeze(1).repeat(1, 1, 1)],dim=2)).squeeze(2) 366 | decoded_llrs[:, 6] = L_u6.squeeze(1) 367 | 368 | 369 | return decoded_llrs 370 | 371 | 372 | def get_msg_bits_batch(data_generator): 373 | msg_bits_batch = next(data_generator) 374 | return msg_bits_batch 375 | 376 | 377 | def moving_average(a, n=3): 378 | ret = np.cumsum(a, dtype=float) 379 | ret[n:] = ret[n:] - ret[:-n] 380 | return ret[n - 1:] / n 381 | 382 | 383 | print("Data loading stuff is completed! \n") 384 | 385 | 386 | gnet_dict = {} 387 | gnet_dict[1, 'left'] = g_Full(2, args.hidden_size, 1) 388 | gnet_dict[1, 'right'] = g_Full(2, args.hidden_size, 1) 389 | for i in range(2, args.m + 1): 390 | gnet_dict[i] = g_Full(2 * 2 ** (i - 1), args.hidden_size, 2 ** (i - 1)) 391 | 392 | fnet_dict = {} 393 | for i in range(3, 6): 394 | fnet_dict[i, 'left'] = f_Full(2 ** i, args.hidden_size, 1) 395 | fnet_dict[i, 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 396 | 397 | fnet_dict[2, 'left'] = f_Full(4, args.hidden_size, 2) 398 | fnet_dict[2, 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 399 | 400 | fnet_dict[1, 'left', 'left'] = f_Full(2, args.hidden_size, 1) 401 | fnet_dict[1, 'left', 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 402 | 403 | fnet_dict[1, 'right', 'left'] = f_Full(2, args.hidden_size, 1) 404 | fnet_dict[1, 'right', 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 405 | 406 | ######### 407 | ### Loading the models 408 | ######### 409 | 410 | results_load_path = './Neural_Plotkin_Results/Polar({0},{1})_2ndModel/NN_EncFull_Skip+Dec_SC/Enc_snr_{2}_Dec_snr{3}/Batch_{4}' \ 411 | .format(2**args.m, args.m+1, args.enc_train_snr, args.dec_train_snr, args.batch_size) 412 | 413 | checkpoint1 = torch.load(results_load_path + '/Models/Encoder_NN_4300.pt', map_location=lambda storage, loc: storage) 414 | 415 | gnet_dict[1, 'left'].load_state_dict(checkpoint1['g1_left']) 416 | gnet_dict[1, 'right'].load_state_dict(checkpoint1['g1_right']) 417 | for i in range(2, args.m + 1): 418 | gnet_dict[i].load_state_dict(checkpoint1['g{0}'.format(i)]) 419 | 420 | 421 | checkpoint2 = torch.load(results_load_path + '/Models/Decoder_NN_4300.pt', map_location=lambda storage, loc: storage) 422 | 423 | for i in range(2, args.m): 424 | fnet_dict[i, 'left'].load_state_dict(checkpoint2['f{0}_left'.format(i)]) 425 | fnet_dict[i, 'right'].load_state_dict(checkpoint2['f{0}_right'.format(i)]) 426 | fnet_dict[1, 'left', 'left'].load_state_dict(checkpoint2['f1_left_left']) 427 | fnet_dict[1, 'left', 'right'].load_state_dict(checkpoint2['f1_left_right']) 428 | fnet_dict[1, 'right', 'left'].load_state_dict(checkpoint2['f1_right_left']) 429 | fnet_dict[1, 'right', 'right'].load_state_dict(checkpoint2['f1_right_right']) 430 | 431 | gnet_dict[1, 'left'].to(device) 432 | gnet_dict[1, 'right'].to(device) 433 | for i in range(2, args.m + 1): 434 | gnet_dict[i].to(device) 435 | 436 | for i in range(2, 6): 437 | fnet_dict[i, 'left'].to(device) 438 | fnet_dict[i, 'right'].to(device) 439 | fnet_dict[1, 'left', 'left'].to(device) 440 | fnet_dict[1, 'left', 'right'].to(device) 441 | fnet_dict[1, 'right', 'left'].to(device) 442 | fnet_dict[1, 'right', 'right'].to(device) 443 | print("Models are loaded!") 444 | 445 | 446 | ###### 447 | ## Pairwise distances 448 | ###### 449 | 450 | def bin_array(num, m): 451 | """Convert a positive integer num into an m-bit bit vector""" 452 | return np.array(list(np.binary_repr(num).zfill(m))).astype(np.float32).reshape(-1) 453 | 454 | 455 | all_msg_bits = [] 456 | 457 | for i in range(2 ** (args.m + 1)): 458 | all_msg_bits.append(bin_array(i, args.m + 1) * 2 - 1) 459 | 460 | all_msg_bits = torch.tensor(np.array(all_msg_bits)).to(device) 461 | 462 | print(all_msg_bits) 463 | 464 | 465 | def pairwise_distances(codebook): 466 | dists = [] 467 | for row1, row2 in combinations(codebook, 2): 468 | distance = (row1 - row2).pow(2).sum() 469 | dists.append(np.sqrt(distance.item())) 470 | return dists, np.min(dists) 471 | 472 | 473 | codebook_reusable_NN = encoder_Polar_full(all_msg_bits, gnet_dict, args.power_constraint_type, 474 | training_mode='test') # Just testing 475 | pairwise_dist_neural, d_min_reusable_NN = pairwise_distances(codebook_reusable_NN.data.cpu()) 476 | 477 | 478 | 479 | 480 | codebook_quantized_reuse_NN = codebook_reusable_NN.sign() 481 | 482 | codebook_plotkin = encoder_Polar_Plotkin(all_msg_bits) 483 | 484 | 485 | 486 | Gaussian_codebook = F.normalize(torch.randn(2 ** (args.m + 1), 2 ** args.m), p=2, dim=1) * np.sqrt(2 ** args.m) 487 | 488 | 489 | 490 | all_msg_bits_large = all_msg_bits.t().unsqueeze(0).repeat(1000, 1, 1).to(device) 491 | 492 | 493 | def encoder_codebook(msg_bits, codebook): 494 | msg_bits_large = msg_bits.unsqueeze(2).repeat(1, 1, 2 ** (args.m + 1)).to(device) 495 | diff = (msg_bits_large - all_msg_bits_large).pow(2).sum(dim=1) 496 | idx = diff.argmin(dim=1, keepdim=False) 497 | return codebook[idx, :] 498 | 499 | 500 | 501 | ######## 502 | ### Testing stuff 503 | ######## 504 | 505 | batch_inflated_neural_codebook = codebook_reusable_NN.t().unsqueeze(0).repeat(1000, 1, 1) 506 | 507 | batch_inflated_Plotkin_codebook = codebook_plotkin.t().unsqueeze(0).repeat(1000, 1, 1) 508 | batch_inflated_Gaussian_codebook = Gaussian_codebook.t().unsqueeze(0).repeat(1000, 1, 1).to(device) 509 | 510 | print(batch_inflated_neural_codebook.shape, batch_inflated_Plotkin_codebook.shape, 511 | batch_inflated_Gaussian_codebook.shape) 512 | 513 | 514 | def decoder_MAP(corrupted_codewords, batch_inflated_codebook): 515 | corrupted_codewords_inflated = corrupted_codewords.unsqueeze(2).repeat(1, 1, 2 ** ( 516 | args.m + 1)) 517 | diff = (corrupted_codewords_inflated - batch_inflated_codebook).pow(2).sum(dim=1) 518 | 519 | idx = diff.argmin(dim=1, keepdim=False) # (batch) 520 | 521 | decoded_bits = all_msg_bits[idx, :] 522 | 523 | return decoded_bits 524 | 525 | 526 | def test_MAP_and_all(msg_bits, snr): 527 | ## Common stuff 528 | 529 | noise_sigma = snr_db2sigma(snr) 530 | 531 | codewords_Plotkin = encoder_Polar_Plotkin(msg_bits) 532 | codewords_reuse_NN = encoder_Polar_full(msg_bits, gnet_dict, args.power_constraint_type, training_mode='test') 533 | 534 | codewords_Gaussian = encoder_codebook(msg_bits, Gaussian_codebook).to(device) 535 | 536 | standard_Gaussian = torch.randn_like(codewords_reuse_NN) 537 | 538 | corrupted_codewords_Plotkin = codewords_Plotkin + noise_sigma * standard_Gaussian 539 | corrupted_codewords_reuse_NN = codewords_reuse_NN + noise_sigma * standard_Gaussian 540 | corrupted_codewords_Gaussian = codewords_Gaussian + noise_sigma * standard_Gaussian 541 | 542 | ### MAP stuff 543 | Plotkin_decoded_bits = decoder_MAP(corrupted_codewords_Plotkin, batch_inflated_Plotkin_codebook) 544 | nn_decoded_bits = decoder_MAP(corrupted_codewords_reuse_NN, batch_inflated_neural_codebook) 545 | 546 | Gaussian_decoded_bits = decoder_MAP(corrupted_codewords_Gaussian, batch_inflated_Gaussian_codebook) 547 | 548 | ber_Plotkin_map = errors_ber(msg_bits, Plotkin_decoded_bits).item() 549 | ber_nn_map = errors_ber(msg_bits, nn_decoded_bits).item() 550 | 551 | ber_Gaussian_map = errors_ber(msg_bits, Gaussian_decoded_bits).item() 552 | 553 | bler_msg_Plotkin_map = errors_bler(msg_bits, Plotkin_decoded_bits).item() 554 | bler_msg_nn_map = errors_bler(msg_bits, nn_decoded_bits).item() 555 | 556 | SC_decoded_bits = decoder_Polar_SC(corrupted_codewords_Plotkin, snr) 557 | 558 | nn_decoded_bits = decoder_Polar_nn_full(corrupted_codewords_reuse_NN, fnet_dict).sign() 559 | 560 | ber_SC = errors_ber(msg_bits, SC_decoded_bits).item() 561 | ber_nn = errors_ber(msg_bits, nn_decoded_bits).item() 562 | 563 | bler_SC = errors_bler(msg_bits, SC_decoded_bits).item() 564 | bler_nn = errors_bler(msg_bits, nn_decoded_bits).item() 565 | 566 | 567 | bler_msg_gaussian_map = errors_bler(msg_bits, Gaussian_decoded_bits).item() 568 | 569 | return ber_Plotkin_map, ber_nn_map, ber_SC, ber_nn, bler_SC, bler_nn, ber_Gaussian_map, bler_msg_Plotkin_map, bler_msg_nn_map, bler_msg_gaussian_map 570 | 571 | 572 | #### Final testing stuff 573 | 574 | snr_range = np.linspace(-6, 12, 19) if args.m <= 4 else np.linspace(-10., 2, 13) 575 | test_size = 1000000 576 | 577 | bers_SC_test = [] 578 | bers_nn_test = [] 579 | 580 | blers_SC_test = [] 581 | blers_nn_test = [] 582 | 583 | bers_Plotkin_test_map = [] 584 | bers_nn_test_map = [] 585 | 586 | blers_msg_Plotkin_map_test = [] 587 | blers_msg_nn_map_test = [] 588 | 589 | bers_Gaussian_map_test = [] 590 | blers_msg_gaussian_map_test = [] 591 | 592 | bersu1_nn_test = [] 593 | bersv1_nn_test = [] 594 | bersv2_nn_test = [] 595 | 596 | os.makedirs(results_load_path, exist_ok=True) 597 | results_file = os.path.join(results_load_path + '/ber_results.%s') 598 | results = ResultsLog(results_file % 'csv', results_file % 'html') 599 | 600 | ##### 601 | # Test Data 602 | ##### 603 | Test_msg_bits = 2 * (torch.rand(test_size, args.m + 1) < 0.5).float() - 1 604 | Test_Data_Generator = torch.utils.data.DataLoader(Test_msg_bits, batch_size=1000, shuffle=False, **kwargs) 605 | 606 | num_test_batches = len(Test_Data_Generator) 607 | 608 | for test_snr in tqdm(snr_range): 609 | 610 | bers_SC, bers_nn = 0., 0. 611 | bers_Plotkin_map, bers_nn_map = 0., 0. 612 | 613 | blers_SC, blers_nn = 0., 0. 614 | bers_Gaussian_map, blers_msg_Plotkin_map, blers_msg_nn_map, blers_msg_gaussian_map = 0., 0., 0., 0. 615 | 616 | for (k, msg_bits) in enumerate(Test_Data_Generator): 617 | msg_bits = msg_bits.to(device) 618 | 619 | ber_Plotkin_map, ber_nn_map, ber_SC, ber_nn, bler_SC, bler_nn, ber_Gaussian_map, bler_msg_Plotkin_map, bler_msg_nn_map, bler_msg_gaussian_map = test_MAP_and_all( 620 | msg_bits, snr=test_snr) 621 | 622 | 623 | bers_Plotkin_map += ber_Plotkin_map 624 | bers_nn_map += ber_nn_map 625 | bers_SC += ber_SC 626 | bers_nn += ber_nn 627 | blers_SC += bler_SC 628 | blers_nn += bler_nn 629 | 630 | bers_Gaussian_map += ber_Gaussian_map 631 | blers_msg_Plotkin_map += bler_msg_Plotkin_map 632 | blers_msg_nn_map += bler_msg_nn_map 633 | blers_msg_gaussian_map += bler_msg_gaussian_map 634 | 635 | bers_Plotkin_map /= num_test_batches 636 | bers_nn_map /= num_test_batches 637 | bers_SC /= num_test_batches 638 | bers_nn /= num_test_batches 639 | 640 | blers_SC /= num_test_batches 641 | blers_nn /= num_test_batches 642 | 643 | bers_Gaussian_map /= num_test_batches 644 | blers_msg_Plotkin_map /= num_test_batches 645 | blers_msg_nn_map /= num_test_batches 646 | blers_msg_gaussian_map /= num_test_batches 647 | 648 | bers_SC_test.append(bers_SC) 649 | bers_nn_test.append(bers_nn) 650 | 651 | bers_Plotkin_test_map.append(bers_Plotkin_map) 652 | bers_nn_test_map.append(bers_nn_map) 653 | 654 | blers_SC_test.append(blers_SC) 655 | blers_nn_test.append(blers_nn) 656 | 657 | blers_msg_Plotkin_map_test.append(blers_msg_Plotkin_map) 658 | blers_msg_nn_map_test.append(blers_msg_nn_map) 659 | 660 | bers_Gaussian_map_test.append(bers_Gaussian_map) 661 | blers_msg_gaussian_map_test.append(blers_msg_gaussian_map) 662 | 663 | results.add(Test_SNR=test_snr, NN_BER=bers_nn, Plotkin_BER=bers_SC, NN_BLER=blers_nn, Plotkin_BLER=blers_SC, 664 | NN_BER_MAP=bers_nn_map, Plotkin_BER_MAP=bers_Plotkin_map, NN_BLER_MAP=blers_msg_nn_map, Plotkin_BLER_MAP=blers_msg_Plotkin_map, 665 | RandGauss_BER_MAP=bers_Gaussian_map, RandGauss_BLER_MAP=blers_msg_gaussian_map) 666 | 667 | results.save() 668 | 669 | ### Plotting stuff 670 | 671 | ## BER 672 | plt.figure(figsize=(12, 8)) 673 | 674 | ok = 1 675 | plt.semilogy(snr_range[:-ok], bers_SC_test[:-ok], label="Polar + SC", marker='o', linewidth=1.5) 676 | plt.semilogy(snr_range[:-ok], bers_nn_test[:-ok], label="Neural Polar + Neural SC", marker='^', linewidth=1.5) 677 | plt.semilogy(snr_range[:-ok], bers_Gaussian_map_test[:-ok], label="Random Gaussian + MAP", marker='^', linewidth=1.5) 678 | plt.semilogy(snr_range[:-ok], bers_Plotkin_test_map[:-ok], label="Polar + MAP (Inference)", marker='o', linewidth=1.5) 679 | plt.semilogy(snr_range[:-ok], bers_nn_test_map[:-ok], label="Neural Polar + MAP (Inference)", marker='^', linewidth=1.5) 680 | plt.ylim(2*(10**-6), 0.4) 681 | 682 | plt.grid() 683 | plt.xlabel("SNR (dB)", fontsize=16) 684 | plt.ylabel("Bit Error Rate", fontsize=16) 685 | plt.title("BER plot of Neural Polar({0},{1}) codes: Trained at Enc:{2}dB, Dec:{3}dB with Batch Size: {4}".format(2**args.m, args.m+1, args.enc_train_snr, 686 | args.dec_train_snr, args.batch_size)) 687 | plt.legend(prop={'size': 15}) 688 | plt.savefig(results_load_path + "/{0}_BER_at_Test_SNRs.pdf".format(args.m)) 689 | 690 | ### BLER 691 | plt.figure(figsize=(12, 8)) 692 | 693 | ok = 1 694 | plt.semilogy(snr_range[:-ok], blers_SC_test[:-ok], label="Polar + SC", marker='o', linewidth=1.5) 695 | plt.semilogy(snr_range[:-ok], blers_nn_test[:-ok], label="Neural Polar + Neural SC", marker='^', linewidth=1.5) 696 | plt.semilogy(snr_range[:-ok], blers_msg_gaussian_map_test[:-ok], label="Random Gaussian + MAP", marker='^', linewidth=1.5) 697 | plt.semilogy(snr_range[:-ok], blers_msg_Plotkin_map_test[:-ok], label="Polar + MAP (Inference)", marker='o', linewidth=1.5) 698 | plt.semilogy(snr_range[:-ok], blers_msg_nn_map_test[:-ok], label="Neural Polar + MAP (Inference)", marker='^', 699 | linewidth=1.5) 700 | plt.ylim(2*(10**-6), 0.4) 701 | 702 | plt.grid() 703 | plt.xlabel("SNR (dB)", fontsize=16) 704 | plt.ylabel("Message bits-Block Error Rate", fontsize=16) 705 | plt.title("BLER plot of Neural Polar({0},{1}) codes: Trained at Enc:{2}dB, Dec:{3}dB with Batch Size: {4}".format(2**args.m, args.m+1, args.enc_train_snr, 706 | args.dec_train_snr, args.batch_size)) 707 | plt.legend(prop={'size': 15}) 708 | plt.savefig(results_load_path + "/{0}_BLER_at_Test_SNRs.pdf".format(args.m)) 709 | 710 | 711 | 712 | 713 | 714 | -------------------------------------------------------------------------------- /train_KO_m1_dumer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | from torchvision.utils import save_image 13 | from torchvision.utils import make_grid 14 | import torch.utils.data 15 | from data_loader import * 16 | from IPython import display 17 | 18 | import pickle 19 | import glob 20 | import os 21 | import logging 22 | import time 23 | from datetime import datetime 24 | from ast import literal_eval 25 | import matplotlib 26 | # matplotlib.use('AGG') 27 | import matplotlib.pyplot as plt 28 | import matplotlib.animation as animation 29 | from PIL import Image 30 | 31 | import reed_muller_modules 32 | from reed_muller_modules.logging_utils import * 33 | 34 | from opt_einsum import contract # This is for faster torch.einsum 35 | from reed_muller_modules.reedmuller_codebook import * 36 | from reed_muller_modules.hadamard import * 37 | from reed_muller_modules.comm_utils import * 38 | from reed_muller_modules.logging_utils import * 39 | from reed_muller_modules.all_functions import * 40 | import reed_muller_modules.reedmuller_codebook as reedmuller_codebook 41 | 42 | import pandas as pd 43 | import numpy as np 44 | from scipy.stats import norm 45 | from tqdm import tqdm 46 | from itertools import combinations 47 | 48 | 49 | parser = argparse.ArgumentParser(description='(m,1) dumer') 50 | 51 | parser.add_argument('--m', type=int, default=8, help='reed muller code parameter m') 52 | 53 | parser.add_argument('--batch_size', type=int, default=20000, help='size of the batches') 54 | parser.add_argument('--hidden_size', type=int, default=64, help='neural network size') 55 | 56 | parser.add_argument('--full_iterations', type=int, default=10000, help='full iterations') 57 | parser.add_argument('--enc_train_iters', type=int, default=50, help='encoder iterations') 58 | parser.add_argument('--dec_train_iters', type=int, default=500, help='decoder iterations') 59 | 60 | parser.add_argument('--enc_train_snr', type=float, default=-4., help='snr at enc are trained') 61 | parser.add_argument('--dec_train_snr', type=float, default=-7., help='snr at dec are trained') 62 | 63 | 64 | 65 | parser.add_argument('--loss_type', type=str, default='BCE', choices=['MSE', 'BCE'], help='loss function') 66 | 67 | parser.add_argument('--gpu', type=int, default=0, help='gpus used for training - e.g 0,1,3') 68 | 69 | args = parser.parse_args() 70 | 71 | device = torch.device("cuda:{0}".format(args.gpu)) 72 | kwargs = {'num_workers': 4, 'pin_memory': False} 73 | 74 | 75 | results_save_path = './Results/RM({0},1)/NN_EncFull_Skip+Dec_Dumer/Enc_snr_{1}_Dec_snr{2}/Batch_{3}'\ 76 | .format(args.m, args.enc_train_snr,args.dec_train_snr, args.batch_size) 77 | os.makedirs(results_save_path, exist_ok=True) 78 | os.makedirs(results_save_path+'/Models', exist_ok = True) 79 | 80 | def repetition_code_matrices(device, m=8): 81 | 82 | M_dict = {} 83 | 84 | for i in range(1, m): 85 | M_dict[i] = torch.ones(1, 2**i).to(device) 86 | 87 | return M_dict 88 | 89 | repetition_M_dict = repetition_code_matrices(device, args.m) 90 | 91 | print("Matrices required for repition code are defined!") 92 | 93 | ###### 94 | ## Functions 95 | ###### 96 | 97 | def snr_db2sigma(train_snr): 98 | return 10**(-train_snr*1.0/20) 99 | 100 | 101 | def log_sum_exp(LLR_vector): 102 | 103 | sum_vector = LLR_vector.sum(dim=1, keepdim=True) 104 | sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) 105 | 106 | return torch.logsumexp(sum_concat, dim=1)- torch.logsumexp(LLR_vector, dim=1) 107 | 108 | 109 | def errors_ber(y_true, y_pred): 110 | y_true = y_true.view(y_true.shape[0], -1, 1) 111 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 112 | 113 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 114 | res = sum(sum(myOtherTensor))/(myOtherTensor.shape[0]*myOtherTensor.shape[1]) 115 | return res 116 | 117 | 118 | def errors_bler(y_true, y_pred): 119 | y_true = y_true.view(y_true.shape[0], -1, 1) 120 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 121 | 122 | decoded_bits = torch.round(y_pred).cpu() 123 | X_test = torch.round(y_true).cpu() 124 | tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) 125 | tp0 = tp0.detach().cpu().numpy() 126 | bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) 127 | return bler_err_rate 128 | 129 | 130 | 131 | 132 | 133 | class g_identity(nn.Module): 134 | def __init__(self): 135 | super(g_vector, self).__init__() 136 | self.fc = nn.Linear(1, 1, bias=False) 137 | 138 | def forward(self, y): 139 | 140 | return y 141 | 142 | class g_vector(nn.Module): 143 | def __init__(self): 144 | super(g_vector, self).__init__() 145 | self.fc = nn.Linear(16, 1, bias=True) 146 | 147 | def forward(self, y): 148 | 149 | return self.fc(y) 150 | 151 | 152 | 153 | class g_Full(nn.Module): 154 | def __init__(self, input_size, hidden_size, output_size): 155 | super(g_Full, self).__init__() 156 | 157 | self.input_size = input_size 158 | 159 | self.half_input_size = int(input_size/2) 160 | 161 | self.hidden_size = hidden_size 162 | self.output_size = output_size 163 | 164 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 165 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 166 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 167 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 168 | 169 | # self.skip = nn.Linear(3*self.half_input_size, self.hidden_size, bias=False) 170 | 171 | def forward(self, y): 172 | x = F.selu(self.fc1(y)) 173 | x = F.selu(self.fc2(x)) 174 | 175 | x = F.selu(self.fc3(x)) 176 | x = self.fc4(x)+y[:, :self.half_input_size]*y[:, self.half_input_size:] 177 | return x 178 | 179 | class f_Full(nn.Module): 180 | def __init__(self, input_size, hidden_size, output_size): 181 | super(f_Full, self).__init__() 182 | self.input_size = input_size 183 | self.hidden_size = hidden_size 184 | self.output_size = output_size 185 | 186 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 187 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 188 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 189 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 190 | 191 | def forward(self, y): 192 | x = F.selu(self.fc1(y)) 193 | x = F.selu(self.fc2(x)) 194 | 195 | x = F.selu(self.fc3(x)) 196 | x = self.fc4(x) 197 | return x 198 | 199 | 200 | def power_constraint(codewords, gnet_top, power_constraint_type, training_mode): 201 | 202 | 203 | if power_constraint_type in ['soft_power_block','soft_power_bit']: 204 | 205 | this_mean = codewords.mean(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.mean() 206 | this_std = codewords.std(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.std() 207 | 208 | if training_mode == 'train': # Training 209 | power_constrained_codewords = (codewords - this_mean)*1.0 / this_std 210 | 211 | gnet_top.update_normstats_for_test(this_mean, this_std) 212 | 213 | elif training_mode == 'test': # For inference 214 | power_constrained_codewords = (codewords - gnet_top.mean_scalar)*1.0/gnet_top.std_scalar 215 | 216 | # else: # When updating the stat parameters of g2net. Just don't do anything 217 | # power_constrained_codewords = _ 218 | 219 | return power_constrained_codewords 220 | 221 | 222 | elif power_constraint_type == 'hard_power_block': 223 | 224 | return F.normalize(codewords, p=2, dim=1)*np.sqrt(2**args.m) 225 | 226 | 227 | else: # 'hard_power_bit' 228 | 229 | return codewords/codewords.abs() 230 | 231 | # Plotkin stuff 232 | def encoder_Plotkin(msg_bits): 233 | 234 | #msg_bits is of shape (batch, m+1) 235 | 236 | u_level0 = msg_bits[:, 0:1] 237 | v_level0 = msg_bits[:, 1:2] 238 | 239 | for i in range(2, args.m+1): 240 | 241 | u_level0 = torch.cat([ u_level0, u_level0 * v_level0], dim=1) 242 | v_level0 = msg_bits[:, i:i+1].mm(repetition_M_dict[i-1]) 243 | 244 | u_levelm = torch.cat([u_level0, u_level0 * v_level0], dim=1) 245 | 246 | return u_levelm 247 | 248 | 249 | 250 | def encoder_full(msg_bits, gnet_dict, power_constraint_type='hard_power_block', training_mode='train'): #g_avector, g_bvector, 251 | 252 | u_level0 = msg_bits[:, 0:1] 253 | v_level0 = msg_bits[:, 1:2] 254 | 255 | for i in range(2, args.m+1): 256 | 257 | u_level0 = torch.cat([ u_level0, gnet_dict[i-1](torch.cat([u_level0, v_level0], dim=1)) ], dim=1) 258 | v_level0 = msg_bits[:, i:i+1].mm(repetition_M_dict[i-1]) 259 | 260 | u_levelm = torch.cat([u_level0, gnet_dict[args.m](torch.cat([u_level0, v_level0], dim=1))], dim=1) 261 | 262 | return power_constraint(u_levelm, gnet_dict[args.m], power_constraint_type, training_mode) 263 | 264 | 265 | 266 | def awgn_channel(codewords, snr): 267 | noise_sigma = snr_db2sigma(snr) 268 | standard_Gaussian = torch.randn_like(codewords) 269 | corrupted_codewords = codewords+noise_sigma * standard_Gaussian 270 | return corrupted_codewords 271 | 272 | def decoder_dumer(corrupted_codewords, snr): 273 | 274 | noise_sigma = snr_db2sigma(snr) 275 | 276 | llrs = (2/noise_sigma**2)*corrupted_codewords 277 | Lu = llrs 278 | 279 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 280 | 281 | for i in range(args.m-1, -1, -1): 282 | 283 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 284 | 285 | v_hat = torch.sign(Lv) 286 | 287 | decoded_bits[:, i+1] = v_hat.squeeze(1) 288 | 289 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 290 | 291 | 292 | u_1_hat = torch.sign(Lu) 293 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 294 | 295 | return decoded_bits 296 | 297 | 298 | def decoder_dumer_soft(corrupted_codewords, snr): 299 | 300 | noise_sigma = snr_db2sigma(snr) 301 | 302 | llrs = (2/noise_sigma**2)*corrupted_codewords 303 | Lu = llrs 304 | 305 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 306 | 307 | for i in range(args.m-1, -1, -1): 308 | 309 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 310 | 311 | v_hat = torch.tanh(Lv/2) 312 | 313 | decoded_bits[:, i+1] = v_hat.squeeze(1) 314 | 315 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 316 | 317 | 318 | u_1_hat = torch.tanh(Lu/2) 319 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 320 | 321 | return decoded_bits 322 | 323 | 324 | def decoder_nn_full(corrupted_codewords, fnet_dict): 325 | 326 | Lu = corrupted_codewords 327 | 328 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 329 | 330 | for i in range(args.m-1, -1 , -1): 331 | 332 | Lv = fnet_dict[2*(args.m-i)-1](Lu)+log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 333 | 334 | v_hat = torch.tanh(Lv/2) 335 | 336 | decoded_llrs[:, i+1] = v_hat.squeeze(1) 337 | 338 | Lu = fnet_dict[2*(args.m-i)](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2), v_hat.unsqueeze(1).repeat(1, 2**i, 1)], dim=2)).squeeze(2)+Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 339 | 340 | u_1_hat = torch.tanh(Lu/2) 341 | 342 | decoded_llrs[:, 0] = u_1_hat.squeeze(1) 343 | 344 | 345 | return decoded_llrs 346 | 347 | 348 | def get_msg_bits_batch(data_generator): 349 | msg_bits_batch = next(data_generator) 350 | return msg_bits_batch 351 | 352 | def moving_average(a, n=3) : 353 | ret = np.cumsum(a, dtype=float) 354 | ret[n:] = ret[n:] - ret[:-n] 355 | return ret[n - 1:] / n 356 | 357 | 358 | def weights_init(m): 359 | classname = m.__class__.__name__ 360 | if classname.find('Conv') != -1: 361 | m.weight.data.normal_(0.0, 0.01) 362 | elif classname.find('BatchNorm') != -1: 363 | m.weight.data.normal_(0.0, 0.01) 364 | m.bias.data.fill_(0) 365 | elif classname.find('Linear') != -1: 366 | m.weight.data.normal_(0.0, 0.02) 367 | m.bias.data.fill_(0.) 368 | 369 | # msg_bits = 2 * (torch.rand(args.full_iterations * args.batch_size, args.m+1) < 0.5).float() - 1 370 | # Data_Generator = torch.utils.data.DataLoader(msg_bits, batch_size=args.batch_size , shuffle=True, **kwargs) 371 | 372 | print("Data loading stuff is completed! \n") 373 | 374 | gnet_dict = {} 375 | 376 | for i in range(1, args.m+1): 377 | gnet_dict[i] = g_Full(2*2**(i-1), args.hidden_size, 2**(i-1)) 378 | 379 | 380 | for i in range(1, args.m+1): 381 | gnet_dict[i].apply(weights_init) 382 | 383 | fnet_dict = {} 384 | 385 | for i in range(1, args.m+1): 386 | fnet_dict[2*i-1] = f_Full(2**(args.m-i+1), args.hidden_size, 1) 387 | fnet_dict[2*i] = f_Full(1+ 1+ 1, args.hidden_size, 1) 388 | 389 | 390 | 391 | 392 | for i in range(1, args.m+1): 393 | fnet_dict[2*i-1].apply(weights_init) 394 | fnet_dict[2*i].apply(weights_init) 395 | # Now load them onto devices 396 | for i in range(1, args.m+1): 397 | gnet_dict[i].to(device) 398 | 399 | 400 | for i in range(1, args.m+1): 401 | fnet_dict[2*i-1].to(device) 402 | fnet_dict[2*i].to(device) 403 | print("Models are loaded!") 404 | 405 | enc_params = [] 406 | for i in range(1, args.m+1): 407 | enc_params += list(gnet_dict[i].parameters()) 408 | 409 | dec_params = [] 410 | for i in range(1, args.m+1): 411 | dec_params += list(fnet_dict[2*i-1].parameters()) + list(fnet_dict[2*i].parameters()) 412 | 413 | enc_optimizer = optim.Adam(enc_params, lr = 1e-5)#, momentum=0.9, nesterov=True) #, amsgrad=True) 414 | dec_optimizer = optim.Adam(dec_params, lr = 1e-4)#, momentum=0.9, nesterov=True) #, amsgrad=True) 415 | 416 | criterion = nn.BCEWithLogitsLoss() if args.loss_type == 'BCE' else nn.MSELoss() # BinaryFocalLoss() #nn.BCEWithLogitsLoss() # nn.MSELoss() 417 | 418 | bers = [] 419 | losses = [] 420 | 421 | try: 422 | # for (k, msg_bits) in enumerate(Data_Generator): 423 | for k in range(args.full_iterations): 424 | start_time = time.time() 425 | msg_bits = 2 * (torch.rand(args.batch_size, args.m+1) < 0.5).float() - 1 426 | msg_bits = msg_bits.to(device) 427 | 428 | # # Train decoder 429 | for _ in range(args.dec_train_iters): 430 | 431 | # msg_bits = msg_bits.to(device) 432 | 433 | transmit_codewords = encoder_full(msg_bits, gnet_dict) 434 | corrupted_codewords = awgn_channel(transmit_codewords, args.dec_train_snr) 435 | decoded_bits = decoder_nn_full(corrupted_codewords, fnet_dict) 436 | 437 | loss = criterion(decoded_bits, 0.5*msg_bits+0.5) 438 | 439 | dec_optimizer.zero_grad() 440 | loss.backward() 441 | dec_optimizer.step() 442 | 443 | 444 | # Train Encoder 445 | for _ in range(args.enc_train_iters): 446 | 447 | # msg_bits = msg_bits.to(device) 448 | 449 | 450 | transmit_codewords = encoder_full(msg_bits, gnet_dict) 451 | corrupted_codewords = awgn_channel(transmit_codewords, args.enc_train_snr) 452 | decoded_bits = decoder_nn_full(corrupted_codewords, fnet_dict) 453 | 454 | loss = criterion(decoded_bits, 0.5*msg_bits+0.5 ) 455 | 456 | enc_optimizer.zero_grad() 457 | loss.backward() 458 | enc_optimizer.step() 459 | 460 | ber = errors_ber(msg_bits, decoded_bits.sign()).item() 461 | 462 | bers.append(ber) 463 | 464 | losses.append(loss.item()) 465 | if k % 10 == 0: 466 | print('[%d/%d] At %d dB, Loss: %.7f BER: %.7f' 467 | % (k+1, args.full_iterations, args.enc_train_snr, loss.item(), ber)) 468 | print("Time for one full iteration is {0:.4f} minutes".format((time.time() - start_time)/60)) 469 | 470 | 471 | # Save the model for safety 472 | if (k+1) % 10 == 0: 473 | 474 | torch.save(dict(zip(['g{0}'.format(i) for i in range(1, args.m+1)], [gnet_dict[i].state_dict() for i in range(1, args.m+1)])),\ 475 | results_save_path+'/Models/Encoder_NN_{0}.pt'.format(k+1)) 476 | 477 | torch.save(dict(zip(['f{0}'.format(i) for i in range(1,2*args.m+1)], [fnet_dict[i].state_dict() for i in range(1, 2*args.m+1)])),\ 478 | results_save_path+'/Models/Decoder_NN_{0}.pt'.format(k+1)) 479 | 480 | plt.figure() 481 | plt.plot(bers) 482 | plt.plot(moving_average(bers, n=10)) 483 | plt.savefig(results_save_path +'/training_ber.png') 484 | plt.close() 485 | 486 | plt.figure() 487 | plt.plot(losses) 488 | plt.plot(moving_average(losses, n=10)) 489 | plt.savefig(results_save_path +'/training_losses.png') 490 | plt.close() 491 | 492 | except KeyboardInterrupt: 493 | print('Graceful Exit') 494 | else: 495 | print('Finished') 496 | 497 | 498 | plt.figure() 499 | plt.plot(bers) 500 | plt.plot(moving_average(bers, n=10)) 501 | plt.savefig(results_save_path +'/training_ber.png') 502 | plt.close() 503 | 504 | plt.figure() 505 | plt.plot(losses) 506 | plt.plot(moving_average(losses, n=10)) 507 | plt.savefig(results_save_path +'/training_losses.png') 508 | plt.close() 509 | 510 | 511 | torch.save(dict(zip(['g{0}'.format(i) for i in range(1, args.m+1)], [gnet_dict[i].state_dict() for i in range(1, args.m+1)])),\ 512 | results_save_path+'/Models/Encoder_NN_{0}.pt'.format(k+1)) 513 | 514 | torch.save(dict(zip(['f{0}'.format(i) for i in range(1,2*args.m+1)], [fnet_dict[i].state_dict() for i in range(1, 2*args.m+1)])),\ 515 | results_save_path+'/Models/Decoder_NN_{0}.pt'.format(k+1)) 516 | -------------------------------------------------------------------------------- /train_KO_m1_map.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | from torchvision.utils import save_image 13 | from torchvision.utils import make_grid 14 | import torch.utils.data 15 | from data_loader import * 16 | from IPython import display 17 | 18 | import pickle 19 | import glob 20 | import os 21 | import logging 22 | import time 23 | from datetime import datetime 24 | from ast import literal_eval 25 | import matplotlib 26 | # matplotlib.use('AGG') 27 | import matplotlib.pyplot as plt 28 | import matplotlib.animation as animation 29 | from PIL import Image 30 | 31 | import reed_muller_modules 32 | from reed_muller_modules.logging_utils import * 33 | 34 | from opt_einsum import contract # This is for faster torch.einsum 35 | from reed_muller_modules.reedmuller_codebook import * 36 | from reed_muller_modules.hadamard import * 37 | from reed_muller_modules.comm_utils import * 38 | from reed_muller_modules.logging_utils import * 39 | from reed_muller_modules.all_functions import * 40 | import reed_muller_modules.reedmuller_codebook as reedmuller_codebook 41 | 42 | import pandas as pd 43 | import numpy as np 44 | from scipy.stats import norm 45 | from tqdm import tqdm 46 | from itertools import combinations 47 | 48 | 49 | parser = argparse.ArgumentParser(description='(m,1) dumer') 50 | 51 | parser.add_argument('--m', type=int, default=8, help='reed muller code parameter m') 52 | 53 | parser.add_argument('--batch_size', type=int, default=50000, help='size of the batches') 54 | parser.add_argument('--hidden_size', type=int, default=64, help='neural network size') 55 | 56 | parser.add_argument('--full_iterations', type=int, default=40000, help='full iterations') 57 | parser.add_argument('--enc_train_iters', type=int, default=50, help='encoder iterations') 58 | 59 | parser.add_argument('--enc_train_snr', type=float, default=-4., help='snr at enc are trained') 60 | 61 | 62 | 63 | parser.add_argument('--loss_type', type=str, default='BCE', choices=['MSE', 'BCE'], help='loss function') 64 | 65 | parser.add_argument('--gpu', type=int, default=0, help='gpus used for training - e.g 0,1,3') 66 | 67 | args = parser.parse_args() 68 | 69 | device = torch.device("cuda:{0}".format(args.gpu)) 70 | kwargs = {'num_workers': 4, 'pin_memory': False} 71 | 72 | 73 | results_save_path = './Results/RM({0},1)_softmap/NN_EncFull_Skip+Dec_Dumer/Enc_snr_{1}/Batch_{2}'\ 74 | .format(args.m, args.enc_train_snr, args.batch_size) 75 | os.makedirs(results_save_path, exist_ok=True) 76 | os.makedirs(results_save_path+'/Models', exist_ok = True) 77 | 78 | def repetition_code_matrices(device, m=8): 79 | 80 | M_dict = {} 81 | 82 | for i in range(1, m): 83 | M_dict[i] = torch.ones(1, 2**i).to(device) 84 | 85 | return M_dict 86 | 87 | repetition_M_dict = repetition_code_matrices(device, args.m) 88 | 89 | print("Matrices required for repition code are defined!") 90 | 91 | 92 | Mul_this_Matrix_Ind_Zero = Variable(torch.load('./data/{0}/Mul_this_matrix_Ind_Zero.pt'.format(args.m))).to(device) 93 | Mul_this_Matrix_Ind_One = Variable(torch.load('./data/{0}/Mul_this_matrix_Ind_One.pt'.format(args.m))).to(device) 94 | 95 | ## Loading the MAP indices for plus and minus one 96 | PlusOneIdx = torch.load('./data/{0}/CodebookIndex_this_matrix_Zero_PlusOne.pt'.format(args.m)).long().to(device) 97 | MinusOneIdx = torch.load('./data/{0}/CodebookIndex_this_matrix_One_MinusOne.pt'.format(args.m)).long().to(device) 98 | 99 | print(PlusOneIdx, "\n", MinusOneIdx) 100 | 101 | RM_Class = reedmuller_codebook.ReedMuller(1, args.m) 102 | msg_length = RM_Class.message_length() 103 | 104 | Generator_Matrix = numpy_to_torch(RM_Class.Generator_Matrix[:, ::-1].copy()) 105 | Generator_Matrix_cuda = Generator_Matrix.to(device) 106 | 107 | 108 | ## this is important because if the standard bits are (u_0, u_1, u_2,...,u_m) then the bits in our Plotkin structure are 109 | ## (u_0, u_m, u_m-1,...., u_1), i.e, u_1= u_0, v_1 = u_m,....,v_m = u_1. 110 | tree_bits_order_from_standard = [0] + list(range(args.m, 0, -1)) 111 | 112 | 113 | print(Generator_Matrix_cuda, "\n", tree_bits_order_from_standard) 114 | print("Mul_One:", Mul_this_Matrix_Ind_One, "\n", "Mul_Zero:", Mul_this_Matrix_Ind_Zero ) 115 | 116 | def llr_info_bits(hadamard_transform_llr, order_of_RM1): 117 | 118 | # Load Mul_this_Matrix_Ind_One/Zero before 119 | # order_of_RM1 = 7 120 | #hadam_transf_of_llr is of shape (batch*num_sparse, 128) 121 | 122 | LLR_Info_bits = torch.zeros(hadamard_transform_llr.shape[0], order_of_RM1 + 1).to(device) 123 | 124 | # Take care of tuple 125 | max_1, _ = hadamard_transform_llr.max(1) 126 | min_1, _ = hadamard_transform_llr.min(1) 127 | 128 | LLR_Info_bits[:, 0] = max_1 + min_1 129 | 130 | # modify this tomorrow morning 131 | 132 | max_zero, _ = torch.max(contract('ij, kj -> ikj', hadamard_transform_llr.abs() , Mul_this_Matrix_Ind_Zero), 2) 133 | max_one, _ = torch.max(contract('ij, kj -> ikj', hadamard_transform_llr.abs() , Mul_this_Matrix_Ind_One), 2) 134 | 135 | LLR_Info_bits[:, 1:] = max_zero - max_one 136 | 137 | return LLR_Info_bits 138 | 139 | 140 | def modified_llr_codeword(LLR_Info_bits): 141 | 142 | # Generator matrix of shape (m+1, 2^m) for RM(m, 1) code is needed here. So load it before hand 143 | # LLR_Info_bits is of shape (batch*num_sparse, m + 1) 144 | 145 | required_LLR_info = contract('ij , jk ->ikj', LLR_Info_bits, Generator_Matrix_cuda) # (batch*num_sparse, 2^m, m+1) 146 | 147 | sign_matrix = (-1)**((required_LLR_info < 0).sum(2)).float() # (batch*num_sparse, 2^m+1) 148 | 149 | min_abs_LLR_info, _= torch.min(torch.where(required_LLR_info==0., torch.max(required_LLR_info.abs())+1, required_LLR_info.abs()), dim = 2) 150 | 151 | return sign_matrix * min_abs_LLR_info 152 | 153 | 154 | def compute_llr_soft_decoding(llr, order_of_RM1): 155 | # hadamard_transform_llr = hadamard_transform_cuda(llr) # shape (batch_size , 128) 156 | 157 | hadamard_transform_llr = hadamard_transform_torch(llr) # shape (batch_size , 128) 158 | # return modified_llr_codeword(llr_info_bits(hadamard_transform_llr, m)) 159 | return llr_info_bits(hadamard_transform_llr, order_of_RM1) 160 | 161 | 162 | def decoder_soft_FHT(corrupted_codewords, snr, order_of_RM1): 163 | llr = llr_awgn_channel_bpsk(corrupted_codewords, snr) 164 | predicted_llr = compute_llr_soft_decoding(llr, order_of_RM1) # This order is in standard 165 | # return predicted_llr[:, tree_bits_order_from_standard] 166 | return predicted_llr 167 | 168 | 169 | 170 | def rm_encoder(msg_bits): 171 | msg_bits = 0.5-0.5*msg_bits 172 | randomly_gen_codebook = reed_muller_batch_encoding(msg_bits, Generator_Matrix_cuda) 173 | 174 | return 1-2*randomly_gen_codebook 175 | ###### 176 | ## Functions 177 | ###### 178 | 179 | def snr_db2sigma(train_snr): 180 | return 10**(-train_snr*1.0/20) 181 | 182 | 183 | def log_sum_exp(LLR_vector): 184 | 185 | sum_vector = LLR_vector.sum(dim=1, keepdim=True) 186 | sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) 187 | 188 | return torch.logsumexp(sum_concat, dim=1)- torch.logsumexp(LLR_vector, dim=1) 189 | 190 | 191 | def errors_ber(y_true, y_pred): 192 | y_true = y_true.view(y_true.shape[0], -1, 1) 193 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 194 | 195 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 196 | res = sum(sum(myOtherTensor))/(myOtherTensor.shape[0]*myOtherTensor.shape[1]) 197 | return res 198 | 199 | 200 | def errors_bler(y_true, y_pred): 201 | y_true = y_true.view(y_true.shape[0], -1, 1) 202 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 203 | 204 | decoded_bits = torch.round(y_pred).cpu() 205 | X_test = torch.round(y_true).cpu() 206 | tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) 207 | tp0 = tp0.detach().cpu().numpy() 208 | bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) 209 | return bler_err_rate 210 | 211 | 212 | 213 | 214 | 215 | class g_identity(nn.Module): 216 | def __init__(self): 217 | super(g_vector, self).__init__() 218 | self.fc = nn.Linear(1, 1, bias=False) 219 | 220 | def forward(self, y): 221 | 222 | return y 223 | 224 | class g_vector(nn.Module): 225 | def __init__(self): 226 | super(g_vector, self).__init__() 227 | self.fc = nn.Linear(16, 1, bias=True) 228 | 229 | def forward(self, y): 230 | 231 | return self.fc(y) 232 | 233 | 234 | 235 | class g_Full(nn.Module): 236 | def __init__(self, input_size, hidden_size, output_size): 237 | super(g_Full, self).__init__() 238 | 239 | self.input_size = input_size 240 | 241 | self.half_input_size = int(input_size/2) 242 | 243 | self.hidden_size = hidden_size 244 | self.output_size = output_size 245 | 246 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 247 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 248 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 249 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 250 | 251 | # self.skip = nn.Linear(3*self.half_input_size, self.hidden_size, bias=False) 252 | 253 | def forward(self, y): 254 | x = F.selu(self.fc1(y)) 255 | 256 | x = F.selu(self.fc2(x)) 257 | 258 | x = F.selu(self.fc3(x)) 259 | x = self.fc4(x) + y[:, :self.half_input_size]*y[:, self.half_input_size:] 260 | return x 261 | 262 | class f_Full(nn.Module): 263 | def __init__(self, input_size, hidden_size, output_size): 264 | super(f_Full, self).__init__() 265 | self.input_size = input_size 266 | self.hidden_size = hidden_size 267 | self.output_size = output_size 268 | 269 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 270 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 271 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 272 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 273 | 274 | def forward(self, y): 275 | x = F.selu(self.fc1(y)) 276 | x = F.selu(self.fc2(x)) 277 | 278 | x = F.selu(self.fc3(x)) 279 | x = self.fc4(x) 280 | return x 281 | 282 | 283 | def power_constraint(codewords, gnet_top, power_constraint_type, training_mode): 284 | 285 | 286 | if power_constraint_type in ['soft_power_block','soft_power_bit']: 287 | 288 | this_mean = codewords.mean(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.mean() 289 | this_std = codewords.std(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.std() 290 | 291 | if training_mode == 'train': # Training 292 | power_constrained_codewords = (codewords - this_mean)*1.0 / this_std 293 | 294 | gnet_top.update_normstats_for_test(this_mean, this_std) 295 | 296 | elif training_mode == 'test': # For inference 297 | power_constrained_codewords = (codewords - gnet_top.mean_scalar)*1.0/gnet_top.std_scalar 298 | 299 | # else: # When updating the stat parameters of g2net. Just don't do anything 300 | # power_constrained_codewords = _ 301 | 302 | return power_constrained_codewords 303 | 304 | 305 | elif power_constraint_type == 'hard_power_block': 306 | 307 | return F.normalize(codewords, p=2, dim=1)*np.sqrt(2**args.m) 308 | 309 | 310 | else: # 'hard_power_bit' 311 | 312 | return codewords/codewords.abs() 313 | 314 | # Plotkin stuff 315 | def encoder_Plotkin(msg_bits): 316 | 317 | #msg_bits is of shape (batch, m+1) 318 | 319 | u_level0 = msg_bits[:, 0:1] 320 | v_level0 = msg_bits[:, 1:2] 321 | 322 | for i in range(2, args.m+1): 323 | 324 | u_level0 = torch.cat([ u_level0, u_level0 * v_level0], dim=1) 325 | v_level0 = msg_bits[:, i:i+1].mm(repetition_M_dict[i-1]) 326 | 327 | u_levelm = torch.cat([u_level0, u_level0 * v_level0], dim=1) 328 | 329 | return u_levelm 330 | 331 | 332 | 333 | def encoder_full(msg_bits, gnet_dict, power_constraint_type='hard_power_block', training_mode='train'): #g_avector, g_bvector, 334 | 335 | u_level0 = msg_bits[:, 0:1] 336 | v_level0 = msg_bits[:, 1:2] 337 | 338 | for i in range(2, args.m+1): 339 | 340 | u_level0 = torch.cat([ u_level0, gnet_dict[i-1](torch.cat([u_level0, v_level0], dim=1)) ], dim=1) 341 | v_level0 = msg_bits[:, i:i+1].mm(repetition_M_dict[i-1]) 342 | 343 | u_levelm = torch.cat([u_level0, gnet_dict[args.m](torch.cat([u_level0, v_level0], dim=1))], dim=1) 344 | 345 | return power_constraint(u_levelm, gnet_dict[args.m], power_constraint_type, training_mode) 346 | 347 | 348 | 349 | def awgn_channel(codewords, snr): 350 | noise_sigma = snr_db2sigma(snr) 351 | standard_Gaussian = torch.randn_like(codewords) 352 | corrupted_codewords = codewords+noise_sigma * standard_Gaussian 353 | return corrupted_codewords 354 | 355 | def decoder_dumer(corrupted_codewords, snr): 356 | 357 | noise_sigma = snr_db2sigma(snr) 358 | 359 | llrs = (2/noise_sigma**2)*corrupted_codewords 360 | Lu = llrs 361 | 362 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 363 | 364 | for i in range(args.m-1, -1, -1): 365 | 366 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 367 | 368 | v_hat = torch.sign(Lv) 369 | 370 | decoded_bits[:, i+1] = v_hat.squeeze(1) 371 | 372 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 373 | 374 | 375 | u_1_hat = torch.sign(Lu) 376 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 377 | 378 | return decoded_bits 379 | 380 | 381 | def decoder_dumer_soft(corrupted_codewords, snr): 382 | 383 | noise_sigma = snr_db2sigma(snr) 384 | 385 | llrs = (2/noise_sigma**2)*corrupted_codewords 386 | Lu = llrs 387 | 388 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 389 | 390 | for i in range(args.m-1, -1, -1): 391 | 392 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 393 | 394 | v_hat = torch.tanh(Lv/2) 395 | 396 | decoded_bits[:, i+1] = v_hat.squeeze(1) 397 | 398 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 399 | 400 | 401 | u_1_hat = torch.tanh(Lu/2) 402 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 403 | 404 | return decoded_bits 405 | 406 | 407 | def decoder_nn_full(corrupted_codewords, fnet_dict): 408 | 409 | Lu = corrupted_codewords 410 | 411 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 412 | 413 | for i in range(args.m-1, -1 , -1): 414 | 415 | Lv = fnet_dict[2*(args.m-i)-1](Lu) 416 | 417 | decoded_llrs[:, i+1] = Lv.squeeze(1) 418 | 419 | Lu = fnet_dict[2*(args.m-i)](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2), Lv.unsqueeze(1).repeat(1, 2**i, 1)], dim=2)).squeeze(2) 420 | 421 | decoded_llrs[:, 0] = Lu.squeeze(1) 422 | 423 | 424 | return decoded_llrs 425 | 426 | 427 | def get_msg_bits_batch(data_generator): 428 | msg_bits_batch = next(data_generator) 429 | return msg_bits_batch 430 | 431 | def moving_average(a, n=3) : 432 | ret = np.cumsum(a, dtype=float) 433 | ret[n:] = ret[n:] - ret[:-n] 434 | return ret[n - 1:] / n 435 | 436 | 437 | def weights_init(m): 438 | classname = m.__class__.__name__ 439 | if classname.find('Conv') != -1: 440 | m.weight.data.normal_(0.0, 0.01) 441 | elif classname.find('BatchNorm') != -1: 442 | m.weight.data.normal_(0.0, 0.01) 443 | m.bias.data.fill_(0) 444 | elif classname.find('Linear') != -1: 445 | m.weight.data.normal_(0.0, 0.02) 446 | m.bias.data.fill_(0.) 447 | 448 | def first_principle_soft_MAP(corrupted_codewords, codebook_PlusOne, codebook_MinusOne): 449 | 450 | # modify this tomorrow morning 451 | 452 | max_PlusOne, _ = torch.max(contract('lk, ijk -> lij', corrupted_codewords, codebook_PlusOne), 2) 453 | max_MinusOne, _ = torch.max(contract('lk, ijk -> lij', corrupted_codewords, codebook_MinusOne), 2) 454 | 455 | return max_PlusOne - max_MinusOne 456 | 457 | def bin_array(num, m): 458 | """Convert a positive integer num into an m-bit bit vector""" 459 | return np.array(list(np.binary_repr(num).zfill(m))).astype(np.float32).reshape(-1) 460 | all_msg_bits = [] 461 | 462 | for i in range(2**(args.m+1)-1, -1, -1): 463 | all_msg_bits.append(bin_array(i,args.m+1)*2-1) 464 | 465 | 466 | all_msg_bits = torch.tensor(np.array(all_msg_bits)).to(device) 467 | 468 | 469 | # msg_bits = 2 * (torch.rand(args.full_iterations * args.batch_size, args.m+1) < 0.5).float() - 1 470 | # Data_Generator = torch.utils.data.DataLoader(msg_bits, batch_size=args.batch_size , shuffle=True, **kwargs) 471 | 472 | print("Data loading stuff is completed! \n") 473 | 474 | gnet_dict = {} 475 | 476 | for i in range(1, args.m+1): 477 | gnet_dict[i] = g_Full(2*2**(i-1), args.hidden_size, 2**(i-1)) 478 | 479 | for i in range(1, args.m+1): 480 | gnet_dict[i].apply(weights_init) 481 | 482 | 483 | # Now load them onto devices 484 | for i in range(1, args.m+1): 485 | gnet_dict[i].to(device) 486 | 487 | 488 | print("Models are loaded!") 489 | 490 | enc_params = [] 491 | for i in range(1, args.m+1): 492 | enc_params += list(gnet_dict[i].parameters()) 493 | 494 | 495 | enc_optimizer = optim.Adam(enc_params, lr = 1e-5)#, momentum=0.9, nesterov=True) #, amsgrad=True) 496 | 497 | criterion = nn.BCEWithLogitsLoss() if args.loss_type == 'BCE' else nn.MSELoss() # BinaryFocalLoss() #nn.BCEWithLogitsLoss() # nn.MSELoss() 498 | 499 | bers = [] 500 | losses = [] 501 | 502 | def pairwise_distances(codebook): 503 | dists = [] 504 | for row1, row2 in combinations(codebook, 2): 505 | distance = (row1-row2).pow(2).sum() 506 | dists.append(np.sqrt(distance.item())) 507 | return dists, np.min(dists) 508 | 509 | codebook_plotkin = encoder_Plotkin(all_msg_bits) 510 | pairwise_dist_plotkin, d_min_plotkin = pairwise_distances(codebook_plotkin.data.cpu()) 511 | Gaussian_codebook = F.normalize(torch.randn(2**(args.m+1), 2**args.m), p=2, dim=1)*np.sqrt(2**args.m) 512 | pairwise_dist_Gaussian, d_min_Gaussian = pairwise_distances(Gaussian_codebook) 513 | 514 | 515 | import scipy.stats 516 | rv = scipy.stats.chi(df=(2**(args.m)),scale=np.sqrt(2)) 517 | 518 | 519 | try: 520 | # for (k, msg_bits) in enumerate(Data_Generator): 521 | for k in range(args.full_iterations): 522 | 523 | 524 | 525 | start_time = time.time() 526 | msg_bits = 2 * (torch.rand(args.batch_size, args.m+1) < 0.5).float() - 1 527 | msg_bits = msg_bits.to(device) 528 | 529 | 530 | 531 | 532 | # Train Encoder 533 | for _ in range(args.enc_train_iters): 534 | 535 | # msg_bits = msg_bits.to(device) 536 | 537 | 538 | transmit_codewords = encoder_full(msg_bits, gnet_dict) 539 | corrupted_codewords = awgn_channel(transmit_codewords, args.enc_train_snr) 540 | 541 | codebook_neural = encoder_full(all_msg_bits, gnet_dict) 542 | 543 | codebook_neural_PlusOne = codebook_neural[PlusOneIdx] 544 | codebook_neural_MinusOne = codebook_neural[MinusOneIdx] 545 | 546 | decoded_bits = (1/snr_db2sigma(args.enc_train_snr)**2)*first_principle_soft_MAP(corrupted_codewords, \ 547 | codebook_neural_PlusOne, codebook_neural_MinusOne) 548 | 549 | loss = criterion(decoded_bits, 0.5*msg_bits+0.5 ) 550 | 551 | enc_optimizer.zero_grad() 552 | loss.backward() 553 | enc_optimizer.step() 554 | 555 | ber = errors_ber(msg_bits, decoded_bits.sign()).item() 556 | 557 | 558 | 559 | bers.append(ber) 560 | 561 | losses.append(loss.item()) 562 | 563 | 564 | 565 | if k % 10 == 0: 566 | print('[%d/%d] At %d dB, Loss: %.7f BER: %.7f' 567 | % (k+1, args.full_iterations, args.enc_train_snr, loss.item(), ber)) 568 | print("Time for one full iteration is {0:.4f} minutes".format((time.time() - start_time)/60)) 569 | 570 | 571 | # Save the model for safety 572 | if (k+1) % 100 == 0: 573 | 574 | torch.save(dict(zip(['g{0}'.format(i) for i in range(1, args.m+1)], [gnet_dict[i].state_dict() for i in range(1, args.m+1)])),\ 575 | results_save_path+'/Models/Encoder_NN_{0}.pt'.format(k+1)) 576 | 577 | 578 | plt.figure() 579 | plt.plot(bers) 580 | plt.plot(moving_average(bers, n=10)) 581 | plt.savefig(results_save_path +'/training_ber.png') 582 | plt.close() 583 | 584 | plt.figure() 585 | plt.plot(losses) 586 | plt.plot(moving_average(losses, n=10)) 587 | plt.savefig(results_save_path +'/training_losses.png') 588 | plt.close() 589 | 590 | except KeyboardInterrupt: 591 | print('Graceful Exit') 592 | else: 593 | print('Finished') 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | plt.figure() 602 | plt.plot(bers) 603 | plt.plot(moving_average(bers, n=10)) 604 | plt.savefig(results_save_path +'/training_ber.png') 605 | plt.close() 606 | 607 | plt.figure() 608 | plt.plot(losses) 609 | plt.plot(moving_average(losses, n=10)) 610 | plt.savefig(results_save_path +'/training_losses.png') 611 | plt.close() 612 | 613 | torch.save(dict(zip(['g{0}'.format(i) for i in range(1, args.m+1)], [gnet_dict[i].state_dict() for i in range(1, args.m+1)])),\ 614 | results_save_path+'/Models/Encoder_NN_{0}.pt'.format(k+1)) 615 | -------------------------------------------------------------------------------- /train_KO_m2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | from torchvision.utils import save_image 13 | from torchvision.utils import make_grid 14 | import torch.utils.data 15 | from data_loader import * 16 | from IPython import display 17 | 18 | import pickle 19 | import glob 20 | import os 21 | import logging 22 | import time 23 | from datetime import datetime 24 | from ast import literal_eval 25 | import matplotlib 26 | # matplotlib.use('AGG') 27 | import matplotlib.pyplot as plt 28 | from PIL import Image 29 | 30 | import reed_muller_modules 31 | from reed_muller_modules.logging_utils import * 32 | 33 | from opt_einsum import contract # This is for faster torch.einsum 34 | from reed_muller_modules.reedmuller_codebook import * 35 | from reed_muller_modules.hadamard import * 36 | from reed_muller_modules.comm_utils import * 37 | from reed_muller_modules.logging_utils import * 38 | from reed_muller_modules.all_functions import * 39 | # import reed_muller_modules.reedmuller_codebook as reedmuller_codebook 40 | 41 | import pandas as pd 42 | import numpy as np 43 | from scipy.stats import norm 44 | import matplotlib.pyplot as plt 45 | from tqdm import tqdm 46 | from itertools import combinations 47 | 48 | parser = argparse.ArgumentParser(description='(m,2) dumer') 49 | 50 | parser.add_argument('--m', type=int, default=8, help='reed muller code parameter m') 51 | 52 | parser.add_argument('--batch_size', type=int, default=50000, help='size of the batches') 53 | 54 | parser.add_argument('--small_batch_size', type=int, default=25000, help='size of the batches') 55 | 56 | parser.add_argument('--hidden_size', type=int, default=32, help='neural network size') 57 | 58 | parser.add_argument('--full_iterations', type=int, default=20000, help='full iterations') 59 | parser.add_argument('--enc_train_iters', type=int, default=50, help='encoder iterations') 60 | parser.add_argument('--dec_train_iters', type=int, default=500, help='decoder iterations') 61 | 62 | parser.add_argument('--enc_train_snr', type=float, default=0., help='snr at enc are trained') 63 | parser.add_argument('--dec_train_snr', type=float, default=-2., help='snr at dec are trained') 64 | 65 | 66 | 67 | parser.add_argument('--loss_type', type=str, default='BCE', choices=['MSE', 'BCE'], help='loss function') 68 | 69 | parser.add_argument('--gpu', type=int, default=7, help='gpus used for training - e.g 0,1,3') 70 | 71 | args = parser.parse_args() 72 | 73 | device = torch.device("cuda:{0}".format(args.gpu)) 74 | kwargs = {'num_workers': 4, 'pin_memory': False} 75 | 76 | results_save_path = './Results/RM({0},2)/fullNN_Enc+fullNN_Dec/Enc_snr_{1}_Dec_snr{2}/Batch_{3}'\ 77 | .format(args.m, args.enc_train_snr,args.dec_train_snr, args.batch_size) 78 | os.makedirs(results_save_path, exist_ok=True) 79 | os.makedirs(results_save_path+'/Models', exist_ok = True) 80 | 81 | 82 | def repetition_code_matrices(device, m=8): 83 | 84 | M_dict = {} 85 | 86 | for i in range(1, m): 87 | M_dict[i] = torch.ones(1, 2**i).to(device) 88 | 89 | return M_dict 90 | 91 | def first_order_generator_matrices(device, m=8): 92 | 93 | G_dict = {} 94 | 95 | for i in range(1, m): 96 | 97 | RM_Class = ReedMuller(1, i if i>0 else 1) 98 | Generator_Matrix = numpy_to_torch(RM_Class.Generator_Matrix[:, ::-1].copy()) 99 | G_dict[i] = Generator_Matrix.to(device) 100 | 101 | return G_dict 102 | 103 | 104 | def first_order_Mul_matrices(device, m =8): 105 | 106 | Mul_Ind_Zero_dict, Mul_Ind_One_dict = {}, {} 107 | 108 | for i in range(1, m): 109 | # ## Loading the Mul_matrices 110 | Mul_Ind_Zero_dict[i] = Variable(torch.load('./data/{0}/Mul_this_matrix_Ind_Zero.pt'.format( i if i>0 else 1))).to(device) 111 | Mul_Ind_One_dict[i] = Variable(torch.load('./data/{0}/Mul_this_matrix_Ind_One.pt'.format( i if i>0 else 1))).to(device) 112 | 113 | 114 | return Mul_Ind_Zero_dict, Mul_Ind_One_dict 115 | 116 | def bin_array(num, m): 117 | """Convert a positive integer num into an m-bit bit vector""" 118 | return np.array(list(np.binary_repr(num).zfill(m))).astype(np.float32).reshape(-1) 119 | 120 | 121 | def first_principle_soft_MAP(corrupted_codewords, codebook_PlusOne, codebook_MinusOne): 122 | 123 | # modify this tomorrow morning 124 | 125 | max_PlusOne, _ = torch.max(contract('lk, ijk -> lij', corrupted_codewords, codebook_PlusOne), 2) 126 | max_MinusOne, _ = torch.max(contract('lk, ijk -> lij', corrupted_codewords, codebook_MinusOne), 2) 127 | 128 | return max_PlusOne - max_MinusOne 129 | 130 | def weights_init(m): 131 | classname = m.__class__.__name__ 132 | if classname.find('Conv') != -1: 133 | m.weight.data.normal_(0.0, 0.01) 134 | elif classname.find('BatchNorm') != -1: 135 | m.weight.data.normal_(0.0, 0.01) 136 | m.bias.data.fill_(0) 137 | elif classname.find('Linear') != -1: 138 | m.weight.data.normal_(0.0, 0.02) 139 | m.bias.data.fill_(0.) 140 | 141 | repetition_M_dict = repetition_code_matrices(device,args.m) 142 | 143 | first_order_generator_dict = first_order_generator_matrices(device, args.m) 144 | 145 | first_order_Mul_Ind_Zero_dict, first_order_Mul_Ind_One_dict = first_order_Mul_matrices(device, args.m) 146 | 147 | 148 | 149 | print("Matrices required for first order code are defined!") 150 | 151 | msg_lengths = [2] 152 | 153 | for i in range(2, args.m+1): 154 | msg_lengths.append(i) 155 | 156 | msg_bits_partition_indices = np.cumsum(msg_lengths) 157 | 158 | code_dimension_k = msg_bits_partition_indices[-1] 159 | 160 | print(msg_bits_partition_indices, code_dimension_k) 161 | 162 | ###### 163 | ## Normal Functions 164 | ###### 165 | 166 | def snr_db2sigma(train_snr): 167 | return 10**(-train_snr*1.0/20) 168 | 169 | 170 | def log_sum_exp(LLR_vector): 171 | 172 | sum_vector = LLR_vector.sum(dim=1, keepdim=True) 173 | sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) 174 | 175 | return torch.logsumexp(sum_concat, dim=1)- torch.logsumexp(LLR_vector, dim=1) 176 | 177 | 178 | def errors_ber(y_true, y_pred): 179 | y_true = y_true.view(y_true.shape[0], -1, 1) 180 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 181 | 182 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 183 | res = sum(sum(myOtherTensor))/(myOtherTensor.shape[0]*myOtherTensor.shape[1]) 184 | return res 185 | 186 | 187 | def errors_bler(y_true, y_pred): 188 | y_true = y_true.view(y_true.shape[0], -1, 1) 189 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 190 | 191 | decoded_bits = torch.round(y_pred).cpu() 192 | X_test = torch.round(y_true).cpu() 193 | tp0 = (abs(decoded_bits-X_test)).view([X_test.shape[0],X_test.shape[1]]) 194 | tp0 = tp0.detach().cpu().numpy() 195 | bler_err_rate = sum(np.sum(tp0,axis=1)>0)*1.0/(X_test.shape[0]) 196 | return bler_err_rate 197 | 198 | 199 | 200 | #### 201 | ## Neural Network Stuff 202 | #### 203 | 204 | 205 | 206 | 207 | class g_Full(nn.Module): 208 | def __init__(self, input_size, hidden_size, output_size): 209 | super(g_Full, self).__init__() 210 | 211 | self.input_size = input_size 212 | 213 | self.half_input_size = int(input_size/2) 214 | 215 | self.hidden_size = hidden_size 216 | self.output_size = output_size 217 | 218 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 219 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 220 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 221 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 222 | 223 | 224 | def forward(self, y): 225 | x = F.selu(self.fc1(y)) 226 | 227 | x = F.selu(self.fc2(x)) 228 | 229 | x = F.selu(self.fc3(x)) 230 | x = self.fc4(x) + y[:, :, :self.half_input_size]*y[:,:, self.half_input_size:] 231 | return x 232 | 233 | class f_Full(nn.Module): 234 | def __init__(self, input_size, hidden_size, output_size): 235 | super(f_Full, self).__init__() 236 | self.input_size = input_size 237 | self.hidden_size = hidden_size 238 | self.output_size = output_size 239 | 240 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 241 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 242 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 243 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 244 | 245 | def forward(self, y): 246 | x = F.selu(self.fc1(y)) 247 | x = F.selu(self.fc2(x)) 248 | 249 | x = F.selu(self.fc3(x)) 250 | x = self.fc4(x) 251 | return x 252 | 253 | 254 | def power_constraint(codewords, gnet_top, power_constraint_type, training_mode): 255 | 256 | 257 | if power_constraint_type in ['soft_power_block','soft_power_bit']: 258 | 259 | this_mean = codewords.mean(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.mean() 260 | this_std = codewords.std(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.std() 261 | 262 | if training_mode == 'train': # Training 263 | power_constrained_codewords = (codewords - this_mean)*1.0 / this_std 264 | 265 | gnet_top.update_normstats_for_test(this_mean, this_std) 266 | 267 | elif training_mode == 'test': # For inference 268 | power_constrained_codewords = (codewords - gnet_top.mean_scalar)*1.0/gnet_top.std_scalar 269 | 270 | 271 | return power_constrained_codewords 272 | 273 | 274 | elif power_constraint_type == 'hard_power_block': 275 | 276 | return F.normalize(codewords, p=2, dim=1)*np.sqrt(2**args.m) 277 | 278 | 279 | else: # 'hard_power_bit' 280 | 281 | return codewords/codewords.abs() 282 | 283 | 284 | ##### 285 | ## Order-1 Encoding & Decoding stuff 286 | ##### 287 | 288 | def rm_encoder(msg_bits, Generator_Matrix): 289 | msg_bits = 0.5-0.5*msg_bits 290 | randomly_gen_codebook = reed_muller_batch_encoding(msg_bits, Generator_Matrix) 291 | 292 | return 1-2*randomly_gen_codebook 293 | 294 | 295 | 296 | def llr_info_bits(hadamard_transform_llr, order_of_RM1): 297 | 298 | 299 | max_1, _ = hadamard_transform_llr.max(1, keepdim=True) 300 | min_1, _ = hadamard_transform_llr.min(1, keepdim=True) 301 | 302 | LLR_zero_column = max_1 + min_1 303 | 304 | 305 | max_zero, _ = torch.max(contract('ij, kj -> ikj', hadamard_transform_llr.abs() , first_order_Mul_Ind_Zero_dict[order_of_RM1]), 2) 306 | max_one, _ = torch.max(contract('ij, kj -> ikj', hadamard_transform_llr.abs() , first_order_Mul_Ind_One_dict[order_of_RM1]), 2) 307 | 308 | LLR_remaining = max_zero - max_one 309 | 310 | return torch.cat([LLR_zero_column, LLR_remaining], dim=1) 311 | 312 | 313 | def modified_llr_codeword(LLR_Info_bits, order_of_RM1): 314 | 315 | 316 | 317 | required_LLR_info = contract('ij , jk ->ikj', LLR_Info_bits, first_order_generator_dict[order_of_RM1]) 318 | 319 | sign_matrix = (-1)**((required_LLR_info < 0).sum(2)).float() 320 | 321 | min_abs_LLR_info, _= torch.min(torch.where(required_LLR_info==0., torch.max(required_LLR_info.abs())+1, required_LLR_info.abs()), dim = 2) 322 | 323 | return sign_matrix * min_abs_LLR_info 324 | 325 | 326 | 327 | 328 | def FirstOrder_SoftFHT_InfoBits_decoder(llr, order_of_RM1): 329 | 330 | hadamard_transform_llr = hadamard_transform_torch(llr) 331 | return llr_info_bits(hadamard_transform_llr, order_of_RM1) 332 | 333 | 334 | def FirstOrder_SoftFHT_LLR_decoder(llr, order_of_RM1, normalize=False): 335 | 336 | hadamard_transform_llr = hadamard_transform_torch(llr) 337 | 338 | return modified_llr_codeword(llr_info_bits(hadamard_transform_llr, order_of_RM1), order_of_RM1) 339 | 340 | 341 | def FirstOrder_SoftFHT_Codewords_decoder(corrupted_codewords, snr, order_of_RM1): 342 | 343 | llr = llr_awgn_channel_bpsk(corrupted_codewords, snr) 344 | 345 | predicted_llr = FirstOrder_SoftFHT_LLR_decoder(llr, order_of_RM1) 346 | 347 | return predicted_llr 348 | 349 | 350 | ##### 351 | ## Order-1 old stuff and Order-2 352 | #### 353 | tree_bits_order_from_standard_dict = {} 354 | 355 | for i in range(1, args.m): 356 | 357 | tree_bits_order_from_standard_dict[i] = [0] + list(range(i, 0, -1)) 358 | 359 | 360 | # Leaves are repetition code 361 | def first_order_encoder_Plotkin(msg_bits, order_of_RM1): 362 | 363 | #msg_bits is of shape (batch, m+1) 364 | 365 | u_level0 = msg_bits[:, 0:1] 366 | v_level0 = msg_bits[:, 1:2] 367 | 368 | for i in range(2, order_of_RM1+1): 369 | 370 | u_level0 = torch.cat([ u_level0, u_level0 * v_level0], dim=1) 371 | v_level0 = msg_bits[:, i:i+1].mm(repetition_M_dict[i-1]) 372 | 373 | u_levelm = torch.cat([u_level0, u_level0 * v_level0], dim=1) 374 | 375 | return u_levelm 376 | 377 | def first_order_decoder_dumer(corrupted_codewords, snr): 378 | 379 | noise_sigma = snr_db2sigma(snr) 380 | 381 | llrs = (2/noise_sigma**2)*corrupted_codewords 382 | Lu = llrs 383 | 384 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 385 | 386 | for i in range(args.m-1, -1, -1): 387 | 388 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 389 | 390 | v_hat = torch.sign(Lv) 391 | 392 | decoded_bits[:, i+1] = v_hat.squeeze(1) 393 | 394 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 395 | 396 | 397 | u_1_hat = torch.sign(Lu) 398 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 399 | 400 | return decoded_bits 401 | 402 | 403 | def first_order_decoder_dumer_soft(corrupted_codewords, snr): 404 | 405 | noise_sigma = snr_db2sigma(snr) 406 | 407 | llrs = (2/noise_sigma**2)*corrupted_codewords 408 | Lu = llrs 409 | 410 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m+1).to(device) 411 | 412 | for i in range(args.m-1, -1, -1): 413 | 414 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 415 | 416 | v_hat = torch.tanh(Lv/2) 417 | 418 | decoded_bits[:, i+1] = v_hat.squeeze(1) 419 | 420 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 421 | 422 | 423 | u_1_hat = torch.tanh(Lu/2) 424 | decoded_bits[:, 0] = u_1_hat.squeeze(1) 425 | 426 | return decoded_bits 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | def first_order_decoder_nn_full(corrupted_codewords, fnet_dict): 435 | 436 | Lu = corrupted_codewords 437 | 438 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], m+1).to(device) 439 | 440 | for i in range(m-1, -1 , -1): 441 | 442 | Lv = fnet_dict[2*(m-i)-1](Lu) 443 | 444 | decoded_llrs[:, i+1] = Lv.squeeze(1) 445 | 446 | Lu = fnet_dict[2*(m-i)](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2), Lv.unsqueeze(1).repeat(1, 2**i, 1)], dim=2)).squeeze(2) 447 | 448 | 449 | return decoded_llrs 450 | 451 | 452 | ########### 453 | ## Order 2 -------------------------------------------------------------------------------------------------------------------- 454 | ########### 455 | 456 | 457 | 458 | def RM_22_Plotkin_encoder(msg_bits): 459 | 460 | u_level0 = torch.cat([msg_bits[:, 0:1], msg_bits[:, 0:1] * msg_bits[:, 1:2]], dim=1) 461 | v_level0 = torch.cat([msg_bits[:, 2:3], msg_bits[:, 2:3] * msg_bits[:, 3:4]], dim=1) 462 | 463 | return torch.cat([u_level0, u_level0 * v_level0], dim=1) 464 | 465 | 466 | def correct_second_order_encoder_Plotkin_RM_leaves(msg_bits): 467 | 468 | 469 | u_level0 = RM_22_Plotkin_encoder(msg_bits[:, :4]) 470 | v_level0 = first_order_encoder_Plotkin(msg_bits[:, msg_bits_partition_indices[1]: msg_bits_partition_indices[2]][:, tree_bits_order_from_standard_dict[2]], 2) 471 | 472 | for i in range(3, args.m): 473 | 474 | u_level0 = torch.cat([ u_level0, u_level0 * v_level0], dim=1) 475 | v_level0 = \ 476 | first_order_encoder_Plotkin(msg_bits[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]][:, tree_bits_order_from_standard_dict[i]], i) 477 | 478 | u_levelm = torch.cat([u_level0, u_level0 * v_level0], dim=1) 479 | 480 | return u_levelm 481 | 482 | def correct_second_order_encoder_Neural_RM_leaves(msg_bits, gnet_dict, power_constraint_type='hard_power_block', training_mode='train'): #g_avector, g_bvector, 483 | ## This denotes the RM(2,2) right most node for u and the left RM(2,1) node for v. 484 | u_level0 = RM_22_Plotkin_encoder(msg_bits[:, :4]) 485 | v_level0 = first_order_encoder_Plotkin(msg_bits[:, msg_bits_partition_indices[1]: msg_bits_partition_indices[2]][:, tree_bits_order_from_standard_dict[2]], 2) 486 | 487 | for i in range(3, args.m): 488 | 489 | u_level0 = torch.cat([ u_level0, gnet_dict[i](torch.cat([u_level0.unsqueeze(2), v_level0.unsqueeze(2)], dim=2)).squeeze(2) ], dim=1) 490 | v_level0 = \ 491 | first_order_encoder_Plotkin(msg_bits[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]][:, tree_bits_order_from_standard_dict[i]], i) 492 | 493 | u_levelm = torch.cat([u_level0, gnet_dict[args.m](torch.cat([u_level0.unsqueeze(2), v_level0.unsqueeze(2)], dim=2)).squeeze(2) ], dim=1) 494 | 495 | return power_constraint(u_levelm, None, power_constraint_type, training_mode) 496 | 497 | 498 | def correct_second_order_OnlyTop_encoder_Neural_RM_leaves(msg_bits, gnet_dict, power_constraint_type='hard_power_block', training_mode='train'): #g_avector, g_bvector, 499 | 500 | u_level0 = RM_22_Plotkin_encoder(msg_bits[:, :4]) 501 | v_level0 = first_order_encoder_Plotkin(msg_bits[:, msg_bits_partition_indices[1]: msg_bits_partition_indices[2]][:, tree_bits_order_from_standard_dict[2]], 2) 502 | 503 | for i in range(3, args.m): 504 | 505 | u_level0 = torch.cat([ u_level0, u_level0 * v_level0], dim=1) 506 | v_level0 = \ 507 | first_order_encoder_Plotkin(msg_bits[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]][:, tree_bits_order_from_standard_dict[i]], i) 508 | 509 | u_levelm = torch.cat([u_level0, gnet_dict[args.m](torch.cat([u_level0.unsqueeze(2), v_level0.unsqueeze(2)], dim=2)).squeeze(2) ], dim=1) 510 | 511 | return power_constraint(u_levelm, None, power_constraint_type, training_mode) 512 | 513 | ##----------------------------- 514 | 515 | ## Loading the MAP indices for plus and minus one 516 | PlusOneIdx_22_leaf = torch.load('./data/{0}/CodebookIndex_this_matrix_Zero_PlusOne.pt'.format(3)).long().to(device) 517 | MinusOneIdx_22_leaf = torch.load('./data/{0}/CodebookIndex_this_matrix_One_MinusOne.pt'.format(3)).long().to(device) 518 | 519 | 520 | all_msg_bits_22_leaf = [] 521 | 522 | for i in range(2**(4)-1, -1, -1): 523 | all_msg_bits_22_leaf.append(bin_array(i,4)*2-1) 524 | 525 | 526 | all_msg_bits_22_leaf = torch.tensor(np.array(all_msg_bits_22_leaf)).to(device) 527 | 528 | RM_22_codebook = RM_22_Plotkin_encoder(all_msg_bits_22_leaf) 529 | 530 | RM_22_codebook_PlusOne = RM_22_codebook[PlusOneIdx_22_leaf] 531 | RM_22_codebook_MinusOne = RM_22_codebook[MinusOneIdx_22_leaf] 532 | 533 | 534 | 535 | def RM_22_SoftMAP_decoder(LLR): 536 | 537 | 538 | max_PlusOne, _ = torch.max(contract('lk, ijk -> lij', LLR, RM_22_codebook_PlusOne), 2) 539 | max_MinusOne, _ = torch.max(contract('lk, ijk -> lij', LLR, RM_22_codebook_MinusOne), 2) 540 | 541 | return max_PlusOne - max_MinusOne 542 | 543 | 544 | def correct_second_order_OnlyTop_decoder_nn_full(corrupted_codewords, fnet_dict, snr): 545 | 546 | noise_sigma = snr_db2sigma(snr) 547 | 548 | llrs = (2/noise_sigma**2)*corrupted_codewords 549 | Lu = llrs 550 | 551 | 552 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], code_dimension_k).to(device) 553 | 554 | i = args.m-1 555 | 556 | Lv = fnet_dict[2*(args.m-i) - 1](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2)).squeeze(2) 557 | decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]] \ 558 | = FirstOrder_SoftFHT_InfoBits_decoder(Lv, i) 559 | 560 | v_hat = \ 561 | torch.tanh(0.5*modified_llr_codeword(decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]], i)) 562 | 563 | Lu = fnet_dict[2*(args.m-i)](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2), Lv.unsqueeze(2), v_hat.unsqueeze(2)], dim=2)).squeeze(2) 564 | 565 | 566 | for i in range(args.m-2, 1 , -1): 567 | 568 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)) #.sum(dim=1, keepdim=True) 569 | 570 | 571 | decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]] = FirstOrder_SoftFHT_InfoBits_decoder(Lv, i) 572 | 573 | v_hat = \ 574 | torch.tanh(0.5*modified_llr_codeword(decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]], i)) 575 | 576 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 577 | 578 | 579 | decoded_llrs[:, :4] = RM_22_SoftMAP_decoder(Lu) 580 | 581 | return decoded_llrs 582 | 583 | 584 | 585 | def correct_second_order_decoder_nn_full(corrupted_codewords, fnet_dict): 586 | 587 | 588 | llrs = corrupted_codewords 589 | 590 | Lu = llrs 591 | 592 | 593 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], code_dimension_k).to(device) 594 | 595 | for i in range(args.m-1, 1 , -1): 596 | 597 | Lv = fnet_dict[2*(args.m-i) - 1](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2)).squeeze(2)+log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)) 598 | 599 | decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]]\ 600 | = FirstOrder_SoftFHT_InfoBits_decoder(Lv, i) 601 | 602 | v_hat = \ 603 | torch.tanh(0.5*modified_llr_codeword(decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]], i)) 604 | 605 | 606 | Lu = fnet_dict[2*(args.m-i)](torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2), Lv.unsqueeze(2), v_hat.unsqueeze(2)], dim=2)).squeeze(2)+Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 607 | 608 | 609 | decoded_llrs[:, :4] = RM_22_SoftMAP_decoder(Lu) 610 | 611 | return decoded_llrs 612 | 613 | 614 | def correct_second_order_decoder_dumer_soft(corrupted_codewords, snr): 615 | 616 | noise_sigma = snr_db2sigma(snr) 617 | 618 | llrs = (2/noise_sigma**2)*corrupted_codewords 619 | Lu = llrs 620 | 621 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], code_dimension_k).to(device) 622 | 623 | for i in range(args.m-1, 1 , -1): 624 | 625 | 626 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)) #.sum(dim=1, keepdim=True) 627 | 628 | 629 | decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]] = FirstOrder_SoftFHT_InfoBits_decoder(Lv, i) 630 | 631 | v_hat = \ 632 | torch.tanh(0.5*modified_llr_codeword(decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]], i)) 633 | 634 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 635 | 636 | decoded_llrs[:, :4] = RM_22_SoftMAP_decoder(Lu) 637 | 638 | return decoded_llrs 639 | 640 | 641 | 642 | 643 | def correct_second_order_decoder_dumer(corrupted_codewords, snr): 644 | 645 | noise_sigma = snr_db2sigma(snr) 646 | 647 | llrs = (2/noise_sigma**2)*corrupted_codewords 648 | Lu = llrs 649 | 650 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], code_dimension_k).to(device) 651 | 652 | for i in range(args.m-1, 1 , -1): 653 | 654 | 655 | Lv = log_sum_exp(torch.cat([Lu[:, :2**i].unsqueeze(2), Lu[:, 2**i:].unsqueeze(2)], dim=2).permute(0, 2, 1)) 656 | 657 | 658 | decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]] = FirstOrder_SoftFHT_InfoBits_decoder(Lv, i).sign() 659 | 660 | v_hat = rm_encoder(decoded_llrs[:, msg_bits_partition_indices[i-1]: msg_bits_partition_indices[i]],\ 661 | first_order_generator_dict[i]) 662 | 663 | Lu = Lu[:, :2**i] + v_hat * Lu[:, 2**i:] 664 | 665 | decoded_llrs[:, :4] = RM_22_SoftMAP_decoder(Lu).sign() 666 | 667 | return decoded_llrs 668 | 669 | #------------------------------------------------------------------------------ 670 | 671 | # Leaves are Reed Muller codes 672 | 673 | 674 | def awgn_channel(codewords, snr): 675 | noise_sigma = snr_db2sigma(snr) 676 | standard_Gaussian = torch.randn_like(codewords) 677 | corrupted_codewords = codewords+noise_sigma * standard_Gaussian 678 | return corrupted_codewords 679 | 680 | 681 | ############################ 682 | 683 | def get_msg_bits_batch(data_generator): 684 | msg_bits_batch = next(data_generator) 685 | return msg_bits_batch 686 | 687 | def moving_average(a, n=3) : 688 | ret = np.cumsum(a, dtype=float) 689 | ret[n:] = ret[n:] - ret[:-n] 690 | return ret[n - 1:] / n 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | gnet_dict = {} 700 | 701 | for i in range(3, args.m+1): 702 | gnet_dict[i] = g_Full(2, args.hidden_size, 1) #g_Full(2*2**(i-1), hidden_size, 2**(i-1)) 703 | 704 | 705 | 706 | 707 | fnet_dict = {} 708 | 709 | for i in range(1, args.m-1): 710 | fnet_dict[2*i-1] = f_Full(2, args.hidden_size, 1) #f_Full(2**(m-i+1), hidden_size, 2**(m-i)) 711 | fnet_dict[2*i] = f_Full(1+ 1+ 2, args.hidden_size, 1) #f_Full(1+ 1+ 2*2**(m-i), hidden_size, 1) 712 | 713 | 714 | for i in range(3, args.m+1): 715 | gnet_dict[i].apply(weights_init) 716 | 717 | for i in range(3, args.m+1): 718 | gnet_dict[i].to(device) 719 | 720 | 721 | 722 | for i in range(1, args.m-1): 723 | fnet_dict[2*i-1].apply(weights_init) 724 | fnet_dict[2*i].apply(weights_init) 725 | 726 | for i in range(1, args.m-1): 727 | fnet_dict[2*i-1].to(device) 728 | fnet_dict[2*i].to(device) 729 | 730 | 731 | 732 | 733 | 734 | print("Models are loaded!") 735 | 736 | enc_params = [] 737 | for i in range(3, args.m+1): 738 | enc_params += list(gnet_dict[i].parameters()) 739 | 740 | dec_params = [] 741 | for i in range(1, args.m-1): 742 | dec_params += list(fnet_dict[2*i-1].parameters()) + list(fnet_dict[2*i].parameters()) 743 | 744 | 745 | enc_optimizer = optim.Adam(enc_params, lr = 1e-5)#, momentum=0.9, nesterov=True) #, amsgrad=True) 746 | dec_optimizer = optim.Adam(dec_params, lr = 1e-5)#, momentum=0.9, nesterov=True) #, amsgrad=True) 747 | criterion = nn.BCEWithLogitsLoss() if args.loss_type == 'BCE' else nn.MSELoss() # BinaryFocalLoss() #nn.BCEWithLogitsLoss() # nn.MSELoss() 748 | bers = [] 749 | losses = [] 750 | codebook_size = 1000 751 | 752 | 753 | torch.save(dict(zip(['g{0}'.format(i) for i in range(3, args.m+1)], [gnet_dict[i].state_dict() for i in range(3, args.m+1)])),\ 754 | results_save_path+'/Models/Encoder_NN_0.pt') 755 | 756 | torch.save(dict(zip(['f{0}'.format(i) for i in range(1,2*args.m-3)], [fnet_dict[i].state_dict() for i in range(1, 2*args.m-3)])),\ 757 | results_save_path+'/Models/Decoder_NN_0.pt') 758 | 759 | def pairwise_distances(codebook): 760 | dists = [] 761 | for row1, row2 in combinations(codebook, 2): 762 | distance = (row1-row2).pow(2).sum() 763 | dists.append(np.sqrt(distance.item())) 764 | return dists, np.min(dists) 765 | try: 766 | for k in range(args.full_iterations): 767 | start_time = time.time() 768 | msg_bits_large_batch = 2 * (torch.rand(args.batch_size, code_dimension_k) < 0.5).float() - 1 769 | 770 | num_small_batches = int(args.batch_size/args.small_batch_size) 771 | # # Train decoder 772 | for _ in range(args.dec_train_iters): 773 | dec_optimizer.zero_grad() 774 | for i in range(num_small_batches): 775 | start, end = i*args.small_batch_size, (i+1)*args.small_batch_size 776 | msg_bits = msg_bits_large_batch[start:end].to(device) 777 | transmit_codewords = correct_second_order_encoder_Neural_RM_leaves(msg_bits, gnet_dict) 778 | corrupted_codewords = awgn_channel(transmit_codewords, args.dec_train_snr) 779 | decoded_bits = correct_second_order_decoder_nn_full(corrupted_codewords, fnet_dict) 780 | 781 | loss = criterion(decoded_bits, 0.5*msg_bits+0.5)/num_small_batches 782 | 783 | loss.backward() 784 | dec_optimizer.step() 785 | 786 | 787 | # Train Encoder 788 | for _ in range(args.enc_train_iters): 789 | 790 | enc_optimizer.zero_grad() 791 | 792 | for i in range(num_small_batches): 793 | start, end = i*args.small_batch_size, (i+1)*args.small_batch_size 794 | msg_bits = msg_bits_large_batch[start:end].to(device) 795 | 796 | transmit_codewords = correct_second_order_encoder_Neural_RM_leaves(msg_bits, gnet_dict) 797 | corrupted_codewords = awgn_channel(transmit_codewords, args.enc_train_snr) 798 | decoded_bits = correct_second_order_decoder_nn_full(corrupted_codewords, fnet_dict) 799 | 800 | loss = criterion(decoded_bits, 0.5*msg_bits+0.5 )/num_small_batches 801 | 802 | loss.backward() 803 | 804 | enc_optimizer.step() 805 | 806 | ber = errors_ber(msg_bits, decoded_bits.sign()).item() 807 | 808 | bers.append(ber) 809 | 810 | losses.append(loss.item()) 811 | if k % 10 == 0: 812 | print('[%d/%d] At %d dB, Loss: %.10f BER: %.10f' 813 | % (k+1, args.full_iterations, args.enc_train_snr, loss.item(), ber)) 814 | print("Time for one full iteration is {0:.4f} minutes".format((time.time() - start_time)/60)) 815 | 816 | 817 | # Save the model for safety 818 | if k % 10 == 0: 819 | 820 | torch.save(dict(zip(['g{0}'.format(i) for i in range(3, args.m+1)], [gnet_dict[i].state_dict() for i in range(3, args.m+1)])),\ 821 | results_save_path+'/Models/Encoder_NN_{0}.pt'.format(k+1)) 822 | 823 | torch.save(dict(zip(['f{0}'.format(i) for i in range(1,2*args.m-3)], [fnet_dict[i].state_dict() for i in range(1, 2*args.m-3)])),\ 824 | results_save_path+'/Models/Decoder_NN_{0}.pt'.format(k+1)) 825 | 826 | plt.figure() 827 | plt.plot(bers) 828 | plt.plot(moving_average(bers, n=10)) 829 | plt.savefig(results_save_path +'/training_ber.png') 830 | plt.close() 831 | 832 | plt.figure() 833 | plt.plot(losses) 834 | plt.plot(moving_average(losses, n=10)) 835 | plt.savefig(results_save_path +'/training_losses.png') 836 | plt.close() 837 | 838 | except KeyboardInterrupt: 839 | print('Graceful Exit') 840 | else: 841 | print('Finished') 842 | 843 | plt.figure() 844 | plt.plot(bers) 845 | plt.plot(moving_average(bers, n=10)) 846 | plt.savefig(results_save_path +'/training_ber.png') 847 | plt.close() 848 | 849 | plt.figure() 850 | plt.plot(losses) 851 | plt.plot(moving_average(losses, n=10)) 852 | plt.savefig(results_save_path +'/training_losses.png') 853 | plt.close() 854 | 855 | torch.save(dict(zip(['g{0}'.format(i) for i in range(3, args.m+1)], [gnet_dict[i].state_dict() for i in range(3, args.m+1)])),\ 856 | results_save_path+'/Models/Encoder_NN.pt') 857 | 858 | torch.save(dict(zip(['f{0}'.format(i) for i in range(1,2*args.m-3)], [fnet_dict[i].state_dict() for i in range(1, 2*args.m-3)])),\ 859 | results_save_path+'/Models/Decoder_NN.pt') 860 | 861 | 862 | 863 | -------------------------------------------------------------------------------- /train_Polar_m6k7.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | from __future__ import print_function 5 | import argparse 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torchvision import datasets, transforms 11 | from torch.autograd import Variable 12 | from torchvision.utils import save_image 13 | from torchvision.utils import make_grid 14 | import torch.utils.data 15 | from data_loader import * 16 | from IPython import display 17 | 18 | import pickle 19 | import glob 20 | import os 21 | import logging 22 | import time 23 | from datetime import datetime 24 | from ast import literal_eval 25 | import matplotlib 26 | # matplotlib.use('AGG') 27 | import matplotlib.pyplot as plt 28 | import matplotlib.animation as animation 29 | from PIL import Image 30 | 31 | 32 | from opt_einsum import contract # This is for faster torch.einsum 33 | 34 | 35 | import pandas as pd 36 | import numpy as np 37 | from scipy.stats import norm 38 | from tqdm import tqdm 39 | from itertools import combinations 40 | 41 | parser = argparse.ArgumentParser(description='(m,k) Polar') 42 | 43 | parser.add_argument('--m', type=int, default=6, help='number of layers in a polar code m') 44 | 45 | parser.add_argument('--batch_size', type=int, default=20000, help='size of the batches') 46 | parser.add_argument('--hidden_size', type=int, default=64, help='neural network size') 47 | 48 | parser.add_argument('--full_iterations', type=int, default=10000, help='full iterations') 49 | parser.add_argument('--enc_train_iters', type=int, default=50, help='encoder iterations') 50 | parser.add_argument('--dec_train_iters', type=int, default=500, help='decoder iterations') 51 | 52 | parser.add_argument('--enc_train_snr', type=float, default=-0.5, help='snr at enc are trained') 53 | parser.add_argument('--dec_train_snr', type=float, default=-2.5, help='snr at dec are trained') 54 | 55 | parser.add_argument('--loss_type', type=str, default='BCE', choices=['MSE', 'BCE'], help='loss function') 56 | 57 | parser.add_argument('--gpu', type=int, default=0, help='gpus used for training - e.g 0,1,3') 58 | 59 | args = parser.parse_args() 60 | 61 | device = torch.device("cuda:{0}".format(args.gpu)) 62 | # device = torch.device("cpu") 63 | 64 | kwargs = {'num_workers': 4, 'pin_memory': False} 65 | 66 | results_save_path = './Results/Polar({0},{1})/NN_EncFull_Skip+Dec_SC/Enc_snr_{2}_Dec_snr{3}/Batch_{4}' \ 67 | .format(2**args.m, args.m+1, args.enc_train_snr, args.dec_train_snr, args.batch_size) 68 | os.makedirs(results_save_path, exist_ok=True) 69 | os.makedirs(results_save_path + '/Models', exist_ok=True) 70 | 71 | 72 | def repetition_code_matrices(device, m=8): 73 | M_dict = {} 74 | 75 | for i in range(1, m): 76 | M_dict[i] = torch.ones(1, 2 ** i).to(device) 77 | 78 | return M_dict 79 | 80 | 81 | repetition_M_dict = repetition_code_matrices(device, args.m) 82 | 83 | print("Matrices required for repition code are defined!") 84 | 85 | 86 | ###### 87 | ## Functions 88 | ###### 89 | 90 | def snr_db2sigma(train_snr): 91 | return 10 ** (-train_snr * 1.0 / 20) 92 | 93 | 94 | def log_sum_exp(LLR_vector): 95 | sum_vector = LLR_vector.sum(dim=1, keepdim=True) 96 | sum_concat = torch.cat([sum_vector, torch.zeros_like(sum_vector)], dim=1) 97 | 98 | return torch.logsumexp(sum_concat, dim=1) - torch.logsumexp(LLR_vector, dim=1) 99 | 100 | 101 | def errors_ber(y_true, y_pred): 102 | y_true = y_true.view(y_true.shape[0], -1, 1) 103 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 104 | 105 | myOtherTensor = torch.ne(torch.round(y_true), torch.round(y_pred)).float() 106 | res = sum(sum(myOtherTensor)) / (myOtherTensor.shape[0] * myOtherTensor.shape[1]) 107 | return res 108 | 109 | 110 | def errors_bler(y_true, y_pred): 111 | y_true = y_true.view(y_true.shape[0], -1, 1) 112 | y_pred = y_pred.view(y_pred.shape[0], -1, 1) 113 | 114 | decoded_bits = torch.round(y_pred).cpu() 115 | X_test = torch.round(y_true).cpu() 116 | tp0 = (abs(decoded_bits - X_test)).view([X_test.shape[0], X_test.shape[1]]) 117 | tp0 = tp0.detach().cpu().numpy() 118 | bler_err_rate = sum(np.sum(tp0, axis=1) > 0) * 1.0 / (X_test.shape[0]) 119 | return bler_err_rate 120 | 121 | 122 | class g_identity(nn.Module): 123 | def __init__(self): 124 | super(g_vector, self).__init__() 125 | self.fc = nn.Linear(1, 1, bias=False) 126 | 127 | def forward(self, y): 128 | return y 129 | 130 | 131 | class g_vector(nn.Module): 132 | def __init__(self): 133 | super(g_vector, self).__init__() 134 | self.fc = nn.Linear(16, 1, bias=True) 135 | 136 | def forward(self, y): 137 | return self.fc(y) 138 | 139 | 140 | class g_Full(nn.Module): 141 | def __init__(self, input_size, hidden_size, output_size): 142 | super(g_Full, self).__init__() 143 | 144 | self.input_size = input_size 145 | 146 | self.half_input_size = int(input_size / 2) 147 | 148 | self.hidden_size = hidden_size 149 | self.output_size = output_size 150 | 151 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 152 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 153 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 154 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 155 | 156 | self.skip = nn.Linear(3 * self.half_input_size, self.hidden_size, bias=False) 157 | 158 | def forward(self, y): 159 | x = F.selu(self.fc1(y)) 160 | x = F.selu(self.fc2(x)) + self.skip( 161 | torch.cat([y, y[:, :self.half_input_size] * y[:, self.half_input_size:]], dim=1)) 162 | 163 | x = F.selu(self.fc3(x)) 164 | x = self.fc4(x) 165 | return x 166 | 167 | 168 | class f_Full(nn.Module): 169 | def __init__(self, input_size, hidden_size, output_size): 170 | super(f_Full, self).__init__() 171 | self.input_size = input_size 172 | self.hidden_size = hidden_size 173 | self.output_size = output_size 174 | 175 | self.fc1 = nn.Linear(self.input_size, self.hidden_size, bias=True) 176 | self.fc2 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 177 | self.fc3 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 178 | self.fc4 = nn.Linear(self.hidden_size, self.output_size, bias=True) 179 | 180 | def forward(self, y): 181 | x = F.selu(self.fc1(y)) 182 | x = F.selu(self.fc2(x)) 183 | 184 | x = F.selu(self.fc3(x)) 185 | x = self.fc4(x) 186 | return x 187 | 188 | 189 | def power_constraint(codewords, gnet_top, power_constraint_type, training_mode): 190 | if power_constraint_type in ['soft_power_block', 'soft_power_bit']: 191 | 192 | this_mean = codewords.mean(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.mean() 193 | this_std = codewords.std(dim=0) if power_constraint_type == 'soft_power_bit' else codewords.std() 194 | 195 | if training_mode == 'train': # Training 196 | power_constrained_codewords = (codewords - this_mean) * 1.0 / this_std 197 | 198 | gnet_top.update_normstats_for_test(this_mean, this_std) 199 | 200 | elif training_mode == 'test': # For inference 201 | power_constrained_codewords = (codewords - gnet_top.mean_scalar) * 1.0 / gnet_top.std_scalar 202 | 203 | 204 | 205 | return power_constrained_codewords 206 | 207 | 208 | elif power_constraint_type == 'hard_power_block': 209 | 210 | return F.normalize(codewords, p=2, dim=1) * np.sqrt(2 ** args.m) 211 | 212 | 213 | else: 214 | 215 | return codewords / codewords.abs() 216 | 217 | 218 | 219 | ## Encoding of Polar Codes ## 220 | # The following is only for polar(n=64,k=7) 221 | 222 | def encoder_Polar_Plotkin(msg_bits): 223 | 224 | u_level1 = torch.cat([msg_bits[:, 6:7], msg_bits[:, 6:7] * msg_bits[:, 5:6]], dim=1) 225 | v_level1 = torch.cat([msg_bits[:, 4:5], msg_bits[:, 4:5] * msg_bits[:, 3:4]], dim=1) 226 | 227 | for i in range(2, args.m - 1): 228 | u_level1 = torch.cat([u_level1, u_level1 * v_level1], dim=1) 229 | v_level1 = msg_bits[:, 4-i:5-i].mm(repetition_M_dict[i]) 230 | 231 | u_level5 = torch.cat([u_level1, u_level1 * v_level1], dim=1) 232 | 233 | u_level6 = torch.cat([u_level5, u_level5], dim=1) 234 | 235 | return u_level6 236 | 237 | 238 | def encoder_Polar_full(msg_bits, gnet_dict, power_constraint_type='hard_power_block', 239 | training_mode='train'): # g_avector, g_bvector, 240 | 241 | u_level1 = torch.cat([msg_bits[:, 6:7], gnet_dict[1, 'right'](torch.cat([msg_bits[:, 6:7], msg_bits[:, 5:6]], dim=1)) ], dim=1) 242 | v_level1 = torch.cat([msg_bits[:, 4:5], gnet_dict[1, 'left'](torch.cat([msg_bits[:, 4:5], msg_bits[:, 3:4]], dim=1))], dim=1) 243 | 244 | for i in range(2, args.m - 1): 245 | u_level1 = torch.cat([u_level1, gnet_dict[i](torch.cat([u_level1, v_level1], dim=1)) ], dim=1) 246 | v_level1 = msg_bits[:, 4-i:5-i].mm(repetition_M_dict[i]) 247 | 248 | u_level5 = torch.cat([u_level1, gnet_dict[args.m-1](torch.cat([u_level1, v_level1], dim=1)) ], dim=1) 249 | 250 | u_level6 = torch.cat([u_level5, u_level5], dim=1) 251 | 252 | 253 | return power_constraint(u_level6, gnet_dict[args.m], power_constraint_type, training_mode) 254 | 255 | 256 | def awgn_channel(codewords, snr): 257 | noise_sigma = snr_db2sigma(snr) 258 | standard_Gaussian = torch.randn_like(codewords) 259 | corrupted_codewords = codewords + noise_sigma * standard_Gaussian 260 | return corrupted_codewords 261 | 262 | def decoder_Polar_SC(corrupted_codewords, snr): 263 | noise_sigma = snr_db2sigma(snr) 264 | 265 | llrs = (2 / noise_sigma ** 2) * corrupted_codewords 266 | Lu = llrs 267 | Lu = Lu[:, 32:] + Lu[:, :32] 268 | 269 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m + 1).to(device) 270 | 271 | for i in range(args.m - 2, 1, -1): 272 | Lv = log_sum_exp(torch.cat([Lu[:, :2 ** i].unsqueeze(2), Lu[:, 2 ** i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 273 | v_hat = torch.sign(Lv) 274 | decoded_bits[:, 4 - i] = v_hat.squeeze(1) 275 | Lu = Lu[:, :2 ** i] + v_hat * Lu[:, 2 ** i:] 276 | 277 | 278 | Lu2 = Lu 279 | Lv1 = log_sum_exp(torch.cat([Lu2[:, 0:2].unsqueeze(2), Lu2[:, 2:4].unsqueeze(2)], dim=2).permute(0, 2, 1)) 280 | L_u3 = log_sum_exp(torch.cat([Lv1[:, 0:1].unsqueeze(2), Lv1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 281 | u3_hat = torch.sign(L_u3) 282 | decoded_bits[:, 3] = u3_hat.squeeze(1) 283 | 284 | L_u4 = Lv1[:, 0:1] + u3_hat * Lv1[:, 1:2] 285 | u4_hat = torch.sign(L_u4) 286 | decoded_bits[:, 4] = u4_hat.squeeze(1) 287 | 288 | v1_hat = torch.cat([decoded_bits[:, 4:5], decoded_bits[:, 4:5] * decoded_bits[:, 3:4]], dim=1) 289 | Lu1 = Lu2[:, 0:2] + v1_hat * Lu2[:, 2:4] 290 | L_u5 = log_sum_exp(torch.cat([Lu1[:, 0:1].unsqueeze(2), Lu1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 291 | u5_hat = torch.sign(L_u5) 292 | decoded_bits[:, 5] = u5_hat.squeeze(1) 293 | 294 | L_u6 = Lu1[:, 0:1] + u5_hat * Lu1[:, 1:2] 295 | u6_hat = torch.sign(L_u6) 296 | decoded_bits[:, 6] = u6_hat.squeeze(1) 297 | 298 | return decoded_bits 299 | 300 | 301 | def decoder_Polar_SC_soft(corrupted_codewords, snr): 302 | noise_sigma = snr_db2sigma(snr) 303 | 304 | llrs = (2 / noise_sigma ** 2) * corrupted_codewords 305 | Lu = llrs 306 | Lu = Lu[:, 32:] + Lu[:, :32] 307 | 308 | decoded_bits = torch.zeros(corrupted_codewords.shape[0], args.m + 1).to(device) 309 | 310 | for i in range(args.m - 2, 1, -1): 311 | Lv = log_sum_exp(torch.cat([Lu[:, :2 ** i].unsqueeze(2), Lu[:, 2 ** i:].unsqueeze(2)], dim=2).permute(0, 2, 1)).sum(dim=1, keepdim=True) 312 | v_hat = torch.tanh(Lv/2) 313 | decoded_bits[:, 4 - i] = v_hat.squeeze(1) 314 | Lu = Lu[:, :2 ** i] + v_hat * Lu[:, 2 ** i:] 315 | 316 | 317 | Lu2 = Lu 318 | Lv1 = log_sum_exp(torch.cat([Lu2[:, 0:2].unsqueeze(2), Lu2[:, 2:4].unsqueeze(2)], dim=2).permute(0, 2, 1)) 319 | L_u3 = log_sum_exp(torch.cat([Lv1[:, 0:1].unsqueeze(2), Lv1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 320 | u3_hat = torch.tanh(L_u3/2) 321 | decoded_bits[:, 3] = u3_hat.squeeze(1) 322 | 323 | L_u4 = Lv1[:, 0:1] + u3_hat * Lv1[:, 1:2] 324 | u4_hat = torch.tanh(L_u4/2) 325 | decoded_bits[:, 4] = u4_hat.squeeze(1) 326 | 327 | v1_hat = torch.cat([decoded_bits[:, 4:5], decoded_bits[:, 4:5] * decoded_bits[:, 3:4]], dim=1) 328 | Lu1 = Lu2[:, 0:2] + v1_hat * Lu2[:, 2:4] 329 | L_u5 = log_sum_exp(torch.cat([Lu1[:, 0:1].unsqueeze(2), Lu1[:, 1:2].unsqueeze(2)], dim=2).permute(0, 2, 1)) 330 | u5_hat = torch.tanh(L_u5/2) 331 | decoded_bits[:, 5] = u5_hat.squeeze(1) 332 | 333 | L_u6 = Lu1[:, 0:1] + u5_hat * Lu1[:, 1:2] 334 | u6_hat = torch.tanh(L_u6/2) 335 | decoded_bits[:, 6] = u6_hat.squeeze(1) 336 | 337 | return decoded_bits 338 | 339 | 340 | def decoder_Polar_nn_full(corrupted_codewords, fnet_dict): 341 | 342 | Lu = corrupted_codewords 343 | Lu = Lu[:, 32:] + Lu[:, :32] 344 | 345 | decoded_llrs = torch.zeros(corrupted_codewords.shape[0], args.m + 1).to(device) 346 | 347 | for i in range(args.m - 2, 1, -1): 348 | Lv = fnet_dict[i+1, 'left'](Lu) 349 | decoded_llrs[:, 4 - i] = Lv.squeeze(1) 350 | v_hat = torch.tanh(Lv/2) 351 | Lu = fnet_dict[i+1, 'right'](torch.cat([Lu[:, :2 ** i].unsqueeze(2), Lu[:, 2 ** i:].unsqueeze(2), v_hat.unsqueeze(1).repeat(1, 2 ** i, 1)],dim=2)).squeeze(2) 352 | 353 | 354 | Lu2 = Lu 355 | Lv1 = fnet_dict[2, 'left'](Lu2) 356 | L_u3 = fnet_dict[1, 'left', 'left'](Lv1) 357 | decoded_llrs[:, 3] = L_u3.squeeze(1) 358 | u3_hat = torch.tanh(0.5 * L_u3) 359 | 360 | L_u4 = fnet_dict[1, 'left', 'right'](torch.cat([Lv1[:, 0:1].unsqueeze(2), Lv1[:, 1:2].unsqueeze(2), u3_hat.unsqueeze(1).repeat(1, 1, 1)],dim=2)).squeeze(2) 361 | decoded_llrs[:, 4] = L_u4.squeeze(1) 362 | u4_hat = torch.tanh(0.5 * L_u4) 363 | 364 | v1_hat = torch.cat([u4_hat, gnet_dict[1, 'left'](torch.cat([torch.sign(L_u4), torch.sign(L_u3)], dim=1)) ], dim=1) 365 | Lu1 = fnet_dict[2, 'right'](torch.cat([Lu2[:, :2].unsqueeze(2), Lu2[:, 2:].unsqueeze(2), v1_hat.unsqueeze(2)],dim=2)).squeeze(2) 366 | L_u5 = fnet_dict[1, 'right', 'left'](Lu1) 367 | decoded_llrs[:, 5] = L_u5.squeeze(1) 368 | u5_hat = torch.tanh(0.5 * L_u5) 369 | 370 | L_u6 = fnet_dict[1, 'right', 'right'](torch.cat([Lu1[:, 0:1].unsqueeze(2), Lu1[:, 1:2].unsqueeze(2), u5_hat.unsqueeze(1).repeat(1, 1, 1)],dim=2)).squeeze(2) 371 | decoded_llrs[:, 6] = L_u6.squeeze(1) 372 | 373 | 374 | return decoded_llrs 375 | 376 | 377 | def get_msg_bits_batch(data_generator): 378 | msg_bits_batch = next(data_generator) 379 | return msg_bits_batch 380 | 381 | 382 | def moving_average(a, n=3): 383 | ret = np.cumsum(a, dtype=float) 384 | ret[n:] = ret[n:] - ret[:-n] 385 | return ret[n - 1:] / n 386 | 387 | 388 | 389 | print("Data loading stuff is completed! \n") 390 | 391 | gnet_dict = {} 392 | gnet_dict[1, 'left'] = g_Full(2, args.hidden_size, 1) 393 | gnet_dict[1, 'right'] = g_Full(2, args.hidden_size, 1) 394 | for i in range(2, args.m + 1): 395 | gnet_dict[i] = g_Full(2 * 2 ** (i - 1), args.hidden_size, 2 ** (i - 1)) 396 | 397 | fnet_dict = {} 398 | for i in range(3, 6): 399 | fnet_dict[i, 'left'] = f_Full(2 ** i, args.hidden_size, 1) 400 | fnet_dict[i, 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 401 | 402 | fnet_dict[2, 'left'] = f_Full(4, args.hidden_size, 2) 403 | fnet_dict[2, 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 404 | 405 | fnet_dict[1, 'left', 'left'] = f_Full(2, args.hidden_size, 1) 406 | fnet_dict[1, 'left', 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 407 | 408 | fnet_dict[1, 'right', 'left'] = f_Full(2, args.hidden_size, 1) 409 | fnet_dict[1, 'right', 'right'] = f_Full(1 + 1 + 1, args.hidden_size, 1) 410 | 411 | # Now load them onto devices 412 | 413 | gnet_dict[1, 'left'].to(device) 414 | gnet_dict[1, 'right'].to(device) 415 | for i in range(2, args.m + 1): 416 | gnet_dict[i].to(device) 417 | 418 | for i in range(2, 6): 419 | fnet_dict[i, 'left'].to(device) 420 | fnet_dict[i, 'right'].to(device) 421 | fnet_dict[1, 'left', 'left'].to(device) 422 | fnet_dict[1, 'left', 'right'].to(device) 423 | fnet_dict[1, 'right', 'left'].to(device) 424 | fnet_dict[1, 'right', 'right'].to(device) 425 | 426 | print("Models are loaded!") 427 | 428 | enc_params = [] 429 | enc_params += list(gnet_dict[1, 'left'].parameters()) + list(gnet_dict[1, 'right'].parameters()) 430 | for i in range(2, args.m + 1): 431 | enc_params += list(gnet_dict[i].parameters()) 432 | 433 | dec_params = [] 434 | for i in range(2, args.m): 435 | dec_params += list(fnet_dict[i, 'left'].parameters()) + list(fnet_dict[i, 'right'].parameters()) 436 | dec_params += list(fnet_dict[1, 'left', 'left'].parameters()) + list(fnet_dict[1, 'left', 'right'].parameters()) 437 | dec_params += list(fnet_dict[1, 'right', 'left'].parameters()) + list(fnet_dict[1, 'right', 'right'].parameters()) 438 | 439 | 440 | enc_optimizer = optim.Adam(enc_params, lr=1e-5) 441 | dec_optimizer = optim.Adam(dec_params, lr=1e-4) 442 | criterion = nn.BCEWithLogitsLoss() if args.loss_type == 'BCE' else nn.MSELoss() 443 | 444 | bers = [] 445 | losses = [] 446 | 447 | try: 448 | for k in range(args.full_iterations): 449 | start_time = time.time() 450 | msg_bits = 2 * (torch.rand(args.batch_size, args.m + 1) < 0.5).float() - 1 451 | msg_bits = msg_bits.to(device) 452 | 453 | # # Train decoder 454 | for _ in range(args.dec_train_iters): 455 | 456 | transmit_codewords = encoder_Polar_full(msg_bits, gnet_dict) 457 | corrupted_codewords = awgn_channel(transmit_codewords, args.dec_train_snr) 458 | decoded_bits = decoder_Polar_nn_full(corrupted_codewords, fnet_dict) 459 | 460 | loss = criterion(decoded_bits, 0.5 * msg_bits + 0.5) 461 | 462 | dec_optimizer.zero_grad() 463 | loss.backward() 464 | dec_optimizer.step() 465 | 466 | # Train Encoder 467 | for _ in range(args.enc_train_iters): 468 | 469 | transmit_codewords = encoder_Polar_full(msg_bits, gnet_dict) 470 | corrupted_codewords = awgn_channel(transmit_codewords, args.enc_train_snr) 471 | decoded_bits = decoder_Polar_nn_full(corrupted_codewords, fnet_dict) 472 | 473 | loss = criterion(decoded_bits, 0.5 * msg_bits + 0.5) 474 | 475 | enc_optimizer.zero_grad() 476 | loss.backward() 477 | enc_optimizer.step() 478 | 479 | ber = errors_ber(msg_bits, decoded_bits.sign()).item() 480 | 481 | bers.append(ber) 482 | 483 | losses.append(loss.item()) 484 | if k % 10 == 0: 485 | print('[%d/%d] At %d dB, Loss: %.7f BER: %.7f' 486 | % (k + 1, args.full_iterations, args.enc_train_snr, loss.item(), ber)) 487 | print("Time for one full iteration is {0:.4f} minutes".format((time.time() - start_time) / 60)) 488 | 489 | # Save the model for safety 490 | if (k + 1) % 100 == 0: 491 | torch.save(dict(zip(['g{0}'.format(i) for i in range(2, args.m + 1)] + ['g1_left', 'g1_right'], 492 | [gnet_dict[i].state_dict() for i in range(2, args.m + 1)] + [gnet_dict[1, 'left'].state_dict(), gnet_dict[1, 'right'].state_dict()] )), \ 493 | results_save_path + '/Models/Encoder_NN_{0}.pt'.format(k + 1)) 494 | 495 | torch.save(dict(zip(['f{0}_left'.format(i) for i in range(2, 6)] + ['f{0}_right'.format(i) for i in range(2, 6)] + \ 496 | ['f1_left_left', 'f1_left_right', 'f1_right_left', 'f1_right_right'], 497 | [fnet_dict[i, 'left'].state_dict() for i in range(2, 6)] + [fnet_dict[i, 'right'].state_dict() for i in range(2, 6)] + \ 498 | [fnet_dict[1, 'left', 'left'].state_dict(), fnet_dict[1, 'left', 'right'].state_dict(), \ 499 | fnet_dict[1, 'right', 'left'].state_dict(), fnet_dict[1, 'right', 'right'].state_dict() ] )), \ 500 | results_save_path + '/Models/Decoder_NN_{0}.pt'.format(k + 1)) 501 | 502 | 503 | 504 | plt.figure() 505 | plt.plot(bers) 506 | plt.plot(moving_average(bers, n=10)) 507 | plt.savefig(results_save_path + '/training_ber.png') 508 | plt.close() 509 | 510 | plt.figure() 511 | plt.plot(losses) 512 | plt.plot(moving_average(losses, n=10)) 513 | plt.savefig(results_save_path + '/training_losses.png') 514 | plt.close() 515 | 516 | except KeyboardInterrupt: 517 | print('Graceful Exit') 518 | else: 519 | print('Finished') 520 | 521 | plt.figure() 522 | plt.plot(bers) 523 | plt.plot(moving_average(bers, n=10)) 524 | plt.savefig(results_save_path + '/training_ber.png') 525 | plt.close() 526 | 527 | plt.figure() 528 | plt.plot(losses) 529 | plt.plot(moving_average(losses, n=10)) 530 | plt.savefig(results_save_path + '/training_losses.png') 531 | plt.close() 532 | 533 | torch.save(dict(zip(['g{0}'.format(i) for i in range(2, args.m + 1)] + ['g1_left', 'g1_right'], 534 | [gnet_dict[i].state_dict() for i in range(2, args.m + 1)] + [gnet_dict[1, 'left'].state_dict(), gnet_dict[1, 'right'].state_dict()])), \ 535 | results_save_path + '/Models/Encoder_NN_{0}.pt'.format(k + 1)) 536 | 537 | torch.save(dict(zip(['f{0}_left'.format(i) for i in range(2, 6)] + ['f{0}_right'.format(i) for i in range(2, 6)] + \ 538 | ['f1_left_left', 'f1_left_right', 'f1_right_left', 'f1_right_right'], 539 | [fnet_dict[i, 'left'].state_dict() for i in range(2, 6)] + [fnet_dict[i, 'right'].state_dict() for i 540 | in range(2, 6)] + \ 541 | [fnet_dict[1, 'left', 'left'].state_dict(), 542 | fnet_dict[1, 'left', 'right'].state_dict(), \ 543 | fnet_dict[1, 'right', 'left'].state_dict(), fnet_dict[1, 'right', 'right'].state_dict()])), \ 544 | results_save_path + '/Models/Decoder_NN_{0}.pt'.format(k + 1)) 545 | --------------------------------------------------------------------------------