├── README.md ├── E2E_DC_ECCT.py └── DC_ECCT.py /README.md: -------------------------------------------------------------------------------- 1 | # Learning Linear Block Error Correction Codes 2 | 3 | Implementation of the End-to-end deep codes Error Correction Code Transformer described in ["Learning Linear Block Error Correction Codes (ICML 2024)"](https://arxiv.org/abs/2405.04050). 4 | 5 | The decoder implementation is related to the Foundation Error Correction Codes model published in ["A Foundation Model for Error Correction Codes (ICLR 2024)"](https://openreview.net/forum?id=7KDuQPrAF3) 6 | 7 | 8 | ## Abstract 9 | Error correction codes are a crucial part of the physical communication layer, ensuring the reliable transfer of data over noisy channels. The design of optimal linear block codes capable of being efficiently decoded is of major concern, especially for short block lengths. While neural decoders have recently demonstrated their advantage over classical decoding techniques, the neural design of the codes remains a challenge. In this work, we propose for the first time a unified encoder-decoder training of binary linear block codes. To this end, we adapt the coding setting to support efficient and differentiable training of the code for end-to-end optimization over the order two Galois field. We also propose a novel Transformer model in which the self-attention masking is performed in a differentiable fashion for the efficient backpropagation of the code gradient. Our results show that (i) the proposed decoder outperforms existing neural decoding on conventional codes, (ii) the suggested framework generates codes that outperform the {analogous} conventional codes, and (iii) the codes we developed not only excel with our decoder but also show enhanced performance with traditional decoding techniques. 10 | 11 | ## Install 12 | - Pytorch 13 | 14 | ## Script 15 | Use the following command to train a toy example. Every modification can be performed via the main function. 16 | 17 | `python E2E_DC_ECCT.py` 18 | 19 | -------------------------------------------------------------------------------- /E2E_DC_ECCT.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep Coding for Linear Block Error Correction 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | def sign_to_bin(x): 8 | return 0.5 * (1 - x) 9 | 10 | def bin_to_sign(x): 11 | return 1 - 2 * x 12 | 13 | def diff_syndrome(H,x): 14 | tmp = bin_to_sign(H.unsqueeze(0)*x.unsqueeze(1)) 15 | tmp = torch.prod(tmp,2) 16 | return sign_to_bin(tmp) 17 | 18 | def diff_gener(G,m): 19 | tmp = bin_to_sign(G.unsqueeze(0)*m.unsqueeze(2)) 20 | tmp = torch.prod(tmp,1) 21 | return sign_to_bin(tmp) 22 | 23 | class Binarization(torch.autograd.Function): 24 | @staticmethod 25 | def forward(ctx, input): 26 | ctx.save_for_backward(input) 27 | return ((input>=0)*1. - (input<0)*1.).float() 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | x = ctx.saved_tensors[0] 31 | return grad_output*(torch.abs(x)<=1) 32 | 33 | class E2E_DC_ECC_Transformer(nn.Module): 34 | def __init__(self, args, decoder): 35 | super(E2E_DC_ECC_Transformer, self).__init__() 36 | #### 37 | self.args = args 38 | code = args.code 39 | self.n = code.n 40 | self.k = code.k 41 | self.bin = Binarization.apply 42 | with torch.no_grad(): 43 | P_matrix = (torch.randint(0,2,(code.n-code.k,code.k))).float() 44 | P_matrix = bin_to_sign(P_matrix)*0.01 45 | self.P_matrix = nn.Parameter(P_matrix) 46 | # self.register_buffer('P_matrix', P_matrix) 47 | self.register_buffer('I_matrix_H', torch.eye(code.n-code.k)) 48 | self.register_buffer('I_matrix_G', torch.eye(code.k)) 49 | # 50 | self.decoder = decoder 51 | ######## 52 | 53 | def forward(self, m, z): 54 | x = diff_gener(self.get_generator_matrix(), m) 55 | x = bin_to_sign(x) 56 | z_mul = ((x+z) * x).detach() 57 | y = x*z_mul 58 | syndrome = bin_to_sign(diff_syndrome(self.get_pc_matrix(),sign_to_bin(self.bin(y)))) 59 | magnitude = torch.abs(y) 60 | emb, loss, x_pred = self.decoder(magnitude, syndrome, self.get_pc_matrix(), z_mul, y, self.get_pc_matrix()) 61 | return loss, x_pred, sign_to_bin(x) 62 | 63 | def get_pc_matrix(self): 64 | bin_P = sign_to_bin(self.bin(self.P_matrix)) 65 | return torch.cat([self.I_matrix_H,bin_P],1) 66 | 67 | def get_generator_matrix(self,): 68 | bin_P = sign_to_bin(self.bin(self.P_matrix)) 69 | return torch.cat([bin_P,self.I_matrix_G],0).transpose(0,1) 70 | 71 | ############################################################ 72 | ############################################################ 73 | 74 | if __name__ == '__main__': 75 | from DC_ECCT import DC_ECC_Transformer 76 | import numpy as np 77 | class Code(): 78 | pass 79 | def EbN0_to_std(EbN0, rate): 80 | snr = EbN0 + 10. * np.log10(2 * rate) 81 | return np.sqrt(1. / (10. ** (snr / 10.))) 82 | code = Code() 83 | code.k = 16 84 | code.n = 31 85 | 86 | args = Code() 87 | args.code = code 88 | args.d_model = 32 89 | args.h = 8 90 | args.N_dec = 2 91 | args.dropout_attn = 0 92 | args.dropout = 0 93 | 94 | bs = 1024 95 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 96 | model = E2E_DC_ECC_Transformer(args, DC_ECC_Transformer(args)).to(device) 97 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 98 | 99 | 100 | EbNo_range_train = range(2, 8) 101 | std_train = torch.tensor([EbN0_to_std(ebn0, code.k / code.n) for ebn0 in EbNo_range_train]).float() 102 | 103 | m = torch.ones((bs, code.k)).long().to(device) 104 | H0 = model.get_pc_matrix().detach().clone() 105 | for iter in range(10000): 106 | model.zero_grad() 107 | stds = std_train[torch.randperm(bs)%len(std_train)] 108 | loss, x_pred, x = model(m, (torch.randn(bs,code.n)*stds.unsqueeze(-1)).to(device)) 109 | loss.backward() 110 | optimizer.step() 111 | if iter%1000 == 0: 112 | print(f'iter {iter}: loss = {loss.item()} BER = {torch.mean((x_pred!=x).float()).item()} ||H_t-H0||_1 = {torch.sum((H0-model.get_pc_matrix()).abs())}') 113 | 114 | 115 | -------------------------------------------------------------------------------- /DC_ECCT.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deep Coding for Linear Block Error Correction 3 | """ 4 | from torch.nn import LayerNorm 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | import copy 10 | 11 | def sign_to_bin(x): 12 | return 0.5 * (1 - x) 13 | 14 | def clones(module, N): 15 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 16 | 17 | 18 | class Encoder(nn.Module): 19 | def __init__(self, layer, N): 20 | super(Encoder, self).__init__() 21 | self.layers = clones(layer, N) 22 | self.norm = LayerNorm(layer.size) 23 | if N > 1: 24 | self.norm2 = LayerNorm(layer.size) 25 | 26 | def forward(self, x, mask): 27 | for idx, layer in enumerate(self.layers, start=1): 28 | x = layer(x, mask) 29 | if idx == len(self.layers)//2 and len(self.layers) > 1: 30 | x = self.norm2(x) 31 | return self.norm(x) 32 | 33 | 34 | class SublayerConnection(nn.Module): 35 | def __init__(self, size, dropout): 36 | super(SublayerConnection, self).__init__() 37 | self.norm = LayerNorm(size) 38 | self.dropout = nn.Dropout(dropout) 39 | 40 | def forward(self, x, sublayer): 41 | return x + self.dropout(sublayer(self.norm(x))) 42 | 43 | 44 | class EncoderLayer(nn.Module): 45 | def __init__(self, size, self_attn, feed_forward, dropout): 46 | super(EncoderLayer, self).__init__() 47 | self.self_attn = self_attn 48 | self.feed_forward = feed_forward 49 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 50 | self.size = size 51 | 52 | def forward(self, x, mask): 53 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 54 | return self.sublayer[1](x, self.feed_forward) 55 | 56 | 57 | class MultiHeadedAttention(nn.Module): 58 | def __init__(self, h, d_model, args, dropout=0.1): 59 | super(MultiHeadedAttention, self).__init__() 60 | assert d_model % h == 0 61 | self.args = args 62 | self.d_k = d_model // h 63 | self.h = h 64 | self.linears = clones(nn.Linear(d_model, d_model), 4) 65 | self.attn = None 66 | self.dropout = nn.Dropout(p=dropout) 67 | ### 68 | d_hidden = 50 69 | self.one_d_mapping = nn.Sequential(nn.Linear(1, d_hidden), nn.ReLU(),nn.Linear(d_hidden, 1)) 70 | 71 | def get_mask_from_pc_matrix(self, pc_matrix): 72 | mask_nk_nk = pc_matrix@pc_matrix.T 73 | mask_n_n = pc_matrix.T@pc_matrix 74 | tmp1 = torch.cat([mask_n_n,pc_matrix.T],1) 75 | tmp2 = torch.cat([pc_matrix,mask_nk_nk],1) 76 | return self.one_d_mapping(torch.cat([tmp1,tmp2],0).unsqueeze(-1)).squeeze().unsqueeze(0).unsqueeze(0) 77 | 78 | def forward(self, query, key, value, mask=None): 79 | nbatches = query.size(0) 80 | query, key, value = \ 81 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 82 | for l, x in zip(self.linears, (query, key, value))] 83 | 84 | x, self.attn = self.attention(query, key, value, mask=mask) 85 | 86 | x = x.transpose(1, 2).contiguous() \ 87 | .view(nbatches, -1, self.h * self.d_k) 88 | return self.linears[-1](x) 89 | 90 | def attention(self, query, key, value, mask=None): 91 | d_k = query.size(-1) 92 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 93 | scores = scores*self.get_mask_from_pc_matrix(mask) 94 | # scores.masked_fill_(torch.eye(scores.size(-1)).unsqueeze(0).unsqueeze(0).bool().to(scores.device),-1e9) 95 | p_attn = F.softmax(scores, dim=-1) 96 | if self.dropout is not None: 97 | p_attn = self.dropout(p_attn) 98 | return torch.matmul(p_attn, value), None 99 | 100 | class PositionwiseFeedForward(nn.Module): 101 | def __init__(self, d_model, d_ff, dropout=0): 102 | super(PositionwiseFeedForward, self).__init__() 103 | self.w_1 = nn.Linear(d_model, d_ff) 104 | self.w_2 = nn.Linear(d_ff, d_model) 105 | self.dropout = nn.Dropout(dropout) 106 | 107 | def forward(self, x): 108 | return self.w_2(self.dropout(F.gelu(self.w_1(x)))) 109 | 110 | ############################################################ 111 | 112 | 113 | class DC_ECC_Transformer(nn.Module): 114 | def __init__(self, args): 115 | super(DC_ECC_Transformer, self).__init__() 116 | #### 117 | self.args = args 118 | c = copy.deepcopy 119 | attn = MultiHeadedAttention(args.h, args.d_model, args, dropout=args.dropout_attn) 120 | ff = PositionwiseFeedForward(args.d_model, args.d_model*4, args.dropout) 121 | # 122 | self.src_embed_synd = torch.nn.Embedding(5, args.d_model) 123 | self.src_embed_magn = torch.nn.Parameter(torch.empty( 124 | (1, args.d_model))) 125 | self.decoder = Encoder(EncoderLayer( 126 | args.d_model, c(attn), c(ff), args.dropout), args.N_dec) 127 | self.oned_final_embed = torch.nn.Sequential( 128 | *[nn.Linear(args.d_model, 1)]) 129 | self.synd_to_mag = nn.Linear(args.d_model, args.d_model) 130 | self.mag_to_mag = nn.Linear(args.d_model, args.d_model) 131 | 132 | for name, p in self.named_parameters(): 133 | if p.dim() > 1 and 'mask_emb' not in name and 'src_embed_synd' not in name: 134 | nn.init.xavier_uniform_(p) 135 | 136 | 137 | def forward(self, magnitude, syndrome, mask, z_mul, y, dict_batch): 138 | emb_magn = self.src_embed_magn.unsqueeze(0) * magnitude.unsqueeze(-1) 139 | emb_synd = self.src_embed_synd(sign_to_bin(syndrome).long()) 140 | emb = torch.cat([emb_magn, emb_synd], 1) 141 | emb = self.decoder(emb, mask) 142 | loss, x_pred = self.loss(-emb, z_mul, y, dict_batch) 143 | return emb, loss.unsqueeze(-1), x_pred 144 | 145 | def loss(self, z_pred, z2, y, pc_matrix): 146 | n_max = z2.size(1) 147 | tmp = z_pred[:,n_max:].unsqueeze(-1)*pc_matrix.unsqueeze(0).unsqueeze(2) 148 | z_pred = self.mag_to_mag(z_pred[:,:n_max])+self.synd_to_mag(tmp.permute(0,1,3,2)).sum(1) 149 | z_pred = self.oned_final_embed(z_pred).squeeze() 150 | # 151 | z_pred = z_pred[:,:z2.size(1)] 152 | 153 | loss = (F.binary_cross_entropy_with_logits(z_pred, sign_to_bin(torch.sign(z2)),reduction='none')).mean(-1).mean() 154 | x_pred = sign_to_bin(torch.sign(-z_pred * torch.sign(y))) 155 | return loss, x_pred 156 | ############################################################ 157 | ############################################################ 158 | 159 | if __name__ == '__main__': 160 | pass 161 | --------------------------------------------------------------------------------