├── .gitignore ├── README.md ├── complexity_analysis ├── complexityAnalysis.m └── scr │ ├── gramMatrix.m │ ├── iddFiltering.m │ ├── initialLMMSEFiltering.m │ └── preprocessing.m ├── scr └── simulation_script.py └── source ├── customOFDMChannel.py ├── dampenedLdpc5gDecoder.py ├── decoder_v1.py ├── mmsePIC.py └── simulationFunctions.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore all .files and .folders 2 | .* 3 | !.gitignore 4 | fig 5 | *.png 6 | *pycache* 7 | nvvm 8 | .idea 9 | .idea/* 10 | *.asv 11 | *.xml 12 | ./results 13 | *.csv 14 | *.pickle 15 | .mat 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DUIDD: Deep-Unfolded Interleaved Detection and Decoding for MIMO Wireless Systems 2 | 3 | This simulator implements the experiments of the paper “DUIDD: Deep-Unfolded Interleaved Detection and Decoding for MIMO Wireless Systems,” 4 | *R. Wiesmayr, C. Dick, 5 | J. Hoydis, and C. Studer, Procs. Asilomar Conf. Signals, Syst., Comput., Oct. 2022, available at https://arxiv.org/abs/2212.07816* 6 | 7 | The simulations are implemented with [NVIDIA Sionna](https://github.com/NVlabs/sionna) Release v0.11 and own extensions. 8 | 9 | Parts of the code are also based on 10 | - *R. Wiesmayr, G. Marti, C. Dick, H. Song, and C. Studer 11 | “Bit Error and Block Error Rate Training for ML-Assisted 12 | Communication,” arXiv:2210.14103, 2022*, available at https://arxiv.org/abs/2210.14103 13 | - *C. Studer, S. Fateh, and D. Seethaler, “ASIC Implementation of Soft-Input Soft-Output MIMO Detection Using MMSE Parallel Interference 14 | Cancellation,” IEEE Journal of Solid-State Circuits, vol. 46, no. 7, pp. 1754–1765, July 2011, available at https://www.csl.cornell.edu/~studer/papers/11JSSC-mmsepic.pdf* 15 | 16 | If you are using this simulator (or parts of it) for a publication, you must cite the above-mentioned references and clearly mention this in your paper. 17 | 18 | ## Running simulations 19 | Please have your Python environment ready with NVIDIA Sionna v0.11, as the code was developed and tested for this version. 20 | 21 | The main simulation script `simulation_script.py` is located in `./scr` and contains multiple simulation parameters that can be modified at will. 22 | The script trains the specified signal processing models and evaluates a performance benchmark. At the end, bit error rate and block error rate curves are plotted and saved. 23 | 24 | > The ray-tracing channels utilized by our script can be downloaded from [here](https://iis-nextcloud.ee.ethz.ch/s/PMDAyWzc6kXwqMS). 25 | 26 | Before running the simulations, the following directories have to be created: 27 | - `./data/weights/` for saving the trained model weights 28 | - `./results` for the simulation results (BER and BLER curves), which are saved as `.csv` and `.pickle` files 29 | - If you want to use the ray-tracing channels, download and place them under `./data/channels/` 30 | 31 | ## Version history 32 | 33 | - Version 0.1: [wiesmayr@iis.ee.ethz.ch](wiesmayr@iis.ee.ethz.ch) - initial version for GitHub release -------------------------------------------------------------------------------- /complexity_analysis/complexityAnalysis.m: -------------------------------------------------------------------------------- 1 | % Computational Complexity Analysis Script 2 | % 3 scenarios: 16, 32, 64 and 128 BS Antennas 3 | % Sweeps with N_T=1:M_R 4 | % Three scenarios for T=1,14 and sweep from 1:100 5 | % 6 | % This analysis counts the number of real valued multiplications for detection 7 | % 8 | 9 | clear all 10 | close all 11 | 12 | addpath ./scr 13 | 14 | QR_BACKSUB = "QR-BACKSUB"; 15 | QR_EXPLICIT = "QR-EXPLICIT"; 16 | MF_BACKSUB = "MF-BACKSUB"; 17 | EXPL_INV = "EXPL-INV"; 18 | EXPL_FILTER = "EXPL-FILTER"; 19 | 20 | MMSE_PIC = "MMSE-PIC"; 21 | MF_LMMSE = "MF-LMMSE"; 22 | 23 | Q = -4; 24 | 25 | %% Analyzing Initial LMMSE Detection (to determine which method is optimal) 26 | 27 | for T = [1, 10, 14] 28 | for M_R = [4, 8, 16, 32, 64, 128] 29 | figure 30 | hold on 31 | title(sprintf('Initial Filtering: UE-Sweep w. T = %d M_R = %d', T, M_R)) 32 | N_T = 1:M_R; 33 | plot(N_T, initialLMMSEFiltering(QR_BACKSUB, M_R, N_T, T, Q),'b-','Linewidth',2) 34 | plot(N_T, initialLMMSEFiltering(QR_EXPLICIT, M_R, N_T, T, Q),'c-','Linewidth',2) 35 | plot(N_T, initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r-','Linewidth',2) 36 | plot(N_T, initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'m-','Linewidth',2) 37 | plot(N_T, initialLMMSEFiltering(EXPL_FILTER, M_R, N_T, T, Q),'g-','Linewidth',2) 38 | legend(QR_BACKSUB, QR_EXPLICIT, "Chol. " + MF_BACKSUB, "Chol. " + EXPL_INV, "Chol. " + EXPL_FILTER, Location="northwest") 39 | xlabel("N_T") 40 | ylabel("Real Valued Multiplications") 41 | hold off 42 | end 43 | end 44 | 45 | % Sweeping over T 46 | T=1:1000; 47 | settings = {[4,4], [16,4], [16,8], [32,8], [64,16], [128,32]}; 48 | for i=1:length(settings) 49 | M_R= settings{i}(1); 50 | N_T= settings{i}(2); 51 | figure 52 | loglog(T, initialLMMSEFiltering(QR_BACKSUB, M_R, N_T, T, Q),'b-','Linewidth',2) 53 | title(sprintf('Initial Filtering: T-Sweep M_R = %d N_T = %d', M_R, N_T)) 54 | hold on 55 | loglog(T, initialLMMSEFiltering(QR_EXPLICIT, M_R, N_T, T, Q),'c-','Linewidth',2) 56 | loglog(T, initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r-','Linewidth',2) 57 | loglog(T, initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'m-','Linewidth',2) 58 | loglog(T, initialLMMSEFiltering(EXPL_FILTER, M_R, N_T, T, Q),'g-','Linewidth',2) 59 | legend(QR_BACKSUB, QR_EXPLICIT, "Chol. " + MF_BACKSUB, "Chol. " + EXPL_INV, "Chol. " + EXPL_FILTER, Location="northwest") 60 | xlabel("T") 61 | ylabel("Real Valued Multiplications") 62 | hold off 63 | end 64 | 65 | %% Analyzing the IDD Detection Complexity Increase 66 | 67 | for T = [1, 10, 14] 68 | for M_R = [16, 32, 64, 128] 69 | figure 70 | hold on 71 | title(sprintf('IDD and Baseline EXPL-INV: UE-Sweep w. T = %d M_R = %d', T, M_R)) 72 | N_T = 1:M_R; 73 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'b-','Linewidth',2) 74 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'r-','Linewidth',2) 75 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'b-.','Linewidth',2) 76 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'r-.','Linewidth',2) 77 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'b--','Linewidth',2) 78 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'r--','Linewidth',2) 79 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'b:','Linewidth',2) 80 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q),'r:','Linewidth',2) 81 | legend(MF_LMMSE+"I=1", MMSE_PIC+"I=1", MF_LMMSE+"I=2", MMSE_PIC+"I=2", MF_LMMSE+"I=3", MMSE_PIC+"I=3", MF_LMMSE+"I=4", MMSE_PIC+"I=4", Location="northwest") 82 | xlabel("N_T") 83 | ylabel("(#MUL IDD Iter)/(#MUL Baseline LMMSE)") 84 | hold off 85 | end 86 | end 87 | 88 | for T = [1, 10, 14] 89 | for M_R = [16, 32, 64, 128] 90 | figure 91 | hold on 92 | title(sprintf('IDD w/ EXPL-INV, Baseline MF-BACKSUB: UE-Sweep w. T = %d M_R = %d', T, M_R)) 93 | N_T = 1:M_R; 94 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b-','Linewidth',2) 95 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r-','Linewidth',2) 96 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b-.','Linewidth',2) 97 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r-.','Linewidth',2) 98 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b--','Linewidth',2) 99 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r--','Linewidth',2) 100 | plot(N_T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b:','Linewidth',2) 101 | plot(N_T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r:','Linewidth',2) 102 | legend(MF_LMMSE+"I=1", MMSE_PIC+"I=1", MF_LMMSE+"I=2", MMSE_PIC+"I=2", MF_LMMSE+"I=3", MMSE_PIC+"I=3", MF_LMMSE+"I=4", MMSE_PIC+"I=4", Location="northwest") 103 | xlabel("N_T") 104 | ylabel("(#MUL IDD Iter)/(#MUL Baseline LMMSE)") 105 | hold off 106 | end 107 | end 108 | 109 | % Sweeping over T 110 | T=1:1000; 111 | settings = {[4,4], [16,4], [16,8], [32,8], [64,16], [128,32]}; 112 | for i=1:length(settings) 113 | M_R= settings{i}(1); 114 | N_T= settings{i}(2); 115 | figure 116 | semilogx(T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b-','Linewidth',2) 117 | hold on 118 | semilogx(T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r-','Linewidth',2) 119 | semilogx(T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b-.','Linewidth',2) 120 | semilogx(T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r-.','Linewidth',2) 121 | semilogx(T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b--','Linewidth',2) 122 | semilogx(T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r--','Linewidth',2) 123 | semilogx(T, (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'b:','Linewidth',2) 124 | semilogx(T, (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q))./initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q),'r:','Linewidth',2) 125 | title(sprintf('IDD-Complexity/Initial Filtering: T-Sweep M_R = %d N_T = %d', M_R, N_T)) 126 | legend(MF_LMMSE+"I=1", MMSE_PIC+"I=1", MF_LMMSE+"I=2", MMSE_PIC+"I=2", MF_LMMSE+"I=3", MMSE_PIC+"I=3", MF_LMMSE+"I=4", MMSE_PIC+"I=4", Location="northwest") 127 | xlabel("T") 128 | ylabel("(#MUL IDD Iter)/(#MUL Baseline LMMSE)") 129 | hold off 130 | end 131 | 132 | %% Complexity of considered scenario 133 | M_R = 16; 134 | N_T = 4; 135 | T_arr = [1, 10, 14]; 136 | for T=T_arr 137 | baseline_lmmse = initialLMMSEFiltering(MF_BACKSUB, M_R, N_T, T, Q); 138 | mmse_pic_1 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 139 | mmse_pic_2 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 140 | mmse_pic_3 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 141 | mmse_pic_4 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 142 | lmmse_mf_1 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 143 | lmmse_mf_2 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 144 | lmmse_mf_3 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 145 | lmmse_mf_4 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 146 | 147 | data_table_ = table(baseline_lmmse, mmse_pic_1, mmse_pic_2, mmse_pic_3, mmse_pic_4, lmmse_mf_1, lmmse_mf_2, lmmse_mf_3, lmmse_mf_4); 148 | data_table_.Properties.VariableNames = ["LMMSE", MMSE_PIC+"1", MMSE_PIC+"2", MMSE_PIC+"3", MMSE_PIC+"4", MF_LMMSE+"1", MF_LMMSE+"2", MF_LMMSE+"3", MF_LMMSE+"4"]; 149 | dataset_title = sprintf('IDD-Complexity M_R = %d N_T = %d T=%d', M_R, N_T, T); 150 | writetable(data_table_,"./data/"+string(dataset_title)+".csv",'Delimiter',',','QuoteStrings',true) 151 | end 152 | 153 | M_R = 8; 154 | N_T = 4; 155 | T_arr = [10]; 156 | for T=T_arr 157 | baseline_lmmse = initialLMMSEFiltering(QR_EXPLICIT, M_R, N_T, T, Q); % QR Explicit more efficient for 8x4 and T=10 158 | mmse_pic_1 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 159 | mmse_pic_2 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 160 | mmse_pic_3 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 161 | mmse_pic_4 = (iddFiltering(MMSE_PIC, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 162 | lmmse_mf_1 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 1) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 163 | lmmse_mf_2 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 2) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 164 | lmmse_mf_3 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 3) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 165 | lmmse_mf_4 = (iddFiltering(MF_LMMSE, EXPL_INV, M_R, N_T, T, Q, 4) + initialLMMSEFiltering(EXPL_INV, M_R, N_T, T, Q)); 166 | 167 | data_table_ = table(baseline_lmmse, mmse_pic_1, mmse_pic_2, mmse_pic_3, mmse_pic_4, lmmse_mf_1, lmmse_mf_2, lmmse_mf_3, lmmse_mf_4); 168 | data_table_.Properties.VariableNames = ["LMMSE", MMSE_PIC+"1", MMSE_PIC+"2", MMSE_PIC+"3", MMSE_PIC+"4", MF_LMMSE+"1", MF_LMMSE+"2", MF_LMMSE+"3", MF_LMMSE+"4"]; 169 | dataset_title = sprintf('IDD-Complexity M_R = %d N_T = %d T=%d', M_R, N_T, T); 170 | writetable(data_table_,"./data/"+string(dataset_title)+".csv",'Delimiter',',','QuoteStrings',true) 171 | end 172 | 173 | -------------------------------------------------------------------------------- /complexity_analysis/scr/gramMatrix.m: -------------------------------------------------------------------------------- 1 | function [mul_compl] = gramMatrix(M_R, N_T) 2 | %GRAMMATRIX calculating the Gram matrix: 3 | % we only need to calculate the upper 4 | % triangular part becaues the Gram is hermitian 5 | mul_compl = M_R*(N_T+1).*N_T/2; 6 | end 7 | 8 | -------------------------------------------------------------------------------- /complexity_analysis/scr/iddFiltering.m: -------------------------------------------------------------------------------- 1 | function [rv_mul] = iddFiltering(Detector, Scenario, M_R, N_T, T, Q, L) 2 | %{ 3 | iddFiltering calculates the multiplication count of the IDD Detection, 4 | considering an efficient implementation of textbook algorithms. 5 | 6 | Inputs: 7 | Detector "Low-Complexity-MF-LMMSE" or "MMSE-PIC" 8 | Scenario "EXPL-INV" or "EXPL-FILTER" 9 | M_R Number of BS antennas 10 | N_T Number of Transmitters 11 | T Number of time slots per coherence interval (block fading) 12 | Q number of bits, positive is general case, negative values 13 | apply low-complexity gray-labeling QAM trick 14 | Outputs: 15 | rv_mul Number of real valued multiplications 16 | 17 | %} 18 | 19 | init_h 20 | 21 | %% Soft Symbol Calculation (applying the Gray labeling trick, best 22 | % implemenetation) 23 | rv_mul = 0; 24 | 25 | % p0 = 0.5 * (1 - tf.math.tanh(0.5 * llr_a)) 26 | rv_mul = rv_mul + N_T*T*Q*2; 27 | switch Q 28 | case -1 29 | rv_mul = rv_mul + 1*N_T*T; 30 | case -2 31 | rv_mul = rv_mul + 4*N_T*T; 32 | case -4 33 | rv_mul = rv_mul + 10*N_T*T; 34 | case -6 35 | rv_mul = rv_mul + 22*N_T*T; 36 | otherwise 37 | % calculate symbol probability 38 | rv_mul = rv_mul + N_T*T*(2^Q)*Q; 39 | % s_hat = tf.reduce_sum(points_reshaped * tf.cast(P_C, 40 | % tf.complex64), axis=-1) complex!! 41 | rv_mul = rv_mul + 4*T*N_T*2^Q; 42 | % Calculate Squared Error 43 | rv_mul = rv_mul + 2*4*2^Q * N_T * T; 44 | % Calculate Error Variance 45 | rv_mul = rv_mul + 2^Q * N_T * T; 46 | end 47 | 48 | %% Parallel Interference Cancellation 49 | compl_mul = 0; 50 | switch Scenario 51 | case EXPL_INV 52 | % y_MF, G already calculated 53 | % efficient implementation of PIC with y_MF 54 | compl_mul = compl_mul + T*N_T.^2; 55 | case EXPL_FILTER 56 | % no y_MF and no G... really perform PIC on y (not y_MF) 57 | compl_mul = compl_mul + T*N_T*M_R; 58 | otherwise 59 | end 60 | 61 | %% Equalization 62 | switch Detector 63 | case MMSE_PIC 64 | % has to calculate everything (also inverse) for every t 65 | switch Scenario 66 | case EXPL_INV 67 | % needs to do LU Decomposition instead of Cholesky becaues 68 | % A is not hermitian 69 | 70 | % calculation of A 71 | % G*diag(error_var) (Gram matrix is Hermitian, only 72 | % calculate upper triangular muls) 73 | compl_mul = compl_mul + T*1/2*N_T.*(N_T+1); 74 | 75 | % LU-decomposition 76 | compl_mul = compl_mul + T*(1/3*N_T.^3 + 1/4*N_T.^2 + 3/4*N_T - 1/2); 77 | 78 | % Forward substitution of L (invert L explicitly, diag is one) 79 | compl_mul = compl_mul + T*(1/6*N_T.^3 - 1/6*N_T); 80 | 81 | % back-substitution of U (N_T times vector backsub, nothing 82 | % is real) 83 | compl_mul = compl_mul + T*N_T.*((1/2)*(N_T.^2 + N_T)); 84 | 85 | % calculate mu_i 86 | % mu_i = a_i^H g_i = 1- \sigma_n^2(A^-1)_ii (real) 87 | compl_mul = compl_mul + T*1/4*N_T; 88 | 89 | % calculate MMSE output 90 | % Filtering A^-1 y_MF 91 | compl_mul = compl_mul + T*N_T.^2; 92 | 93 | % NPI calculation (rho_i = mu_i/(1-mu_i)) 94 | compl_mul = compl_mul + (T/4)*N_T; 95 | case EXPL_FILTER 96 | error('not implemented') 97 | otherwise 98 | error('not implemented') 99 | end 100 | case MF_LMMSE 101 | switch Scenario 102 | case EXPL_INV 103 | % y_MF, G, mu_i and A^-1 already calculated 104 | % calculating filter matrix Z_bar 105 | compl_mul = compl_mul + N_T.^2 + 2*N_T; 106 | 107 | % NPI calculation (assuming SNR_Heuristic 4) 108 | rv_mul = rv_mul + T*(N_T.^2); 109 | 110 | % Filtering 111 | compl_mul = compl_mul + T*N_T.^2; 112 | case EXPL_FILTER 113 | error('not implemented') 114 | otherwise 115 | error('not implemented') 116 | end 117 | otherwise 118 | error('not implemented') 119 | end 120 | 121 | %% LLR calculation 122 | switch Q 123 | case -1 124 | rv_mul = rv_mul + 1*N_T*T; 125 | case -2 126 | rv_mul = rv_mul + 2*N_T*T; 127 | case -4 128 | rv_mul = rv_mul + 4*N_T*T; 129 | case -6 130 | rv_mul = rv_mul + 6*N_T*T; 131 | otherwise 132 | rv_mul = rv_mul + 3*T*N_T*2^Q; 133 | end 134 | %Multiply lambda_b my rho_i 135 | rv_mul = rv_mul + T*Q*N_T; 136 | 137 | rv_mul = rv_mul + 4*compl_mul; 138 | 139 | %% L IDD iterations (i.e. L-1 times this filtering) 140 | rv_mul = (L-1)*rv_mul; 141 | 142 | % include initial squared and abs Gram Matrix calculation 143 | if Scenario == EXPL_INV && Detector == MF_LMMSE && L >= 2 144 | rv_mul = rv_mul + 2*N_T.^2; 145 | end 146 | 147 | % rv_mul = int64(rv_mul); 148 | 149 | end 150 | 151 | -------------------------------------------------------------------------------- /complexity_analysis/scr/initialLMMSEFiltering.m: -------------------------------------------------------------------------------- 1 | function [rv_mul] = initialLMMSEFiltering(Scenario, M_R, N_T, T, Q) 2 | %{ 3 | initialLMMSEFiltering calculates the multiplication count of the initial 4 | LMMSE Detection, considering an efficient implementation of 5 | textbook-algorithms, applying several low-complexity tricks. 6 | Inputs: 7 | Scenario "MF_BACKSUB" for matched filtering w/ backsubstitution, 8 | "EXPL_INV" for explicitly calculating the inverse of A, 9 | "EXPL_FILTER" for explicitly calculating the filter 10 | M_R Number of BS antennas 11 | N_T Number of Transmitters 12 | T Number of time slots per coherence interval (block fading) 13 | Q number of bits, positive is general case, negative value 14 | applies Gray Labeling QAM trick (negative Q \in {-1,-2,-4,-6}) 15 | Outputs: 16 | rv_mul Number of real valued multiplications 17 | %} 18 | init_h 19 | 20 | mul = 0; 21 | if Scenario == QR_BACKSUB || Scenario == QR_EXPLICIT 22 | % calculate reduced QR decomposition of [H; \sigma_n I_M_R] 23 | % considering that the lower part corresponding to \sigma_n 24 | % I_M_R is lower triangular with real values on the diagonal. 25 | mul = M_R.*(N_T.^2) + 1/3*N_T.^3 - 1/4*N_T.^2 -5/12*N_T; 26 | 27 | % \tilde{Q}_a is M_R x N_T (part that corresponds to H), 1/sigma_n Q_b 28 | % is R^-1 29 | % R_red^-1 (\tilde{Q}_a^H y) 30 | 31 | if Scenario == QR_BACKSUB 32 | % \tilde{Q}_a^H y 33 | mul = mul + N_T*M_R*T; 34 | % R_red^-1 (equivalent: \tilde{Q}_b) w/ back substitution 35 | % (backsub is the same complexity as multiplying 36 | % upper triangular matrix with vector, diag R_red is real) 37 | mul = mul + T*1/2*(N_T.^2); 38 | elseif Scenario == QR_EXPLICIT 39 | % 1/sigma_n Q_b 40 | mul = mul + N_T.^2/2; 41 | % W = (1/sigma_n Q_b) \tilde{Q}_a^H 42 | mul = mul + 1/2*N_T.^2*M_R; 43 | % filtering W*y 44 | mul = mul + N_T*M_R*T; 45 | end 46 | 47 | % NPI calculatoin: mu = diag(I - Q_b Q_b^H) (Q_b upper triangular, diag is real) 48 | mul = mul + 1/4*(N_T.^2); 49 | else 50 | % Gram Matrix calculation 51 | mul = mul + gramMatrix(M_R, N_T); 52 | % Cholesky Factorization of A = (G + n0*eye(M_T)) = L D L^H 53 | mul = mul + 1/6*N_T.^3 - 1/2*N_T.^2 - 1/3*N_T; 54 | % A = L D L^H 55 | end 56 | 57 | % Filtering x_hat = A \ H' y 58 | switch Scenario 59 | case MF_BACKSUB 60 | % matched filtering y_MF = H' * y 61 | mul = mul + T*N_T*M_R; 62 | % Forward Substitution L^-1 * y_MF (same complexity as directly 63 | % multiplying y with L^-1, diag of L is 1) 64 | mul = mul + T*(1/2)*N_T.*(N_T-1); 65 | % Scale w D^-1 (real) 66 | mul = mul + 1/2*N_T*T; 67 | % Backsubstitutioin L^-H (diag of L is one) 68 | mul = mul + T*(1/2)*N_T.*(N_T-1); 69 | 70 | % Calculate mu (for filter normalizatioin and NPI calculatoin) 71 | % mu_i = diag(L^-H (D^-1 (L^-1 G))) 72 | % by inverting L explicitly (diag L is one) 73 | mul = mul + 1/6*N_T.^3 - 1/6*N_T; 74 | % calculate all mu_i 75 | % mu_i = 1- \sigma_n^2 \sum_j=i^N_T D_jj^-1 |L^-1_ji|^2 76 | % (diag L is one) 77 | mul = mul + 1/4*(3/2 * N_T.^2 - 1/2*N_T); 78 | case EXPL_INV 79 | % matched filtering y_MF = H' * y 80 | mul = mul + T*N_T*M_R; 81 | % Forward Substitution (calc inverse) L^-1 (diag of L is one) 82 | mul = mul+ 1/6*N_T.^3 - 1/6*N_T; 83 | % Scaling D^-1 L^-1 (D is real, diag L is one) 84 | mul = mul + N_T.*(N_T-1)/4; 85 | % Multiplication of two triangular matrices L^-H (D L^-1) (diagonal 86 | % elements of right part are real, left part one, we only need to 87 | % calculate upper triangular part, because A is hermitian) 88 | mul = mul + 1/6*N_T.^3 + 1/4*N_T.^2 - 5/12*N_T; 89 | % Now, we have A^-1 90 | 91 | % Filtering A^-1 y_MF 92 | mul = mul + T*N_T.^2; 93 | 94 | % Calculate mu (for filter normalizatioin and NPI calculatoin) 95 | % mu_i = 1-\sigma_n^2 (A^-1)_ii (real) 96 | mul = mul + 1/4*N_T; 97 | case EXPL_FILTER 98 | % No Matched Filtering!! 99 | 100 | % W = L^-H (D^-1 (L^-1 H^H)) 101 | % forward substitution M_R times (diag L is one) 102 | mul = mul + 1/2*(N_T.*(N_T-1))*M_R; 103 | % scale with D^-1 (real) 104 | mul = mul + N_T*M_R/2; 105 | % backsub L^-H M_R times (diag L is one) 106 | mul = mul + 1/2*(N_T.*(N_T-1))*M_R; 107 | 108 | % filtering W y 109 | mul = mul + N_T*M_R*T; 110 | 111 | % calculate mu_i = w_i' * h_i 112 | mul = mul + N_T*M_R; 113 | case QR_BACKSUB 114 | case QR_EXPLICIT 115 | otherwise 116 | error('Unknown Scenario') 117 | end 118 | 119 | % scale equalized symbols for unbiasedness (i.e., \tilde{x}_i = 120 | % \hat{x_i}/mu_i and NPI calculation (rho_i = mu_i/(1-mu_i)) 121 | mul = mul + (T+1/4)*N_T; 122 | 123 | rv_mul = 4*mul; 124 | 125 | % LLR calculation 126 | switch Q 127 | case -1 128 | rv_mul = rv_mul + 1*N_T*T; 129 | case -2 130 | rv_mul = rv_mul + 2*N_T*T; 131 | case -4 132 | rv_mul = rv_mul + 4*N_T*T; 133 | case -6 134 | rv_mul = rv_mul + 6*N_T*T; 135 | otherwise 136 | rv_mul = rv_mul + 3*T*N_T*2^Q; 137 | end 138 | %Multiply lambda_b my rho_i 139 | rv_mul = rv_mul + T*Q*N_T; 140 | 141 | % rv_mul = int64(rv_mul); 142 | 143 | end 144 | 145 | -------------------------------------------------------------------------------- /complexity_analysis/scr/preprocessing.m: -------------------------------------------------------------------------------- 1 | function [rv_mul] = preprocessing(Scenario, M_R, N_T, T, IDD) 2 | %{ 3 | preprocessing calculates the multiplication count of the preprocessing 4 | Inputs: 5 | Scenario "MF_BACKSUB" for matched filtering w/ backsubstitution, 6 | "EXPL_INV" for explicitly calculating the inverse of A, 7 | "EXPL_FILTER" for explicitly calculating the filter 8 | M_R Number of BS antennas 9 | N_T Number of Transmitters 10 | T Number of time slots per coherence interval (block fading) 11 | IDD boolean, true if IDD 12 | Outputs: 13 | rv_mul Number of real valued multiplications 14 | %} 15 | mul = 0; 16 | 17 | % calculating the Gram matrix 18 | mul = mul + M_R*N_T^2; 19 | 20 | rv_mul = 4*mul; 21 | 22 | end 23 | 24 | -------------------------------------------------------------------------------- /source/customOFDMChannel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layer for sampling OFDM MIMO channels from a set of SIMO vectors (based on Sionna's OFDM channel class) 3 | """ 4 | import sionna.ofdm 5 | import tensorflow as tf 6 | from sionna.channel import subcarrier_frequencies 7 | from sionna.utils import expand_to_rank 8 | from tensorflow.keras.layers import Layer 9 | from sionna.channel.apply_ofdm_channel import ApplyOFDMChannel 10 | import numpy as np 11 | 12 | class customOFDMChannel(Layer): 13 | # pylint: disable=line-too-long 14 | r"""custom OFDM Channel 15 | Applies OFDM Channels generated by REMCOM 16 | Randomly selects UE possitions 17 | OFDM frequency responses preprocessed by MATLAB Script 18 | """ 19 | 20 | def __init__(self, channel_set, resource_grid, add_awgn=True, 21 | normalize_channel=False, return_channel=False, chanIdxComb=None, randomSubSamplingChanIdx=False, 22 | dtype=tf.complex64, **kwargs): 23 | super().__init__(trainable=False, dtype=dtype, **kwargs) 24 | 25 | self._channels = tf.cast(channel_set[:, :, :resource_grid.fft_size], dtype=dtype) 26 | self._rg = resource_grid 27 | self._add_awgn = add_awgn 28 | self._normalize_channel = normalize_channel 29 | self._return_channel = return_channel 30 | 31 | self._chanIdxComb = tf.constant(chanIdxComb, dtype=tf.int32) 32 | self._randomSubSamplingChanIdx = randomSubSamplingChanIdx 33 | 34 | if normalize_channel: 35 | # Normalization is performed such that for each batch example and 36 | # link, the mean energy per resource grid is one. 37 | # Average over TX antennas, RX antennas, and subcarriers. 38 | c = tf.reduce_mean(tf.square(tf.abs(self._channels)), axis=(1, 2), keepdims=True) 39 | self._channels = self._channels / tf.complex(tf.sqrt(c), tf.constant(0., dtype.real_dtype)) 40 | 41 | def build(self, input_shape): # pylint: disable=unused-argument 42 | self._apply_channel = ApplyOFDMChannel(self._add_awgn, tf.as_dtype(self.dtype)) 43 | 44 | def call(self, inputs): 45 | 46 | if self._add_awgn: 47 | x, no = inputs 48 | else: 49 | x = inputs 50 | 51 | batchsize = tf.shape(x)[0] 52 | n_tx = self._rg.num_tx 53 | num_rx_ant = tf.shape(self._channels)[1] 54 | 55 | # randomly select channel 56 | if self._chanIdxComb is None: 57 | # Limited randomnes, strategy (Matlab code) 58 | # itx = num_tx * randi((batch_size, num_tx), 0, int(num_channels/num_tx)-1) + (0:(num_tx-1)) 59 | chan_batch_ue_idx = n_tx*tf.random.uniform((batchsize, n_tx), 0, int(tf.shape(self._channels)[0]/n_tx) - 1, 60 | tf.int32)+tf.range(n_tx, dtype=tf.int32) 61 | else: 62 | if self._randomSubSamplingChanIdx: 63 | # randomly sample rows of channel index combinations 64 | # tensorflow seed is reset whenever the model graph is re-traced (so, all the models apply the same channel indices) 65 | numIdx = int(tf.shape(self._chanIdxComb)[0]) 66 | rows = tf.reshape(tf.range(0, numIdx, dtype=tf.int64), [numIdx, 1]) 67 | rows = tf.random.shuffle(rows)[:batchsize] 68 | # tf.print(rows[0]) 69 | # [rows, _, _] = tf.random.uniform_candidate_sampler(cands, numIdx, batchsize, True, numIdx) 70 | chan_batch_ue_idx = tf.gather(self._chanIdxComb, tf.squeeze(rows), axis=0) 71 | else: 72 | chan_batch_ue_idx = self._chanIdxComb 73 | # select channels for all batch elements 74 | h_freq = tf.gather(self._channels, chan_batch_ue_idx, axis=0) 75 | # h_freq is [batch_size, num_tx, num_rx_ant, num_ofdm_channels] 76 | # expand to dimensions [batch_size, 1, num_tx, 1, num_rx_ant, 1, num_ofdm_channels] 77 | h_freq = tf.expand_dims(h_freq, axis=1) 78 | h_freq = tf.expand_dims(h_freq, axis=3) 79 | h_freq = tf.expand_dims(h_freq, axis=-2) 80 | # h_freq is [batch_size, 1, num_tx, 1, num_rx_ant, 1, num_ofdm_channels] 81 | h_freq = tf.transpose(h_freq, perm=[0, 1, 4, 2, 3, 5, 6]) 82 | # block-fading channel same channels over same OFDM block (ie. for all OFDM symbols) 83 | h_freq = tf.tile(h_freq, [1, 1, 1, 1, 1, self._rg.num_ofdm_symbols, 1]) 84 | # h_freq: [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm_symbols, fft_size] 85 | # reshape h_freq (force shape to be [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_ofdm_symbols, fft_size] 86 | chan_shape = tf.concat([[batchsize], [1], [num_rx_ant], [n_tx], [1], [self._rg.num_ofdm_symbols], [self._rg.fft_size]], 0) 87 | h_freq = tf.reshape(h_freq, chan_shape) 88 | 89 | if self._add_awgn: 90 | y = self._apply_channel([x, h_freq, no]) 91 | else: 92 | y = self._apply_channel([x, h_freq]) 93 | 94 | if self._return_channel: 95 | return y, h_freq 96 | else: 97 | return y 98 | -------------------------------------------------------------------------------- /source/dampenedLdpc5gDecoder.py: -------------------------------------------------------------------------------- 1 | 2 | # 3 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # SPDX-License-Identifier: Apache-2.0 5 | # 6 | """Layers for channel decoding and utility functions.""" 7 | 8 | ## Modified by R. Wiesmayr in October 2022: 9 | # Extended the Sionna implementation of the LDPC BP decoder by message damping 10 | 11 | import tensorflow as tf 12 | from sionna.fec.ldpc.encoding import LDPC5GEncoder 13 | from sionna.fec.utils import llr2mi 14 | 15 | from source.decoder_v1 import LDPC5GDecoder1 16 | 17 | 18 | class dampenedLDPC5GDecoder(LDPC5GDecoder1): 19 | # pylint: disable=line-too-long 20 | r"""LDPC5GDecoder(encoder, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, 21 | return_infobits=True, prune_pcm=True, num_iter=20, stateful=False, output_dtype=tf.float32, **kwargs) 22 | 23 | (Iterative) belief propagation decoder for 5G NR LDPC codes. 24 | 25 | Inherits from :class:`~sionna.fec.ldpc.decoding.LDPCBPDecoder` and provides 26 | a wrapper for 5G compatibility, i.e., automatically handles puncturing and 27 | shortening according to [3GPPTS38212_LDPC]_. 28 | 29 | Note that for full 5G 3GPP NR compatibility, the correct puncturing and 30 | shortening patterns must be applied and, thus, the encoder object is 31 | required as input. 32 | 33 | If required the decoder can be made trainable and is differentiable 34 | (the training of some check node types may be not supported) following the 35 | concept of "weighted BP" [Nachmani]_. 36 | 37 | The class inherits from the Keras layer class and can be used as layer in a 38 | Keras model. 39 | 40 | Parameters 41 | ---------- 42 | encoder: LDPC5GEncoder 43 | An instance of :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder` 44 | containing the correct code parameters. 45 | 46 | trainable: bool 47 | Defaults to False. If True, every outgoing variable node message is 48 | scaled with a trainable scalar. 49 | 50 | cn_type: str 51 | A string defaults to '"boxplus-phi"'. One of 52 | {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where 53 | '"boxplus"' implements the single-parity-check APP decoding rule. 54 | '"boxplus-phi"' implements the numerical more stable version of 55 | boxplus [Ryan]_. 56 | '"minsum"' implements the min-approximation of the CN 57 | update rule [Ryan]_. 58 | 59 | hard_out: bool 60 | Defaults to True. If True, the decoder provides hard-decided 61 | codeword bits instead of soft-values. 62 | 63 | track_exit: bool 64 | Defaults to False. If True, the decoder tracks EXIT characteristics. 65 | Note that this requires the all-zero CW as input. 66 | 67 | return_infobits: bool 68 | Defaults to True. If True, only the `k` info bits (soft or 69 | hard-decided) are returned. Otherwise all `n` positions are 70 | returned. 71 | 72 | prune_pcm: bool 73 | Defaults to True. If True, all punctured degree-1 VNs and 74 | connected check nodes are removed from the decoding graph (see 75 | [Cammerer]_ for details). Besides numerical differences, this should 76 | yield the same decoding result but improved the decoding throughput 77 | and reduces the memory footprint. 78 | 79 | num_iter: int 80 | Defining the number of decoder iteration (no early stopping used at 81 | the moment!). 82 | 83 | stateful: bool 84 | Defaults to False. If True, the internal VN messages ``msg_vn`` 85 | from the last decoding iteration are returned, and ``msg_vn`` or 86 | `None` needs to be given as a second input when calling the decoder. 87 | This is required for iterative demapping and decoding. 88 | 89 | output_dtype: tf.DType 90 | Defaults to tf.float32. Defines the output datatype of the layer 91 | (internal precision remains tf.float32). 92 | 93 | Input 94 | ----- 95 | llrs_ch or (llrs_ch, msg_vn): 96 | Tensor or Tuple (only required if ``stateful`` is True): 97 | 98 | llrs_ch: [...,n], tf.float32 99 | 2+D tensor containing the channel logits/llr values. 100 | 101 | msg_vn: None or RaggedTensor, tf.float32 102 | Ragged tensor of VN messages. 103 | Required only if ``stateful`` is True. 104 | 105 | Output 106 | ------ 107 | : [...,n] or [...,k], tf.float32 108 | 2+D Tensor of same shape as ``inputs`` containing 109 | bit-wise soft-estimates (or hard-decided bit-values) of all 110 | codeword bits. If ``return_infobits`` is True, only the `k` 111 | information bits are returned. 112 | 113 | : RaggedTensor, tf.float32: 114 | Tensor of VN messages. 115 | Returned only if ``stateful`` is set to True. 116 | Raises 117 | ------ 118 | ValueError 119 | If the shape of ``pcm`` is invalid or contains other 120 | values than `0` or `1`. 121 | 122 | AssertionError 123 | If ``trainable`` is not `bool`. 124 | 125 | AssertionError 126 | If ``track_exit`` is not `bool`. 127 | 128 | AssertionError 129 | If ``hard_out`` is not `bool`. 130 | 131 | AssertionError 132 | If ``return_infobits`` is not `bool`. 133 | 134 | AssertionError 135 | If ``encoder`` is not an instance of 136 | :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder`. 137 | 138 | ValueError 139 | If ``output_dtype`` is not {tf.float16, tf.float32, tf. 140 | float64}. 141 | 142 | ValueError 143 | If ``inputs`` is not of shape `[batch_size, n]`. 144 | 145 | ValueError 146 | If ``num_iter`` is not an integer greater (or equal) `0`. 147 | 148 | InvalidArgumentError 149 | When rank(``inputs``)<2. 150 | 151 | Note 152 | ---- 153 | As decoding input logits 154 | :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for 155 | compatibility with the learning framework, but 156 | internally llrs with definition 157 | :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. 158 | 159 | The decoder is not (particularly) optimized for Quasi-cyclic (QC) LDPC 160 | codes and, thus, supports arbitrary parity-check matrices. 161 | 162 | The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to 163 | account for arbitrary node degrees. To avoid a performance degradation 164 | caused by a severe indexing overhead, the batch-dimension is shifted to 165 | the last dimension during decoding. 166 | 167 | If the decoder is made trainable [Nachmani]_, for performance 168 | improvements only variable to check node messages are scaled as the VN 169 | operation is linear and, thus, would not increase the expressive power 170 | of the weights. 171 | """ 172 | 173 | def __init__(self, 174 | encoder, 175 | trainable=False, 176 | cn_type='boxplus-phi', 177 | hard_out=True, 178 | track_exit=False, 179 | return_infobits=True, 180 | prune_pcm=True, 181 | num_iter=20, 182 | stateful=False, 183 | output_dtype=tf.float32, 184 | alpha0=0, 185 | beta0=0, 186 | trainDamping=False, 187 | constrainAlpha=True, 188 | constrainBeta=True, 189 | **kwargs): 190 | 191 | super().__init__(encoder, 192 | trainable=trainable, 193 | cn_type=cn_type, 194 | hard_out=hard_out, 195 | track_exit=track_exit, 196 | return_infobits=return_infobits, 197 | prune_pcm=prune_pcm, 198 | num_iter=num_iter, 199 | stateful=stateful, 200 | output_dtype=output_dtype, 201 | **kwargs) 202 | 203 | self._trainable = trainDamping or trainable 204 | # with constraints 205 | if constrainAlpha: 206 | self._alpha = tf.Variable(alpha0*tf.ones([num_iter]), dtype=output_dtype, trainable=trainDamping, name="alpha_damping", constraint=lambda x: tf.clip_by_value(x, 0.0,1.0)) 207 | else: 208 | self._alpha = tf.Variable(alpha0 * tf.ones([num_iter]), dtype=output_dtype, trainable=trainDamping, 209 | name="alpha_damping") 210 | if constrainBeta: 211 | self._beta = tf.Variable(beta0*tf.ones([num_iter]), dtype=output_dtype, trainable=trainDamping, name="beta_damping", constraint=lambda x: tf.clip_by_value(x, 0.0,1.0)) 212 | else: 213 | self._beta = tf.Variable(beta0 * tf.ones([num_iter]), dtype=output_dtype, trainable=trainDamping, 214 | name="beta_damping") 215 | 216 | 217 | ######################################### 218 | # Public methods and properties 219 | ######################################### 220 | 221 | @property 222 | def alpha(self): 223 | """Alpha values for dampening.""" 224 | return self._alpha 225 | 226 | @property 227 | def beta(self): 228 | """Alpha values for dampening.""" 229 | return self._beta 230 | 231 | 232 | # def build(self, input_shape): 233 | # """Build model.""" 234 | # if self._stateful: 235 | # assert(len(input_shape)==2), \ 236 | # "For stateful decoding, a tuple of two inputs is expected." 237 | # input_shape = input_shape[0] 238 | # 239 | # # check input dimensions for consistency 240 | # assert (input_shape[-1]==self.encoder.n), \ 241 | # 'Last dimension must be of length n.' 242 | # assert (len(input_shape)>=2), 'The inputs must have at least rank 2.' 243 | # 244 | # self._old_shape_5g = input_shape 245 | # # self._alpha = tf.Variable(self._alpha, dtype=self._output_dtype, trainable=self._trainDamping, name="alpha_damping") 246 | 247 | 248 | def super_call(self, inputs): 249 | """Iterative BP decoding function. 250 | 251 | This function performs ``num_iter`` belief propagation decoding 252 | iterations and returns the estimated codeword. 253 | 254 | Args: 255 | inputs (tf.float32): Tensor of shape `[...,n]` containing the 256 | channel logits/llr values. 257 | 258 | Returns: 259 | `tf.float32`: Tensor of shape `[...,n]` containing 260 | bit-wise soft-estimates (or hard-decided bit-values) of all 261 | codeword bits. 262 | 263 | Raises: 264 | ValueError: If ``inputs`` is not of shape `[batch_size, n]`. 265 | 266 | InvalidArgumentError: When rank(``inputs``)<2. 267 | """ 268 | 269 | # Extract inputs 270 | if self._stateful: 271 | llr_ch, msg_vn = inputs 272 | else: 273 | llr_ch = inputs 274 | 275 | tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') 276 | 277 | # internal calculations still in tf.float32 278 | llr_ch = tf.cast(llr_ch, tf.float32) 279 | 280 | # last dim must be of length n 281 | tf.debugging.assert_equal(tf.shape(llr_ch)[-1], 282 | self._num_vns, 283 | 'Last dimension must be of length n.') 284 | 285 | llr_ch_shape = llr_ch.get_shape().as_list() 286 | new_shape = [-1, self._num_vns] 287 | llr_ch_reshaped = tf.reshape(llr_ch, new_shape) 288 | 289 | # must be done during call, as XLA fails otherwise due to ragged 290 | # indices placed on the CPU device. 291 | # create permutation index from cn perspective 292 | self._cn_mask_tf = tf.ragged.constant(self._gen_node_mask(self._cn_con), 293 | row_splits_dtype=tf.int32) 294 | 295 | # batch dimension is last dimension due to ragged tensor representation 296 | llr_ch = tf.transpose(llr_ch_reshaped, (1,0)) 297 | 298 | llr_ch = -1. * llr_ch # logits are converted into "true" llrs 299 | 300 | # init internal decoder state if not explicitly 301 | # provided (e.g., required to restore decoder state for iterative 302 | # detection and decoding) 303 | # load internal state from previous iteration 304 | # required for iterative det./dec. 305 | if not self._stateful or msg_vn is None: 306 | msg_shape = tf.stack([tf.constant(self._num_edges), 307 | tf.shape(llr_ch)[1]], 308 | axis=0) 309 | msg_vn = tf.zeros(msg_shape, dtype=tf.float32) 310 | else: 311 | msg_vn = msg_vn.flat_values 312 | 313 | # track exit decoding trajectory; requires all-zero cw? 314 | if self._track_exit: 315 | self._ie_c = tf.zeros(self._num_iter+1) 316 | self._ie_v = tf.zeros(self._num_iter+1) 317 | 318 | # perform one decoding iteration 319 | # Remark: msg_vn cannot be ragged as input for tf.while_loop as 320 | # otherwise XLA will not be supported (with TF 2.5) 321 | def dec_iter(llr_ch, msg_vn, it): 322 | it += 1 323 | # msg_vn_old are the cn2vn messages from the previous iteration 324 | msg_vn_old = tf.RaggedTensor.from_row_splits( 325 | values=msg_vn, 326 | row_splits=tf.constant(self._vn_row_splits, tf.int32)) 327 | # variable node update 328 | # msg_vn are now the vn2cn messages from the vn perspective 329 | msg_vn = self._vn_update(msg_vn_old, llr_ch) 330 | 331 | # track exit decoding trajectory; requires all-zero cw 332 | if self._track_exit: 333 | # neg values as different llr def is expected 334 | mi = llr2mi(-1. * msg_vn.flat_values) 335 | self._ie_v = tf.tensor_scatter_nd_add(self._ie_v, 336 | tf.reshape(it, (1, 1)), 337 | tf.reshape(mi, (1))) 338 | 339 | # scale outgoing vn messages (weighted BP); only if activated 340 | if self._has_weights: 341 | msg_vn = tf.ragged.map_flat_values(self._mult_weights, 342 | msg_vn) 343 | # permute edges into CN perspective 344 | msg_cn = tf.gather(msg_vn.flat_values, self._cn_mask_tf, axis=None) 345 | 346 | # check node update using the pre-defined function 347 | msg_cn = self._cn_update(msg_cn) 348 | 349 | # track exit decoding trajectory; requires all-zero cw? 350 | if self._track_exit: 351 | # neg values as different llr def is expected 352 | mi = llr2mi(-1.*msg_cn.flat_values) 353 | # update pos i+1 such that first iter is stored as 0 354 | self._ie_c = tf.tensor_scatter_nd_add(self._ie_c, 355 | tf.reshape(it, (1, 1)), 356 | tf.reshape(mi, (1))) 357 | 358 | # re-permute edges to variable node perspective + damping via vn2cn messages + damping via old and new state (cn2vn messages) 359 | msg_vn = (1-self.alpha[it-1]-self.beta[it-1])*tf.gather(msg_cn.flat_values, self._ind_cn_inv, axis=None) + \ 360 | self.alpha[it-1]*msg_vn.flat_values + \ 361 | self.beta[it-1]*msg_vn_old.flat_values 362 | return llr_ch, msg_vn, it 363 | 364 | # stopping condition (required for tf.while_loop) 365 | def dec_stop(llr_ch, msg_vn, it): # pylint: disable=W0613 366 | return tf.less(it, self._num_iter) 367 | 368 | # start decoding iterations 369 | it = tf.constant(0) 370 | # maximum_iterations required for XLA 371 | _, msg_vn, _ = tf.while_loop(dec_stop, 372 | dec_iter, 373 | (llr_ch, msg_vn, it), 374 | parallel_iterations=1, 375 | maximum_iterations=self._num_iter) 376 | 377 | 378 | # raggedTensor for final marginalization 379 | msg_vn = tf.RaggedTensor.from_row_splits( 380 | values=msg_vn, 381 | row_splits=tf.constant(self._vn_row_splits, tf.int32)) 382 | 383 | # marginalize and remove ragged Tensor 384 | x_hat = tf.add(llr_ch, tf.reduce_sum(msg_vn, axis=1)) 385 | 386 | # restore batch dimension to first dimension 387 | x_hat = tf.transpose(x_hat, (1,0)) 388 | 389 | x_hat = -1. * x_hat # convert llrs back into logits 390 | 391 | if self._hard_out: # hard decide decoder output if required 392 | x_hat = tf.cast(tf.less(0.0, x_hat), self._output_dtype) 393 | 394 | # Reshape c_short so that it matches the original input dimensions 395 | output_shape = llr_ch_shape 396 | output_shape[0] = -1 # overwrite batch dim (can be None in Keras) 397 | 398 | x_reshaped = tf.reshape(x_hat, output_shape) 399 | 400 | # cast output to output_dtype 401 | x_out = tf.cast(x_reshaped, self._output_dtype) 402 | 403 | if not self._stateful: 404 | return x_out 405 | else: 406 | return x_out, msg_vn 407 | 408 | def call(self, inputs): 409 | """Iterative BP decoding function. 410 | 411 | This function performs ``num_iter`` belief propagation decoding 412 | iterations and returns the estimated codeword. 413 | 414 | Args: 415 | inputs (tf.float32): Tensor of shape `[...,n]` containing the 416 | channel logits/llr values. 417 | 418 | Returns: 419 | `tf.float32`: Tensor of shape `[...,n]` or `[...,k]` 420 | (``return_infobits`` is True) containing bit-wise soft-estimates 421 | (or hard-decided bit-values) of all codeword bits (or info 422 | bits, respectively). 423 | 424 | Raises: 425 | ValueError: If ``inputs`` is not of shape `[batch_size, n]`. 426 | 427 | ValueError: If ``num_iter`` is not an integer greater (or equal) 428 | `0`. 429 | 430 | InvalidArgumentError: When rank(``inputs``)<2. 431 | """ 432 | # Modified from sionna code: super().call ==> implements other signature for vn_update function (also takes in iterations count) 433 | 434 | # Extract inputs 435 | if self._stateful: 436 | llr_ch, msg_vn = inputs 437 | else: 438 | llr_ch = inputs 439 | 440 | tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') 441 | 442 | llr_ch_shape = llr_ch.get_shape().as_list() 443 | new_shape = [-1, llr_ch_shape[-1]] 444 | llr_ch_reshaped = tf.reshape(llr_ch, new_shape) 445 | batch_size = tf.shape(llr_ch_reshaped)[0] 446 | 447 | # invert if rate-matching output interleaver was applied as defined in 448 | # Sec. 5.4.2.2 in 38.212 449 | if self._encoder.num_bits_per_symbol is not None: 450 | llr_ch_reshaped = tf.gather(llr_ch_reshaped, 451 | self._encoder.out_int_inv, 452 | axis=-1) 453 | 454 | # undo puncturing of the first 2*Z bit positions 455 | llr_5g = tf.concat( 456 | [tf.zeros([batch_size, 2 * self.encoder.z], self._output_dtype), 457 | llr_ch_reshaped], 458 | 1) 459 | 460 | # undo puncturing of the last positions 461 | # total length must be n_ldpc, while llr_ch has length n 462 | # first 2*z positions are already added 463 | # -> add n_ldpc - n - 2Z punctured positions 464 | k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits 465 | nb_punc_bits = ((self.encoder.n_ldpc - k_filler) 466 | - self.encoder.n - 2 * self.encoder.z) 467 | 468 | llr_5g = tf.concat([llr_5g, 469 | tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes], 470 | self._output_dtype)], 471 | 1) 472 | 473 | # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max) 474 | # the first k positions are the systematic bits 475 | x1 = tf.slice(llr_5g, [0, 0], [batch_size, self.encoder.k]) 476 | 477 | # parity part 478 | nb_par_bits = (self.encoder.n_ldpc - k_filler 479 | - self.encoder.k - self._nb_pruned_nodes) 480 | x2 = tf.slice(llr_5g, 481 | [0, self.encoder.k], 482 | [batch_size, nb_par_bits]) 483 | 484 | # negative sign due to logit definition 485 | z = -self._llr_max * tf.ones([batch_size, k_filler], self._output_dtype) 486 | 487 | llr_5g = tf.concat([x1, z, x2], 1) 488 | 489 | # and execute the decoder (modified super-call because of damping) 490 | if not self._stateful: 491 | x_hat = self.super_call(llr_5g) 492 | else: 493 | x_hat, msg_vn = self.super_call([llr_5g, msg_vn]) 494 | 495 | if self._return_infobits: # return only info bits 496 | # reconstruct u_hat # code is systematic 497 | u_hat = tf.slice(x_hat, [0, 0], [batch_size, self.encoder.k]) 498 | # Reshape u_hat so that it matches the original input dimensions 499 | output_shape = llr_ch_shape[0:-1] + [self.encoder.k] 500 | # overwrite first dimension as this could be None (Keras) 501 | output_shape[0] = -1 502 | u_reshaped = tf.reshape(u_hat, output_shape) 503 | 504 | # enable other output datatypes than tf.float32 505 | u_out = tf.cast(u_reshaped, self._output_dtype) 506 | 507 | if not self._stateful: 508 | return u_out 509 | else: 510 | return u_out, msg_vn 511 | 512 | else: # return all codeword bits 513 | # the transmitted CW bits are not the same as used during decoding 514 | # cf. last parts of 5G encoding function 515 | 516 | # remove last dim 517 | x = tf.reshape(x_hat, [batch_size, self._n_pruned]) 518 | 519 | # remove filler bits at pos (k, k_ldpc) 520 | x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k]) 521 | 522 | x_no_filler2 = tf.slice(x, 523 | [0, self.encoder.k_ldpc], 524 | [batch_size, 525 | self._n_pruned - self.encoder.k_ldpc]) 526 | 527 | x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1) 528 | 529 | # shorten the first 2*Z positions and end after n bits 530 | x_short = tf.slice(x_no_filler, 531 | [0, 2 * self.encoder.z], 532 | [batch_size, self.encoder.n]) 533 | 534 | # if used, apply rate-matching output interleaver again as 535 | # Sec. 5.4.2.2 in 38.212 536 | if self._encoder.num_bits_per_symbol is not None: 537 | x_short = tf.gather(x_short, self._encoder.out_int, axis=-1) 538 | 539 | # Reshape x_short so that it matches the original input dimensions 540 | # overwrite first dimension as this could be None (Keras) 541 | llr_ch_shape[0] = -1 542 | x_short = tf.reshape(x_short, llr_ch_shape) 543 | 544 | # enable other output datatypes than tf.float32 545 | x_out = tf.cast(x_short, self._output_dtype) 546 | 547 | if not self._stateful: 548 | return x_out 549 | else: 550 | return x_out, msg_vn 551 | 552 | class llrTradeOffDampenedLDPC5GDecoder(dampenedLDPC5GDecoder): 553 | def __init__(self, 554 | encoder, 555 | trainableWeights=False, 556 | cn_type='boxplus-phi', 557 | hard_out=True, 558 | track_exit=False, 559 | return_infobits=True, 560 | prune_pcm=True, 561 | num_iter=20, 562 | stateful=False, 563 | output_dtype=tf.float32, 564 | alpha0=0, 565 | beta0=0, 566 | trainDamping=False, 567 | constrainAlpha=True, 568 | constrainBeta=True, 569 | trainLLRTradeOff=False, 570 | alpha_llr_0 = 1, 571 | beta_llr_0 = 0, 572 | **kwargs): 573 | 574 | super().__init__(encoder, 575 | trainable=trainableWeights, 576 | cn_type=cn_type, 577 | hard_out=False, 578 | track_exit=track_exit, 579 | return_infobits=return_infobits, 580 | prune_pcm=prune_pcm, 581 | num_iter=num_iter, 582 | stateful=stateful, 583 | constrainAlpha=constrainAlpha, 584 | constrainBeta=constrainBeta, 585 | output_dtype=output_dtype, 586 | trainDamping=trainDamping, 587 | alpha0=alpha0, beta0=beta0, 588 | **kwargs) 589 | 590 | self._hard_out_ = hard_out 591 | self._trainLlrTradeOff = trainLLRTradeOff 592 | self._trainable = trainDamping or trainableWeights or trainLLRTradeOff 593 | self._alpha_llr = tf.Variable(alpha_llr_0, dtype=output_dtype, trainable=trainLLRTradeOff, name="alpha_llr_tradeoff") 594 | self._beta_llr = tf.Variable(beta_llr_0, dtype=output_dtype, trainable=trainLLRTradeOff, name="beta_llr_tradeoff") 595 | 596 | @property 597 | def alpha_llr(self): 598 | """Alpha values for dampening.""" 599 | return self._alpha_llr 600 | 601 | @property 602 | def beta_llr(self): 603 | """Alpha values for dampening.""" 604 | return self._beta_llr 605 | 606 | def call(self, inputs): 607 | """Iterative BP decoding function. 608 | 609 | This function performs ``num_iter`` belief propagation decoding 610 | iterations and returns the estimated codeword. 611 | 612 | Args: 613 | inputs (tf.float32): Tensor of shape `[...,n]` containing the 614 | channel logits/llr values. 615 | 616 | Returns: 617 | `tf.float32`: Tensor of shape `[...,n]` or `[...,k]` 618 | (``return_infobits`` is True) containing bit-wise soft-estimates 619 | (or hard-decided bit-values) of all codeword bits (or info 620 | bits, respectively). 621 | 622 | Raises: 623 | ValueError: If ``inputs`` is not of shape `[batch_size, n]`. 624 | 625 | ValueError: If ``num_iter`` is not an integer greater (or equal) 626 | `0`. 627 | 628 | InvalidArgumentError: When rank(``inputs``)<2. 629 | """ 630 | # Modified from sionna code: super().call ==> implements other signature for vn_update function (also takes in iterations count) 631 | 632 | # Extract inputs 633 | if self._stateful: 634 | llr_ch, msg_vn = inputs 635 | else: 636 | llr_ch = inputs 637 | 638 | tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') 639 | 640 | llr_ch_shape = llr_ch.get_shape().as_list() 641 | new_shape = [-1, llr_ch_shape[-1]] 642 | llr_ch_reshaped = tf.reshape(llr_ch, new_shape) 643 | batch_size = tf.shape(llr_ch_reshaped)[0] 644 | 645 | # invert if rate-matching output interleaver was applied as defined in 646 | # Sec. 5.4.2.2 in 38.212 647 | if self._encoder.num_bits_per_symbol is not None: 648 | llr_ch_reshaped = tf.gather(llr_ch_reshaped, 649 | self._encoder.out_int_inv, 650 | axis=-1) 651 | 652 | # undo puncturing of the first 2*Z bit positions 653 | llr_5g = tf.concat( 654 | [tf.zeros([batch_size, 2 * self.encoder.z], self._output_dtype), 655 | llr_ch_reshaped], 656 | 1) 657 | 658 | # undo puncturing of the last positions 659 | # total length must be n_ldpc, while llr_ch has length n 660 | # first 2*z positions are already added 661 | # -> add n_ldpc - n - 2Z punctured positions 662 | k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits 663 | nb_punc_bits = ((self.encoder.n_ldpc - k_filler) 664 | - self.encoder.n - 2 * self.encoder.z) 665 | 666 | llr_5g = tf.concat([llr_5g, 667 | tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes], 668 | self._output_dtype)], 669 | 1) 670 | 671 | # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max) 672 | # the first k positions are the systematic bits 673 | x1 = tf.slice(llr_5g, [0, 0], [batch_size, self.encoder.k]) 674 | 675 | # parity part 676 | nb_par_bits = (self.encoder.n_ldpc - k_filler 677 | - self.encoder.k - self._nb_pruned_nodes) 678 | x2 = tf.slice(llr_5g, 679 | [0, self.encoder.k], 680 | [batch_size, nb_par_bits]) 681 | 682 | # negative sign due to logit definition 683 | z = -self._llr_max * tf.ones([batch_size, k_filler], self._output_dtype) 684 | 685 | llr_5g = tf.concat([x1, z, x2], 1) 686 | 687 | # and execute the decoder 688 | # and execute the decoder (modified super-call because of damping) 689 | if not self._stateful: 690 | x_hat = self.super_call(llr_5g) 691 | else: 692 | x_hat, msg_vn = self.super_call([llr_5g, msg_vn]) 693 | # Intrinsic/Extrinsic LLR trade-off 694 | x_hat = self._alpha_llr * x_hat - self._beta_llr * llr_5g 695 | 696 | if self._hard_out_: # hard decide decoder output if required 697 | x_hat = tf.cast(tf.less(0.0, x_hat), self._output_dtype) 698 | 699 | if self._return_infobits: # return only info bits 700 | # reconstruct u_hat # code is systematic 701 | u_hat = tf.slice(x_hat, [0, 0], [batch_size, self.encoder.k]) 702 | # Reshape u_hat so that it matches the original input dimensions 703 | output_shape = llr_ch_shape[0:-1] + [self.encoder.k] 704 | # overwrite first dimension as this could be None (Keras) 705 | output_shape[0] = -1 706 | u_reshaped = tf.reshape(u_hat, output_shape) 707 | 708 | # enable other output datatypes than tf.float32 709 | u_out = tf.cast(u_reshaped, self._output_dtype) 710 | 711 | if not self._stateful: 712 | return u_out 713 | else: 714 | return u_out, msg_vn 715 | 716 | else: # return all codeword bits 717 | # the transmitted CW bits are not the same as used during decoding 718 | # cf. last parts of 5G encoding function 719 | 720 | # remove last dim 721 | x = tf.reshape(x_hat, [batch_size, self._n_pruned]) 722 | 723 | # remove filler bits at pos (k, k_ldpc) 724 | x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k]) 725 | 726 | x_no_filler2 = tf.slice(x, 727 | [0, self.encoder.k_ldpc], 728 | [batch_size, 729 | self._n_pruned - self.encoder.k_ldpc]) 730 | 731 | x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1) 732 | 733 | # shorten the first 2*Z positions and end after n bits 734 | x_short = tf.slice(x_no_filler, 735 | [0, 2 * self.encoder.z], 736 | [batch_size, self.encoder.n]) 737 | 738 | # if used, apply rate-matching output interleaver again as 739 | # Sec. 5.4.2.2 in 38.212 740 | if self._encoder.num_bits_per_symbol is not None: 741 | x_short = tf.gather(x_short, self._encoder.out_int, axis=-1) 742 | 743 | # Reshape x_short so that it matches the original input dimensions 744 | # overwrite first dimension as this could be None (Keras) 745 | llr_ch_shape[0] = -1 746 | x_short = tf.reshape(x_short, llr_ch_shape) 747 | 748 | # enable other output datatypes than tf.float32 749 | x_out = tf.cast(x_short, self._output_dtype) 750 | 751 | if not self._stateful: 752 | return x_out 753 | else: 754 | return x_out, msg_vn -------------------------------------------------------------------------------- /source/decoder_v1.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | """Layers for channel decoding and utility functions.""" 6 | 7 | ## Modified by R. Wiesmayr in October 2022: 8 | # Set message clipping from 20 to 1000 to avoid problems with decoding at high SNR 9 | 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | import scipy as sp # for sparse H matrix computations 14 | from tensorflow.keras.layers import Layer 15 | from sionna.fec.ldpc.encoding import LDPC5GEncoder 16 | from sionna.fec.utils import llr2mi 17 | import matplotlib.pyplot as plt 18 | 19 | class LDPCBPDecoder1(Layer): 20 | # pylint: disable=line-too-long 21 | r"""LDPCBPDecoder(pcm, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, num_iter=20, 22 | stateful=False,output_dtype=tf.float32, **kwargs) 23 | 24 | Iterative belief propagation decoder for low-density parity-check (LDPC) 25 | codes and other `codes on graphs`. 26 | 27 | This class defines a generic belief propagation decoder for decoding 28 | with arbitrary parity-check matrices. It can be used to iteratively 29 | estimate/recover the transmitted codeword (or information bits) based on the 30 | LLR-values of the received noisy codeword observation. 31 | 32 | The decoder implements the flooding SPA algorithm [Ryan]_, i.e., all nodes 33 | are updated in a parallel fashion. Different check node update functions are 34 | available 35 | 36 | (1) `boxplus` 37 | 38 | .. math:: 39 | y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} 40 | \operatorname{tanh} \left( \frac{x_{i' \to j}}{2} \right) \right) 41 | 42 | (2) `boxplus-phi` 43 | 44 | .. math:: 45 | y_{j \to i} = \alpha_{j \to i} \cdot \phi \left( \sum_{i' \in \mathcal{N}_(j) \setminus i} \phi 46 | \left( |x_{i' \to j}|\right) \right) 47 | 48 | with :math:`\phi(x)=-\operatorname{log}(\operatorname{tanh} \left(\frac{x}{2}) \right)` 49 | 50 | (3) `minsum` 51 | 52 | .. math:: 53 | \qquad y_{j \to i} = \alpha_{j \to i} \cdot {min}_{i' \in \mathcal{N}_(j) \setminus i} \left(|x_{i' \to j}|\right) 54 | 55 | where :math:`y_{j \to i}` denotes the message from check node (CN) *j* to 56 | variable node (VN) *i* and :math:`x_{i \to j}` from VN *i* to CN *j*, 57 | respectively. Further, :math:`\mathcal{N}_(j)` denotes all indices of 58 | connected VNs to CN *j* and 59 | 60 | .. math:: 61 | \alpha_{j \to i} = \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{sign}(x_{i' \to j}) 62 | 63 | is the sign of the outgoing message. For further details we refer to 64 | [Ryan]_. 65 | 66 | Note that for full 5G 3GPP NR compatibility, the correct puncturing and 67 | shortening patterns must be applied (cf. [Richardson]_ for details), this 68 | can be done by :class:`~sionna.fec.ldpc.decoding.LDPC5GEncoder` and 69 | :class:`~sionna.fec.ldpc.decoding.LDPC5GDecoder`, respectively. 70 | 71 | If required, the decoder can be made trainable and is fully differentiable 72 | by following the concept of `weighted BP` [Nachmani]_ as shown in Fig. 1 73 | leading to 74 | 75 | .. math:: 76 | y_{j \to i} = 2 \operatorname{tanh}^{-1} \left( \prod_{i' \in \mathcal{N}_(j) \setminus i} \operatorname{tanh} 77 | \left( \frac{\textcolor{red}{w_{i' \to j}} \cdot x_{i' \to j}}{2} \right) \right) 78 | 79 | where :math:`w_{i \to j}` denotes the trainable weight of message :math:`x_{i \to j}`. 80 | Please note that the training of some check node types may be not supported. 81 | 82 | .. figure:: ../figures/weighted_bp.png 83 | 84 | Fig. 1: Weighted BP as proposed in [Nachmani]_. 85 | 86 | 87 | The class inherits from the Keras layer class and can be used as layer in a 88 | Keras model. 89 | 90 | Parameters 91 | ---------- 92 | pcm: ndarray 93 | An ndarray of shape `[n-k, n]` defining the parity-check matrix 94 | consisting only of `0` or `1` entries. Can be also of type `scipy. 95 | sparse.csr_matrix` or `scipy.sparse.csc_matrix`. 96 | 97 | trainable: bool 98 | Defaults to False. If True, every outgoing variable node message is 99 | scaled with a trainable scalar. 100 | 101 | cn_type: str 102 | A string defaults to '"boxplus-phi"'. One of 103 | {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where 104 | '"boxplus"' implements the single-parity-check APP decoding rule. 105 | '"boxplus-phi"' implements the numerical more stable version of 106 | boxplus [Ryan]_. 107 | '"minsum"' implements the min-approximation of the CN 108 | update rule [Ryan]_. 109 | 110 | hard_out: bool 111 | Defaults to True. If True, the decoder provides hard-decided 112 | codeword bits instead of soft-values. 113 | 114 | track_exit: bool 115 | Defaults to False. If True, the decoder tracks EXIT 116 | characteristics. Note that this requires the all-zero 117 | CW as input. 118 | 119 | num_iter: int 120 | Defining the number of decoder iteration (no early stopping used at 121 | the moment!). 122 | 123 | stateful: bool 124 | Defaults to False. If True, the internal VN messages ``msg_vn`` 125 | from the last decoding iteration are returned, and ``msg_vn`` or 126 | `None` needs to be given as a second input when calling the decoder. 127 | This is required for iterative demapping and decoding. 128 | 129 | output_dtype: tf.DType 130 | Defaults to tf.float32. Defines the output datatype of the layer 131 | (internal precision remains tf.float32). 132 | 133 | Input 134 | ----- 135 | llrs_ch or (llrs_ch, msg_vn): 136 | Tensor or Tuple (only required if ``stateful`` is True): 137 | 138 | llrs_ch: [...,n], tf.float32 139 | 2+D tensor containing the channel logits/llr values. 140 | 141 | msg_vn: None or RaggedTensor, tf.float32 142 | Ragged tensor of VN messages. 143 | Required only if ``stateful`` is True. 144 | 145 | Output 146 | ------ 147 | : [...,n], tf.float32 148 | 2+D Tensor of same shape as ``inputs`` containing 149 | bit-wise soft-estimates (or hard-decided bit-values) of all 150 | codeword bits. 151 | 152 | : RaggedTensor, tf.float32: 153 | Tensor of VN messages. 154 | Returned only if ``stateful`` is set to True. 155 | 156 | Attributes 157 | ---------- 158 | pcm: ndarray 159 | An ndarray of shape `[n-k, n]` defining the parity-check matrix 160 | consisting only of `0` or `1` entries. Can be also of type `scipy. 161 | sparse.csr_matrix` or `scipy.sparse.csc_matrix`. 162 | 163 | num_cns: int 164 | Defining the number of check nodes. 165 | 166 | num_vns: int 167 | Defining the number of variable nodes. 168 | 169 | num_edges: int 170 | Defining the total number of edges. 171 | 172 | trainable: bool 173 | If True, the decoder uses trainable weights. 174 | 175 | _atanh_clip_value: float 176 | Defining the internal clipping value before the atanh is applied 177 | (relates to the CN update). 178 | 179 | _cn_type: str 180 | Defining the CN update function type. 181 | 182 | _cn_update: 183 | A function defining the CN update. 184 | 185 | _hard_out: bool 186 | If True, the decoder outputs hard-decided bits. 187 | 188 | _cn_con: ndarray 189 | An ndarray of shape `[num_edges]` defining all edges from check 190 | node perspective. 191 | 192 | _vn_con: ndarray 193 | An ndarray of shape `[num_edges]` defining all edges from variable 194 | node perspective. 195 | 196 | _vn_mask_tf: tf.float32 197 | A ragged Tensor of shape `[num_vns, None]` defining the incoming 198 | message indices per VN. The second dimension is ragged and depends 199 | on the node degree. 200 | 201 | _cn_mask_tf: tf.float32 202 | A ragged Tensor of shape `[num_cns, None]` defining the incoming 203 | message indices per CN. The second dimension is ragged and depends 204 | on the node degree. 205 | 206 | _ind_cn: ndarray 207 | An ndarray of shape `[num_edges]` defining the permutation index to 208 | rearrange messages from variable into check node perspective. 209 | 210 | _ind_cn_inv: ndarray 211 | An ndarray of shape `[num_edges]` defining the permutation index to 212 | rearrange messages from check into variable node perspective. 213 | 214 | _vn_row_splits: ndarray 215 | An ndarray of shape `[num_vns+1]` defining the row split positions 216 | of a 1D vector consisting of all edges messages. Used to build a 217 | ragged Tensor of incoming VN messages. 218 | 219 | _cn_row_splits: ndarray 220 | An ndarray of shape `[num_cns+1]` defining the row split positions 221 | of a 1D vector consisting of all edges messages. Used to build a 222 | ragged Tensor of incoming CN messages. 223 | 224 | _edge_weights: tf.float32 225 | A Tensor of shape `[num_edges]` defining a (trainable) weight per 226 | outgoing VN message. 227 | 228 | Raises: 229 | ValueError 230 | If the shape of ``pcm`` is invalid or contains other values than 231 | `0` or `1` or dtype is not `tf.float32`. 232 | 233 | AssertionError 234 | If ``trainable`` is not `bool`. 235 | 236 | AssertionError 237 | If ``track_exit`` is not `bool`. 238 | 239 | AssertionError 240 | If ``hard_out`` is not `bool`. 241 | 242 | AssertionError 243 | If ``stateful`` is not `bool`. 244 | 245 | AssertionError 246 | If ``cn_type`` is not `str`. 247 | 248 | ValueError 249 | If ``num_iter`` is not an integer greater (or equal) `0`. 250 | 251 | ValueError 252 | If ``output_dtype`` is not 253 | {tf.float16, tf.float32, tf.float64}. 254 | 255 | ValueError 256 | If ``inputs`` is not of shape `[batch_size, n]`. 257 | 258 | InvalidArgumentError 259 | When rank(``inputs``)<2. 260 | Note 261 | ---- 262 | As decoding input logits 263 | :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are 264 | assumed for compatibility with the learning framework, but internally 265 | log-likelihood ratios (LLRs) with definition :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. 266 | 267 | The decoder is not (particularly) optimized for quasi-cyclic (QC) LDPC 268 | codes and, thus, supports arbitrary parity-check matrices. 269 | 270 | The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to 271 | account for arbitrary node degrees. To avoid a performance degradation 272 | caused by a severe indexing overhead, the batch-dimension is shifted to 273 | the last dimension during decoding. 274 | 275 | If the decoder is made trainable [Nachmani]_, for performance 276 | improvements only variable to check node messages are scaled as the VN 277 | operation is linear and, thus, would not increase the expressive power 278 | of the weights. 279 | 280 | """ 281 | 282 | def __init__(self, 283 | pcm, 284 | trainable=False, 285 | cn_type='boxplus-phi', 286 | hard_out=True, 287 | track_exit=False, 288 | num_iter=20, 289 | stateful=False, 290 | output_dtype=tf.float32, 291 | **kwargs): 292 | 293 | super().__init__(dtype=output_dtype, **kwargs) 294 | 295 | assert isinstance(trainable, bool), 'trainable must be bool.' 296 | assert isinstance(hard_out, bool), 'hard_out must be bool.' 297 | assert isinstance(track_exit, bool), 'track_exit must be bool.' 298 | assert isinstance(cn_type, str) , 'cn_type must be str.' 299 | assert isinstance(num_iter, int), 'num_iter must be int.' 300 | assert num_iter>=0, 'num_iter cannot be negative.' 301 | assert isinstance(stateful, bool), 'stateful must be bool.' 302 | assert isinstance(output_dtype, tf.DType), \ 303 | 'output_dtype must be tf.Dtype.' 304 | 305 | if isinstance(pcm, np.ndarray): 306 | assert np.array_equal(pcm, pcm.astype(bool)), 'PC matrix \ 307 | must be binary.' 308 | elif isinstance(pcm, sp.sparse.csr_matrix): 309 | assert np.array_equal(pcm.data, pcm.data.astype(bool)), \ 310 | 'PC matrix must be binary.' 311 | elif isinstance(pcm, sp.sparse.csc_matrix): 312 | assert np.array_equal(pcm.data, pcm.data.astype(bool)), \ 313 | 'PC matrix must be binary.' 314 | else: 315 | raise TypeError("Unsupported dtype of pcm.") 316 | 317 | if output_dtype not in (tf.float16, tf.float32, tf.float64): 318 | raise ValueError( 319 | 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 320 | 321 | if output_dtype is not tf.float32: 322 | print('Note: decoder uses tf.float32 for internal calculations.') 323 | 324 | # init decoder parameters 325 | self._pcm = pcm 326 | self._trainable = trainable 327 | self._cn_type = cn_type 328 | self._hard_out = hard_out 329 | self._track_exit = track_exit 330 | self._num_iter = tf.constant(num_iter, dtype=tf.int32) 331 | self._stateful = stateful 332 | self._output_dtype = output_dtype 333 | 334 | # clipping value for the atanh function is applied (tf.float32 is used) 335 | self._atanh_clip_value = 1 - 1e-7 336 | # clipping for min-sum decoding 337 | self._llr_max_minsum = 1000 338 | 339 | # init code parameters 340 | self._num_cns = pcm.shape[0] # total number of check nodes 341 | self._num_vns = pcm.shape[1] # total number of variable nodes 342 | 343 | # make pcm sparse first if ndarray is provided 344 | if isinstance(pcm, np.ndarray): 345 | pcm = sp.sparse.csr_matrix(pcm) 346 | # find all edges from variable and check node perspective 347 | self._cn_con, self._vn_con, _ = sp.sparse.find(pcm) 348 | 349 | # number of edges equals number of non-zero elements in the 350 | # parity-check matrix 351 | self._num_edges = len(self._vn_con) 352 | 353 | # permutation index to rearrange messages into check node perspective 354 | self._ind_cn = np.argsort(self._cn_con) 355 | 356 | # inverse permutation index to rearrange messages back into variable 357 | # node perspective 358 | self._ind_cn_inv = np.argsort(self._ind_cn) 359 | 360 | # generate row masks (array of integers defining the row split pos.) 361 | self._vn_row_splits = self._gen_node_mask_row(self._vn_con) 362 | self._cn_row_splits = self._gen_node_mask_row( 363 | self._cn_con[self._ind_cn]) 364 | # pre-load the CN function for performance reasons 365 | if self._cn_type=='boxplus': 366 | # check node update using the tanh function 367 | self._cn_update = self._cn_update_tanh 368 | elif self._cn_type=='boxplus-phi': 369 | # check node update using the "_phi" function 370 | self._cn_update = self._cn_update_phi 371 | elif self._cn_type=='minsum': 372 | # check node update using the min-sum approximation 373 | self._cn_update = self._cn_update_minsum 374 | else: 375 | raise ValueError('Unknown node type.') 376 | 377 | # init trainable weights if needed 378 | self._has_weights = False # indicates if trainable weights exist 379 | if self._trainable: 380 | self._has_weights = True 381 | self._edge_weights = tf.Variable(tf.ones(self._num_edges), 382 | trainable=self._trainable, 383 | dtype=tf.float32) 384 | 385 | # track mutual information during decoding 386 | self._ie_c = 0 387 | self._ie_v = 0 388 | 389 | ######################################### 390 | # Public methods and properties 391 | ######################################### 392 | 393 | @property 394 | def pcm(self): 395 | """Parity-check matrix of LDPC code.""" 396 | return self._pcm 397 | 398 | @property 399 | def num_cns(self): 400 | """Number of check nodes.""" 401 | return self._num_cns 402 | 403 | @property 404 | def num_vns(self): 405 | """Number of variable nodes.""" 406 | return self._num_vns 407 | 408 | @property 409 | def num_edges(self): 410 | """Number of edges in decoding graph.""" 411 | return self._num_edges 412 | 413 | @property 414 | def has_weights(self): 415 | """Indicates if decoder has trainable weights.""" 416 | return self._has_weights 417 | 418 | @property 419 | def edge_weights(self): 420 | """Trainable weights of the BP decoder.""" 421 | if not self._has_weights: 422 | return [] 423 | else: 424 | return self._edge_weights 425 | 426 | @property 427 | def output_dtype(self): 428 | """Output dtype of decoder.""" 429 | return self._output_dtype 430 | 431 | @property 432 | def ie_c(self): 433 | "Extrinsic mutual information at check node." 434 | return self._ie_c 435 | 436 | @property 437 | def ie_v(self): 438 | "Extrinsic mutual information at variable node." 439 | return self._ie_v 440 | 441 | @property 442 | def num_iter(self): 443 | "Number of decoding iterations." 444 | return self._num_iter 445 | 446 | @num_iter.setter 447 | def num_iter(self, num_iter): 448 | "Number of decoding iterations." 449 | assert isinstance(num_iter, int), 'num_iter must be int.' 450 | assert num_iter>=0, 'num_iter cannot be negative.' 451 | self._num_iter = tf.constant(num_iter, dtype=tf.int32) 452 | 453 | def show_weights(self, size=7): 454 | """Show histogram of trainable weights. 455 | 456 | Input 457 | ----- 458 | size: float 459 | Figure size of the matplotlib figure. 460 | 461 | """ 462 | # only plot if weights exist 463 | if self._has_weights: 464 | weights = self._edge_weights.numpy() 465 | 466 | plt.figure(figsize=(size,size)) 467 | plt.hist(weights, density=True, bins=20, align='mid') 468 | plt.xlabel('weight value') 469 | plt.ylabel('density') 470 | plt.grid(True, which='both', axis='both') 471 | plt.title('Weight Distribution') 472 | else: 473 | print("No weights to show.") 474 | 475 | ######################### 476 | # Utility methods 477 | ######################### 478 | 479 | def _gen_node_mask(self, con): 480 | """ Generates internal node masks indicating which msg index belongs 481 | to which node index. 482 | """ 483 | ind = np.argsort(con) 484 | con = con[ind] 485 | 486 | node_mask = [] 487 | 488 | cur_node = 0 489 | cur_mask = [] 490 | for i in range(self._num_edges): 491 | if con[i] == cur_node: 492 | cur_mask.append(ind[i]) 493 | else: 494 | node_mask.append(cur_mask) 495 | cur_mask = [ind[i]] 496 | cur_node += 1 497 | node_mask.append(cur_mask) 498 | return node_mask 499 | 500 | def _gen_node_mask_row(self, con): 501 | """ Defining the row split positions of a 1D vector consisting of all 502 | edges messages. 503 | 504 | Used to build a ragged Tensor of incoming node messages. 505 | """ 506 | node_mask = [0] # the first element indicates the first node index (=0) 507 | 508 | cur_node = 0 509 | for i in range(self._num_edges): 510 | if con[i] != cur_node: 511 | node_mask.append(i) 512 | cur_node += 1 513 | node_mask.append(self._num_edges) # last element must be the number of 514 | # elements (delimiter) 515 | return node_mask 516 | 517 | def _vn_update(self, msg, llr_ch): 518 | """ Variable node update function. 519 | 520 | This function implements the (extrinsic) variable node update 521 | function. It takes the sum over all incoming messages ``msg`` excluding 522 | the intrinsic (= outgoing) message itself. 523 | 524 | Additionally, the channel LLR ``llr_ch`` is added to each message. 525 | """ 526 | # aggregate all incoming messages per node 527 | x = tf.reduce_sum(msg, axis=1) 528 | x = tf.add(x, llr_ch) 529 | 530 | # TF2.9 does not support XLA for the addition of ragged tensors 531 | # the following code provides a workaround that supports XLA 532 | 533 | # subtract extrinsic message from node value 534 | # x = tf.expand_dims(x, axis=1) 535 | # x = tf.add(-msg, x) 536 | x = tf.ragged.map_flat_values(lambda x, y, row_ind : 537 | x + tf.gather(y, row_ind), 538 | -1.*msg, 539 | x, 540 | msg.value_rowids()) 541 | return x 542 | 543 | def _extrinsic_min(self, msg): 544 | """ Provides the extrinsic min operation for the minsum approximation 545 | of the CN function. 546 | 547 | This function implements the extrinsic min operation, i.e., 548 | the min is taken over all values excluding the value at the current 549 | index. 550 | 551 | Note that the input is expected to be a Tensor and NOT a ragged Tensor. 552 | """ 553 | num_val = tf.shape(msg)[0] 554 | msg = tf.transpose(msg, (1,0)) 555 | msg = tf.expand_dims(msg, axis=1) 556 | id_mat = tf.eye(num_val) 557 | 558 | msg = (tf.tile(msg, (1, num_val, 1)) # create outgoing tensor per value 559 | + 1000. * id_mat) # "ignore" intrinsic msg by adding large const. 560 | 561 | 562 | msg = tf.math.reduce_min(msg, axis=2) 563 | msg = tf.transpose(msg, (1,0)) 564 | return msg 565 | 566 | def _where_ragged(self, msg): 567 | """Helper to replace 0 elements from ragged tensor (called with 568 | map_flat_values).""" 569 | return tf.where(tf.equal(msg, 0), tf.ones_like(msg) * 1e-12, msg) 570 | 571 | def _where_ragged_inv(self, msg): 572 | """Helper to replace small elements from ragged tensor (called with 573 | map_flat_values) with exact `0`.""" 574 | msg_mod = tf.where(tf.less(tf.abs(msg), 1e-7), 575 | tf.zeros_like(msg), 576 | msg) 577 | return msg_mod 578 | 579 | def _cn_update_tanh(self, msg): 580 | """Check node update function implementing the exact boxplus operation. 581 | 582 | This function implements the (extrinsic) check node update 583 | function. It calculates the boxplus function over all incoming messages 584 | "msg" excluding the intrinsic (=outgoing) message itself. 585 | The exact boxplus function is implemented by using the tanh function. 586 | 587 | The input is expected to be a ragged Tensor of shape 588 | `[num_cns, None, batch_size]`. 589 | 590 | Note that for numerical stability clipping is applied. 591 | """ 592 | 593 | msg = msg / 2 594 | # tanh is not overloaded for ragged tensors 595 | msg = tf.ragged.map_flat_values(tf.tanh, msg) # tanh is not overloaded 596 | 597 | # for ragged tensors; map to flat tensor first 598 | msg = tf.ragged.map_flat_values(self._where_ragged, msg) 599 | 600 | msg_prod = tf.reduce_prod(msg, axis=1) 601 | 602 | # TF2.9 does not support XLA for the multiplication of ragged tensors 603 | # the following code provides a workaround that supports XLA 604 | 605 | # ^-1 to avoid division 606 | # Note this is (potentially) numerically unstable 607 | # msg = msg**-1 * tf.expand_dims(msg_prod, axis=1) # remove own edge 608 | 609 | msg = tf.ragged.map_flat_values(lambda x, y, row_ind : 610 | x * tf.gather(y, row_ind), 611 | msg**-1, 612 | msg_prod, 613 | msg.value_rowids()) 614 | 615 | # Overwrite small (numerical zeros) message values with exact zero 616 | # these are introduced by the previous "_where_ragged" operation 617 | # this is required to keep the product stable (cf. _phi_update for log 618 | # sum implementation) 619 | msg = tf.ragged.map_flat_values(self._where_ragged_inv, msg) 620 | 621 | msg = tf.clip_by_value(msg, 622 | clip_value_min=-self._atanh_clip_value, 623 | clip_value_max=self._atanh_clip_value) 624 | 625 | # atanh is not overloaded for ragged tensors 626 | msg = 2 * tf.ragged.map_flat_values(tf.atanh, msg) 627 | return msg 628 | 629 | def _phi(self, x): 630 | """Helper function for the check node update. 631 | 632 | This function implements the (element-wise) `"_phi"` function as defined 633 | in [Ryan]_. 634 | """ 635 | # the clipping values are optimized for tf.float32 636 | x = tf.clip_by_value(x, clip_value_min=8.5e-8, clip_value_max=16.635532) 637 | return tf.math.log(tf.math.exp(x)+1) - tf.math.log(tf.math.exp(x)-1) 638 | 639 | def _cn_update_phi(self, msg): 640 | """Check node update function implementing the exact boxplus operation. 641 | 642 | This function implements the (extrinsic) check node update function 643 | based on the numerically more stable `"_phi"` function (cf. [Ryan]_). 644 | It calculates the boxplus function over all incoming messages ``msg`` 645 | excluding the intrinsic (=outgoing) message itself. 646 | The exact boxplus function is implemented by using the `"_phi"` function 647 | as in [Ryan]_. 648 | 649 | The input is expected to be a ragged Tensor of shape 650 | `[num_cns, None, batch_size]`. 651 | 652 | Note that for numerical stability clipping is applied. 653 | """ 654 | 655 | sign_val = tf.sign(msg) 656 | 657 | sign_val = tf.where(tf.equal(sign_val, 0), 658 | tf.ones_like(sign_val), 659 | sign_val) 660 | 661 | sign_node = tf.reduce_prod(sign_val, axis=1) 662 | 663 | # TF2.9 does not support XLA for the multiplication of ragged tensors 664 | # the following code provides a workaround that supports XLA 665 | 666 | # sign_val = sign_val * tf.expand_dims(sign_node, axis=1) 667 | sign_val = tf.ragged.map_flat_values(lambda x, y, row_ind : 668 | x * tf.gather(y, row_ind), 669 | sign_val, 670 | sign_node, 671 | sign_val.value_rowids()) 672 | 673 | msg = tf.ragged.map_flat_values(tf.abs, msg) # remove sign 674 | 675 | # apply _phi element-wise (does not support ragged Tensors) 676 | msg = tf.ragged.map_flat_values(self._phi, msg) 677 | msg_sum = tf.reduce_sum(msg, axis=1) 678 | 679 | # TF2.9 does not support XLA for the addition of ragged tensors 680 | # the following code provides a workaround that supports XLA 681 | 682 | # msg = tf.add( -msg, tf.expand_dims(msg_sum, axis=1)) # remove own edge 683 | msg = tf.ragged.map_flat_values(lambda x, y, row_ind : 684 | x + tf.gather(y, row_ind), 685 | -1.*msg, 686 | msg_sum, 687 | msg.value_rowids()) 688 | 689 | # apply _phi element-wise (does not support ragged Tensors) 690 | msg = self._stop_ragged_gradient(sign_val) * tf.ragged.map_flat_values( 691 | self._phi, msg) 692 | return msg 693 | 694 | def _stop_ragged_gradient(self, rt): 695 | """Helper function as TF 2.5 does not support ragged gradient 696 | stopping""" 697 | return rt.with_flat_values(tf.stop_gradient(rt.flat_values)) 698 | 699 | def _sign_val_minsum(self, msg): 700 | """Helper to replace find sign-value during min-sum decoding. 701 | Must be called with `map_flat_values`.""" 702 | 703 | sign_val = tf.sign(msg) 704 | sign_val = tf.where(tf.equal(sign_val, 0), 705 | tf.ones_like(sign_val), 706 | sign_val) 707 | return sign_val 708 | 709 | def _cn_update_minsum_mapfn(self, msg): 710 | """ Check node update function implementing the min-sum approximation. 711 | 712 | This function approximates the (extrinsic) check node update 713 | function based on the min-sum approximation (cf. [Ryan]_). 714 | It calculates the "extrinsic" min function over all incoming messages 715 | ``msg`` excluding the intrinsic (=outgoing) message itself. 716 | 717 | The input is expected to be a ragged Tensor of shape 718 | `[num_vns, None, batch_size]`. 719 | 720 | This function uses tf.map_fn() to call the CN updates. 721 | It is currently not used, but can be used as template to implement 722 | modified CN functions (e.g., offset-corrected minsum). 723 | Please note that tf.map_fn lowers the throughput significantly. 724 | """ 725 | 726 | sign_val = tf.ragged.map_flat_values(self._sign_val_minsum, msg) 727 | 728 | sign_node = tf.reduce_prod(sign_val, axis=1) 729 | sign_val = self._stop_ragged_gradient(sign_val) * tf.expand_dims( 730 | sign_node, axis=1) 731 | 732 | msg = tf.ragged.map_flat_values(tf.abs, msg) # remove sign 733 | 734 | # calculate extrinsic messages and include the sign 735 | msg_e = tf.map_fn(self._extrinsic_min, msg, infer_shape=False) 736 | 737 | # ensure shape after map_fn 738 | msg_fv = msg_e.flat_values 739 | msg_fv = tf.ensure_shape(msg_fv, msg.flat_values.shape) 740 | msg_e = msg.with_flat_values(msg_fv) 741 | 742 | msg = sign_val * msg_e 743 | 744 | return msg 745 | 746 | def _cn_update_minsum(self, msg): 747 | """ Check node update function implementing the min-sum approximation. 748 | 749 | This function approximates the (extrinsic) check node update 750 | function based on the min-sum approximation (cf. [Ryan]_). 751 | It calculates the "extrinsic" min function over all incoming messages 752 | ``msg`` excluding the intrinsic (=outgoing) message itself. 753 | 754 | The input is expected to be a ragged Tensor of shape 755 | `[num_vns, None, batch_size]`. 756 | """ 757 | # a constant used overwrite the first min 758 | LARGE_VAL = 100000. # pylint: disable=invalid-name 759 | 760 | # clip values for numerical stability 761 | msg = tf.clip_by_value(msg, 762 | clip_value_min=-self._llr_max_minsum, 763 | clip_value_max=self._llr_max_minsum) 764 | 765 | # calculate sign of outgoing msg 766 | sign_val = tf.ragged.map_flat_values(self._sign_val_minsum, msg) 767 | 768 | sign_node = tf.reduce_prod(sign_val, axis=1) 769 | 770 | # TF2.9 does not support XLA for the multiplication of ragged tensors 771 | # the following code provides a workaround that supports XLA 772 | 773 | # sign_val = self._stop_ragged_gradient(sign_val) \ 774 | # * tf.expand_dims(sign_node, axis=1) 775 | sign_val = tf.ragged.map_flat_values( 776 | lambda x, y, row_ind: 777 | tf.multiply(x, tf.gather(y, row_ind)), 778 | self._stop_ragged_gradient(sign_val), 779 | sign_node, 780 | sign_val.value_rowids()) 781 | 782 | msg = tf.ragged.map_flat_values(tf.abs, msg) # remove sign 783 | 784 | # Calculate the extrinsic minimum per CN, i.e., for each message of 785 | # index i, find the smallest and the second smallest value. 786 | # However, in some cases the second smallest value may equal the 787 | # smallest value (multiplicity of mins). 788 | # Please note that this needs to be applied to raggedTensors, e.g., 789 | # tf.top_k() is currently not supported and the ops must support graph 790 | # # mode. 791 | 792 | # find min_value per node 793 | min_val = tf.reduce_min(msg, axis=1, keepdims=True) 794 | 795 | # TF2.9 does not support XLA for the subtraction of ragged tensors 796 | # the following code provides a workaround that supports XLA 797 | 798 | # and subtract min; the new array contains zero at the min positions 799 | # benefits from broadcasting; all other values are positive 800 | # msg_min1 = msg - min_val 801 | msg_min1 = tf.ragged.map_flat_values(lambda x, y, row_ind: 802 | x- tf.gather(y, row_ind), 803 | msg, 804 | tf.squeeze(min_val, axis=1), 805 | msg.value_rowids()) 806 | 807 | # replace 0 (=min positions) with large value to ignore it for further 808 | # min calculations 809 | msg = tf.ragged.map_flat_values(lambda x: 810 | tf.where(tf.equal(x, 0), LARGE_VAL, x), 811 | msg_min1) 812 | 813 | # find the second smallest element (we add min_val as this has been 814 | # subtracted before) 815 | min_val2 = tf.reduce_min(msg, axis=1, keepdims=True) + min_val 816 | 817 | # Detect duplicated minima (i.e., min_val occurs at two incoming 818 | # messages). As the LLRs per node are 2*LARGE_VAL, the multiplicity of the min is at least 2. 822 | node_sum = tf.reduce_sum(msg, axis=1, keepdims=True) - (2*LARGE_VAL-1.) 823 | # indicator that duplicated min was detected (per node) 824 | double_min = 0.5*(1-tf.sign(node_sum)) 825 | 826 | # if a duplicate min occurred, both edges must have min_val, otherwise 827 | # the second smallest value is taken 828 | min_val_e = (1-double_min) * min_val + (double_min) * min_val2 829 | 830 | # replace all values with min_val except the position where the min 831 | # occurred (=extrinsic min). 832 | msg_e = tf.where(msg==LARGE_VAL, min_val_e, min_val) 833 | 834 | # it seems like tf.where does not set the shape of tf.ragged properly 835 | # we need to ensure the shape manually 836 | msg_e = tf.ragged.map_flat_values( 837 | lambda x: 838 | tf.ensure_shape(x, msg.flat_values.shape), 839 | msg_e) 840 | 841 | # TF2.9 does not support XLA for the multiplication of ragged tensors 842 | # the following code provides a workaround that supports XLA 843 | 844 | # and apply sign 845 | #msg = sign_val * msg_e 846 | msg = tf.ragged.map_flat_values(tf.multiply, 847 | sign_val, 848 | msg_e) 849 | 850 | return msg 851 | 852 | def _mult_weights(self, x): 853 | """Multiply messages with trainable weights for weighted BP.""" 854 | # transpose for simpler broadcasting of training variables 855 | x = tf.transpose(x, (1, 0)) 856 | x = tf.math.multiply(x, self._edge_weights) 857 | x = tf.transpose(x, (1, 0)) 858 | return x 859 | 860 | ######################### 861 | # Keras layer functions 862 | ######################### 863 | 864 | def build(self, input_shape): 865 | # Raise AssertionError if shape of x is invalid 866 | if self._stateful: 867 | assert(len(input_shape)==2), \ 868 | "For stateful decoding, a tuple of two inputs is expected." 869 | input_shape = input_shape[0] 870 | 871 | assert (input_shape[-1]==self._num_vns), \ 872 | 'Last dimension must be of length n.' 873 | assert (len(input_shape)>=2), 'The inputs must have at least rank 2.' 874 | 875 | def call(self, inputs): 876 | """Iterative BP decoding function. 877 | 878 | This function performs ``num_iter`` belief propagation decoding 879 | iterations and returns the estimated codeword. 880 | 881 | Args: 882 | llr_ch or (llr_ch, msg_vn): 883 | 884 | llr_ch (tf.float32): Tensor of shape `[...,n]` containing the 885 | channel logits/llr values. 886 | 887 | msg_vn (tf.float32) : Ragged tensor containing the VN 888 | messages, or None. Required if ``stateful`` is set to True. 889 | 890 | Returns: 891 | `tf.float32`: Tensor of shape `[...,n]` containing 892 | bit-wise soft-estimates (or hard-decided bit-values) of all 893 | codeword bits. 894 | 895 | Raises: 896 | ValueError: If ``inputs`` is not of shape `[batch_size, n]`. 897 | 898 | InvalidArgumentError: When rank(``inputs``)<2. 899 | """ 900 | 901 | # Extract inputs 902 | if self._stateful: 903 | llr_ch, msg_vn = inputs 904 | else: 905 | llr_ch = inputs 906 | 907 | tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') 908 | 909 | # internal calculations still in tf.float32 910 | llr_ch = tf.cast(llr_ch, tf.float32) 911 | 912 | # last dim must be of length n 913 | tf.debugging.assert_equal(tf.shape(llr_ch)[-1], 914 | self._num_vns, 915 | 'Last dimension must be of length n.') 916 | 917 | llr_ch_shape = llr_ch.get_shape().as_list() 918 | new_shape = [-1, self._num_vns] 919 | llr_ch_reshaped = tf.reshape(llr_ch, new_shape) 920 | 921 | # must be done during call, as XLA fails otherwise due to ragged 922 | # indices placed on the CPU device. 923 | # create permutation index from cn perspective 924 | self._cn_mask_tf = tf.ragged.constant(self._gen_node_mask(self._cn_con), 925 | row_splits_dtype=tf.int32) 926 | 927 | # batch dimension is last dimension due to ragged tensor representation 928 | llr_ch = tf.transpose(llr_ch_reshaped, (1,0)) 929 | 930 | llr_ch = -1. * llr_ch # logits are converted into "true" llrs 931 | 932 | # init internal decoder state if not explicitly 933 | # provided (e.g., required to restore decoder state for iterative 934 | # detection and decoding) 935 | # load internal state from previous iteration 936 | # required for iterative det./dec. 937 | if not self._stateful or msg_vn is None: 938 | msg_shape = tf.stack([tf.constant(self._num_edges), 939 | tf.shape(llr_ch)[1]], 940 | axis=0) 941 | msg_vn = tf.zeros(msg_shape, dtype=tf.float32) 942 | else: 943 | msg_vn = msg_vn.flat_values 944 | 945 | # track exit decoding trajectory; requires all-zero cw? 946 | if self._track_exit: 947 | self._ie_c = tf.zeros(self._num_iter+1) 948 | self._ie_v = tf.zeros(self._num_iter+1) 949 | 950 | # perform one decoding iteration 951 | # Remark: msg_vn cannot be ragged as input for tf.while_loop as 952 | # otherwise XLA will not be supported (with TF 2.5) 953 | def dec_iter(llr_ch, msg_vn, it): 954 | it += 1 955 | 956 | msg_vn = tf.RaggedTensor.from_row_splits( 957 | values=msg_vn, 958 | row_splits=tf.constant(self._vn_row_splits, tf.int32)) 959 | # variable node update 960 | msg_vn = self._vn_update(msg_vn, llr_ch) 961 | 962 | # track exit decoding trajectory; requires all-zero cw 963 | if self._track_exit: 964 | # neg values as different llr def is expected 965 | mi = llr2mi(-1. * msg_vn.flat_values) 966 | self._ie_v = tf.tensor_scatter_nd_add(self._ie_v, 967 | tf.reshape(it, (1, 1)), 968 | tf.reshape(mi, (1))) 969 | 970 | # scale outgoing vn messages (weighted BP); only if activated 971 | if self._has_weights: 972 | msg_vn = tf.ragged.map_flat_values(self._mult_weights, 973 | msg_vn) 974 | # permute edges into CN perspective 975 | msg_cn = tf.gather(msg_vn.flat_values, self._cn_mask_tf, axis=None) 976 | 977 | # check node update using the pre-defined function 978 | msg_cn = self._cn_update(msg_cn) 979 | 980 | # track exit decoding trajectory; requires all-zero cw? 981 | if self._track_exit: 982 | # neg values as different llr def is expected 983 | mi = llr2mi(-1.*msg_cn.flat_values) 984 | # update pos i+1 such that first iter is stored as 0 985 | self._ie_c = tf.tensor_scatter_nd_add(self._ie_c, 986 | tf.reshape(it, (1, 1)), 987 | tf.reshape(mi, (1))) 988 | 989 | # re-permute edges to variable node perspective 990 | msg_vn = tf.gather(msg_cn.flat_values, self._ind_cn_inv, axis=None) 991 | return llr_ch, msg_vn, it 992 | 993 | # stopping condition (required for tf.while_loop) 994 | def dec_stop(llr_ch, msg_vn, it): # pylint: disable=W0613 995 | return tf.less(it, self._num_iter) 996 | 997 | # start decoding iterations 998 | it = tf.constant(0) 999 | # maximum_iterations required for XLA 1000 | _, msg_vn, _ = tf.while_loop(dec_stop, 1001 | dec_iter, 1002 | (llr_ch, msg_vn, it), 1003 | parallel_iterations=1, 1004 | maximum_iterations=self._num_iter) 1005 | 1006 | 1007 | # raggedTensor for final marginalization 1008 | msg_vn = tf.RaggedTensor.from_row_splits( 1009 | values=msg_vn, 1010 | row_splits=tf.constant(self._vn_row_splits, tf.int32)) 1011 | 1012 | # marginalize and remove ragged Tensor 1013 | x_hat = tf.add(llr_ch, tf.reduce_sum(msg_vn, axis=1)) 1014 | 1015 | # restore batch dimension to first dimension 1016 | x_hat = tf.transpose(x_hat, (1,0)) 1017 | 1018 | x_hat = -1. * x_hat # convert llrs back into logits 1019 | 1020 | if self._hard_out: # hard decide decoder output if required 1021 | x_hat = tf.cast(tf.less(0.0, x_hat), self._output_dtype) 1022 | 1023 | # Reshape c_short so that it matches the original input dimensions 1024 | output_shape = llr_ch_shape 1025 | output_shape[0] = -1 # overwrite batch dim (can be None in Keras) 1026 | 1027 | x_reshaped = tf.reshape(x_hat, output_shape) 1028 | 1029 | # cast output to output_dtype 1030 | x_out = tf.cast(x_reshaped, self._output_dtype) 1031 | 1032 | if not self._stateful: 1033 | return x_out 1034 | else: 1035 | return x_out, msg_vn 1036 | 1037 | class LDPC5GDecoder1(LDPCBPDecoder1): 1038 | # pylint: disable=line-too-long 1039 | r"""LDPC5GDecoder(encoder, trainable=False, cn_type='boxplus-phi', hard_out=True, track_exit=False, return_infobits=True, prune_pcm=True, num_iter=20, stateful=False, output_dtype=tf.float32, **kwargs) 1040 | 1041 | (Iterative) belief propagation decoder for 5G NR LDPC codes. 1042 | 1043 | Inherits from :class:`~sionna.fec.ldpc.decoding.LDPCBPDecoder` and provides 1044 | a wrapper for 5G compatibility, i.e., automatically handles puncturing and 1045 | shortening according to [3GPPTS38212_LDPC]_. 1046 | 1047 | Note that for full 5G 3GPP NR compatibility, the correct puncturing and 1048 | shortening patterns must be applied and, thus, the encoder object is 1049 | required as input. 1050 | 1051 | If required the decoder can be made trainable and is differentiable 1052 | (the training of some check node types may be not supported) following the 1053 | concept of "weighted BP" [Nachmani]_. 1054 | 1055 | The class inherits from the Keras layer class and can be used as layer in a 1056 | Keras model. 1057 | 1058 | Parameters 1059 | ---------- 1060 | encoder: LDPC5GEncoder 1061 | An instance of :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder` 1062 | containing the correct code parameters. 1063 | 1064 | trainable: bool 1065 | Defaults to False. If True, every outgoing variable node message is 1066 | scaled with a trainable scalar. 1067 | 1068 | cn_type: str 1069 | A string defaults to '"boxplus-phi"'. One of 1070 | {`"boxplus"`, `"boxplus-phi"`, `"minsum"`} where 1071 | '"boxplus"' implements the single-parity-check APP decoding rule. 1072 | '"boxplus-phi"' implements the numerical more stable version of 1073 | boxplus [Ryan]_. 1074 | '"minsum"' implements the min-approximation of the CN 1075 | update rule [Ryan]_. 1076 | 1077 | hard_out: bool 1078 | Defaults to True. If True, the decoder provides hard-decided 1079 | codeword bits instead of soft-values. 1080 | 1081 | track_exit: bool 1082 | Defaults to False. If True, the decoder tracks EXIT characteristics. 1083 | Note that this requires the all-zero CW as input. 1084 | 1085 | return_infobits: bool 1086 | Defaults to True. If True, only the `k` info bits (soft or 1087 | hard-decided) are returned. Otherwise all `n` positions are 1088 | returned. 1089 | 1090 | prune_pcm: bool 1091 | Defaults to True. If True, all punctured degree-1 VNs and 1092 | connected check nodes are removed from the decoding graph (see 1093 | [Cammerer]_ for details). Besides numerical differences, this should 1094 | yield the same decoding result but improved the decoding throughput 1095 | and reduces the memory footprint. 1096 | 1097 | num_iter: int 1098 | Defining the number of decoder iteration (no early stopping used at 1099 | the moment!). 1100 | 1101 | stateful: bool 1102 | Defaults to False. If True, the internal VN messages ``msg_vn`` 1103 | from the last decoding iteration are returned, and ``msg_vn`` or 1104 | `None` needs to be given as a second input when calling the decoder. 1105 | This is required for iterative demapping and decoding. 1106 | 1107 | output_dtype: tf.DType 1108 | Defaults to tf.float32. Defines the output datatype of the layer 1109 | (internal precision remains tf.float32). 1110 | 1111 | Input 1112 | ----- 1113 | llrs_ch or (llrs_ch, msg_vn): 1114 | Tensor or Tuple (only required if ``stateful`` is True): 1115 | 1116 | llrs_ch: [...,n], tf.float32 1117 | 2+D tensor containing the channel logits/llr values. 1118 | 1119 | msg_vn: None or RaggedTensor, tf.float32 1120 | Ragged tensor of VN messages. 1121 | Required only if ``stateful`` is True. 1122 | 1123 | Output 1124 | ------ 1125 | : [...,n] or [...,k], tf.float32 1126 | 2+D Tensor of same shape as ``inputs`` containing 1127 | bit-wise soft-estimates (or hard-decided bit-values) of all 1128 | codeword bits. If ``return_infobits`` is True, only the `k` 1129 | information bits are returned. 1130 | 1131 | : RaggedTensor, tf.float32: 1132 | Tensor of VN messages. 1133 | Returned only if ``stateful`` is set to True. 1134 | Raises 1135 | ------ 1136 | ValueError 1137 | If the shape of ``pcm`` is invalid or contains other 1138 | values than `0` or `1`. 1139 | 1140 | AssertionError 1141 | If ``trainable`` is not `bool`. 1142 | 1143 | AssertionError 1144 | If ``track_exit`` is not `bool`. 1145 | 1146 | AssertionError 1147 | If ``hard_out`` is not `bool`. 1148 | 1149 | AssertionError 1150 | If ``return_infobits`` is not `bool`. 1151 | 1152 | AssertionError 1153 | If ``encoder`` is not an instance of 1154 | :class:`~sionna.fec.ldpc.encoding.LDPC5GEncoder`. 1155 | 1156 | ValueError 1157 | If ``output_dtype`` is not {tf.float16, tf.float32, tf. 1158 | float64}. 1159 | 1160 | ValueError 1161 | If ``inputs`` is not of shape `[batch_size, n]`. 1162 | 1163 | ValueError 1164 | If ``num_iter`` is not an integer greater (or equal) `0`. 1165 | 1166 | InvalidArgumentError 1167 | When rank(``inputs``)<2. 1168 | 1169 | Note 1170 | ---- 1171 | As decoding input logits 1172 | :math:`\operatorname{log} \frac{p(x=1)}{p(x=0)}` are assumed for 1173 | compatibility with the learning framework, but 1174 | internally llrs with definition 1175 | :math:`\operatorname{log} \frac{p(x=0)}{p(x=1)}` are used. 1176 | 1177 | The decoder is not (particularly) optimized for Quasi-cyclic (QC) LDPC 1178 | codes and, thus, supports arbitrary parity-check matrices. 1179 | 1180 | The decoder is implemented by using '"ragged Tensors"' [TF_ragged]_ to 1181 | account for arbitrary node degrees. To avoid a performance degradation 1182 | caused by a severe indexing overhead, the batch-dimension is shifted to 1183 | the last dimension during decoding. 1184 | 1185 | If the decoder is made trainable [Nachmani]_, for performance 1186 | improvements only variable to check node messages are scaled as the VN 1187 | operation is linear and, thus, would not increase the expressive power 1188 | of the weights. 1189 | """ 1190 | 1191 | def __init__(self, 1192 | encoder, 1193 | trainable=False, 1194 | cn_type='boxplus-phi', 1195 | hard_out=True, 1196 | track_exit=False, 1197 | return_infobits=True, 1198 | prune_pcm=True, 1199 | num_iter=20, 1200 | stateful=False, 1201 | output_dtype=tf.float32, 1202 | **kwargs): 1203 | 1204 | # needs the 5G Encoder to access all 5G parameters 1205 | assert isinstance(encoder, LDPC5GEncoder), 'encoder must \ 1206 | be of class LDPC5GEncoder.' 1207 | self._encoder = encoder 1208 | pcm = encoder.pcm 1209 | 1210 | assert isinstance(return_infobits, bool), 'return_info must be bool.' 1211 | self._return_infobits = return_infobits 1212 | 1213 | self._llr_max = 10000 # internal max value for LLR initialization 1214 | 1215 | assert isinstance(output_dtype, tf.DType), \ 1216 | 'output_dtype must be tf.DType.' 1217 | if output_dtype not in (tf.float16, tf.float32, tf.float64): 1218 | raise ValueError( 1219 | 'output_dtype must be {tf.float16, tf.float32, tf.float64}.') 1220 | self._output_dtype = output_dtype 1221 | 1222 | assert isinstance(stateful, bool), 'stateful must be bool.' 1223 | self._stateful = stateful 1224 | 1225 | assert isinstance(prune_pcm, bool), 'prune_pcm must be bool.' 1226 | # prune punctured degree-1 VNs and connected CNs. A punctured 1227 | # VN-1 node will always "send" llr=0 to the connected CN. Thus, this 1228 | # CN will only send 0 messages to all other VNs, i.e., does not 1229 | # contribute to the decoding process. 1230 | self._prune_pcm = prune_pcm 1231 | if prune_pcm: 1232 | # find index of first position with only degree-1 VN 1233 | dv = np.sum(pcm, axis=0) # VN degree 1234 | last_pos = encoder._n_ldpc 1235 | for idx in range(encoder._n_ldpc-1, 0, -1): 1236 | if dv[0, idx]==1: 1237 | last_pos = idx 1238 | else: 1239 | break 1240 | # number of filler bits 1241 | k_filler = self.encoder.k_ldpc - self.encoder.k 1242 | # number of punctured bits 1243 | nb_punc_bits = ((self.encoder.n_ldpc - k_filler) 1244 | - self.encoder.n - 2*self.encoder.z) 1245 | # effective codeword length after pruning of vn-1 nodes 1246 | self._n_pruned = np.max((last_pos, encoder._n_ldpc - nb_punc_bits)) 1247 | self._nb_pruned_nodes = encoder._n_ldpc - self._n_pruned 1248 | # remove last CNs and VNs from pcm 1249 | pcm = pcm[:-self._nb_pruned_nodes, :-self._nb_pruned_nodes] 1250 | 1251 | #check for consistency 1252 | assert(self._nb_pruned_nodes>=0), "Internal error: number of \ 1253 | pruned nodes must be positive." 1254 | else: 1255 | self._nb_pruned_nodes = 0 1256 | # no pruning; same length as before 1257 | self._n_pruned = encoder._n_ldpc 1258 | 1259 | 1260 | 1261 | super().__init__(pcm, 1262 | trainable, 1263 | cn_type, 1264 | hard_out, 1265 | track_exit, 1266 | num_iter=num_iter, 1267 | stateful=stateful, 1268 | output_dtype=output_dtype, 1269 | **kwargs) 1270 | 1271 | ######################################### 1272 | # Public methods and properties 1273 | ######################################### 1274 | 1275 | @property 1276 | def llr_max(self): 1277 | """Max LLR value used for rate-matching.""" 1278 | return self._llr_max 1279 | 1280 | @property 1281 | def encoder(self): 1282 | """LDPC Encoder used for rate-matching/recovery.""" 1283 | return self._encoder 1284 | 1285 | ######################### 1286 | # Keras layer functions 1287 | ######################### 1288 | 1289 | def build(self, input_shape): 1290 | """Build model.""" 1291 | if self._stateful: 1292 | assert(len(input_shape)==2), \ 1293 | "For stateful decoding, a tuple of two inputs is expected." 1294 | input_shape = input_shape[0] 1295 | 1296 | # check input dimensions for consistency 1297 | assert (input_shape[-1]==self.encoder.n), \ 1298 | 'Last dimension must be of length n.' 1299 | assert (len(input_shape)>=2), 'The inputs must have at least rank 2.' 1300 | 1301 | self._old_shape_5g = input_shape 1302 | 1303 | def call(self, inputs): 1304 | """Iterative BP decoding function. 1305 | 1306 | This function performs ``num_iter`` belief propagation decoding 1307 | iterations and returns the estimated codeword. 1308 | 1309 | Args: 1310 | inputs (tf.float32): Tensor of shape `[...,n]` containing the 1311 | channel logits/llr values. 1312 | 1313 | Returns: 1314 | `tf.float32`: Tensor of shape `[...,n]` or `[...,k]` 1315 | (``return_infobits`` is True) containing bit-wise soft-estimates 1316 | (or hard-decided bit-values) of all codeword bits (or info 1317 | bits, respectively). 1318 | 1319 | Raises: 1320 | ValueError: If ``inputs`` is not of shape `[batch_size, n]`. 1321 | 1322 | ValueError: If ``num_iter`` is not an integer greater (or equal) 1323 | `0`. 1324 | 1325 | InvalidArgumentError: When rank(``inputs``)<2. 1326 | """ 1327 | 1328 | # Extract inputs 1329 | if self._stateful: 1330 | llr_ch, msg_vn = inputs 1331 | else: 1332 | llr_ch = inputs 1333 | 1334 | tf.debugging.assert_type(llr_ch, self.dtype, 'Invalid input dtype.') 1335 | 1336 | llr_ch_shape = llr_ch.get_shape().as_list() 1337 | new_shape = [-1, llr_ch_shape[-1]] 1338 | llr_ch_reshaped = tf.reshape(llr_ch, new_shape) 1339 | batch_size = tf.shape(llr_ch_reshaped)[0] 1340 | 1341 | # invert if rate-matching output interleaver was applied as defined in 1342 | # Sec. 5.4.2.2 in 38.212 1343 | if self._encoder.num_bits_per_symbol is not None: 1344 | llr_ch_reshaped = tf.gather(llr_ch_reshaped, 1345 | self._encoder.out_int_inv, 1346 | axis=-1) 1347 | 1348 | 1349 | # undo puncturing of the first 2*Z bit positions 1350 | llr_5g = tf.concat( 1351 | [tf.zeros([batch_size, 2*self.encoder.z], self._output_dtype), 1352 | llr_ch_reshaped], 1353 | 1) 1354 | 1355 | # undo puncturing of the last positions 1356 | # total length must be n_ldpc, while llr_ch has length n 1357 | # first 2*z positions are already added 1358 | # -> add n_ldpc - n - 2Z punctured positions 1359 | k_filler = self.encoder.k_ldpc - self.encoder.k # number of filler bits 1360 | nb_punc_bits = ((self.encoder.n_ldpc - k_filler) 1361 | - self.encoder.n - 2*self.encoder.z) 1362 | 1363 | 1364 | llr_5g = tf.concat([llr_5g, 1365 | tf.zeros([batch_size, nb_punc_bits - self._nb_pruned_nodes], 1366 | self._output_dtype)], 1367 | 1) 1368 | 1369 | # undo shortening (= add 0 positions after k bits, i.e. LLR=LLR_max) 1370 | # the first k positions are the systematic bits 1371 | x1 = tf.slice(llr_5g, [0,0], [batch_size, self.encoder.k]) 1372 | 1373 | # parity part 1374 | nb_par_bits = (self.encoder.n_ldpc - k_filler 1375 | - self.encoder.k - self._nb_pruned_nodes) 1376 | x2 = tf.slice(llr_5g, 1377 | [0, self.encoder.k], 1378 | [batch_size, nb_par_bits]) 1379 | 1380 | # negative sign due to logit definition 1381 | z = -self._llr_max * tf.ones([batch_size, k_filler], self._output_dtype) 1382 | 1383 | llr_5g = tf.concat([x1, z, x2], 1) 1384 | 1385 | # and execute the decoder 1386 | if not self._stateful: 1387 | x_hat = super().call(llr_5g) 1388 | else: 1389 | x_hat,msg_vn = super().call([llr_5g, msg_vn]) 1390 | 1391 | if self._return_infobits: # return only info bits 1392 | # reconstruct u_hat # code is systematic 1393 | u_hat = tf.slice(x_hat, [0,0], [batch_size, self.encoder.k]) 1394 | # Reshape u_hat so that it matches the original input dimensions 1395 | output_shape = llr_ch_shape[0:-1] + [self.encoder.k] 1396 | # overwrite first dimension as this could be None (Keras) 1397 | output_shape[0] = -1 1398 | u_reshaped = tf.reshape(u_hat, output_shape) 1399 | 1400 | # enable other output datatypes than tf.float32 1401 | u_out = tf.cast(u_reshaped, self._output_dtype) 1402 | 1403 | if not self._stateful: 1404 | return u_out 1405 | else: 1406 | return u_out, msg_vn 1407 | 1408 | else: # return all codeword bits 1409 | # the transmitted CW bits are not the same as used during decoding 1410 | # cf. last parts of 5G encoding function 1411 | 1412 | # remove last dim 1413 | x = tf.reshape(x_hat, [batch_size, self._n_pruned]) 1414 | 1415 | # remove filler bits at pos (k, k_ldpc) 1416 | x_no_filler1 = tf.slice(x, [0, 0], [batch_size, self.encoder.k]) 1417 | 1418 | x_no_filler2 = tf.slice(x, 1419 | [0, self.encoder.k_ldpc], 1420 | [batch_size, 1421 | self._n_pruned-self.encoder.k_ldpc]) 1422 | 1423 | x_no_filler = tf.concat([x_no_filler1, x_no_filler2], 1) 1424 | 1425 | # shorten the first 2*Z positions and end after n bits 1426 | x_short = tf.slice(x_no_filler, 1427 | [0, 2*self.encoder.z], 1428 | [batch_size, self.encoder.n]) 1429 | 1430 | # if used, apply rate-matching output interleaver again as 1431 | # Sec. 5.4.2.2 in 38.212 1432 | if self._encoder.num_bits_per_symbol is not None: 1433 | x_short = tf.gather(x_short, self._encoder.out_int, axis=-1) 1434 | 1435 | # Reshape x_short so that it matches the original input dimensions 1436 | # overwrite first dimension as this could be None (Keras) 1437 | llr_ch_shape[0] = -1 1438 | x_short= tf.reshape(x_short, llr_ch_shape) 1439 | 1440 | # enable other output datatypes than tf.float32 1441 | x_out = tf.cast(x_short, self._output_dtype) 1442 | 1443 | if not self._stateful: 1444 | return x_out 1445 | else: 1446 | return x_out, msg_vn 1447 | -------------------------------------------------------------------------------- /source/mmsePIC.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------- 2 | # -- Implementation of the MMSE PIC and LoCo PIC Data Detectors 3 | # -- 4 | # -- The implementation of the MMSE PIC is based on 5 | # -- C. Studer, S. Fateh, and D. Seethaler, “ASIC Implementation of Soft-Input Soft-Output MIMO Detection Using MMSE 6 | # -- Parallel Interference Cancellation,” IEEE Journal of Solid-State Circuits, vol. 46, no. 7, pp. 1754–1765, July 2011, 7 | # -- available https://www.csl.cornell.edu/~studer/papers/11JSSC-mmsepic.pdf 8 | # -- 9 | # -- October 2022 (c) Reinhard Wiesmayr (wiesmayr@iis.ee.ethz.ch) 10 | # -- The code is based on the Sionna implementation of the LMMSE detector. 11 | # ----------------------------------------------------- 12 | 13 | import platform 14 | from tensorflow.keras.layers import Layer 15 | import sionna 16 | from sionna.mapping import * 17 | from sionna.ofdm import RemoveNulledSubcarriers 18 | from sionna.utils import split_dim, flatten_dims, expand_to_rank, flatten_last_dims, matrix_inv, matrix_sqrt_inv 19 | import numpy as np 20 | 21 | 22 | def selectDataCarryingOFDMSymbols(data_vec, rg_dim, data_ind, num_ofdm_data_symbols, num_effective_subcarriers): 23 | data_vec = flatten_dims(data_vec, 2, rg_dim) 24 | # data_ind carries indices for all data streams, we assume that they are all the same and only select the first one 25 | data_vec = tf.gather(data_vec, data_ind, axis=2) 26 | return split_dim(data_vec, [num_ofdm_data_symbols, num_effective_subcarriers], axis=rg_dim) 27 | 28 | 29 | class SisoMmsePicDetector(Layer): 30 | # pylint: disable=line-too-long 31 | """ 32 | Soft-Input Soft-Output Minimum Mean Squared Error (MMSE) Parallel Interference Cancellation Detector, based on 33 | C. Studer, S. Fateh, and D. Seethaler, “ASIC Implementation of Soft-Input Soft-Output MIMO Detection Using MMSE 34 | Parallel Interference Cancellation,” IEEE Journal of Solid-State Circuits, vol. 46, no. 7, pp. 1754–1765, July 2011, 35 | available at https://www.csl.cornell.edu/~studer/papers/11JSSC-mmsepic.pdf 36 | 37 | This implementation does NOT support XLA mode. Refer to the latest Sionna release for an implementation of MMSE PIC 38 | that supports XLA. 39 | """ 40 | def __init__(self, 41 | resource_grid, 42 | stream_management, 43 | demapping_method, 44 | constellation: sionna.mapping.Constellation, 45 | dtype=tf.complex64, low_complexity=False, 46 | regularizationEpsilon=1e-4, data_carrying_whitened_inputs = False, 47 | hyper_parameter_err_var_num_ofdm_symbols = -1, training=False, 48 | **kwargs): 49 | super().__init__(**kwargs) 50 | assert isinstance(resource_grid, sionna.ofdm.ResourceGrid) 51 | assert isinstance(stream_management, sionna.mimo.StreamManagement) 52 | self._resource_grid = resource_grid 53 | self._stream_management = stream_management 54 | self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid) 55 | self._constellation = constellation 56 | self._dtype = dtype 57 | self._epsilon = regularizationEpsilon 58 | self._low_complexity = low_complexity 59 | self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid) 60 | self._data_carrying_whitened_inputs = data_carrying_whitened_inputs 61 | 62 | # Precompute indices to extract data symbols 63 | mask = resource_grid.pilot_pattern.mask 64 | num_data_symbols = resource_grid.pilot_pattern.num_data_symbols 65 | data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING") 66 | self._data_ind = data_ind[..., :num_data_symbols] 67 | 68 | num_bits_per_symbol = self._constellation.num_bits_per_symbol 69 | num_points = int(2 ** num_bits_per_symbol) 70 | a = np.zeros([num_points, num_bits_per_symbol]) 71 | for i in range(0, num_points): 72 | a[i, :] = np.array(list(np.binary_repr(i, num_bits_per_symbol)), 73 | dtype=np.int16) 74 | 75 | self._a = a 76 | self._aBool = tf.cast(self._a, tf.bool) 77 | 78 | # Compute symbol indices for which the bits are 0 or 1 79 | c0 = np.zeros([int(num_points / 2), num_bits_per_symbol]) 80 | c1 = np.zeros([int(num_points / 2), num_bits_per_symbol]) 81 | for i in range(num_bits_per_symbol - 1, -1, -1): 82 | c0[:, i] = np.where(a[:, i] == 0)[0] 83 | c1[:, i] = np.where(a[:, i] == 1)[0] 84 | self._c0 = tf.constant(c0, dtype=tf.int32) # Symbols with ith bit=0 85 | self._c1 = tf.constant(c1, dtype=tf.int32) # Symbols with ith bit=1 86 | 87 | if constellation.normalize: 88 | n = int(num_bits_per_symbol / 2) 89 | qam_var = 1 / (2 ** (n - 2)) * np.sum(np.linspace(1, 2 ** n - 1, 2 ** (n - 1)) ** 2) 90 | self._qam_normalization_factor = 1 / np.sqrt(qam_var) 91 | 92 | else: 93 | self._qam_normalization_factor = 1 94 | 95 | if demapping_method == "app": 96 | self._reduce = tf.reduce_logsumexp 97 | else: 98 | self._reduce = tf.reduce_max 99 | 100 | if hyper_parameter_err_var_num_ofdm_symbols > 0: 101 | self._eta = tf.Variable(tf.ones([1, 1, hyper_parameter_err_var_num_ofdm_symbols, 1, 1, 1]), trainable=training, dtype=tf.float32, name="eta") 102 | else: 103 | self._eta = 1 104 | 105 | def soft_symbols(self, llr_a, points_reshaped, batch_size, num_ofdm_symbols, num_effective_subcarriers, num_tx, 106 | num_streams): 107 | 108 | p0 = 0.5 * (1 - tf.math.tanh( 109 | 0.5 * llr_a)) 110 | 111 | if self._low_complexity and self._constellation._constellation_type == "qam" and self._constellation.num_bits_per_symbol in [ 112 | 1, 2, 4, 6]: 113 | p1 = 1 - p0 114 | if self._constellation.num_bits_per_symbol == 1: 115 | # BPSK 116 | s_real = (1 - 2 * tf.gather(p1, indices=0, axis=-1)) 117 | s_imag = 0 118 | 119 | c = 1 120 | d = 0 121 | elif self._constellation.num_bits_per_symbol == 2: 122 | # QPSK 123 | s_real = (1 - 2 * tf.gather(p1, indices=0, axis=-1)) 124 | s_imag = (1 - 2 * tf.gather(p1, indices=1, axis=-1)) 125 | 126 | c = 2 127 | d = 0 128 | elif self._constellation.num_bits_per_symbol == 4: 129 | # 16-QAM 130 | s_real = (1 - 2 * tf.gather(p1, indices=0, axis=-1)) * (1 + 2 * tf.gather(p1, indices=2, axis=-1)) 131 | s_imag = (1 - 2 * tf.gather(p1, indices=1, axis=-1)) * (1 + 2 * tf.gather(p1, indices=3, axis=-1)) 132 | 133 | c = 1 + 8 * tf.gather(p1, indices=2, axis=-1) 134 | d = 1 + 8 * tf.gather(p1, indices=3, axis=-1) 135 | elif self._constellation.num_bits_per_symbol == 6: 136 | # 64-QAM 137 | raise Exception('constellation order not implemented') 138 | else: 139 | raise Exception('unsupported constellation order') 140 | 141 | s_hat = self._qam_normalization_factor * tf.complex(s_real, s_imag) 142 | # normalization can be included in previous scaling factor... 143 | error_var = self._qam_normalization_factor ** 2 * ((c + d) - tf.square(s_real) - tf.square(s_imag)) 144 | 145 | log_P_C = None 146 | else: 147 | p0 = tf.expand_dims(p0, axis=-2) 148 | p1 = 1 - p0 149 | oneBits_reshaped = tf.reshape(self._aBool, [1, 1, 1, 1, 1] + self._constellation.points.shape + 150 | self._constellation.num_bits_per_symbol) 151 | pC_bits = tf.where(oneBits_reshaped, p1, p0) 152 | 153 | P_C = tf.reduce_prod(pC_bits, axis=-1) 154 | 155 | # numerically stable way to calculate log_pC (log of constellation symbol probabilities) 156 | # following (22), (23) from C. Studer et al., "Soft–Input Soft–Output Single Tree-Search 157 | # Sphere Decoding," IEEE TRANS. ON INFORMATION THEORY, VOL. 56, NO. 10, OCTOBER 2010 158 | abs_llrs = tf.math.abs(llr_a) 159 | K_i_tilde = tf.reduce_sum(0.5 * abs_llrs + tf.math.log(1 + tf.math.exp(-abs_llrs)), axis=-1, 160 | keepdims=True) # @TODO: check axis right? 161 | 162 | x_ib = 2 * (tf.cast(oneBits_reshaped, dtype=tf.float32) - 0.5) 163 | log_P_C = - (K_i_tilde - tf.reduce_sum(0.5 * x_ib * tf.expand_dims(llr_a, axis=-2), axis=-1)) 164 | 165 | # s_hat [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers] 166 | s_hat = tf.reduce_sum(points_reshaped * tf.cast(P_C, tf.complex64), axis=-1) 167 | 168 | # Calculate Error Variance Estimate 169 | squared_error = tf.math.pow( 170 | tf.maximum(tf.abs(tf.expand_dims(s_hat, axis=-1) - points_reshaped), self._epsilon), 2) 171 | error_var = tf.reduce_sum(squared_error * P_C, axis=-1) 172 | 173 | # transform s_hat and error_var to [batch_size, 1, num_ofdm_symbols, num_effective_subcarriers, 174 | # num_tx*num_streams, 1] 175 | s_hat = tf.transpose(s_hat, [0, 3, 4, 1, 2]) 176 | error_var = tf.transpose(error_var, [0, 3, 4, 1, 2]) 177 | s_int_shape = tf.concat( 178 | [[batch_size], [1], [num_ofdm_symbols], [num_effective_subcarriers], [num_tx * num_streams, 1]], 0) 179 | s_hat = tf.reshape(s_hat, s_int_shape) 180 | error_var = tf.reshape(error_var, s_int_shape) 181 | 182 | return [s_hat, error_var, log_P_C] 183 | 184 | def LLLCalculation(self, z_i, rho_i, points_reshaped, log_P_C): 185 | # z_i is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, 1] 186 | if self._low_complexity and self._constellation._constellation_type == "qam" and self._constellation.num_bits_per_symbol in [ 187 | 1, 2, 4, 6]: 188 | # transform z_i to constellation w/o unit-energy scaling 189 | z_i = z_i / self._qam_normalization_factor 190 | 191 | if self._constellation.num_bits_per_symbol == 1: 192 | # BPSK 193 | lambda_b_1 = 4 * tf.math.real(z_i) 194 | lambda_b = lambda_b_1 195 | elif self._constellation.num_bits_per_symbol == 2: 196 | # QPSK 197 | lambda_b_1 = 4 * tf.math.real(z_i) 198 | lambda_b_2 = 4 * tf.math.imag(z_i) 199 | lambda_b = tf.concat([lambda_b_1, lambda_b_2], axis=-1) 200 | elif self._constellation.num_bits_per_symbol == 4: 201 | # 16-QAM 202 | z_i_real = tf.math.real(z_i) 203 | z_i_imag = tf.math.imag(z_i) 204 | lambda_b_1 = tf.where(tf.math.less_equal(tf.abs(z_i_real), 2), 4 * z_i_real, 205 | 8 * z_i_real - 8 * tf.sign(z_i_real)) 206 | lambda_b_2 = 8 - 4 * tf.abs(z_i_real) 207 | lambda_b_3 = tf.where(tf.math.less_equal(tf.abs(z_i_imag), 2), 4 * z_i_imag, 208 | 8 * z_i_imag - 8 * tf.sign(z_i_imag)) 209 | lambda_b_4 = 8 - 4 * tf.abs(z_i_imag) 210 | lambda_b = tf.concat([lambda_b_1, lambda_b_3, lambda_b_2, lambda_b_4], axis=-1) 211 | elif self._constellation.num_bits_per_symbol == 6: 212 | # 64-QAM 213 | raise Exception('constellation order not implemented') 214 | else: 215 | raise Exception('unsupported constellation order') 216 | 217 | lambda_b = self._qam_normalization_factor ** 2 * lambda_b 218 | llr_d = - rho_i * lambda_b # minus because of inverse LLR definition 219 | else: 220 | squared_dist = tf.math.pow(tf.math.abs(z_i - points_reshaped), 2) 221 | 222 | squared_dist = tf.maximum(squared_dist, self._epsilon ** 2) 223 | 224 | if log_P_C is not None: 225 | exponents = -squared_dist * rho_i + log_P_C # intrinsic 226 | else: 227 | exponents = -squared_dist * rho_i # extrinsic 228 | 229 | exp0 = tf.gather(exponents, self._c0, axis=-1, batch_dims=0) 230 | exp1 = tf.gather(exponents, self._c1, axis=-1, batch_dims=0) 231 | 232 | # transform 233 | # llr_d is [batch_size, 1, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, num_bits_per_symbol] 234 | llr_d = self._reduce(exp1, axis=-2) - self._reduce(exp0, axis=-2) 235 | 236 | return llr_d 237 | 238 | def call(self, inputs): 239 | y, h_hat, err_var, no, llr_a, G = inputs 240 | # y has shape: 241 | # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size] 242 | 243 | # h_hat has shape: 244 | # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,... 245 | # ..., num_ofdm_symbols, num_effective_subcarriers] 246 | 247 | # err_var has a shape that is broadcastable to h_hat 248 | 249 | # llr_a None | [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 250 | 251 | # no has shape [batch_size, num_rx] => assumed constant noise var across all Rx Antennas 252 | # or just the first n dimensions of this 253 | 254 | # prepare variables for shape 255 | batch_size = tf.shape(y)[0] 256 | num_effective_subcarriers = self._resource_grid.num_effective_subcarriers 257 | num_ofdm_data_symbols = int(self._resource_grid.num_data_symbols / num_effective_subcarriers) 258 | num_bits_per_symbol = self._constellation.num_bits_per_symbol 259 | num_tx = self._resource_grid.num_tx 260 | num_points = int(self._constellation.points.shape[0]) 261 | num_streams = self._resource_grid.num_streams_per_tx 262 | num_data_symbols = int(self._resource_grid.num_data_symbols) 263 | _type_float = tf.float32 264 | data_ind = self._data_ind[0, 0, :] 265 | 266 | if not self._data_carrying_whitened_inputs: 267 | # Remove nulled subcarriers from y (guards, dc). New shape: 268 | # [batch_size, num_rx, num_rx_ant, ... 269 | # ..., num_ofdm_symbols, num_effective_subcarriers] 270 | y_eff = self._removed_nulled_scs(y) 271 | #################################################### 272 | ### Prepare the observation y for MIMO detection ### 273 | #################################################### 274 | # Transpose y_eff to put num_rx_ant last. New shape: 275 | # [batch_size, num_rx, num_ofdm_symbols,... 276 | # ..., num_effective_subcarriers, num_rx_ant] 277 | y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2]) 278 | y_dt = tf.cast(y_dt, self._dtype) 279 | 280 | # Gather only data-carrying symbols 281 | # New shape: 282 | # [batch_size, num_rx, num_ofdm_data_symbols,... 283 | # ..., num_effective_subcarriers, num_rx_ant] 284 | y_dt = selectDataCarryingOFDMSymbols(y_dt, 2, data_ind, num_ofdm_data_symbols, num_effective_subcarriers) 285 | 286 | ############################################## 287 | ### Prepare the err_var for MIMO detection ### 288 | ############################################## 289 | # New shape is: 290 | # [batch_size, num_rx, num_ofdm_data_symbols,... 291 | # ..., num_effective_subcarriers, num_rx_ant, num_tx*num_streams] 292 | err_var_dt = tf.broadcast_to(err_var, tf.shape(h_hat)) 293 | err_var_dt = tf.transpose(err_var_dt, [0, 1, 5, 6, 2, 3, 4]) 294 | err_var_dt = flatten_last_dims(err_var_dt, 2) 295 | err_var_dt = tf.cast(err_var_dt, self._dtype) 296 | err_var_dt = selectDataCarryingOFDMSymbols(err_var_dt, 2, data_ind, num_ofdm_data_symbols, 297 | num_effective_subcarriers) 298 | err_var_dt = err_var_dt * tf.cast(self._eta, dtype=tf.complex64) 299 | 300 | ############################### 301 | ### Construct MIMO channels ### 302 | ############################### 303 | 304 | # Reshape h_hat for the construction of desired/interfering channels: 305 | # [num_rx, num_tx, num_streams_per_tx, batch_size, num_rx_ant, ,... 306 | # ..., num_ofdm_symbols, num_effective_subcarriers] 307 | perm = [1, 3, 4, 0, 2, 5, 6] 308 | h_dt = tf.transpose(h_hat, perm) 309 | 310 | # Flatten first three dimensions: 311 | # [num_rx*num_tx*num_streams_per_tx, batch_size, num_rx_ant, ... 312 | # ..., num_ofdm_symbols, num_effective_subcarriers] 313 | h_dt = flatten_dims(h_dt, 3, 0) 314 | 315 | # Gather desired and undesired channels 316 | ind_desired = self._stream_management.detection_desired_ind 317 | ind_undesired = self._stream_management.detection_undesired_ind 318 | h_dt_desired = tf.gather(h_dt, ind_desired, axis=0) 319 | h_dt_undesired = tf.gather(h_dt, ind_undesired, axis=0) 320 | 321 | # Split first dimension to separate RX and TX: 322 | # [num_rx, num_streams_per_rx, batch_size, num_rx_ant, ... 323 | # ..., num_ofdm_symbols, num_effective_subcarriers] 324 | h_dt_desired = split_dim(h_dt_desired, [self._stream_management.num_rx, -1], 0) 325 | h_dt_undesired = split_dim(h_dt_undesired, [self._stream_management.num_rx, -1], 0) 326 | 327 | # Permutate dims to 328 | # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. 329 | # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] 330 | perm = [2, 0, 4, 5, 3, 1] 331 | h_dt_desired = tf.transpose(h_dt_desired, perm) 332 | h_dt_desired = tf.cast(h_dt_desired, self._dtype) 333 | h_dt_undesired = tf.transpose(h_dt_undesired, perm) 334 | h_dt_desired = selectDataCarryingOFDMSymbols(h_dt_desired, 2, data_ind, num_ofdm_data_symbols, 335 | num_effective_subcarriers) 336 | h_dt_undesired = selectDataCarryingOFDMSymbols(h_dt_undesired, 2, data_ind, num_ofdm_data_symbols, 337 | num_effective_subcarriers) 338 | 339 | ################################## 340 | ### Prepare the noise variance ### 341 | ################################## 342 | # no is first broadcast to [batch_size, num_rx, num_rx_ant] 343 | # then the rank is expanded to that of y 344 | # then it is transposed like y to the final shape 345 | # [batch_size, num_rx, num_ofdm_symbols,... 346 | # ..., num_effective_subcarriers, num_rx_ant] 347 | no_dt = expand_to_rank(no, 3, -1) 348 | no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3]) 349 | no_dt = expand_to_rank(no_dt, tf.rank(y), -1) 350 | no_dt = tf.transpose(no_dt, [0, 1, 3, 4, 2]) 351 | no_dt = tf.cast(no_dt, self._dtype) 352 | 353 | ################################################## 354 | ### Compute the interference covariance matrix ### 355 | ################################################## 356 | # Covariance of undesired transmitters 357 | s_inf = tf.matmul(h_dt_undesired, h_dt_undesired, adjoint_b=True) 358 | 359 | # Thermal noise 360 | s_no = tf.linalg.diag(no_dt) 361 | 362 | # Channel estimation errors 363 | # As we have only error variance information for each element, 364 | # we simply sum them across transmitters and build a 365 | # diagonal covariance matrix from this 366 | s_csi = tf.linalg.diag(tf.reduce_sum(err_var_dt, -1)) 367 | 368 | # Final covariance matrix 369 | s = s_inf + s_no + s_csi 370 | s = tf.cast(s, self._dtype) 371 | 372 | # Noise+Interference Whitening 373 | s_inv_1_2 = matrix_sqrt_inv(s) 374 | 375 | # Whiten the observation 376 | y_dt = tf.expand_dims(y_dt, -1) 377 | y_dt_whitened = tf.matmul(s_inv_1_2, y_dt) 378 | 379 | # Compute channel after whitening 380 | h_dt_desired_whitened = tf.matmul(s_inv_1_2, h_dt_desired) 381 | 382 | # Step 1: Compute Gram matrix 383 | # h_dt_desired is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. 384 | # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] 385 | # G is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 386 | else: 387 | h_dt_desired_whitened = h_hat 388 | y_dt_whitened = y 389 | if G is None: 390 | G = tf.linalg.matmul(h_dt_desired_whitened, h_dt_desired_whitened, adjoint_a=True) 391 | 392 | # y_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx] 393 | y_MF = tf.linalg.matmul(h_dt_desired_whitened, y_dt_whitened, adjoint_a=True) 394 | else: 395 | y_MF = y 396 | ############################################################ 397 | #### SISO LMMSE PIC ### 398 | # following Algorithm 1 from [1] 399 | ############################################################ 400 | 401 | # Calculate Soft Symbols 402 | points_reshaped = tf.reshape(self._constellation.points, [1] * 5 + [num_points]) 403 | 404 | if llr_a is None: 405 | # no a priori LLR => no parallel interference cancellation 406 | y_hat_i_MF = y_MF 407 | # _lambda = None 408 | _error_var_row_vec = None 409 | log_P_C = None 410 | error_var = 1 411 | llr_a_out = 0 412 | _A = G 413 | else: 414 | # Step 2: Calculte soft-symbols and variances 415 | 416 | # llr_a is [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol] 417 | # reshape to [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, num_bits_per_symbol] 418 | llr_a_out = llr_a 419 | llr_a = tf.expand_dims(llr_a, axis=-1) 420 | llr_a = tf.expand_dims(llr_a, axis=-3) 421 | llr_int_shape = tf.concat( 422 | [tf.shape(llr_a)[:-3], [num_ofdm_data_symbols, num_effective_subcarriers, num_bits_per_symbol]], 0) 423 | llr_a = tf.reshape(llr_a, llr_int_shape) 424 | 425 | # Compute log(P(points)) from llr_a 426 | # [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, num_constellation] 427 | # logPb0 = np.log(1 + np.exp(llr_a)) # numerical instability exp(large) is inf => Jacobi Logarithm 428 | 429 | [s_hat, error_var, log_P_C] = self.soft_symbols(llr_a, points_reshaped, batch_size, num_ofdm_data_symbols, 430 | num_effective_subcarriers, num_tx, num_streams) 431 | 432 | # Step 3: Perform PIC 433 | # H^H y_hat_i = y_MF - sum_j!=i gj s_hat_j = y + g_i s_hat_i - sum_j g_j s_hat_j 434 | # y_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx] 435 | # G is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 436 | _g_j_s_hat_j = tf.linalg.matmul(G, s_hat) 437 | _s_hat = tf.transpose(s_hat, [0, 1, 2, 3, 5, 4]) 438 | y_hat_i_MF = y_MF + G * _s_hat - _g_j_s_hat_j 439 | # y_hat_i_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, 440 | # num_tx*num_streams] 441 | 442 | # Step 4: Compute A 443 | # Calculate MMSE Filter (efficiently) 444 | # W^H = A^-1 H^H 445 | # A = H^H H \Lambda + N_0 I_Mt 446 | # \Lambda_ii = E_i = error_var 447 | 448 | _error_var_row_vec = tf.linalg.matrix_transpose(error_var) 449 | # _lambda = tf.linalg.diag(tf.squeeze(error_var, axis=-1)) 450 | # _lambda is [batch_size, 1, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, num_tx*num_streams] 451 | # _A = tf.matmul(G, tf.cast(_lambda, dtype=self.dtype)) 452 | _A = G * tf.cast(_error_var_row_vec, dtype=self.dtype) 453 | 454 | # calculate LMMSE filter (unit power Tx signals, perfect CSI) 455 | # _A is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 456 | # _I_NT is [1, 1, 1, 1, num_streams_per_rx, num_streams_per_rx] 457 | _I_NT = tf.linalg.eye(tf.shape(_A)[-1], dtype=self.dtype) 458 | _I_NT = tf.reshape(_I_NT, tf.concat([[1] * (_A._rank() - 2), tf.shape(_I_NT)], 0)) 459 | # thermal noise is identity after noise whitening 460 | _A = _A + _I_NT 461 | 462 | # Step 5: compute MMSE filter and outputs, calculate A\H^H 463 | # A_inv is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 464 | # calculating inverse explicitly is necessary 465 | A_inv = tf.linalg.inv(_A) 466 | # A_inv_Hermitian = tf.transpose(A_inv, conjugate=True, perm=[0, 1, 2, 3, 5, 4]) 467 | 468 | # G and [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 469 | # mu_i = a_i^H g_i 470 | _G_trans = tf.linalg.matrix_transpose(G) 471 | mu_i = tf.math.real(tf.reduce_sum(A_inv * _G_trans, axis=-1, keepdims=True)) 472 | # mu_i is [batch_size, 1, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, 1] 473 | 474 | rho_i = tf.divide(mu_i, tf.maximum(1 - error_var * mu_i, self._epsilon)) 475 | # z_i = W^H y_dt = mu_i^-1 a_i^H y_hat_i_MF 476 | # y_hat_i_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, 477 | # num_tx*num_streams] 478 | 479 | # h_i^H h_i 480 | # channel_strengths = tf.linalg.diag_part(G) 481 | # normalization_chan_strength = tf.linalg.diag(1 / channel_strengths) 482 | 483 | if llr_a is not None: 484 | # z_i = tf.linalg.matmul(A_inv / tf.cast(mu_i, dtype=self.dtype), y_hat_i_MF) 485 | # z_i = tf.linalg.diag_part(z_i) 486 | y_hat_i_MF_trans = tf.linalg.matrix_transpose(y_hat_i_MF) 487 | z_i = tf.squeeze( 488 | tf.reduce_sum(A_inv * y_hat_i_MF_trans, axis=-1, keepdims=True) / tf.cast(mu_i, dtype=self.dtype), 489 | axis=-1) 490 | 491 | ### LMMSE calculation done => continue with LLR calculation 492 | 493 | # Step 6: calculate LLRs 494 | 495 | # calculate exponents 496 | # Compute squared distances from y to all points 497 | 498 | # log_P_C is [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, 499 | # num_constellation] transform log_P_C to [batch_size, 1, num_ofdm_symbols, num_effective_subcarriers, 500 | # num_tx*num_streams, num_constellation] 501 | if log_P_C is not None: 502 | log_P_C = tf.transpose(log_P_C, [0, 3, 4, 1, 2, 5]) 503 | log_P_C_int_shape = tf.concat( 504 | [[batch_size], [1], [num_ofdm_data_symbols], [num_effective_subcarriers], [num_tx * num_streams], 505 | [num_points]], 0) 506 | log_P_C = tf.reshape(log_P_C, log_P_C_int_shape) 507 | 508 | z_i = tf.expand_dims(z_i, axis=-1) 509 | else: 510 | z_i = tf.linalg.matmul(A_inv, y_hat_i_MF) / tf.cast(mu_i, dtype=self.dtype) 511 | # z_i is [batch_size, num_rx, num_ofdm_data_symbols, num_effective_subcarriers, num_streams_per_rx, 1] 512 | llr_d = self.LLLCalculation(z_i, rho_i, points_reshaped, log_P_C) 513 | 514 | # llr_d = tf.reduce_logsumexp(exp1, axis=-2) - tf.reduce_logsumexp(exp0, axis=-2) 515 | 516 | # internal llr_a shape [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, 517 | # num_bits_per_symbol] 518 | # outer llr_a shape is [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol] 519 | # convert llr_d to out-shape 520 | llr_d = tf.squeeze(llr_d, axis=[1]) 521 | tmp_shape = tf.concat([[batch_size], [num_ofdm_data_symbols], [num_effective_subcarriers], [num_tx], 522 | [num_streams], [num_bits_per_symbol]], 0) 523 | llr_d = tf.reshape(llr_d, tmp_shape) 524 | llr_d = tf.transpose(llr_d, [0, 3, 4, 1, 2, 5]) 525 | out_shape = tf.concat([[batch_size], [num_tx], [num_streams], [num_data_symbols * num_bits_per_symbol]], 0) 526 | llr_d = tf.reshape(llr_d, out_shape) 527 | 528 | # subtract llr_a => llr_e = llr_d - llr_a 529 | if self._low_complexity: 530 | llr_e = llr_d 531 | else: 532 | llr_e = llr_d - llr_a_out 533 | 534 | return [llr_e, y_MF, h_dt_desired_whitened, G] 535 | 536 | # SISO LoCo PIC 537 | class sisoLoCoPicDetector(SisoMmsePicDetector): 538 | # pylint: disable=line-too-long 539 | """ 540 | Soft-Input Soft-Output Low-Complexity (LoCo) Parallel Interference Cancellation Detector 541 | with trainable parameter \alpha (and \beta). 542 | """ 543 | 544 | def __init__(self, 545 | resource_grid, 546 | stream_management, 547 | demapping_method, 548 | constellation=None, 549 | trainable=False, 550 | alpha0=1, 551 | dtype=tf.complex64, 552 | regularizationEpsilon=1e-4, 553 | data_carrying_whitened_inputs = False, 554 | low_complexity=False, 555 | two_variables=False, 556 | beta0=0, 557 | error_var_term="default"): 558 | super().__init__(resource_grid=resource_grid, stream_management=stream_management, 559 | demapping_method=demapping_method, data_carrying_whitened_inputs=data_carrying_whitened_inputs, 560 | constellation=constellation, low_complexity=low_complexity, 561 | dtype=dtype) 562 | assert isinstance(resource_grid, sionna.ofdm.ResourceGrid) 563 | assert isinstance(stream_management, sionna.mimo.StreamManagement) 564 | self._resource_grid = resource_grid 565 | self._stream_management = stream_management 566 | self._removed_nulled_scs = RemoveNulledSubcarriers(self._resource_grid) 567 | 568 | self._trainable = trainable 569 | 570 | self._two_variables = two_variables 571 | 572 | self._epsilon = regularizationEpsilon 573 | 574 | # Precompute indices to extract data symbols 575 | mask = resource_grid.pilot_pattern.mask 576 | num_data_symbols = resource_grid.pilot_pattern.num_data_symbols 577 | data_ind = tf.argsort(flatten_last_dims(mask), direction="ASCENDING") 578 | self._data_ind = data_ind[..., :num_data_symbols] 579 | 580 | num_bits_per_symbol = self._constellation.num_bits_per_symbol 581 | num_points = int(2 ** num_bits_per_symbol) 582 | a = np.zeros([num_points, num_bits_per_symbol]) 583 | for i in range(0, num_points): 584 | a[i, :] = np.array(list(np.binary_repr(i, num_bits_per_symbol)), 585 | dtype=np.int16) 586 | 587 | self._error_var_term = error_var_term 588 | 589 | self._a = a 590 | self._aBool = tf.cast(self._a, tf.bool) 591 | 592 | self._alpha0 = alpha0 593 | self._beta0 = beta0 594 | 595 | self._alpha = tf.Variable(self._alpha0, dtype=tf.as_dtype(self._dtype).real_dtype, trainable=self._trainable, 596 | name="alpha_lmmse") 597 | if two_variables: 598 | self._beta = tf.Variable(self._beta0, dtype=tf.as_dtype(self._dtype).real_dtype, trainable=self._trainable, 599 | name="beta_mf") 600 | else: 601 | self._beta = beta0 602 | 603 | @property 604 | def alpha(self): 605 | return self._alpha 606 | 607 | @property 608 | def beta(self): 609 | return self._beta 610 | 611 | def call(self, inputs): 612 | y, h_hat, err_var, no, llr_a, A_inv, G, mu_i = inputs # attention: obey right order of input variables!!! 613 | # y has shape: 614 | # [batch_size, num_rx, num_rx_ant, num_ofdm_symbols, fft_size] 615 | 616 | # h_hat has shape: 617 | # [batch_size, num_rx, num_rx_ant, num_tx, num_streams,... 618 | # ..., num_ofdm_symbols, num_effective_subcarriers] 619 | 620 | # err_var has a shape that is broadcastable to h_hat 621 | 622 | # llr_a None | [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol], tf.float 623 | 624 | # no has shape [batch_size, num_rx] => assumed constant noise var across all Rx Antennas 625 | # or just the first n dimensions of this 626 | # prepare variables for shape 627 | batch_size = tf.shape(y)[0] 628 | num_effective_subcarriers = self._resource_grid.num_effective_subcarriers 629 | num_ofdm_data_symbols = int(self._resource_grid.num_data_symbols / num_effective_subcarriers) 630 | 631 | num_bits_per_symbol = self._constellation.num_bits_per_symbol 632 | num_tx = self._resource_grid.num_tx 633 | num_points = np.cast[int](self._constellation.points.shape)[0] 634 | num_streams = self._resource_grid.num_streams_per_tx 635 | num_data_symbols = int(self._resource_grid.num_data_symbols) 636 | _type_float = tf.float32 637 | data_ind = self._data_ind[0, 0, :] 638 | 639 | if not self._data_carrying_whitened_inputs: 640 | # Remove nulled subcarriers from y (guards, dc). New shape: 641 | # [batch_size, num_rx, num_rx_ant, ... 642 | # ..., num_ofdm_symbols, num_effective_subcarriers] 643 | y_eff = self._removed_nulled_scs(y) 644 | #################################################### 645 | ### Prepare the observation y for MIMO detection ### 646 | #################################################### 647 | # Transpose y_eff to put num_rx_ant last. New shape: 648 | # [batch_size, num_rx, num_ofdm_symbols,... 649 | # ..., num_effective_subcarriers, num_rx_ant] 650 | y_dt = tf.transpose(y_eff, [0, 1, 3, 4, 2]) 651 | y_dt = tf.cast(y_dt, self._dtype) 652 | 653 | # Gather only data-carrying symbols 654 | # New shape: 655 | # [batch_size, num_rx, num_ofdm_data_symbols,... 656 | # ..., num_effective_subcarriers, num_rx_ant] 657 | y_dt = selectDataCarryingOFDMSymbols(y_dt, 2, data_ind, num_ofdm_data_symbols, num_effective_subcarriers) 658 | ############################################## 659 | ### Prepare the err_var for MIMO detection ### 660 | ############################################## 661 | # New shape is: 662 | # [batch_size, num_rx, num_ofdm_symbols,... 663 | # ..., num_effective_subcarriers, num_rx_ant, num_tx*num_streams] 664 | err_var_dt = tf.broadcast_to(err_var, tf.shape(h_hat)) 665 | err_var_dt = tf.transpose(err_var_dt, [0, 1, 5, 6, 2, 3, 4]) 666 | err_var_dt = flatten_last_dims(err_var_dt, 2) 667 | err_var_dt = tf.cast(err_var_dt, self._dtype) 668 | err_var_dt = selectDataCarryingOFDMSymbols(err_var_dt, 2, data_ind, num_ofdm_data_symbols, 669 | num_effective_subcarriers) 670 | 671 | ############################### 672 | ### Construct MIMO channels ### 673 | ############################### 674 | 675 | # Reshape h_hat for the construction of desired/interfering channels: 676 | # [num_rx, num_tx, num_streams_per_tx, batch_size, num_rx_ant, ,... 677 | # ..., num_ofdm_symbols, num_effective_subcarriers] 678 | perm = [1, 3, 4, 0, 2, 5, 6] 679 | h_dt = tf.transpose(h_hat, perm) 680 | 681 | # Flatten first three dimensions: 682 | # [num_rx*num_tx*num_streams_per_tx, batch_size, num_rx_ant, ... 683 | # ..., num_ofdm_symbols, num_effective_subcarriers] 684 | h_dt = flatten_dims(h_dt, 3, 0) 685 | 686 | # Gather desired and undesired channels 687 | ind_desired = self._stream_management.detection_desired_ind 688 | ind_undesired = self._stream_management.detection_undesired_ind 689 | h_dt_desired = tf.gather(h_dt, ind_desired, axis=0) 690 | h_dt_undesired = tf.gather(h_dt, ind_undesired, axis=0) 691 | 692 | # Split first dimension to separate RX and TX: 693 | # [num_rx, num_streams_per_rx, batch_size, num_rx_ant, ... 694 | # ..., num_ofdm_symbols, num_effective_subcarriers] 695 | h_dt_desired = split_dim(h_dt_desired, [self._stream_management.num_rx, -1], 0) 696 | h_dt_undesired = split_dim(h_dt_undesired, [self._stream_management.num_rx, -1], 0) 697 | 698 | # Permutate dims to 699 | # [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. 700 | # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] 701 | perm = [2, 0, 4, 5, 3, 1] 702 | h_dt_desired = tf.transpose(h_dt_desired, perm) 703 | h_dt_desired = tf.cast(h_dt_desired, self._dtype) 704 | h_dt_undesired = tf.transpose(h_dt_undesired, perm) 705 | h_dt_desired = selectDataCarryingOFDMSymbols(h_dt_desired, 2, data_ind, num_ofdm_data_symbols, 706 | num_effective_subcarriers) 707 | h_dt_undesired = selectDataCarryingOFDMSymbols(h_dt_undesired, 2, data_ind, num_ofdm_data_symbols, 708 | num_effective_subcarriers) 709 | 710 | ################################## 711 | ### Prepare the noise variance ### 712 | ################################## 713 | # no is first broadcast to [batch_size, num_rx, num_rx_ant] 714 | # then the rank is expanded to that of y 715 | # then it is transposed like y to the final shape 716 | # [batch_size, num_rx, num_ofdm_symbols,... 717 | # ..., num_effective_subcarriers, num_rx_ant] 718 | no_dt = expand_to_rank(no, 3, -1) 719 | no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3]) 720 | no_dt = expand_to_rank(no_dt, tf.rank(y), -1) 721 | no_dt = tf.transpose(no_dt, [0, 1, 3, 4, 2]) 722 | no_dt = tf.cast(no_dt, self._dtype) 723 | ################################################## 724 | ### Compute the interference covariance matrix ### 725 | ################################################## 726 | # Covariance of undesired transmitters 727 | s_inf = tf.matmul(h_dt_undesired, h_dt_undesired, adjoint_b=True) 728 | 729 | # Thermal noise 730 | s_no = tf.linalg.diag(no_dt) 731 | 732 | # Channel estimation errors 733 | # As we have only error variance information for each element, 734 | # we simply sum them across transmitters and build a 735 | # diagonal covariance matrix from this 736 | s_csi = tf.linalg.diag(tf.reduce_sum(err_var_dt, -1)) 737 | 738 | # Final covariance matrix 739 | s = s_inf + s_no + s_csi 740 | s = tf.cast(s, self._dtype) 741 | 742 | # Noise+Interference Whitening 743 | s_inv_1_2 = matrix_sqrt_inv(s) 744 | 745 | # Whiten the observation 746 | y_dt = tf.expand_dims(y_dt, -1) 747 | y_dt_whitened = tf.matmul(s_inv_1_2, y_dt) 748 | 749 | # Compute channel after whitening 750 | h_dt_desired_whitened = tf.matmul(s_inv_1_2, h_dt_desired) 751 | 752 | # Step 1: Compute Gram matrix 753 | # h_dt_desired is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. 754 | # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] 755 | # G is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 756 | else: 757 | h_dt_desired_whitened = h_hat 758 | y_dt_whitened = y 759 | ################################## 760 | ### Prepare the noise variance ### 761 | ################################## 762 | # no is broadcast to [batch_size, num_rx, num_rx_ant] 763 | # no_dt = expand_to_rank(no, 3, -1) 764 | # no_dt = tf.broadcast_to(no_dt, tf.shape(y)[:3]) 765 | 766 | ############################################################ 767 | #### SISO LMMSE PIC ### 768 | ############################################################ 769 | if G is None: 770 | G = tf.linalg.matmul(h_dt_desired_whitened, h_dt_desired_whitened, adjoint_a=True) 771 | # y_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx] 772 | y_MF = tf.linalg.matmul(h_dt_desired_whitened, y_dt_whitened, adjoint_a=True) 773 | else: 774 | y_MF = y 775 | 776 | 777 | ## following Algorithm 1 from [1] 778 | # Step 1: Compute Gram matrix 779 | # h_dt_desired is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,.. 780 | # ..., num_rx_ant, num_streams_per_rx(num_Interfering_streams_per_rx)] 781 | # G is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 782 | 783 | # _I_Nt is [1, 1, 1, 1, num_streams_per_rx, num_streams_per_rx] 784 | _I_Nt = tf.linalg.eye(tf.shape(G)[-1], dtype=self.dtype) 785 | _I_Nt = tf.reshape(_I_Nt, tf.concat([[1] * (G._rank() - 2), tf.shape(_I_Nt)], 0)) 786 | 787 | # Calculate Soft Symbols 788 | points_reshaped = tf.reshape(self._constellation.points, [1] * 5 + [num_points]) 789 | 790 | if llr_a is None: 791 | # no a priori LLR => no parallel interference cancellation 792 | y_hat_i_MF = y_MF 793 | log_pC = None # tf.math.log(tf.cast(1 / num_points, dtype=tf.float32)) 794 | error_var = 1 795 | llr_a_out = 0 796 | _lambda = _I_Nt 797 | else: 798 | # Step 2: Calculte soft-symbols and variances 799 | 800 | # llr_a is [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol] 801 | # reshape to [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, num_bits_per_symbol] 802 | llr_a_out = llr_a 803 | llr_a = tf.expand_dims(llr_a, axis=-1) 804 | llr_a = tf.expand_dims(llr_a, axis=-3) 805 | llr_int_shape = tf.concat( 806 | [tf.shape(llr_a)[:-3], [num_ofdm_data_symbols, num_effective_subcarriers, num_bits_per_symbol]], 0) 807 | llr_a = tf.reshape(llr_a, llr_int_shape) 808 | 809 | [s_hat, error_var, log_pC] = self.soft_symbols(llr_a, points_reshaped, batch_size, num_ofdm_data_symbols, 810 | num_effective_subcarriers, num_tx, num_streams) 811 | 812 | # Step 3: Perform PIC 813 | # H^H y_hat_i = y_MF - sum_j!=i gj s_hat_j = y + g_i s_hat_i - sum_j g_j s_hat_j 814 | # y_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx] 815 | # G is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 816 | # _g_j_s_hat_j = tf.linalg.matmul(G, s_hat) 817 | _s_hat = tf.linalg.matrix_transpose(s_hat) 818 | # y_hat_i_MF_old = tf.expand_dims(y_MF, axis=-1) + G * _s_hat - _g_j_s_hat_j 819 | 820 | _G_times_s_hat = G * _s_hat 821 | _g_j_s_hat_j = tf.reduce_sum(_G_times_s_hat, axis=-1, keepdims=True) 822 | y_hat_i_MF = y_MF + _G_times_s_hat - _g_j_s_hat_j # @TODO Debug and verify 823 | 824 | # y_hat_i_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, 825 | # num_tx*num_streams] 826 | 827 | # Step 4: Compute A 828 | # Calculate MMSE MF Filter (efficiently) 829 | # W = (alpha * (1/diag(A^-1 G)) A^-1 + beta * (1/diag(G))) * H^H 830 | # Z_bar = alpha * (1/diag(A^-1 G)) A^-1 + beta * (1/diag(G)) 831 | 832 | # \Lambda_ii = E_i = error_var 833 | _lambda = tf.linalg.diag(tf.squeeze(error_var, axis=-1)) 834 | # _lambda is [batch_size, 1, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, num_tx*num_streams] 835 | 836 | no_dt = 1 # expand_to_rank(no, tf.rank(_I_Nt), -1) 837 | no_dt_complex = tf.cast(no_dt, dtype=self.dtype) 838 | if A_inv is None: 839 | # calculate LMMSE filter (unit power Tx signals, perfect CSI) 840 | # _A is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 841 | _A = G 842 | _A = _A + no_dt_complex * _I_Nt 843 | 844 | # Step 5: compute MMSE filter and outputs, calculate A\H^H 845 | # A_inv is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 846 | A_inv = matrix_inv(_A) 847 | # A_inv_Hermitian = tf.transpose(A_inv, conjugate=True, perm=[0, 1, 2, 3, 5, 4]) 848 | 849 | 850 | # G is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 851 | # Calculate normalized partial filter 852 | # normalizing_lmmse = tf.linalg.diag(1 / tf.linalg.diag_part(tf.matmul(A_inv, G))) 853 | 854 | # preprocessing 855 | if mu_i is None: 856 | mu_i = tf.reduce_sum(A_inv * tf.linalg.matrix_transpose(G), axis=-1, keepdims=True) 857 | mu_i_real = tf.math.real(mu_i) 858 | normalizing_lmmse_col_vec = 1 / mu_i 859 | 860 | normalizing_mf_vec = 1 / tf.linalg.diag_part(G) 861 | # normalizing_mf = tf.linalg.diag(normalizing_mf_vec) 862 | # normalizing_mf_col_vec = tf.expand_dims(1 / tf.linalg.diag_part(G), axis=-1) 863 | 864 | # z_i = W^H y_dt = mu_i^-1 a_i^H y_hat_i_MF 865 | # if PIC: y_hat_i_MF is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, 866 | # num_tx*num_streams]; else:[batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,num_tx*num_streams, 1] 867 | alpha = tf.cast(self.alpha, dtype=self.dtype) 868 | if self._two_variables: 869 | beta = tf.cast(self.beta, dtype=self.dtype) 870 | else: 871 | beta = 1 - alpha 872 | # beta = tf.cast(self.beta, dtype=self.dtype) 873 | Z_bar = (alpha * normalizing_lmmse_col_vec) * A_inv + tf.linalg.diag(beta * normalizing_mf_vec) 874 | # rho_i = 1 875 | if llr_a is not None: # i.e. PIC 876 | # calculate error variance 877 | # error_i is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_streams_per_rx, num_streams_per_rx] 878 | 879 | # if PIC: s_hat_i is [batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers, num_tx*num_streams, 880 | # num_tx*num_streams]; else:[batch_size, num_rx, num_ofdm_symbols, num_effective_subcarriers,num_tx*num_streams, 1] 881 | G_abs = tf.math.abs(G) 882 | G_uu = tf.expand_dims(tf.linalg.diag_part(G_abs), axis=-1) 883 | if self._error_var_term == "default": 884 | _lambda = tf.cast(_lambda, dtype=self.dtype) 885 | theta_i = tf.linalg.diag_part(tf.matmul(Z_bar @ ((G @ _lambda + no_dt_complex * _I_Nt) @ G), Z_bar, 886 | adjoint_b=True) - _lambda) 887 | theta_i = tf.math.real(theta_i) 888 | rho_i = tf.expand_dims(1 / theta_i, axis=-1) 889 | elif self._error_var_term == "sinr_heuristic": 890 | rho_i = G_uu / (no_dt + G_abs @ error_var) 891 | elif self._error_var_term == "sinr_heuristic2": 892 | rho_i = G_uu / (no_dt + G_uu * error_var) 893 | elif self._error_var_term == "sinr_heuristic3": # by inspection heuristic after PIC and matched-filtering 894 | rho_i = G_uu / (no_dt + (G_abs - tf.linalg.diag(tf.squeeze(G_uu, axis=-1))) @ error_var) 895 | elif self._error_var_term == "sinr_heuristic4": # by inspection variance after PIC and matched-filtering 896 | G_uu_squared = tf.square(G_uu) 897 | G_abs_squared = tf.square(G_abs) 898 | rho_i = G_uu_squared / (G_uu * no_dt + (G_abs_squared - tf.linalg.diag( 899 | tf.squeeze(G_uu_squared, axis=-1))) @ error_var) 900 | elif self._error_var_term == "ocd_paper": 901 | mu_tilde_i = G_uu / (G_uu + no_dt) 902 | rho_i = mu_tilde_i/(1-mu_tilde_i) 903 | elif self._error_var_term == "ocd_paper2": 904 | rho_i = G_uu/no_dt 905 | elif self._error_var_term == "lmmse": # MMSE PIC NPI expression - suboptimal (only optimal if MMSE PIC) 906 | rho_i = tf.divide(mu_i_real, tf.maximum(1 - error_var * mu_i_real, self._epsilon)) 907 | else: 908 | raise Exception('unsupported error variance term') 909 | # print(str(np.min(llr_a_out.numpy()))) 910 | # filtering 911 | s_hat_i = tf.reduce_sum(Z_bar * tf.linalg.matrix_transpose(y_hat_i_MF), axis=-1) 912 | 913 | # Step 6: calculate LLRs 914 | 915 | # calculate exponents 916 | # Compute squared distances from y to all points 917 | 918 | # pC is [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, 919 | # num_constellation]; transform pC to [batch_size, 1, num_ofdm_symbols, num_effective_subcarriers, 920 | # num_tx*num_streams, num_constellation] 921 | if log_pC is not None: # is none in low-complexity mode 922 | log_pC = tf.transpose(log_pC, [0, 3, 4, 1, 2, 5]) 923 | log_pC_int_shape = tf.concat( 924 | [[batch_size], [1], [num_ofdm_data_symbols], [num_effective_subcarriers], [num_tx * num_streams], 925 | [num_points]], 0) 926 | log_pC = tf.reshape(log_pC, log_pC_int_shape) 927 | 928 | s_hat_i = tf.expand_dims(s_hat_i, axis=-1) 929 | # squared_dist = tf.math.pow(tf.math.abs(tf.expand_dims(s_hat_i, axis=-1) - points_reshaped), 2) 930 | else: # preprocessing / first IDD iteration 931 | # intially (when LMMSE filtering) --> apply LMMSE error variance 932 | # rho_i = tf.divide(mu_i_real, tf.maximum(1 - error_var * mu_i_real, self._epsilon)) 933 | rho_i = tf.divide(mu_i_real, tf.maximum(1 - mu_i_real, self._epsilon)) 934 | s_hat_i = tf.linalg.matmul(Z_bar, y_hat_i_MF) 935 | # squared_dist = tf.math.pow(tf.math.abs(s_hat_i - points_reshaped), 2) 936 | 937 | llr_d = self.LLLCalculation(s_hat_i, rho_i, points_reshaped, log_pC) 938 | 939 | # internal llr_a shape [batch_size, num_tx, num_streams, num_ofdm_symbols, num_effective_subcarriers, 940 | # num_bits_per_symbol] 941 | # outer llr_a shape is [batch_size, num_tx, num_streams, num_data_symbols*num_bits_per_symbol] 942 | # convert llr_d to out-shape 943 | llr_d = tf.squeeze(llr_d, axis=[1]) 944 | tmp_shape = tf.concat([[batch_size], [num_ofdm_data_symbols], [num_effective_subcarriers], [num_tx], [num_streams], 945 | [num_bits_per_symbol]], 0) 946 | llr_d = tf.reshape(llr_d, tmp_shape) 947 | llr_d = tf.transpose(llr_d, [0, 3, 4, 1, 2, 5]) 948 | out_shape = tf.concat([[batch_size], [num_tx], [num_streams], [num_data_symbols * num_bits_per_symbol]], 0) 949 | llr_d = tf.reshape(llr_d, out_shape) 950 | 951 | # subtract llr_a => llr_e = llr_d - llr_a 952 | if log_pC is not None: 953 | llr_e = llr_d - llr_a_out # extrinsic LLRs 954 | else: 955 | llr_e = llr_d 956 | 957 | return [llr_e, A_inv, G, y_MF, mu_i, h_dt_desired_whitened] 958 | -------------------------------------------------------------------------------- /source/simulationFunctions.py: -------------------------------------------------------------------------------- 1 | # -- a priori input sweep I_A 2 | import pickle 3 | import time 4 | 5 | import numpy as np 6 | from datetime import datetime 7 | import pandas as pd 8 | import tensorflow as tf 9 | 10 | ######################################## 11 | ## This file contains utility functions 12 | ######################################## 13 | 14 | 15 | def save_data(sim_title, plot_data, sim_params=None, path="./fig/data/simulationResults/"): 16 | try: 17 | filename = datetime.now().strftime("%Y-%m-%d %H-%M ") + sim_title.replace("&", "").replace(".", "").replace(" ", 18 | "") 19 | file = open(path + filename + ".csv", "w") 20 | df = pd.DataFrame.from_dict(plot_data) 21 | df.to_csv(file, line_terminator='\n') 22 | file.close() 23 | 24 | with open(path + filename + ".pickle", 'wb') as f: 25 | pickle.dump(plot_data, f) 26 | 27 | if sim_params is not None: 28 | file = open(path + filename + '_params.csv', "w") 29 | df = pd.DataFrame.from_dict(sim_params) 30 | df.to_csv(file, line_terminator='\n') 31 | file.close() 32 | 33 | except Exception as e: 34 | print(e) 35 | 36 | def load_data(filename, path="./fig/data/simulationResults/"): 37 | with open(path + filename, "rb") as f: 38 | data = pickle.load(f) 39 | return data 40 | 41 | def genRandIdx(numChan, numUes,batchSize): 42 | idx = np.arange(numChan, dtype=np.int) 43 | idx_ar = np.zeros((batchSize, numUes), dtype=np.int) 44 | for i in range(batchSize): 45 | idx_ar[i, :] = np.random.choice(idx, numUes, replace=False) 46 | return idx_ar 47 | 48 | # Utility function for saving model weights 49 | def save_weights(model, model_weights_path): 50 | weights = model.get_weights() 51 | with open(model_weights_path, 'wb') as f: 52 | pickle.dump(weights, f) 53 | 54 | # Utility function for loading model weights 55 | def load_weights(model, model_weights_path): 56 | with open(model_weights_path, 'rb') as f: 57 | weights = pickle.load(f) 58 | model.set_weights(weights) 59 | return model 60 | 61 | def train_model(model, ebno_db_min, ebno_db_max, num_training_iterations, training_batch_size): 62 | # Optimizer Adam used to apply gradients 63 | 64 | # optimizer = tf.keras.optimizers.SGD(momentum=0.2, learning_rate=0.05) 65 | optimizer = tf.keras.optimizers.Adam() 66 | for i in range(num_training_iterations): 67 | # Sampling a batch of SNRs 68 | ebno_db = tf.random.uniform(shape=[training_batch_size], minval=ebno_db_min, maxval=ebno_db_max) 69 | # Forward pass 70 | with tf.GradientTape() as tape: 71 | bce = model(training_batch_size, ebno_db) 72 | loss_value = bce 73 | # Computing and applying gradients 74 | weights = model.trainable_weights 75 | # print(weights) 76 | grads = tape.gradient(loss_value, weights) 77 | optimizer.apply_gradients(zip(grads, weights)) 78 | # Periodically printing the progress 79 | if i % 5 == 0: 80 | print('Iteration {}/{} BCE: {:.4f}'.format(i, num_training_iterations, bce.numpy())) 81 | for s in zip(weights, grads): 82 | # print(np.mean(np.abs(s[0])), np.mean(np.abs(s[1]))) 83 | print(s[0], s[1]) 84 | # print("Weight: %.4f Gradient: %.4f" % (weights[1].numpy(), grads[1].numpy())) 85 | 86 | def train_model_deweighting_SNR(model, snr_dB_min, snr_dB_max, training_batch_size, num_training_epochs, num_iter_per_epoch): 87 | # Optimizer Adam used to apply gradients 88 | 89 | # optimizer = tf.keras.optimizers.SGD(momentum=0.2, learning_rate=0.05) 90 | optimizer = tf.keras.optimizers.Adam() 91 | deweighting_weights = tf.Variable(tf.ones((training_batch_size))/training_batch_size, trainable=False, dtype=tf.float32, name="deweighting weights") 92 | ebno_db = tf.cast(tf.linspace(snr_dB_min, snr_dB_max, training_batch_size), dtype=tf.float32) 93 | for i_e in range(num_training_epochs): 94 | sum_loss=0 95 | print("Epoch {}/{} Weights: \n".format(i_e, num_training_epochs) + str(deweighting_weights.numpy().transpose() )) 96 | for i_iter in range(num_iter_per_epoch): 97 | # Sampling a batch of SNRs 98 | # Forward pass 99 | with tf.GradientTape() as tape: 100 | bce = model(training_batch_size, ebno_db) 101 | sum_loss = sum_loss + bce 102 | loss_value = tf.reduce_sum(bce * deweighting_weights) 103 | # Computing and applying gradients 104 | weights = model.trainable_weights 105 | # print(weights) 106 | grads = tape.gradient(loss_value, weights) 107 | optimizer.apply_gradients(zip(grads, weights)) 108 | # sum_bce = sum_bce + bce 109 | # Periodically printing the progress 110 | if i_iter % 5 == 0: 111 | print('Iteration {}/{} BCE: {:.4f}'.format(i_iter, num_iter_per_epoch, loss_value.numpy())) 112 | for s in zip(weights, grads): 113 | # print(np.mean(np.abs(s[0])), np.mean(np.abs(s[1]))) 114 | print(s[0], s[1]) 115 | # print("Weight: %.4f Gradient: %.4f" % (weights[1].numpy(), grads[1].numpy())) 116 | deweighting_weights.assign(1/(sum_loss + 1e-4)) 117 | deweighting_weights.assign(deweighting_weights/tf.reduce_sum(deweighting_weights)) 118 | print(sum_loss) 119 | 120 | def train_model_BLER(model_BCE, model_BLER, ebno_db_min, ebno_db_max, num_pretraining_iterations, num_BLER_training_iterations, training_batch_size): 121 | # Optimizer Adam used to apply gradients 122 | train_model(model_BCE, ebno_db_min, ebno_db_max, num_pretraining_iterations, training_batch_size) 123 | model_BLER.set_weights(model_BCE.get_weights()) 124 | train_model(model_BLER, ebno_db_min, ebno_db_max, num_BLER_training_iterations, training_batch_size) 125 | 126 | def train_model_BLER_SNR_deweighting(model_BCE, model_BLER, ebno_db_min, ebno_db_max, num_pretraining_iterations, num_snr_training_epochs, num_iter_per_epoch, training_batch_size): 127 | train_model(model_BCE, ebno_db_min, ebno_db_max, num_pretraining_iterations, training_batch_size) 128 | model_BLER.set_weights(model_BCE.get_weights()) 129 | train_model_deweighting_SNR(model_BLER, ebno_db_min, ebno_db_max, training_batch_size, num_snr_training_epochs, num_iter_per_epoch) --------------------------------------------------------------------------------