├── LICENSE ├── README.md ├── checkpoint ├── filled │ ├── best_model_checkpoint.pth │ └── final_model.pth └── sparse │ ├── best_model_checkpoint.pth │ └── final_model.pth ├── conda_env.txt ├── data ├── train │ ├── label │ │ ├── labels_snr_0dB.pt │ │ ├── labels_snr_10dB.pt │ │ ├── labels_snr_15dB.pt │ │ ├── labels_snr_20dB.pt │ │ ├── labels_snr_25dB.pt │ │ ├── labels_snr_30dB.pt │ │ └── labels_snr_5dB.pt │ └── signal │ │ ├── signals_snr_0dB.pt │ │ ├── signals_snr_10dB.pt │ │ ├── signals_snr_15dB.pt │ │ ├── signals_snr_20dB.pt │ │ ├── signals_snr_25dB.pt │ │ ├── signals_snr_30dB.pt │ │ └── signals_snr_5dB.pt └── val │ ├── label │ ├── labels_snr_0dB.pt │ ├── labels_snr_10dB.pt │ ├── labels_snr_15dB.pt │ ├── labels_snr_20dB.pt │ ├── labels_snr_25dB.pt │ ├── labels_snr_30dB.pt │ └── labels_snr_5dB.pt │ └── signal │ ├── signals_snr_0dB.pt │ ├── signals_snr_10dB.pt │ ├── signals_snr_15dB.pt │ ├── signals_snr_20dB.pt │ ├── signals_snr_25dB.pt │ ├── signals_snr_30dB.pt │ └── signals_snr_5dB.pt ├── fig ├── Accuracy1_SLA.png ├── Accuracy1_ULA.png ├── Accuracy2_SLA.png ├── Accuracy2_ULA.png ├── Example_SLA.png ├── Example_SLA_real.png ├── Example_ULA.png ├── Example_ULA_real.png ├── Network.png ├── Separate_SLA.png └── Separate_ULA.png ├── models ├── SADOANet.py └── __pycache__ │ └── SADOANet.cpython-312.pyc ├── real_World_DOA_dataset ├── README.md ├── data.mat └── fig │ ├── DOA data.png │ ├── example.gif │ ├── multiExamples.png │ └── platform.png ├── scr ├── __pycache__ │ ├── eval_fun.cpython-312.pyc │ ├── helpers.cpython-311.pyc │ └── helpers.cpython-312.pyc ├── dataset_gen.py ├── eval_fun.py ├── helpers.py ├── realData_demo.mat └── run_eval.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 ruxin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Learning-Enabled-Robust-DOA-Estimation-with-Single-Snapshot-Sparse-Arrays 2 | This is the code for paper "Antenna Failure Resilience: Deep Learning-Enabled Robust DOA Estimation with Single Snapshot Sparse Arrays" 3 | 4 | ## Simulated dataset generation for trianing and validation 5 | ``` sh 6 | python scr/dataset_gen.py --output_dir './' --num_samples_val 1024 --num_samples_train 100000 --N 10 --max_targets 3 7 | ``` 8 | 9 | ## Network architectures 10 | 11 |

12 | 13 |

14 | 15 | ## Training 16 | Without sparse augmentation model: 17 | 18 | ```sh 19 | python train.py --data_path './data' --checkpoint_path './checkpoint' --number_elements 10 --output_size 61 --sparsity 0.3 --use_sparse False --learning_rate 0.0001 --batch_size 1024 --epochs 300 20 | ``` 21 | 22 | With sparse augmentation model: 23 | 24 | ``` sh 25 | python train.py --data_path './data' --checkpoint_path './checkpoint' --number_elements 10 --output_size 61 --sparsity 0.3 --use_sparse True --learning_rate 0.0001 --batch_size 1024 --epochs 300 26 | ``` 27 | 28 | ## Evaluation 29 | The evaluation of the model can be conducted immediately using weights that we have trained and provided. These weights are available in the 'checkpoint' directory. 30 | Before proceeding with the following steps, ensure you are in the correct directory where the scripts or applications are located. 31 | 32 | ``` sh 33 | cd scr 34 | ``` 35 | 36 | ### Single target accuracy 37 | 38 | ``` sh 39 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'accuracy1' 40 | ``` 41 | 42 | Expected outputs: ULA(left), SLA(right) 43 |

44 | 45 | 46 |

47 | 48 | ### Two target accuracy 49 | 50 | ``` sh 51 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'accuracy2' 52 | ``` 53 | 54 | Expected outputs: ULA(left), SLA(right) 55 |

56 | 57 | 58 |

59 | 60 | ### Seperatebility 61 | 62 | ``` sh 63 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'separate' 64 | ``` 65 | 66 | Expected outputs: ULA(left), SLA(right) 67 |

68 | 69 | 70 |

71 | 72 | ### Complexity 73 | ``` sh 74 | python run_eval.py --num_simulations 1000 --num_antennas 10 --evaluation_mode 'complexity' 75 | ``` 76 | Expected outputs: 77 | 78 | Total trainable parameters in MLP model: 2848829 79 | 80 | Total trainable parameters in Ours model: 4106301 81 | 82 | ### Results examples 83 | #### With simulated data 84 | 85 | ``` sh 86 | python run_eval.py --evaluation_mode 'examples' 87 | ``` 88 | 89 | Expected outputs: 90 | 91 |

92 | 93 |

94 |

95 | 96 |

97 | 98 | #### With real world data 99 | 100 | ``` sh 101 | python run_eval.py --evaluation_mode 'examples' --real True 102 | ``` 103 | 104 | Expected outputs: 105 |

106 | 107 |

108 |

109 | 110 |

111 | ## Real World dataset 112 | please refer README in the folder 'real_World_DOA_dataset' 113 | 114 | ## Enviroment 115 | The Conda environment required for this project is specified in the file 'conda_env.txt'. This file contains a list of all the necessary Python packages and their versions to ensure compatibility and reproducibility of the project's code. 116 | 117 | If you find this project helpful for your research, please consider citing: 118 | 119 | ```BibTex 120 | @article{zheng2024antenna, 121 | title={Antenna Failure Resilience: Deep Learning-Enabled Robust DOA Estimation with Single Snapshot Sparse Arrays}, 122 | author={Zheng, Ruxin and Sun, Shunqiao and Liu, Hongshan and Chen, Honglei and Soltanalian, Mojtaba and Li, Jian}, 123 | journal={arXiv preprint arXiv:2405.02788}, 124 | year={2024} 125 | } 126 | ``` 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /checkpoint/filled/best_model_checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/filled/best_model_checkpoint.pth -------------------------------------------------------------------------------- /checkpoint/filled/final_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/filled/final_model.pth -------------------------------------------------------------------------------- /checkpoint/sparse/best_model_checkpoint.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/sparse/best_model_checkpoint.pth -------------------------------------------------------------------------------- /checkpoint/sparse/final_model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/checkpoint/sparse/final_model.pth -------------------------------------------------------------------------------- /conda_env.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | alabaster=0.7.12=pyhd3eb1b0_0 7 | arrow=1.2.3=py312h06a4308_1 8 | astroid=2.14.2=py312h06a4308_0 9 | asttokens=2.0.5=pyhd3eb1b0_0 10 | atomicwrites=1.4.0=py_0 11 | attrs=23.1.0=py312h06a4308_0 12 | autopep8=2.0.4=pyhd3eb1b0_0 13 | babel=2.11.0=py312h06a4308_0 14 | beautifulsoup4=4.12.2=py312h06a4308_0 15 | binaryornot=0.4.4=pyhd3eb1b0_1 16 | black=24.3.0=py312h06a4308_0 17 | blas=1.0=mkl 18 | bleach=4.1.0=pyhd3eb1b0_0 19 | brotli=1.0.9=h5eee18b_8 20 | brotli-bin=1.0.9=h5eee18b_8 21 | brotli-python=1.0.9=py312h6a678d5_8 22 | bzip2=1.0.8=h5eee18b_6 23 | ca-certificates=2024.3.11=h06a4308_0 24 | certifi=2024.2.2=py312h06a4308_0 25 | cffi=1.16.0=py312h5eee18b_1 26 | chardet=4.0.0=py312h06a4308_1003 27 | charset-normalizer=2.0.4=pyhd3eb1b0_0 28 | click=8.1.7=py312h06a4308_0 29 | cloudpickle=2.2.1=py312h06a4308_0 30 | colorama=0.4.6=py312h06a4308_0 31 | comm=0.2.1=py312h06a4308_0 32 | contourpy=1.2.0=py312hdb19cb5_0 33 | cookiecutter=2.6.0=py312h06a4308_0 34 | cryptography=42.0.5=py312hdda0065_0 35 | cuda-cudart=11.8.89=0 36 | cuda-cupti=11.8.87=0 37 | cuda-libraries=11.8.0=0 38 | cuda-nvcc=12.4.131=0 39 | cuda-nvrtc=11.8.89=0 40 | cuda-nvtx=11.8.86=0 41 | cuda-runtime=11.8.0=0 42 | cudatoolkit=11.8.0=h6a678d5_0 43 | cycler=0.11.0=pyhd3eb1b0_0 44 | cyrus-sasl=2.1.28=h52b45da_1 45 | dbus=1.13.18=hb2f20db_0 46 | debugpy=1.6.7=py312h6a678d5_0 47 | decorator=5.1.1=pyhd3eb1b0_0 48 | defusedxml=0.7.1=pyhd3eb1b0_0 49 | diff-match-patch=20200713=pyhd3eb1b0_0 50 | dill=0.3.7=py312h06a4308_0 51 | docstring-to-markdown=0.11=py312h06a4308_0 52 | docutils=0.18.1=py312h06a4308_3 53 | executing=0.8.3=pyhd3eb1b0_0 54 | expat=2.6.2=h6a678d5_0 55 | ffmpeg=4.3=hf484d3e_0 56 | filelock=3.13.1=py312h06a4308_0 57 | flake8=7.0.0=py312h06a4308_0 58 | fontconfig=2.14.1=h4c34cd2_2 59 | fonttools=4.51.0=py312h5eee18b_0 60 | freetype=2.12.1=h4a9f257_0 61 | glib=2.78.4=h6a678d5_0 62 | glib-tools=2.78.4=h6a678d5_0 63 | gmp=6.2.1=h295c915_3 64 | gnutls=3.6.15=he1e5248_0 65 | gst-plugins-base=1.14.1=h6a678d5_1 66 | gstreamer=1.14.1=h5eee18b_1 67 | icu=73.1=h6a678d5_0 68 | idna=3.7=py312h06a4308_0 69 | imagesize=1.4.1=py312h06a4308_0 70 | importlib-metadata=7.0.1=py312h06a4308_0 71 | inflection=0.5.1=py312h06a4308_1 72 | intel-openmp=2023.1.0=hdb19cb5_46306 73 | intervaltree=3.1.0=pyhd3eb1b0_0 74 | ipykernel=6.28.0=py312h06a4308_0 75 | ipython=8.20.0=py312h06a4308_0 76 | isort=5.9.3=pyhd3eb1b0_0 77 | jaraco.classes=3.2.1=pyhd3eb1b0_0 78 | jedi=0.18.1=py312h06a4308_1 79 | jeepney=0.7.1=pyhd3eb1b0_0 80 | jellyfish=1.0.1=py312hb02cf49_0 81 | jinja2=3.1.3=py312h06a4308_0 82 | jpeg=9e=h5eee18b_1 83 | jsonschema=4.19.2=py312h06a4308_0 84 | jsonschema-specifications=2023.7.1=py312h06a4308_0 85 | jupyter_client=8.6.0=py312h06a4308_0 86 | jupyter_core=5.5.0=py312h06a4308_0 87 | jupyterlab_pygments=0.2.2=py312h06a4308_0 88 | keyring=24.3.1=py312h06a4308_0 89 | kiwisolver=1.4.4=py312h6a678d5_0 90 | krb5=1.20.1=h143b758_1 91 | lame=3.100=h7b6447c_0 92 | lazy-object-proxy=1.10.0=py312h5eee18b_0 93 | lcms2=2.12=h3be6417_0 94 | ld_impl_linux-64=2.38=h1181459_1 95 | lerc=3.0=h295c915_0 96 | libbrotlicommon=1.0.9=h5eee18b_8 97 | libbrotlidec=1.0.9=h5eee18b_8 98 | libbrotlienc=1.0.9=h5eee18b_8 99 | libclang=14.0.6=default_hc6dbbc7_1 100 | libclang13=14.0.6=default_he11475f_1 101 | libcublas=11.11.3.6=0 102 | libcufft=10.9.0.58=0 103 | libcufile=1.9.1.3=0 104 | libcups=2.4.2=h2d74bed_1 105 | libcurand=10.3.5.147=0 106 | libcusolver=11.4.1.48=0 107 | libcusparse=11.7.5.86=0 108 | libdeflate=1.17=h5eee18b_1 109 | libedit=3.1.20230828=h5eee18b_0 110 | libevent=2.1.12=hdbd6064_1 111 | libffi=3.4.4=h6a678d5_1 112 | libgcc-ng=11.2.0=h1234567_1 113 | libgfortran-ng=11.2.0=h00389a5_1 114 | libgfortran5=11.2.0=h1234567_1 115 | libglib=2.78.4=hdc74915_0 116 | libgomp=11.2.0=h1234567_1 117 | libiconv=1.16=h5eee18b_3 118 | libidn2=2.3.4=h5eee18b_0 119 | libjpeg-turbo=2.0.0=h9bf148f_0 120 | libllvm14=14.0.6=hdb19cb5_3 121 | libnpp=11.8.0.86=0 122 | libnvjpeg=11.9.0.86=0 123 | libpng=1.6.39=h5eee18b_0 124 | libpq=12.17=hdbd6064_0 125 | libsodium=1.0.18=h7b6447c_0 126 | libspatialindex=1.9.3=h2531618_0 127 | libstdcxx-ng=11.2.0=h1234567_1 128 | libtasn1=4.19.0=h5eee18b_0 129 | libtiff=4.5.1=h6a678d5_0 130 | libunistring=0.9.10=h27cfd23_0 131 | libuuid=1.41.5=h5eee18b_0 132 | libwebp-base=1.3.2=h5eee18b_0 133 | libxcb=1.15=h7f8727e_0 134 | libxkbcommon=1.0.1=h5eee18b_1 135 | libxml2=2.10.4=hfdd30dd_2 136 | llvm-openmp=14.0.6=h9e868ea_0 137 | lz4-c=1.9.4=h6a678d5_1 138 | markdown-it-py=2.2.0=py312h06a4308_1 139 | markupsafe=2.1.3=py312h5eee18b_0 140 | matplotlib=3.8.4=py312h06a4308_0 141 | matplotlib-base=3.8.4=py312h526ad5a_0 142 | matplotlib-inline=0.1.6=py312h06a4308_0 143 | mccabe=0.7.0=pyhd3eb1b0_0 144 | mdurl=0.1.0=py312h06a4308_0 145 | mistune=2.0.4=py312h06a4308_0 146 | mkl=2023.1.0=h213fc3f_46344 147 | mkl-service=2.4.0=py312h5eee18b_1 148 | mkl_fft=1.3.8=py312h5eee18b_0 149 | mkl_random=1.2.4=py312hdb19cb5_0 150 | more-itertools=10.1.0=py312h06a4308_0 151 | mpmath=1.3.0=py312h06a4308_0 152 | mypy_extensions=1.0.0=py312h06a4308_0 153 | mysql=5.7.24=h721c034_2 154 | nbclient=0.8.0=py312h06a4308_0 155 | nbconvert=7.10.0=py312h06a4308_0 156 | nbformat=5.9.2=py312h06a4308_0 157 | ncurses=6.4=h6a678d5_0 158 | nest-asyncio=1.6.0=py312h06a4308_0 159 | nettle=3.7.3=hbbd107a_1 160 | networkx=3.1=py312h06a4308_0 161 | nspr=4.35=h6a678d5_0 162 | nss=3.89.1=h6a678d5_0 163 | numpy=1.26.4=py312hc5e2394_0 164 | numpy-base=1.26.4=py312h0da6c21_0 165 | numpydoc=1.5.0=py312h06a4308_0 166 | openh264=2.1.1=h4ff587b_0 167 | openjpeg=2.4.0=h3ad879b_0 168 | openssl=3.0.13=h7f8727e_1 169 | packaging=23.2=py312h06a4308_0 170 | pandocfilters=1.5.0=pyhd3eb1b0_0 171 | parso=0.8.3=pyhd3eb1b0_0 172 | pathspec=0.10.3=py312h06a4308_0 173 | pcre2=10.42=hebb0a14_0 174 | pexpect=4.8.0=pyhd3eb1b0_3 175 | pickleshare=0.7.5=pyhd3eb1b0_1003 176 | pillow=10.3.0=py312h5eee18b_0 177 | pip=23.3.1=py312h06a4308_0 178 | platformdirs=3.10.0=py312h06a4308_0 179 | pluggy=1.0.0=py312h06a4308_1 180 | ply=3.11=py312h06a4308_1 181 | prompt-toolkit=3.0.43=py312h06a4308_0 182 | prompt_toolkit=3.0.43=hd3eb1b0_0 183 | psutil=5.9.0=py312h5eee18b_0 184 | ptyprocess=0.7.0=pyhd3eb1b0_2 185 | pure_eval=0.2.2=pyhd3eb1b0_0 186 | pybind11-abi=5=hd3eb1b0_0 187 | pycodestyle=2.11.1=py312h06a4308_0 188 | pycparser=2.21=pyhd3eb1b0_0 189 | pydocstyle=6.3.0=py312h06a4308_0 190 | pyflakes=3.2.0=py312h06a4308_0 191 | pygments=2.15.1=py312h06a4308_1 192 | pylint=2.16.2=py312h06a4308_0 193 | pylint-venv=3.0.3=py312h06a4308_0 194 | pyls-spyder=0.4.0=pyhd3eb1b0_0 195 | pyparsing=3.0.9=py312h06a4308_0 196 | pyqt=5.15.10=py312h6a678d5_0 197 | pyqt5-sip=12.13.0=py312h5eee18b_0 198 | pyqtwebengine=5.15.10=py312h6a678d5_0 199 | pysocks=1.7.1=py312h06a4308_0 200 | python=3.12.3=h996f2a0_0 201 | python-dateutil=2.8.2=pyhd3eb1b0_0 202 | python-fastjsonschema=2.16.2=py312h06a4308_0 203 | python-lsp-black=2.0.0=py312h06a4308_0 204 | python-lsp-jsonrpc=1.1.2=pyhd3eb1b0_0 205 | python-lsp-server=1.10.0=py312h06a4308_0 206 | python-slugify=5.0.2=pyhd3eb1b0_0 207 | pytoolconfig=1.2.6=py312h06a4308_0 208 | pytorch=2.3.0=py3.12_cuda11.8_cudnn8.7.0_0 209 | pytorch-cuda=11.8=h7e8668a_5 210 | pytorch-mutex=1.0=cuda 211 | pytz=2024.1=py312h06a4308_0 212 | pyxdg=0.27=pyhd3eb1b0_0 213 | pyyaml=6.0.1=py312h5eee18b_0 214 | pyzmq=25.1.2=py312h6a678d5_0 215 | qdarkstyle=3.2.3=pyhd3eb1b0_0 216 | qstylizer=0.2.2=py312h06a4308_0 217 | qt-main=5.15.2=h53bd1ea_10 218 | qt-webengine=5.15.9=h9ab4d14_7 219 | qtawesome=1.2.2=py312h06a4308_0 220 | qtconsole=5.5.1=py312h06a4308_0 221 | qtpy=2.4.1=py312h06a4308_0 222 | readline=8.2=h5eee18b_0 223 | referencing=0.30.2=py312h06a4308_0 224 | requests=2.31.0=py312h06a4308_1 225 | rich=13.3.5=py312h06a4308_1 226 | rope=1.12.0=py312h06a4308_0 227 | rpds-py=0.10.6=py312hb02cf49_0 228 | rtree=1.0.1=py312h06a4308_0 229 | scipy=1.13.0=py312hc5e2394_0 230 | secretstorage=3.3.1=py312h06a4308_1 231 | setuptools=68.2.2=py312h06a4308_0 232 | sip=6.7.12=py312h6a678d5_0 233 | six=1.16.0=pyhd3eb1b0_1 234 | snowballstemmer=2.2.0=pyhd3eb1b0_0 235 | sortedcontainers=2.4.0=pyhd3eb1b0_0 236 | soupsieve=2.5=py312h06a4308_0 237 | sphinx=5.0.2=py312h06a4308_0 238 | sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0 239 | sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0 240 | sphinxcontrib-htmlhelp=2.0.0=pyhd3eb1b0_0 241 | sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0 242 | sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0 243 | sphinxcontrib-serializinghtml=1.1.5=pyhd3eb1b0_0 244 | spyder=5.5.1=py312h06a4308_0 245 | spyder-kernels=2.5.0=py312h06a4308_0 246 | sqlite=3.45.3=h5eee18b_0 247 | stack_data=0.2.0=pyhd3eb1b0_0 248 | sympy=1.12=py312h06a4308_0 249 | tbb=2021.8.0=hdb19cb5_0 250 | text-unidecode=1.3=pyhd3eb1b0_0 251 | textdistance=4.2.1=pyhd3eb1b0_0 252 | three-merge=0.1.1=pyhd3eb1b0_0 253 | tinycss2=1.2.1=py312h06a4308_0 254 | tk=8.6.12=h1ccaba5_0 255 | tomli=2.0.1=py312h06a4308_1 256 | tomlkit=0.11.1=py312h06a4308_0 257 | torchaudio=2.3.0=py312_cu118 258 | torchvision=0.18.0=py312_cu118 259 | tornado=6.3.3=py312h5eee18b_0 260 | tqdm=4.66.2=py312he106c6f_0 261 | traitlets=5.7.1=py312h06a4308_0 262 | typing_extensions=4.9.0=py312h06a4308_1 263 | tzdata=2024a=h04d1e81_0 264 | ujson=5.4.0=py312h6a678d5_0 265 | unicodedata2=15.1.0=py312h5eee18b_0 266 | unidecode=1.2.0=pyhd3eb1b0_0 267 | urllib3=2.1.0=py312h06a4308_1 268 | watchdog=2.1.6=py312h06a4308_0 269 | wcwidth=0.2.5=pyhd3eb1b0_0 270 | webencodings=0.5.1=py312h06a4308_2 271 | whatthepatch=1.0.2=py312h06a4308_0 272 | wheel=0.41.2=py312h06a4308_0 273 | wrapt=1.14.1=py312h5eee18b_0 274 | wurlitzer=3.0.2=py312h06a4308_0 275 | xz=5.4.6=h5eee18b_1 276 | yaml=0.2.5=h7b6447c_0 277 | yapf=0.40.2=py312h06a4308_0 278 | zeromq=4.3.5=h6a678d5_0 279 | zipp=3.17.0=py312h06a4308_0 280 | zlib=1.2.13=h5eee18b_1 281 | zstd=1.5.5=hc292b87_1 282 | -------------------------------------------------------------------------------- /data/train/label/labels_snr_0dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_0dB.pt -------------------------------------------------------------------------------- /data/train/label/labels_snr_10dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_10dB.pt -------------------------------------------------------------------------------- /data/train/label/labels_snr_15dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_15dB.pt -------------------------------------------------------------------------------- /data/train/label/labels_snr_20dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_20dB.pt -------------------------------------------------------------------------------- /data/train/label/labels_snr_25dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_25dB.pt -------------------------------------------------------------------------------- /data/train/label/labels_snr_30dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_30dB.pt -------------------------------------------------------------------------------- /data/train/label/labels_snr_5dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/label/labels_snr_5dB.pt -------------------------------------------------------------------------------- /data/train/signal/signals_snr_0dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_0dB.pt -------------------------------------------------------------------------------- /data/train/signal/signals_snr_10dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_10dB.pt -------------------------------------------------------------------------------- /data/train/signal/signals_snr_15dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_15dB.pt -------------------------------------------------------------------------------- /data/train/signal/signals_snr_20dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_20dB.pt -------------------------------------------------------------------------------- /data/train/signal/signals_snr_25dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_25dB.pt -------------------------------------------------------------------------------- /data/train/signal/signals_snr_30dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_30dB.pt -------------------------------------------------------------------------------- /data/train/signal/signals_snr_5dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/train/signal/signals_snr_5dB.pt -------------------------------------------------------------------------------- /data/val/label/labels_snr_0dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_0dB.pt -------------------------------------------------------------------------------- /data/val/label/labels_snr_10dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_10dB.pt -------------------------------------------------------------------------------- /data/val/label/labels_snr_15dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_15dB.pt -------------------------------------------------------------------------------- /data/val/label/labels_snr_20dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_20dB.pt -------------------------------------------------------------------------------- /data/val/label/labels_snr_25dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_25dB.pt -------------------------------------------------------------------------------- /data/val/label/labels_snr_30dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_30dB.pt -------------------------------------------------------------------------------- /data/val/label/labels_snr_5dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/label/labels_snr_5dB.pt -------------------------------------------------------------------------------- /data/val/signal/signals_snr_0dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_0dB.pt -------------------------------------------------------------------------------- /data/val/signal/signals_snr_10dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_10dB.pt -------------------------------------------------------------------------------- /data/val/signal/signals_snr_15dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_15dB.pt -------------------------------------------------------------------------------- /data/val/signal/signals_snr_20dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_20dB.pt -------------------------------------------------------------------------------- /data/val/signal/signals_snr_25dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_25dB.pt -------------------------------------------------------------------------------- /data/val/signal/signals_snr_30dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_30dB.pt -------------------------------------------------------------------------------- /data/val/signal/signals_snr_5dB.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/data/val/signal/signals_snr_5dB.pt -------------------------------------------------------------------------------- /fig/Accuracy1_SLA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy1_SLA.png -------------------------------------------------------------------------------- /fig/Accuracy1_ULA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy1_ULA.png -------------------------------------------------------------------------------- /fig/Accuracy2_SLA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy2_SLA.png -------------------------------------------------------------------------------- /fig/Accuracy2_ULA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Accuracy2_ULA.png -------------------------------------------------------------------------------- /fig/Example_SLA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_SLA.png -------------------------------------------------------------------------------- /fig/Example_SLA_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_SLA_real.png -------------------------------------------------------------------------------- /fig/Example_ULA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_ULA.png -------------------------------------------------------------------------------- /fig/Example_ULA_real.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Example_ULA_real.png -------------------------------------------------------------------------------- /fig/Network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Network.png -------------------------------------------------------------------------------- /fig/Separate_SLA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Separate_SLA.png -------------------------------------------------------------------------------- /fig/Separate_ULA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/fig/Separate_ULA.png -------------------------------------------------------------------------------- /models/SADOANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.append('../') 5 | from scr.helpers import steering_vector 6 | 7 | 8 | class DOANet(nn.Module): 9 | def __init__(self, number_element=20, output_size=61): 10 | super(DOANet, self).__init__() 11 | # Layer configurations 12 | self.input_size = number_element 13 | self.output_size = output_size 14 | hidden_sizes = [2048, 1024, 512, 256, 128] # Adjustable hidden layer sizes 15 | 16 | # Network layers 17 | layers = [] 18 | for h1, h2 in zip([self.input_size] + hidden_sizes[:-1], hidden_sizes): 19 | layers.append(nn.Linear(h1, h2)) 20 | layers.append(nn.ReLU()) 21 | layers.append(nn.Linear(hidden_sizes[-1], self.output_size)) 22 | layers.append(nn.Sigmoid()) 23 | 24 | self.layers = nn.Sequential(*layers) 25 | 26 | def forward(self, x): 27 | return self.layers(x) 28 | 29 | 30 | class SparseLayer(nn.Module): 31 | def __init__(self, input_size=10, max_sparsity=0.5): 32 | super(SparseLayer, self).__init__() 33 | self.input_size = input_size 34 | self.max_zeros = int(input_size * max_sparsity) 35 | 36 | def forward(self, x): 37 | if self.training: 38 | batch_size,N = x.size() # Get the batch size 39 | sparsity = torch.zeros(batch_size, dtype=torch.long).to(x.device) # Tensor to store the number of zeros 40 | # Generate a random mask for each example in the batch 41 | masks = torch.ones((batch_size, self.input_size)).to(x.device) # Start with all ones 42 | NN = torch.ones(batch_size) * N 43 | num_zeros = torch.randint(0, self.max_zeros + 1, (batch_size,)) 44 | for i in range(batch_size): 45 | zero_indices = torch.randperm(self.input_size)[:num_zeros[i]] # Random indices to be zeroed 46 | masks[i, zero_indices] = 0 # Set selected indices to zero 47 | sparsity = NN - num_zeros # Store the number of zeros used for this mask 48 | x_sparse = x * masks 49 | 50 | else: 51 | x_sparse = x 52 | sparsity, masks = self.thresholding(x) 53 | 54 | return x_sparse, sparsity, masks.to(x_sparse.dtype).to(x_sparse.device) 55 | 56 | 57 | def thresholding(self, x): 58 | threshold = 0.001 59 | mask = (torch.abs(x) > threshold).float() 60 | return torch.sum(mask, dim=1), mask 61 | 62 | 63 | class SALayer(nn.Module): 64 | def __init__(self, number_element=10, output_size=61, max_sparsity=0.5): 65 | super(SALayer, self).__init__() 66 | # Initialize SparseLayer 67 | self.sparselayer = SparseLayer(number_element, max_sparsity) 68 | self.output_size = output_size 69 | self.hidden_size = 512 70 | 71 | # Calculate steering vectors 72 | a_theta = steering_vector(number_element, torch.arange(-30, 31)).conj() 73 | self.AH = torch.transpose(a_theta, 0, 1) 74 | 75 | # Define the network layers 76 | self.linear_layer = nn.Linear(number_element * 2, self.hidden_size) 77 | self.relu = nn.ReLU() # ReLU activation function 78 | 79 | def forward(self, x): 80 | batch_size, N = x.size() 81 | # Sparse representation and FFT 82 | x_sparse, sparsity, masks = self.sparselayer(x) 83 | x_fft = self.apply_fft(x_sparse, sparsity, batch_size) 84 | masks_fft = self.apply_fft(masks, sparsity, batch_size) 85 | 86 | # Flatten the sparse input and apply linear transformation 87 | x_flat = torch.view_as_real(x_sparse).view(batch_size, -1) 88 | embedded_values = self.relu(self.linear_layer(x_flat)) 89 | normalized_values = self.normalize(embedded_values, sparsity) 90 | 91 | # Concatenate features from different processing streams 92 | output = torch.cat((masks_fft, x_fft, normalized_values), dim=1) 93 | return output 94 | 95 | def normalize(self, x, sparsity): 96 | """Normalize the data by the sparsity-derived factor.""" 97 | normalization_factor = sparsity.unsqueeze(-1).to(x.device) 98 | return x / normalization_factor 99 | 100 | def apply_fft(self, x, sparsity, batch_size): 101 | """Apply FFT to the input tensor and adjust based on batch size.""" 102 | AH_batched = self.AH.unsqueeze(0).repeat(batch_size, 1, 1) 103 | x_expanded = x.unsqueeze(-1) 104 | fft_output = torch.abs(torch.matmul(AH_batched.to(x_expanded.device), x_expanded)).squeeze(-1) 105 | return self.normalize(fft_output,sparsity) 106 | 107 | 108 | class SADOANet(nn.Module): 109 | def __init__(self, number_element=10, output_size=61, max_sparsity=0.5, is_sparse=True): 110 | super(SADOANet, self).__init__() 111 | self.salayer = SALayer(number_element, output_size, max_sparsity) 112 | input_size = 512 + output_size * 2 if is_sparse else number_element * 2 113 | self.doanet = DOANet(input_size, output_size) 114 | self.is_sparse = is_sparse 115 | 116 | def forward(self, x): 117 | if len(x.size()) == 3: # signal pass in as real value 118 | x = torch.view_as_complex(x) 119 | if self.is_sparse: 120 | x = self.salayer(x) 121 | else: 122 | x = torch.view_as_real(x).view(x.size(0), -1) 123 | return self.doanet(x) 124 | 125 | -------------------------------------------------------------------------------- /models/__pycache__/SADOANet.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/models/__pycache__/SADOANet.cpython-312.pyc -------------------------------------------------------------------------------- /real_World_DOA_dataset/README.md: -------------------------------------------------------------------------------- 1 | # Real World dataset for Direction of Arrival (DOA) Estimation 2 | 3 | ## Motivation 4 | In the field of DOA estimation, the absence of publicly available real-world datasets has been a significant barrier for advancing and validating DOA estimation technologies. Historically, researchers have relied on simulated datasets to train and evaluate their models. To bridge this gap and contribute a practical resource to the community, we have developed a DOA estimation dataset gathered under real-world conditions. 5 | 6 | ## Data acquisition vehicle platform 7 | Data acquisition vehicle platform of Lexus RX450h with high-resolution imaging radar, LiDAR, and stereo cameras is used to carry out field experiments at the University of Alabama. 8 |

9 | 10 |

11 | 12 | ## Field experiment 13 | This dataset was generated in a parking lot scenario where a stationary vehicle, equipped with a TI Cascade Imaging Radar, collected data. The vehicle was stationed to capture signals from a corner reflector placed 15 meters away, encompassing all possible directions. 14 | 15 |

16 | 17 |

18 | 19 | This comprehensive data collection resulted in 195 high-SNR signals representing unique angles of arrival from a single target. 20 |

21 | 22 |

23 | 24 | 25 | To enhance the complexity and usability of the dataset, we superimposed these signals to simulate scenarios with multiple targets. Here are some examples of FFT spectrum on multiple targets signals. 26 |

27 | 28 |

29 | 30 | ## Dataset structure 31 | data.mat contains: 32 | - ang_list (1 x 195): the ground truth DOA of each signal 33 | - bv_list (195 x 86): 195 raw signal with 86 antennas 34 | 35 | ## How to use 36 | Matlab 37 | ``` matlab 38 | load('data.mat') 39 | ``` 40 | python 41 | ``` python 42 | import scipy.io 43 | data = scipy.io.loadmat('data.mat') 44 | ``` 45 | 46 | If this dataset contributes to your research, please acknowledge its use with the following citation: 47 | ``` LATEX 48 | @ARTICLE{10348517, 49 | author={Zheng, Ruxin and Sun, Shunqiao and Liu, Hongshan and Chen, Honglei and Li, Jian}, 50 | journal={IEEE Sensors Journal}, 51 | title={Interpretable and Efficient Beamforming-Based Deep Learning for Single Snapshot DOA Estimation}, 52 | year={2023}, 53 | volume={}, 54 | number={}, 55 | pages={1-1}, 56 | keywords={Direction-of-arrival estimation;Estimation;Deep learning;Covariance matrices;Sensors;Mathematical models;Array signal processing;Single snapshot DOA estimation;array signal processing;automotive radar;interpretability;deep learning}, 57 | doi={10.1109/JSEN.2023.3338575}} 58 | ``` 59 | -------------------------------------------------------------------------------- /real_World_DOA_dataset/data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/data.mat -------------------------------------------------------------------------------- /real_World_DOA_dataset/fig/DOA data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/DOA data.png -------------------------------------------------------------------------------- /real_World_DOA_dataset/fig/example.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/example.gif -------------------------------------------------------------------------------- /real_World_DOA_dataset/fig/multiExamples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/multiExamples.png -------------------------------------------------------------------------------- /real_World_DOA_dataset/fig/platform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/real_World_DOA_dataset/fig/platform.png -------------------------------------------------------------------------------- /scr/__pycache__/eval_fun.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/__pycache__/eval_fun.cpython-312.pyc -------------------------------------------------------------------------------- /scr/__pycache__/helpers.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/__pycache__/helpers.cpython-311.pyc -------------------------------------------------------------------------------- /scr/__pycache__/helpers.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/__pycache__/helpers.cpython-312.pyc -------------------------------------------------------------------------------- /scr/dataset_gen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from helpers import generate_data 4 | 5 | 6 | def main(args): 7 | print("Generating validation data ... ...") 8 | generate_data( 9 | N=args.N, 10 | num_samples=args.num_samples_val, 11 | max_targets=args.max_targets, 12 | folder_path=os.path.join(args.output_dir, 'data/val') 13 | ) 14 | 15 | print("Generating training data ... ...") 16 | generate_data( 17 | N=args.N, 18 | num_samples=args.num_samples_train, 19 | max_targets=args.max_targets, 20 | folder_path=os.path.join(args.output_dir, 'data/train') 21 | ) 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description='Generate dataset for antenna signal processing.') 25 | parser.add_argument('--output_dir', type=str, default='./', 26 | help='Base directory for output data') 27 | parser.add_argument('--num_samples_val', type=int, default=1024, 28 | help='Number of validation samples to generate') 29 | parser.add_argument('--num_samples_train', type=int, default=100000, 30 | help='Number of training samples to generate') 31 | parser.add_argument('--N', type=int, default=10, 32 | help='Number of antenna elements') 33 | parser.add_argument('--max_targets', type=int, default=3, 34 | help='Maximum number of targets per sample') 35 | args = parser.parse_args() 36 | main(args) -------------------------------------------------------------------------------- /scr/eval_fun.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import matplotlib.pyplot as plt 4 | from scipy.signal import find_peaks 5 | import time 6 | from helpers import * 7 | import sys 8 | sys.path.append('../') 9 | from models.SADOANet import SADOANet 10 | from tqdm import tqdm 11 | import torch 12 | from collections import OrderedDict 13 | 14 | def load_sadoanet(num_elements, output_size, sparsity, is_sparse, device, model_path): 15 | model = SADOANet(num_elements, output_size, sparsity, is_sparse).to(device) 16 | state_dict = torch.load(model_path) 17 | # model is trained using nn.DataParallel 18 | # need to rename key, if the model not in DataParallel 19 | new_state_dict = OrderedDict() 20 | for k, v in state_dict.items(): 21 | new_state_dict[k.replace("module.","")] = v 22 | model.load_state_dict(new_state_dict) 23 | 24 | return model.eval() 25 | 26 | def randSparse(signal,sparsity): 27 | sparseSignal = signal.clone() 28 | sparseInd = torch.randperm(signal.numel())[:int(signal.numel() * sparsity)] 29 | sparseSignal[sparseInd] = 0 30 | return sparseSignal 31 | 32 | def FFT(signal, num_antennas=10, ang_min = -30, ang_max = 31, ang_step = 1): 33 | ang_list = torch.arange(ang_min,ang_max,ang_step) 34 | a_theta = steering_vector(num_antennas,ang_list ).conj() 35 | AH = torch.transpose(a_theta, 0, 1) 36 | spec = torch.abs(torch.matmul(AH,signal)/num_antennas).squeeze().numpy() 37 | return ang_list, spec 38 | 39 | def IAA(y, Niter=15): 40 | """ 41 | Implementation of the IAA algorithm. 42 | """ 43 | N = y.shape[0] 44 | ang_list = torch.arange(-90,91,1) 45 | A = steering_vector(N,ang_list) 46 | AH = torch.transpose(A.conj(), 0, 1) 47 | N, K = A.shape 48 | Ns = y.shape[1] 49 | Pk = np.zeros(K, dtype=np.complex128) 50 | 51 | y = y.squeeze().numpy() 52 | A = A.numpy() 53 | AH = AH.resolve_conj().numpy() 54 | # Initial computation of power 55 | 56 | Pk = (AH @ y / N) ** 2 57 | P = np.diag(Pk) 58 | R = A @ P @ AH 59 | 60 | # Main iteration 61 | for _ in range(Niter): 62 | R += 0e-3 * np.eye(N) 63 | ak_R = AH @ np.linalg.pinv(R) 64 | T = ak_R @ y 65 | B = ak_R @ A 66 | b = B.diagonal() 67 | sk = T/np.abs(b) 68 | Pk = np.abs(sk) ** 2 69 | P = np.diag(Pk) 70 | R = A @ P @ A.conj().T 71 | # spec = Pk 72 | spec = Pk[60:121] #-30:1:30 73 | ang_list = torch.arange(-30,31,1) 74 | return ang_list, spec 75 | 76 | def DLapproach(signal,model,device): 77 | model_result = model(signal.squeeze().unsqueeze(0).to(device)) 78 | model_result = model_result.squeeze().cpu().detach().numpy() 79 | ang_list = torch.arange(-30,31,1) 80 | spec = model_result 81 | return ang_list, spec 82 | 83 | def estimate_doa(ang_list, spec, scale = 0.7): 84 | """ 85 | Estimate doa from spectrum 86 | """ 87 | max_height = np.max(spec) 88 | min_peak_height = (scale * max_height) 89 | peaks, properties = find_peaks(spec, height=min_peak_height) 90 | # Sort peaks by their magnitudes in descending order 91 | sorted_indices = np.argsort(properties['peak_heights'])[::-1] # Get indices to sort in descending order 92 | sorted_peaks = peaks[sorted_indices] 93 | sorted_peak_heights = properties['peak_heights'][sorted_indices] 94 | doa = ang_list[sorted_peaks] 95 | 96 | return doa 97 | 98 | def plot_results(snr_levels, mse_metrics): 99 | plt.figure(figsize=(10, 6)) 100 | markers = {'fft': 'o', 'iaa': 's', 'mlp': '^', 'sparse': 'd', 'grid': 'x'} 101 | colors = {'fft': 'blue', 'iaa': 'green', 'mlp': 'red', 'sparse': 'cyan', 'grid': 'black'} 102 | for key, color in colors.items(): 103 | plt.semilogy(snr_levels, mse_metrics[key], label=key.upper(), color=color, marker=markers[key]) 104 | 105 | plt.title('MSE vs SNR for Different DOA Estimation Methods') 106 | plt.xlabel('SNR (dB)') 107 | plt.ylabel('Mean Squared Error (deg^2)') 108 | plt.legend() 109 | plt.grid(True) 110 | plt.show() 111 | 112 | def plot_HR(snr_levels, HR_metrics): 113 | plt.figure(figsize=(10, 6)) 114 | markers = {'fft': 'o', 'iaa': 's', 'mlp': '^', 'sparse': 'd'} 115 | colors = {'fft': 'blue', 'iaa': 'green', 'mlp': 'red', 'sparse': 'cyan'} 116 | for key, color in colors.items(): 117 | plt.plot(snr_levels, HR_metrics[key], label=key.upper(), color=color, marker=markers[key]) 118 | 119 | plt.title('Hit Rate vs SNR for Different DOA Estimation Methods') 120 | plt.xlabel('SNR (dB)') 121 | plt.ylabel('Hit Rate (%)') 122 | plt.legend() 123 | plt.grid(True) 124 | plt.show() 125 | 126 | def handle_estimated_angles(tmp): 127 | """ Handle the sorting and selection of estimated angles. """ 128 | if tmp.nelement() == 0: 129 | return np.array([30, 30]) # Default error value when no angles are resolved 130 | elif tmp.nelement() == 1: 131 | return np.array([tmp[0].item(), tmp[0].item()]) 132 | else: 133 | tmp = tmp[:2] 134 | sorted_tmp = torch.sort(tmp)[0].numpy() 135 | return sorted_tmp 136 | 137 | def calculate_mse(actual_angles_1, actual_angles_2, estimates): 138 | """ Calculate the mean squared error for estimated angles. """ 139 | mse_1 = np.mean((actual_angles_1 - estimates[:, 0]) ** 2) 140 | mse_2 = np.mean((actual_angles_2 - estimates[:, 1]) ** 2) 141 | return np.mean([mse_1, mse_2]) 142 | 143 | def count_parameters(model): 144 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 145 | 146 | ################################################################# 147 | ## MONTE CARLO TEST FUNCTIONS ## 148 | ################################################################# 149 | def run_monte_carlo_accuracy(num_simulations, num_antennas=10, sparse_flag = True): 150 | snr_levels = np.arange(-5, 35, 5) # SNR levels from -5 dB to 30 dB in 5 dB steps 151 | mse_metrics = {'fft': [], 'iaa': [], 'mlp': [], 'sparse': [], 'grid': []} 152 | 153 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 154 | # Load models 155 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth') 156 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth') 157 | 158 | methods = { 159 | 'fft': lambda sig: FFT(sig), 160 | 'iaa': lambda sig: IAA(sig), 161 | 'mlp': lambda sig: DLapproach(sig, mlp_model, device), 162 | 'sparse': lambda sig: DLapproach(sig, sparse_model, device) 163 | } 164 | 165 | for snr_db in tqdm(snr_levels, desc='SNR levels', unit='snr'): 166 | actual_angles = np.random.uniform(-30, 30, num_simulations) 167 | estimates = {key: np.zeros(num_simulations) for key in mse_metrics} 168 | 169 | for i in range(num_simulations): 170 | signal = generate_complex_signal(num_antennas, snr_db, torch.tensor([actual_angles[i]])) 171 | if sparse_flag: 172 | signal = randSparse(signal, 0.3) 173 | 174 | for method_name, method in methods.items(): 175 | ang_list, spec = method(signal) 176 | tmp = estimate_doa(ang_list, spec) 177 | estimates[method_name][i] = actual_angles[i] if tmp.nelement() == 0 else tmp[0].item() 178 | 179 | for key in mse_metrics: 180 | if key != 'grid': 181 | mse_metrics[key].append(np.mean((actual_angles - estimates[key]) ** 2)) 182 | else : 183 | mse_metrics[key].append(np.mean((actual_angles - np.rint(actual_angles)) ** 2)) 184 | return snr_levels, mse_metrics 185 | 186 | 187 | ################################################################## 188 | def run_monte_carlo_accuracy2(num_simulations, num_antennas=10, sparse_flag = True): 189 | """ 190 | Run Monte Carlo simulations to evaluate various DOA estimation methods across SNR levels. 191 | """ 192 | snr_levels = np.arange(-5, 35, 5) # SNR levels from -5 dB to 30 dB in 5 dB steps 193 | mse_metrics = {'fft': [], 'iaa': [], 'mlp': [], 'sparse': [], 'grid': []} 194 | 195 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 196 | # Load models 197 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth') 198 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth') 199 | 200 | methods = { 201 | 'fft': lambda sig: FFT(sig), 202 | 'iaa': lambda sig: IAA(sig), 203 | 'mlp': lambda sig: DLapproach(sig, mlp_model, device), 204 | 'sparse': lambda sig: DLapproach(sig, sparse_model, device) 205 | } 206 | 207 | for snr_db in tqdm(snr_levels, desc='SNR levels', unit='snr'): 208 | actual_angles_1 = np.random.uniform(-0.6, 0.4, num_simulations) 209 | actual_angles_2 = np.random.uniform(9.6, 10.4, num_simulations) 210 | actual_angles = np.column_stack((actual_angles_1, actual_angles_2)).flatten() 211 | 212 | estimates = {key: np.zeros((num_simulations, 2)) for key in mse_metrics} 213 | 214 | for i in range(num_simulations): 215 | signal = generate_complex_signal(num_antennas, snr_db, torch.tensor([actual_angles_1[i], actual_angles_2[i]])) 216 | if sparse_flag: 217 | signal = randSparse(signal, 0.3) 218 | 219 | for method_name, method in methods.items(): 220 | ang_list, spec = method(signal) 221 | tmp = estimate_doa(ang_list, spec,0) 222 | estimates[method_name][i] = handle_estimated_angles(tmp) 223 | 224 | # Calculate MSE for each method 225 | for key in mse_metrics: 226 | if key != 'grid': 227 | mse_metrics[key].append(calculate_mse(actual_angles_1, actual_angles_2, estimates[key])) 228 | else: 229 | tmp1 = np.mean((actual_angles_1- np.rint(actual_angles_1)) ** 2) 230 | tmp2 = np.mean((actual_angles_2- np.rint(actual_angles_2)) ** 2) 231 | mse_metrics[key].append(np.mean(tmp1+tmp2)) 232 | return snr_levels, mse_metrics 233 | 234 | 235 | 236 | ################################################################## 237 | def run_monte_carlo_sep(num_simulations, num_antennas=10, sparse_flag = True): 238 | """ 239 | Run Monte Carlo simulations to evaluate DOA estimation accuracy across different separation angles. 240 | """ 241 | sep_angles = np.arange(2, 30, 2) # Separation angles from 2 to 28 degrees in 2 degree steps 242 | HR_metrics = {'fft': [], 'iaa': [], 'mlp': [], 'sparse': []} 243 | 244 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 245 | # Load models 246 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth') 247 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth') 248 | 249 | methods = { 250 | 'fft': lambda sig: FFT(sig), 251 | 'iaa': lambda sig: IAA(sig), 252 | 'mlp': lambda sig: DLapproach(sig, mlp_model, device), 253 | 'sparse': lambda sig: DLapproach(sig, sparse_model, device) 254 | } 255 | 256 | 257 | for sep in tqdm(sep_angles, desc='Separation', unit='deg', ncols=86): 258 | actual_angles = np.array([-sep / 2, sep / 2]) # Centered angles 259 | tmp_rates = {key: np.zeros(num_simulations) for key in HR_metrics} 260 | 261 | for i in range(num_simulations): 262 | signal = generate_complex_signal(num_antennas, 40, torch.from_numpy(actual_angles)) 263 | if sparse_flag: 264 | signal = randSparse(signal, 0.3) 265 | 266 | for method_name, method in methods.items(): 267 | ang_list, spec = method(signal) 268 | estimated_angles = estimate_doa(ang_list, spec, 0.7 if method_name in ['fft', 'iaa'] else 0.2) 269 | estimated_angles, _ = torch.sort(estimated_angles) 270 | if estimated_angles.nelement() == 2 and np.allclose(estimated_angles.numpy(), actual_angles, atol=2): 271 | tmp_rates[method_name][i] = 1 272 | 273 | for key in HR_metrics: 274 | HR_metrics[key].append(np.mean(tmp_rates[key])) 275 | 276 | return sep_angles, HR_metrics 277 | 278 | 279 | ################################################################## 280 | def run_examples(signal,num_antennas=10,sparse_flag = True): 281 | # load model 282 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 283 | mlp_model = load_sadoanet(10, 61, 0.3, False, device, '../checkpoint/filled/best_model_checkpoint.pth') 284 | sparse_model = load_sadoanet(10, 61, 0.3, True, device, '../checkpoint/sparse/best_model_checkpoint.pth') 285 | signal = torch.from_numpy(signal).to(torch.cfloat) 286 | sparsity= 0.3 287 | antennaPos = np.arange(num_antennas) 288 | if sparse_flag: 289 | signal = randSparse(signal, 0.3) 290 | zero_indices = torch.where(signal == 0)[0].numpy() 291 | print(zero_indices) 292 | antennaPos = np.arange(num_antennas) 293 | antennaPos = np.delete(antennaPos, zero_indices) 294 | 295 | # Create a 1x4 subplot grid 296 | fig, axs = plt.subplots(1, 4, figsize=(20, 5)) # Adjust the figsize to ensure all subplots are visible and not squished 297 | 298 | # Plot the first subplot similar to the uploaded image 299 | axs[0].stem(antennaPos, np.ones_like(antennaPos), linefmt='blue', markerfmt='bo', basefmt="r-") 300 | if sparse_flag: 301 | axs[0].set_title('Sparse Linear Array') 302 | else: 303 | axs[0].set_title('Uniform Linear Array') 304 | axs[0].set_xlabel('Horizontal [Half Wavelength]') 305 | axs[0].set_ylabel('Amplitude') 306 | axs[0].set_ylim([0, 2]) # Setting the y-axis limits to match the uploaded image 307 | axs[0].set_xlim([0, num_antennas-1]) # Setting the x-axis limits to match the uploaded image 308 | axs[0].set_yticks([0, 1, 2]) 309 | axs[0].set_xticks(np.arange(num_antennas)) 310 | axs[0].grid(True) 311 | 312 | #FFT 313 | ang_list, spec0 = FFT(signal,num_antennas) 314 | spec0 = spec0/np.max(spec0) 315 | #IAA 316 | ang_list, spec1 = IAA(signal) 317 | spec1 = spec1/np.max(spec1) 318 | axs[1].plot(ang_list, spec0, label='DBF', color='blue',linewidth=2) 319 | axs[1].plot(ang_list, spec1, label='IAA', color='red',linewidth=2) 320 | axs[1].axvline(0, color='black', linestyle='--', linewidth=1.5) 321 | axs[1].axvline(7, color='black', linestyle='--', linewidth=1.5) 322 | axs[1].set_title('DBF vs IAA') 323 | axs[1].set_xlabel('Angle [degree]') 324 | axs[1].set_ylabel('Magnitude') 325 | axs[1].set_ylim([0, 1]) # Setting the y-axis limits to match the uploaded image 326 | axs[1].set_xlim([-30, 30]) # Setting the x-axis limits to match the uploaded image 327 | axs[1].set_xticks(np.arange(-30, 31, 10)) 328 | axs[1].set_yticks([0, 0.5, 1]) 329 | axs[1].legend() 330 | axs[1].grid(True) 331 | 332 | #MLP 333 | ang_list, spec2 = DLapproach(signal,mlp_model,device) 334 | spec2 = spec2/np.max(spec2) 335 | axs[2].plot(ang_list, spec2, color='blue',linewidth=2) 336 | axs[2].axvline(0, color='black', linestyle='--', linewidth=1.5) 337 | axs[2].axvline(7, color='black', linestyle='--', linewidth=1.5) 338 | axs[2].set_title('MLP') 339 | axs[2].set_xlabel('Angle [degree]') 340 | axs[2].set_ylabel('Magnitude') 341 | axs[2].set_ylim([0, 1]) # Setting the y-axis limits to match the uploaded image 342 | axs[2].set_xlim([-30, 30]) # Setting the x-axis limits to match the uploaded image 343 | axs[2].set_xticks(np.arange(-30, 31, 10)) 344 | axs[2].set_yticks([0, 0.5, 1]) 345 | axs[2].grid(True) 346 | 347 | #Sparse 348 | ang_list, spec3 = DLapproach(signal,sparse_model,device) 349 | spec3 = spec3/np.max(spec3) 350 | axs[3].plot(ang_list, spec3, color='blue',linewidth=2) 351 | axs[3].axvline(0, color='black', linestyle='--', linewidth=1.5) 352 | axs[3].axvline(7, color='black', linestyle='--', linewidth=1.5) 353 | axs[3].set_title('Ours') 354 | axs[3].set_xlabel('Angle [degree]') 355 | axs[3].set_ylabel('Magnitude') 356 | axs[3].set_ylim([0, 1]) # Setting the y-axis limits to match the uploaded image 357 | axs[3].set_xlim([-30, 30]) # Setting the x-axis limits to match the uploaded image 358 | axs[3].set_xticks(np.arange(-30, 31, 10)) 359 | axs[3].set_yticks([0, 0.5, 1]) 360 | axs[3].grid(True) 361 | plt.show() 362 | return ang_list, spec0, spec1, spec2, spec3 363 | 364 | ################################################################## 365 | def model_complexity(num_elements = 10, output_size = 61, sparsity = 0.3): 366 | mlp_model = SADOANet(num_elements, output_size, sparsity, False) 367 | num_params = count_parameters(mlp_model) 368 | sparse_model = SADOANet(num_elements, output_size, sparsity, True) 369 | num_params_sparse = count_parameters(sparse_model) 370 | print(f"Total trainable parameters in MLP model: {num_params}") 371 | print(f"Total trainable parameters in Ours model: {num_params_sparse}") 372 | return num_params, num_params_sparse 373 | 374 | ################################################################## 375 | # def main(): 376 | # parser = argparse.ArgumentParser(description="Run Monte Carlo simulations for DOA estimation accuracy") 377 | # parser.add_argument('--num_simulations', type=int, default=1000, help="Number of Monte Carlo simulations to run") 378 | # parser.add_argument('--num_antennas', type=int, default=10, help="Number of antennas in the array") 379 | # args = parser.parse_args() 380 | 381 | # # Accuracy - single target 382 | # snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas, False) 383 | # plot_results(snr_levels, mse_metrics) 384 | # snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas) 385 | # plot_results(snr_levels, mse_metrics) 386 | 387 | # # Accuracy - two targets 388 | # snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas, False) 389 | # plot_results(snr_levels, mse_metrics) 390 | # snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas) 391 | # plot_results(snr_levels, mse_metrics) 392 | 393 | # # Generate estimation example on real data 394 | # signal = generate_complex_signal(10, 40, torch.tensor([0, 7])).numpy() 395 | # run_examples(signal,num_antennas=10,sparse_flag = False) 396 | # run_examples(signal,num_antennas=10,sparse_flag = True) 397 | # run_examples(signal,num_antennas=10,sparse_flag = True) 398 | # run_examples(signal,num_antennas=10,sparse_flag = True) 399 | 400 | # # Counts total trainable parameters 401 | # model_complexity() 402 | # if __name__ == "__main__": 403 | # main() 404 | -------------------------------------------------------------------------------- /scr/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm.auto import tqdm 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | 7 | def steering_vector(N, deg): 8 | """ 9 | Calculate the steering vector for a uniform linear array using given antenna configuration. 10 | 11 | Args: 12 | N (int): Number of antenna elements. 13 | deg (float): Angle of arrival in degrees. 14 | 15 | Returns: 16 | torch.Tensor: The steering vector as a complex-valued tensor. 17 | """ 18 | d = 0.5 # Element spacing (in units of wavelength) 19 | wavelength = 1.0 # Wavelength of the signal (same units as d) 20 | k = 2 * torch.pi / wavelength # Wavenumber 21 | n = torch.arange(0, N).view(N, 1) # Antenna element indices [0, 1, ..., N-1] 22 | theta = deg * torch.pi / 180 # Convert degrees to radians 23 | phases = k * d * n * torch.sin(theta) # Phase shift for each element 24 | 25 | return torch.exp(1j * phases) # Complex exponential for each phase shift 26 | 27 | 28 | 29 | def generate_complex_signal(N=10, snr_db=10, deg=torch.tensor([30])): 30 | """ 31 | Generates a complex-valued signal for an array of N antenna elements. 32 | 33 | Args: 34 | N (int): Number of antenna elements. 35 | snr_db (float): Signal-to-Noise Ratio in decibels. 36 | deg (tensor): Angle of arrival in degrees. 37 | 38 | Returns: 39 | torch.Tensor: Complex-valued tensor of shape (N, 1) representing the received signals. 40 | """ 41 | a_theta = steering_vector(N, deg) 42 | phase = torch.exp(2j * torch.pi * torch.randn(a_theta.size()[1])).view(-1, 1) 43 | signal = torch.matmul(a_theta.to(phase.dtype), phase) 44 | signal_power = torch.mean(torch.abs(signal)**2) 45 | snr_linear = 10**(snr_db / 10) 46 | 47 | noise_power = signal_power / snr_linear 48 | noise_real = torch.sqrt(noise_power / 2) * torch.randn_like(signal.real) 49 | noise_imag = torch.sqrt(noise_power / 2) * torch.randn_like(signal.imag) 50 | noise = torch.complex(noise_real, noise_imag) 51 | 52 | return signal + noise 53 | 54 | 55 | def generate_label(degrees, min_angle=-30, max_angle=30): 56 | """ 57 | Generate one-hot encoded labels for the given degrees. 58 | 59 | Args: 60 | degrees (tensor): Target angles in degrees. 61 | 62 | Returns: 63 | torch.Tensor: One-hot encoded labels. 64 | """ 65 | labels = torch.zeros(max_angle - min_angle + 1) 66 | indices = degrees - min_angle 67 | labels[indices.long()] = 1 68 | return labels 69 | 70 | def generate_data(N, num_samples=1, max_targets=3, folder_path='/content/drive/MyDrive/Asilomar2024/data/'): 71 | """ 72 | Generate dataset with random number of targets and varying SNR levels. 73 | 74 | Args: 75 | N (int): Number of antenna elements. 76 | num_samples (int): Number of samples to generate for each SNR level. 77 | max_targets (int): Maximum number of targets. 78 | folder_path (str): Base folder path for saving data. 79 | 80 | Returns: 81 | int: Always returns 0. Data saved in specified directory. 82 | """ 83 | angles = torch.arange(-30, 31, 1) 84 | signal_folder = os.path.join(folder_path, 'signal') 85 | label_folder = os.path.join(folder_path, 'label') 86 | os.makedirs(signal_folder, exist_ok=True) 87 | os.makedirs(label_folder, exist_ok=True) 88 | 89 | for snr_db in tqdm(range(0, 35, 5), desc='SNR levels', unit='snr', dynamic_ncols=True): 90 | all_signals, all_labels = [], [] 91 | for _ in range(num_samples): 92 | num_targets = torch.randint(1, max_targets + 1, (1,)).item() 93 | deg_indices = torch.randperm(len(angles))[:num_targets] 94 | degs = angles[deg_indices] 95 | label = generate_label(degs) 96 | noisy_signal = generate_complex_signal(N=N, snr_db=snr_db, deg=degs) 97 | all_signals.append(noisy_signal) 98 | all_labels.append(label) 99 | torch.save(all_signals, os.path.join(signal_folder, f'signals_snr_{snr_db}dB.pt')) 100 | torch.save(all_labels, os.path.join(label_folder, f'labels_snr_{snr_db}dB.pt')) 101 | return None 102 | 103 | 104 | class SignalDataset(Dataset): 105 | def __init__(self, file_paths, label_paths): 106 | """ 107 | Initializes a dataset containing signals and their corresponding labels. 108 | 109 | Args: 110 | file_paths (list): Paths to files containing signals. 111 | label_paths (list): Paths to files containing labels. 112 | """ 113 | self.signals = [torch.stack(torch.load(file), dim=0) for file in file_paths] 114 | self.labels = [torch.stack(torch.load(label), dim=0) for label in label_paths] 115 | self.signals = torch.cat(self.signals, dim=0) 116 | self.labels = torch.cat(self.labels, dim=0) 117 | 118 | def __len__(self): 119 | return len(self.signals) 120 | 121 | def __getitem__(self, idx): 122 | return self.signals[idx], self.labels[idx] 123 | 124 | def create_dataloader(data_path, batch_size=32, shuffle=True): 125 | """ 126 | Create a DataLoader for batching and shuffling the dataset. 127 | 128 | Args: 129 | data_path (str): Path to the directory containing the data files. 130 | batch_size (int): Number of samples per batch. 131 | shuffle (bool): Whether to shuffle the data. 132 | 133 | Returns: 134 | DataLoader: Configured DataLoader for the dataset. 135 | """ 136 | signal_dir_path = os.path.join(data_path, "signal") 137 | label_dir_path = os.path.join(data_path, "label") 138 | signal_files = [os.path.join(signal_dir_path, f) for f in os.listdir(signal_dir_path) if 'signals' in f] 139 | label_files = [os.path.join(label_dir_path, f) for f in os.listdir(label_dir_path) if 'labels' in f] 140 | dataset = SignalDataset(sorted(signal_files), sorted(label_files)) 141 | return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) -------------------------------------------------------------------------------- /scr/realData_demo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ruxinzh/Deep_RSA_DOA/02574383dd1a48eac7c0c780a80234e3e8d5603f/scr/realData_demo.mat -------------------------------------------------------------------------------- /scr/run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import scipy.io 4 | from eval_fun import run_monte_carlo_accuracy, run_monte_carlo_accuracy2, run_monte_carlo_sep, plot_results, plot_HR, model_complexity, generate_complex_signal, run_examples 5 | def main(): 6 | parser = argparse.ArgumentParser(description="Run Monte Carlo simulations for DOA estimation accuracy") 7 | parser.add_argument('--num_simulations', type=int, default=100, help="Number of Monte Carlo simulations to run") 8 | parser.add_argument('--num_antennas', type=int, default=10, help="Number of antennas in the array") 9 | parser.add_argument('--evaluation_mode', type=str, default='accuracy1', help="Evaluation mode: accuracy1, accuracy2, separate, examples, complexity") 10 | parser.add_argument('--real', type=bool, default = True, help="using real world data demo") 11 | args = parser.parse_args() 12 | 13 | if args.evaluation_mode == 'accuracy1': 14 | # Accuracy - single target 15 | print('Monte Carlo simulation on single target case with ULA \n') 16 | snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas, False) 17 | plot_results(snr_levels, mse_metrics) 18 | print('Monte Carlo simulation on single target case with SLA \n') 19 | snr_levels, mse_metrics = run_monte_carlo_accuracy(args.num_simulations, args.num_antennas) 20 | plot_results(snr_levels, mse_metrics) 21 | 22 | elif args.evaluation_mode == 'accuracy2': 23 | # Accuracy - two targets 24 | print('Monte Carlo simulation on two targets case with ULA \n') 25 | snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas, False) 26 | plot_results(snr_levels, mse_metrics) 27 | print('Monte Carlo simulation on two targets case with SLA \n') 28 | snr_levels, mse_metrics = run_monte_carlo_accuracy2(args.num_simulations, args.num_antennas) 29 | plot_results(snr_levels, mse_metrics) 30 | 31 | elif args.evaluation_mode == 'separate': 32 | # Sseparability 33 | print('Monte Carlo simulation on separability with ULA \n') 34 | sep_angles, HR_metrics = run_monte_carlo_sep(args.num_simulations, args.num_antennas, False) 35 | plot_HR(sep_angles, HR_metrics) 36 | print('Monte Carlo simulation on separability with SLA \n') 37 | sep_angles, HR_metrics = run_monte_carlo_sep(args.num_simulations, args.num_antennas) 38 | plot_HR(sep_angles, HR_metrics) 39 | 40 | elif args.evaluation_mode == 'examples': 41 | # Generate estimation example on real data 42 | if args.real: 43 | file_path = 'realData_demo.mat' 44 | mat_data = scipy.io.loadmat(file_path) 45 | bv = mat_data['bv_final'] 46 | bv = bv[:,0:10].T 47 | signal = torch.from_numpy(bv).to(torch.cfloat).numpy() 48 | else: 49 | signal = generate_complex_signal(10, 40, torch.tensor([0, 7])).numpy() 50 | run_examples(signal, num_antennas=10, sparse_flag=False) 51 | run_examples(signal, num_antennas=10, sparse_flag=True) 52 | 53 | elif args.evaluation_mode == 'complexity': 54 | # Counts total trainable parameters 55 | model_complexity() 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from tqdm.auto import tqdm 7 | 8 | from scr.helpers import * 9 | from models.SADOANet import SADOANet 10 | 11 | def validate_model(model, dataloader, criterion, device): 12 | """Perform validation and return the average loss.""" 13 | model.eval() # Set the model to evaluation mode 14 | total_loss = 0 15 | with torch.no_grad(): 16 | for signals, labels in dataloader: 17 | signals = signals.to(device).squeeze(-1) 18 | labels = labels.to(device) 19 | outputs = model(signals) 20 | loss = criterion(outputs, labels) 21 | total_loss += loss.item() 22 | return total_loss / len(dataloader) 23 | 24 | def main(args): 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | print(f"Using device: {device}") 27 | 28 | # Initialize the model 29 | model = SADOANet(args.number_elements, args.output_size, args.sparsity, args.use_sparse) 30 | 31 | # Check if multiple GPUs are available and wrap the model using DataParallel 32 | if torch.cuda.device_count() >= 4: 33 | print(f"Using {torch.cuda.device_count()} GPUs!") 34 | model = nn.DataParallel(model, device_ids=list(range(4))) # Modify here to change the number of GPUs used 35 | model = model.to(device) 36 | 37 | # Loss function and optimizer 38 | criterion = nn.BCELoss() 39 | optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) 40 | 41 | train_loader = create_dataloader(os.path.join(args.data_path, 'train'), batch_size=args.batch_size) 42 | val_loader = create_dataloader(os.path.join(args.data_path, 'val'), batch_size=args.batch_size) 43 | 44 | if args.use_sparse: 45 | save_checkpoint_path = os.path.join(args.checkpoint_path, 'sparse') 46 | os.makedirs(save_checkpoint_path, exist_ok=True) 47 | else: 48 | save_checkpoint_path = os.path.join(args.checkpoint_path, 'filled') 49 | os.makedirs(save_checkpoint_path, exist_ok=True) 50 | 51 | checkpoint_path = os.path.join(save_checkpoint_path, 'best_model_checkpoint.pth') 52 | final_model_path = os.path.join(save_checkpoint_path, 'final_model.pth') 53 | # model.load_state_dict(torch.load(checkpoint_path),strict= False) 54 | # Training loop 55 | best_val_loss = float('inf') 56 | for epoch in range(args.epochs): 57 | model.train() 58 | train_loss = 0 59 | for signals, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1} [Training]"): 60 | signals = signals.to(device).squeeze(-1) 61 | labels = labels.to(device) 62 | outputs = model(signals) 63 | loss = criterion(outputs, labels) 64 | 65 | optimizer.zero_grad() 66 | loss.backward() 67 | optimizer.step() 68 | train_loss += loss.item() 69 | 70 | average_train_loss = train_loss / len(train_loader) 71 | print(f"Epoch {epoch + 1} Training Loss: {average_train_loss}") 72 | 73 | if (epoch + 1) % 5 == 0: 74 | val_loss = validate_model(model, val_loader, criterion, device) 75 | print(f"Epoch {epoch + 1} Validation Loss: {val_loss}") 76 | if val_loss < best_val_loss: 77 | best_val_loss = val_loss 78 | torch.save(model.state_dict(), checkpoint_path) 79 | print(f"Best model saved with validation loss {best_val_loss} at {checkpoint_path}") 80 | 81 | torch.save(model.state_dict(), final_model_path) 82 | print(f"Final model saved at {final_model_path}") 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser(description="Train a DOA estimation model") 86 | parser.add_argument('--data_path', type=str, default='./data', 87 | help='Path to training and validation data directory') 88 | parser.add_argument('--checkpoint_path', type=str, default='./checkpoint', 89 | help='Path where to save model checkpoints') 90 | parser.add_argument('--number_elements', type=int, default=10, 91 | help='Number of array elements in the model') 92 | parser.add_argument('--output_size', type=int, default=61, 93 | help='Output size of the model') 94 | parser.add_argument('--sparsity', type=float, default=0.3, 95 | help='Sparsity level used in the model') 96 | parser.add_argument('--use_sparse', type=bool, default=False, 97 | help='Whether to use sparse augmentation layer in the model') 98 | parser.add_argument('--learning_rate', type=float, default=0.0001, 99 | help='Learning rate for the optimizer') 100 | parser.add_argument('--batch_size', type=int, default=1024, 101 | help='Batch size for training and validation') 102 | parser.add_argument('--epochs', type=int, default=300, 103 | help='Number of epochs to train the model') 104 | 105 | args = parser.parse_args() 106 | main(args) 107 | --------------------------------------------------------------------------------