├── LICENSE
├── README.md
├── assets
└── block_diagram.png
└── train
├── datasets.py
├── main.py
├── model.py
├── my_log.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Introduction
2 |
3 | This repository contains the source code for the following paper [1].
4 |
5 | [1] [Y. Zhang et al., "AI Empowered Channel Semantic Acquisition for 6G Integrated Sensing and Communication Networks," in IEEE Network, doi: 10.1109/MNET.2024.3354264.](https://ieeexplore.ieee.org/document/10400499)
6 |
7 | Considering the correlation between C&S channels, we propose a two-stage frame structure along with the channel semantic learning based JCASCasterNet, which yield gains in both extraction and reconstruction of channel semantic.
8 |
9 |
10 |

11 |
12 |
13 | For more information, please visit our website: [gaozhen16.github.io](https://gaozhen16.github.io/)
14 |
15 | If you have any questions, you can contact me by email: yifeizhang2000@foxmail.com
16 |
17 | ## Training
18 |
19 | Run the main function in `train/main.py` to start training. The core code of JCASCasterNet is in `train/model.py`.
20 |
21 | ## License
22 |
23 | All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
24 |
--------------------------------------------------------------------------------
/assets/block_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tocaloid/JCASCasterNet/5fd45528c14e5917ebde69e7527296faf5025062/assets/block_diagram.png
--------------------------------------------------------------------------------
/train/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 | from torch.fft import fft, ifft
7 | from torch.utils.data import Dataset
8 |
9 | import numpy as np
10 | from numpy import pi
11 | from numpy.random import choice
12 |
13 | import hdf5storage
14 |
15 | import os
16 | import time
17 | import random
18 | import shutil
19 | import copy
20 | from tqdm import tqdm
21 | from scipy import io
22 | from math import sqrt, log10
23 |
24 | from utils import env_init
25 |
26 | from my_log import logger, rmLogger
27 |
28 |
29 |
30 | def genH(
31 | N_gen ,
32 | b ,
33 | N_C ,
34 | K ,
35 | M_com ,
36 | L_share ,
37 | L_com ,
38 | L_rad ,
39 | R ,
40 | device ,
41 | ):
42 | theta_share=torch.rand(b, L_share, 1).to(device)*pi; # b, L_share, 1
43 | theta_com=torch.rand(b, L_com, 1).to(device)*pi; # b, L_com, 1
44 | theta_com[:, 0:L_share, :]=theta_share.clone().detach();
45 | theta_rad=torch.rand(b, L_rad, 1).to(device)*pi; # b, L_rad, 1
46 | theta_rad[:, 0:L_share, :]=theta_share.clone().detach();
47 |
48 | r_share=torch.rand(b, L_share, 1).to(device)*R; # b, L_share, 1
49 | r_com=torch.rand(b, L_com, 1).to(device)*R; # b, L_com, 1
50 | r_com[:, 0:L_share, :]=r_share.clone().detach();
51 | r_rad=torch.rand(b, L_rad, 1).to(device)*R; # b, L_rad, 1
52 | r_rad[:, 0:L_share, :]=r_share.clone().detach();
53 |
54 | theta_UE=torch.rand(b, K, 1).to(device)*pi; # b, K, 1
55 | r_UE=torch.rand(b, K, 1).to(device)*R; # b, K, 1
56 | def genBatch(
57 | b,
58 | ):
59 | # H_com
60 | theta_f=torch.rand(b*K, L_com+1, 1).to(device)*2*pi; # b*K, L_com+1, 1
61 | rg_f=torch.arange(N_C).to(device).reshape(1, 1, N_C).float(); # 1, 1, N_C
62 | v_f=torch.exp(1j*(theta_f@rg_f)).unsqueeze(-1); # b*K, L_com+1, N_C, 1
63 |
64 | theta_a=torch.cos(torch.cat(
65 | (theta_com.repeat(K, 1, 1), # b*K, L_com, 1
66 | theta_UE.reshape(b*K, 1, 1), ), # b*K, 1, 1
67 | 1,
68 | ))*pi; # b*K, L_com+1, 1
69 | rg_a=torch.arange(M_com).to(device).reshape(1, 1, M_com).float(); # 1, 1, M_com
70 | v_a=torch.exp(1j*(theta_a@rg_a)).unsqueeze(-2); # b*K, L_com+1, 1, M_com
71 |
72 | gain=torch.randn(b*K, L_com+1, 1, 1, dtype=torch.complex64).to(device); # b*K, L_com+1, 1, 1
73 |
74 | H_com=(gain*(v_f@v_a)).sum(1)/sqrt(L_com+1); # b*K, N_C, M_com
75 | H_com=H_com.reshape(b, K, N_C, M_com).transpose(1, 2); # b, N_C, K, M
76 |
77 | # H_rad
78 | theta_f=torch.rand(b, L_rad, 1).to(device)*2*pi; # b, L_rad, 1
79 | rg_f=torch.arange(N_C).to(device).reshape(1, 1, N_C).float(); # 1, 1, N_C
80 | v_f=torch.exp(1j*(theta_f@rg_f)).unsqueeze(-1); # b, L_rad, N_C, 1
81 |
82 | theta_Tx=torch.cos(theta_rad)*pi; # b, L_rad, 1
83 | rg_Tx=torch.arange(M_com).to(device).reshape(1, 1, M_com).float(); # 1, 1, M_com
84 | v_Tx=torch.exp(1j*(theta_Tx@rg_Tx)).unsqueeze(-2); # b, L_rad, 1, M_com
85 |
86 | gain=torch.randn(b, L_rad, 1, 1, dtype=torch.complex64).to(device); # b, L_rad, 1, 1
87 |
88 | H_rad=(gain*(v_f@v_Tx)).sum(1)/sqrt(L_rad); # b, N_C, M_com
89 | H_rad=H_rad.unsqueeze(-2); # b, N_C, 1, M_com
90 |
91 | return H_com, \
92 | H_rad;
93 |
94 | H_com=torch.zeros(N_gen, N_C, K, M_com, dtype=torch.complex64);
95 | H_rad=torch.zeros(N_gen, N_C, 1, M_com, dtype=torch.complex64);
96 | cnt=0;
97 | for _ in range(N_gen//b):
98 | vcnt=cnt+b;
99 |
100 | H_com[cnt:vcnt], \
101 | H_rad[cnt:vcnt]=genBatch(b);
102 |
103 | cnt=vcnt;
104 |
105 | if cnt!=N_gen:
106 | H_com[cnt:vcnt], \
107 | H_rad[cnt:vcnt]=genBatch(N_gen-cnt);
108 |
109 | return H_com .cpu(), \
110 | H_rad .cpu(), \
111 | theta_share .cpu(), \
112 | theta_com .cpu(), \
113 | theta_rad .cpu(), \
114 | r_share .cpu(), \
115 | r_com .cpu(), \
116 | r_rad .cpu(), \
117 | theta_UE .cpu(), \
118 | r_UE .cpu();
119 |
120 |
121 |
122 | def main():
123 | env_init(42);
124 | H_com , \
125 | H_rad , \
126 | theta_share , \
127 | theta_com , \
128 | theta_rad , \
129 | r_share , \
130 | r_com , \
131 | r_rad , \
132 | theta_UE , \
133 | r_UE =genH(
134 | N_gen =2560 ,
135 | b =256 ,
136 | N_C =32 ,
137 | K =4 ,
138 | M_com =32 ,
139 | L_share =6 ,
140 | L_com =8 ,
141 | L_rad =12 ,
142 | R =100 ,
143 | device =0 ,
144 | );
145 |
146 |
147 |
148 | if __name__=="__main__":
149 | logger("datasets.py");
150 | main();
151 | rmLogger();
152 |
153 |
154 |
155 |
--------------------------------------------------------------------------------
/train/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torch.linalg import svd, eig, inv, pinv
7 | from torch.fft import fft, ifft
8 |
9 | import numpy as np
10 |
11 | import math
12 | from math import sqrt, pi
13 |
14 | import os
15 | from copy import deepcopy
16 |
17 | from model import JCASCasterNet, ComCasterNet, RadCasterNet
18 |
19 | from datasets import genH
20 |
21 | from utils import env_init , \
22 | calMeanPower , \
23 | SELoss , \
24 | RMSELoss , \
25 | NMSELoss , \
26 | CosLoss , \
27 | theta2label
28 |
29 | from my_log import logger, rmLogger
30 |
31 | from tqdm import tqdm
32 |
33 |
34 |
35 | def main():
36 | env_init(42);
37 | N_C =32 ;
38 | M_com =32 ;
39 | N_RF_com =2 ;
40 | Q_bsc =2 ;
41 | Q_ehc =2 ;
42 | K =N_RF_com ;
43 | B =16 ;
44 | SNR_com =10 ;
45 | SNR_rad =10 ;
46 | d_model =256 ;
47 | device =0 ;
48 | net=JCASCasterNet(
49 | N_C =N_C ,
50 | M_com =M_com ,
51 | N_RF_com =N_RF_com ,
52 | Q_bsc =Q_bsc ,
53 | Q_ehc =Q_ehc ,
54 | K =K ,
55 | B =B ,
56 | SNR_com =SNR_com ,
57 | SNR_rad =SNR_rad ,
58 | d_model =d_model ,
59 | ).to(device);
60 | opt=torch.optim.AdamW(
61 | net.parameters(),
62 | lr=1e-4,
63 | );
64 |
65 | b =256 ;
66 | L_share =6 ;
67 | L_com =3*K ;
68 | L_rad =6 ;
69 | R =100 ;
70 | SNR =10 ;
71 |
72 | acc_num=8;
73 | interval=100;
74 | for i in tqdm(range(1_000_000)):
75 | opt.zero_grad();
76 | net.train();
77 |
78 | for _ in range(acc_num):
79 | H_com , \
80 | H_rad , \
81 | theta_share , \
82 | theta_com , \
83 | theta_rad , \
84 | r_share , \
85 | r_com , \
86 | r_rad , \
87 | theta_UE , \
88 | r_UE =genH(
89 | N_gen =b ,
90 | b =b ,
91 | N_C =N_C ,
92 | K =K ,
93 | M_com =M_com ,
94 | L_share =L_share ,
95 | L_com =L_com ,
96 | L_rad =L_rad ,
97 | R =R ,
98 | device =device ,
99 | );
100 |
101 | H_com=H_com.to(device);
102 | H_rad=H_rad.to(device);
103 |
104 | F, H_rad_est=net(
105 | H_com=H_com,
106 | H_rad=H_rad,
107 | );
108 |
109 | loss=0;
110 |
111 | loss_SE=SELoss(H_com, F, SNR);
112 | loss=loss+loss_SE;
113 |
114 | loss_Cos=CosLoss(H_rad_est, H_rad);
115 | loss=loss+loss_Cos;
116 |
117 | loss=loss/acc_num;
118 | loss.backward();
119 |
120 | opt.step();
121 |
122 | if (i+1)%interval==0:
123 | net.eval();
124 | N_eval=2;
125 |
126 | loss_SE_acc=0;
127 | loss_Cos_acc=0;
128 | for _ in range(N_eval):
129 | H_com , \
130 | H_rad , \
131 | theta_share , \
132 | theta_com , \
133 | theta_rad , \
134 | r_share , \
135 | r_com , \
136 | r_rad , \
137 | theta_UE , \
138 | r_UE =genH(
139 | N_gen =10*b ,
140 | b =10*b ,
141 | N_C =N_C ,
142 | K =K ,
143 | M_com =M_com ,
144 | L_share =L_share ,
145 | L_com =L_com ,
146 | L_rad =L_rad ,
147 | R =R ,
148 | device =device ,
149 | );
150 |
151 | H_com=H_com.to(device);
152 | H_rad=H_rad.to(device);
153 |
154 | with torch.no_grad():
155 | F, H_rad_est=net(
156 | H_com=H_com,
157 | H_rad=H_rad,
158 | );
159 |
160 | loss_SE_acc+=SELoss(H_com, F, SNR).item()/N_eval;
161 | loss_Cos_acc+=CosLoss(H_rad_est, H_rad).item()/N_eval;
162 |
163 |
164 | logger("SE: {}".format(-loss_SE_acc));
165 | logger("Cos: {}".format(-loss_Cos_acc));
166 | logger("lr: {}".format(opt.param_groups[0]['lr']));
167 |
168 | folder_path="./BaseCKPT_Cos";
169 | if not os.path.exists(folder_path):
170 | os.makedirs(folder_path);
171 |
172 | save_path=folder_path+"/ckpt"+str(i+1)+".pth.tar";
173 | torch.save(
174 | {'state_dict': net.state_dict(),
175 | 'opt_dict': opt.state_dict(),
176 | }, save_path);
177 |
178 |
179 |
180 | if __name__=="__main__":
181 | logger("main.py");
182 | main();
183 | rmLogger();
--------------------------------------------------------------------------------
/train/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from torch.linalg import svd, eig, inv, pinv
7 | from torch.fft import fft, ifft
8 |
9 | import numpy as np
10 |
11 | import math
12 | from math import sqrt, pi
13 |
14 | from utils import calMeanPower, awgn, env_init, SELoss
15 |
16 | from datasets import genH
17 |
18 | from my_log import logger, rmLogger
19 |
20 |
21 |
22 | #==================================================================================#
23 | # utils
24 | #==================================================================================#
25 | class TransposeLayer(nn.Module):
26 | def __init__(
27 | self,
28 | i,
29 | j,
30 | ):
31 | super().__init__();
32 | self.i=i;
33 | self.j=j;
34 |
35 | def forward(
36 | self,
37 | x,
38 | ):
39 | return x.transpose(self.i, self.j);
40 |
41 |
42 |
43 | class MergeLayer(nn.Module):
44 | def __init__(
45 | self,
46 | i,
47 | ):
48 | super().__init__();
49 | self.i=i;
50 |
51 | def forward(
52 | self,
53 | x,
54 | ):
55 | assert self.i!=-1;
56 | assert self.i!=len(x.shape)-1;
57 | if self.i!=-2:
58 | return x.reshape(*x.shape[:self.i], -1, *x.shape[self.i+2:]);
59 | else:
60 | return x.reshape(*x.shape[:self.i], -1);
61 |
62 |
63 |
64 | class SplitLayer(nn.Module):
65 | def __init__(
66 | self,
67 | i,
68 | n1=None,
69 | n2=None,
70 | ):
71 | super().__init__();
72 | assert (n1 is not None and n2 is None) or (n1 is None and n2 is not None);
73 | self.i=i;
74 | self.n1=n1;
75 | self.n2=n2;
76 |
77 | def forward(
78 | self,
79 | x,
80 | ):
81 | if self.i!=-1:
82 | if self.n2 is None:
83 | return x.reshape(*x.shape[:self.i], self.n1, -1, *x.shape[self.i+1:]);
84 | else:
85 | return x.reshape(*x.shape[:self.i], -1, self.n2, *x.shape[self.i+1:]);
86 | else:
87 | if self.n2 is None:
88 | return x.reshape(*x.shape[:self.i], self.n1, -1);
89 | else:
90 | return x.reshape(*x.shape[:self.i], -1, self.n2);
91 |
92 |
93 |
94 | class PermuteLayer(nn.Module):
95 | def __init__(
96 | self,
97 | permute_order,
98 | ):
99 | super().__init__();
100 | self.permute_order=permute_order;
101 |
102 | def forward(
103 | self,
104 | x,
105 | ):
106 | return x.permute(self.permute_order);
107 |
108 |
109 |
110 | class ReshapeLayer(nn.Module):
111 | def __init__(
112 | self,
113 | shape,
114 | ):
115 | super().__init__();
116 | self.shape=shape;
117 |
118 | def forward(
119 | self,
120 | x,
121 | ):
122 | return x.reshape(self.shape);
123 |
124 |
125 |
126 | class RepeatLayer(nn.Module):
127 | def __init__(
128 | self,
129 | times,
130 | ):
131 | super().__init__();
132 | self.times=times;
133 |
134 | def forward(
135 | self,
136 | x,
137 | ):
138 | return x.repeat(self.times);
139 |
140 |
141 |
142 | class Complex2realLayer(nn.Module):
143 | def __init__(
144 | self,
145 | i,
146 | ):
147 | super().__init__();
148 | self.i=i;
149 |
150 | def forward(
151 | self,
152 | x,
153 | ):
154 | return torch.cat((x.real, x.imag), dim=self.i);
155 |
156 |
157 |
158 | class Real2complexLayer(nn.Module):
159 | def __init__(
160 | self,
161 | i,
162 | ):
163 | super().__init__();
164 | self.i=i;
165 |
166 | def forward(
167 | self,
168 | x,
169 | ):
170 | x=x.transpose(0, self.i);
171 | x=x[0]+1j*x[1];
172 | x=x.unsqueeze(0);
173 | x=x.transpose(0, self.i);
174 | if self.i==-1 or self.i==len(x.shape)-1:
175 | i_stable=self.i-1;
176 | else:
177 | i_stable=self.i;
178 |
179 | x=MergeLayer(i_stable)(x);
180 | return x;
181 |
182 |
183 |
184 | class NormLayer(nn.Module):
185 | def __init__(
186 | self,
187 | d,
188 | c_index=-1,
189 | ):
190 | super().__init__();
191 | self.norm=nn.Sequential(
192 | TransposeLayer(c_index, -1),
193 | nn.LayerNorm(d),
194 | TransposeLayer(c_index, -1),
195 | );
196 |
197 | def forward(
198 | self,
199 | x,
200 | ):
201 | return self.norm(x);
202 |
203 |
204 |
205 | #==================================================================================#
206 | # model
207 | #==================================================================================#
208 | feedback_bits=2;
209 |
210 | def calF_RF(
211 | Theta, # *, 1, M, N_RF
212 | ):
213 | F_RF=torch.exp(1j*Theta);
214 |
215 | return F_RF;
216 |
217 |
218 |
219 | def normF_BB(
220 | F_RF, # *, 1, M, N_RF
221 | F_BB, # *, N_C, N_RF, _
222 | ):
223 | N_C, N_RF=F_BB.shape[-3:-1];
224 |
225 | F=F_RF@F_BB; # *, N_C, M, _
226 | power=(F.abs()**2).sum([-3, -2, -1], keepdim=True); # *, 1, 1, 1
227 | F_BB=F_BB/power.sqrt()*sqrt(N_C*N_RF);
228 |
229 | return F_BB;
230 |
231 |
232 |
233 | def calPilot(
234 | pilot_Theta, # *, Q, 1 , M , N_RF
235 | pilot_BB , # *, Q, N_C, N_RF, 1
236 | ):
237 | _, Q, _ , M , N_RF=pilot_Theta.shape;
238 | _, Q, N_C, N_RF, _ =pilot_BB .shape;
239 |
240 | pilot_RF=calF_RF(pilot_Theta); # *, Q, 1, M, N_RF
241 | pilot_BB=normF_BB(pilot_RF, pilot_BB); # *, Q, N_C, N_RF, 1
242 | pilot=pilot_RF@pilot_BB; # *, Q, N_C, M, 1
243 | pilot=pilot.transpose(1, 4).reshape(-1, N_C, M, Q); # *, N_C, M, Q
244 |
245 | return pilot;
246 |
247 |
248 |
249 | class PositionalEncoding(nn.Module):
250 | def __init__(
251 | self,
252 | d_model,
253 | dropout,
254 | max_len,
255 | ):
256 | super().__init__()
257 | position=torch.arange(max_len).unsqueeze(1);
258 | div_term=torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model));
259 | pe=torch.zeros(max_len, 1, d_model);
260 | pe[:, 0, 0::2]=torch.sin(position*div_term);
261 | pe[:, 0, 1::2]=torch.cos(position*div_term);
262 | self.register_buffer('pe', pe);
263 | self.dropout=nn.Dropout(p=dropout);
264 |
265 | def forward(
266 | self,
267 | x,
268 | ):
269 | x=x+self.pe[:x.size(0)];
270 | return self.dropout(x);
271 |
272 |
273 |
274 | #=======================================================================================================================
275 | #=======================================================================================================================
276 | # Number to Bit Function Defining
277 | def Num2Bit(Num, B):
278 | Num_ = Num.type(torch.uint8)
279 | def integer2bit(integer, num_bits=B):
280 | dtype = integer.type()
281 | device = integer.device;
282 | exponent_bits = -torch.arange(1-B, 1).type(dtype).to(device);
283 | exponent_bits = exponent_bits.repeat(integer.shape + (1,))
284 | out = integer.unsqueeze(-1) // 2 ** exponent_bits
285 | return out% 2
286 | bit = integer2bit(Num_)
287 | bit = bit.reshape(-1, Num_.shape[1] * B)
288 | return bit.type(torch.float32)
289 | def Bit2Num(Bit, B):
290 | device=Bit.device;
291 | Bit_ = Bit.type(torch.float32)
292 | Bit_ = torch.reshape(Bit_, [-1, int(Bit_.shape[1] / B), B])
293 | num = torch.zeros(Bit_[:, :, 0].shape).to(device)
294 | for i in range(B):
295 | num = num + Bit_[:, :, i] * 2 ** (B - 1 - i)
296 | return num
297 | #=======================================================================================================================
298 | #=======================================================================================================================
299 | # Quantization and Dequantization Layers Defining
300 | class Quantization(torch.autograd.Function):
301 | @staticmethod
302 | def forward(ctx, x, B):
303 | ctx.constant = B
304 | step = 2**B-1e-6;
305 | out = torch.round(x*step-0.5);
306 | out = Num2Bit(out, B)
307 | return out.type(x.type())
308 | @staticmethod
309 | def backward(ctx, grad_output):
310 | # return as many input gradients as there were arguments.
311 | # Gradients of constant arguments to forward must be None.
312 | # Gradient of a number is the sum of its B bits.
313 | b, _ = grad_output.shape
314 | grad_num = torch.sum(grad_output.reshape(b, -1, ctx.constant), dim=2) / ctx.constant
315 | return grad_num, None
316 | class Dequantization(torch.autograd.Function):
317 | @staticmethod
318 | def forward(ctx, x, B):
319 | ctx.constant = B
320 | step = 2**B;
321 | out = Bit2Num(x, B)
322 | out = (out+0.5)/step;
323 | return out.type(x.type())
324 | @staticmethod
325 | def backward(ctx, grad_output):
326 | # return as many input gradients as there were arguments.
327 | # Gradients of non-Tensor arguments to forward must be None.
328 | # repeat the gradient of a Num for B time.
329 | b, c = grad_output.shape
330 | grad_output = grad_output.unsqueeze(2) / ctx.constant
331 | grad_bit = grad_output.expand(b, c, ctx.constant)
332 | return torch.reshape(grad_bit, (-1, c * ctx.constant)), None
333 | class QuantizationLayer(nn.Module):
334 | def __init__(self, B):
335 | super(QuantizationLayer, self).__init__()
336 | self.B = B
337 | def forward(self, x):
338 | out = Quantization.apply(x, self.B)
339 | return out
340 | class DequantizationLayer(nn.Module):
341 | def __init__(self, B):
342 | super(DequantizationLayer, self).__init__()
343 | self.B = B
344 | def forward(self, x):
345 | out = Dequantization.apply(x, self.B)
346 | return out
347 |
348 |
349 |
350 | class SemanticEncoder(nn.Module):
351 | def __init__(
352 | self ,
353 | N_C ,
354 | M_com ,
355 | N_RF_com ,
356 | Q_bsc ,
357 | Q_ehc ,
358 | K ,
359 | B ,
360 | SNR_com ,
361 | SNR_rad ,
362 | d_model ,
363 | ):
364 | super().__init__();
365 |
366 | Q=Q_bsc+Q_ehc;
367 |
368 | self.register_buffer("N_C" , torch.tensor(N_C ));
369 | self.register_buffer("M_com" , torch.tensor(M_com ));
370 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com ));
371 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc ));
372 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc ));
373 | self.register_buffer("K" , torch.tensor(K ));
374 | self.register_buffer("B" , torch.tensor(B ));
375 | self.register_buffer("SNR_com" , torch.tensor(SNR_com ));
376 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad ));
377 | self.register_buffer("d_model" , torch.tensor(d_model ));
378 |
379 | d_model=d_model//K;
380 |
381 | self.embed=nn.Sequential( # b, N_C, K, Q
382 | TransposeLayer(1, 2), # b, K, N_C, Q
383 | ReshapeLayer([-1, N_C, Q]), # b*K, N_C, Q
384 | Complex2realLayer(-1), # b*K, N_C, 2*Q
385 | nn.Linear(2*Q, d_model), # b*K, N_C, d_model
386 | ); # b*K, N_C, d_model
387 | self.backbone=nn.Sequential( # b*K, N_C, d_model
388 | TransposeLayer(0, 1), # N_C, b*K, d_model
389 | PositionalEncoding(d_model, 0., N_C),
390 | nn.TransformerEncoder(
391 | encoder_layer=nn.TransformerEncoderLayer(
392 | d_model=d_model,
393 | nhead=8,
394 | dim_feedforward=4*d_model,
395 | dropout=0.,
396 | activation="gelu",
397 | norm_first=False,
398 | ),
399 | num_layers=6,
400 | ),
401 | TransposeLayer(0, 1), # b*K, N_C, d_model
402 | ); # b*K, N_C, d_model
403 | self.head=nn.Sequential( # b*K, N_C, d_model
404 | ReshapeLayer([-1, N_C*d_model]), # b*K, N_C*d_model
405 | nn.Linear(N_C*d_model, B//feedback_bits), # b*K, B//feedback_bits
406 | nn.BatchNorm1d(B//feedback_bits),
407 | nn.Sigmoid(),
408 | QuantizationLayer(feedback_bits),
409 | ReshapeLayer([-1, K, B]), # b, K, B
410 | ); # b, K, B
411 |
412 | def forward(
413 | self ,
414 | measurement , # b, N_C, K, Q
415 | ):
416 | x=self.embed(measurement); # b*K, N_C, d_model
417 | x=self.backbone(x);
418 | feedback=self.head(x); # b, K, B
419 |
420 | return feedback;
421 |
422 |
423 |
424 | class SemanticDecoder(nn.Module):
425 | def __init__(
426 | self ,
427 | N_C ,
428 | M_com ,
429 | N_RF_com ,
430 | Q_bsc ,
431 | Q_ehc ,
432 | K ,
433 | B ,
434 | SNR_com ,
435 | SNR_rad ,
436 | d_model ,
437 | ):
438 | super().__init__();
439 |
440 | Q=Q_bsc+Q_ehc;
441 |
442 | self.register_buffer("N_C" , torch.tensor(N_C ));
443 | self.register_buffer("M_com" , torch.tensor(M_com ));
444 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com ));
445 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc ));
446 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc ));
447 | self.register_buffer("K" , torch.tensor(K ));
448 | self.register_buffer("B" , torch.tensor(B ));
449 | self.register_buffer("SNR_com" , torch.tensor(SNR_com ));
450 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad ));
451 | self.register_buffer("d_model" , torch.tensor(d_model ));
452 |
453 | d_model=d_model//K;
454 |
455 | self.embed_pilot_Theta=nn.Sequential( # b, N_C, M_com, Q
456 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q
457 | TransposeLayer(1, 2), # b, M_com, N_C, 2*Q
458 | ReshapeLayer([-1, M_com, N_C*2*Q]), # b, M_com, N_C*2*Q
459 | nn.Linear(N_C*2*Q, d_model), # b, M_com, d_model
460 | ReshapeLayer([-1, 1, M_com, d_model]), # b, 1, M_com, d_model
461 | RepeatLayer([1, K, 1, 1]), # b, K, M_com, d_model
462 | ReshapeLayer([-1, M_com, d_model]), # b*K, M_com, d_model
463 | ); # b*K, M_com, d_model
464 | self.embed_echo_Theta=nn.Sequential( # b, N_C, 1, Q
465 | TransposeLayer(1, 3), # b, Q, 1, N_C
466 | Complex2realLayer(-1), # b, Q, 1, 2*N_C
467 | ReshapeLayer([-1, Q, 1*2*N_C]), # b, Q, 1*2*N_C
468 | nn.Linear(1*2*N_C, d_model), # b, Q, d_model
469 | nn.Conv1d(Q, M_com, 1), # b, M_com, d_model
470 | ReshapeLayer([-1, 1, M_com, d_model]), # b, 1, M_com, d_model
471 | RepeatLayer([1, K, 1, 1]), # b, K, M_com, d_model
472 | ReshapeLayer([-1, M_com, d_model]), # b*K, M_com, d_model
473 | ); # b*K, M_com, d_model
474 | self.embed_feedback_Theta=nn.Sequential( # b, K, B
475 | ReshapeLayer([-1, B]), # b*K, B
476 | DequantizationLayer(feedback_bits),
477 | nn.Linear(B//feedback_bits, M_com*d_model), # b*K, M_com*d_model
478 | ReshapeLayer([-1, M_com, d_model]), # b*K, M_com, d_model
479 | ); # b*K, M_com, d_model
480 | self.backbone_Theta=nn.Sequential( # b*K, M_com, d_model
481 | TransposeLayer(0, 1), # M_com, b*K, d_model
482 | PositionalEncoding(d_model, 0., M_com),
483 | nn.TransformerEncoder(
484 | encoder_layer=nn.TransformerEncoderLayer(
485 | d_model=d_model,
486 | nhead=8,
487 | dim_feedforward=4*d_model,
488 | dropout=0.,
489 | activation="gelu",
490 | norm_first=False,
491 | ),
492 | num_layers=6,
493 | ),
494 | TransposeLayer(0, 1), # b*K, M_com, d_model
495 | ); # b*K, M_com, d_model
496 | self.F_Theta_head=nn.Sequential( # b*K, M_com, d_model
497 | nn.Linear(d_model, 1), # b*K, M_com, 1
498 | ReshapeLayer([-1, K, M_com, 1]), # b, K, M_com, 1
499 | TransposeLayer(1, 3), # b, 1, M_com, K
500 | ); # b, 1, M_com, N_RF_com (K==N_RF_com)
501 |
502 | d_model=d_model*K;
503 |
504 | self.embed_pilot_BB=nn.Sequential( # b, N_C, M_com, Q
505 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q
506 | ReshapeLayer([-1, N_C, M_com*2*Q]), # b, N_C, M_com*2*Q
507 | nn.Linear(M_com*2*Q, d_model), # b, N_C, d_model
508 | ); # b, N_C, d_model
509 | self.embed_echo_BB=nn.Sequential( # b, N_C, 1, Q
510 | Complex2realLayer(-1), # b, N_C, 1, 2*Q
511 | ReshapeLayer([-1, N_C, 1*2*Q]), # b, N_C, 1*2*Q
512 | nn.Linear(1*2*Q, d_model), # b, N_C, d_model
513 | ); # b, N_C, d_model
514 | self.embed_feedback_BB=nn.Sequential( # b, K, B
515 | ReshapeLayer([-1, B]), # b*K, B
516 | DequantizationLayer(feedback_bits),
517 | ReshapeLayer([-1, K*B//feedback_bits]), # b, K*B//feedback_bits
518 | nn.Linear(K*B//feedback_bits, N_C*d_model), # b, N_C*d_model
519 | ReshapeLayer([-1, N_C, d_model]), # b, N_C, d_model
520 | ); # b, N_C, d_model
521 | self.embed_F_RF=nn.Sequential( # b, 1, M_com, N_RF_com
522 | Complex2realLayer(-1), # b, 1, M_com, 2*N_RF_com
523 | ReshapeLayer([-1, 1, M_com*2*N_RF_com]), # b, 1, M_com*2*N_RF_com
524 | nn.Linear(M_com*2*N_RF_com, d_model), # b, 1, d_model
525 | ); # b, 1, d_model
526 | self.backbone_BB=nn.Sequential( # b, N_C+1, d_model
527 | TransposeLayer(0, 1), # N_C+1, b, d_model
528 | PositionalEncoding(d_model, 0., N_C+1),
529 | nn.TransformerEncoder(
530 | encoder_layer=nn.TransformerEncoderLayer(
531 | d_model=d_model,
532 | nhead=8,
533 | dim_feedforward=4*d_model,
534 | dropout=0.,
535 | activation="gelu",
536 | norm_first=False,
537 | ),
538 | num_layers=6,
539 | ),
540 | TransposeLayer(0, 1), # b, N_C+1, d_model
541 | ); # b, N_C+1, d_model
542 | self.F_BB_head=nn.Sequential( # b, N_C, d_model
543 | nn.Linear(d_model, N_RF_com*K*2), # b, N_C, N_RF_com*K*2
544 | ReshapeLayer([-1, N_C, N_RF_com, K, 2]), # b, N_C, N_RF_com, K, 2
545 | Real2complexLayer(-1), # b, N_C, N_RF_com, K
546 | ); # b, N_C, N_RF_com, K
547 |
548 | def forward(
549 | self ,
550 | pilot , # b, N_C, M_com, Q
551 | echo , # b, N_C, 1 , Q
552 | feedback , # b, K, B
553 | ):
554 | b, N_C, M_com, Q=pilot.shape;
555 |
556 | z_pilot_Theta=self.embed_pilot_Theta(pilot.detach()); # b*K, M_com, d_model
557 | z_echo_Theta=self.embed_echo_Theta(echo.detach()); # b*K, M_com, d_model
558 | z_feedback_Theta=self.embed_feedback_Theta(feedback); # b*K, M_com, d_model
559 |
560 | z_Theta=z_pilot_Theta+z_echo_Theta+z_feedback_Theta; # b*K, M_com, d_model
561 | z_Theta=self.backbone_Theta(z_Theta);
562 | F_RF=calF_RF(self.F_Theta_head(z_Theta[:, 0:M_com])); # b, 1, M_com, N_RF_com
563 |
564 | z_pilot_BB=self.embed_pilot_BB(pilot); # b, N_C, d_model
565 | z_echo_BB=self.embed_echo_BB(echo); # b, N_C, d_model
566 | z_feedback_BB=self.embed_feedback_BB(feedback); # b, N_C, d_model
567 | z_F_RF=self.embed_F_RF(F_RF); # b, 1, d_model
568 |
569 | z_BB=torch.cat((
570 | z_pilot_BB+z_echo_BB+z_feedback_BB,
571 | z_F_RF,
572 | ), dim=1); # b, N_C+1, d_model
573 | z_BB=self.backbone_BB(z_BB);
574 | F_BB=normF_BB(F_RF, self.F_BB_head(z_BB[:, 0:N_C])); # b, N_C, N_RF_com, K
575 |
576 | F=F_RF@F_BB; # b, N_C, M_com, K
577 |
578 | return F;
579 |
580 |
581 |
582 | class ComCasterNet(nn.Module):
583 | def __init__(
584 | self ,
585 | N_C ,
586 | M_com ,
587 | N_RF_com ,
588 | Q_bsc ,
589 | Q_ehc ,
590 | K ,
591 | B ,
592 | SNR_com ,
593 | SNR_rad ,
594 | d_model ,
595 | ):
596 | super().__init__();
597 |
598 | Q=Q_bsc+Q_ehc;
599 |
600 | self.register_buffer("N_C" , torch.tensor(N_C ));
601 | self.register_buffer("M_com" , torch.tensor(M_com ));
602 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com ));
603 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc ));
604 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc ));
605 | self.register_buffer("K" , torch.tensor(K ));
606 | self.register_buffer("B" , torch.tensor(B ));
607 | self.register_buffer("SNR_com" , torch.tensor(SNR_com ));
608 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad ));
609 | self.register_buffer("d_model" , torch.tensor(d_model ));
610 |
611 | self.register_buffer("echo", torch.zeros(1, N_C, 1, Q).to(torch.complex64));
612 |
613 | self.pilot_Theta=nn.Parameter(torch.rand (1, Q, 1 , M_com , N_RF_com)*2*pi);
614 | self.pilot_BB =nn.Parameter(torch.randn(1, Q, N_C, N_RF_com, 1 , dtype=torch.complex64));
615 |
616 | self.SEnc=SemanticEncoder(
617 | N_C =N_C ,
618 | M_com =M_com ,
619 | N_RF_com =N_RF_com ,
620 | Q_bsc =Q_bsc ,
621 | Q_ehc =Q_ehc ,
622 | K =K ,
623 | B =B ,
624 | SNR_com =SNR_com ,
625 | SNR_rad =SNR_rad ,
626 | d_model =d_model ,
627 | );
628 |
629 | self.SDec=SemanticDecoder(
630 | N_C =N_C ,
631 | M_com =M_com ,
632 | N_RF_com =N_RF_com ,
633 | Q_bsc =Q_bsc ,
634 | Q_ehc =Q_ehc ,
635 | K =K ,
636 | B =B ,
637 | SNR_com =SNR_com ,
638 | SNR_rad =SNR_rad ,
639 | d_model =d_model ,
640 | );
641 |
642 | def forward(
643 | self ,
644 | H_com , # b, N_C, K, M_com
645 | ):
646 | b=H_com.shape[0];
647 |
648 | SNR_com=self.SNR_com;
649 |
650 | pilot=calPilot(
651 | pilot_Theta=self.pilot_Theta,
652 | pilot_BB =self.pilot_BB ,
653 | ); # 1, N_C, M_com, Q
654 |
655 | measurement=H_com@pilot; # b, N_C, K, Q
656 | measurement=awgn(measurement, 1, pilot, SNR_com);
657 |
658 | feedback=self.SEnc(measurement); # b, K, B
659 | F=self.SDec(
660 | pilot.repeat(b, 1, 1, 1), # b, N_C, M_com, Q
661 | self.echo.repeat(b, 1, 1, 1), # b, N_C, 1, Q
662 | feedback,
663 | ); # b, N_C, M_com, N_RF_com
664 |
665 | return F;
666 |
667 |
668 |
669 | class EnhancedPilotDesignNet(nn.Module):
670 | def __init__(
671 | self ,
672 | N_C ,
673 | M_com ,
674 | N_RF_com ,
675 | Q_bsc ,
676 | Q_ehc ,
677 | K ,
678 | B ,
679 | SNR_com ,
680 | SNR_rad ,
681 | d_model ,
682 | ):
683 | super().__init__();
684 |
685 | self.register_buffer("N_C" , torch.tensor(N_C ));
686 | self.register_buffer("M_com" , torch.tensor(M_com ));
687 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com ));
688 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc ));
689 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc ));
690 | self.register_buffer("K" , torch.tensor(K ));
691 | self.register_buffer("B" , torch.tensor(B ));
692 | self.register_buffer("SNR_com" , torch.tensor(SNR_com ));
693 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad ));
694 | self.register_buffer("d_model" , torch.tensor(d_model ));
695 |
696 | self.embed_pilot=nn.Sequential( # b, N_C, M_com, Q_bsc
697 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q_bsc
698 | ReshapeLayer([-1, N_C, M_com*2*Q_bsc]), # b, N_C, M_com*2*Q_bsc
699 | nn.Linear(M_com*2*Q_bsc, d_model), # b, N_C, d_model
700 | ); # b, N_C, d_model
701 | self.embed_echo=nn.Sequential( # b, N_C, 1, Q_bsc
702 | Complex2realLayer(-1), # b, N_C, 1, 2*Q_bsc
703 | ReshapeLayer([-1, N_C, 1*2*Q_bsc]), # b, N_C, 1*2*Q_bsc
704 | nn.Linear(1*2*Q_bsc, d_model), # b, N_C, d_model
705 | ); # b, N_C, d_model
706 | self.backbone=nn.Sequential( # b, N_C, d_model
707 | TransposeLayer(0, 1), # N_C, b, d_model
708 | PositionalEncoding(d_model, 0., N_C),
709 | nn.TransformerEncoder(
710 | encoder_layer=nn.TransformerEncoderLayer(
711 | d_model=d_model,
712 | nhead=8,
713 | dim_feedforward=4*d_model,
714 | dropout=0.,
715 | activation="gelu",
716 | norm_first=False,
717 | ),
718 | num_layers=6,
719 | ),
720 | TransposeLayer(0, 1), # b, N_C, d_model
721 | ); # b, N_C, d_model
722 | self.p_Theta_head=nn.Sequential( # b, N_C, d_model
723 | ReshapeLayer([-1, N_C*d_model]), # b, N_C*d_model
724 | nn.Linear(N_C*d_model, Q_ehc*1*M_com*N_RF_com), # b, Q_ehc*1*M_com*N_RF_com
725 | ReshapeLayer([-1, Q_ehc, 1, M_com, N_RF_com]), # b, Q_ehc, 1, M_com, N_RF_com
726 | ); # b, Q_ehc, 1, M_com, N_RF_com
727 | self.p_BB_head=nn.Sequential( # b, N_C, d_model
728 | ReshapeLayer([-1, N_C*d_model]), # b, N_C*d_model
729 | nn.Linear(N_C*d_model, Q_ehc*N_C*N_RF_com*1*2), # b, Q_ehc*N_C*N_RF_com*1*2
730 | ReshapeLayer([-1, Q_ehc, N_C, N_RF_com, 1, 2]), # b, Q_ehc, N_C, N_RF_com, 1, 2
731 | Real2complexLayer(-1), # b, Q_ehc, N_C, N_RF_com, 1
732 | ); # b, Q_ehc, N_C, N_RF_com, 1
733 |
734 | def forward(
735 | self ,
736 | bsc_pilot , # 1, N_C, M_com, Q_bsc
737 | bsc_echo , # b, N_C, 1 , Q_bsc
738 | ):
739 | z_pilot=self.embed_pilot(bsc_pilot.detach()); # b, N_C, d_model
740 | z_echo=self.embed_echo(bsc_echo); # b, N_C, d_model
741 |
742 | z=z_pilot+z_echo;
743 | z=self.backbone(z);
744 | ehc_pilot=calPilot(
745 | pilot_Theta=self.p_Theta_head(z), # b, Q_ehc, 1 , M_com , N_RF_com
746 | pilot_BB =self.p_BB_head (z), # b, Q_ehc, N_C, N_RF_com, 1
747 | ); # b, N_C, M_com, Q_ehc
748 |
749 | return ehc_pilot;
750 |
751 |
752 |
753 | class ChannelSemanticReconstructionNet(nn.Module):
754 | def __init__(
755 | self ,
756 | N_C ,
757 | M_com ,
758 | N_RF_com ,
759 | Q_bsc ,
760 | Q_ehc ,
761 | K ,
762 | B ,
763 | SNR_com ,
764 | SNR_rad ,
765 | d_model ,
766 | ):
767 | super().__init__();
768 |
769 | Q=Q_bsc+Q_ehc;
770 |
771 | self.register_buffer("N_C" , torch.tensor(N_C ));
772 | self.register_buffer("M_com" , torch.tensor(M_com ));
773 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com ));
774 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc ));
775 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc ));
776 | self.register_buffer("K" , torch.tensor(K ));
777 | self.register_buffer("B" , torch.tensor(B ));
778 | self.register_buffer("SNR_com" , torch.tensor(SNR_com ));
779 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad ));
780 | self.register_buffer("d_model" , torch.tensor(d_model ));
781 |
782 | self.embed_pilot=nn.Sequential( # b, N_C, M_com, Q
783 | Complex2realLayer(-1), # b, N_C, M_com, 2*Q
784 | ReshapeLayer([-1, N_C, M_com*2*Q]), # b, N_C, M_com*2*Q
785 | nn.Linear(M_com*2*Q, d_model), # b, N_C, d_model
786 | ); # b, N_C, d_model
787 | self.embed_echo=nn.Sequential( # b, N_C, 1, Q
788 | Complex2realLayer(-1), # b, N_C, 1, 2*Q
789 | ReshapeLayer([-1, N_C, 1*2*Q]), # b, N_C, 1*2*Q
790 | nn.Linear(1*2*Q, d_model), # b, N_C, d_model
791 | ); # b, N_C, d_model
792 | self.embed_feedback=nn.Sequential( # b, K, B
793 | ReshapeLayer([-1, B]), # b*K, B
794 | DequantizationLayer(feedback_bits),
795 | ReshapeLayer([-1, K*B//feedback_bits]), # b, K*B//feedback_bits
796 | nn.Linear(K*B//feedback_bits, N_C*d_model), # b, N_C*d_model
797 | ReshapeLayer([-1, N_C, d_model]), # b, N_C, d_model
798 | ); # b, N_C, d_model
799 | self.backbone=nn.Sequential( # b, N_C, d_model
800 | TransposeLayer(0, 1), # N_C, b, d_model
801 | PositionalEncoding(d_model, 0., N_C),
802 | nn.TransformerEncoder(
803 | encoder_layer=nn.TransformerEncoderLayer(
804 | d_model=d_model,
805 | nhead=8,
806 | dim_feedforward=4*d_model,
807 | dropout=0.,
808 | activation="gelu",
809 | norm_first=False,
810 | ),
811 | num_layers=6,
812 | ),
813 | TransposeLayer(0, 1), # b, N_C, d_model
814 | ); # b, N_C, d_model
815 | self.head=nn.Sequential( # b, N_C, d_model
816 | nn.Linear(d_model, 1*M_com*2), # b, N_C, 1*M_com*2
817 | ReshapeLayer([-1, N_C, 1, M_com, 2]), # b, N_C, 1, M_com, 2
818 | Real2complexLayer(-1), # b, N_C, 1, M_com
819 | ); # b, N_C, 1, M_com
820 |
821 | def forward(
822 | self ,
823 | pilot , # b, N_C, M_com, Q
824 | echo , # b, N_C, 1 , Q
825 | feedback , # b, K, B
826 | ):
827 | b, N_C, M_com, Q=pilot.shape;
828 |
829 | z_pilot=self.embed_pilot(pilot.detach()); # b, N_C, d_model
830 | z_echo=self.embed_echo(echo); # b, N_C, d_model
831 | z_feedback=self.embed_feedback(feedback.detach()); # b, N_C, d_model
832 |
833 | z=z_pilot+z_echo+z_feedback; # b, N_C+1, d_model
834 | z=self.backbone(z[:, 0:N_C]);
835 | H_rad_est=self.head(z); # b, N_C, 1, M_com
836 |
837 | return H_rad_est;
838 |
839 |
840 |
841 | class RadCasterNet(nn.Module):
842 | def __init__(
843 | self ,
844 | N_C ,
845 | M_com ,
846 | N_RF_com ,
847 | Q_bsc ,
848 | Q_ehc ,
849 | K ,
850 | B ,
851 | SNR_com ,
852 | SNR_rad ,
853 | d_model ,
854 | ):
855 | super().__init__();
856 |
857 | self.register_buffer("N_C" , torch.tensor(N_C ));
858 | self.register_buffer("M_com" , torch.tensor(M_com ));
859 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com ));
860 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc ));
861 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc ));
862 | self.register_buffer("K" , torch.tensor(K ));
863 | self.register_buffer("B" , torch.tensor(B ));
864 | self.register_buffer("SNR_com" , torch.tensor(SNR_com ));
865 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad ));
866 | self.register_buffer("d_model" , torch.tensor(d_model ));
867 |
868 | self.register_buffer("feedback", torch.zeros(1, K, B).float());
869 |
870 | self.bsc_pilot_Theta=nn.Parameter(torch.rand (1, Q_bsc, 1 , M_com , N_RF_com)*2*pi);
871 | self.bsc_pilot_BB =nn.Parameter(torch.randn(1, Q_bsc, N_C, N_RF_com, 1 , dtype=torch.complex64));
872 |
873 | self.EPDN=EnhancedPilotDesignNet(
874 | N_C =N_C ,
875 | M_com =M_com ,
876 | N_RF_com =N_RF_com ,
877 | Q_bsc =Q_bsc ,
878 | Q_ehc =Q_ehc ,
879 | K =K ,
880 | B =B ,
881 | SNR_com =SNR_com ,
882 | SNR_rad =SNR_rad ,
883 | d_model =d_model ,
884 | );
885 |
886 | self.CSRN=ChannelSemanticReconstructionNet(
887 | N_C =N_C ,
888 | M_com =M_com ,
889 | N_RF_com =N_RF_com ,
890 | Q_bsc =Q_bsc ,
891 | Q_ehc =Q_ehc ,
892 | K =K ,
893 | B =B ,
894 | SNR_com =SNR_com ,
895 | SNR_rad =SNR_rad ,
896 | d_model =d_model ,
897 | );
898 |
899 | def bsc_pilot(
900 | self ,
901 | ):
902 | return calPilot(
903 | pilot_Theta=self.bsc_pilot_Theta,
904 | pilot_BB =self.bsc_pilot_BB ,
905 | ); # 1, N_C, M_com, Q_bsc
906 |
907 | def forward(
908 | self ,
909 | H_rad , # b, N_C, 1, M_com
910 | ):
911 | b=H_rad.shape[0];
912 |
913 | SNR_rad=self.SNR_rad;
914 |
915 | # stage 1
916 | bsc_pilot=self.bsc_pilot(); # 1, N_C, M_com, Q_bsc
917 |
918 | bsc_echo=H_rad@bsc_pilot; # b, N_C, 1, Q_bsc
919 | bsc_echo=awgn(bsc_echo, 1, bsc_pilot, SNR_rad);
920 |
921 | # stage 2
922 | ehc_pilot=self.EPDN(bsc_pilot, bsc_echo); # b, N_C, M_com, Q_ehc
923 |
924 | ehc_echo=H_rad@ehc_pilot; # b, N_C, 1, Q_ehc
925 | ehc_echo=awgn(ehc_echo, 1, ehc_pilot, SNR_rad);
926 |
927 | # Cat
928 | pilot=torch.cat((bsc_pilot.repeat(b, 1, 1, 1), ehc_pilot), dim=-1); # b, N_C, M_com, Q
929 | echo=torch.cat((bsc_echo, ehc_echo), dim=-1); # b, N_C, N_RF_rad, Q
930 |
931 | H_rad_est=self.CSRN(pilot, echo, self.feedback.repeat(b, 1, 1)); # b, N_C, 1, M_com
932 |
933 | return H_rad_est;
934 |
935 |
936 |
937 | class JCASCasterNet(nn.Module):
938 | def __init__(
939 | self ,
940 | N_C ,
941 | M_com ,
942 | N_RF_com ,
943 | Q_bsc ,
944 | Q_ehc ,
945 | K ,
946 | B ,
947 | SNR_com ,
948 | SNR_rad ,
949 | d_model ,
950 | ):
951 | super().__init__();
952 |
953 | self.register_buffer("N_C" , torch.tensor(N_C ));
954 | self.register_buffer("M_com" , torch.tensor(M_com ));
955 | self.register_buffer("N_RF_com" , torch.tensor(N_RF_com ));
956 | self.register_buffer("Q_bsc" , torch.tensor(Q_bsc ));
957 | self.register_buffer("Q_ehc" , torch.tensor(Q_ehc ));
958 | self.register_buffer("K" , torch.tensor(K ));
959 | self.register_buffer("B" , torch.tensor(B ));
960 | self.register_buffer("SNR_com" , torch.tensor(SNR_com ));
961 | self.register_buffer("SNR_rad" , torch.tensor(SNR_rad ));
962 | self.register_buffer("d_model" , torch.tensor(d_model ));
963 |
964 | self.CN=ComCasterNet(
965 | N_C =N_C ,
966 | M_com =M_com ,
967 | N_RF_com =N_RF_com ,
968 | Q_bsc =Q_bsc ,
969 | Q_ehc =Q_ehc ,
970 | K =K ,
971 | B =B ,
972 | SNR_com =SNR_com ,
973 | SNR_rad =SNR_rad ,
974 | d_model =d_model ,
975 | );
976 |
977 | self.SN=RadCasterNet(
978 | N_C =N_C ,
979 | M_com =M_com ,
980 | N_RF_com =N_RF_com ,
981 | Q_bsc =Q_bsc ,
982 | Q_ehc =Q_ehc ,
983 | K =K ,
984 | B =B ,
985 | SNR_com =SNR_com ,
986 | SNR_rad =SNR_rad ,
987 | d_model =d_model ,
988 | );
989 |
990 | def forward(
991 | self ,
992 | H_com , # b, N_C, K, M_com
993 | H_rad , # b, N_C, 1, M_com
994 | ):
995 | b=H_rad.shape[0];
996 |
997 | SNR_com=self.CN.SNR_com;
998 | SNR_rad=self.SN.SNR_rad;
999 |
1000 | # stage 1
1001 | bsc_pilot=self.SN.bsc_pilot(); # 1, N_C, M_com, Q_bsc
1002 |
1003 | bsc_echo=H_rad@bsc_pilot; # b, N_C, 1, Q_bsc
1004 | bsc_echo=awgn(bsc_echo, 1, bsc_pilot, SNR_rad);
1005 |
1006 | bsc_measurement=H_com@bsc_pilot; # b, N_C, K, Q_bsc
1007 | bsc_measurement=awgn(bsc_measurement, 1, bsc_pilot, SNR_com);
1008 |
1009 | # stage 2
1010 | ehc_pilot=self.SN.EPDN(bsc_pilot, bsc_echo); # b, N_C, M_com, Q_ehc
1011 |
1012 | ehc_echo=H_rad@ehc_pilot; # b, N_C, 1, Q_ehc
1013 | ehc_echo=awgn(ehc_echo, 1, ehc_pilot, SNR_rad);
1014 |
1015 | ehc_measurement=H_com@ehc_pilot; # b, N_C, K, Q_ehc
1016 | ehc_measurement=awgn(ehc_measurement, 1, ehc_pilot, SNR_com);
1017 |
1018 | # Feedback
1019 | measurement=torch.cat((bsc_measurement, ehc_measurement), dim=-1); # b, N_C, K, Q
1020 | feedback=self.CN.SEnc(measurement); # b, K, B
1021 |
1022 | # Cat
1023 | pilot=torch.cat((bsc_pilot.repeat(b, 1, 1, 1), ehc_pilot), dim=-1); # b, N_C, M_com, Q
1024 | echo=torch.cat((bsc_echo, ehc_echo), dim=-1); # b, N_C, N_RF_rad, Q
1025 |
1026 | # Com
1027 | F=self.CN.SDec(pilot, echo, feedback); # b, N_C, M_com, N_RF_com
1028 |
1029 | # Rad
1030 | H_rad_est=self.SN.CSRN(pilot, echo, feedback); # b, N_C, 1, M_com
1031 |
1032 | return F, H_rad_est;
1033 |
1034 |
1035 |
1036 | class MUSIC(nn.Module):
1037 | def __init__(
1038 | self,
1039 | ):
1040 | super().__init__();
1041 |
1042 | def forward(
1043 | self ,
1044 | H , # b, N_C, 1, M
1045 | eps , # [deg]
1046 | ):
1047 | with torch.no_grad():
1048 | b, N_C, _, M=H.shape;
1049 | device=H.device;
1050 | L=6;
1051 |
1052 | X=H.squeeze().transpose(-2, -1); # b, M, N_C
1053 | if b==1:
1054 | X=X.unsqueeze(0);
1055 |
1056 | R=X@X.mH; # b, M, M
1057 | e, V=eig(R);
1058 | idx=e.abs().argsort(dim=-1, descending=True); # b, M
1059 |
1060 | idx_n=idx[:, L:].unsqueeze(-2).repeat(1, M, 1); # b, M, M-L
1061 | V_n=V.gather( # b, M, M
1062 | dim=-1,
1063 | index=idx_n,
1064 | ); # b, M, M-L
1065 |
1066 | R_n=V_n@V_n.mH; # b, M, M
1067 | R_n=R_n.unsqueeze(1); # b, 1, M, M
1068 |
1069 | theta=torch.arange(0, 180, eps).to(device); # [deg]
1070 | theta=theta/180*pi; # [rad]
1071 | theta=torch.cos(theta)*pi;
1072 | theta=theta.unsqueeze(0).unsqueeze(-1); # 1, Num, 1
1073 | rg=torch.arange(M).to(device).reshape(1, 1, M).float(); # 1, 1, M
1074 | A=torch.exp(1j*(theta@rg)).unsqueeze(-1); # 1, Num, M, 1
1075 |
1076 | P=1/(A.mH@R_n@A).squeeze(); # b, Num
1077 |
1078 | return P.real;
1079 |
1080 |
1081 |
1082 |
1083 |
1084 |
1085 |
1086 |
1087 |
1088 |
1089 |
1090 |
1091 |
1092 |
1093 |
1094 |
1095 |
1096 |
1097 |
1098 |
1099 |
--------------------------------------------------------------------------------
/train/my_log.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import logging
4 |
5 | def logger(s):
6 | tmp=logging.getLogger('log');
7 | if not tmp.handlers:
8 | tmp.setLevel(level=logging.INFO);
9 | handler=logging.FileHandler("./log.txt");
10 | handler.setLevel(logging.INFO);
11 | formatter=logging.Formatter('[%(asctime)s]: %(message)s');
12 | handler.setFormatter(formatter);
13 |
14 | console=logging.StreamHandler();
15 | console.setLevel(logging.INFO);
16 | console.setFormatter(formatter);
17 |
18 | tmp.addHandler(handler);
19 | tmp.addHandler(console);
20 |
21 | tmp.info(s);
22 |
23 | def rmLogger():
24 | logging.shutdown();
--------------------------------------------------------------------------------
/train/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 | from torch import pi
7 | from torch.fft import fft, ifft
8 |
9 | import numpy as np
10 | from numpy.random import choice
11 |
12 | import os
13 | import random
14 | import shutil
15 | from scipy import io
16 | from math import sqrt, log10
17 |
18 | import matplotlib.pyplot as plt
19 | from mpl_toolkits.mplot3d import Axes3D
20 | from matplotlib.pyplot import stem,plot
21 | from math import log, log2, log10, sqrt
22 |
23 | from my_log import logger, rmLogger
24 |
25 |
26 |
27 | def SELoss(
28 | H, # b, N_C, K, M
29 | F, # (b,) N_C, M, K
30 | SNR,
31 | ):
32 | SNR=10**(SNR/10);
33 |
34 | Y=H@F; # b, N_C, K, K
35 | b, N_C, K, _=Y.shape;
36 |
37 | power_noise=calMeanPower(F)/SNR;
38 |
39 | Y=Y.reshape(b*N_C, K, K).abs().square(); # b*N_C, K, K
40 | idx=torch.arange(K);
41 | power_signal=Y[:, idx, idx]; # b*N_C, K
42 | power_inter=Y.sum(-1)-power_signal; # b*N_C, K
43 | SINR=power_signal/(power_inter+power_noise);
44 | res=torch.log(1+SINR)/log(2);
45 |
46 | return -1*res.mean(); # 1
47 |
48 |
49 |
50 | def RMSELoss(
51 | x,
52 | y,
53 | ):
54 | b=x.shape[0];
55 |
56 | x=x.reshape(b, -1);
57 | y=y.reshape(b, -1);
58 |
59 | res=(x-y).abs().square().mean(-1).sqrt().mean();
60 | return res;
61 |
62 |
63 |
64 | def NMSELoss(
65 | x,
66 | y,
67 | ):
68 | b=x.shape[0];
69 |
70 | x=x.reshape(b, -1);
71 | y=y.reshape(b, -1);
72 |
73 | res=(x-y).abs().square().mean(-1);
74 | res=res/y.abs().square().mean(-1);
75 |
76 | return res.mean();
77 |
78 |
79 |
80 | def CosLoss(
81 | x, # b, N_C, 1, M
82 | y, # b, N_C, 1, M
83 | ):
84 | tmp_xy=(x@y.mH).abs(); # b, N_C, 1, 1
85 | tmp_xx=(x@x.mH).abs().sqrt(); # b, N_C, 1, 1
86 | tmp_yy=(y@y.mH).abs().sqrt(); # b, N_C, 1, 1
87 |
88 | cos=tmp_xy/tmp_xx/tmp_yy;
89 |
90 | return -cos.mean();
91 |
92 |
93 |
94 | def theta2label(
95 | theta,
96 | grid_num,
97 | ):
98 | b=theta.shape[0];
99 |
100 | idx=torch.round((theta/180)*grid_num).long()[:, :, 0];
101 | label=torch.zeros(b, grid_num);
102 | label.scatter_(
103 | dim=-1,
104 | index=idx,
105 | src=torch.ones_like(idx).float(),
106 | );
107 |
108 | return label;
109 |
110 |
111 | def calMeanPower(x: torch.Tensor):
112 | return x.abs().square().mean();
113 |
114 |
115 |
116 | def awgn(
117 | x,
118 | ofdm_dim_idx,
119 | s,
120 | SNR,
121 | ):
122 | device=x.device;
123 | SNR=10**(SNR/10);
124 |
125 | npower=calMeanPower(s)/SNR;
126 | noise=torch.randn(*x.shape, dtype=torch.complex64).to(device)*torch.sqrt(npower);
127 |
128 | return x+fft(noise, dim=ofdm_dim_idx)/sqrt(noise.shape[ofdm_dim_idx]);
129 |
130 |
131 |
132 | def env_init(seed: int):
133 | random.seed(seed);
134 | os.environ['PYHTONHASHSEED']=str(seed);
135 | np.random.seed(seed);
136 | torch.manual_seed(seed);
137 | torch.cuda.manual_seed(seed);
138 | torch.backends.cudnn.deterministic=True;
139 |
140 |
141 |
142 | def main():
143 | pass;
144 |
145 |
146 |
147 | if __name__=="__main__":
148 | logger("utils.py");
149 | main();
150 | rmLogger();
--------------------------------------------------------------------------------