├── CRC.py ├── README.md ├── Simulation.py ├── construction.py ├── construction_method ├── GA.py ├── GA_function.py ├── PW.py └── ZW.py ├── decoder.py ├── function.py ├── plotbler.py ├── polarcodes.py ├── snr_generate.py ├── test.py └── writeintxt.py /CRC.py: -------------------------------------------------------------------------------- 1 | class CRC: 2 | """循环冗余检验 3 | parameters 4 | ---------- 5 | info : list 6 | 需要被编码的信息 7 | crc_n : int, default: 32 8 | 生成多项式的阶数 9 | p : list 10 | 生成多项式 11 | q : list 12 | crc后得到的商 13 | check_code : list 14 | crc后得到的余数,即计算得到的校验码 15 | code : list 16 | 最终的编码 17 | """ 18 | 19 | def __init__(self, info, crc_n): 20 | self.info = info 21 | 22 | # 初始化生成多项式p 23 | if crc_n == 8: 24 | loc = [8, 2, 1, 0] 25 | elif crc_n ==12: 26 | loc = [12, 11, 3, 1, 0] 27 | elif crc_n == 16: 28 | loc = [16, 15, 2, 0] 29 | elif crc_n == 32: 30 | loc = [32, 26, 23, 22, 16, 12, 11, 10, 8, 7, 5, 2, 1, 0] 31 | elif crc_n == 4: 32 | loc = [4, 1, 0] 33 | else: 34 | print('crc_n have not been added in CRC !') 35 | quit() 36 | p = [0 for i in range(crc_n + 1)] 37 | for i in loc: 38 | p[i] = 1 39 | p=p[::-1] 40 | 41 | info = self.info.copy() 42 | times = len(info) 43 | 44 | # 左移补零 45 | for i in range(crc_n): 46 | info.append(0) 47 | # 除 48 | q = [] 49 | for i in range(times): 50 | if info[i] == 1: 51 | q.append(1) 52 | for j in range(crc_n+1): 53 | info[j + i] = info[j + i] ^ p[j] 54 | else: 55 | q.append(0) 56 | 57 | # 余数 58 | check_code = info[-crc_n::] 59 | 60 | # 生成编码 61 | code = self.info.copy() 62 | for i in check_code: 63 | code.append(i) 64 | 65 | self.crc_n = crc_n 66 | self.p = p 67 | self.q = q 68 | self.check_code = check_code 69 | self.code = code 70 | 71 | def detection(self): 72 | codes=self.info.copy() 73 | crc_n=self.crc_n 74 | info_len=len(codes)-crc_n 75 | info=codes[0:info_len] 76 | codes_1=CRC(info,crc_n) 77 | recodes=codes_1.code 78 | compare=[codes[i]^recodes[i] for i in range(len(codes))] 79 | cc=sum(compare) 80 | if cc == 0: 81 | value = 1 82 | else: 83 | value = 0 84 | return value 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PolarCodesPython 2 | PolarCodes inplemented in Python, including Construction Method (GA, ZW, PW, HPW), Decoder (SC, SC-List, Fast-SC, SCFlip and BP, some early stopping methods in BP decoder), CRC-Aided. 3 | -------------------------------------------------------------------------------- /Simulation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import polarcodes 5 | import construction 6 | import writeintxt 7 | import plotbler 8 | import snr_generate 9 | 10 | 11 | # 初始值预设 12 | run_num_min=1e3 13 | run_num_max=1e6 14 | error_max=100 15 | 16 | n = 10 17 | K=2**(n-1) 18 | SNR_para = [2.5,0.5,2.5] # snr在bsc为转移概率,在bec为删除概率,都为float64,在awgn为Eb/N0信噪比(dB),为列表,[start,step,end] 19 | crc_n=16 #有0,4,8,12,16,32,0表示无crc,其他的多项式在CRC里面改 20 | 21 | frozen_bit = 0 22 | construction_method = 'ga' # 有zw,ga,pw, hpw 23 | channel_con = 'awgn' # 有awgn,bsc和bec 24 | decode_method = 'scl' # 有sc,scl, scf, fsc 和 bp 注意bp用的是带扩展因子s的,具体去function.f_hf_SMS看 25 | para1= 8 26 | para2= 'hf' 27 | # para1 para2 28 | # SC / 忽略 / 忽略 29 | # BP 最大迭代次数,int,一般为 30-70 提前终止方法,str,有 max_iter, g_matrix, crc_es 30 | # SCL 列表数量, int, 华为推荐 8 路径度量计算方法,str, 一般有 hf, exact 31 | # SCF 最大翻转次数,int, 推荐 16 翻转方法,目前只翻转1阶,str,只有 llr_based 32 | # FSC / 忽略 节点识别,有regular_node, type_node, 目前只有regular_node 33 | #初始值预设完毕 34 | 35 | decode_para=[para1,para2] 36 | N=2**n 37 | rate=K/N 38 | SNR = snr_generate.SNR_array(SNR_para) 39 | len_snr=SNR.size 40 | SNR_list=list(SNR) 41 | BER_list=[] 42 | BLER_list=[] 43 | time_start=time.time() 44 | 45 | for i in range(len_snr): 46 | SNRi=SNR[i] 47 | run = 1 48 | error = np.array([0, 0]) 49 | information_pos = construction.construction(N, SNRi, channel_con, construction_method, K + crc_n, rate) 50 | 51 | while run <= run_num_min or (run <= run_num_max and error[0] <= error_max): 52 | #print('The run_num is : ' + str(run), end='\r') 53 | value=polarcodes.polarcodes(n, rate, SNRi, channel_con, decode_method, decode_para, information_pos, frozen_bit,crc_n) 54 | error += value 55 | run += 1 56 | 57 | 58 | ber_bler=np.array([error[0] / (K * (run-1)), error[1] / (run-1)]) 59 | print('When SNR= ',SNRi,'dB, '+'the BER is : ',ber_bler[0],', the BLER is : ',ber_bler[1],'. Run loop : ',run-1) 60 | BER_list.append(ber_bler[0]) 61 | BLER_list.append(ber_bler[1]) 62 | 63 | time_end=time.time() 64 | time_cost=int(time_end-time_start) 65 | print('Time cost : ',time_cost,' s') 66 | 67 | #写入文件ber_bler内 68 | writeintxt.writetxt(N,K,SNR_para,crc_n,decode_method,BER_list,BLER_list,time_cost,construction_method,decode_para) 69 | 70 | #绘图 71 | plotbler.plotfig(SNR_list,BER_list,BLER_list) 72 | -------------------------------------------------------------------------------- /construction.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('construction_method') 3 | 4 | import ZW 5 | import GA 6 | import PW 7 | 8 | 9 | def construction(N,snr,channel_con,construction_method,K,rate): 10 | if construction_method == 'zw': 11 | if channel_con == 'awgn': 12 | print('When the channel is AWGN, do not use the ZW construction method !') 13 | quit() 14 | else: 15 | information_pos = ZW.zw(N, snr, K, channel_con) 16 | elif construction_method == 'ga': 17 | if channel_con == 'bec' or channel_con == 'bsc': 18 | print('When the channel is BEC or BSC, do not use the GA construction method !') 19 | quit() 20 | else: 21 | information_pos = GA.ga(N, snr, K, rate) 22 | elif construction_method == 'pw': 23 | information_pos = PW.pw(N, K) 24 | elif construction_method == 'hpw': 25 | information_pos = PW.hpw(N, K) 26 | else: 27 | print('This is not a construction method or s have not added it to the construction_method !') 28 | quit() 29 | # print(information_pos) 30 | return information_pos -------------------------------------------------------------------------------- /construction_method/GA.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import GA_function 4 | 5 | def ga(N,snr,information_num,rate): 6 | n=np.log2(N) 7 | sigma2=(1/(2*rate))*(10**(-snr/10)) 8 | llr=2/sigma2 9 | 10 | llri = [0 for i in range(N)] 11 | llrcopy = llri.copy() 12 | llri[0] = llr 13 | m = 1 14 | while m <= N / 2: 15 | for k in range(m): 16 | llr_temp = llri[k] 17 | llrcopy[k*2] = GA_function.phi_inverse(1 - (1 - GA_function.phi(llr_temp))**2) 18 | llrcopy[2*k + 1] = llr_temp * 2 19 | llri = llrcopy.copy() 20 | #print(llri) 21 | m = 2 * m 22 | 23 | 24 | llri_array = np.array(llri) 25 | llri_sorted = np.argsort(llri_array) 26 | llri_sort = list(llri_sorted) 27 | information_pos = llri_sort[-information_num:] 28 | information_pos = sorted(information_pos) 29 | 30 | return information_pos -------------------------------------------------------------------------------- /construction_method/GA_function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def phi(x): 4 | if x >= 0 and x <= 10: 5 | y=np.exp(-0.4527*(x**0.859)+0.0218) 6 | elif x >= 10: 7 | y=np.sqrt(np.pi/x)*np.exp(-x/4)*(1-(10/(7*x))) 8 | else: 9 | print('x is not a positive number, so the function: phi can not pass !') 10 | quit() 11 | return y 12 | 13 | def phi_derivative(x): 14 | if x >= 0 and x <= 10: 15 | dx = -0.4527 * 0.86 * (x ** (-0.14)) * phi(x) 16 | else: 17 | dx = np.sqrt(np.pi) * np.exp(-x / 4) * ((15 / 7) * (x ** (-5 / 2)) - (1 / 7) * (x ** (-3 / 2)) - (1 / 4) * (x ** (-1 / 2))) 18 | 19 | return dx 20 | 21 | def phi_inverse(x): 22 | #用的是Newton's method,也就是牛顿迭代法求零点 23 | if x >= 0.0388 and x <= 1.0221: 24 | y= ((0.0218-np.log(x))/0.4527)**(1/0.86) 25 | else: 26 | x0=0.0388 27 | x1=x0-((phi(x0)-x)/phi_derivative(x0)) 28 | delta=np.abs(x1-x0) 29 | gap=1e-3 30 | 31 | while delta >= gap: 32 | x0=x1 33 | x1=x1-((phi(x1)-x)/phi_derivative(x1)) 34 | if x1 > 1e2: 35 | gap = 10 36 | delta=np.abs(x1-x0) 37 | y=x1 38 | return y 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /construction_method/PW.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def pw(N,information_num): 4 | n=int(np.log2(N)) 5 | beta=2**0.25 6 | channels=np.zeros(N) 7 | for i in range(N): 8 | bin_seq=np.binary_repr(i,n) 9 | sum=0 10 | for j in range(n): 11 | if bin_seq[j] == '1': 12 | sum+= beta**(n-j-1) 13 | channels[i]=sum 14 | 15 | pw_array = np.array(channels) 16 | pw_sorted = np.argsort(pw_array) 17 | pw_sort = list(pw_sorted) 18 | information_pos = pw_sort[-information_num:] 19 | information_pos = sorted(information_pos) 20 | return information_pos 21 | 22 | def hpw(N,information_num): 23 | n = int(np.log2(N)) 24 | beta = 2 ** 0.25 25 | channels = np.zeros(N) 26 | for i in range(N): 27 | bin_seq = np.binary_repr(i, n) 28 | sum = 0 29 | for j in range(n): 30 | if bin_seq[j] == '1': 31 | sum += beta ** (n - j - 1)+ 0.25 * (beta**(0.25*(n-j-1))) 32 | channels[i] = sum 33 | 34 | hpw_array = np.array(channels) 35 | hpw_sorted = np.argsort(hpw_array) 36 | hpw_sort = list(hpw_sorted) 37 | information_pos = hpw_sort[-information_num:] 38 | information_pos = sorted(information_pos) 39 | return information_pos -------------------------------------------------------------------------------- /construction_method/ZW.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def zw(N,snr,information_num,channel_con): 4 | n=np.log2(N) 5 | if channel_con == 'bec': 6 | zw=snr 7 | else: 8 | zw=2*np.sqrt(snr*(1-snr)) 9 | zwi=[0 for i in range(N)] 10 | zwi[0]=zw 11 | m=1 12 | while m <= N/2: 13 | for k in range(1,m+1,1): 14 | z_temp=zwi[k-1] 15 | zwi[k-1]=2*z_temp-z_temp**2 16 | zwi[k+m-1]=z_temp**2 17 | m=2*m 18 | zwi_array=np.array(zwi) 19 | zwi_sorted=np.argsort(zwi_array) 20 | zwi_sort=list(zwi_sorted) 21 | information_pos=zwi_sort[0:information_num] 22 | information_pos=sorted(information_pos) 23 | 24 | return information_pos -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import function 4 | import CRC 5 | 6 | def decoder(y,decode_method,decode_para,information_pos,frozen_bit,channel_con,snr,rate,crc_n): 7 | N = y.size 8 | n = np.log2(N) 9 | #G = function.generate_matrix(n) 10 | if channel_con == 'awgn': 11 | sigma=(1/np.sqrt(2*rate))*(10**(-snr/20)) 12 | y_llr = (2 * y) / (sigma**2) 13 | #print(y_llr) 14 | elif channel_con == 'bec': 15 | y_1=y.copy() 16 | y_1[y_1 == 0] = np.infty 17 | y_1[y_1 == 1] = -np.infty 18 | y_1[y_1 == -1] = 0 19 | y_llr=y_1.copy() 20 | #print(y_llr) 21 | elif channel_con == 'bsc': 22 | a=np.log((1-snr)/snr) 23 | b=-a 24 | y_1=y.copy() 25 | y_1[y_1 == 0] = a 26 | y_1[y_1 == 1] = b 27 | y_llr=y_1.copy() 28 | else: 29 | print('This is not a channel name or s have not added it to the function: channel!') 30 | quit() 31 | #print(y_llr) 32 | #至此求得解码前需要的llr值,下面开始译码 33 | if decode_method == 'sc': 34 | if crc_n != 0: 35 | print('Notice: SC decoder can not use CRC !') 36 | u_d=sc_decoder(y_llr,information_pos,frozen_bit) 37 | u_d=u_d[0] 38 | elif decode_method == 'scl': 39 | if crc_n == 0: 40 | print('Notice: There is no CRC to check and select the right path !') 41 | if type(decode_para[0]) != int or decode_para[0] <= 0: 42 | print('Please input the right L(e.g. the number of list, para1) ') 43 | quit() 44 | u_d = scl_decoder(y_llr, information_pos, frozen_bit,decode_para,crc_n) 45 | elif decode_method == 'bp': 46 | if crc_n != 0 and decode_para[1] != 'crc_es': 47 | print('Notice: This BP early stopping method can not use CRC !') 48 | elif crc_n == 0 and decode_para[1] == 'crc_es': 49 | print('There is no CRC to early stop !') 50 | quit() 51 | u_d=bp_decoder(y_llr,information_pos,frozen_bit,decode_para,crc_n) 52 | elif decode_method == 'scf': 53 | if crc_n == 0: 54 | print('There is no CRC, so SCF decoder can not run !') 55 | quit() 56 | else: 57 | u_d = scf_decoder(y_llr, information_pos, frozen_bit,decode_para,crc_n) 58 | elif decode_method == 'fsc': 59 | if crc_n != 0: 60 | print('Notice: FSC decoder can not use CRC !') 61 | if decode_para[1] != 'regular_node' and decode_para[1] != 'type_node': 62 | print('The node process method is not right, revise the para2 !') 63 | quit() 64 | u_d=fsc_decoder(y_llr,information_pos,frozen_bit,decode_para) 65 | else: 66 | print('s have not added the decode_method in the function of decoder.decoder !') 67 | quit() 68 | return u_d 69 | 70 | def sc_decoder(y_llr,information_pos,frozen_bit): 71 | N = y_llr.size 72 | #print(N) 73 | n = int(np.log2(N)) 74 | llr_matrix=np.ones((n+1,N)) 75 | llr_matrix[llr_matrix == 1] = float('nan') 76 | bit_matrix=llr_matrix.copy() 77 | llr_matrix[0]=y_llr 78 | position=[0,0,n,N] 79 | while function.all_num(bit_matrix[n]) == 0: 80 | up_llr=llr_matrix[position[0]][position[1]:position[1]+2**(position[2]-position[0])] 81 | up_bit=bit_matrix[position[0]][position[1]:position[1]+2**(position[2]-position[0])] 82 | left_llr=llr_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)] 83 | left_bit=bit_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)] 84 | right_llr=llr_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])] 85 | right_bit=bit_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])] 86 | #print(left_llr) 87 | 88 | if function.all_num(up_bit) == 1: 89 | position=function.up(position) 90 | else: 91 | if function.all_num(right_bit) == 1: 92 | up_bit=function.get_up_bit(left_bit,right_bit) 93 | bit_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])]=up_bit.copy() 94 | else: 95 | if function.all_num(right_llr) == 1: 96 | if position[0] == position[2]-1: 97 | right_bit_pos=position[1]+1 98 | right_bit=function.get_right_bit(right_llr,information_pos,frozen_bit,right_bit_pos) 99 | bit_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])]=right_bit 100 | else: 101 | position=function.rightdown(position) 102 | else: 103 | if function.all_num(left_bit) == 1: 104 | right_llr=function.get_right_llr(left_bit,up_llr) 105 | llr_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])]=right_llr 106 | else: 107 | if function.all_num(left_llr) == 0: 108 | left_llr = function.get_left_llr(up_llr) 109 | llr_matrix[position[0] + 1][position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] = left_llr 110 | else: 111 | if position[0] == position[2]-1: 112 | left_bit_pos=position[1] 113 | left_bit=function.get_left_bit(left_llr,information_pos,frozen_bit,left_bit_pos) 114 | bit_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)]=left_bit 115 | else: 116 | position = function.leftdown(position) 117 | 118 | u_d=[bit_matrix[n],llr_matrix[n]] 119 | return u_d 120 | 121 | def scl_decoder(y_llr, information_pos, frozen_bit, decode_para, crc_n): 122 | N = y_llr.size 123 | n = int(np.log2(N)) 124 | llr_matrix = np.ones((n + 1, N)) 125 | llr_matrix[llr_matrix == 1] = float('nan') 126 | bit_matrix = llr_matrix.copy() 127 | llr_matrix[0] = y_llr 128 | list_max= decode_para[0] 129 | pm_method=decode_para[1] 130 | split_pos=information_pos #采用传统的方法,在所有信息位分裂,未裁剪和减复杂度 131 | llr_list = [llr_matrix] 132 | bit_list = [bit_matrix] 133 | #print(llr_list) 134 | #print(bit_list) 135 | pm_list = [0] 136 | split_loc=0 137 | split_len=len(split_pos) 138 | l_now=1 139 | 140 | while split_len-1 >= split_loc: 141 | for i in range(l_now): 142 | llr_matrix_temp = llr_list[i] 143 | bit_matrix_temp = bit_list[i] 144 | pm_temp = pm_list[i] 145 | 146 | matrix_temp = sc_stepping_decoder(llr_matrix_temp, bit_matrix_temp, information_pos, frozen_bit, 147 | split_pos[split_loc]) 148 | 149 | llr_list[i] = matrix_temp[0] 150 | bit_list[i] = matrix_temp[1] 151 | right_pm_update = function.get_pm_update( 152 | matrix_temp[0][n][split_pos[split_loc - 1] + 1:split_pos[split_loc] + 1], 153 | matrix_temp[1][n][split_pos[split_loc - 1] + 1:split_pos[split_loc] + 1], pm_method) 154 | pm_list[i] = pm_temp + right_pm_update 155 | llr_list.append(matrix_temp[0].copy()) 156 | bit_matrix_wrong = matrix_temp[1].copy() 157 | bit_matrix_wrong[n][split_pos[split_loc]] = 1 - bit_matrix_wrong[n][split_pos[split_loc]] 158 | bit_list.append(bit_matrix_wrong.copy()) 159 | wrong_pm_update = function.get_pm_update( 160 | matrix_temp[0][n][split_pos[split_loc - 1] + 1:split_pos[split_loc] + 1], 161 | bit_matrix_wrong[n][split_pos[split_loc - 1] + 1:split_pos[split_loc] + 1], pm_method) 162 | pm_list.append(pm_temp + wrong_pm_update) 163 | 164 | if l_now > list_max/2: 165 | pm_list_arg=np.argsort(pm_list) 166 | del_list_arg=pm_list_arg[list_max:] 167 | #删减多余的列表分支 168 | pm_list = [pm_list[i] for i in range(len(pm_list)) if i not in del_list_arg] 169 | llr_list = [llr_list[i] for i in range(len(llr_list)) if i not in del_list_arg] 170 | bit_list = [bit_list[i] for i in range(len(bit_list)) if i not in del_list_arg] 171 | l_now = len(pm_list) 172 | split_loc += 1 173 | 174 | if split_pos[-1] != N-1: 175 | for i in range(l_now): 176 | llr_matrix_temp = llr_list[i] 177 | bit_matrix_temp = bit_list[i] 178 | pm_temp = pm_list[i] 179 | matrix_temp = sc_stepping_decoder(llr_matrix_temp, bit_matrix_temp, information_pos, frozen_bit,N-1) 180 | llr_list[i] = matrix_temp[0] 181 | bit_list[i] = matrix_temp[1] 182 | right_pm_update = function.get_pm_update( 183 | matrix_temp[0][n][split_pos[split_loc - 1]+1:N], 184 | matrix_temp[1][n][split_pos[split_loc - 1]+1:N], pm_method) 185 | pm_list[i] = pm_temp + right_pm_update 186 | 187 | #接下来提取list中的译码值经过crc校验,采取的是排序pm,从最小的开始经过CRC,第一个通过CRC的即作为译码值输出 188 | pm_argsort=np.argsort(pm_list) 189 | for i in pm_argsort: 190 | u_d_temp=bit_list[i][n] 191 | u_d_temp = np.array([0 if u_d_temp[i] == 0 else 1 for i in range(u_d_temp.size)]) 192 | u_d_temp_info = u_d_temp[information_pos] 193 | u_d_temp_info = list(u_d_temp_info) 194 | crc_c = CRC.CRC(u_d_temp_info, crc_n) 195 | flag = crc_c.detection() 196 | if flag == 1: 197 | u_d=u_d_temp 198 | break 199 | elif flag == 0 and i == pm_argsort[-1]: 200 | u_d_temp=bit_list[pm_argsort[0]][n] 201 | u_d_temp = np.array([0 if u_d_temp[i] == 0 else 1 for i in range(u_d_temp.size)]) 202 | u_d=u_d_temp 203 | 204 | return u_d 205 | 206 | def sc_stepping_decoder(llr_matrix,bit_matrix,information_pos,frozen_bit,split_pos): #运行到判决split_pos的位置,注意是split_pos判决完成后返回的 207 | N = int(bit_matrix[0].size) 208 | n = int(np.log2(N)) 209 | loc=function.get_up_loc(bit_matrix) 210 | position = [loc[0], loc[1], n, N] 211 | 212 | while bit_matrix[n][split_pos] != 0 and bit_matrix[n][split_pos] != 1: 213 | up_llr=llr_matrix[position[0]][position[1]:position[1]+2**(position[2]-position[0])] 214 | up_bit=bit_matrix[position[0]][position[1]:position[1]+2**(position[2]-position[0])] 215 | left_llr=llr_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)] 216 | left_bit=bit_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)] 217 | right_llr=llr_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])] 218 | right_bit=bit_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])] 219 | 220 | if function.all_num(up_bit) == 1: 221 | position = function.up(position) 222 | else: 223 | if function.all_num(right_bit) == 1: 224 | up_bit = function.get_up_bit(left_bit, right_bit) 225 | bit_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])] = up_bit.copy() 226 | else: 227 | if function.all_num(right_llr) == 1: 228 | if position[0] == position[2] - 1: #右底 229 | right_bit_pos = position[1] + 1 230 | right_bit = function.get_right_bit(right_llr, information_pos, frozen_bit, right_bit_pos) 231 | bit_matrix[position[0] + 1][position[1] + 2 ** (position[2] - position[0] - 1):position[1] + 2 ** ( 232 | position[2] - position[0])] = right_bit 233 | else: 234 | position = function.rightdown(position) 235 | else: 236 | if function.all_num(left_bit) == 1: 237 | right_llr = function.get_right_llr(left_bit, up_llr) 238 | llr_matrix[position[0] + 1][position[1] + 2 ** (position[2] - position[0] - 1):position[1] + 2 ** ( 239 | position[2] - position[0])] = right_llr 240 | else: 241 | if function.all_num(left_llr) == 0: 242 | left_llr = function.get_left_llr(up_llr) 243 | llr_matrix[position[0] + 1][ 244 | position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] = left_llr 245 | else: 246 | if position[0] == position[2] - 1: #左底 247 | left_bit_pos = position[1] 248 | left_bit = function.get_left_bit(left_llr, information_pos, frozen_bit, left_bit_pos) 249 | bit_matrix[position[0] + 1][ 250 | position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] = left_bit 251 | else: 252 | position = function.leftdown(position) 253 | return [llr_matrix,bit_matrix] 254 | 255 | def fsc_decoder(y_llr,information_pos,frozen_bit,decode_para): 256 | N = y_llr.size 257 | n = int(np.log2(N)) 258 | node_list = function.node_identify(N,information_pos,decode_para[1]) 259 | 260 | llr_matrix = np.ones((n + 1, N)) 261 | llr_matrix[llr_matrix == 1] = float('nan') 262 | bit_matrix = llr_matrix.copy() 263 | llr_matrix[0] = y_llr 264 | position = [0, 0, n, N] 265 | while function.all_num(bit_matrix[n]) == 0: 266 | node_col = int(np.floor(position[1]/(2**(n-position[0])))) 267 | if node_list[position[0]][node_col] == 'nan': 268 | up_llr = llr_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])] 269 | up_bit = bit_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])] 270 | left_llr = llr_matrix[position[0] + 1][position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] 271 | left_bit = bit_matrix[position[0] + 1][position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] 272 | right_llr = llr_matrix[position[0] + 1][ 273 | position[1] + 2 ** (position[2] - position[0] - 1):position[1] + 2 ** (position[2] - position[0])] 274 | right_bit = bit_matrix[position[0] + 1][ 275 | position[1] + 2 ** (position[2] - position[0] - 1):position[1] + 2 ** (position[2] - position[0])] 276 | 277 | if function.all_num(up_bit) == 1: 278 | position = function.up(position) 279 | else: 280 | if function.all_num(right_bit) == 1: 281 | up_bit = function.get_up_bit(left_bit, right_bit) 282 | bit_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])] = up_bit.copy() 283 | else: 284 | if function.all_num(right_llr) == 1: 285 | if position[0] == position[2] - 1: 286 | right_bit_pos = position[1] + 1 287 | right_bit = function.get_right_bit(right_llr, information_pos, frozen_bit, right_bit_pos) 288 | bit_matrix[position[0] + 1][ 289 | position[1] + 2 ** (position[2] - position[0] - 1):position[1] + 2 ** ( 290 | position[2] - position[0])] = right_bit 291 | else: 292 | position = function.rightdown(position) 293 | else: 294 | if function.all_num(left_bit) == 1: 295 | right_llr = function.get_right_llr(left_bit, up_llr) 296 | llr_matrix[position[0] + 1][ 297 | position[1] + 2 ** (position[2] - position[0] - 1):position[1] + 2 ** ( 298 | position[2] - position[0])] = right_llr 299 | else: 300 | if function.all_num(left_llr) == 0: 301 | left_llr = function.get_left_llr(up_llr) 302 | llr_matrix[position[0] + 1][ 303 | position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] = left_llr 304 | else: 305 | if position[0] == position[2] - 1: 306 | left_bit_pos = position[1] 307 | left_bit = function.get_left_bit(left_llr, information_pos, frozen_bit, left_bit_pos) 308 | bit_matrix[position[0] + 1][ 309 | position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] = left_bit 310 | else: 311 | position = function.leftdown(position) 312 | 313 | else: 314 | up_llr = llr_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])] 315 | node_result = function.node_process(up_llr,node_list[position[0]][node_col]) 316 | bit_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])] = node_result[0] 317 | bit_matrix[n][position[1]:position[1] + 2 ** (position[2] - position[0])] = node_result[1] 318 | position = function.up(position) 319 | u_d = bit_matrix[n] 320 | return u_d 321 | 322 | 323 | 324 | 325 | 326 | 327 | def sc_flip1_decoder(y_llr,information_pos,frozen_bit,flip_pos): 328 | N = y_llr.size 329 | #print(N) 330 | n = int(np.log2(N)) 331 | llr_matrix=np.ones((n+1,N)) 332 | llr_matrix[llr_matrix == 1] = float('nan') 333 | bit_matrix=llr_matrix.copy() 334 | llr_matrix[0]=y_llr 335 | position=[0,0,n,N] 336 | while function.all_num(bit_matrix[n]) == 0: 337 | up_llr=llr_matrix[position[0]][position[1]:position[1]+2**(position[2]-position[0])] 338 | up_bit=bit_matrix[position[0]][position[1]:position[1]+2**(position[2]-position[0])] 339 | left_llr=llr_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)] 340 | left_bit=bit_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)] 341 | right_llr=llr_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])] 342 | right_bit=bit_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])] 343 | 344 | if function.all_num(up_bit) == 1: 345 | position=function.up(position) 346 | else: 347 | if function.all_num(right_bit) == 1: 348 | up_bit=function.get_up_bit(left_bit,right_bit) 349 | bit_matrix[position[0]][position[1]:position[1] + 2 ** (position[2] - position[0])]=up_bit.copy() 350 | else: 351 | if function.all_num(right_llr) == 1: 352 | if position[0] == position[2]-1: 353 | right_bit_pos=position[1]+1 354 | right_bit=function.get_right_bit_flip1(right_llr,information_pos,frozen_bit,right_bit_pos,flip_pos) 355 | bit_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])]=right_bit 356 | else: 357 | position=function.rightdown(position) 358 | else: 359 | if function.all_num(left_bit) == 1: 360 | right_llr=function.get_right_llr(left_bit,up_llr) 361 | llr_matrix[position[0]+1][position[1]+2**(position[2]-position[0]-1):position[1]+2**(position[2]-position[0])]=right_llr 362 | else: 363 | if function.all_num(left_llr) == 0: 364 | left_llr = function.get_left_llr(up_llr) 365 | llr_matrix[position[0] + 1][position[1]:position[1] + 2 ** (position[2] - position[0] - 1)] = left_llr 366 | else: 367 | if position[0] == position[2]-1: 368 | left_bit_pos=position[1] 369 | left_bit=function.get_left_bit_flip1(left_llr,information_pos,frozen_bit,left_bit_pos,flip_pos) 370 | bit_matrix[position[0]+1][position[1]:position[1]+2**(position[2]-position[0]-1)]=left_bit 371 | else: 372 | position = function.leftdown(position) 373 | 374 | u_d=[bit_matrix[n],llr_matrix[n]] 375 | return u_d 376 | 377 | def scf_decoder(y_llr, information_pos, frozen_bit,decode_para,crc_n): 378 | # 第一次用SC译码的过程 379 | N = y_llr.size 380 | n = int(np.log2(N)) 381 | u_d_1_list = sc_decoder(y_llr,information_pos,frozen_bit) 382 | u_d_1_llr=u_d_1_list[1].copy() 383 | u_d_1=u_d_1_list[0].copy() 384 | u_d_1 = np.array([0 if u_d_1[i] == 0 else 1 for i in range(u_d_1.size)]) 385 | u_d_1_info=u_d_1[information_pos] 386 | u_d_1_llr_info=u_d_1_llr[information_pos] 387 | #print(u_d_1_llr_info) 388 | u_d_1_info=list(u_d_1_info) 389 | #print(u_d_1) 390 | crc_c = CRC.CRC(u_d_1_info, crc_n) 391 | flag = crc_c.detection() 392 | #print(flag) 393 | flip_info_pos1 = list(np.argsort(np.abs(u_d_1_llr_info))) 394 | flip_info_pos = flip_info_pos1[0:decode_para[0]] 395 | information_pos_array=np.array(information_pos) 396 | flip_pos=information_pos_array[flip_info_pos] 397 | #print(flip_pos) 398 | if flag == 0: 399 | T=1 400 | while T <= decode_para[0]: 401 | #print(T) 402 | u_d_2_list=sc_flip1_decoder(y_llr,information_pos,frozen_bit,flip_pos[T-1]) 403 | u_d_2_llr = u_d_2_list[1].copy() 404 | u_d_2 = u_d_2_list[0].copy() 405 | u_d_2 = np.array([0 if u_d_2[i] == 0 else 1 for i in range(u_d_2.size)]) 406 | u_d_2_info=u_d_2[information_pos] 407 | u_d_2_info=list(u_d_2_info) 408 | crc_c = CRC.CRC(u_d_2_info, crc_n) 409 | flag1 = crc_c.detection() 410 | if flag1 == 1: 411 | u_d = np.array(u_d_2) 412 | #print('The scf decoder flip: ',T,' times and it works.') 413 | break 414 | else: 415 | T += 1 416 | u_d = np.array(u_d_1) 417 | else: 418 | #print('no flip and success') 419 | u_d=np.array(u_d_1) 420 | return u_d 421 | 422 | def bp_decoder(y_llr,information_pos,frozen_bit,decode_para,crc_n): 423 | N = y_llr.size 424 | n = int(np.log2(N)) 425 | 426 | if decode_para[1] == 'max_iter': 427 | 428 | bp_max_iter=int(decode_para[0]) 429 | #初始化左传播和右传播矩阵 430 | left_matrix=np.zeros([N,n+1]) 431 | right_matrix=np.zeros([N,n+1]) 432 | left_matrix[:,n]=y_llr 433 | temp_value=(1-2*frozen_bit)*np.infty 434 | temp=[temp_value if i not in information_pos else 0 for i in range(N)] 435 | right_matrix[:,0]=temp 436 | 437 | for iter in range(bp_max_iter): 438 | for i in range(n): 439 | left_matrix[:,n-i-1]=function.bp_update_left(left_matrix[:,n-i],right_matrix[:,n-i-1],n-i) 440 | 441 | for i in range(n): 442 | right_matrix[:,i+1]=function.bp_update_right(left_matrix[:,i+1],right_matrix[:,i],i+1) 443 | 444 | 445 | u_d_llr=left_matrix[:,0]+right_matrix[:,0] 446 | u_d=[0 if u_d_llr[i] >= 0 else 1 for i in range(N)] 447 | u_d=np.array(u_d) 448 | 449 | 450 | elif decode_para[1] == 'g_matrix': 451 | 452 | # 初始化左传播和右传播矩阵 453 | left_matrix = np.zeros([N, n + 1]) 454 | right_matrix = np.zeros([N, n + 1]) 455 | left_matrix[:, n] = y_llr 456 | temp_value = (1 - 2 * frozen_bit) * np.infty 457 | temp = [temp_value if i not in information_pos else 0 for i in range(N)] 458 | right_matrix[:, 0] = temp 459 | flag=0 460 | iter_num=1 461 | 462 | while flag == 0 and iter_num <= decode_para[0]: 463 | for i in range(n): 464 | left_matrix[:, n - i - 1] = function.bp_update_left(left_matrix[:, n - i], right_matrix[:, n - i - 1],n - i) 465 | 466 | for i in range(n): 467 | right_matrix[:, i + 1] = function.bp_update_right(left_matrix[:, i + 1], right_matrix[:, i], i + 1) 468 | 469 | u_d_llr = left_matrix[:, 0] + right_matrix[:, 0] 470 | u_d = [0 if u_d_llr[i] >= 0 else 1 for i in range(N)] 471 | x_d_llr = left_matrix[:, n] + right_matrix[:, n] 472 | x_d = [0 if x_d_llr[i] >= 0 else 1 for i in range(N)] 473 | G = function.generate_matrix(n) 474 | x_g = u_d * G 475 | x_g = np.array(x_g % 2) 476 | #print(x_g) 477 | cor = [0 if x_g[0][i]==x_d[i] else 1 for i in range(N)] 478 | if sum(cor) == 0: 479 | flag = 1 480 | else: 481 | flag = 0 482 | iter_num += 1 483 | u_d = np.array(u_d) 484 | 485 | 486 | elif decode_para[1] == 'crc_es': 487 | 488 | # 初始化左传播和右传播矩阵 489 | left_matrix = np.zeros([N, n + 1]) 490 | right_matrix = np.zeros([N, n + 1]) 491 | left_matrix[:, n] = y_llr 492 | temp_value = (1 - 2 * frozen_bit) * np.infty 493 | temp = [temp_value if i not in information_pos else 0 for i in range(N)] 494 | right_matrix[:, 0] = temp 495 | flag = 0 496 | iter_num = 1 497 | 498 | while flag == 0 and iter_num <= decode_para[0]: 499 | for i in range(n): 500 | left_matrix[:, n - i - 1] = function.bp_update_left(left_matrix[:, n - i], right_matrix[:, n - i - 1],n - i) 501 | 502 | for i in range(n): 503 | right_matrix[:, i + 1] = function.bp_update_right(left_matrix[:, i + 1], right_matrix[:, i], i + 1) 504 | 505 | u_d_llr = left_matrix[:, 0] + right_matrix[:, 0] 506 | u_d = np.array([0 if u_d_llr[i] >= 0 else 1 for i in range(N)]) 507 | u_d_info = u_d[information_pos] 508 | u_d_info = list(u_d_info) 509 | crc_c=CRC.CRC(u_d_info,crc_n) 510 | flag=crc_c.detection() 511 | # if flag == 1: 512 | # print('It work!',iter_num) 513 | 514 | iter_num += 1 515 | u_d = np.array(u_d) 516 | else: 517 | print('This is not a early stopping method of BP decoder or s have not added it to the endecoder.bp_decoder !') 518 | quit() 519 | return u_d 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def generate_matrix(n): 4 | F=np.matrix([[1,0],[1,1]]) 5 | n_i=1 6 | G=F 7 | while n_i < n: 8 | G=np.kron(G,F) 9 | n_i+=1 10 | return G 11 | 12 | def channel(x,channel_con,snr,rate): 13 | if channel_con == 'bsc': 14 | y_1=np.random.choice([0, 1],size=x.size,p=[1-snr,snr]) 15 | y=(y_1+x)%2 16 | elif channel_con == 'bec': 17 | y_1=np.random.choice([-2, 0], size=x.size,p=[snr,1-snr]) 18 | y_2=y_1+x 19 | y_2[y_2 == -2] = -1 20 | y=y_2.copy() 21 | elif channel_con == 'awgn': 22 | sigma=(1/np.sqrt(2*rate))*(10**(-snr/20)) 23 | #print(sigma) 24 | #print(x) 25 | #print(1-2*x) 26 | y=(1-2*x)+np.random.randn(x.size)*sigma 27 | #在这里用bpsk将0映射为1,1映射为-1 28 | else: 29 | print('This is not a channel name or s have not added it to the function: channel!') 30 | quit() 31 | return y 32 | 33 | def all_num(x): 34 | length=x.size 35 | flag=1 36 | for i in range(length): 37 | if np.isnan(x[i]) == True: 38 | flag = 0 39 | break 40 | return flag 41 | 42 | def leftdown(position): 43 | p0=position[0]+1 44 | p1=position[1] 45 | p2=position[2] 46 | p3=position[3] 47 | p=[p0,p1,p2,p3] 48 | return p 49 | 50 | def rightdown(position): 51 | p0=position[0]+1 52 | p1=position[1]+2**(position[2]-1-position[0]) 53 | p2=position[2] 54 | p3=position[3] 55 | p=[p0,p1,p2,p3] 56 | return p 57 | 58 | def up(position): 59 | p0=position[0]-1 60 | p1_t=np.floor(position[1]/(2**(position[2]-position[0]+1)))*(2**(position[2]-position[0]+1)) 61 | p1=int(p1_t) 62 | p2=position[2] 63 | p3=position[3] 64 | p=[p0,p1,p2,p3] 65 | return p 66 | 67 | def get_up_bit(left_bit,right_bit): 68 | length=left_bit.size 69 | #print(left_bit) 70 | #print(right_bit) 71 | temp=np.array([(left_bit+right_bit)%2,right_bit]) 72 | temp.resize((1,2*length)) 73 | return temp 74 | 75 | def get_right_bit(right_llr,information_pos,frozen_bit,right_bit_pos): 76 | if right_bit_pos in information_pos: 77 | if right_llr > 0: 78 | temp=0 79 | else: 80 | temp=1 81 | else: 82 | temp=frozen_bit 83 | return temp 84 | 85 | def get_right_bit_flip1(right_llr,information_pos,frozen_bit,right_bit_pos,flip_pos): 86 | if right_bit_pos in information_pos: 87 | if right_bit_pos == flip_pos: 88 | if right_llr > 0: 89 | temp=1 90 | else: 91 | temp=0 92 | else: 93 | if right_llr > 0: 94 | temp = 0 95 | else: 96 | temp = 1 97 | else: 98 | temp=frozen_bit 99 | return temp 100 | 101 | def get_right_llr(left_bit,up_llr): 102 | length=int(left_bit.size) 103 | temp=np.array([g(up_llr[i],up_llr[i+length],left_bit[i]) for i in range(length)]) 104 | return temp 105 | 106 | def get_left_bit(left_llr,information_pos,frozen_bit,left_bit_pos): 107 | if left_bit_pos in information_pos: 108 | #print(left_llr) 109 | if left_llr >= 0: 110 | temp = 0 111 | else: 112 | temp = 1 113 | else: 114 | temp = frozen_bit 115 | return temp 116 | 117 | def get_left_bit_flip1(left_llr,information_pos,frozen_bit,left_bit_pos,flip_pos): 118 | if left_bit_pos in information_pos: 119 | if left_bit_pos == flip_pos: 120 | if left_llr >= 0: 121 | temp = 1 122 | else: 123 | temp = 0 124 | else: 125 | if left_llr >= 0: 126 | temp = 0 127 | else: 128 | temp = 1 129 | else: 130 | temp = frozen_bit 131 | return temp 132 | 133 | def get_left_llr(up_llr): 134 | length=int(up_llr.size/2) 135 | temp=np.array([f_hf(up_llr[i],up_llr[i+length]) for i in range(length)]) 136 | return temp 137 | 138 | def f(L1,L2): 139 | temp=np.log((1+np.exp(L1+L2))/(np.exp(L1)+np.exp(L2))) 140 | return temp 141 | 142 | def f_hf(L1,L2): 143 | #硬件友好型Hard-Friendly的 144 | #print(L1) 145 | s1=np.sign(L1) 146 | s2=np.sign(L2) 147 | if s1 == 0: 148 | s1=1 149 | if s2 == 0: 150 | s2=1 151 | temp=s1*s2*np.min([np.abs(L1),np.abs(L2)]) 152 | return temp 153 | 154 | def f_hf_SMS(L1,L2): 155 | # 硬件友好型Hard-Friendly的,有一个扩展因子s 156 | s=0.9375 157 | s1 = np.sign(L1) 158 | s2 = np.sign(L2) 159 | if s1 == 0: 160 | s1 = 1 161 | if s2 == 0: 162 | s2 = 1 163 | temp = s*s1 * s2 * np.min([np.abs(L1), np.abs(L2)]) 164 | return temp 165 | 166 | def g(L1,L2,U1): 167 | temp=(1-2*U1)*L1+L2 168 | return temp 169 | 170 | def element_update_left(left,right): 171 | value=np.zeros(2) 172 | value[0]=f_hf_SMS(right[1]+left[1],left[0]) 173 | value[1]=f_hf_SMS(left[0],right[0])+left[1] 174 | return value 175 | 176 | def element_update_right(left,right): 177 | value=np.zeros(2) 178 | value[0]=f_hf_SMS(right[1]+left[1],right[0]) 179 | value[1]=f_hf_SMS(left[0],right[0])+right[1] 180 | return value 181 | 182 | 183 | def bp_update_left(left_array,right_array,left_array_n): 184 | N=left_array.size 185 | interval=2**(left_array_n-1) 186 | num=int(N/(interval*2)) 187 | value=np.zeros(N) 188 | for i in range(num): 189 | for j in range(interval): 190 | left_ele=np.zeros(2) 191 | right_ele=np.zeros(2) 192 | left_ele[0]=left_array[2*i*interval+j] 193 | left_ele[1]=left_array[2*i*interval+j+interval] 194 | right_ele[0] = right_array[2 * i * interval + j] 195 | right_ele[1] = right_array[2 * i * interval + j + interval] 196 | get_value=element_update_left(left_ele,right_ele) 197 | value[2*i*interval+j]=get_value[0] 198 | value[2*i*interval+j+interval]=get_value[1] 199 | return value 200 | 201 | def bp_update_right(left_array,right_array,left_array_n): 202 | N=left_array.size 203 | interval=2**(left_array_n-1) 204 | num=int(N/(interval*2)) 205 | value=np.zeros(N) 206 | for i in range(num): 207 | for j in range(interval): 208 | left_ele=np.zeros(2) 209 | right_ele=np.zeros(2) 210 | left_ele[0]=left_array[2*i*interval+j] 211 | left_ele[1]=left_array[2*i*interval+j+interval] 212 | right_ele[0] = right_array[2 * i * interval + j] 213 | right_ele[1] = right_array[2 * i * interval + j + interval] 214 | get_value=element_update_right(left_ele,right_ele) 215 | value[2*i*interval+j]=get_value[0] 216 | value[2*i*interval+j+interval]=get_value[1] 217 | #print('\n',get_value) 218 | 219 | return value 220 | 221 | def get_up_loc(bit_matrix): 222 | N=int(bit_matrix[0].size) 223 | n=int(np.log2(N)) 224 | detect_array=bit_matrix[n] 225 | for i in range(N): 226 | if detect_array[i] == 1 or detect_array[i] == 0: 227 | pass 228 | else: 229 | detect = i-1 230 | break 231 | if detect%2 == 0: 232 | loc_row = n - 1 233 | loc_col=detect 234 | else: 235 | loc_row = n - 1 236 | loc_col = detect-1 237 | if detect == -1: 238 | loc_row=0 239 | loc_col=0 240 | #print(loc_row) 241 | #print(loc_col) 242 | return [loc_row,loc_col] 243 | 244 | def get_pm_update(llr_array,bit_array,pm_method): 245 | l=llr_array.size 246 | pm=0 247 | if pm_method == 'exact': 248 | for i in range(l): 249 | pm=pm+np.log(1+np.exp(-1*(1-2*bit_array[i])*llr_array[i])) 250 | elif pm_method == 'hf': 251 | for i in range(l): 252 | if np.sign(llr_array[i]) != np.sign(1-2*bit_array[i]): 253 | pm += np.abs(llr_array[i]) 254 | else: 255 | print('Notice: This is not a PM compute method or s have not added it to the function.get_pm_update !') 256 | return pm 257 | 258 | def node_identify(N,information_pos,node_method): 259 | n = int(np.log2(N)) 260 | init=['nan'] 261 | node_list = [init] 262 | for i in range(n): 263 | node_list.append(node_list[-1]*2) 264 | node_down=[1 if i in information_pos else 0 for i in range(N)] 265 | node_down2=node_list[-2] 266 | for i in range(int(N/2)): 267 | if node_down[i*2] == 0 and node_down[i*2+1] == 0: 268 | node_down2[i]= 'rate0' 269 | elif node_down[i*2] == 0 and node_down[i*2+1] == 1: 270 | node_down2[i]= 'rep' 271 | elif node_down[i*2] == 1 and node_down[i*2+1] == 1: 272 | node_down2[i]= 'rate1' 273 | else: 274 | pass 275 | node_list[-1]=node_down 276 | node_list[-2] = node_down2 277 | 278 | if node_method == 'regular_node': 279 | depth=list(range(n-1)) 280 | depth.reverse() 281 | for i in depth: 282 | for j in range(2**i): 283 | if node_list[i+1][j*2] == 'rate0' and node_list[i+1][j*2+1] == 'rate0': 284 | node_list[i][j] = 'rate0' 285 | elif node_list[i+1][j*2] == 'rate0' and node_list[i+1][j*2+1] == 'rep': 286 | node_list[i][j] = 'rep' 287 | elif node_list[i + 1][j * 2] == 'rate1' and node_list[i + 1][j * 2 + 1] == 'rate1': 288 | node_list[i][j] = 'rate1' 289 | elif node_list[i+1][j*2] == 'spc' and node_list[i+1][j*2+1] == 'rate1': 290 | node_list[i][j] = 'spc' 291 | elif node_list[i+1][j*2] == 'rep' and node_list[i+1][j*2+1] == 'rate1': 292 | if i == depth[0]: 293 | node_list[i][j] = 'spc' 294 | else: 295 | pass 296 | else: 297 | print('s have not added the Type_1-5 nodes in function.node_process') 298 | quit() 299 | 300 | return node_list 301 | 302 | def node_process(up_llr,node_id): 303 | N = up_llr.size 304 | #print(N) 305 | n = int(np.log2(N)) 306 | if node_id == 'rate0': 307 | bit_depth_i=np.zeros(N) 308 | bit_depth_n=bit_depth_i 309 | elif node_id == 'rep': 310 | llr_sum = np.sum(up_llr) 311 | if llr_sum >= 0: 312 | temp = 0 313 | bit_depth_i = np.zeros(N) 314 | else: 315 | temp = 1 316 | bit_depth_i = np.ones(N) 317 | bit_depth_n = np.zeros(N) 318 | bit_depth_n[-1] = temp 319 | elif node_id == 'spc': 320 | bit_depth_i = np.zeros(N) 321 | for i in range(N): 322 | if up_llr[i] >= 0: 323 | bit_depth_i[i] = 0 324 | else: 325 | bit_depth_i[i] = 1 326 | #print(bit_depth_i) 327 | #print(i) 328 | if np.sum(bit_depth_i)%2 != 0: 329 | flip_index=np.argmin(np.abs(up_llr)) 330 | bit_depth_i[flip_index] = (bit_depth_i[flip_index]+1)%2 331 | bit_depth_n = bit_depth_i*generate_matrix(n) %2 332 | bit_depth_n = np.array(bit_depth_n) 333 | elif node_id == 'rate1': 334 | bit_depth_i = np.zeros(N) 335 | for i in range(N): 336 | if up_llr[i] >= 0: 337 | bit_depth_i[i] = 0 338 | else: 339 | bit_depth_i[i] = 1 340 | bit_depth_n = bit_depth_i * generate_matrix(n) % 2 341 | bit_depth_n = np.array(bit_depth_n) 342 | 343 | return [bit_depth_i,bit_depth_n] 344 | 345 | 346 | -------------------------------------------------------------------------------- /plotbler.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | def plotfig(SNR_list,BER_list,BLER_list): 4 | plt.semilogy(SNR_list,BER_list, linewidth=1,label='BER', linestyle='-',c='b') 5 | plt.semilogy(SNR_list,BLER_list, linewidth=1,label='BLER', linestyle='--',c='b') 6 | plt.legend() 7 | plt.grid() 8 | plt.show() 9 | plt.title('BER', fontsize=15) 10 | plt.xlabel('SNR(dB)', fontsize=14) 11 | plt.ylabel('BER/BLER', fontsize=14) 12 | plt.tick_params(axis='x', labelsize=10) 13 | plt.tick_params(axis='y', labelsize=10, width=2) 14 | # x = range(0, 10, 1) 15 | # y = range(0, 12, 1) 16 | # plt.xticks(x) 17 | # plt.yticks(y) -------------------------------------------------------------------------------- /polarcodes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | sys.path.append('construction_method') 4 | 5 | import function 6 | import decoder 7 | import CRC 8 | 9 | def polarcodes(n,rate,snr,channel_con,decode_method,decode_para,information_pos,frozen_bit,crc_n): 10 | N = 2 ** n 11 | information_num = int(N * rate) 12 | # 信息-information_bit生成 13 | information_bit = np.random.randint(0, 2, size=(1, information_num)) 14 | 15 | if crc_n == 0: 16 | pass 17 | else: 18 | informationbit=information_bit.copy() 19 | informationbit.resize(information_num,) 20 | information_bit_list=list(informationbit) 21 | crc_info=CRC.CRC(information_bit_list,crc_n) 22 | crc_information_bit=crc_info.code 23 | crc_information_bit=np.array(crc_information_bit) 24 | crc_information_bit.resize(1,information_num+crc_n) 25 | 26 | # 编码前序列-u生成 27 | u = np.ones((1, N)) * frozen_bit 28 | j = 0 29 | #print(u.size) 30 | #print(information_bit.size) 31 | if crc_n == 0: 32 | for i in information_pos: 33 | u[0][i] = information_bit[0][j] 34 | j += 1 35 | else: 36 | for i in information_pos: 37 | u[0][i] = crc_information_bit[0][j] 38 | j += 1 39 | 40 | #print(information_pos) 41 | # 生成矩阵-G生成 42 | G = function.generate_matrix(n) 43 | 44 | # 编码比特-x生成 45 | x = u*G 46 | x = np.array(x % 2) 47 | 48 | # 经过信道生成y 49 | y = function.channel(x, channel_con, snr, rate) 50 | 51 | # y进入译码器生成u_d 52 | u_d = decoder.decoder(y, decode_method, decode_para, information_pos, frozen_bit, channel_con, snr, rate, crc_n) 53 | #print(u_d) 54 | # 计算错误数 55 | information_pos=information_pos[0:information_num] 56 | information_bit_d = u_d[information_pos] 57 | error_num = int(np.sum((information_bit_d + information_bit) % 2)) 58 | if error_num != 0: 59 | decode_fail=1 60 | else: 61 | decode_fail=0 62 | r_value=np.array([error_num,decode_fail]) 63 | 64 | 65 | return r_value 66 | 67 | 68 | -------------------------------------------------------------------------------- /snr_generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def SNR_array(SNR): 4 | if len(SNR) == 3: 5 | value=np.linspace(SNR[0],SNR[2],int((SNR[2]-SNR[0])/SNR[1]+1)) 6 | value1=[round(value[i],2) for i in range(value.size)] 7 | value1=np.array(value1) 8 | elif len(SNR) == 2: 9 | value = np.linspace(SNR[0],SNR[1],int((SNR[1]-SNR[0])/0.1+1)) 10 | value1=[round(value[i],2) for i in range(value.size)] 11 | value1 = np.array(value1) 12 | elif len(SNR) == 1: 13 | value=np.array([SNR[0]]) 14 | value1=[round(value[i],2) for i in range(value.size)] 15 | value1 = np.array(value1) 16 | else: 17 | print('Input the right SNR value') 18 | quit() 19 | return value1 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import polarcodes 5 | import construction 6 | import plotbler 7 | import snr_generate 8 | 9 | 10 | # 初始值预设 11 | run_num=1000 12 | 13 | n = 8 14 | K=2**(n-1) 15 | SNR_para = [2.5,0.5,2.5] # snr在bsc为转移概率,在bec为删除概率,都为float64,在awgn为Eb/N0信噪比(dB),为列表,[start,step,end] 16 | crc_n=0 #有0,4,8,12,16,32,0表示无crc,其他的多项式在CRC里面改 17 | 18 | frozen_bit = 0 19 | construction_method = 'ga' # 有zw,ga,pw, hpw 20 | channel_con = 'awgn' # 有awgn,bsc和bec 21 | decode_method = 'sc' # 有sc,scl, scf, fsc 和 bp 注意bp用的是带扩展因子s的,具体去function.f_hf_SMS看 22 | para1= 4 23 | para2= 'hf' 24 | # para1 para2 25 | # SC / 忽略 / 忽略 26 | # BP 最大迭代次数,int,一般为 30-70 提前终止方法,str,有 max_iter, g_matrix, crc_es 27 | # SCL 列表数量, int, 华为推荐 8 路径度量计算方法,str, 一般有 hf, exact 28 | # SCF 最大翻转次数,int, 推荐 16 翻转方法,目前只翻转1阶,str,只有 llr_based 29 | # FSC / 忽略 节点识别,有regular_node, type_node, 目前只有regular_node 30 | #初始值预设完毕 31 | 32 | decode_para=[para1,para2] 33 | N=2**n 34 | rate=K/N 35 | SNR = snr_generate.SNR_array(SNR_para) 36 | len_snr=SNR.size 37 | SNR_list=list(SNR) 38 | BER_list=[] 39 | BLER_list=[] 40 | time_start=time.time() 41 | 42 | for i in range(len_snr): 43 | SNRi=SNR[i] 44 | run = 1 45 | error = np.array([0, 0]) 46 | information_pos = construction.construction(N,SNRi,channel_con,construction_method,K+crc_n,rate) 47 | 48 | while run <= run_num: 49 | #print('The run_num is : ' + str(run), end='\r') 50 | value=polarcodes.polarcodes(n, rate, SNRi, channel_con, decode_method, decode_para, information_pos, frozen_bit,crc_n) 51 | error += value 52 | run += 1 53 | 54 | ber_bler=np.array([error[0] / (N * rate * run_num), error[1] / run_num]) 55 | print('When SNR= ',SNRi,'dB, ') 56 | print('the BER is : ',ber_bler[0],', and the BLER is : ',ber_bler[1]) 57 | BER_list.append(ber_bler[0]) 58 | BLER_list.append(ber_bler[1]) 59 | 60 | time_end=time.time() 61 | time_cost=int(time_end-time_start) 62 | print('Time cost : ',time_cost,' s') 63 | 64 | #绘图 65 | plotbler.plotfig(SNR_list,BER_list,BLER_list) -------------------------------------------------------------------------------- /writeintxt.py: -------------------------------------------------------------------------------- 1 | def writetxt(N,K,SNR_para,crc_n,decode_method,BER_list,BLER_list,time_cost,construction_method,decode_para): 2 | if crc_n == 0: 3 | txt_line1 = 'Polar('+ str(N) + ','+ str(K)+','+ construction_method.upper()+',' + decode_method.upper() +'), '+ 'decode_para='+'['+str(decode_para[0])+','+str(decode_para[1])+']. '+'SNR= '+ str( 4 | SNR_para[0]) + ' : ' + str(SNR_para[1]) + ' : ' + str(SNR_para[2])+ '. No CRC.' + ' Time Cost: '+str(time_cost)+'s'+'\n' 5 | else: 6 | txt_line1 = 'Polar('+ str(N) + ','+ str(K)+','+ construction_method.upper()+',' + decode_method.upper() +'), '+ 'decode_para='+'['+str(decode_para[0])+','+str(decode_para[1])+']. '+'SNR= '+ str( 7 | SNR_para[0]) + ' : ' + str(SNR_para[1]) + ' : ' + str(SNR_para[2])+ '. CRC-' + str(crc_n) +'. Time Cost: '+str(time_cost)+'s'+ '\n' 8 | text = 'BER: ' + str(BER_list) + '\n' + 'BLER: ' + str(BLER_list) + '\n' 9 | 10 | with open('ber_bler.txt','a') as txtfile: 11 | txtfile.write(txt_line1) 12 | txtfile.write(text) --------------------------------------------------------------------------------