├── LICENSE ├── README.md ├── assets └── block_diagram.png └── train ├── datasets.py ├── main.py ├── model.py ├── my_log.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This repository contains the source code for the following paper [1]. 4 | 5 | [1] [Y. Zhang et al., "AI Empowered Channel Semantic Acquisition for 6G Integrated Sensing and Communication Networks," in IEEE Network, doi: 10.1109/MNET.2024.3354264.](https://ieeexplore.ieee.org/document/10400499) 6 | 7 | Considering the correlation between C&S channels, we propose a two-stage frame structure along with the channel semantic learning based JCASCasterNet, which yield gains in both extraction and reconstruction of channel semantic. 8 | 9 |
10 | 11 |
12 | 13 | For more information, please visit our website: [gaozhen16.github.io](https://gaozhen16.github.io/) 14 | 15 | If you have any questions, you can contact me by email: yifeizhang2000@foxmail.com 16 | 17 | ## Training 18 | 19 | Run the main function in `train/main.py` to start training. The core code of JCASCasterNet is in `train/model.py`. 20 | 21 | ## License 22 | 23 | All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). 24 | -------------------------------------------------------------------------------- /assets/block_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tocaloid/JCASCasterNet/5fd45528c14e5917ebde69e7527296faf5025062/assets/block_diagram.png -------------------------------------------------------------------------------- /train/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.fft import fft, ifft 7 | from torch.utils.data import Dataset 8 | 9 | import numpy as np 10 | from numpy import pi 11 | from numpy.random import choice 12 | 13 | import hdf5storage 14 | 15 | import os 16 | import time 17 | import random 18 | import shutil 19 | import copy 20 | from tqdm import tqdm 21 | from scipy import io 22 | from math import sqrt, log10 23 | 24 | from utils import env_init 25 | 26 | from my_log import logger, rmLogger 27 | 28 | 29 | 30 | def genH( 31 | N_gen , 32 | b , 33 | N_C , 34 | K , 35 | M_com , 36 | L_share , 37 | L_com , 38 | L_rad , 39 | R , 40 | device , 41 | ): 42 | theta_share=torch.rand(b, L_share, 1).to(device)*pi; # b, L_share, 1 43 | theta_com=torch.rand(b, L_com, 1).to(device)*pi; # b, L_com, 1 44 | theta_com[:, 0:L_share, :]=theta_share.clone().detach(); 45 | theta_rad=torch.rand(b, L_rad, 1).to(device)*pi; # b, L_rad, 1 46 | theta_rad[:, 0:L_share, :]=theta_share.clone().detach(); 47 | 48 | r_share=torch.rand(b, L_share, 1).to(device)*R; # b, L_share, 1 49 | r_com=torch.rand(b, L_com, 1).to(device)*R; # b, L_com, 1 50 | r_com[:, 0:L_share, :]=r_share.clone().detach(); 51 | r_rad=torch.rand(b, L_rad, 1).to(device)*R; # b, L_rad, 1 52 | r_rad[:, 0:L_share, :]=r_share.clone().detach(); 53 | 54 | theta_UE=torch.rand(b, K, 1).to(device)*pi; # b, K, 1 55 | r_UE=torch.rand(b, K, 1).to(device)*R; # b, K, 1 56 | def genBatch( 57 | b, 58 | ): 59 | # H_com 60 | theta_f=torch.rand(b*K, L_com+1, 1).to(device)*2*pi; # b*K, L_com+1, 1 61 | rg_f=torch.arange(N_C).to(device).reshape(1, 1, N_C).float(); # 1, 1, N_C 62 | v_f=torch.exp(1j*(theta_f@rg_f)).unsqueeze(-1); # b*K, L_com+1, N_C, 1 63 | 64 | theta_a=torch.cos(torch.cat( 65 | (theta_com.repeat(K, 1, 1), # b*K, L_com, 1 66 | theta_UE.reshape(b*K, 1, 1), ), # b*K, 1, 1 67 | 1, 68 | ))*pi; # b*K, L_com+1, 1 69 | rg_a=torch.arange(M_com).to(device).reshape(1, 1, M_com).float(); # 1, 1, M_com 70 | v_a=torch.exp(1j*(theta_a@rg_a)).unsqueeze(-2); # b*K, L_com+1, 1, M_com 71 | 72 | gain=torch.randn(b*K, L_com+1, 1, 1, dtype=torch.complex64).to(device); # b*K, L_com+1, 1, 1 73 | 74 | H_com=(gain*(v_f@v_a)).sum(1)/sqrt(L_com+1); # b*K, N_C, M_com 75 | H_com=H_com.reshape(b, K, N_C, M_com).transpose(1, 2); # b, N_C, K, M 76 | 77 | # H_rad 78 | theta_f=torch.rand(b, L_rad, 1).to(device)*2*pi; # b, L_rad, 1 79 | rg_f=torch.arange(N_C).to(device).reshape(1, 1, N_C).float(); # 1, 1, N_C 80 | v_f=torch.exp(1j*(theta_f@rg_f)).unsqueeze(-1); # b, L_rad, N_C, 1 81 | 82 | theta_Tx=torch.cos(theta_rad)*pi; # b, L_rad, 1 83 | rg_Tx=torch.arange(M_com).to(device).reshape(1, 1, M_com).float(); # 1, 1, M_com 84 | v_Tx=torch.exp(1j*(theta_Tx@rg_Tx)).unsqueeze(-2); # b, L_rad, 1, M_com 85 | 86 | gain=torch.randn(b, L_rad, 1, 1, dtype=torch.complex64).to(device); # b, L_rad, 1, 1 87 | 88 | H_rad=(gain*(v_f@v_Tx)).sum(1)/sqrt(L_rad); # b, N_C, M_com 89 | H_rad=H_rad.unsqueeze(-2); # b, N_C, 1, M_com 90 | 91 | return H_com, \ 92 | H_rad; 93 | 94 | H_com=torch.zeros(N_gen, N_C, K, M_com, dtype=torch.complex64); 95 | H_rad=torch.zeros(N_gen, N_C, 1, M_com, dtype=torch.complex64); 96 | cnt=0; 97 | for _ in range(N_gen//b): 98 | vcnt=cnt+b; 99 | 100 | H_com[cnt:vcnt], \ 101 | H_rad[cnt:vcnt]=genBatch(b); 102 | 103 | cnt=vcnt; 104 | 105 | if cnt!=N_gen: 106 | H_com[cnt:vcnt], \ 107 | H_rad[cnt:vcnt]=genBatch(N_gen-cnt); 108 | 109 | return H_com .cpu(), \ 110 | H_rad .cpu(), \ 111 | theta_share .cpu(), \ 112 | theta_com .cpu(), \ 113 | theta_rad .cpu(), \ 114 | r_share .cpu(), \ 115 | r_com .cpu(), \ 116 | r_rad .cpu(), \ 117 | theta_UE .cpu(), \ 118 | r_UE .cpu(); 119 | 120 | 121 | 122 | def main(): 123 | env_init(42); 124 | H_com , \ 125 | H_rad , \ 126 | theta_share , \ 127 | theta_com , \ 128 | theta_rad , \ 129 | r_share , \ 130 | r_com , \ 131 | r_rad , \ 132 | theta_UE , \ 133 | r_UE =genH( 134 | N_gen =2560 , 135 | b =256 , 136 | N_C =32 , 137 | K =4 , 138 | M_com =32 , 139 | L_share =6 , 140 | L_com =8 , 141 | L_rad =12 , 142 | R =100 , 143 | device =0 , 144 | ); 145 | 146 | 147 | 148 | if __name__=="__main__": 149 | logger("datasets.py"); 150 | main(); 151 | rmLogger(); 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /train/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.linalg import svd, eig, inv, pinv 7 | from torch.fft import fft, ifft 8 | 9 | import numpy as np 10 | 11 | import math 12 | from math import sqrt, pi 13 | 14 | import os 15 | from copy import deepcopy 16 | 17 | from model import JCASCasterNet, ComCasterNet, RadCasterNet 18 | 19 | from datasets import genH 20 | 21 | from utils import env_init , \ 22 | calMeanPower , \ 23 | SELoss , \ 24 | RMSELoss , \ 25 | NMSELoss , \ 26 | CosLoss , \ 27 | theta2label 28 | 29 | from my_log import logger, rmLogger 30 | 31 | from tqdm import tqdm 32 | 33 | 34 | 35 | def main(): 36 | env_init(42); 37 | N_C =32 ; 38 | M_com =32 ; 39 | N_RF_com =2 ; 40 | Q_bsc =2 ; 41 | Q_ehc =2 ; 42 | K =N_RF_com ; 43 | B =16 ; 44 | SNR_com =10 ; 45 | SNR_rad =10 ; 46 | d_model =256 ; 47 | device =0 ; 48 | net=JCASCasterNet( 49 | N_C =N_C , 50 | M_com =M_com , 51 | N_RF_com =N_RF_com , 52 | Q_bsc =Q_bsc , 53 | Q_ehc =Q_ehc , 54 | K =K , 55 | B =B , 56 | SNR_com =SNR_com , 57 | SNR_rad =SNR_rad , 58 | d_model =d_model , 59 | ).to(device); 60 | opt=torch.optim.AdamW( 61 | net.parameters(), 62 | lr=1e-4, 63 | ); 64 | 65 | b =256 ; 66 | L_share =6 ; 67 | L_com =3*K ; 68 | L_rad =6 ; 69 | R =100 ; 70 | SNR =10 ; 71 | 72 | acc_num=8; 73 | interval=100; 74 | for i in tqdm(range(1_000_000)): 75 | opt.zero_grad(); 76 | net.train(); 77 | 78 | for _ in range(acc_num): 79 | H_com , \ 80 | H_rad , \ 81 | theta_share , \ 82 | theta_com , \ 83 | theta_rad , \ 84 | r_share , \ 85 | r_com , \ 86 | r_rad , \ 87 | theta_UE , \ 88 | r_UE =genH( 89 | N_gen =b , 90 | b =b , 91 | N_C =N_C , 92 | K =K , 93 | M_com =M_com , 94 | L_share =L_share , 95 | L_com =L_com , 96 | L_rad =L_rad , 97 | R =R , 98 | device =device , 99 | ); 100 | 101 | H_com=H_com.to(device); 102 | H_rad=H_rad.to(device); 103 | 104 | F, H_rad_est=net( 105 | H_com=H_com, 106 | H_rad=H_rad, 107 | ); 108 | 109 | loss=0; 110 | 111 | loss_SE=SELoss(H_com, F, SNR); 112 | loss=loss+loss_SE; 113 | 114 | loss_Cos=CosLoss(H_rad_est, H_rad); 115 | loss=loss+loss_Cos; 116 | 117 | loss=loss/acc_num; 118 | loss.backward(); 119 | 120 | opt.step(); 121 | 122 | if (i+1)%interval==0: 123 | net.eval(); 124 | N_eval=2; 125 | 126 | loss_SE_acc=0; 127 | loss_Cos_acc=0; 128 | for _ in range(N_eval): 129 | H_com , \ 130 | H_rad , \ 131 | theta_share , \ 132 | theta_com , \ 133 | theta_rad , \ 134 | r_share , \ 135 | r_com , \ 136 | r_rad , \ 137 | theta_UE , \ 138 | r_UE =genH( 139 | N_gen =10*b , 140 | b =10*b , 141 | N_C =N_C , 142 | K =K , 143 | M_com =M_com , 144 | L_share =L_share , 145 | L_com =L_com , 146 | L_rad =L_rad , 147 | R =R , 148 | device =device , 149 | ); 150 | 151 | H_com=H_com.to(device); 152 | H_rad=H_rad.to(device); 153 | 154 | with torch.no_grad(): 155 | F, H_rad_est=net( 156 | H_com=H_com, 157 | H_rad=H_rad, 158 | ); 159 | 160 | loss_SE_acc+=SELoss(H_com, F, SNR).item()/N_eval; 161 | loss_Cos_acc+=CosLoss(H_rad_est, H_rad).item()/N_eval; 162 | 163 | 164 | logger("SE: {}".format(-loss_SE_acc)); 165 | logger("Cos: {}".format(-loss_Cos_acc)); 166 | logger("lr: {}".format(opt.param_groups[0]['lr'])); 167 | 168 | folder_path="./BaseCKPT_Cos"; 169 | if not os.path.exists(folder_path): 170 | os.makedirs(folder_path); 171 | 172 | save_path=folder_path+"/ckpt"+str(i+1)+".pth.tar"; 173 | torch.save( 174 | {'state_dict': net.state_dict(), 175 | 'opt_dict': opt.state_dict(), 176 | }, save_path); 177 | 178 | 179 | 180 | if __name__=="__main__": 181 | logger("main.py"); 182 | main(); 183 | rmLogger(); -------------------------------------------------------------------------------- /train/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.linalg import svd, eig, inv, pinv 7 | from torch.fft import fft, ifft 8 | 9 | import numpy as np 10 | 11 | import math 12 | from math import sqrt, pi 13 | 14 | from utils import calMeanPower, awgn, env_init, SELoss 15 | 16 | from datasets import genH 17 | 18 | from my_log import logger, rmLogger 19 | 20 | 21 | 22 | #==================================================================================# 23 | # utils 24 | #==================================================================================# 25 | class TransposeLayer(nn.Module): 26 | def __init__( 27 | self, 28 | i, 29 | j, 30 | ): 31 | super().__init__(); 32 | self.i=i; 33 | self.j=j; 34 | 35 | def forward( 36 | self, 37 | x, 38 | ): 39 | return x.transpose(self.i, self.j); 40 | 41 | 42 | 43 | class MergeLayer(nn.Module): 44 | def __init__( 45 | self, 46 | i, 47 | ): 48 | super().__init__(); 49 | self.i=i; 50 | 51 | def forward( 52 | self, 53 | x, 54 | ): 55 | assert self.i!=-1; 56 | assert self.i!=len(x.shape)-1; 57 | if self.i!=-2: 58 | return x.reshape(*x.shape[:self.i], -1, *x.shape[self.i+2:]); 59 | else: 60 | return x.reshape(*x.shape[:self.i], -1); 61 | 62 | 63 | 64 | class SplitLayer(nn.Module): 65 | def __init__( 66 | self, 67 | i, 68 | n1=None, 69 | n2=None, 70 | ): 71 | super().__init__(); 72 | assert (n1 is not None and n2 is None) or (n1 is None and n2 is not None); 73 | self.i=i; 74 | self.n1=n1; 75 | self.n2=n2; 76 | 77 | def forward( 78 | self, 79 | x, 80 | ): 81 | if self.i!=-1: 82 | if self.n2 is None: 83 | return x.reshape(*x.shape[:self.i], self.n1, -1, *x.shape[self.i+1:]); 84 | else: 85 | return x.reshape(*x.shape[:self.i], -1, self.n2, *x.shape[self.i+1:]); 86 | else: 87 | if self.n2 is None: 88 | return x.reshape(*x.shape[:self.i], self.n1, -1); 89 | else: 90 | return x.reshape(*x.shape[:self.i], -1, self.n2); 91 | 92 | 93 | 94 | class PermuteLayer(nn.Module): 95 | def __init__( 96 | self, 97 | permute_order, 98 | ): 99 | super().__init__(); 100 | self.permute_order=permute_order; 101 | 102 | def forward( 103 | self, 104 | x, 105 | ): 106 | return x.permute(self.permute_order); 107 | 108 | 109 | 110 | class ReshapeLayer(nn.Module): 111 | def __init__( 112 | self, 113 | shape, 114 | ): 115 | super().__init__(); 116 | self.shape=shape; 117 | 118 | def forward( 119 | self, 120 | x, 121 | ): 122 | return x.reshape(self.shape); 123 | 124 | 125 | 126 | class RepeatLayer(nn.Module): 127 | def __init__( 128 | self, 129 | times, 130 | ): 131 | super().__init__(); 132 | self.times=times; 133 | 134 | def forward( 135 | self, 136 | x, 137 | ): 138 | return x.repeat(self.times); 139 | 140 | 141 | 142 | class Complex2realLayer(nn.Module): 143 | def __init__( 144 | self, 145 | i, 146 | ): 147 | super().__init__(); 148 | self.i=i; 149 | 150 | def forward( 151 | self, 152 | x, 153 | ): 154 | return torch.cat((x.real, x.imag), dim=self.i); 155 | 156 | 157 | 158 | class Real2complexLayer(nn.Module): 159 | def __init__( 160 | self, 161 | i, 162 | ): 163 | super().__init__(); 164 | self.i=i; 165 | 166 | def forward( 167 | self, 168 | x, 169 | ): 170 | x=x.transpose(0, self.i); 171 | x=x[0]+1j*x[1]; 172 | x=x.unsqueeze(0); 173 | x=x.transpose(0, self.i); 174 | if self.i==-1 or self.i==len(x.shape)-1: 175 | i_stable=self.i-1; 176 | else: 177 | i_stable=self.i; 178 | 179 | x=MergeLayer(i_stable)(x); 180 | return x; 181 | 182 | 183 | 184 | class NormLayer(nn.Module): 185 | def __init__( 186 | self, 187 | d, 188 | c_index=-1, 189 | ): 190 | super().__init__(); 191 | self.norm=nn.Sequential( 192 | TransposeLayer(c_index, -1), 193 | nn.LayerNorm(d), 194 | TransposeLayer(c_index, -1), 195 | ); 196 | 197 | def forward( 198 | self, 199 | x, 200 | ): 201 | return self.norm(x); 202 | 203 | 204 | 205 | #==================================================================================# 206 | # model 207 | #==================================================================================# 208 | feedback_bits=2; 209 | 210 | def calF_RF( 211 | Theta, # *, 1, M, N_RF 212 | ): 213 | F_RF=torch.exp(1j*Theta); 214 | 215 | return F_RF; 216 | 217 | 218 | 219 | def normF_BB( 220 | F_RF, # *, 1, M, N_RF 221 | F_BB, # *, N_C, N_RF, _ 222 | ): 223 | N_C, N_RF=F_BB.shape[-3:-1]; 224 | 225 | F=F_RF@F_BB; # *, N_C, M, _ 226 | power=(F.abs()**2).sum([-3, -2, -1], keepdim=True); # *, 1, 1, 1 227 | F_BB=F_BB/power.sqrt()*sqrt(N_C*N_RF); 228 | 229 | return F_BB; 230 | 231 | 232 | 233 | def calPilot( 234 | pilot_Theta, # *, Q, 1 , M , N_RF 235 | pilot_BB , # *, Q, N_C, N_RF, 1 236 | ): 237 | _, Q, _ , M , N_RF=pilot_Theta.shape; 238 | _, Q, N_C, N_RF, _ =pilot_BB .shape; 239 | 240 | pilot_RF=calF_RF(pilot_Theta); # *, Q, 1, M, N_RF 241 | pilot_BB=normF_BB(pilot_RF, pilot_BB); # *, Q, N_C, N_RF, 1 242 | pilot=pilot_RF@pilot_BB; # *, Q, N_C, M, 1 243 | pilot=pilot.transpose(1, 4).reshape(-1, N_C, M, Q); # *, N_C, M, Q 244 | 245 | return pilot; 246 | 247 | 248 | 249 | class PositionalEncoding(nn.Module): 250 | def __init__( 251 | self, 252 | d_model, 253 | dropout, 254 | max_len, 255 | ): 256 | super().__init__() 257 | position=torch.arange(max_len).unsqueeze(1); 258 | div_term=torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model)); 259 | pe=torch.zeros(max_len, 1, d_model); 260 | pe[:, 0, 0::2]=torch.sin(position*div_term); 261 | pe[:, 0, 1::2]=torch.cos(position*div_term); 262 | self.register_buffer('pe', pe); 263 | self.dropout=nn.Dropout(p=dropout); 264 | 265 | def forward( 266 | self, 267 | x, 268 | ): 269 | x=x+self.pe[:x.size(0)]; 270 | return self.dropout(x); 271 | 272 | 273 | 274 | #======================================================================================================================= 275 | #======================================================================================================================= 276 | # Number to Bit Function Defining 277 | def Num2Bit(Num, B): 278 | Num_ = Num.type(torch.uint8) 279 | def integer2bit(integer, num_bits=B): 280 | dtype = integer.type() 281 | device = integer.device; 282 | exponent_bits = -torch.arange(1-B, 1).type(dtype).to(device); 283 | exponent_bits = exponent_bits.repeat(integer.shape + (1,)) 284 | out = integer.unsqueeze(-1) // 2 ** exponent_bits 285 | return out% 2 286 | bit = integer2bit(Num_) 287 | bit = bit.reshape(-1, Num_.shape[1] * B) 288 | return bit.type(torch.float32) 289 | def Bit2Num(Bit, B): 290 | device=Bit.device; 291 | Bit_ = Bit.type(torch.float32) 292 | Bit_ = torch.reshape(Bit_, [-1, int(Bit_.shape[1] / B), B]) 293 | num = torch.zeros(Bit_[:, :, 0].shape).to(device) 294 | for i in range(B): 295 | num = num + Bit_[:, :, i] * 2 ** (B - 1 - i) 296 | return num 297 | #======================================================================================================================= 298 | #======================================================================================================================= 299 | # Quantization and Dequantization Layers Defining 300 | class Quantization(torch.autograd.Function): 301 | @staticmethod 302 | def forward(ctx, x, B): 303 | ctx.constant = B 304 | step = 2**B-1e-6; 305 | out = torch.round(x*step-0.5); 306 | out = Num2Bit(out, B) 307 | return out.type(x.type()) 308 | @staticmethod 309 | def backward(ctx, grad_output): 310 | # return as many input gradients as there were arguments. 311 | # Gradients of constant arguments to forward must be None. 312 | # Gradient of a number is the sum of its B bits. 313 | b, _ = grad_output.shape 314 | grad_num = torch.sum(grad_output.reshape(b, -1, ctx.constant), dim=2) / ctx.constant 315 | return grad_num, None 316 | class Dequantization(torch.autograd.Function): 317 | @staticmethod 318 | def forward(ctx, x, B): 319 | ctx.constant = B 320 | step = 2**B; 321 | out = Bit2Num(x, B) 322 | out = (out+0.5)/step; 323 | return out.type(x.type()) 324 | @staticmethod 325 | def backward(ctx, grad_output): 326 | # return as many input gradients as there were arguments. 327 | # Gradients of non-Tensor arguments to forward must be None. 328 | # repeat the gradient of a Num for B time. 329 | b, c = grad_output.shape 330 | grad_output = grad_output.unsqueeze(2) / ctx.constant 331 | grad_bit = grad_output.expand(b, c, ctx.constant) 332 | return torch.reshape(grad_bit, (-1, c * ctx.constant)), None 333 | class QuantizationLayer(nn.Module): 334 | def __init__(self, B): 335 | super(QuantizationLayer, self).__init__() 336 | self.B = B 337 | def forward(self, x): 338 | out = Quantization.apply(x, self.B) 339 | return out 340 | class DequantizationLayer(nn.Module): 341 | def __init__(self, B): 342 | super(DequantizationLayer, self).__init__() 343 | self.B = B 344 | def forward(self, x): 345 | out = Dequantization.apply(x, self.B) 346 | return out 347 | 348 | 349 | 350 | class SemanticEncoder(nn.Module): 351 | def __init__( 352 | self , 353 | N_C , 354 | M_com , 355 | N_RF_com , 356 | Q_bsc , 357 | Q_ehc , 358 | K , 359 | B , 360 | SNR_com , 361 | SNR_rad , 362 | d_model , 363 | ): 364 | super().__init__(); 365 | 366 | Q=Q_bsc+Q_ehc; 367 | 368 | self.register_buffer("N_C" , torch.tensor(N_C )); 369 | self.register_buffer("M_com" , torch.tensor(M_com )); 370 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com )); 371 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc )); 372 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc )); 373 | self.register_buffer("K" , torch.tensor(K )); 374 | self.register_buffer("B" , torch.tensor(B )); 375 | self.register_buffer("SNR_com" , torch.tensor(SNR_com )); 376 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad )); 377 | self.register_buffer("d_model" , torch.tensor(d_model )); 378 | 379 | d_model=d_model//K; 380 | 381 | self.embed=nn.Sequential( # b, N_C, K, Q 382 | TransposeLayer(1, 2), # b, K, N_C, Q 383 | ReshapeLayer([-1, N_C, Q]), # b*K, N_C, Q 384 | Complex2realLayer(-1), # b*K, N_C, 2*Q 385 | nn.Linear(2*Q, d_model), # b*K, N_C, d_model 386 | ); # b*K, N_C, d_model 387 | self.backbone=nn.Sequential( # b*K, N_C, d_model 388 | TransposeLayer(0, 1), # N_C, b*K, d_model 389 | PositionalEncoding(d_model, 0., N_C), 390 | nn.TransformerEncoder( 391 | encoder_layer=nn.TransformerEncoderLayer( 392 | d_model=d_model, 393 | nhead=8, 394 | dim_feedforward=4*d_model, 395 | dropout=0., 396 | activation="gelu", 397 | norm_first=False, 398 | ), 399 | num_layers=6, 400 | ), 401 | TransposeLayer(0, 1), # b*K, N_C, d_model 402 | ); # b*K, N_C, d_model 403 | self.head=nn.Sequential( # b*K, N_C, d_model 404 | ReshapeLayer([-1, N_C*d_model]), # b*K, N_C*d_model 405 | nn.Linear(N_C*d_model, B//feedback_bits), # b*K, B//feedback_bits 406 | nn.BatchNorm1d(B//feedback_bits), 407 | nn.Sigmoid(), 408 | QuantizationLayer(feedback_bits), 409 | ReshapeLayer([-1, K, B]), # b, K, B 410 | ); # b, K, B 411 | 412 | def forward( 413 | self , 414 | measurement , # b, N_C, K, Q 415 | ): 416 | x=self.embed(measurement); # b*K, N_C, d_model 417 | x=self.backbone(x); 418 | feedback=self.head(x); # b, K, B 419 | 420 | return feedback; 421 | 422 | 423 | 424 | class SemanticDecoder(nn.Module): 425 | def __init__( 426 | self , 427 | N_C , 428 | M_com , 429 | N_RF_com , 430 | Q_bsc , 431 | Q_ehc , 432 | K , 433 | B , 434 | SNR_com , 435 | SNR_rad , 436 | d_model , 437 | ): 438 | super().__init__(); 439 | 440 | Q=Q_bsc+Q_ehc; 441 | 442 | self.register_buffer("N_C" , torch.tensor(N_C )); 443 | self.register_buffer("M_com" , torch.tensor(M_com )); 444 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com )); 445 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc )); 446 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc )); 447 | self.register_buffer("K" , torch.tensor(K )); 448 | self.register_buffer("B" , torch.tensor(B )); 449 | self.register_buffer("SNR_com" , torch.tensor(SNR_com )); 450 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad )); 451 | self.register_buffer("d_model" , torch.tensor(d_model )); 452 | 453 | d_model=d_model//K; 454 | 455 | self.embed_pilot_Theta=nn.Sequential( # b, N_C, M_com, Q 456 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q 457 | TransposeLayer(1, 2), # b, M_com, N_C, 2*Q 458 | ReshapeLayer([-1, M_com, N_C*2*Q]), # b, M_com, N_C*2*Q 459 | nn.Linear(N_C*2*Q, d_model), # b, M_com, d_model 460 | ReshapeLayer([-1, 1, M_com, d_model]), # b, 1, M_com, d_model 461 | RepeatLayer([1, K, 1, 1]), # b, K, M_com, d_model 462 | ReshapeLayer([-1, M_com, d_model]), # b*K, M_com, d_model 463 | ); # b*K, M_com, d_model 464 | self.embed_echo_Theta=nn.Sequential( # b, N_C, 1, Q 465 | TransposeLayer(1, 3), # b, Q, 1, N_C 466 | Complex2realLayer(-1), # b, Q, 1, 2*N_C 467 | ReshapeLayer([-1, Q, 1*2*N_C]), # b, Q, 1*2*N_C 468 | nn.Linear(1*2*N_C, d_model), # b, Q, d_model 469 | nn.Conv1d(Q, M_com, 1), # b, M_com, d_model 470 | ReshapeLayer([-1, 1, M_com, d_model]), # b, 1, M_com, d_model 471 | RepeatLayer([1, K, 1, 1]), # b, K, M_com, d_model 472 | ReshapeLayer([-1, M_com, d_model]), # b*K, M_com, d_model 473 | ); # b*K, M_com, d_model 474 | self.embed_feedback_Theta=nn.Sequential( # b, K, B 475 | ReshapeLayer([-1, B]), # b*K, B 476 | DequantizationLayer(feedback_bits), 477 | nn.Linear(B//feedback_bits, M_com*d_model), # b*K, M_com*d_model 478 | ReshapeLayer([-1, M_com, d_model]), # b*K, M_com, d_model 479 | ); # b*K, M_com, d_model 480 | self.backbone_Theta=nn.Sequential( # b*K, M_com, d_model 481 | TransposeLayer(0, 1), # M_com, b*K, d_model 482 | PositionalEncoding(d_model, 0., M_com), 483 | nn.TransformerEncoder( 484 | encoder_layer=nn.TransformerEncoderLayer( 485 | d_model=d_model, 486 | nhead=8, 487 | dim_feedforward=4*d_model, 488 | dropout=0., 489 | activation="gelu", 490 | norm_first=False, 491 | ), 492 | num_layers=6, 493 | ), 494 | TransposeLayer(0, 1), # b*K, M_com, d_model 495 | ); # b*K, M_com, d_model 496 | self.F_Theta_head=nn.Sequential( # b*K, M_com, d_model 497 | nn.Linear(d_model, 1), # b*K, M_com, 1 498 | ReshapeLayer([-1, K, M_com, 1]), # b, K, M_com, 1 499 | TransposeLayer(1, 3), # b, 1, M_com, K 500 | ); # b, 1, M_com, N_RF_com (K==N_RF_com) 501 | 502 | d_model=d_model*K; 503 | 504 | self.embed_pilot_BB=nn.Sequential( # b, N_C, M_com, Q 505 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q 506 | ReshapeLayer([-1, N_C, M_com*2*Q]), # b, N_C, M_com*2*Q 507 | nn.Linear(M_com*2*Q, d_model), # b, N_C, d_model 508 | ); # b, N_C, d_model 509 | self.embed_echo_BB=nn.Sequential( # b, N_C, 1, Q 510 | Complex2realLayer(-1), # b, N_C, 1, 2*Q 511 | ReshapeLayer([-1, N_C, 1*2*Q]), # b, N_C, 1*2*Q 512 | nn.Linear(1*2*Q, d_model), # b, N_C, d_model 513 | ); # b, N_C, d_model 514 | self.embed_feedback_BB=nn.Sequential( # b, K, B 515 | ReshapeLayer([-1, B]), # b*K, B 516 | DequantizationLayer(feedback_bits), 517 | ReshapeLayer([-1, K*B//feedback_bits]), # b, K*B//feedback_bits 518 | nn.Linear(K*B//feedback_bits, N_C*d_model), # b, N_C*d_model 519 | ReshapeLayer([-1, N_C, d_model]), # b, N_C, d_model 520 | ); # b, N_C, d_model 521 | self.embed_F_RF=nn.Sequential( # b, 1, M_com, N_RF_com 522 | Complex2realLayer(-1), # b, 1, M_com, 2*N_RF_com 523 | ReshapeLayer([-1, 1, M_com*2*N_RF_com]), # b, 1, M_com*2*N_RF_com 524 | nn.Linear(M_com*2*N_RF_com, d_model), # b, 1, d_model 525 | ); # b, 1, d_model 526 | self.backbone_BB=nn.Sequential( # b, N_C+1, d_model 527 | TransposeLayer(0, 1), # N_C+1, b, d_model 528 | PositionalEncoding(d_model, 0., N_C+1), 529 | nn.TransformerEncoder( 530 | encoder_layer=nn.TransformerEncoderLayer( 531 | d_model=d_model, 532 | nhead=8, 533 | dim_feedforward=4*d_model, 534 | dropout=0., 535 | activation="gelu", 536 | norm_first=False, 537 | ), 538 | num_layers=6, 539 | ), 540 | TransposeLayer(0, 1), # b, N_C+1, d_model 541 | ); # b, N_C+1, d_model 542 | self.F_BB_head=nn.Sequential( # b, N_C, d_model 543 | nn.Linear(d_model, N_RF_com*K*2), # b, N_C, N_RF_com*K*2 544 | ReshapeLayer([-1, N_C, N_RF_com, K, 2]), # b, N_C, N_RF_com, K, 2 545 | Real2complexLayer(-1), # b, N_C, N_RF_com, K 546 | ); # b, N_C, N_RF_com, K 547 | 548 | def forward( 549 | self , 550 | pilot , # b, N_C, M_com, Q 551 | echo , # b, N_C, 1 , Q 552 | feedback , # b, K, B 553 | ): 554 | b, N_C, M_com, Q=pilot.shape; 555 | 556 | z_pilot_Theta=self.embed_pilot_Theta(pilot.detach()); # b*K, M_com, d_model 557 | z_echo_Theta=self.embed_echo_Theta(echo.detach()); # b*K, M_com, d_model 558 | z_feedback_Theta=self.embed_feedback_Theta(feedback); # b*K, M_com, d_model 559 | 560 | z_Theta=z_pilot_Theta+z_echo_Theta+z_feedback_Theta; # b*K, M_com, d_model 561 | z_Theta=self.backbone_Theta(z_Theta); 562 | F_RF=calF_RF(self.F_Theta_head(z_Theta[:, 0:M_com])); # b, 1, M_com, N_RF_com 563 | 564 | z_pilot_BB=self.embed_pilot_BB(pilot); # b, N_C, d_model 565 | z_echo_BB=self.embed_echo_BB(echo); # b, N_C, d_model 566 | z_feedback_BB=self.embed_feedback_BB(feedback); # b, N_C, d_model 567 | z_F_RF=self.embed_F_RF(F_RF); # b, 1, d_model 568 | 569 | z_BB=torch.cat(( 570 | z_pilot_BB+z_echo_BB+z_feedback_BB, 571 | z_F_RF, 572 | ), dim=1); # b, N_C+1, d_model 573 | z_BB=self.backbone_BB(z_BB); 574 | F_BB=normF_BB(F_RF, self.F_BB_head(z_BB[:, 0:N_C])); # b, N_C, N_RF_com, K 575 | 576 | F=F_RF@F_BB; # b, N_C, M_com, K 577 | 578 | return F; 579 | 580 | 581 | 582 | class ComCasterNet(nn.Module): 583 | def __init__( 584 | self , 585 | N_C , 586 | M_com , 587 | N_RF_com , 588 | Q_bsc , 589 | Q_ehc , 590 | K , 591 | B , 592 | SNR_com , 593 | SNR_rad , 594 | d_model , 595 | ): 596 | super().__init__(); 597 | 598 | Q=Q_bsc+Q_ehc; 599 | 600 | self.register_buffer("N_C" , torch.tensor(N_C )); 601 | self.register_buffer("M_com" , torch.tensor(M_com )); 602 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com )); 603 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc )); 604 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc )); 605 | self.register_buffer("K" , torch.tensor(K )); 606 | self.register_buffer("B" , torch.tensor(B )); 607 | self.register_buffer("SNR_com" , torch.tensor(SNR_com )); 608 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad )); 609 | self.register_buffer("d_model" , torch.tensor(d_model )); 610 | 611 | self.register_buffer("echo", torch.zeros(1, N_C, 1, Q).to(torch.complex64)); 612 | 613 | self.pilot_Theta=nn.Parameter(torch.rand (1, Q, 1 , M_com , N_RF_com)*2*pi); 614 | self.pilot_BB =nn.Parameter(torch.randn(1, Q, N_C, N_RF_com, 1 , dtype=torch.complex64)); 615 | 616 | self.SEnc=SemanticEncoder( 617 | N_C =N_C , 618 | M_com =M_com , 619 | N_RF_com =N_RF_com , 620 | Q_bsc =Q_bsc , 621 | Q_ehc =Q_ehc , 622 | K =K , 623 | B =B , 624 | SNR_com =SNR_com , 625 | SNR_rad =SNR_rad , 626 | d_model =d_model , 627 | ); 628 | 629 | self.SDec=SemanticDecoder( 630 | N_C =N_C , 631 | M_com =M_com , 632 | N_RF_com =N_RF_com , 633 | Q_bsc =Q_bsc , 634 | Q_ehc =Q_ehc , 635 | K =K , 636 | B =B , 637 | SNR_com =SNR_com , 638 | SNR_rad =SNR_rad , 639 | d_model =d_model , 640 | ); 641 | 642 | def forward( 643 | self , 644 | H_com , # b, N_C, K, M_com 645 | ): 646 | b=H_com.shape[0]; 647 | 648 | SNR_com=self.SNR_com; 649 | 650 | pilot=calPilot( 651 | pilot_Theta=self.pilot_Theta, 652 | pilot_BB =self.pilot_BB , 653 | ); # 1, N_C, M_com, Q 654 | 655 | measurement=H_com@pilot; # b, N_C, K, Q 656 | measurement=awgn(measurement, 1, pilot, SNR_com); 657 | 658 | feedback=self.SEnc(measurement); # b, K, B 659 | F=self.SDec( 660 | pilot.repeat(b, 1, 1, 1), # b, N_C, M_com, Q 661 | self.echo.repeat(b, 1, 1, 1), # b, N_C, 1, Q 662 | feedback, 663 | ); # b, N_C, M_com, N_RF_com 664 | 665 | return F; 666 | 667 | 668 | 669 | class EnhancedPilotDesignNet(nn.Module): 670 | def __init__( 671 | self , 672 | N_C , 673 | M_com , 674 | N_RF_com , 675 | Q_bsc , 676 | Q_ehc , 677 | K , 678 | B , 679 | SNR_com , 680 | SNR_rad , 681 | d_model , 682 | ): 683 | super().__init__(); 684 | 685 | self.register_buffer("N_C" , torch.tensor(N_C )); 686 | self.register_buffer("M_com" , torch.tensor(M_com )); 687 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com )); 688 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc )); 689 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc )); 690 | self.register_buffer("K" , torch.tensor(K )); 691 | self.register_buffer("B" , torch.tensor(B )); 692 | self.register_buffer("SNR_com" , torch.tensor(SNR_com )); 693 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad )); 694 | self.register_buffer("d_model" , torch.tensor(d_model )); 695 | 696 | self.embed_pilot=nn.Sequential( # b, N_C, M_com, Q_bsc 697 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q_bsc 698 | ReshapeLayer([-1, N_C, M_com*2*Q_bsc]), # b, N_C, M_com*2*Q_bsc 699 | nn.Linear(M_com*2*Q_bsc, d_model), # b, N_C, d_model 700 | ); # b, N_C, d_model 701 | self.embed_echo=nn.Sequential( # b, N_C, 1, Q_bsc 702 | Complex2realLayer(-1), # b, N_C, 1, 2*Q_bsc 703 | ReshapeLayer([-1, N_C, 1*2*Q_bsc]), # b, N_C, 1*2*Q_bsc 704 | nn.Linear(1*2*Q_bsc, d_model), # b, N_C, d_model 705 | ); # b, N_C, d_model 706 | self.backbone=nn.Sequential( # b, N_C, d_model 707 | TransposeLayer(0, 1), # N_C, b, d_model 708 | PositionalEncoding(d_model, 0., N_C), 709 | nn.TransformerEncoder( 710 | encoder_layer=nn.TransformerEncoderLayer( 711 | d_model=d_model, 712 | nhead=8, 713 | dim_feedforward=4*d_model, 714 | dropout=0., 715 | activation="gelu", 716 | norm_first=False, 717 | ), 718 | num_layers=6, 719 | ), 720 | TransposeLayer(0, 1), # b, N_C, d_model 721 | ); # b, N_C, d_model 722 | self.p_Theta_head=nn.Sequential( # b, N_C, d_model 723 | ReshapeLayer([-1, N_C*d_model]), # b, N_C*d_model 724 | nn.Linear(N_C*d_model, Q_ehc*1*M_com*N_RF_com), # b, Q_ehc*1*M_com*N_RF_com 725 | ReshapeLayer([-1, Q_ehc, 1, M_com, N_RF_com]), # b, Q_ehc, 1, M_com, N_RF_com 726 | ); # b, Q_ehc, 1, M_com, N_RF_com 727 | self.p_BB_head=nn.Sequential( # b, N_C, d_model 728 | ReshapeLayer([-1, N_C*d_model]), # b, N_C*d_model 729 | nn.Linear(N_C*d_model, Q_ehc*N_C*N_RF_com*1*2), # b, Q_ehc*N_C*N_RF_com*1*2 730 | ReshapeLayer([-1, Q_ehc, N_C, N_RF_com, 1, 2]), # b, Q_ehc, N_C, N_RF_com, 1, 2 731 | Real2complexLayer(-1), # b, Q_ehc, N_C, N_RF_com, 1 732 | ); # b, Q_ehc, N_C, N_RF_com, 1 733 | 734 | def forward( 735 | self , 736 | bsc_pilot , # 1, N_C, M_com, Q_bsc 737 | bsc_echo , # b, N_C, 1 , Q_bsc 738 | ): 739 | z_pilot=self.embed_pilot(bsc_pilot.detach()); # b, N_C, d_model 740 | z_echo=self.embed_echo(bsc_echo); # b, N_C, d_model 741 | 742 | z=z_pilot+z_echo; 743 | z=self.backbone(z); 744 | ehc_pilot=calPilot( 745 | pilot_Theta=self.p_Theta_head(z), # b, Q_ehc, 1 , M_com , N_RF_com 746 | pilot_BB =self.p_BB_head (z), # b, Q_ehc, N_C, N_RF_com, 1 747 | ); # b, N_C, M_com, Q_ehc 748 | 749 | return ehc_pilot; 750 | 751 | 752 | 753 | class ChannelSemanticReconstructionNet(nn.Module): 754 | def __init__( 755 | self , 756 | N_C , 757 | M_com , 758 | N_RF_com , 759 | Q_bsc , 760 | Q_ehc , 761 | K , 762 | B , 763 | SNR_com , 764 | SNR_rad , 765 | d_model , 766 | ): 767 | super().__init__(); 768 | 769 | Q=Q_bsc+Q_ehc; 770 | 771 | self.register_buffer("N_C" , torch.tensor(N_C )); 772 | self.register_buffer("M_com" , torch.tensor(M_com )); 773 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com )); 774 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc )); 775 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc )); 776 | self.register_buffer("K" , torch.tensor(K )); 777 | self.register_buffer("B" , torch.tensor(B )); 778 | self.register_buffer("SNR_com" , torch.tensor(SNR_com )); 779 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad )); 780 | self.register_buffer("d_model" , torch.tensor(d_model )); 781 | 782 | self.embed_pilot=nn.Sequential( # b, N_C, M_com, Q 783 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q 784 | ReshapeLayer([-1, N_C, M_com*2*Q]), # b, N_C, M_com*2*Q 785 | nn.Linear(M_com*2*Q, d_model), # b, N_C, d_model 786 | ); # b, N_C, d_model 787 | self.embed_echo=nn.Sequential( # b, N_C, 1, Q 788 | Complex2realLayer(-1), # b, N_C, 1, 2*Q 789 | ReshapeLayer([-1, N_C, 1*2*Q]), # b, N_C, 1*2*Q 790 | nn.Linear(1*2*Q, d_model), # b, N_C, d_model 791 | ); # b, N_C, d_model 792 | self.embed_feedback=nn.Sequential( # b, K, B 793 | ReshapeLayer([-1, B]), # b*K, B 794 | DequantizationLayer(feedback_bits), 795 | ReshapeLayer([-1, K*B//feedback_bits]), # b, K*B//feedback_bits 796 | nn.Linear(K*B//feedback_bits, N_C*d_model), # b, N_C*d_model 797 | ReshapeLayer([-1, N_C, d_model]), # b, N_C, d_model 798 | ); # b, N_C, d_model 799 | self.backbone=nn.Sequential( # b, N_C, d_model 800 | TransposeLayer(0, 1), # N_C, b, d_model 801 | PositionalEncoding(d_model, 0., N_C), 802 | nn.TransformerEncoder( 803 | encoder_layer=nn.TransformerEncoderLayer( 804 | d_model=d_model, 805 | nhead=8, 806 | dim_feedforward=4*d_model, 807 | dropout=0., 808 | activation="gelu", 809 | norm_first=False, 810 | ), 811 | num_layers=6, 812 | ), 813 | TransposeLayer(0, 1), # b, N_C, d_model 814 | ); # b, N_C, d_model 815 | self.head=nn.Sequential( # b, N_C, d_model 816 | nn.Linear(d_model, 1*M_com*2), # b, N_C, 1*M_com*2 817 | ReshapeLayer([-1, N_C, 1, M_com, 2]), # b, N_C, 1, M_com, 2 818 | Real2complexLayer(-1), # b, N_C, 1, M_com 819 | ); # b, N_C, 1, M_com 820 | 821 | def forward( 822 | self , 823 | pilot , # b, N_C, M_com, Q 824 | echo , # b, N_C, 1 , Q 825 | feedback , # b, K, B 826 | ): 827 | b, N_C, M_com, Q=pilot.shape; 828 | 829 | z_pilot=self.embed_pilot(pilot.detach()); # b, N_C, d_model 830 | z_echo=self.embed_echo(echo); # b, N_C, d_model 831 | z_feedback=self.embed_feedback(feedback.detach()); # b, N_C, d_model 832 | 833 | z=z_pilot+z_echo+z_feedback; # b, N_C+1, d_model 834 | z=self.backbone(z[:, 0:N_C]); 835 | H_rad_est=self.head(z); # b, N_C, 1, M_com 836 | 837 | return H_rad_est; 838 | 839 | 840 | 841 | class RadCasterNet(nn.Module): 842 | def __init__( 843 | self , 844 | N_C , 845 | M_com , 846 | N_RF_com , 847 | Q_bsc , 848 | Q_ehc , 849 | K , 850 | B , 851 | SNR_com , 852 | SNR_rad , 853 | d_model , 854 | ): 855 | super().__init__(); 856 | 857 | self.register_buffer("N_C" , torch.tensor(N_C )); 858 | self.register_buffer("M_com" , torch.tensor(M_com )); 859 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com )); 860 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc )); 861 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc )); 862 | self.register_buffer("K" , torch.tensor(K )); 863 | self.register_buffer("B" , torch.tensor(B )); 864 | self.register_buffer("SNR_com" , torch.tensor(SNR_com )); 865 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad )); 866 | self.register_buffer("d_model" , torch.tensor(d_model )); 867 | 868 | self.register_buffer("feedback", torch.zeros(1, K, B).float()); 869 | 870 | self.bsc_pilot_Theta=nn.Parameter(torch.rand (1, Q_bsc, 1 , M_com , N_RF_com)*2*pi); 871 | self.bsc_pilot_BB =nn.Parameter(torch.randn(1, Q_bsc, N_C, N_RF_com, 1 , dtype=torch.complex64)); 872 | 873 | self.EPDN=EnhancedPilotDesignNet( 874 | N_C =N_C , 875 | M_com =M_com , 876 | N_RF_com =N_RF_com , 877 | Q_bsc =Q_bsc , 878 | Q_ehc =Q_ehc , 879 | K =K , 880 | B =B , 881 | SNR_com =SNR_com , 882 | SNR_rad =SNR_rad , 883 | d_model =d_model , 884 | ); 885 | 886 | self.CSRN=ChannelSemanticReconstructionNet( 887 | N_C =N_C , 888 | M_com =M_com , 889 | N_RF_com =N_RF_com , 890 | Q_bsc =Q_bsc , 891 | Q_ehc =Q_ehc , 892 | K =K , 893 | B =B , 894 | SNR_com =SNR_com , 895 | SNR_rad =SNR_rad , 896 | d_model =d_model , 897 | ); 898 | 899 | def bsc_pilot( 900 | self , 901 | ): 902 | return calPilot( 903 | pilot_Theta=self.bsc_pilot_Theta, 904 | pilot_BB =self.bsc_pilot_BB , 905 | ); # 1, N_C, M_com, Q_bsc 906 | 907 | def forward( 908 | self , 909 | H_rad , # b, N_C, 1, M_com 910 | ): 911 | b=H_rad.shape[0]; 912 | 913 | SNR_rad=self.SNR_rad; 914 | 915 | # stage 1 916 | bsc_pilot=self.bsc_pilot(); # 1, N_C, M_com, Q_bsc 917 | 918 | bsc_echo=H_rad@bsc_pilot; # b, N_C, 1, Q_bsc 919 | bsc_echo=awgn(bsc_echo, 1, bsc_pilot, SNR_rad); 920 | 921 | # stage 2 922 | ehc_pilot=self.EPDN(bsc_pilot, bsc_echo); # b, N_C, M_com, Q_ehc 923 | 924 | ehc_echo=H_rad@ehc_pilot; # b, N_C, 1, Q_ehc 925 | ehc_echo=awgn(ehc_echo, 1, ehc_pilot, SNR_rad); 926 | 927 | # Cat 928 | pilot=torch.cat((bsc_pilot.repeat(b, 1, 1, 1), ehc_pilot), dim=-1); # b, N_C, M_com, Q 929 | echo=torch.cat((bsc_echo, ehc_echo), dim=-1); # b, N_C, N_RF_rad, Q 930 | 931 | H_rad_est=self.CSRN(pilot, echo, self.feedback.repeat(b, 1, 1)); # b, N_C, 1, M_com 932 | 933 | return H_rad_est; 934 | 935 | 936 | 937 | class JCASCasterNet(nn.Module): 938 | def __init__( 939 | self , 940 | N_C , 941 | M_com , 942 | N_RF_com , 943 | Q_bsc , 944 | Q_ehc , 945 | K , 946 | B , 947 | SNR_com , 948 | SNR_rad , 949 | d_model , 950 | ): 951 | super().__init__(); 952 | 953 | self.register_buffer("N_C" , torch.tensor(N_C )); 954 | self.register_buffer("M_com" , torch.tensor(M_com )); 955 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com )); 956 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc )); 957 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc )); 958 | self.register_buffer("K" , torch.tensor(K )); 959 | self.register_buffer("B" , torch.tensor(B )); 960 | self.register_buffer("SNR_com" , torch.tensor(SNR_com )); 961 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad )); 962 | self.register_buffer("d_model" , torch.tensor(d_model )); 963 | 964 | self.CN=ComCasterNet( 965 | N_C =N_C , 966 | M_com =M_com , 967 | N_RF_com =N_RF_com , 968 | Q_bsc =Q_bsc , 969 | Q_ehc =Q_ehc , 970 | K =K , 971 | B =B , 972 | SNR_com =SNR_com , 973 | SNR_rad =SNR_rad , 974 | d_model =d_model , 975 | ); 976 | 977 | self.SN=RadCasterNet( 978 | N_C =N_C , 979 | M_com =M_com , 980 | N_RF_com =N_RF_com , 981 | Q_bsc =Q_bsc , 982 | Q_ehc =Q_ehc , 983 | K =K , 984 | B =B , 985 | SNR_com =SNR_com , 986 | SNR_rad =SNR_rad , 987 | d_model =d_model , 988 | ); 989 | 990 | def forward( 991 | self , 992 | H_com , # b, N_C, K, M_com 993 | H_rad , # b, N_C, 1, M_com 994 | ): 995 | b=H_rad.shape[0]; 996 | 997 | SNR_com=self.CN.SNR_com; 998 | SNR_rad=self.SN.SNR_rad; 999 | 1000 | # stage 1 1001 | bsc_pilot=self.SN.bsc_pilot(); # 1, N_C, M_com, Q_bsc 1002 | 1003 | bsc_echo=H_rad@bsc_pilot; # b, N_C, 1, Q_bsc 1004 | bsc_echo=awgn(bsc_echo, 1, bsc_pilot, SNR_rad); 1005 | 1006 | bsc_measurement=H_com@bsc_pilot; # b, N_C, K, Q_bsc 1007 | bsc_measurement=awgn(bsc_measurement, 1, bsc_pilot, SNR_com); 1008 | 1009 | # stage 2 1010 | ehc_pilot=self.SN.EPDN(bsc_pilot, bsc_echo); # b, N_C, M_com, Q_ehc 1011 | 1012 | ehc_echo=H_rad@ehc_pilot; # b, N_C, 1, Q_ehc 1013 | ehc_echo=awgn(ehc_echo, 1, ehc_pilot, SNR_rad); 1014 | 1015 | ehc_measurement=H_com@ehc_pilot; # b, N_C, K, Q_ehc 1016 | ehc_measurement=awgn(ehc_measurement, 1, ehc_pilot, SNR_com); 1017 | 1018 | # Feedback 1019 | measurement=torch.cat((bsc_measurement, ehc_measurement), dim=-1); # b, N_C, K, Q 1020 | feedback=self.CN.SEnc(measurement); # b, K, B 1021 | 1022 | # Cat 1023 | pilot=torch.cat((bsc_pilot.repeat(b, 1, 1, 1), ehc_pilot), dim=-1); # b, N_C, M_com, Q 1024 | echo=torch.cat((bsc_echo, ehc_echo), dim=-1); # b, N_C, N_RF_rad, Q 1025 | 1026 | # Com 1027 | F=self.CN.SDec(pilot, echo, feedback); # b, N_C, M_com, N_RF_com 1028 | 1029 | # Rad 1030 | H_rad_est=self.SN.CSRN(pilot, echo, feedback); # b, N_C, 1, M_com 1031 | 1032 | return F, H_rad_est; 1033 | 1034 | 1035 | 1036 | class MUSIC(nn.Module): 1037 | def __init__( 1038 | self, 1039 | ): 1040 | super().__init__(); 1041 | 1042 | def forward( 1043 | self , 1044 | H , # b, N_C, 1, M 1045 | eps , # [deg] 1046 | ): 1047 | with torch.no_grad(): 1048 | b, N_C, _, M=H.shape; 1049 | device=H.device; 1050 | L=6; 1051 | 1052 | X=H.squeeze().transpose(-2, -1); # b, M, N_C 1053 | if b==1: 1054 | X=X.unsqueeze(0); 1055 | 1056 | R=X@X.mH; # b, M, M 1057 | e, V=eig(R); 1058 | idx=e.abs().argsort(dim=-1, descending=True); # b, M 1059 | 1060 | idx_n=idx[:, L:].unsqueeze(-2).repeat(1, M, 1); # b, M, M-L 1061 | V_n=V.gather( # b, M, M 1062 | dim=-1, 1063 | index=idx_n, 1064 | ); # b, M, M-L 1065 | 1066 | R_n=V_n@V_n.mH; # b, M, M 1067 | R_n=R_n.unsqueeze(1); # b, 1, M, M 1068 | 1069 | theta=torch.arange(0, 180, eps).to(device); # [deg] 1070 | theta=theta/180*pi; # [rad] 1071 | theta=torch.cos(theta)*pi; 1072 | theta=theta.unsqueeze(0).unsqueeze(-1); # 1, Num, 1 1073 | rg=torch.arange(M).to(device).reshape(1, 1, M).float(); # 1, 1, M 1074 | A=torch.exp(1j*(theta@rg)).unsqueeze(-1); # 1, Num, M, 1 1075 | 1076 | P=1/(A.mH@R_n@A).squeeze(); # b, Num 1077 | 1078 | return P.real; 1079 | 1080 | 1081 | 1082 | 1083 | 1084 | 1085 | 1086 | 1087 | 1088 | 1089 | 1090 | 1091 | 1092 | 1093 | 1094 | 1095 | 1096 | 1097 | 1098 | 1099 | -------------------------------------------------------------------------------- /train/my_log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import logging 4 | 5 | def logger(s): 6 | tmp=logging.getLogger('log'); 7 | if not tmp.handlers: 8 | tmp.setLevel(level=logging.INFO); 9 | handler=logging.FileHandler("./log.txt"); 10 | handler.setLevel(logging.INFO); 11 | formatter=logging.Formatter('[%(asctime)s]: %(message)s'); 12 | handler.setFormatter(formatter); 13 | 14 | console=logging.StreamHandler(); 15 | console.setLevel(logging.INFO); 16 | console.setFormatter(formatter); 17 | 18 | tmp.addHandler(handler); 19 | tmp.addHandler(console); 20 | 21 | tmp.info(s); 22 | 23 | def rmLogger(): 24 | logging.shutdown(); -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch import pi 7 | from torch.fft import fft, ifft 8 | 9 | import numpy as np 10 | from numpy.random import choice 11 | 12 | import os 13 | import random 14 | import shutil 15 | from scipy import io 16 | from math import sqrt, log10 17 | 18 | import matplotlib.pyplot as plt 19 | from mpl_toolkits.mplot3d import Axes3D 20 | from matplotlib.pyplot import stem,plot 21 | from math import log, log2, log10, sqrt 22 | 23 | from my_log import logger, rmLogger 24 | 25 | 26 | 27 | def SELoss( 28 | H, # b, N_C, K, M 29 | F, # (b,) N_C, M, K 30 | SNR, 31 | ): 32 | SNR=10**(SNR/10); 33 | 34 | Y=H@F; # b, N_C, K, K 35 | b, N_C, K, _=Y.shape; 36 | 37 | power_noise=calMeanPower(F)/SNR; 38 | 39 | Y=Y.reshape(b*N_C, K, K).abs().square(); # b*N_C, K, K 40 | idx=torch.arange(K); 41 | power_signal=Y[:, idx, idx]; # b*N_C, K 42 | power_inter=Y.sum(-1)-power_signal; # b*N_C, K 43 | SINR=power_signal/(power_inter+power_noise); 44 | res=torch.log(1+SINR)/log(2); 45 | 46 | return -1*res.mean(); # 1 47 | 48 | 49 | 50 | def RMSELoss( 51 | x, 52 | y, 53 | ): 54 | b=x.shape[0]; 55 | 56 | x=x.reshape(b, -1); 57 | y=y.reshape(b, -1); 58 | 59 | res=(x-y).abs().square().mean(-1).sqrt().mean(); 60 | return res; 61 | 62 | 63 | 64 | def NMSELoss( 65 | x, 66 | y, 67 | ): 68 | b=x.shape[0]; 69 | 70 | x=x.reshape(b, -1); 71 | y=y.reshape(b, -1); 72 | 73 | res=(x-y).abs().square().mean(-1); 74 | res=res/y.abs().square().mean(-1); 75 | 76 | return res.mean(); 77 | 78 | 79 | 80 | def CosLoss( 81 | x, # b, N_C, 1, M 82 | y, # b, N_C, 1, M 83 | ): 84 | tmp_xy=(x@y.mH).abs(); # b, N_C, 1, 1 85 | tmp_xx=(x@x.mH).abs().sqrt(); # b, N_C, 1, 1 86 | tmp_yy=(y@y.mH).abs().sqrt(); # b, N_C, 1, 1 87 | 88 | cos=tmp_xy/tmp_xx/tmp_yy; 89 | 90 | return -cos.mean(); 91 | 92 | 93 | 94 | def theta2label( 95 | theta, 96 | grid_num, 97 | ): 98 | b=theta.shape[0]; 99 | 100 | idx=torch.round((theta/180)*grid_num).long()[:, :, 0]; 101 | label=torch.zeros(b, grid_num); 102 | label.scatter_( 103 | dim=-1, 104 | index=idx, 105 | src=torch.ones_like(idx).float(), 106 | ); 107 | 108 | return label; 109 | 110 | 111 | def calMeanPower(x: torch.Tensor): 112 | return x.abs().square().mean(); 113 | 114 | 115 | 116 | def awgn( 117 | x, 118 | ofdm_dim_idx, 119 | s, 120 | SNR, 121 | ): 122 | device=x.device; 123 | SNR=10**(SNR/10); 124 | 125 | npower=calMeanPower(s)/SNR; 126 | noise=torch.randn(*x.shape, dtype=torch.complex64).to(device)*torch.sqrt(npower); 127 | 128 | return x+fft(noise, dim=ofdm_dim_idx)/sqrt(noise.shape[ofdm_dim_idx]); 129 | 130 | 131 | 132 | def env_init(seed: int): 133 | random.seed(seed); 134 | os.environ['PYHTONHASHSEED']=str(seed); 135 | np.random.seed(seed); 136 | torch.manual_seed(seed); 137 | torch.cuda.manual_seed(seed); 138 | torch.backends.cudnn.deterministic=True; 139 | 140 | 141 | 142 | def main(): 143 | pass; 144 | 145 | 146 | 147 | if __name__=="__main__": 148 | logger("utils.py"); 149 | main(); 150 | rmLogger(); --------------------------------------------------------------------------------