├── LICENSE
├── README.md
├── checkpoint
├── filled
│ ├── best_model_checkpoint.pth
│ └── final_model.pth
└── sparse
│ ├── best_model_checkpoint.pth
│ └── final_model.pth
├── conda_env.txt
├── data
├── train
│ ├── label
│ │ ├── labels_snr_0dB.pt
│ │ ├── labels_snr_10dB.pt
│ │ ├── labels_snr_15dB.pt
│ │ ├── labels_snr_20dB.pt
│ │ ├── labels_snr_25dB.pt
│ │ ├── labels_snr_30dB.pt
│ │ └── labels_snr_5dB.pt
│ └── signal
│ │ ├── signals_snr_0dB.pt
│ │ ├── signals_snr_10dB.pt
│ │ ├── signals_snr_15dB.pt
│ │ ├── signals_snr_20dB.pt
│ │ ├── signals_snr_25dB.pt
│ │ ├── signals_snr_30dB.pt
│ │ └── signals_snr_5dB.pt
└── val
│ ├── label
│ ├── labels_snr_0dB.pt
│ ├── labels_snr_10dB.pt
│ ├── labels_snr_15dB.pt
│ ├── labels_snr_20dB.pt
│ ├── labels_snr_25dB.pt
│ ├── labels_snr_30dB.pt
│ └── labels_snr_5dB.pt
│ └── signal
│ ├── signals_snr_0dB.pt
│ ├── signals_snr_10dB.pt
│ ├── signals_snr_15dB.pt
│ ├── signals_snr_20dB.pt
│ ├── signals_snr_25dB.pt
│ ├── signals_snr_30dB.pt
│ └── signals_snr_5dB.pt
├── fig
├── Accuracy1_SLA.png
├── Accuracy1_ULA.png
├── Accuracy2_SLA.png
├── Accuracy2_ULA.png
├── Example_SLA.png
├── Example_SLA_real.png
├── Example_ULA.png
├── Example_ULA_real.png
├── Network.png
├── Separate_SLA.png
└── Separate_ULA.png
├── models
├── SADOANet.py
└── __pycache__
│ └── SADOANet.cpython-312.pyc
├── real_World_DOA_dataset
├── README.md
├── data.mat
└── fig
│ ├── DOA data.png
│ ├── example.gif
│ ├── multiExamples.png
│ └── platform.png
├── scr
├── __pycache__
│ ├── eval_fun.cpython-312.pyc
│ ├── helpers.cpython-311.pyc
│ └── helpers.cpython-312.pyc
├── dataset_gen.py
├── eval_fun.py
├── helpers.py
├── realData_demo.mat
└── run_eval.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 ruxin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep-Learning-Enabled-Robust-DOA-Estimation-with-Single-Snapshot-Sparse-Arrays
2 | This is the code for paper "Antenna Failure Resilience: Deep Learning-Enabled Robust DOA Estimation with Single Snapshot Sparse Arrays"
3 |
4 | ## Simulated dataset generation for trianing and validation
5 | ``` sh
6 | python scr/dataset_gen.py --output_dir './' --num_samples_val 1024 --num_samples_train 100000 --N 10 --max_targets 3
7 | ```
8 |
9 | ## Network architectures
10 |
11 |
12 |
13 |
14 |
15 | ## Training
16 | Without sparse augmentation model:
17 |
18 | ```sh
19 | python train.py --data_path './data' --checkpoint_path './checkpoint' --number_elements 10 --output_size 61 --sparsity 0.3 --use_sparse False --learning_rate 0.0001 --batch_size 1024 --epochs 300
20 | ```
21 |
22 | With sparse augmentation model:
23 |
24 | ``` sh
25 | python train.py --data_path './data' --checkpoint_path './checkpoint' --number_elements 10 --output_size 61 --sparsity 0.3 --use_sparse True --learning_rate 0.0001 --batch_size 1024 --epochs 300
26 | ```
27 |
28 | ## Evaluation
29 | The evaluation of the model can be conducted immediately using weights that we have trained and provided. These weights are available in the 'checkpoint' directory.
30 | Before proceeding with the following steps, ensure you are in the correct directory where the scripts or applications are located.
31 |
32 | ``` sh
33 | cd scr
34 | ```
35 |
36 | ### Single target accuracy
37 |
38 | ``` sh
39 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'accuracy1'
40 | ```
41 |
42 | Expected outputs: ULA(left), SLA(right)
43 |
44 |
45 |
46 |
47 |
48 | ### Two target accuracy
49 |
50 | ``` sh
51 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'accuracy2'
52 | ```
53 |
54 | Expected outputs: ULA(left), SLA(right)
55 |
56 |
57 |
58 |
59 |
60 | ### Seperatebility
61 |
62 | ``` sh
63 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'separate'
64 | ```
65 |
66 | Expected outputs: ULA(left), SLA(right)
67 |
68 |
69 |
70 |
71 |
72 | ### Complexity
73 | ``` sh
74 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'complexity'
75 | ```
76 | Expected outputs:
77 |
78 | Total trainable parameters in MLP model: 2848829
79 |
80 | Total trainable parameters in Ours model: 4106301
81 |
82 | ### Results examples
83 | #### With simulated data
84 |
85 | ``` sh
86 | python run_eval.py --evaluation_mode 'examples'
87 | ```
88 |
89 | Expected outputs:
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | #### With real world data
99 |
100 | ``` sh
101 | python run_eval.py --evaluation_mode 'examples' --real True
102 | ```
103 |
104 | Expected outputs:
105 |
106 |
107 |
108 |
109 |
110 |
111 | ## Real World dataset
112 | please refer README in the folder 'real_World_DOA_dataset'
113 |
114 | ## Enviroment
115 | The Conda environment required for this project is specified in the file 'conda_env.txt'. This file contains a list of all the necessary Python packages and their versions to ensure compatibility and reproducibility of the project's code.
116 |
117 | If you find this project helpful for your research, please consider citing:
118 |
119 | ```BibTex
120 | @article{zheng2024antenna,
121 | title={Antenna Failure Resilience: Deep Learning-Enabled Robust DOA Estimation with Single Snapshot Sparse Arrays},
122 | author={Zheng, Ruxin and Sun, Shunqiao and Liu, Hongshan and Chen, Honglei and Soltanalian, Mojtaba and Li, Jian},
123 | journal={arXiv preprint arXiv:2405.02788},
124 | year={2024}
125 | }
126 | ```
127 |
128 |
129 |
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/checkpoint/filled/best_model_checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/filled/best_model_checkpoint.pth
--------------------------------------------------------------------------------
/checkpoint/filled/final_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/filled/final_model.pth
--------------------------------------------------------------------------------
/checkpoint/sparse/best_model_checkpoint.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/sparse/best_model_checkpoint.pth
--------------------------------------------------------------------------------
/checkpoint/sparse/final_model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/sparse/final_model.pth
--------------------------------------------------------------------------------
/conda_env.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | _openmp_mutex=5.1=1_gnu
6 | alabaster=0.7.12=pyhd3eb1b0_0
7 | arrow=1.2.3=py312h06a4308_1
8 | astroid=2.14.2=py312h06a4308_0
9 | asttokens=2.0.5=pyhd3eb1b0_0
10 | atomicwrites=1.4.0=py_0
11 | attrs=23.1.0=py312h06a4308_0
12 | autopep8=2.0.4=pyhd3eb1b0_0
13 | babel=2.11.0=py312h06a4308_0
14 | beautifulsoup4=4.12.2=py312h06a4308_0
15 | binaryornot=0.4.4=pyhd3eb1b0_1
16 | black=24.3.0=py312h06a4308_0
17 | blas=1.0=mkl
18 | bleach=4.1.0=pyhd3eb1b0_0
19 | brotli=1.0.9=h5eee18b_8
20 | brotli-bin=1.0.9=h5eee18b_8
21 | brotli-python=1.0.9=py312h6a678d5_8
22 | bzip2=1.0.8=h5eee18b_6
23 | ca-certificates=2024.3.11=h06a4308_0
24 | certifi=2024.2.2=py312h06a4308_0
25 | cffi=1.16.0=py312h5eee18b_1
26 | chardet=4.0.0=py312h06a4308_1003
27 | charset-normalizer=2.0.4=pyhd3eb1b0_0
28 | click=8.1.7=py312h06a4308_0
29 | cloudpickle=2.2.1=py312h06a4308_0
30 | colorama=0.4.6=py312h06a4308_0
31 | comm=0.2.1=py312h06a4308_0
32 | contourpy=1.2.0=py312hdb19cb5_0
33 | cookiecutter=2.6.0=py312h06a4308_0
34 | cryptography=42.0.5=py312hdda0065_0
35 | cuda-cudart=11.8.89=0
36 | cuda-cupti=11.8.87=0
37 | cuda-libraries=11.8.0=0
38 | cuda-nvcc=12.4.131=0
39 | cuda-nvrtc=11.8.89=0
40 | cuda-nvtx=11.8.86=0
41 | cuda-runtime=11.8.0=0
42 | cudatoolkit=11.8.0=h6a678d5_0
43 | cycler=0.11.0=pyhd3eb1b0_0
44 | cyrus-sasl=2.1.28=h52b45da_1
45 | dbus=1.13.18=hb2f20db_0
46 | debugpy=1.6.7=py312h6a678d5_0
47 | decorator=5.1.1=pyhd3eb1b0_0
48 | defusedxml=0.7.1=pyhd3eb1b0_0
49 | diff-match-patch=20200713=pyhd3eb1b0_0
50 | dill=0.3.7=py312h06a4308_0
51 | docstring-to-markdown=0.11=py312h06a4308_0
52 | docutils=0.18.1=py312h06a4308_3
53 | executing=0.8.3=pyhd3eb1b0_0
54 | expat=2.6.2=h6a678d5_0
55 | ffmpeg=4.3=hf484d3e_0
56 | filelock=3.13.1=py312h06a4308_0
57 | flake8=7.0.0=py312h06a4308_0
58 | fontconfig=2.14.1=h4c34cd2_2
59 | fonttools=4.51.0=py312h5eee18b_0
60 | freetype=2.12.1=h4a9f257_0
61 | glib=2.78.4=h6a678d5_0
62 | glib-tools=2.78.4=h6a678d5_0
63 | gmp=6.2.1=h295c915_3
64 | gnutls=3.6.15=he1e5248_0
65 | gst-plugins-base=1.14.1=h6a678d5_1
66 | gstreamer=1.14.1=h5eee18b_1
67 | icu=73.1=h6a678d5_0
68 | idna=3.7=py312h06a4308_0
69 | imagesize=1.4.1=py312h06a4308_0
70 | importlib-metadata=7.0.1=py312h06a4308_0
71 | inflection=0.5.1=py312h06a4308_1
72 | intel-openmp=2023.1.0=hdb19cb5_46306
73 | intervaltree=3.1.0=pyhd3eb1b0_0
74 | ipykernel=6.28.0=py312h06a4308_0
75 | ipython=8.20.0=py312h06a4308_0
76 | isort=5.9.3=pyhd3eb1b0_0
77 | jaraco.classes=3.2.1=pyhd3eb1b0_0
78 | jedi=0.18.1=py312h06a4308_1
79 | jeepney=0.7.1=pyhd3eb1b0_0
80 | jellyfish=1.0.1=py312hb02cf49_0
81 | jinja2=3.1.3=py312h06a4308_0
82 | jpeg=9e=h5eee18b_1
83 | jsonschema=4.19.2=py312h06a4308_0
84 | jsonschema-specifications=2023.7.1=py312h06a4308_0
85 | jupyter_client=8.6.0=py312h06a4308_0
86 | jupyter_core=5.5.0=py312h06a4308_0
87 | jupyterlab_pygments=0.2.2=py312h06a4308_0
88 | keyring=24.3.1=py312h06a4308_0
89 | kiwisolver=1.4.4=py312h6a678d5_0
90 | krb5=1.20.1=h143b758_1
91 | lame=3.100=h7b6447c_0
92 | lazy-object-proxy=1.10.0=py312h5eee18b_0
93 | lcms2=2.12=h3be6417_0
94 | ld_impl_linux-64=2.38=h1181459_1
95 | lerc=3.0=h295c915_0
96 | libbrotlicommon=1.0.9=h5eee18b_8
97 | libbrotlidec=1.0.9=h5eee18b_8
98 | libbrotlienc=1.0.9=h5eee18b_8
99 | libclang=14.0.6=default_hc6dbbc7_1
100 | libclang13=14.0.6=default_he11475f_1
101 | libcublas=11.11.3.6=0
102 | libcufft=10.9.0.58=0
103 | libcufile=1.9.1.3=0
104 | libcups=2.4.2=h2d74bed_1
105 | libcurand=10.3.5.147=0
106 | libcusolver=11.4.1.48=0
107 | libcusparse=11.7.5.86=0
108 | libdeflate=1.17=h5eee18b_1
109 | libedit=3.1.20230828=h5eee18b_0
110 | libevent=2.1.12=hdbd6064_1
111 | libffi=3.4.4=h6a678d5_1
112 | libgcc-ng=11.2.0=h1234567_1
113 | libgfortran-ng=11.2.0=h00389a5_1
114 | libgfortran5=11.2.0=h1234567_1
115 | libglib=2.78.4=hdc74915_0
116 | libgomp=11.2.0=h1234567_1
117 | libiconv=1.16=h5eee18b_3
118 | libidn2=2.3.4=h5eee18b_0
119 | libjpeg-turbo=2.0.0=h9bf148f_0
120 | libllvm14=14.0.6=hdb19cb5_3
121 | libnpp=11.8.0.86=0
122 | libnvjpeg=11.9.0.86=0
123 | libpng=1.6.39=h5eee18b_0
124 | libpq=12.17=hdbd6064_0
125 | libsodium=1.0.18=h7b6447c_0
126 | libspatialindex=1.9.3=h2531618_0
127 | libstdcxx-ng=11.2.0=h1234567_1
128 | libtasn1=4.19.0=h5eee18b_0
129 | libtiff=4.5.1=h6a678d5_0
130 | libunistring=0.9.10=h27cfd23_0
131 | libuuid=1.41.5=h5eee18b_0
132 | libwebp-base=1.3.2=h5eee18b_0
133 | libxcb=1.15=h7f8727e_0
134 | libxkbcommon=1.0.1=h5eee18b_1
135 | libxml2=2.10.4=hfdd30dd_2
136 | llvm-openmp=14.0.6=h9e868ea_0
137 | lz4-c=1.9.4=h6a678d5_1
138 | markdown-it-py=2.2.0=py312h06a4308_1
139 | markupsafe=2.1.3=py312h5eee18b_0
140 | matplotlib=3.8.4=py312h06a4308_0
141 | matplotlib-base=3.8.4=py312h526ad5a_0
142 | matplotlib-inline=0.1.6=py312h06a4308_0
143 | mccabe=0.7.0=pyhd3eb1b0_0
144 | mdurl=0.1.0=py312h06a4308_0
145 | mistune=2.0.4=py312h06a4308_0
146 | mkl=2023.1.0=h213fc3f_46344
147 | mkl-service=2.4.0=py312h5eee18b_1
148 | mkl_fft=1.3.8=py312h5eee18b_0
149 | mkl_random=1.2.4=py312hdb19cb5_0
150 | more-itertools=10.1.0=py312h06a4308_0
151 | mpmath=1.3.0=py312h06a4308_0
152 | mypy_extensions=1.0.0=py312h06a4308_0
153 | mysql=5.7.24=h721c034_2
154 | nbclient=0.8.0=py312h06a4308_0
155 | nbconvert=7.10.0=py312h06a4308_0
156 | nbformat=5.9.2=py312h06a4308_0
157 | ncurses=6.4=h6a678d5_0
158 | nest-asyncio=1.6.0=py312h06a4308_0
159 | nettle=3.7.3=hbbd107a_1
160 | networkx=3.1=py312h06a4308_0
161 | nspr=4.35=h6a678d5_0
162 | nss=3.89.1=h6a678d5_0
163 | numpy=1.26.4=py312hc5e2394_0
164 | numpy-base=1.26.4=py312h0da6c21_0
165 | numpydoc=1.5.0=py312h06a4308_0
166 | openh264=2.1.1=h4ff587b_0
167 | openjpeg=2.4.0=h3ad879b_0
168 | openssl=3.0.13=h7f8727e_1
169 | packaging=23.2=py312h06a4308_0
170 | pandocfilters=1.5.0=pyhd3eb1b0_0
171 | parso=0.8.3=pyhd3eb1b0_0
172 | pathspec=0.10.3=py312h06a4308_0
173 | pcre2=10.42=hebb0a14_0
174 | pexpect=4.8.0=pyhd3eb1b0_3
175 | pickleshare=0.7.5=pyhd3eb1b0_1003
176 | pillow=10.3.0=py312h5eee18b_0
177 | pip=23.3.1=py312h06a4308_0
178 | platformdirs=3.10.0=py312h06a4308_0
179 | pluggy=1.0.0=py312h06a4308_1
180 | ply=3.11=py312h06a4308_1
181 | prompt-toolkit=3.0.43=py312h06a4308_0
182 | prompt_toolkit=3.0.43=hd3eb1b0_0
183 | psutil=5.9.0=py312h5eee18b_0
184 | ptyprocess=0.7.0=pyhd3eb1b0_2
185 | pure_eval=0.2.2=pyhd3eb1b0_0
186 | pybind11-abi=5=hd3eb1b0_0
187 | pycodestyle=2.11.1=py312h06a4308_0
188 | pycparser=2.21=pyhd3eb1b0_0
189 | pydocstyle=6.3.0=py312h06a4308_0
190 | pyflakes=3.2.0=py312h06a4308_0
191 | pygments=2.15.1=py312h06a4308_1
192 | pylint=2.16.2=py312h06a4308_0
193 | pylint-venv=3.0.3=py312h06a4308_0
194 | pyls-spyder=0.4.0=pyhd3eb1b0_0
195 | pyparsing=3.0.9=py312h06a4308_0
196 | pyqt=5.15.10=py312h6a678d5_0
197 | pyqt5-sip=12.13.0=py312h5eee18b_0
198 | pyqtwebengine=5.15.10=py312h6a678d5_0
199 | pysocks=1.7.1=py312h06a4308_0
200 | python=3.12.3=h996f2a0_0
201 | python-dateutil=2.8.2=pyhd3eb1b0_0
202 | python-fastjsonschema=2.16.2=py312h06a4308_0
203 | python-lsp-black=2.0.0=py312h06a4308_0
204 | python-lsp-jsonrpc=1.1.2=pyhd3eb1b0_0
205 | python-lsp-server=1.10.0=py312h06a4308_0
206 | python-slugify=5.0.2=pyhd3eb1b0_0
207 | pytoolconfig=1.2.6=py312h06a4308_0
208 | pytorch=2.3.0=py3.12_cuda11.8_cudnn8.7.0_0
209 | pytorch-cuda=11.8=h7e8668a_5
210 | pytorch-mutex=1.0=cuda
211 | pytz=2024.1=py312h06a4308_0
212 | pyxdg=0.27=pyhd3eb1b0_0
213 | pyyaml=6.0.1=py312h5eee18b_0
214 | pyzmq=25.1.2=py312h6a678d5_0
215 | qdarkstyle=3.2.3=pyhd3eb1b0_0
216 | qstylizer=0.2.2=py312h06a4308_0
217 | qt-main=5.15.2=h53bd1ea_10
218 | qt-webengine=5.15.9=h9ab4d14_7
219 | qtawesome=1.2.2=py312h06a4308_0
220 | qtconsole=5.5.1=py312h06a4308_0
221 | qtpy=2.4.1=py312h06a4308_0
222 | readline=8.2=h5eee18b_0
223 | referencing=0.30.2=py312h06a4308_0
224 | requests=2.31.0=py312h06a4308_1
225 | rich=13.3.5=py312h06a4308_1
226 | rope=1.12.0=py312h06a4308_0
227 | rpds-py=0.10.6=py312hb02cf49_0
228 | rtree=1.0.1=py312h06a4308_0
229 | scipy=1.13.0=py312hc5e2394_0
230 | secretstorage=3.3.1=py312h06a4308_1
231 | setuptools=68.2.2=py312h06a4308_0
232 | sip=6.7.12=py312h6a678d5_0
233 | six=1.16.0=pyhd3eb1b0_1
234 | snowballstemmer=2.2.0=pyhd3eb1b0_0
235 | sortedcontainers=2.4.0=pyhd3eb1b0_0
236 | soupsieve=2.5=py312h06a4308_0
237 | sphinx=5.0.2=py312h06a4308_0
238 | sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0
239 | sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0
240 | sphinxcontrib-htmlhelp=2.0.0=pyhd3eb1b0_0
241 | sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0
242 | sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0
243 | sphinxcontrib-serializinghtml=1.1.5=pyhd3eb1b0_0
244 | spyder=5.5.1=py312h06a4308_0
245 | spyder-kernels=2.5.0=py312h06a4308_0
246 | sqlite=3.45.3=h5eee18b_0
247 | stack_data=0.2.0=pyhd3eb1b0_0
248 | sympy=1.12=py312h06a4308_0
249 | tbb=2021.8.0=hdb19cb5_0
250 | text-unidecode=1.3=pyhd3eb1b0_0
251 | textdistance=4.2.1=pyhd3eb1b0_0
252 | three-merge=0.1.1=pyhd3eb1b0_0
253 | tinycss2=1.2.1=py312h06a4308_0
254 | tk=8.6.12=h1ccaba5_0
255 | tomli=2.0.1=py312h06a4308_1
256 | tomlkit=0.11.1=py312h06a4308_0
257 | torchaudio=2.3.0=py312_cu118
258 | torchvision=0.18.0=py312_cu118
259 | tornado=6.3.3=py312h5eee18b_0
260 | tqdm=4.66.2=py312he106c6f_0
261 | traitlets=5.7.1=py312h06a4308_0
262 | typing_extensions=4.9.0=py312h06a4308_1
263 | tzdata=2024a=h04d1e81_0
264 | ujson=5.4.0=py312h6a678d5_0
265 | unicodedata2=15.1.0=py312h5eee18b_0
266 | unidecode=1.2.0=pyhd3eb1b0_0
267 | urllib3=2.1.0=py312h06a4308_1
268 | watchdog=2.1.6=py312h06a4308_0
269 | wcwidth=0.2.5=pyhd3eb1b0_0
270 | webencodings=0.5.1=py312h06a4308_2
271 | whatthepatch=1.0.2=py312h06a4308_0
272 | wheel=0.41.2=py312h06a4308_0
273 | wrapt=1.14.1=py312h5eee18b_0
274 | wurlitzer=3.0.2=py312h06a4308_0
275 | xz=5.4.6=h5eee18b_1
276 | yaml=0.2.5=h7b6447c_0
277 | yapf=0.40.2=py312h06a4308_0
278 | zeromq=4.3.5=h6a678d5_0
279 | zipp=3.17.0=py312h06a4308_0
280 | zlib=1.2.13=h5eee18b_1
281 | zstd=1.5.5=hc292b87_1
282 |
--------------------------------------------------------------------------------
/data/train/label/labels_snr_0dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_0dB.pt
--------------------------------------------------------------------------------
/data/train/label/labels_snr_10dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_10dB.pt
--------------------------------------------------------------------------------
/data/train/label/labels_snr_15dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_15dB.pt
--------------------------------------------------------------------------------
/data/train/label/labels_snr_20dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_20dB.pt
--------------------------------------------------------------------------------
/data/train/label/labels_snr_25dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_25dB.pt
--------------------------------------------------------------------------------
/data/train/label/labels_snr_30dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_30dB.pt
--------------------------------------------------------------------------------
/data/train/label/labels_snr_5dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_5dB.pt
--------------------------------------------------------------------------------
/data/train/signal/signals_snr_0dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_0dB.pt
--------------------------------------------------------------------------------
/data/train/signal/signals_snr_10dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_10dB.pt
--------------------------------------------------------------------------------
/data/train/signal/signals_snr_15dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_15dB.pt
--------------------------------------------------------------------------------
/data/train/signal/signals_snr_20dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_20dB.pt
--------------------------------------------------------------------------------
/data/train/signal/signals_snr_25dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_25dB.pt
--------------------------------------------------------------------------------
/data/train/signal/signals_snr_30dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_30dB.pt
--------------------------------------------------------------------------------
/data/train/signal/signals_snr_5dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_5dB.pt
--------------------------------------------------------------------------------
/data/val/label/labels_snr_0dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_0dB.pt
--------------------------------------------------------------------------------
/data/val/label/labels_snr_10dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_10dB.pt
--------------------------------------------------------------------------------
/data/val/label/labels_snr_15dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_15dB.pt
--------------------------------------------------------------------------------
/data/val/label/labels_snr_20dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_20dB.pt
--------------------------------------------------------------------------------
/data/val/label/labels_snr_25dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_25dB.pt
--------------------------------------------------------------------------------
/data/val/label/labels_snr_30dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_30dB.pt
--------------------------------------------------------------------------------
/data/val/label/labels_snr_5dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_5dB.pt
--------------------------------------------------------------------------------
/data/val/signal/signals_snr_0dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_0dB.pt
--------------------------------------------------------------------------------
/data/val/signal/signals_snr_10dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_10dB.pt
--------------------------------------------------------------------------------
/data/val/signal/signals_snr_15dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_15dB.pt
--------------------------------------------------------------------------------
/data/val/signal/signals_snr_20dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_20dB.pt
--------------------------------------------------------------------------------
/data/val/signal/signals_snr_25dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_25dB.pt
--------------------------------------------------------------------------------
/data/val/signal/signals_snr_30dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_30dB.pt
--------------------------------------------------------------------------------
/data/val/signal/signals_snr_5dB.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_5dB.pt
--------------------------------------------------------------------------------
/fig/Accuracy1_SLA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy1_SLA.png
--------------------------------------------------------------------------------
/fig/Accuracy1_ULA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy1_ULA.png
--------------------------------------------------------------------------------
/fig/Accuracy2_SLA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy2_SLA.png
--------------------------------------------------------------------------------
/fig/Accuracy2_ULA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy2_ULA.png
--------------------------------------------------------------------------------
/fig/Example_SLA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_SLA.png
--------------------------------------------------------------------------------
/fig/Example_SLA_real.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_SLA_real.png
--------------------------------------------------------------------------------
/fig/Example_ULA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_ULA.png
--------------------------------------------------------------------------------
/fig/Example_ULA_real.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_ULA_real.png
--------------------------------------------------------------------------------
/fig/Network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Network.png
--------------------------------------------------------------------------------
/fig/Separate_SLA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Separate_SLA.png
--------------------------------------------------------------------------------
/fig/Separate_ULA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Separate_ULA.png
--------------------------------------------------------------------------------
/models/SADOANet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.append('../')
5 | from scr.helpers import steering_vector
6 |
7 |
8 | class DOANet(nn.Module):
9 | def __init__(self, number_element=20, output_size=61):
10 | super(DOANet, self).__init__()
11 | # Layer configurations
12 | self.input_size = number_element
13 | self.output_size = output_size
14 | hidden_sizes = [2048, 1024, 512, 256, 128] # Adjustable hidden layer sizes
15 |
16 | # Network layers
17 | layers = []
18 | for h1, h2 in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes):
19 | layers.append(nn.Linear(h1, h2))
20 | layers.append(nn.ReLU())
21 | layers.append(nn.Linear(hidden_sizes[-1], self.output_size))
22 | layers.append(nn.Sigmoid())
23 |
24 | self.layers = nn.Sequential(*layers)
25 |
26 | def forward(self, x):
27 | return self.layers(x)
28 |
29 |
30 | class SparseLayer(nn.Module):
31 | def __init__(self, input_size=10, max_sparsity=0.5):
32 | super(SparseLayer, self).__init__()
33 | self.input_size = input_size
34 | self.max_zeros = int(input_size * max_sparsity)
35 |
36 | def forward(self, x):
37 | if self.training:
38 | batch_size,N = x.size() # Get the batch size
39 | sparsity = torch.zeros(batch_size, dtype=torch.long).to(x.device) # Tensor to store the number of zeros
40 | # Generate a random mask for each example in the batch
41 | masks = torch.ones((batch_size, self.input_size)).to(x.device) # Start with all ones
42 | NN = torch.ones(batch_size) * N
43 | num_zeros = torch.randint(0, self.max_zeros + 1, (batch_size,))
44 | for i in range(batch_size):
45 | zero_indices = torch.randperm(self.input_size)[:num_zeros[i]] # Random indices to be zeroed
46 | masks[i, zero_indices] = 0 # Set selected indices to zero
47 | sparsity = NN - num_zeros # Store the number of zeros used for this mask
48 | x_sparse = x * masks
49 |
50 | else:
51 | x_sparse = x
52 | sparsity, masks = self.thresholding(x)
53 |
54 | return x_sparse, sparsity, masks.to(x_sparse.dtype).to(x_sparse.device)
55 |
56 |
57 | def thresholding(self, x):
58 | threshold = 0.001
59 | mask = (torch.abs(x) > threshold).float()
60 | return torch.sum(mask, dim=1), mask
61 |
62 |
63 | class SALayer(nn.Module):
64 | def __init__(self, number_element=10, output_size=61, max_sparsity=0.5):
65 | super(SALayer, self).__init__()
66 | # Initialize SparseLayer
67 | self.sparselayer = SparseLayer(number_element, max_sparsity)
68 | self.output_size = output_size
69 | self.hidden_size = 512
70 |
71 | # Calculate steering vectors
72 | a_theta = steering_vector(number_element, torch.arange(-30, 31)).conj()
73 | self.AH = torch.transpose(a_theta, 0, 1)
74 |
75 | # Define the network layers
76 | self.linear_layer = nn.Linear(number_element * 2, self.hidden_size)
77 | self.relu = nn.ReLU() # ReLU activation function
78 |
79 | def forward(self, x):
80 | batch_size, N = x.size()
81 | # Sparse representation and FFT
82 | x_sparse, sparsity, masks = self.sparselayer(x)
83 | x_fft = self.apply_fft(x_sparse, sparsity, batch_size)
84 | masks_fft = self.apply_fft(masks, sparsity, batch_size)
85 |
86 | # Flatten the sparse input and apply linear transformation
87 | x_flat = torch.view_as_real(x_sparse).view(batch_size, -1)
88 | embedded_values = self.relu(self.linear_layer(x_flat))
89 | normalized_values = self.normalize(embedded_values, sparsity)
90 |
91 | # Concatenate features from different processing streams
92 | output = torch.cat((masks_fft, x_fft, normalized_values), dim=1)
93 | return output
94 |
95 | def normalize(self, x, sparsity):
96 | """Normalize the data by the sparsity-derived factor."""
97 | normalization_factor = sparsity.unsqueeze(-1).to(x.device)
98 | return x / normalization_factor
99 |
100 | def apply_fft(self, x, sparsity, batch_size):
101 | """Apply FFT to the input tensor and adjust based on batch size."""
102 | AH_batched = self.AH.unsqueeze(0).repeat(batch_size, 1, 1)
103 | x_expanded = x.unsqueeze(-1)
104 | fft_output = torch.abs(torch.matmul(AH_batched.to(x_expanded.device), x_expanded)).squeeze(-1)
105 | return self.normalize(fft_output,sparsity)
106 |
107 |
108 | class SADOANet(nn.Module):
109 | def __init__(self, number_element=10, output_size=61, max_sparsity=0.5, is_sparse=True):
110 | super(SADOANet, self).__init__()
111 | self.salayer = SALayer(number_element, output_size, max_sparsity)
112 | input_size = 512 + output_size * 2 if is_sparse else number_element * 2
113 | self.doanet = DOANet(input_size, output_size)
114 | self.is_sparse = is_sparse
115 |
116 | def forward(self, x):
117 | if len(x.size()) == 3: # signal pass in as real value
118 | x = torch.view_as_complex(x)
119 | if self.is_sparse:
120 | x = self.salayer(x)
121 | else:
122 | x = torch.view_as_real(x).view(x.size(0), -1)
123 | return self.doanet(x)
124 |
125 |
--------------------------------------------------------------------------------
/models/__pycache__/SADOANet.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/models/__pycache__/SADOANet.cpython-312.pyc
--------------------------------------------------------------------------------
/real_World_DOA_dataset/README.md:
--------------------------------------------------------------------------------
1 | # Real World dataset for Direction of Arrival (DOA) Estimation
2 |
3 | ## Motivation
4 | In the field of DOA estimation, the absence of publicly available real-world datasets has been a significant barrier for advancing and validating DOA estimation technologies. Historically, researchers have relied on simulated datasets to train and evaluate their models. To bridge this gap and contribute a practical resource to the community, we have developed a DOA estimation dataset gathered under real-world conditions.
5 |
6 | ## Data acquisition vehicle platform
7 | Data acquisition vehicle platform of Lexus RX450h with high-resolution imaging radar, LiDAR, and stereo cameras is used to carry out field experiments at the University of Alabama.
8 |
9 |
10 |
11 |
12 | ## Field experiment
13 | This dataset was generated in a parking lot scenario where a stationary vehicle, equipped with a TI Cascade Imaging Radar, collected data. The vehicle was stationed to capture signals from a corner reflector placed 15 meters away, encompassing all possible directions.
14 |
15 |
16 |
17 |
18 |
19 | This comprehensive data collection resulted in 195 high-SNR signals representing unique angles of arrival from a single target.
20 |
21 |
22 |
23 |
24 |
25 | To enhance the complexity and usability of the dataset, we superimposed these signals to simulate scenarios with multiple targets. Here are some examples of FFT spectrum on multiple targets signals.
26 |
27 |
28 |
29 |
30 | ## Dataset structure
31 | data.mat contains:
32 | - ang_list (1 x 195): the ground truth DOA of each signal
33 | - bv_list (195 x 86): 195 raw signal with 86 antennas
34 |
35 | ## How to use
36 | Matlab
37 | ``` matlab
38 | load('data.mat')
39 | ```
40 | python
41 | ``` python
42 | import scipy.io
43 | data = scipy.io.loadmat('data.mat')
44 | ```
45 |
46 | If this dataset contributes to your research, please acknowledge its use with the following citation:
47 | ``` LATEX
48 | @ARTICLE{10348517,
49 | author={Zheng, Ruxin and Sun, Shunqiao and Liu, Hongshan and Chen, Honglei and Li, Jian},
50 | journal={IEEE Sensors Journal},
51 | title={Interpretable and Efficient Beamforming-Based Deep Learning for Single Snapshot DOA Estimation},
52 | year={2023},
53 | volume={},
54 | number={},
55 | pages={1-1},
56 | keywords={Direction-of-arrival estimation;Estimation;Deep learning;Covariance matrices;Sensors;Mathematical models;Array signal processing;Single snapshot DOA estimation;array signal processing;automotive radar;interpretability;deep learning},
57 | doi={10.1109/JSEN.2023.3338575}}
58 | ```
59 |
--------------------------------------------------------------------------------
/real_World_DOA_dataset/data.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/data.mat
--------------------------------------------------------------------------------
/real_World_DOA_dataset/fig/DOA data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/DOA data.png
--------------------------------------------------------------------------------
/real_World_DOA_dataset/fig/example.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/example.gif
--------------------------------------------------------------------------------
/real_World_DOA_dataset/fig/multiExamples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/multiExamples.png
--------------------------------------------------------------------------------
/real_World_DOA_dataset/fig/platform.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/platform.png
--------------------------------------------------------------------------------
/scr/__pycache__/eval_fun.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/__pycache__/eval_fun.cpython-312.pyc
--------------------------------------------------------------------------------
/scr/__pycache__/helpers.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/__pycache__/helpers.cpython-311.pyc
--------------------------------------------------------------------------------
/scr/__pycache__/helpers.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/__pycache__/helpers.cpython-312.pyc
--------------------------------------------------------------------------------
/scr/dataset_gen.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from helpers import generate_data
4 |
5 |
6 | def main(args):
7 | print("Generating validation data ... ...")
8 | generate_data(
9 | N=args.N,
10 | num_samples=args.num_samples_val,
11 | max_targets=args.max_targets,
12 | folder_path=os.path.join(args.output_dir, 'data/val')
13 | )
14 |
15 | print("Generating training data ... ...")
16 | generate_data(
17 | N=args.N,
18 | num_samples=args.num_samples_train,
19 | max_targets=args.max_targets,
20 | folder_path=os.path.join(args.output_dir, 'data/train')
21 | )
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser(description='Generate dataset for antenna signal processing.')
25 | parser.add_argument('--output_dir', type=str, default='./',
26 | help='Base directory for output data')
27 | parser.add_argument('--num_samples_val', type=int, default=1024,
28 | help='Number of validation samples to generate')
29 | parser.add_argument('--num_samples_train', type=int, default=100000,
30 | help='Number of training samples to generate')
31 | parser.add_argument('--N', type=int, default=10,
32 | help='Number of antenna elements')
33 | parser.add_argument('--max_targets', type=int, default=3,
34 | help='Maximum number of targets per sample')
35 | args = parser.parse_args()
36 | main(args)
--------------------------------------------------------------------------------
/scr/eval_fun.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import argparse
3 | import matplotlib.pyplot as plt
4 | from scipy.signal import find_peaks
5 | import time
6 | from helpers import *
7 | import sys
8 | sys.path.append('../')
9 | from models.SADOANet import SADOANet
10 | from tqdm import tqdm
11 | import torch
12 | from collections import OrderedDict
13 |
14 | def load_sadoanet(num_elements, output_size, sparsity, is_sparse, device, model_path):
15 | model = SADOANet(num_elements, output_size, sparsity, is_sparse).to(device)
16 | state_dict = torch.load(model_path)
17 | # model is trained using nn.DataParallel
18 | # need to rename key, if the model not in DataParallel
19 | new_state_dict = OrderedDict()
20 | for k, v in state_dict.items():
21 | new_state_dict[k.replace("module.","")] = v
22 | model.load_state_dict(new_state_dict)
23 |
24 | return model.eval()
25 |
26 | def randSparse(signal,sparsity):
27 | sparseSignal = signal.clone()
28 | sparseInd = torch.randperm(signal.numel())[:int(signal.numel() * sparsity)]
29 | sparseSignal[sparseInd] = 0
30 | return sparseSignal
31 |
32 | def FFT(signal, num_antennas=10, ang_min = -30, ang_max = 31, ang_step = 1):
33 | ang_list = torch.arange(ang_min,ang_max,ang_step)
34 | a_theta = steering_vector(num_antennas,ang_list ).conj()
35 | AH = torch.transpose(a_theta, 0, 1)
36 | spec = torch.abs(torch.matmul(AH,signal)/num_antennas).squeeze().numpy()
37 | return ang_list, spec
38 |
39 | def IAA(y, Niter=15):
40 | """
41 | Implementation of the IAA algorithm.
42 | """
43 | N = y.shape[0]
44 | ang_list = torch.arange(-90,91,1)
45 | A = steering_vector(N,ang_list)
46 | AH = torch.transpose(A.conj(), 0, 1)
47 | N, K = A.shape
48 | Ns = y.shape[1]
49 | Pk = np.zeros(K, dtype=np.complex128)
50 |
51 | y = y.squeeze().numpy()
52 | A = A.numpy()
53 | AH = AH.resolve_conj().numpy()
54 | # Initial computation of power
55 |
56 | Pk = (AH @ y / N) ** 2
57 | P = np.diag(Pk)
58 | R = A @ P @ AH
59 |
60 | # Main iteration
61 | for _ in range(Niter):
62 | R += 0e-3 * np.eye(N)
63 | ak_R = AH @ np.linalg.pinv(R)
64 | T = ak_R @ y
65 | B = ak_R @ A
66 | b = B.diagonal()
67 | sk = T/np.abs(b)
68 | Pk = np.abs(sk) ** 2
69 | P = np.diag(Pk)
70 | R = A @ P @ A.conj().T
71 | # spec = Pk
72 | spec = Pk[60:121] #-30:1:30
73 | ang_list = torch.arange(-30,31,1)
74 | return ang_list, spec
75 |
76 | def DLapproach(signal,model,device):
77 | model_result = model(signal.squeeze().unsqueeze(0).to(device))
78 | model_result = model_result.squeeze().cpu().detach().numpy()
79 | ang_list = torch.arange(-30,31,1)
80 | spec = model_result
81 | return ang_list, spec
82 |
83 | def estimate_doa(ang_list, spec, scale = 0.7):
84 | """
85 | Estimate doa from spectrum
86 | """
87 | max_height = np.max(spec)
88 | min_peak_height = (scale * max_height)
89 | peaks, properties = find_peaks(spec, height=min_peak_height)
90 | # Sort peaks by their magnitudes in descending order
91 | sorted_indices = np.argsort(properties['peak_heights'])[::-1] # Get indices to sort in descending order
92 | sorted_peaks = peaks[sorted_indices]
93 | sorted_peak_heights = properties['peak_heights'][sorted_indices]
94 | doa = ang_list[sorted_peaks]
95 |
96 | return doa
97 |
98 | def plot_results(snr_levels, mse_metrics):
99 | plt.figure(figsize=(10, 6))
100 | markers = {'fft': 'o', 'iaa': 's', 'mlp': '^', 'sparse': 'd', 'grid': 'x'}
101 | colors = {'fft': 'blue', 'iaa': 'green', 'mlp': 'red', 'sparse': 'cyan', 'grid': 'black'}
102 | for key, color in colors.items():
103 | plt.semilogy(snr_levels, mse_metrics[key], label=key.upper(), color=color, marker=markers[key])
104 |
105 | plt.title('MSE vs SNR for Different DOA Estimation Methods')
106 | plt.xlabel('SNR (dB)')
107 | plt.ylabel('Mean Squared Error (deg^2)')
108 | plt.legend()
109 | plt.grid(True)
110 | plt.show()
111 |
112 | def plot_HR(snr_levels, HR_metrics):
113 | plt.figure(figsize=(10, 6))
114 | markers = {'fft': 'o', 'iaa': 's', 'mlp': '^', 'sparse': 'd'}
115 | colors = {'fft': 'blue', 'iaa': 'green', 'mlp': 'red', 'sparse': 'cyan'}
116 | for key, color in colors.items():
117 | plt.plot(snr_levels, HR_metrics[key], label=key.upper(), color=color, marker=markers[key])
118 |
119 | plt.title('Hit Rate vs SNR for Different DOA Estimation Methods')
120 | plt.xlabel('SNR (dB)')
121 | plt.ylabel('Hit Rate (%)')
122 | plt.legend()
123 | plt.grid(True)
124 | plt.show()
125 |
126 | def handle_estimated_angles(tmp):
127 | """ Handle the sorting and selection of estimated angles. """
128 | if tmp.nelement() == 0:
129 | return np.array([30, 30]) # Default error value when no angles are resolved
130 | elif tmp.nelement() == 1:
131 | return np.array([tmp[0].item(), tmp[0].item()])
132 | else:
133 | tmp = tmp[:2]
134 | sorted_tmp = torch.sort(tmp)[0].numpy()
135 | return sorted_tmp
136 |
137 | def calculate_mse(actual_angles_1, actual_angles_2, estimates):
138 | """ Calculate the mean squared error for estimated angles. """
139 | mse_1 = np.mean((actual_angles_1 - estimates[:, 0]) ** 2)
140 | mse_2 = np.mean((actual_angles_2 - estimates[:, 1]) ** 2)
141 | return np.mean([mse_1, mse_2])
142 |
143 | def count_parameters(model):
144 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
145 |
146 | #################################################################
147 | ## MONTE CARLO TEST FUNCTIONS ##
148 | #################################################################
149 | def run_monte_carlo_accuracy(num_simulations, num_antennas=10, sparse_flag = True):
150 | snr_levels = np.arange(-5, 35, 5) # SNR levels from -5 dB to 30 dB in 5 dB steps
151 | mse_metrics = {'fft': [], 'iaa': [], 'mlp': [], 'sparse': [], 'grid': []}
152 |
153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154 | # Load models
155 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth')
156 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth')
157 |
158 | methods = {
159 | 'fft': lambda sig: FFT(sig),
160 | 'iaa': lambda sig: IAA(sig),
161 | 'mlp': lambda sig: DLapproach(sig, mlp_model, device),
162 | 'sparse': lambda sig: DLapproach(sig, sparse_model, device)
163 | }
164 |
165 | for snr_db in tqdm(snr_levels, desc='SNR levels', unit='snr'):
166 | actual_angles = np.random.uniform(-30, 30, num_simulations)
167 | estimates = {key: np.zeros(num_simulations) for key in mse_metrics}
168 |
169 | for i in range(num_simulations):
170 | signal = generate_complex_signal(num_antennas, snr_db, torch.tensor([actual_angles[i]]))
171 | if sparse_flag:
172 | signal = randSparse(signal, 0.3)
173 |
174 | for method_name, method in methods.items():
175 | ang_list, spec = method(signal)
176 | tmp = estimate_doa(ang_list, spec)
177 | estimates[method_name][i] = actual_angles[i] if tmp.nelement() == 0 else tmp[0].item()
178 |
179 | for key in mse_metrics:
180 | if key != 'grid':
181 | mse_metrics[key].append(np.mean((actual_angles - estimates[key]) ** 2))
182 | else :
183 | mse_metrics[key].append(np.mean((actual_angles - np.rint(actual_angles)) ** 2))
184 | return snr_levels, mse_metrics
185 |
186 |
187 | ##################################################################
188 | def run_monte_carlo_accuracy2(num_simulations, num_antennas=10, sparse_flag = True):
189 | """
190 | Run Monte Carlo simulations to evaluate various DOA estimation methods across SNR levels.
191 | """
192 | snr_levels = np.arange(-5, 35, 5) # SNR levels from -5 dB to 30 dB in 5 dB steps
193 | mse_metrics = {'fft': [], 'iaa': [], 'mlp': [], 'sparse': [], 'grid': []}
194 |
195 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196 | # Load models
197 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth')
198 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth')
199 |
200 | methods = {
201 | 'fft': lambda sig: FFT(sig),
202 | 'iaa': lambda sig: IAA(sig),
203 | 'mlp': lambda sig: DLapproach(sig, mlp_model, device),
204 | 'sparse': lambda sig: DLapproach(sig, sparse_model, device)
205 | }
206 |
207 | for snr_db in tqdm(snr_levels, desc='SNR levels', unit='snr'):
208 | actual_angles_1 = np.random.uniform(-0.6, 0.4, num_simulations)
209 | actual_angles_2 = np.random.uniform(9.6, 10.4, num_simulations)
210 | actual_angles = np.column_stack((actual_angles_1, actual_angles_2)).flatten()
211 |
212 | estimates = {key: np.zeros((num_simulations, 2)) for key in mse_metrics}
213 |
214 | for i in range(num_simulations):
215 | signal = generate_complex_signal(num_antennas, snr_db, torch.tensor([actual_angles_1[i], actual_angles_2[i]]))
216 | if sparse_flag:
217 | signal = randSparse(signal, 0.3)
218 |
219 | for method_name, method in methods.items():
220 | ang_list, spec = method(signal)
221 | tmp = estimate_doa(ang_list, spec,0)
222 | estimates[method_name][i] = handle_estimated_angles(tmp)
223 |
224 | # Calculate MSE for each method
225 | for key in mse_metrics:
226 | if key != 'grid':
227 | mse_metrics[key].append(calculate_mse(actual_angles_1, actual_angles_2, estimates[key]))
228 | else:
229 | tmp1 = np.mean((actual_angles_1- np.rint(actual_angles_1)) ** 2)
230 | tmp2 = np.mean((actual_angles_2- np.rint(actual_angles_2)) ** 2)
231 | mse_metrics[key].append(np.mean(tmp1+tmp2))
232 | return snr_levels, mse_metrics
233 |
234 |
235 |
236 | ##################################################################
237 | def run_monte_carlo_sep(num_simulations, num_antennas=10, sparse_flag = True):
238 | """
239 | Run Monte Carlo simulations to evaluate DOA estimation accuracy across different separation angles.
240 | """
241 | sep_angles = np.arange(2, 30, 2) # Separation angles from 2 to 28 degrees in 2 degree steps
242 | HR_metrics = {'fft': [], 'iaa': [], 'mlp': [], 'sparse': []}
243 |
244 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
245 | # Load models
246 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth')
247 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth')
248 |
249 | methods = {
250 | 'fft': lambda sig: FFT(sig),
251 | 'iaa': lambda sig: IAA(sig),
252 | 'mlp': lambda sig: DLapproach(sig, mlp_model, device),
253 | 'sparse': lambda sig: DLapproach(sig, sparse_model, device)
254 | }
255 |
256 |
257 | for sep in tqdm(sep_angles, desc='Separation', unit='deg', ncols=86):
258 | actual_angles = np.array([-sep / 2, sep / 2]) # Centered angles
259 | tmp_rates = {key: np.zeros(num_simulations) for key in HR_metrics}
260 |
261 | for i in range(num_simulations):
262 | signal = generate_complex_signal(num_antennas, 40, torch.from_numpy(actual_angles))
263 | if sparse_flag:
264 | signal = randSparse(signal, 0.3)
265 |
266 | for method_name, method in methods.items():
267 | ang_list, spec = method(signal)
268 | estimated_angles = estimate_doa(ang_list, spec, 0.7 if method_name in ['fft', 'iaa'] else 0.2)
269 | estimated_angles, _ = torch.sort(estimated_angles)
270 | if estimated_angles.nelement() == 2 and np.allclose(estimated_angles.numpy(), actual_angles, atol=2):
271 | tmp_rates[method_name][i] = 1
272 |
273 | for key in HR_metrics:
274 | HR_metrics[key].append(np.mean(tmp_rates[key]))
275 |
276 | return sep_angles, HR_metrics
277 |
278 |
279 | ##################################################################
280 | def run_examples(signal,num_antennas=10,sparse_flag = True):
281 | # load model
282 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
283 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth')
284 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth')
285 | signal = torch.from_numpy(signal).to(torch.cfloat)
286 | sparsity= 0.3
287 | antennaPos = np.arange(num_antennas)
288 | if sparse_flag:
289 | signal = randSparse(signal, 0.3)
290 | zero_indices = torch.where(signal == 0)[0].numpy()
291 | print(zero_indices)
292 | antennaPos = np.arange(num_antennas)
293 | antennaPos = np.delete(antennaPos, zero_indices)
294 |
295 | # Create a 1x4 subplot grid
296 | fig, axs = plt.subplots(1, 4, figsize=(20, 5)) # Adjust the figsize to ensure all subplots are visible and not squished
297 |
298 | # Plot the first subplot similar to the uploaded image
299 | axs[0].stem(antennaPos, np.ones_like(antennaPos), linefmt='blue', markerfmt='bo', basefmt="r-")
300 | if sparse_flag:
301 | axs[0].set_title('Sparse Linear Array')
302 | else:
303 | axs[0].set_title('Uniform Linear Array')
304 | axs[0].set_xlabel('Horizontal [Half Wavelength]')
305 | axs[0].set_ylabel('Amplitude')
306 | axs[0].set_ylim([0, 2]) # Setting the y-axis limits to match the uploaded image
307 | axs[0].set_xlim([0, num_antennas-1]) # Setting the x-axis limits to match the uploaded image
308 | axs[0].set_yticks([0, 1, 2])
309 | axs[0].set_xticks(np.arange(num_antennas))
310 | axs[0].grid(True)
311 |
312 | #FFT
313 | ang_list, spec0 = FFT(signal,num_antennas)
314 | spec0 = spec0/np.max(spec0)
315 | #IAA
316 | ang_list, spec1 = IAA(signal)
317 | spec1 = spec1/np.max(spec1)
318 | axs[1].plot(ang_list, spec0, label='DBF', color='blue',linewidth=2)
319 | axs[1].plot(ang_list, spec1, label='IAA', color='red',linewidth=2)
320 | axs[1].axvline(0, color='black', linestyle='--', linewidth=1.5)
321 | axs[1].axvline(7, color='black', linestyle='--', linewidth=1.5)
322 | axs[1].set_title('DBF vs IAA')
323 | axs[1].set_xlabel('Angle [degree]')
324 | axs[1].set_ylabel('Magnitude')
325 | axs[1].set_ylim([0, 1]) # Setting the y-axis limits to match the uploaded image
326 | axs[1].set_xlim([-30, 30]) # Setting the x-axis limits to match the uploaded image
327 | axs[1].set_xticks(np.arange(-30, 31, 10))
328 | axs[1].set_yticks([0, 0.5, 1])
329 | axs[1].legend()
330 | axs[1].grid(True)
331 |
332 | #MLP
333 | ang_list, spec2 = DLapproach(signal,mlp_model,device)
334 | spec2 = spec2/np.max(spec2)
335 | axs[2].plot(ang_list, spec2, color='blue',linewidth=2)
336 | axs[2].axvline(0, color='black', linestyle='--', linewidth=1.5)
337 | axs[2].axvline(7, color='black', linestyle='--', linewidth=1.5)
338 | axs[2].set_title('MLP')
339 | axs[2].set_xlabel('Angle [degree]')
340 | axs[2].set_ylabel('Magnitude')
341 | axs[2].set_ylim([0, 1]) # Setting the y-axis limits to match the uploaded image
342 | axs[2].set_xlim([-30, 30]) # Setting the x-axis limits to match the uploaded image
343 | axs[2].set_xticks(np.arange(-30, 31, 10))
344 | axs[2].set_yticks([0, 0.5, 1])
345 | axs[2].grid(True)
346 |
347 | #Sparse
348 | ang_list, spec3 = DLapproach(signal,sparse_model,device)
349 | spec3 = spec3/np.max(spec3)
350 | axs[3].plot(ang_list, spec3, color='blue',linewidth=2)
351 | axs[3].axvline(0, color='black', linestyle='--', linewidth=1.5)
352 | axs[3].axvline(7, color='black', linestyle='--', linewidth=1.5)
353 | axs[3].set_title('Ours')
354 | axs[3].set_xlabel('Angle [degree]')
355 | axs[3].set_ylabel('Magnitude')
356 | axs[3].set_ylim([0, 1]) # Setting the y-axis limits to match the uploaded image
357 | axs[3].set_xlim([-30, 30]) # Setting the x-axis limits to match the uploaded image
358 | axs[3].set_xticks(np.arange(-30, 31, 10))
359 | axs[3].set_yticks([0, 0.5, 1])
360 | axs[3].grid(True)
361 | plt.show()
362 | return ang_list, spec0, spec1, spec2, spec3
363 |
364 | ##################################################################
365 | def model_complexity(num_elements = 10, output_size = 61, sparsity = 0.3):
366 | mlp_model = SADOANet(num_elements, output_size, sparsity, False)
367 | num_params = count_parameters(mlp_model)
368 | sparse_model = SADOANet(num_elements, output_size, sparsity, True)
369 | num_params_sparse = count_parameters(sparse_model)
370 | print(f"Total trainable parameters in MLP model: {num_params}")
371 | print(f"Total trainable parameters in Ours model: {num_params_sparse}")
372 | return num_params, num_params_sparse
373 |
374 | ##################################################################
375 | # def main():
376 | # parser = argparse.ArgumentParser(description="Run Monte Carlo simulations for DOA estimation accuracy")
377 | # parser.add_argument('--num_simulations', type=int, default=1000, help="Number of Monte Carlo simulations to run")
378 | # parser.add_argument('--num_antennas', type=int, default=10, help="Number of antennas in the array")
379 | # args = parser.parse_args()
380 |
381 | # # Accuracy - single target
382 | # snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas, False)
383 | # plot_results(snr_levels, mse_metrics)
384 | # snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas)
385 | # plot_results(snr_levels, mse_metrics)
386 |
387 | # # Accuracy - two targets
388 | # snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas, False)
389 | # plot_results(snr_levels, mse_metrics)
390 | # snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas)
391 | # plot_results(snr_levels, mse_metrics)
392 |
393 | # # Generate estimation example on real data
394 | # signal = generate_complex_signal(10, 40, torch.tensor([0, 7])).numpy()
395 | # run_examples(signal,num_antennas=10,sparse_flag = False)
396 | # run_examples(signal,num_antennas=10,sparse_flag = True)
397 | # run_examples(signal,num_antennas=10,sparse_flag = True)
398 | # run_examples(signal,num_antennas=10,sparse_flag = True)
399 |
400 | # # Counts total trainable parameters
401 | # model_complexity()
402 | # if __name__ == "__main__":
403 | # main()
404 |
--------------------------------------------------------------------------------
/scr/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from tqdm.auto import tqdm
4 | from torch.utils.data import Dataset, DataLoader
5 |
6 |
7 | def steering_vector(N, deg):
8 | """
9 | Calculate the steering vector for a uniform linear array using given antenna configuration.
10 |
11 | Args:
12 | N (int): Number of antenna elements.
13 | deg (float): Angle of arrival in degrees.
14 |
15 | Returns:
16 | torch.Tensor: The steering vector as a complex-valued tensor.
17 | """
18 | d = 0.5 # Element spacing (in units of wavelength)
19 | wavelength = 1.0 # Wavelength of the signal (same units as d)
20 | k = 2 * torch.pi / wavelength # Wavenumber
21 | n = torch.arange(0, N).view(N, 1) # Antenna element indices [0, 1, ..., N-1]
22 | theta = deg * torch.pi / 180 # Convert degrees to radians
23 | phases = k * d * n * torch.sin(theta) # Phase shift for each element
24 |
25 | return torch.exp(1j * phases) # Complex exponential for each phase shift
26 |
27 |
28 |
29 | def generate_complex_signal(N=10, snr_db=10, deg=torch.tensor([30])):
30 | """
31 | Generates a complex-valued signal for an array of N antenna elements.
32 |
33 | Args:
34 | N (int): Number of antenna elements.
35 | snr_db (float): Signal-to-Noise Ratio in decibels.
36 | deg (tensor): Angle of arrival in degrees.
37 |
38 | Returns:
39 | torch.Tensor: Complex-valued tensor of shape (N, 1) representing the received signals.
40 | """
41 | a_theta = steering_vector(N, deg)
42 | phase = torch.exp(2j * torch.pi * torch.randn(a_theta.size()[1])).view(-1, 1)
43 | signal = torch.matmul(a_theta.to(phase.dtype), phase)
44 | signal_power = torch.mean(torch.abs(signal)**2)
45 | snr_linear = 10**(snr_db / 10)
46 |
47 | noise_power = signal_power / snr_linear
48 | noise_real = torch.sqrt(noise_power / 2) * torch.randn_like(signal.real)
49 | noise_imag = torch.sqrt(noise_power / 2) * torch.randn_like(signal.imag)
50 | noise = torch.complex(noise_real, noise_imag)
51 |
52 | return signal + noise
53 |
54 |
55 | def generate_label(degrees, min_angle=-30, max_angle=30):
56 | """
57 | Generate one-hot encoded labels for the given degrees.
58 |
59 | Args:
60 | degrees (tensor): Target angles in degrees.
61 |
62 | Returns:
63 | torch.Tensor: One-hot encoded labels.
64 | """
65 | labels = torch.zeros(max_angle - min_angle + 1)
66 | indices = degrees - min_angle
67 | labels[indices.long()] = 1
68 | return labels
69 |
70 | def generate_data(N, num_samples=1, max_targets=3, folder_path='/content/drive/MyDrive/Asilomar2024/data/'):
71 | """
72 | Generate dataset with random number of targets and varying SNR levels.
73 |
74 | Args:
75 | N (int): Number of antenna elements.
76 | num_samples (int): Number of samples to generate for each SNR level.
77 | max_targets (int): Maximum number of targets.
78 | folder_path (str): Base folder path for saving data.
79 |
80 | Returns:
81 | int: Always returns 0. Data saved in specified directory.
82 | """
83 | angles = torch.arange(-30, 31, 1)
84 | signal_folder = os.path.join(folder_path, 'signal')
85 | label_folder = os.path.join(folder_path, 'label')
86 | os.makedirs(signal_folder, exist_ok=True)
87 | os.makedirs(label_folder, exist_ok=True)
88 |
89 | for snr_db in tqdm(range(0, 35, 5), desc='SNR levels', unit='snr', dynamic_ncols=True):
90 | all_signals, all_labels = [], []
91 | for _ in range(num_samples):
92 | num_targets = torch.randint(1, max_targets + 1, (1,)).item()
93 | deg_indices = torch.randperm(len(angles))[:num_targets]
94 | degs = angles[deg_indices]
95 | label = generate_label(degs)
96 | noisy_signal = generate_complex_signal(N=N, snr_db=snr_db, deg=degs)
97 | all_signals.append(noisy_signal)
98 | all_labels.append(label)
99 | torch.save(all_signals, os.path.join(signal_folder, f'signals_snr_{snr_db}dB.pt'))
100 | torch.save(all_labels, os.path.join(label_folder, f'labels_snr_{snr_db}dB.pt'))
101 | return None
102 |
103 |
104 | class SignalDataset(Dataset):
105 | def __init__(self, file_paths, label_paths):
106 | """
107 | Initializes a dataset containing signals and their corresponding labels.
108 |
109 | Args:
110 | file_paths (list): Paths to files containing signals.
111 | label_paths (list): Paths to files containing labels.
112 | """
113 | self.signals = [torch.stack(torch.load(file), dim=0) for file in file_paths]
114 | self.labels = [torch.stack(torch.load(label), dim=0) for label in label_paths]
115 | self.signals = torch.cat(self.signals, dim=0)
116 | self.labels = torch.cat(self.labels, dim=0)
117 |
118 | def __len__(self):
119 | return len(self.signals)
120 |
121 | def __getitem__(self, idx):
122 | return self.signals[idx], self.labels[idx]
123 |
124 | def create_dataloader(data_path, batch_size=32, shuffle=True):
125 | """
126 | Create a DataLoader for batching and shuffling the dataset.
127 |
128 | Args:
129 | data_path (str): Path to the directory containing the data files.
130 | batch_size (int): Number of samples per batch.
131 | shuffle (bool): Whether to shuffle the data.
132 |
133 | Returns:
134 | DataLoader: Configured DataLoader for the dataset.
135 | """
136 | signal_dir_path = os.path.join(data_path, "signal")
137 | label_dir_path = os.path.join(data_path, "label")
138 | signal_files = [os.path.join(signal_dir_path, f) for f in os.listdir(signal_dir_path) if 'signals' in f]
139 | label_files = [os.path.join(label_dir_path, f) for f in os.listdir(label_dir_path) if 'labels' in f]
140 | dataset = SignalDataset(sorted(signal_files), sorted(label_files))
141 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
--------------------------------------------------------------------------------
/scr/realData_demo.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/realData_demo.mat
--------------------------------------------------------------------------------
/scr/run_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import scipy.io
4 | from eval_fun import run_monte_carlo_accuracy, run_monte_carlo_accuracy2, run_monte_carlo_sep, plot_results, plot_HR, model_complexity, generate_complex_signal, run_examples
5 | def main():
6 | parser = argparse.ArgumentParser(description="Run Monte Carlo simulations for DOA estimation accuracy")
7 | parser.add_argument('--num_simulations', type=int, default=100, help="Number of Monte Carlo simulations to run")
8 | parser.add_argument('--num_antennas', type=int, default=10, help="Number of antennas in the array")
9 | parser.add_argument('--evaluation_mode', type=str, default='accuracy1', help="Evaluation mode: accuracy1, accuracy2, separate, examples, complexity")
10 | parser.add_argument('--real', type=bool, default = True, help="using real world data demo")
11 | args = parser.parse_args()
12 |
13 | if args.evaluation_mode == 'accuracy1':
14 | # Accuracy - single target
15 | print('Monte Carlo simulation on single target case with ULA \n')
16 | snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas, False)
17 | plot_results(snr_levels, mse_metrics)
18 | print('Monte Carlo simulation on single target case with SLA \n')
19 | snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas)
20 | plot_results(snr_levels, mse_metrics)
21 |
22 | elif args.evaluation_mode == 'accuracy2':
23 | # Accuracy - two targets
24 | print('Monte Carlo simulation on two targets case with ULA \n')
25 | snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas, False)
26 | plot_results(snr_levels, mse_metrics)
27 | print('Monte Carlo simulation on two targets case with SLA \n')
28 | snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas)
29 | plot_results(snr_levels, mse_metrics)
30 |
31 | elif args.evaluation_mode == 'separate':
32 | # Sseparability
33 | print('Monte Carlo simulation on separability with ULA \n')
34 | sep_angles, HR_metrics = run_monte_carlo_sep(args.num_simulations, args.num_antennas, False)
35 | plot_HR(sep_angles, HR_metrics)
36 | print('Monte Carlo simulation on separability with SLA \n')
37 | sep_angles, HR_metrics = run_monte_carlo_sep(args.num_simulations, args.num_antennas)
38 | plot_HR(sep_angles, HR_metrics)
39 |
40 | elif args.evaluation_mode == 'examples':
41 | # Generate estimation example on real data
42 | if args.real:
43 | file_path = 'realData_demo.mat'
44 | mat_data = scipy.io.loadmat(file_path)
45 | bv = mat_data['bv_final']
46 | bv = bv[:,0:10].T
47 | signal = torch.from_numpy(bv).to(torch.cfloat).numpy()
48 | else:
49 | signal = generate_complex_signal(10, 40, torch.tensor([0, 7])).numpy()
50 | run_examples(signal, num_antennas=10, sparse_flag=False)
51 | run_examples(signal, num_antennas=10, sparse_flag=True)
52 |
53 | elif args.evaluation_mode == 'complexity':
54 | # Counts total trainable parameters
55 | model_complexity()
56 |
57 | if __name__ == "__main__":
58 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from tqdm.auto import tqdm
7 |
8 | from scr.helpers import *
9 | from models.SADOANet import SADOANet
10 |
11 | def validate_model(model, dataloader, criterion, device):
12 | """Perform validation and return the average loss."""
13 | model.eval() # Set the model to evaluation mode
14 | total_loss = 0
15 | with torch.no_grad():
16 | for signals, labels in dataloader:
17 | signals = signals.to(device).squeeze(-1)
18 | labels = labels.to(device)
19 | outputs = model(signals)
20 | loss = criterion(outputs, labels)
21 | total_loss += loss.item()
22 | return total_loss / len(dataloader)
23 |
24 | def main(args):
25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 | print(f"Using device: {device}")
27 |
28 | # Initialize the model
29 | model = SADOANet(args.number_elements, args.output_size, args.sparsity, args.use_sparse)
30 |
31 | # Check if multiple GPUs are available and wrap the model using DataParallel
32 | if torch.cuda.device_count() >= 4:
33 | print(f"Using {torch.cuda.device_count()} GPUs!")
34 | model = nn.DataParallel(model, device_ids=list(range(4))) # Modify here to change the number of GPUs used
35 | model = model.to(device)
36 |
37 | # Loss function and optimizer
38 | criterion = nn.BCELoss()
39 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
40 |
41 | train_loader = create_dataloader(os.path.join(args.data_path, 'train'), batch_size=args.batch_size)
42 | val_loader = create_dataloader(os.path.join(args.data_path, 'val'), batch_size=args.batch_size)
43 |
44 | if args.use_sparse:
45 | save_checkpoint_path = os.path.join(args.checkpoint_path, 'sparse')
46 | os.makedirs(save_checkpoint_path, exist_ok=True)
47 | else:
48 | save_checkpoint_path = os.path.join(args.checkpoint_path, 'filled')
49 | os.makedirs(save_checkpoint_path, exist_ok=True)
50 |
51 | checkpoint_path = os.path.join(save_checkpoint_path, 'best_model_checkpoint.pth')
52 | final_model_path = os.path.join(save_checkpoint_path, 'final_model.pth')
53 | # model.load_state_dict(torch.load(checkpoint_path),strict= False)
54 | # Training loop
55 | best_val_loss = float('inf')
56 | for epoch in range(args.epochs):
57 | model.train()
58 | train_loss = 0
59 | for signals, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1} [Training]"):
60 | signals = signals.to(device).squeeze(-1)
61 | labels = labels.to(device)
62 | outputs = model(signals)
63 | loss = criterion(outputs, labels)
64 |
65 | optimizer.zero_grad()
66 | loss.backward()
67 | optimizer.step()
68 | train_loss += loss.item()
69 |
70 | average_train_loss = train_loss / len(train_loader)
71 | print(f"Epoch {epoch + 1} Training Loss: {average_train_loss}")
72 |
73 | if (epoch + 1) % 5 == 0:
74 | val_loss = validate_model(model, val_loader, criterion, device)
75 | print(f"Epoch {epoch + 1} Validation Loss: {val_loss}")
76 | if val_loss < best_val_loss:
77 | best_val_loss = val_loss
78 | torch.save(model.state_dict(), checkpoint_path)
79 | print(f"Best model saved with validation loss {best_val_loss} at {checkpoint_path}")
80 |
81 | torch.save(model.state_dict(), final_model_path)
82 | print(f"Final model saved at {final_model_path}")
83 |
84 | if __name__ == "__main__":
85 | parser = argparse.ArgumentParser(description="Train a DOA estimation model")
86 | parser.add_argument('--data_path', type=str, default='./data',
87 | help='Path to training and validation data directory')
88 | parser.add_argument('--checkpoint_path', type=str, default='./checkpoint',
89 | help='Path where to save model checkpoints')
90 | parser.add_argument('--number_elements', type=int, default=10,
91 | help='Number of array elements in the model')
92 | parser.add_argument('--output_size', type=int, default=61,
93 | help='Output size of the model')
94 | parser.add_argument('--sparsity', type=float, default=0.3,
95 | help='Sparsity level used in the model')
96 | parser.add_argument('--use_sparse', type=bool, default=False,
97 | help='Whether to use sparse augmentation layer in the model')
98 | parser.add_argument('--learning_rate', type=float, default=0.0001,
99 | help='Learning rate for the optimizer')
100 | parser.add_argument('--batch_size', type=int, default=1024,
101 | help='Batch size for training and validation')
102 | parser.add_argument('--epochs', type=int, default=300,
103 | help='Number of epochs to train the model')
104 |
105 | args = parser.parse_args()
106 | main(args)
107 |
--------------------------------------------------------------------------------