├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── doa_py
├── __init__.py
├── algorithm
│ ├── __init__.py
│ ├── broadband.py
│ ├── esprit_based.py
│ ├── music_based.py
│ ├── sparse.py
│ └── utils.py
├── arrays.py
├── plot.py
└── signals.py
├── examples
├── broad.py
├── notebooks
│ ├── broadband.ipynb
│ ├── uca.ipynb
│ └── ula.ipynb
└── ula.py
├── pics
├── doa_py.svg
├── esprit.svg
├── l1_svd.svg
├── music_spectrum.svg
└── uca_rb_music.svg
├── pyproject.toml
├── ruff.toml
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v3
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build
33 | - name: Build package
34 | run: python -m build
35 | - name: Publish package
36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
37 | with:
38 | user: __token__
39 | password: ${{ secrets.PYPI_API_TOKEN }}
40 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | _*.py
2 | uv.lock
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
165 | # vscode
166 | .vscode/*
167 | !.vscode/settings.json
168 | !.vscode/tasks.json
169 | !.vscode/launch.json
170 | !.vscode/extensions.json
171 | !.vscode/*.code-snippets
172 |
173 | # Local History for Visual Studio Code
174 | .history/
175 |
176 | # Built Visual Studio Code Extensions
177 | *.vsix
178 |
179 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Qian Xu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | # DOA_Py
6 |
7 | DOA Estimation algorithms implemented in Python. It can be used for ULA, UCA and broadband/wideband DOA estimation.
8 |
9 | ## Getting Started
10 |
11 | ### Installation
12 |
13 | ```bash
14 | pip install doa_py
15 | ```
16 |
17 | or install from source
18 |
19 | ```bash
20 | git clone https://github.com/zhiim/doa_py.git
21 | cd doa_py
22 | pip install .
23 | ```
24 |
25 | ### Usage
26 |
27 | A sample example of DOA estimation using MUSIC algorithm.
28 |
29 | ```python
30 | import numpy as np
31 |
32 | from doa_py import arrays, signals
33 | from doa_py.algorithm import music
34 | from doa_py.plot import plot_spatial_spectrum
35 |
36 | # Create a 8-element ULA with 0.5m spacing
37 | ula = arrays.UniformLinearArray(m=8, dd=0.5)
38 | # Create a complex stochastic signal
39 | source = signals.ComplexStochasticSignal(fc=3e8)
40 |
41 | # Simulate the received data
42 | received_data = ula.received_signal(
43 | signal=source, snr=0, nsamples=1000, angle_incidence=np.array([0, 30]), unit="deg"
44 | )
45 |
46 | # Calculate the MUSIC spectrum
47 | angle_grids = np.arange(-90, 90, 1)
48 | spectrum = music(
49 | received_data=received_data,
50 | num_signal=2,
51 | array=ula,
52 | signal_fre=3e8,
53 | angle_grids=angle_grids,
54 | unit="deg",
55 | )
56 |
57 | # Plot the spatial spectrum
58 | plot_spatial_spectrum(
59 | spectrum=spectrum,
60 | ground_truth=np.array([0, 30]),
61 | angle_grids=angle_grids,
62 | num_signal=2,
63 | )
64 | ```
65 |
66 | You will a get a figure like this:
67 | 
68 |
69 | Check [examples](./examples/) for for more details on how to use it.
70 |
71 | You can see more plot results of the algorithm in the [Showcase](#showcase).
72 |
73 | ## What's implemented
74 |
75 | ### Array Structures
76 |
77 | - Uniform Linear Array (support array position error and mutual coupling error)
78 | - Uniform Circular Array
79 |
80 | ### Signal Models
81 |
82 | - **Narrowband**
83 | - _ComplexStochasticSignal_: The amplitude of signals at each sampling point is a complex random variable.
84 | - _RandomFreqSignal_: Signals transmitted by different sources have different intermediate frequencies (support coherent mode).
85 | - **Broadband**
86 | - _ChirpSignal_: Chirp signals with different chirp bandwidths within the sampling period.
87 | - _MultiFreqSignal_: Broadband signals formed by the superposition of multiple single-frequency signals within a certain frequency band.
88 | - _MixedSignal_: Narrorband and broadband mixed signal
89 |
90 | ### Algorithms
91 |
92 | - DOA estimation for ULA
93 | - [x] MUSIC
94 | - [x] ESPRIT
95 | - [x] Root-MUSIC
96 | - [x] OMP
97 | - [x] $l_1$-SVD
98 | - DOA estimation for URA
99 | - [ ] URA-MUSIC
100 | - [ ] URA-ESPRIT
101 | - DOA estimation for UCA
102 | - [x] UCA-RB-MUSIC
103 | - [x] UCA-ESPRIT
104 | - Broadband/Wideband DOA estimation
105 | - [x] iMUSIC
106 | - [x] CSSM
107 | - [x] TOPS
108 | - Coherent DOA estimation
109 | - [x] smoothed-MUSIC
110 |
111 | ### Showcase
112 |
113 | 
114 |
115 | 
116 |
117 | 
118 |
119 | ## License
120 |
121 | This project is licensed under the [MIT](LICENSE) License - see the LICENSE file for details.
122 |
--------------------------------------------------------------------------------
/doa_py/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | # DOA_Py
3 |
4 | DOA Estimation algorithms implemented in Python. It can be used for ULA, UCA and
5 | broadband/wideband DOA estimation.
6 |
7 | ## Getting Started
8 |
9 | ### Installation
10 |
11 | ```bash
12 | pip install doa_py
13 | ```
14 |
15 | or install from source
16 |
17 | ```bash
18 | git clone https://github.com/zhiim/doa_py.git
19 | cd doa_py
20 | pip install .
21 | ```
22 |
23 | ### Usage
24 |
25 | A sample example of DOA estimation using MUSIC algorithm.
26 |
27 | ```python
28 | import numpy as np
29 |
30 | from doa_py import arrays, signals
31 | from doa_py.algorithm import music
32 | from doa_py.plot import plot_spatial_spectrum
33 |
34 | # Create a 8-element ULA with 0.5m spacing
35 | ula = arrays.UniformLinearArray(m=8, dd=0.5)
36 | # Create a complex stochastic signal
37 | source = signals.ComplexStochasticSignal(fc=3e8)
38 |
39 | # Simulate the received data
40 | received_data = ula.received_signal(
41 | signal=source, snr=0, nsamples=1000, angle_incidence=np.array([0, 30]),
42 | unit="deg"
43 | )
44 |
45 | # Calculate the MUSIC spectrum
46 | angle_grids = np.arange(-90, 90, 1)
47 | spectrum = music(
48 | received_data=received_data,
49 | num_signal=2,
50 | array=ula,
51 | signal_fre=3e8,
52 | angle_grids=angle_grids,
53 | unit="deg",
54 | )
55 |
56 | # Plot the spatial spectrum
57 | plot_spatial_spectrum(
58 | spectrum=spectrum,
59 | ground_truth=np.array([0, 30]),
60 | angle_grids=angle_grids,
61 | num_signal=2,
62 | )
63 | ```
64 |
65 | You will a get a figure like this:
66 | 
67 |
68 | Check [examples](https://github.com/zhiim/doa_py/tree/master/examples) for more
69 | examples.
70 |
71 | ## What's implemented
72 |
73 | ### Array Structures
74 |
75 | - Uniform Linear Array (ULA)
76 | - Uniform Rectangular Array (URA, to be implemented)
77 | - Uniform Circular Array (UCA)
78 |
79 | ### Signal Models
80 |
81 | - **Narrowband**
82 | - _ComplexStochasticSignal_: The amplitude of signals at each sampling point
83 | is a complex random variable.
84 | - _RandomFreqSignal_: Signals transmitted by different sources have different
85 | intermediate frequencies (IF).
86 | - **Broadband**
87 | - _ChirpSignal_: Chirp signals with different chirp bandwidths within the
88 | sampling period.
89 | - _MultiFreqSignal_: Broadband signals formed by the superposition of multiple
90 | single-frequency signals within a certain frequency band.
91 | - _MixedSignal_: Narrorband and broadband mixed signal
92 |
93 | ### Algorithms
94 |
95 | - DOA estimation for ULA
96 | - [x] MUSIC
97 | - [x] ESPRIT
98 | - [x] Root-MUSIC
99 | - [x] OMP
100 | - [x] l1-SVD
101 | - DOA estimation for URA
102 | - [ ] URA-MUSIC
103 | - [ ] URA-ESPRIT
104 | - DOA estimation for UCA
105 | - [x] UCA-RB-MUSIC
106 | - [x] UCA-ESPRIT
107 | - Broadband/Wideband DOA estimation
108 | - [x] ISSM
109 | - [x] CSSM
110 | - [x] TOPS
111 |
112 | ## License
113 |
114 | This project is licensed under the [MIT](LICENSE) License - see the LICENSE file
115 | for details.
116 | """
117 |
118 | __version__ = "0.4.0"
119 | __author__ = "Qian Xu"
120 |
--------------------------------------------------------------------------------
/doa_py/algorithm/__init__.py:
--------------------------------------------------------------------------------
1 | from .broadband import *
2 | from .esprit_based import *
3 | from .music_based import *
4 | from .sparse import *
5 |
--------------------------------------------------------------------------------
/doa_py/algorithm/broadband.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .music_based import music
4 | from .utils import (
5 | divide_into_fre_bins,
6 | get_noise_space,
7 | get_signal_space,
8 | )
9 |
10 | C = 3e8
11 |
12 |
13 | def imusic(
14 | received_data,
15 | num_signal,
16 | array,
17 | fs,
18 | angle_grids,
19 | num_groups,
20 | f_min=None,
21 | f_max=None,
22 | n_fft_min=128,
23 | unit="deg",
24 | ):
25 | """Incoherent MUSIC estimator for wideband DOA estimation.
26 |
27 | Args:
28 | received_data : Array received signals
29 | num_signal : Number of signals
30 | array : Instance of array class
31 | fs: sampling frequency
32 | angle_grids : Angle grids corresponding to spatial spectrum. It should
33 | be a numpy array.
34 | num_groups: Divide sampling points into serveral groups, and do FFT
35 | separately in each group
36 | f_min : Minimum frequency of interest. Defaults to None.
37 | f_max : Maximum frequency of interest. Defaults to None.
38 | n_fft_min: minimum number of FFT points
39 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
40 | 'deg'.
41 |
42 | References:
43 | Wax, M., Tie-Jun Shan, and T. Kailath. “Spatio-Temporal Spectral
44 | Analysis by Eigenstructure Methods.” IEEE Transactions on Acoustics,
45 | Speech, and Signal Processing 32, no. 4 (August 1984): 817-27.
46 | https://doi.org/10.1109/TASSP.1984.1164400.
47 | """
48 | signal_fre_bins, fre_bins = divide_into_fre_bins(
49 | received_data, num_groups, fs, f_min, f_max, n_fft_min
50 | )
51 |
52 | # MUSIC algorithm in every frequency point
53 | spectrum_fre_bins = np.zeros((signal_fre_bins.shape[1], angle_grids.size))
54 | for i, fre in enumerate(fre_bins):
55 | spectrum_fre_bins[i, :] = music(
56 | received_data=signal_fre_bins[:, i, :],
57 | num_signal=num_signal,
58 | array=array,
59 | signal_fre=fre,
60 | angle_grids=angle_grids,
61 | unit=unit,
62 | )
63 |
64 | spectrum = np.mean(spectrum_fre_bins, axis=0)
65 |
66 | return np.squeeze(spectrum)
67 |
68 |
69 | def norm_music(
70 | received_data,
71 | num_signal,
72 | array,
73 | fs,
74 | angle_grids,
75 | num_groups,
76 | f_min=None,
77 | f_max=None,
78 | n_fft_min=128,
79 | unit="deg",
80 | ):
81 | """Normalized incoherent MUSIC estimator for wideband DOA estimation.
82 |
83 | Args:
84 | received_data : Array received signals
85 | num_signal : Number of signals
86 | array : Instance of array class
87 | fs: sampling frequency
88 | angle_grids : Angle grids corresponding to spatial spectrum. It should
89 | be a numpy array.
90 | num_groups: Divide sampling points into serveral groups, and do FFT
91 | separately in each group
92 | f_min : Minimum frequency of interest. Defaults to None.
93 | f_max : Maximum frequency of interest. Defaults to None.
94 | n_fft_min: minimum number of FFT points
95 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
96 | 'deg'.
97 |
98 | References:
99 | Salvati, Daniele, Carlo Drioli, and Gian Luca Foresti. “Incoherent
100 | Frequency Fusion for Broadband Steered Response Power Algorithms in
101 | Noisy Environments.” IEEE Signal Processing Letters 21, no. 5
102 | (May 2014): 581–85. https://doi.org/10.1109/LSP.2014.2311164.
103 | """
104 | signal_fre_bins, fre_bins = divide_into_fre_bins(
105 | received_data, num_groups, fs, f_min, f_max, n_fft_min
106 | )
107 |
108 | # MUSIC algorithm in every frequency point
109 | spectrum_fre_bins = np.zeros((signal_fre_bins.shape[1], angle_grids.size))
110 | for i, fre in enumerate(fre_bins):
111 | spectrum_fre_bins[i, :] = music(
112 | received_data=signal_fre_bins[:, i, :],
113 | num_signal=num_signal,
114 | array=array,
115 | signal_fre=fre,
116 | angle_grids=angle_grids,
117 | unit=unit,
118 | )
119 |
120 | spectrum = np.mean(
121 | spectrum_fre_bins
122 | / np.linalg.norm(spectrum_fre_bins, ord=np.inf, axis=1).reshape(-1, 1),
123 | axis=0,
124 | )
125 |
126 | return np.squeeze(spectrum)
127 |
128 |
129 | def cssm(
130 | received_data,
131 | num_signal,
132 | array,
133 | fs,
134 | angle_grids,
135 | pre_estimate,
136 | fre_ref=None,
137 | f_min=None,
138 | f_max=None,
139 | unit="deg",
140 | ):
141 | """Coherent Signal Subspace Method (CSSM) for wideband DOA estimation.
142 |
143 | Args:
144 | received_data : Array received signals
145 | num_signal : Number of signals
146 | array : Instance of array class
147 | fs: sampling frequency
148 | angle_grids : Angle grids corresponding to spatial spectrum. It should
149 | be a numpy array.
150 | pre_estimate: pre-estimated angles
151 | fre_ref: reference frequency. If it's not provided the frequency point
152 | with the maximum power will be used.
153 | f_min : Minimum frequency of interest. Defaults to None.
154 | f_max : Maximum frequency of interest. Defaults to None.
155 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
156 | 'deg'.
157 |
158 | References:
159 | Wang, H., and M. Kaveh. “Coherent Signal-Subspace Processing for the
160 | Detection and Estimation of Angles of Arrival of Multiple Wide-Band
161 | Sources.” IEEE Transactions on Acoustics, Speech, and Signal Processing
162 | 33, no. 4 (August 1985): 823-31.
163 | https://doi.org/10.1109/TASSP.1985.1164667.
164 | """
165 | num_snapshots = received_data.shape[1]
166 | pre_estimate = pre_estimate.reshape(1, -1)
167 |
168 | # Divide the received signal into multiple frequency points
169 | delta_f = fs / num_snapshots
170 | # there is a little trick to use as wider frequency range as possible
171 | idx_f_min = max(int(f_min / delta_f) - 1, 0) if f_min is not None else 0
172 | idx_f_max = (
173 | min(int(f_max / delta_f) + 1, num_snapshots // 2)
174 | if f_max is not None
175 | else num_snapshots // 2
176 | )
177 | signal_fre_bins = np.fft.fft(received_data, axis=1)[
178 | :, idx_f_min : idx_f_max + 1
179 | ]
180 | fre_bins = np.fft.fftfreq(num_snapshots, 1 / fs)[idx_f_min : idx_f_max + 1]
181 |
182 | if fre_ref is None:
183 | # Find the frequency point with the maximum power
184 | fre_ref = fre_bins[np.argmax(np.abs(signal_fre_bins).sum(axis=0))]
185 |
186 | # Calculate the manifold matrix corresponding to the pre-estimated angles at
187 | # the reference frequency point
188 | matrix_a_ref = array.steering_vector(fre_ref, pre_estimate, unit=unit)
189 |
190 | for i, fre in enumerate(fre_bins):
191 | # Manifold matrix corresponding to the pre-estimated angles at
192 | # each frequency point
193 | matrix_a_f = array.steering_vector(fre, pre_estimate, unit=unit)
194 | matrix_q = matrix_a_f @ matrix_a_ref.transpose().conj()
195 | # Perform singular value decomposition on matrix_q
196 | matrix_u, _, matrix_vh = np.linalg.svd(matrix_q)
197 | # Construct the optimal focusing matrix using the RSS method
198 | matrix_t_f = matrix_vh.transpose().conj() @ matrix_u.transpose().conj()
199 | # Focus the received signals at each frequency point to the reference
200 | # frequency point
201 | signal_fre_bins[:, i] = matrix_t_f @ signal_fre_bins[:, i]
202 |
203 | spectrum = music(
204 | received_data=signal_fre_bins,
205 | num_signal=num_signal,
206 | array=array,
207 | signal_fre=fre_ref,
208 | angle_grids=angle_grids,
209 | unit=unit,
210 | )
211 |
212 | return np.squeeze(spectrum)
213 |
214 |
215 | def tops(
216 | received_data,
217 | num_signal,
218 | array,
219 | fs,
220 | angle_grids,
221 | num_groups,
222 | fre_ref=None,
223 | f_min=None,
224 | f_max=None,
225 | n_fft_min=128,
226 | unit="deg",
227 | ):
228 | """Test of orthogonality of projected subspaces (TOPS) method for wideband
229 | DOA estimation.
230 |
231 | Args:
232 | received_data: received signals from the array.
233 | num_signal: Number of signals.
234 | array : Instance of array class
235 | fs: Sampling frequency.
236 | angle_grids: Grid points of spatial spectrum, should be a numpy array.
237 | num_groups: Number of groups for FFT, each group performs an
238 | independent FFT.
239 | fre_ref: reference frequency. If it's not provided the frequency point
240 | with the maximum power will be used.
241 | f_min : Minimum frequency of interest. Defaults to None.
242 | f_max : Maximum frequency of interest. Defaults to None.
243 | n_fft_min: minimum number of FFT points
244 | unit: Unit of angle measurement, 'rad' for radians, 'deg' for degrees.
245 | Defaults to 'deg'.
246 |
247 | References:
248 | Yoon, Yeo-Sun, L.M. Kaplan, and J.H. McClellan. “TOPS: New DOA Estimator
249 | for Wideband Signals.” IEEE Transactions on Signal Processing 54, no. 6
250 | (June 2006): 1977-89. https://doi.org/10.1109/TSP.2006.872581.
251 | """
252 | num_antennas = received_data.shape[0]
253 |
254 | signal_fre_bins, fre_bins = divide_into_fre_bins(
255 | received_data, num_groups, fs, f_min, f_max, n_fft_min
256 | )
257 |
258 | if fre_ref is None:
259 | fre_ref = fre_bins[np.argmax(np.abs(signal_fre_bins).sum(axis=(0, 2)))]
260 |
261 | # index of reference frequency in FFT output
262 | ref_index = int(fre_ref / (fs / fre_bins.size))
263 | # get signal space of reference frequency
264 | signal_space_ref = get_signal_space(
265 | np.cov(signal_fre_bins[:, ref_index, :]), num_signal=num_signal
266 | )
267 |
268 | spectrum = np.zeros(angle_grids.size)
269 | for i, grid in enumerate(angle_grids):
270 | matrix_d = np.empty((num_signal, 0), dtype=np.complex128)
271 |
272 | for j, fre in enumerate(fre_bins):
273 | # calculate noise subspace for the current frequency point
274 | noise_space_f = get_noise_space(
275 | np.cov(signal_fre_bins[:, j, :]), num_signal
276 | )
277 |
278 | # construct transformation matrix
279 | matrix_phi = array.steering_vector(fre - fre_ref, grid, unit=unit)
280 | matrix_phi = np.diag(np.squeeze(matrix_phi))
281 |
282 | # transform the signal subspace of the reference frequency to the
283 | # current frequency using the transformation matrix
284 | matrix_u = matrix_phi @ signal_space_ref
285 |
286 | # construct projection matrix to reduce errors in matrix U
287 | matrix_a_f = array.steering_vector(fre, grid, unit=unit)
288 | matrix_p = (
289 | np.eye(num_antennas)
290 | - 1
291 | / (matrix_a_f.transpose().conj() @ matrix_a_f)
292 | * matrix_a_f
293 | @ matrix_a_f.transpose().conj()
294 | )
295 |
296 | # project matrix U using the projection matrix
297 | matrix_u = matrix_p @ matrix_u
298 |
299 | matrix_d = np.concatenate(
300 | (matrix_d, matrix_u.T.conj() @ noise_space_f), axis=1
301 | )
302 |
303 | # construct spatial spectrum using the minimum eigenvalue of matrix D
304 | _, s, _ = np.linalg.svd(matrix_d)
305 | spectrum[i] = 1 / min(s)
306 |
307 | return spectrum
308 |
--------------------------------------------------------------------------------
/doa_py/algorithm/esprit_based.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .utils import get_signal_space
4 |
5 | C = 3e8
6 |
7 |
8 | def esprit(received_data, num_signal, array, signal_fre, unit="deg"):
9 | """Total least-squares ESPRIT. Most names of matrix are taken directly from
10 | the reference paper.
11 |
12 | Args:
13 | received_data : Array received signals
14 | num_signal : Number of signals
15 | array : Instance of array class
16 | signal_fre: Signal frequency
17 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
18 | 'deg'.
19 |
20 | Reference:
21 | Roy, R., and T. Kailath. “ESPRIT-Estimation of Signal Parameters via
22 | Rotational Invariance Techniques.” IEEE Transactions on Acoustics,
23 | Speech, and Signal Processing 37, no. 7 (July 1989): 984-95.
24 | https://doi.org/10.1109/29.32276.
25 | """
26 | signal_space = get_signal_space(np.cov(received_data), num_signal)
27 |
28 | # get signal space of two sub array. Each sub array consists of M-1 antennas
29 | matrix_e_x = signal_space[:-1, :]
30 | matrix_e_y = signal_space[1:, :]
31 | # the fixed distance of corresponding elements in two sub-array ensures
32 | # the rotational invariance
33 | sub_array_spacing = array.array_position[1][1] - array.array_position[0][1]
34 |
35 | matrix_c = np.hstack(
36 | (matrix_e_x, matrix_e_y)
37 | ).transpose().conj() @ np.hstack((matrix_e_x, matrix_e_y))
38 |
39 | # get eigenvectors
40 | eigenvalues, eigenvectors = np.linalg.eig(matrix_c)
41 | sorted_index = np.argsort(np.abs(eigenvalues))[::-1] # descending order
42 | matrix_e = eigenvectors[:, sorted_index[: 2 * num_signal]]
43 |
44 | # take the upper right and lower right sub matrix
45 | matrix_e_12 = matrix_e[:num_signal, num_signal:]
46 | matrix_e_22 = matrix_e[num_signal:, num_signal:]
47 |
48 | matrix_psi = -matrix_e_12 @ np.linalg.inv(matrix_e_22)
49 | matrix_phi = np.linalg.eigvals(matrix_psi)
50 |
51 | # Note: the signal model we use is different from model in reference paper,
52 | # so there should be "-2 pi f"
53 | angles = np.arcsin(
54 | C
55 | * np.angle(matrix_phi)
56 | / ((-2 * np.pi * signal_fre) * sub_array_spacing)
57 | )
58 |
59 | if unit == "deg":
60 | angles = angles / np.pi * 180
61 |
62 | return np.sort(angles)
63 |
64 |
65 | def uca_esprit(received_data, num_signal, array, signal_fre, unit="deg"):
66 | """UCA-ESPRIT for Uniform Circular Array.
67 |
68 | Args:
69 | received_data : Array received signals
70 | num_signal : Number of signals
71 | array : Instance of array class
72 | signal_fre: Signal frequency
73 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
74 | 'deg'.
75 |
76 | Reference:
77 | Mathews, C.P., and M.D. Zoltowski. “Eigenstructure Techniques for 2-D
78 | Angle Estimation with Uniform Circular Arrays.” IEEE Transactions on
79 | Signal Processing 42, no. 9 (September 1994): 2395-2407.
80 | https://doi.org/10.1109/78.317861.
81 | """
82 | # max number of phase modes can be excitated
83 | m = int(np.floor(2 * np.pi * array.radius / (C / signal_fre)))
84 |
85 | matrix_c_v = np.diag(
86 | 1j ** np.concatenate((np.arange(-m, 0), np.arange(0, -m - 1, step=-1)))
87 | )
88 | matrix_v = (
89 | 1
90 | / np.sqrt(array.num_antennas)
91 | * np.exp(
92 | -1j
93 | * 2
94 | * np.pi
95 | * np.arange(0, array.num_antennas).reshape(-1, 1)
96 | @ np.arange(-m, m + 1).reshape(1, -1)
97 | / array.num_antennas
98 | )
99 | )
100 | matrix_f_e = matrix_v @ matrix_c_v.conj().transpose()
101 | matrix_w = (
102 | 1
103 | / np.sqrt(2 * m + 1)
104 | * np.exp(
105 | 1j
106 | * 2
107 | * np.pi
108 | * np.arange(-m, m + 1).reshape(-1, 1)
109 | @ np.arange(-m, m + 1).reshape(1, -1)
110 | / (2 * m + 1)
111 | )
112 | )
113 | matrix_f_r = matrix_f_e @ matrix_w
114 |
115 | # beamspace data vector
116 | beamspace_data = matrix_f_r.conj().transpose() @ received_data
117 |
118 | # only use the real part of covariance matrix
119 | cov_real = np.real(np.cov(beamspace_data))
120 | signal_space = get_signal_space(cov_real, num_signal)
121 |
122 | matrix_c_o = np.diag(
123 | (-1) ** np.concatenate((np.arange(m, -1, step=-1), np.zeros(m)))
124 | )
125 | signal_space = matrix_c_o @ matrix_w @ signal_space
126 |
127 | s1 = signal_space[:-2, :]
128 | s2 = signal_space[1:-1, :]
129 | s3 = signal_space[2:, :]
130 |
131 | matrix_gamma = (1 / np.pi / array.radius) * np.diag(np.arange(-(m - 1), m))
132 | matrix_e = np.hstack((s1, s3))
133 | matrix_psi_hat = (
134 | np.linalg.inv(matrix_e.conj().transpose() @ matrix_e)
135 | @ matrix_e.conj().transpose()
136 | @ matrix_gamma
137 | @ s2
138 | )
139 | matrix_psi = matrix_psi_hat[: len(matrix_psi_hat) // 2, :]
140 |
141 | eig_values = np.linalg.eigvals(matrix_psi)
142 |
143 | elevation = np.arccos(np.abs(eig_values))
144 | azimuth = np.angle(-eig_values)
145 |
146 | if unit == "deg":
147 | elevation = elevation / np.pi * 180
148 | azimuth = azimuth / np.pi * 180
149 |
150 | return azimuth, elevation
151 |
--------------------------------------------------------------------------------
/doa_py/algorithm/music_based.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from .utils import forward_backward_smoothing, get_noise_space
4 |
5 | C = 3e8
6 |
7 |
8 | def music(
9 | received_data, num_signal, array, signal_fre, angle_grids, unit="deg"
10 | ):
11 | """1D MUSIC
12 |
13 | Args:
14 | received_data : Array received signals
15 | num_signal : Number of signals
16 | array : Instance of array class
17 | signal_fre: Signal frequency
18 | angle_grids : Angle grids corresponding to spatial spectrum. It should
19 | be a numpy array.
20 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
21 | 'deg'.
22 | """
23 | noise_space = get_noise_space(np.cov(received_data), num_signal)
24 |
25 | # Calculate the manifold matrix when there are incident signal in all
26 | # grid points
27 | manifold_all_grids = array.steering_vector(
28 | signal_fre, angle_grids, unit=unit
29 | )
30 |
31 | v = noise_space.transpose().conj() @ manifold_all_grids
32 |
33 | # Each column of matrix v corresponds to an incident signal, calculate the
34 | # square of the 2-norm for each column
35 | spectrum = 1 / np.linalg.norm(v, axis=0) ** 2
36 |
37 | return np.squeeze(spectrum)
38 |
39 |
40 | def root_music(received_data, num_signal, array, signal_fre, unit="deg"):
41 | """Root-MUSIC
42 |
43 | Args:
44 | received_data : Array of received signals
45 | num_signal : Number of signals
46 | array : Instance of array class
47 | signal_fre: Signal frequency
48 | unit: The unit of the angle, `rad` represents radian, `deg` represents
49 | degree. Defaults to 'deg'.
50 |
51 | References:
52 | Rao, B.D., and K.V.S. Hari. “Performance Analysis of Root-Music.”
53 | IEEE Transactions on Acoustics, Speech, and Signal Processing 37,
54 | no. 12 (December 1989): 1939-49. https://doi.org/10.1109/29.45540.
55 | """
56 | noise_space = get_noise_space(np.cov(received_data), num_signal)
57 |
58 | num_antennas = array.num_antennas
59 | antenna_spacing = array.array_position[1][1] - array.array_position[0][1]
60 |
61 | # Since the polynomial solving function provided by numpy requires the
62 | # coefficients of the polynomial as input, and extracting the coefficients
63 | # is very complex, so the implementation code of rootMMUSIC in doatools is
64 | # directly used here.
65 |
66 | # Alternatively, the sympy library can be used to solve polynomial
67 | # equations, but it will be more computationally expensive.
68 |
69 | # Compute the coefficients for the polynomial.
70 | matrix_c = noise_space @ noise_space.transpose().conj()
71 | coeff = np.zeros((num_antennas - 1,), dtype=np.complex128)
72 | for i in range(1, num_antennas):
73 | coeff[i - 1] += np.sum(np.diag(matrix_c, i))
74 | coeff = np.hstack((coeff[::-1], np.sum(np.diag(matrix_c)), coeff.conj()))
75 | # Find the roots of the polynomial.
76 | z = np.roots(coeff)
77 |
78 | # To avoid simultaneously obtaining a pair of complex conjugate roots, only
79 | # take roots inside the unit circle
80 | roots_inside_unit_circle = np.extract(np.abs(z) <= 1, z)
81 | sorted_index = np.argsort(np.abs(np.abs(roots_inside_unit_circle) - 1))
82 | chosen_roots = roots_inside_unit_circle[sorted_index[:num_signal]]
83 |
84 | angles = np.arcsin(
85 | (C / signal_fre)
86 | / (-2 * np.pi * antenna_spacing)
87 | * np.angle(chosen_roots)
88 | )
89 |
90 | if unit == "deg":
91 | angles = angles / np.pi * 180
92 |
93 | return np.sort(angles)
94 |
95 |
96 | def uca_rb_music(
97 | received_data,
98 | num_signal,
99 | array,
100 | signal_fre,
101 | azimuth_grids,
102 | elevation_grids,
103 | unit="deg",
104 | ):
105 | """form MUSIC for Uniform Circular Array (UCA)
106 |
107 | Args:
108 | received_data (_type_): _description_
109 | num_signal (_type_): _description_
110 | array (_type_): _description_
111 | signal_fre (_type_): _description_
112 | angle_grids (_type_): _description_
113 | unit (str, optional): _description_. Defaults to "deg".
114 |
115 | References:
116 | Mathews, C.P., and M.D. Zoltowski. “Eigenstructure Techniques for 2-D
117 | Angle Estimation with Uniform Circular Arrays.” IEEE Transactions on
118 | Signal Processing 42, no. 9 (September 1994): 2395-2407.
119 | https://doi.org/10.1109/78.317861.
120 | """
121 | # max number of phase modes can be excitated
122 | m = np.floor(2 * np.pi * array.radius / (C / signal_fre))
123 |
124 | matrix_c_v = np.diag(
125 | 1j ** np.concatenate((np.arange(-m, 0), np.arange(0, -m - 1, step=-1)))
126 | )
127 | matrix_v = (
128 | 1
129 | / np.sqrt(array.num_antennas)
130 | * np.exp(
131 | -1j
132 | * 2
133 | * np.pi
134 | * np.arange(0, array.num_antennas).reshape(-1, 1)
135 | @ np.arange(-m, m + 1).reshape(1, -1)
136 | / array.num_antennas
137 | )
138 | )
139 | matrix_f_e = matrix_v @ matrix_c_v.conj().transpose()
140 | matrix_w = (
141 | 1
142 | / np.sqrt(2 * m + 1)
143 | * np.exp(
144 | 1j
145 | * 2
146 | * np.pi
147 | * np.arange(-m, m + 1).reshape(-1, 1)
148 | @ np.arange(-m, m + 1).reshape(1, -1)
149 | / (2 * m + 1)
150 | )
151 | )
152 | matrix_f_r = matrix_f_e @ matrix_w
153 |
154 | # beamspace data vector
155 | beamspace_data = matrix_f_r.conj().transpose() @ received_data
156 |
157 | # only use the real part of covariance matrix
158 | cov_real = np.real(np.cov(beamspace_data))
159 | noise_space = get_noise_space(cov_real, num_signal)
160 |
161 | spectrum = np.zeros((azimuth_grids.size, elevation_grids.size))
162 | for i, elevation in enumerate(elevation_grids):
163 | angle_grids = np.vstack(
164 | (azimuth_grids, elevation * np.ones_like(azimuth_grids))
165 | )
166 | # Calculate the manifold matrix when there are incident signal in all
167 | # grid points
168 | manifold_all_grids = array.steering_vector(
169 | signal_fre, angle_grids, unit=unit
170 | )
171 | manifold_all_grids = matrix_f_r.conj().transpose() @ manifold_all_grids
172 |
173 | v = noise_space.transpose().conj() @ manifold_all_grids
174 |
175 | # Each column of matrix v corresponds to an incident signal, calculate
176 | # the square of the 2-norm for each column
177 | spectrum[:, i] = 1 / np.linalg.norm(v, axis=0) ** 2
178 |
179 | return spectrum
180 |
181 |
182 | def smoothed_music(
183 | received_data,
184 | num_signal,
185 | array,
186 | signal_fre,
187 | angle_grids,
188 | subarray_size=None,
189 | unit="deg",
190 | ):
191 | """1D MUSIC with forward-backward smoothing for coherent signal
192 |
193 | Args:
194 | received_data : Array received signals
195 | num_signal : Number of signals
196 | array : Instance of array class
197 | signal_fre: Signal frequency
198 | angle_grids : Angle grids corresponding to spatial spectrum. It should
199 | be a numpy array.
200 | subarray_size : Size of subarray for spatial smoothing. Defaults to
201 | None.
202 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
203 | 'deg'.
204 | """
205 | if subarray_size is None:
206 | subarray_size = received_data.shape[0] // 2
207 |
208 | if subarray_size < num_signal + 1:
209 | raise ValueError("Subarray size must be greater than number of signals")
210 |
211 | smoothed_data = forward_backward_smoothing(received_data, subarray_size)
212 |
213 | noise_space = get_noise_space(smoothed_data, num_signal)
214 |
215 | # Calculate the manifold matrix when there are incident signal in all
216 | # grid points
217 | manifold_all_grids = array.steering_vector(
218 | signal_fre, angle_grids, unit=unit
219 | )[0:subarray_size, :]
220 |
221 | v = noise_space.transpose().conj() @ manifold_all_grids
222 |
223 | # Each column of matrix v corresponds to an incident signal, calculate the
224 | # square of the 2-norm for each column
225 | spectrum = 1 / np.linalg.norm(v, axis=0) ** 2
226 |
227 | return np.squeeze(spectrum)
228 |
--------------------------------------------------------------------------------
/doa_py/algorithm/sparse.py:
--------------------------------------------------------------------------------
1 | import cvxpy as cp
2 | import numpy as np
3 |
4 | C = 3e8
5 |
6 |
7 | def omp(received_data, num_signal, array, signal_fre, angle_grids, unit="deg"):
8 | """OMP based sparse representation algorithms for DOA estimation
9 |
10 | Args:
11 | received_data : Array received signals
12 | num_signal : Number of signals
13 | array : Instance of array class
14 | signal_fre: Signal frequency
15 | angle_grids : Angle grids corresponding to spatial spectrum. It should
16 | be a numpy array.
17 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
18 | 'deg'.
19 |
20 | Reference:
21 | Cotter, Shane F. “Multiple Snapshot Matching Pursuit for Direction of
22 | Arrival (DOA) Estimation.” In 2007 15th European Signal Processing
23 | Conference, 247-51, 2007.
24 | https://ieeexplore.ieee.org/abstract/document/7098802.
25 | """
26 | angle_grids = angle_grids.reshape(
27 | -1,
28 | )
29 |
30 | # # build the overcomplete basis
31 | matrix_a_over = array.steering_vector(signal_fre, angle_grids, unit=unit)
32 |
33 | # initiate iteration
34 | atom_index = []
35 | residual = received_data
36 |
37 | # iteration
38 | while len(atom_index) < num_signal:
39 | # measure relevance using Frobenius norm
40 | relevance = np.linalg.norm(
41 | matrix_a_over.transpose().conj() @ residual, axis=1
42 | )
43 | index_max = np.argmax(relevance)
44 | # append index of atoms
45 | if index_max not in atom_index:
46 | atom_index.append(index_max)
47 | # update residual
48 | chosen_atom = np.asmatrix(matrix_a_over[:, atom_index])
49 | sparse_vector = (
50 | np.linalg.inv(chosen_atom.transpose().conj() @ chosen_atom)
51 | @ chosen_atom.transpose().conj()
52 | @ received_data
53 | )
54 | residual = received_data - chosen_atom @ sparse_vector
55 |
56 | angles = angle_grids[atom_index]
57 |
58 | return np.sort(angles)
59 |
60 |
61 | def l1_svd(
62 | received_data, num_signal, array, signal_fre, angle_grids, unit="deg"
63 | ):
64 | """L1 norm based sparse representation algorithms for DOA estimation
65 |
66 | Args:
67 | received_data : Array received signals
68 | num_signal : Number of signals
69 | array : Instance of array class
70 | signal_fre: Signal frequency
71 | angle_grids : Angle grids corresponding to spatial spectrum. It should
72 | be a numpy array.
73 | unit : Unit of angle, 'rad' for radians, 'deg' for degrees. Defaults to
74 | 'deg'.
75 |
76 | Reference:
77 | Malioutov, D., M. Cetin, and A.S. Willsky. “A Sparse Signal
78 | Reconstruction Perspective for Source Localization with Sensor Arrays.”
79 | IEEE Transactions on Signal Processing 53, no. 8 (August 2005): 3010-22.
80 | https://doi.org/10.1109/TSP.2005.850882.
81 | """
82 | # build the overcomplete basis
83 | a_over = array.steering_vector(signal_fre, angle_grids, unit=unit)
84 |
85 | num_samples = received_data.shape[1]
86 |
87 | _, _, vh = np.linalg.svd(received_data)
88 |
89 | d_k = np.vstack(
90 | (np.eye(num_signal), np.zeros((num_samples - num_signal, num_signal)))
91 | )
92 | y_sv = received_data @ vh.conj().transpose() @ d_k
93 |
94 | # solve the l1 norm problem using cvxpy
95 | p = cp.Variable()
96 | q = cp.Variable()
97 | r = cp.Variable(len(angle_grids))
98 | s_sv = cp.Variable((len(angle_grids), num_signal), complex=True)
99 |
100 | # constraints of the problem
101 | constraints = [cp.norm(y_sv - a_over @ s_sv, "fro") <= p, cp.sum(r) <= q]
102 | for i in range(len(angle_grids)):
103 | constraints.append(cp.norm(s_sv[i, :]) <= r[i])
104 |
105 | # objective function
106 | objective = cp.Minimize(p + 2 * q)
107 | prob = cp.Problem(objective, constraints)
108 |
109 | prob.solve()
110 |
111 | spectrum = s_sv.value
112 | spectrum = np.sum(np.abs(spectrum), axis=1)
113 |
114 | return spectrum
115 |
--------------------------------------------------------------------------------
/doa_py/algorithm/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def get_noise_space(corvariance_matrix, num_signal):
5 | eigenvalues, eigenvectors = np.linalg.eig(corvariance_matrix)
6 | sorted_index = np.argsort(np.abs(eigenvalues)) # ascending order
7 | noise_space = eigenvectors[:, sorted_index[:-num_signal]]
8 |
9 | return noise_space
10 |
11 |
12 | def get_signal_space(corvariance_matrix, num_signal):
13 | eigenvalues, eigenvectors = np.linalg.eig(corvariance_matrix)
14 | sorted_index = np.argsort(np.abs(eigenvalues)) # ascending order
15 | signal_space = eigenvectors[:, sorted_index[-num_signal:]]
16 |
17 | return signal_space
18 |
19 |
20 | def divide_into_fre_bins(
21 | received_data, num_groups, fs, f_min=None, f_max=None, n_fft_min=128
22 | ):
23 | """Do FFT on array signal of each channel, and divide signal into different
24 | frequency points.
25 |
26 | Args:
27 | received_data : array received signal
28 | num_groups : how many groups divide snapshots into
29 | fs : sampling frequency
30 | f_min: minimum frequency of interest
31 | f_max: maximum frequency of interest
32 | min_n_fft: minimum number of FFT points
33 |
34 | Returns:
35 | `signal_fre_bins`: a (m, n, l) tensor, in which m equals to number of
36 | antennas, n is equals to point of FFT, l is the number of groups
37 | `fre_bins`: corresponding freqeuncy of each point in FFT output
38 | """
39 | num_snapshots = received_data.shape[1]
40 |
41 | # number of sampling points in each group
42 | n_each_group = num_snapshots // num_groups
43 | if n_each_group < n_fft_min:
44 | n_fft = n_fft_min # zero padding when sampling points is not enough
45 | else:
46 | n_fft = n_each_group
47 |
48 | delta_f = fs / n_fft
49 | # there is a little trick to use as wider frequency range as possible
50 | idx_f_min = max(int(f_min / delta_f) - 1, 0) if f_min is not None else 0
51 | idx_f_max = (
52 | min(int(f_max / delta_f) + 1, n_fft // 2)
53 | if f_max is not None
54 | else n_fft // 2
55 | )
56 | idx_range = idx_f_max - idx_f_min + 1
57 |
58 | signal_fre_bins = np.zeros(
59 | (received_data.shape[0], idx_range, num_groups),
60 | dtype=np.complex128,
61 | )
62 | # do FTT separately in each group
63 | for group_i in range(num_groups):
64 | signal_fre_bins[:, :, group_i] = np.fft.fft(
65 | received_data[
66 | :, group_i * n_each_group : (group_i + 1) * n_each_group
67 | ],
68 | n=n_fft,
69 | axis=1,
70 | )[:, idx_f_min : idx_f_max + 1]
71 | fre_bins = np.fft.fftfreq(n_fft, 1 / fs)[idx_f_min : idx_f_max + 1]
72 |
73 | return signal_fre_bins, fre_bins
74 |
75 |
76 | def forward_backward_smoothing(received_data, subarray_size):
77 | num_elements = received_data.shape[0]
78 | num_subarrays = num_elements - subarray_size + 1
79 | smoothed_data = np.zeros(
80 | (subarray_size, subarray_size), dtype=np.complex128
81 | )
82 |
83 | # forward smoothing
84 | for i in range(num_subarrays):
85 | subarray = received_data[i : i + subarray_size, :]
86 | smoothed_data += np.cov(subarray)
87 |
88 | # backward smoothing
89 | matrix_j = np.fliplr(np.eye(subarray_size))
90 | smoothed_data += matrix_j @ smoothed_data.conj() @ matrix_j
91 |
92 | smoothed_data /= 2 * num_subarrays
93 |
94 | return smoothed_data
95 |
--------------------------------------------------------------------------------
/doa_py/arrays.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from abc import ABC
3 | from typing import Literal
4 |
5 | import numpy as np
6 | from scipy.linalg import toeplitz
7 | from typing_extensions import override
8 |
9 | from .signals import BroadSignal, NarrowSignal, RandomFreqSignal, Signal
10 |
11 | C = 3e8 # wave speed
12 |
13 |
14 | class Array(ABC):
15 | def __init__(
16 | self,
17 | element_position_x,
18 | element_position_y,
19 | element_position_z,
20 | rng=None,
21 | ):
22 | """element position should be defined in 3D (x, y, z) coordinate
23 | system"""
24 | self._element_position = np.vstack(
25 | (element_position_x, element_position_y, element_position_z)
26 | ).T
27 |
28 | self._ideal_position = self._element_position.copy()
29 |
30 | if rng is None:
31 | self._rng = np.random.default_rng()
32 | else:
33 | self._rng = rng
34 |
35 | @property
36 | def num_antennas(self):
37 | return self._element_position.shape[0]
38 |
39 | @property
40 | def array_position(self):
41 | return self._ideal_position
42 |
43 | def set_rng(self, rng):
44 | """Setting random number generator
45 |
46 | Args:
47 | rng (np.random.Generator): random generator used to generator random
48 | """
49 | self._rng = rng
50 |
51 | def _unify_unit(self, variable, unit):
52 | if unit == "deg":
53 | variable = variable / 180 * np.pi
54 |
55 | return variable
56 |
57 | def steering_vector(self, fre, angle_incidence, unit="deg"):
58 | """Calculate steering vector corresponding to the angle of incidence
59 |
60 | Args:
61 | fre (float): Frequency of carrier wave
62 | angle_incidence (float | np.ndarray): Incidence angle. If only
63 | azimuth is considered, `angle_incidence` is a 1xN dimensional
64 | matrix; if two dimensions are considered, `angle_incidence` is
65 | a 2xN dimensional matrix, where the first row is the azimuth and
66 | the second row is the elevation angle.
67 | unit: The unit of the angle, `rad` represents radian,
68 | `deg` represents degree. Defaults to 'deg'.
69 |
70 | Returns:
71 | If `angle_incidence` corresponds to single signal, return a steering
72 | vector of dimension `Mx1`. If `angle_incidence` is doa of k signals,
73 | return a steering maxtrix of dimension `Mxk`
74 | """
75 | if np.squeeze(angle_incidence).ndim == 1 or angle_incidence.size == 1:
76 | angle_incidence = np.vstack(
77 | (angle_incidence.reshape(1, -1), np.zeros(angle_incidence.size))
78 | )
79 |
80 | angle_incidence = self._unify_unit(
81 | np.reshape(angle_incidence, (2, -1)), unit
82 | )
83 |
84 | cos_cos = np.cos(angle_incidence[0]) * np.cos(angle_incidence[1])
85 | sin_cos = np.sin(angle_incidence[0]) * np.cos(angle_incidence[1])
86 | sin_ = np.sin(angle_incidence[1])
87 |
88 | # only the ideal position is known by algorithms
89 | time_delay = (
90 | 1 / C * self._ideal_position @ np.vstack((cos_cos, sin_cos, sin_))
91 | )
92 | steering_vector = np.exp(-1j * 2 * np.pi * fre * time_delay)
93 |
94 | return steering_vector
95 |
96 | def _steering_vector_with_error(self, fre, angle_incidence, unit="deg"):
97 | if np.squeeze(angle_incidence).ndim == 1 or angle_incidence.size == 1:
98 | angle_incidence = np.vstack(
99 | (angle_incidence.reshape(1, -1), np.zeros(angle_incidence.size))
100 | )
101 |
102 | angle_incidence = self._unify_unit(
103 | np.reshape(angle_incidence, (2, -1)), unit
104 | )
105 |
106 | cos_cos = np.cos(angle_incidence[0]) * np.cos(angle_incidence[1])
107 | sin_cos = np.sin(angle_incidence[0]) * np.cos(angle_incidence[1])
108 | sin_ = np.sin(angle_incidence[1])
109 |
110 | # used to generate received signal using the actual position
111 | time_delay = (
112 | 1 / C * self._element_position @ np.vstack((cos_cos, sin_cos, sin_))
113 | )
114 | steering_vector = np.exp(-1j * 2 * np.pi * fre * time_delay)
115 |
116 | # applying mutual coupling effects
117 | if (
118 | hasattr(self, "_coupling_matrix")
119 | and self._coupling_matrix is not None
120 | ):
121 | steering_vector = self._coupling_matrix @ steering_vector
122 |
123 | return steering_vector
124 |
125 | def add_position_error(self):
126 | """Add position error to the array
127 |
128 | Args:
129 | error_std (float): Standard deviation of the position error
130 | error_type (str): Type of the error, `gaussian` or `uniform`
131 | """
132 | warnings.warn(
133 | "This method is not implemented for {}".format(type(self).__name__),
134 | UserWarning,
135 | )
136 |
137 | def add_mutual_coupling(self, coupling_matrix=None):
138 | """Add mutual coupling effects to the array
139 |
140 | Args:
141 | coupling_matrix (np.ndarray): A square matrix representing mutual
142 | coupling between array elements. Should be of size
143 | (num_antennas, num_antennas).
144 | """
145 | if coupling_matrix is None:
146 | coupling_matrix = self.get_default_coupling_matrix()
147 |
148 | if coupling_matrix.shape != (self.num_antennas, self.num_antennas):
149 | raise ValueError(
150 | f"Coupling matrix shape {coupling_matrix.shape} does not match "
151 | f"the number of antennas {self.num_antennas}"
152 | )
153 |
154 | self._coupling_matrix = coupling_matrix
155 |
156 | def get_default_coupling_matrix(self):
157 | warnings.warn(
158 | "Default mutual coupling matrix is not provided for {}, no mutual "
159 | "coupling will be considered".format(type(self).__name__),
160 | UserWarning,
161 | )
162 | return np.eye(self.num_antennas)
163 |
164 | def add_correlation_matrix(self, correlation_matrix=None):
165 | """Add spatial correlation matrix to the array, which is used to
166 | generate spatially correlated noise.
167 |
168 | If this method is not called, use the spatially and temporally
169 | uncorrelated noise.
170 |
171 | If the `correlation_matrix` is not provided, use the default correlation
172 | matrix.
173 |
174 | Args:
175 | correlation_matrix (np.ndarray): A square matrix representing
176 | spatial correlation between array elements. Should be of size
177 | (num_antennas, num_antennas). Defaults to None.
178 | """
179 | if correlation_matrix is None:
180 | correlation_matrix = self.get_default_correlation_matrix()
181 |
182 | if correlation_matrix.shape != (self.num_antennas, self.num_antennas):
183 | raise ValueError(
184 | f"Correlation matrix shape {correlation_matrix.shape} does not "
185 | f"match the number of antennas {self.num_antennas}"
186 | )
187 |
188 | self._correlation_matrix = correlation_matrix
189 |
190 | def get_default_correlation_matrix(self):
191 | """Use identity matrix as the default correlation matrix, which results
192 | in spatially white noise.
193 |
194 | You can override this method to provide a different default correlation.
195 | """
196 | warnings.warn(
197 | "Default correlation matrix is not provided for {}, the generated "
198 | "noise will be spatially white noise".format(type(self).__name__),
199 | UserWarning,
200 | )
201 | return np.eye(self.num_antennas)
202 |
203 | def received_signal(
204 | self,
205 | signal: Signal,
206 | angle_incidence,
207 | snr=None,
208 | nsamples=100,
209 | amp=None,
210 | unit="deg",
211 | use_cache=False,
212 | calc_method: Literal["delay", "fft"] = "delay",
213 | **kwargs,
214 | ):
215 | """Generate array received signal based on array signal model
216 |
217 | If `broadband` is set to True, generate array received signal based on
218 | broadband signal's model.
219 |
220 | Args:
221 | signal: An instance of the `Signal` class
222 | angle_incidence: Incidence angle. If only azimuth is considered,
223 | `angle_incidence` is a 1xN dimensional matrix; if two dimensions
224 | are considered, `angle_incidence` is a 2xN dimensional matrix,
225 | where the first row is the azimuth and the second row is the
226 | elevation angle.
227 | snr: Signal-to-noise ratio. If set to None, no noise will be added
228 | nsamples (int): Number of snapshots, defaults to 100
229 | amp: The amplitude of each signal, 1d numpy array
230 | unit: The unit of the angle, `rad` represents radian,
231 | `deg` represents degree. Defaults to 'deg'.
232 | use_cache (bool): If True, use cache to generate identical signals
233 | (noise is random). Default to `False`.
234 | calc_method (str): Only used when generate broadband signal.
235 | Generate broadband signal in frequency domain, or time domain
236 | using delay. Defaults to `delay`.
237 | """
238 | # Convert the angle from degree to radians
239 | angle_incidence = self._unify_unit(angle_incidence, unit)
240 |
241 | if isinstance(signal, BroadSignal):
242 | received = self._gen_broadband(
243 | signal=signal,
244 | snr=snr,
245 | nsamples=nsamples,
246 | angle_incidence=angle_incidence,
247 | amp=amp,
248 | use_cache=use_cache,
249 | calc_method=calc_method,
250 | **kwargs,
251 | )
252 | if isinstance(signal, NarrowSignal):
253 | received = self._gen_narrowband(
254 | signal=signal,
255 | snr=snr,
256 | nsamples=nsamples,
257 | angle_incidence=angle_incidence,
258 | amp=amp,
259 | use_cache=use_cache,
260 | )
261 |
262 | return received
263 |
264 | def _gen_narrowband(
265 | self,
266 | signal: NarrowSignal,
267 | snr,
268 | nsamples,
269 | angle_incidence,
270 | amp,
271 | use_cache=False,
272 | ):
273 | """Generate narrowband received signal
274 |
275 | `azimuth` and `elevation` are already in radians
276 | """
277 | if angle_incidence.ndim == 1:
278 | num_signal = angle_incidence.size
279 | else:
280 | num_signal = angle_incidence.shape[1]
281 |
282 | incidence_signal = signal.gen(
283 | n=num_signal, nsamples=nsamples, amp=amp, use_cache=use_cache
284 | )
285 |
286 | if (
287 | hasattr(signal, "_multipath_enabled")
288 | and signal._multipath_enabled
289 | # multipath DOA is only supported for RandomFreqSignal
290 | and isinstance(signal, RandomFreqSignal)
291 | ):
292 | angle_incidence = self._get_multi_path_doa(
293 | angle_incidence, signal._num_paths
294 | )
295 | # used to get the multipath DOAs
296 | self._doa = angle_incidence
297 |
298 | manifold_matrix = self._steering_vector_with_error(
299 | signal.frequency, angle_incidence, unit="rad"
300 | )
301 |
302 | received = manifold_matrix @ incidence_signal
303 |
304 | if snr is not None:
305 | received = self._add_noise(received, snr)
306 |
307 | return received
308 |
309 | def _gen_broadband(
310 | self,
311 | signal: BroadSignal,
312 | snr,
313 | nsamples,
314 | angle_incidence,
315 | amp,
316 | use_cache=False,
317 | calc_method: Literal["delay", "fft"] = "delay",
318 | **kwargs,
319 | ):
320 | assert calc_method in ["fft", "delay"], "Invalid calculation method"
321 |
322 | if calc_method == "fft":
323 | return self._gen_broadband_fft(
324 | signal=signal,
325 | snr=snr,
326 | nsamples=nsamples,
327 | angle_incidence=angle_incidence,
328 | amp=amp,
329 | use_cache=use_cache,
330 | **kwargs,
331 | )
332 | else:
333 | return self._gen_broadband_delay(
334 | signal=signal,
335 | snr=snr,
336 | nsamples=nsamples,
337 | angle_incidence=angle_incidence,
338 | amp=amp,
339 | use_cache=use_cache,
340 | **kwargs,
341 | )
342 |
343 | def _gen_broadband_fft(
344 | self,
345 | signal: BroadSignal,
346 | snr,
347 | nsamples,
348 | angle_incidence,
349 | amp,
350 | use_cache=False,
351 | **kwargs,
352 | ):
353 | """Generate broadband received signal using FFT model
354 |
355 | `azimuth` and `elevation` are already in radians
356 | """
357 | if angle_incidence.ndim == 1:
358 | num_signal = angle_incidence.size
359 | else:
360 | num_signal = angle_incidence.shape[1]
361 |
362 | num_antennas = self._element_position.shape[0]
363 |
364 | incidence_signal = signal.gen(
365 | n=num_signal,
366 | nsamples=nsamples,
367 | amp=amp,
368 | use_cache=use_cache,
369 | delay=None,
370 | **kwargs,
371 | )
372 |
373 | # generate array signal in frequency domain
374 | signal_fre_domain = np.fft.fft(incidence_signal, axis=1)
375 |
376 | received_fre_domain = np.zeros(
377 | (num_antennas, nsamples), dtype=np.complex128
378 | )
379 | fre_points = np.fft.fftfreq(nsamples, 1 / signal.fs)
380 | for i, fre in enumerate(fre_points):
381 | manifold_fre = self._steering_vector_with_error(
382 | fre, angle_incidence, unit="rad"
383 | )
384 |
385 | # calculate array received signal at every frequency point
386 | received_fre_domain[:, i] = manifold_fre @ signal_fre_domain[:, i]
387 |
388 | received = np.fft.ifft(received_fre_domain, axis=1)
389 |
390 | if snr is not None:
391 | received = self._add_noise(received, snr)
392 |
393 | return received
394 |
395 | def _gen_broadband_delay(
396 | self,
397 | signal: BroadSignal,
398 | snr,
399 | nsamples,
400 | angle_incidence,
401 | amp,
402 | use_cache=False,
403 | **kwargs,
404 | ):
405 | """Generate broadband received signal by applying delay
406 |
407 | `azimuth` and `elevation` are already in radians
408 | """
409 | if angle_incidence.ndim == 1:
410 | num_signal = angle_incidence.size
411 | else:
412 | num_signal = angle_incidence.shape[1]
413 |
414 | if np.squeeze(angle_incidence).ndim == 1 or angle_incidence.size == 1:
415 | angle_incidence = np.vstack(
416 | (angle_incidence.reshape(1, -1), np.zeros(angle_incidence.size))
417 | )
418 |
419 | angle_incidence = np.reshape(angle_incidence, (2, -1))
420 |
421 | # calculate time delay
422 | cos_cos = np.cos(angle_incidence[0]) * np.cos(angle_incidence[1])
423 | sin_cos = np.sin(angle_incidence[0]) * np.cos(angle_incidence[1])
424 | sin_ = np.sin(angle_incidence[1])
425 | time_delay = -(
426 | 1 / C * self._element_position @ np.vstack((cos_cos, sin_cos, sin_))
427 | )
428 |
429 | received = np.zeros((self.num_antennas, nsamples), dtype=np.complex128)
430 |
431 | # clear cache if not use cache
432 | if not use_cache:
433 | signal.clear_cache()
434 |
435 | # must use cache as the same signal is received by different antennas
436 | for i in range(self.num_antennas):
437 | received[i, :] = np.sum(
438 | signal.gen(
439 | n=num_signal,
440 | nsamples=nsamples,
441 | amp=amp,
442 | use_cache=True,
443 | delay=time_delay[i, :],
444 | **kwargs,
445 | ),
446 | axis=0,
447 | )
448 |
449 | if snr is not None:
450 | received = self._add_noise(received, snr)
451 |
452 | return received
453 |
454 | def _add_noise(self, signal, snr_db):
455 | sig_pow = np.mean(np.abs(signal) ** 2, axis=1)
456 | noise_pow = sig_pow / 10 ** (snr_db / 10)
457 |
458 | noise = (np.sqrt(noise_pow / 2)).reshape(-1, 1) * (
459 | self._rng.standard_normal(size=signal.shape)
460 | + 1j * self._rng.standard_normal(size=signal.shape)
461 | )
462 |
463 | # if spatial correlation matrix is provided, add correlated noise
464 | if hasattr(self, "_correlation_matrix"):
465 | # use Cholesky decomposition to generate spatially correlated noise
466 | matrix_sqrt = np.linalg.cholesky(self._correlation_matrix)
467 | noise = matrix_sqrt @ noise
468 |
469 | return signal + noise
470 |
471 | def _get_multi_path_doa(self, real_doa, num_paths):
472 | """Generate multipath DOAs
473 |
474 | Args:
475 | real_doa: Real DOAs
476 | num_paths: number of paths
477 | """
478 | if real_doa.ndim == 1:
479 | num_signal = real_doa.size
480 | else:
481 | warnings.warn("Only 1D DOA is supported", UserWarning)
482 |
483 | multipath_angles = np.zeros(num_signal * (num_paths + 1))
484 | multipath_angles[:num_signal] = real_doa
485 |
486 | for i in range(num_paths):
487 | # angle offsets between -30 and 30 degrees
488 | angle_offsets = self._rng.uniform(-np.pi / 6, np.pi / 6, num_signal)
489 | # add offsets to the original angles
490 | multipath_angles[(i + 1) * num_signal : (i + 2) * num_signal] = (
491 | angle_offsets + real_doa
492 | )
493 |
494 | return multipath_angles
495 |
496 |
497 | class UniformLinearArray(Array):
498 | def __init__(self, m: int, dd: float, rng=None):
499 | """Uniform linear array.
500 |
501 | The array is uniformly arranged along the y-axis.
502 |
503 | Args:
504 | m (int): number of antenna elements
505 | dd (float): distance between adjacent antennas
506 | rng (np.random.Generator): random generator used to generator random
507 | """
508 | # antenna position in (x, y, z) coordinate system
509 | element_position_x = np.zeros(m)
510 | element_position_y = np.arange(m) * dd
511 | element_position_z = np.zeros(m)
512 |
513 | super().__init__(
514 | element_position_x, element_position_y, element_position_z, rng
515 | )
516 |
517 | @override
518 | def add_position_error(self, error_std=0.3, error_type="uniform"):
519 | dd = self._ideal_position[1, 1] - self._ideal_position[0, 1]
520 | # the error is added to the distance between adjacent antennas
521 | sigma = error_std * dd
522 | if error_type == "gaussian":
523 | error = self._rng.normal(0, sigma, self.num_antennas)
524 | elif error_type == "uniform":
525 | error = self._rng.uniform(-sigma, sigma, self.num_antennas)
526 | else:
527 | raise ValueError("Invalid error type")
528 |
529 | self._element_position[:, 1] = self._ideal_position[:, 1] + error
530 |
531 | @override
532 | def get_default_coupling_matrix(self, rho=0.6):
533 | """Add mutual coupling effects to the array
534 |
535 | Args:
536 | rho (float): amplitude of the mutual coupling
537 | coupling_matrix (np.ndarray): A square matrix representing mutual
538 | coupling between array elements. Should be of size
539 | (num_antennas, num_antennas).
540 | """
541 | # reference: Liu, Zhang-Meng, Chenwei Zhang, and Philip S. Yu.
542 | # “Direction-of-Arrival Estimation Based on Deep Neural Networks
543 | # With Robustness to Array Imperfections.” IEEE Transactions on
544 | # Antennas and Propagation 66, no. 12 (December 2018): 7315–27.
545 | # https://doi.org/10.1109/TAP.2018.2874430.
546 |
547 | coefficient = (rho * np.exp(1j * np.pi / 3)) ** np.arange(
548 | self.num_antennas
549 | )
550 | coefficient[0] = 0
551 | coupling_matrix = toeplitz(coefficient) + np.eye(self.num_antennas)
552 |
553 | return coupling_matrix
554 |
555 | @override
556 | def get_default_correlation_matrix(self, rho=0.5):
557 | # reference: Agrawal, M., and S. Prasad. “A Modified Likelihood
558 | # Function Approach to DOA Estimation in the Presence of Unknown
559 | # Spatially Correlated Gaussian Noise Using a Uniform Linear Array.”
560 | # IEEE Transactions on Signal Processing 48, no. 10 (October 2000):
561 | # 2743–49. https://doi.org/10.1109/78.869024.
562 | correlation_matrix = (rho ** np.arange(self.num_antennas)) * np.exp(
563 | -1j * np.pi / 2 * np.arange(self.num_antennas)
564 | )
565 | correlation_matrix = toeplitz(correlation_matrix)
566 |
567 | return correlation_matrix
568 |
569 |
570 | class UniformCircularArray(Array):
571 | def __init__(self, m, r, rng=None):
572 | """Uniform circular array.
573 |
574 | The origin is taken as the center of the circle, and the
575 | counterclockwise direction is considered as the positive direction.
576 |
577 | Args:
578 | m (int): Number of antennas.
579 | r (float): Radius of the circular array.
580 | rng (optional): Random number generator. Defaults to None.
581 | """
582 | self._radius = r
583 |
584 | element_position_x = r * np.cos(2 * np.pi * np.arange(m) / m)
585 | element_position_y = r * np.sin(2 * np.pi * np.arange(m) / m)
586 | element_position_z = np.zeros(m)
587 |
588 | super().__init__(
589 | element_position_x, element_position_y, element_position_z, rng
590 | )
591 |
592 | @property
593 | def radius(self):
594 | return self._radius
595 |
--------------------------------------------------------------------------------
/doa_py/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from scipy.signal import find_peaks
4 | from skimage.feature import peak_local_max
5 |
6 |
7 | def plot_spatial_spectrum(
8 | spectrum,
9 | angle_grids,
10 | num_signal,
11 | ground_truth=None,
12 | x_label="Angle",
13 | y_label="Spectrum",
14 | ):
15 | """Plot spatial spectrum
16 |
17 | Args:
18 | spectrum: Spatial spectrum estimated by the algorithm
19 | angle_grids: Angle grids corresponding to the spatial spectrum
20 | num_signal: Number of signals
21 | ground_truth: True incident angles
22 | x_label: x-axis label
23 | y_label: y-axis label
24 | """
25 | spectrum = spectrum / np.max(spectrum)
26 | # find peaks and peak heights
27 | peaks_idx, heights = find_peaks(spectrum, height=0)
28 |
29 | idx = heights["peak_heights"].argsort()[-num_signal:]
30 | peaks_idx = peaks_idx[idx]
31 | heights = heights["peak_heights"][idx]
32 |
33 | angles = angle_grids[peaks_idx]
34 |
35 | fig = plt.figure()
36 | ax = fig.add_subplot(1, 1, 1)
37 |
38 | # set ticks
39 | grids_min = angle_grids[0]
40 | grids_max = angle_grids[-1]
41 | major_space = (grids_max - grids_min + 1) / 6
42 | minor_space = major_space / 5
43 | ax.set_xlim(grids_min, grids_max)
44 | ax.xaxis.set_major_locator(plt.MultipleLocator(major_space))
45 | ax.xaxis.set_minor_locator(plt.MultipleLocator(minor_space))
46 |
47 | # plot spectrum
48 | ax.plot(angle_grids, spectrum)
49 | ax.set_yscale("log")
50 |
51 | # plot peaks
52 | ax.scatter(angles, heights, color="red", marker="x")
53 | for i, angle in enumerate(angles):
54 | ax.annotate(angle, xy=(angle, heights[i]))
55 |
56 | # ground truth
57 | if ground_truth is not None:
58 | for angle in ground_truth:
59 | ax.axvline(x=angle, color="green", linestyle="--")
60 |
61 | # set labels
62 | if ground_truth is not None:
63 | ax.legend(["Spectrum", "Estimated", "Ground Truth"])
64 | else:
65 | ax.legend(["Spectrum", "Estimated"])
66 |
67 | ax.set_xlabel(x_label)
68 | ax.set_ylabel(y_label)
69 |
70 | plt.show()
71 |
72 |
73 | def plot_estimated_value(
74 | estimates,
75 | ticks_min=-90,
76 | ticks_max=90,
77 | ground_truth=None,
78 | x_label="Angle",
79 | y_label="Spectrum",
80 | ):
81 | """Display estimated angle values
82 |
83 | Args:
84 | estimates: Angle estimates
85 | ticks_min (int, optional): Minimum value for x-axis ticks.
86 | Defaults to -90.
87 | ticks_max (int, optional): Maximum value for x-axis ticks.
88 | Defaults to 90.
89 | ground_truth: True incident angles
90 | x_label (str, optional): x-axis label. Defaults to "Angle".
91 | y_label (str, optional): y-axis label. Defaults to "Spetrum".
92 | """
93 | fig = plt.figure()
94 | ax = fig.add_subplot(1, 1, 1)
95 |
96 | # set ticks
97 | major_space = (ticks_max - ticks_min) / 6
98 | minor_space = major_space / 5
99 | ax.set_xlim(ticks_min, ticks_max)
100 | ax.xaxis.set_major_locator(plt.MultipleLocator(major_space))
101 | ax.xaxis.set_minor_locator(plt.MultipleLocator(minor_space))
102 |
103 | # ground truth
104 | if ground_truth is not None:
105 | for angle in ground_truth:
106 | truth_line = ax.axvline(x=angle, color="c", linestyle="--")
107 |
108 | # plot estimates
109 | for angle in estimates:
110 | estimate_line = ax.axvline(x=angle, color="r", linestyle="--")
111 |
112 | # set labels
113 | ax.set_xlabel(x_label)
114 | ax.set_ylabel(y_label)
115 |
116 | # set legend
117 | if ground_truth is not None:
118 | ax.legend([truth_line, estimate_line], ["Ground Truth", "Estimated"])
119 | else:
120 | ax.legend([estimate_line], ["Estimated"])
121 |
122 | plt.show()
123 |
124 |
125 | def plot_spatial_spectrum_2d(
126 | spectrum,
127 | azimuth_grids,
128 | elevation_grids,
129 | num_signal,
130 | ground_truth=None,
131 | x_label="Elevation",
132 | y_label="Azimuth",
133 | z_label="Spectrum",
134 | ):
135 | """Plot 2D spatial spectrum
136 |
137 | Args:
138 | spectrum: Spatial spectrum estimated by the algorithm
139 | azimuth_grids : Azimuth grids corresponding to the spatial spectrum
140 | elevation_grids : Elevation grids corresponding to the spatial spectrum
141 | num_signal: Number of signals
142 | ground_truth: True incident angles
143 | x_label: x-axis label
144 | y_label: y-axis label
145 | z_label : x-axis label. Defaults to "Spectrum".
146 | """
147 | x, y = np.meshgrid(elevation_grids, azimuth_grids)
148 | spectrum = spectrum / spectrum.max()
149 | # Find the peaks in the surface
150 | peaks = peak_local_max(spectrum, num_peaks=num_signal)
151 | spectrum = np.log(spectrum + 1e-10)
152 |
153 | fig = plt.figure()
154 | ax = fig.add_subplot(1, 1, 1, projection="3d")
155 |
156 | # plot spectrum
157 | surf = ax.plot_surface(x, y, spectrum, cmap="viridis", antialiased=True)
158 | # Plot the peaks on the surface
159 | for peak in peaks:
160 | peak_dot = ax.scatter(
161 | x[peak[0], peak[1]],
162 | y[peak[0], peak[1]],
163 | spectrum[peak[0], peak[1]],
164 | c="r",
165 | marker="x",
166 | )
167 | ax.text(
168 | x[peak[0], peak[1]],
169 | y[peak[0], peak[1]],
170 | spectrum[peak[0], peak[1]],
171 | "({}, {})".format(x[peak[0], peak[1]], y[peak[0], peak[1]]),
172 | )
173 | # plot ground truth
174 | if ground_truth is not None:
175 | truth_lines = ax.stem(
176 | ground_truth[1],
177 | ground_truth[0],
178 | np.ones_like(ground_truth[0]),
179 | bottom=spectrum.min(),
180 | linefmt="g--",
181 | markerfmt=" ",
182 | basefmt=" ",
183 | )
184 |
185 | if ground_truth is not None:
186 | ax.legend(
187 | [surf, truth_lines, peak_dot],
188 | ["Spectrum", "Estimated", "Ground Truth"],
189 | )
190 | else:
191 | ax.legend([surf, peak_dot], ["Spectrum", "Estimated"])
192 |
193 | ax.set_xlabel(x_label)
194 | ax.set_ylabel(y_label)
195 | ax.set_zlabel(z_label)
196 |
197 | plt.show()
198 |
199 |
200 | def plot_estimated_value_2d(
201 | estimated_azimuth,
202 | estimated_elevation,
203 | ground_truth=None,
204 | unit="deg",
205 | ):
206 | """Display estimated angle values"""
207 | if unit == "deg":
208 | estimated_azimuth = estimated_azimuth / 180 * np.pi
209 |
210 | fig = plt.figure()
211 | ax = fig.add_subplot(1, 1, 1, projection="polar")
212 |
213 | if ground_truth is not None:
214 | ground_truth = ground_truth.astype(float)
215 | ground_truth[0] = ground_truth[0] / 180 * np.pi
216 | ax.scatter(ground_truth[0], ground_truth[1], marker="o", color="g")
217 | ax.scatter(estimated_azimuth, estimated_elevation, marker="x", color="r")
218 |
219 | ax.set_rlabel_position(90)
220 |
221 | for i in range(len(estimated_azimuth)):
222 | ax.annotate(
223 | "({:.2f}, {:.2f})".format(
224 | estimated_azimuth[i] / np.pi * 180, estimated_elevation[i]
225 | ),
226 | (estimated_azimuth[i], estimated_elevation[i]),
227 | )
228 |
229 | ax.set_xticks(np.arange(0, 2 * np.pi, step=np.pi / 6))
230 | ax.set_rlim([0, 90])
231 | ax.set_yticks(np.arange(0, 90, step=15))
232 |
233 | ax.legend(["Ground Truth", "Estimated"])
234 |
235 | plt.show()
236 |
--------------------------------------------------------------------------------
/doa_py/signals.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from abc import ABC, abstractmethod
3 | from typing import Any, Literal, Optional, Union
4 |
5 | import numpy as np
6 | import numpy.typing as npt
7 | from typing_extensions import override
8 |
9 | ListLike = Union[npt.NDArray[np.number], list[int | float | complex]]
10 |
11 |
12 | class Signal(ABC):
13 | """Base class for all signal classes
14 |
15 | Signals that inherit from this base class must implement the gen() method to
16 | generate simulated sampled signals.
17 | """
18 |
19 | def __init__(self, rng: Optional[np.random.Generator] = None):
20 | if rng is None:
21 | self._rng = np.random.default_rng()
22 | else:
23 | self._rng = rng
24 |
25 | # caches used to generate the identical signals
26 | self._cache = {}
27 |
28 | def set_rng(self, rng: np.random.Generator):
29 | """Setting random number generator
30 |
31 | Args:
32 | rng (np.random.Generator): random generator used to generate random
33 | numbers
34 | """
35 | self._rng = rng
36 |
37 | def _set_cache(self, key: str, value: Any):
38 | """Set cache value
39 |
40 | Args:
41 | key (str): Cache key
42 | value (Any): Cache value
43 | """
44 | self._cache[key] = value
45 |
46 | def clear_cache(self):
47 | self._cache = {}
48 |
49 | def _get_amp(
50 | self,
51 | amp: Optional[ListLike],
52 | n: int,
53 | ) -> npt.NDArray[np.number]:
54 | # Default to generate signals with equal amplitudes
55 | if amp is None:
56 | amp = np.diag(np.ones(n))
57 | else:
58 | if not isinstance(amp, np.ndarray):
59 | amp = np.array(amp)
60 |
61 | amp = np.squeeze(amp)
62 | if not (amp.ndim == 1 and amp.size == n):
63 | raise TypeError(
64 | "amp should be an 1D array of size n = {}".format(n)
65 | )
66 |
67 | amp = np.diag(amp)
68 |
69 | return amp
70 |
71 | @abstractmethod
72 | def gen(
73 | self,
74 | n: int,
75 | nsamples: int,
76 | amp=Optional[ListLike],
77 | use_cache: bool = False,
78 | ) -> npt.NDArray[np.complex128]:
79 | """Generate sampled signals
80 |
81 | Args:
82 | n (int): Number of signals
83 | nsamples (int): Number of snapshots
84 | amp (np.array): Amplitude of the signals (1D array of size n), used
85 | to define different amplitudes for different signals.
86 | By default it will generate equal amplitude signal.
87 | use_cache (bool): If True, use cache to generate identical signals.
88 | Default to `False`.
89 |
90 | Returns:
91 | signal (np.array): Sampled signals
92 | """
93 | pass
94 |
95 | def add_multipath(self, num_paths=1):
96 | """Add multipath effect to the signal
97 |
98 | Args:
99 | num_paths (int): Number of multipath components
100 | """
101 | self._multipath_enabled = True
102 | self._num_paths = num_paths
103 |
104 | def _generate_multipath(self):
105 | """Generate multipath components"""
106 | warnings.warn("This method is not implemented", UserWarning)
107 |
108 |
109 | class NarrowSignal(Signal):
110 | def __init__(
111 | self, fc: Union[int, float], rng: Optional[np.random.Generator] = None
112 | ):
113 | """Narrowband signal
114 |
115 | Args:
116 | fc (float): Signal frequency
117 | rng (np.random.Generator): Random generator used to generate random
118 | numbers
119 | """
120 | self._fc = fc
121 |
122 | super().__init__(rng=rng)
123 |
124 | @property
125 | def frequency(self):
126 | """Frequency of narrowband signal"""
127 | return self._fc
128 |
129 | @abstractmethod
130 | def gen(
131 | self,
132 | n: int,
133 | nsamples: int,
134 | amp: Optional[ListLike] = None,
135 | use_cache: bool = False,
136 | ) -> npt.NDArray[np.complex128]:
137 | pass
138 |
139 |
140 | class ComplexStochasticSignal(NarrowSignal):
141 | def __init__(self, fc, rng=None):
142 | """Complex stochastic signal (complex exponential form of random phase
143 | signal)
144 |
145 | Args:
146 | fc (float): Signal frequency
147 | rng (np.random.Generator): Random generator used to generate random
148 | numbers
149 | """
150 | super().__init__(fc, rng)
151 |
152 | def gen(
153 | self, n, nsamples, amp=None, use_cache=False
154 | ) -> npt.NDArray[np.complex128]:
155 | amp = self._get_amp(amp, n)
156 |
157 | if use_cache and not self._cache == {}:
158 | # use cache
159 | real = self._cache["real"]
160 | imag = self._cache["imag"]
161 | assert real.shape == (n, nsamples) and imag.shape == (
162 | n,
163 | nsamples,
164 | ), "Cache shape mismatch"
165 | else:
166 | # Generate random amp
167 | real = self._rng.standard_normal(size=(n, nsamples))
168 | imag = self._rng.standard_normal(size=(n, nsamples))
169 | self._set_cache("real", real)
170 | self._set_cache("imag", imag)
171 |
172 | # Generate complex envelope
173 | signal = amp @ (np.sqrt(1 / 2) * (real + 1j * imag))
174 | return signal
175 |
176 |
177 | class RandomFreqSignal(NarrowSignal):
178 | def __init__(
179 | self,
180 | fc: Union[int, float],
181 | freq_ratio: float = 0.05,
182 | coherent: bool = False,
183 | rng: Optional[np.random.Generator] = None,
184 | ):
185 | """Random frequency signal
186 |
187 | Args:
188 | fc (float): Signal frequency
189 | freq_ratio (float): Ratio of the maximum frequency deviation from fc
190 | coherent (bool): If True, generate coherent signals
191 | rng (np.random.Generator): Random generator used to generate random
192 | numbers
193 | """
194 | super().__init__(fc, rng)
195 |
196 | assert (
197 | 0 < freq_ratio < 0.1
198 | ), "This signal must be narrowband: freq_ratio in (0, 0.1)"
199 | self._freq_ratio = freq_ratio
200 |
201 | self._coherent = coherent
202 |
203 | def gen(self, n, nsamples, amp=None, use_cache=False):
204 | amp = self._get_amp(amp, n)
205 |
206 | if use_cache and not self._cache == {}:
207 | freq = self._cache["freq"]
208 | phase = self._cache["phase"]
209 | assert freq.shape == (n, 1) and phase.shape == (
210 | n,
211 | 1,
212 | ), "Cache shape mismatch"
213 | else:
214 | # Generate random phase signal
215 | if self._coherent:
216 | freq = self._rng.uniform(
217 | 0, self._freq_ratio * self._fc
218 | ) * np.ones((n, 1))
219 | else:
220 | freq = self._rng.uniform(
221 | 0, self._freq_ratio * self._fc, size=(n, 1)
222 | )
223 | phase = self._rng.uniform(0, 2 * np.pi, size=(n, 1))
224 | self._set_cache("freq", freq)
225 | self._set_cache("phase", phase)
226 |
227 | fs = self._fc * self._freq_ratio * 5
228 | # Generate random frequency signal
229 | signal = (
230 | amp
231 | @ np.exp(1j * 2 * np.pi * freq / fs * np.arange(nsamples))
232 | * np.exp(1j * phase) # phase
233 | )
234 |
235 | # add multipath effect
236 | if hasattr(self, "_multipath_enabled") and self._multipath_enabled:
237 | path_signal = self._generate_multipath(signal, self.frequency)
238 | signal = np.vstack([signal, path_signal])
239 |
240 | return signal
241 |
242 | @override
243 | def _generate_multipath(self, signal_data, signal_fre):
244 | num_signal = signal_data.shape[0]
245 | num_snapshots = signal_data.shape[1]
246 |
247 | # -- add multi path delay and amplitude --------------------------------
248 | relative_amplitudes = self._rng.uniform(
249 | 0.3, 0.9, (num_signal, self._num_paths)
250 | )
251 | relative_phases = self._rng.uniform(
252 | 0, 2 * np.pi, (num_signal, self._num_paths)
253 | )
254 |
255 | path_signal = np.zeros(
256 | (num_signal * self._num_paths, num_snapshots), dtype=np.complex128
257 | )
258 | for i in range(self._num_paths):
259 | path_signal[i * num_signal : (i + 1) * num_signal, :] = (
260 | relative_amplitudes[:, i].reshape(-1, 1)
261 | * signal_data
262 | * np.exp(1j * relative_phases[:, i].reshape(-1, 1))
263 | )
264 |
265 | return path_signal
266 |
267 |
268 | class BroadSignal(Signal):
269 | def __init__(
270 | self,
271 | f_min: Union[int, float],
272 | f_max: Union[int, float],
273 | fs: Union[int, float],
274 | min_length_ratio: float = 0.1,
275 | no_overlap: bool = False,
276 | rng: Optional[np.random.Generator] = None,
277 | ):
278 | self._f_min = f_min
279 | self._f_max = f_max
280 | self._fs = fs
281 | self._min_length_ratio = min_length_ratio
282 | self._no_overlap = no_overlap
283 |
284 | super().__init__(rng=rng)
285 |
286 | @property
287 | def fs(self):
288 | return self._fs
289 |
290 | @property
291 | @abstractmethod
292 | def f_min(self):
293 | pass
294 |
295 | @property
296 | @abstractmethod
297 | def f_max(self):
298 | pass
299 |
300 | @abstractmethod
301 | def gen(
302 | self,
303 | n: int,
304 | nsamples: int,
305 | amp: Optional[ListLike] = None,
306 | use_cache: bool = False,
307 | delay: Optional[Union[npt.NDArray, int, float]] = None,
308 | ) -> npt.NDArray[np.complex128]:
309 | """Generate sampled signals
310 |
311 | Args:
312 | n (int): Number of signals
313 | nsamples (int): Number of snapshots
314 | amp (np.array): Amplitude of the signals (1D array of size n), used
315 | to define different amplitudes for different signals.
316 | By default it will generate equal amplitude signal.
317 | use_cache (bool): If True, use cache to generate identical signals.
318 | Default to `False`.
319 | delay (float | None): If not None, apply delay to all signals.
320 |
321 | Returns:
322 | signal (np.array): Sampled signals
323 | """
324 | pass
325 |
326 | def _gen_fre_bands(self, n: int):
327 | """Generate frequency ranges for each boardband signal
328 |
329 | Args:
330 | n (int): Number of signals
331 |
332 | Returns:
333 | ranges (np.array): Frequency ranges for each signal with shape
334 | (n, 2)
335 | """
336 | if self._no_overlap:
337 | return self._gen_fre_bands_no_overlapping(n)
338 | return self._gen_fre_bands_overlapping(n)
339 |
340 | def _gen_fre_bands_overlapping(self, n: int):
341 | """Generate frequency bands may overlapping."""
342 | min_length = (self._f_max - self._f_min) * self._min_length_ratio
343 | bands = np.zeros((n, 2))
344 | for i in range(n):
345 | length = self._rng.uniform(min_length, self._f_max - self._f_min)
346 | start = self._rng.uniform(self._f_min, self._f_max - length)
347 | bands[i] = [start, start + length]
348 | return bands
349 |
350 | def _gen_fre_bands_no_overlapping(self, n: int):
351 | """Generate non-overlapping frequency bands."""
352 | max_length = (self._f_max - self._f_min) // n
353 | min_length = max_length * self._min_length_ratio
354 |
355 | bands = np.zeros((n, 2))
356 |
357 | for i in range(n):
358 | length = self._rng.uniform(min_length, max_length)
359 | start = self._rng.uniform(
360 | self._f_min + i * max_length,
361 | self._f_min + (i + 1) * max_length - length,
362 | )
363 | new_band = [start, start + length]
364 |
365 | bands[i] = new_band
366 |
367 | return bands
368 |
369 |
370 | class ChirpSignal(BroadSignal):
371 | def __init__(
372 | self,
373 | f_min,
374 | f_max,
375 | fs,
376 | min_length_ratio: float = 0.1,
377 | no_overlap: bool = False,
378 | rng=None,
379 | ):
380 | """Chirp signal
381 |
382 | Args:
383 | f_min (float): Minimum frequency
384 | f_max (float): Maximum frequency
385 | fs (int | float): Sampling frequency
386 | min_length_ratio (float): Minimum length ratio of the frequency
387 | band in (f_max - f_min)
388 | no_overlap (bool): If True, generate signals with non-overlapping
389 | bands
390 | rng (np.random.Generator): Random generator used to generate random
391 | numbers
392 | """
393 | super().__init__(
394 | f_min, f_max, fs, min_length_ratio, no_overlap, rng=rng
395 | )
396 |
397 | @property
398 | def f_min(self):
399 | if "fre_ranges" not in self._cache:
400 | raise ValueError("fre_ranges not in cache")
401 | return np.min(self._cache["fre_ranges"][:, 0])
402 |
403 | @property
404 | def f_max(self):
405 | if "fre_ranges" not in self._cache:
406 | raise ValueError("fre_ranges not in cache")
407 | return np.max(self._cache["fre_ranges"][:, 1])
408 |
409 | def gen(
410 | self,
411 | n,
412 | nsamples,
413 | amp=None,
414 | use_cache=False,
415 | delay=None,
416 | ) -> npt.NDArray[np.complex128]:
417 | amp = self._get_amp(amp, n)
418 |
419 | # use cache
420 | if use_cache and not self._cache == {}:
421 | fre_ranges = self._cache["fre_ranges"]
422 | phase = self._cache["phase"]
423 | assert fre_ranges.shape == (n, 2) and phase.shape == (
424 | n,
425 | 1,
426 | ), "Cache shape mismatch"
427 | # generate new and write to cache
428 | else:
429 | fre_ranges = self._gen_fre_bands(n)
430 | phase = self._rng.uniform(0, 2 * np.pi, size=(n, 1))
431 | self._set_cache("fre_ranges", fre_ranges)
432 | self._set_cache("phase", phase)
433 |
434 | t = np.arange(nsamples) * 1 / self._fs
435 |
436 | # start freq
437 | f0 = fre_ranges[:, 0]
438 | # freq move to f1 in t
439 | k = (fre_ranges[:, 1] - fre_ranges[:, 0]) / t[-1]
440 |
441 | if delay is not None:
442 | if isinstance(delay, (int, float)):
443 | delay = np.ones(n) * delay
444 | t = t + delay.reshape(-1, 1)
445 |
446 | signal = np.exp(
447 | 1j
448 | * 2
449 | * np.pi
450 | * (f0.reshape(-1, 1) * t + 0.5 * k.reshape(-1, 1) * t**2)
451 | ) * np.exp(1j * phase)
452 |
453 | signal = amp @ signal
454 |
455 | return signal
456 |
457 |
458 | class MultiFreqSignal(BroadSignal):
459 | def __init__(
460 | self,
461 | f_min: Union[int, float],
462 | f_max: Union[int, float],
463 | fs: Union[int, float],
464 | min_length_ratio: float = 0.1,
465 | no_overlap: bool = False,
466 | rng: Optional[np.random.Generator] = None,
467 | ncarriers: int = 100,
468 | ):
469 | """Broadband signal consisting of mulitple narrowband signals modulated
470 | on different carrier frequencies.
471 |
472 | Args:
473 | f_min (float): Minimum frequency
474 | f_max (float): Maximum frequency
475 | fs (int | float): Sampling frequency
476 | min_length_ratio (float): Minimum length ratio of the frequency
477 | band in (f_max - f_min)
478 | no_overlap (bool): If True, generate signals with non-overlapping
479 | bands
480 | rng (np.random.Generator): Random generator used to generate random
481 | numbers
482 | ncarriers (int): Number of carrier frequencies in each broadband
483 | """
484 | super().__init__(
485 | f_min, f_max, fs, min_length_ratio, no_overlap, rng=rng
486 | )
487 |
488 | self._ncarriers = ncarriers
489 |
490 | @property
491 | def f_min(self):
492 | if "fres" not in self._cache:
493 | raise ValueError("fres not in cache")
494 | return np.min(self._cache["fres"])
495 |
496 | @property
497 | def f_max(self):
498 | if "fres" not in self._cache:
499 | raise ValueError("fres not in cache")
500 | return np.max(self._cache["fres"])
501 |
502 | def gen(
503 | self,
504 | n,
505 | nsamples,
506 | amp=None,
507 | use_cache=False,
508 | delay=None,
509 | ) -> npt.NDArray[np.complex128]:
510 | amp = self._get_amp(amp, n)
511 | """Generate sampled signals
512 |
513 | Args:
514 | n (int): Number of signals
515 | nsamples (int): Number of snapshots
516 | amp (np.array): Amplitude of the signals (1D array of size n), used
517 | to define different amplitudes for different signals.
518 | By default it will generate equal amplitude signal.
519 | use_cache (bool): If True, use cache to generate identical signals.
520 | Default to `False`.
521 | delay (float | None): If not None, apply delay to all signals.
522 |
523 | Returns:
524 | signal (np.array): Sampled signals
525 | """
526 |
527 | if use_cache and not self._cache == {}:
528 | fres = self._cache["fres"]
529 | phase = self._cache["phase"]
530 | assert fres.shape == (n, self._ncarriers) and phase.shape == (
531 | n,
532 | self._ncarriers,
533 | 1,
534 | ), "Cache shape mismatch"
535 | else:
536 | fre_ranges = self._gen_fre_bands(n)
537 | # generate random carrier frequencies
538 | fres = self._rng.uniform(
539 | fre_ranges[:, 0].reshape(-1, 1),
540 | fre_ranges[:, 1].reshape(-1, 1),
541 | size=(n, self._ncarriers),
542 | )
543 | phase = self._rng.uniform(
544 | 0, 2 * np.pi, size=(n, self._ncarriers, 1)
545 | )
546 | self._set_cache("fres", fres)
547 | self._set_cache("phase", phase)
548 |
549 | t = np.arange(nsamples) * (1 / self._fs)
550 |
551 | if delay is not None:
552 | if isinstance(delay, (int, float)):
553 | delay = np.ones(n) * delay
554 | t = t + delay.reshape(-1, 1) # t is broadcasted to (n, nsamples)
555 | # let t able to be broadcasted where calculating `signal`
556 | t = np.expand_dims(t, axis=1)
557 |
558 | signal = np.sum(
559 | np.exp(
560 | 1j
561 | * 2
562 | * np.pi
563 | * np.repeat(np.expand_dims(fres, axis=2), nsamples, axis=2)
564 | * t
565 | )
566 | * np.exp(1j * phase),
567 | axis=1,
568 | )
569 |
570 | # norm signal power to 1
571 | signal = signal / np.sqrt(np.mean(np.abs(signal) ** 2))
572 |
573 | signal = amp @ signal
574 |
575 | return signal
576 |
577 |
578 | class MixedSignal(BroadSignal):
579 | def __init__(
580 | self,
581 | f_min: Union[int, float],
582 | f_max: Union[int, float],
583 | fs: Union[int, float],
584 | min_length_ratio: float = 0.1,
585 | no_overlap: bool = False,
586 | rng: Optional[np.random.Generator] = None,
587 | base: Literal["chirp", "multifreq"] = "chirp",
588 | ncarriers: int = 100,
589 | ):
590 | """Narrorband and broadband mixed signal
591 |
592 | Args:
593 | f_min (float): Minimum frequency
594 | f_max (float): Maximum frequency
595 | fs (int | float): Sampling frequency
596 | min_length_ratio (float): Minimum length ratio of the frequency
597 | band in (f_max - f_min)
598 | no_overlap (bool): If True, generate signals with non-overlapping
599 | bands
600 | rng (np.random.Generator): Random generator used to generate random
601 | numbers
602 | base (str): Type of base signal, either 'chirp' or 'multifreq'
603 | ncarriers (int): Only for `multifreq` base. Number of carrier
604 | frequencies in each broadband
605 |
606 | Raises:
607 | ValueError: If base is not 'chirp' or 'multifreq'
608 | """
609 | if base not in ["chirp", "multifreq"]:
610 | raise ValueError("base must be either 'chirp' or 'multifreq'")
611 | if base == "chirp":
612 | self._base = ChirpSignal(
613 | f_min, f_max, fs, min_length_ratio, no_overlap, rng
614 | )
615 | elif base == "multifreq":
616 | self._base = MultiFreqSignal(
617 | f_min, f_max, fs, min_length_ratio, no_overlap, rng, ncarriers
618 | )
619 |
620 | super().__init__(
621 | f_min, f_max, fs, min_length_ratio, no_overlap, rng=rng
622 | )
623 |
624 | def clear_cache(self):
625 | super().clear_cache()
626 | self._base.clear_cache()
627 |
628 | @property
629 | def f_min(self):
630 | return np.min([np.min(self._cache["narrow_freqs"]), self._base.f_min])
631 |
632 | @property
633 | def f_max(self):
634 | return np.max([np.max(self._cache["narrow_freqs"]), self._base.f_max])
635 |
636 | def gen(
637 | self,
638 | n: int,
639 | nsamples: int,
640 | amp: Optional[ListLike] = None,
641 | use_cache=False,
642 | delay: Optional[Union[npt.NDArray, int, float]] = None,
643 | m: Optional[int] = None,
644 | narrow_idx: Union[npt.NDArray[np.int_], list[int], None] = None,
645 | ):
646 | """Generate sampled signals
647 |
648 | Args:
649 | n (int): Number of all signals (narrowband and broadband)
650 | nsamples (int): Number of snapshots
651 | amp (np.array): Amplitude of the signals (1D array of size n), used
652 | to define different amplitudes for different signals.
653 | By default it will generate equal amplitude signal.
654 | use_cache (bool): If True, use cache to generate identical signals.
655 | Default to `False`.
656 | delay (float | None): If not None, apply delay to all signals.
657 | m (int): Number of narrowband signals inside `n`. If set to `None`,
658 | it will use a random int smaller than n
659 | narrow_idx (array): index of where narrowband signal is located in n
660 | signals
661 | """
662 | if m is None:
663 | m = self._rng.integers(1, n)
664 | else:
665 | if m >= n:
666 | raise ValueError(
667 | "Number of narrowband signals must be less than n"
668 | )
669 |
670 | amp = self._get_amp(amp, n)
671 |
672 | if use_cache and not self._cache == {}:
673 | narrow_freqs = self._cache["narrow_freqs"]
674 | phase = self._cache["phase"]
675 | narrow_idx = self._cache["narrow_idx"]
676 | assert narrow_freqs.shape == (m, 1), "Cache shape mismatch"
677 | assert phase.shape == (m, 1), "Cache shape mismatch"
678 | assert isinstance(narrow_idx, np.ndarray)
679 | else:
680 | narrow_freqs = self._rng.uniform(
681 | self._f_min, self._f_max, size=m
682 | ).reshape(-1, 1)
683 | phase = self._rng.uniform(0, 2 * np.pi, size=(m, 1))
684 | if narrow_idx is None:
685 | narrow_idx = self._rng.choice(n, m, replace=False)
686 | narrow_idx = np.array(narrow_idx)
687 | assert len(narrow_idx) == m, "narrow_idx length mismatch"
688 | self._set_cache("narrow_freqs", narrow_freqs)
689 | self._set_cache("phase", phase)
690 | self._set_cache("narrow_idx", narrow_idx)
691 |
692 | if delay is not None:
693 | if isinstance(delay, (int, float)):
694 | delay = np.ones(n) * delay
695 | else:
696 | delay = np.zeros(n)
697 |
698 | broad_idx = ~np.isin(np.arange(n), narrow_idx)
699 |
700 | # generate narrowband signals
701 | t = np.arange(nsamples) * (1 / self._fs)
702 | t = t + delay.reshape(-1, 1)[narrow_idx]
703 |
704 | narrow_s = (
705 | np.exp(1j * 2 * np.pi * narrow_freqs * t) # sine wave
706 | * np.exp(1j * phase) # phase
707 | )
708 |
709 | # generate broadband signals
710 | broad_s = self._base.gen(
711 | n=n - m,
712 | nsamples=nsamples,
713 | use_cache=use_cache,
714 | delay=delay.reshape(-1, 1)[broad_idx],
715 | )
716 |
717 | # combine narrowband and broadband signals
718 | signal = np.zeros((n, nsamples), dtype=np.complex128)
719 | signal[narrow_idx] = narrow_s
720 | signal[broad_idx] = broad_s
721 |
722 | signal = amp @ signal
723 |
724 | return signal
725 |
--------------------------------------------------------------------------------
/examples/broad.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 |
8 | from doa_py.algorithm import cssm, imusic, tops
9 | from doa_py.arrays import UniformLinearArray
10 | from doa_py.plot import plot_spatial_spectrum
11 | from doa_py.signals import ChirpSignal
12 |
13 | # signal parameters
14 | angle_incidence = np.array([0, 30])
15 | num_snapshots = 1000
16 | fre_min = 1e6
17 | fre_max = 1e7
18 | fs = 2.5e7
19 | snr = 0
20 |
21 | num_antennas = 8
22 | antenna_spacing = 0.5 * (
23 | 3e8 / fre_max
24 | ) # set to half wavelength of highest frequency
25 |
26 |
27 | # generate signal and received data
28 | signal = ChirpSignal(f_min=fre_min, f_max=fre_max, fs=fs)
29 |
30 | array = UniformLinearArray(m=num_antennas, dd=antenna_spacing)
31 |
32 | # plot the signal in the frequency domain
33 | plt.plot(
34 | np.fft.fftshift(np.fft.fftfreq(num_snapshots, 1 / fs)),
35 | np.abs(
36 | np.fft.fftshift(
37 | np.fft.fft(
38 | signal.gen(n=len(angle_incidence), nsamples=num_snapshots)
39 | )
40 | )
41 | ).transpose(),
42 | )
43 | plt.xlabel("Frequency (Hz)")
44 | plt.ylabel("Magnitude")
45 | plt.show()
46 |
47 | received_data = array.received_signal(
48 | signal=signal,
49 | snr=snr,
50 | nsamples=num_snapshots,
51 | angle_incidence=angle_incidence,
52 | unit="deg",
53 | )
54 |
55 |
56 | search_grids = np.arange(-90, 90, 1)
57 |
58 | num_signal = len(angle_incidence)
59 | spectrum = imusic(
60 | received_data=received_data,
61 | num_signal=num_signal,
62 | array=array,
63 | fs=fs,
64 | angle_grids=search_grids,
65 | num_groups=16,
66 | unit="deg",
67 | )
68 |
69 | plot_spatial_spectrum(
70 | spectrum=spectrum,
71 | ground_truth=angle_incidence,
72 | angle_grids=search_grids,
73 | num_signal=num_signal,
74 | )
75 |
76 | spectrum = cssm(
77 | received_data=received_data,
78 | num_signal=num_signal,
79 | array=array,
80 | fs=fs,
81 | angle_grids=search_grids,
82 | fre_ref=(fre_min + fre_max) / 2,
83 | pre_estimate=np.array([-1, 29]),
84 | unit="deg",
85 | )
86 |
87 | plot_spatial_spectrum(
88 | spectrum=spectrum,
89 | ground_truth=angle_incidence,
90 | angle_grids=search_grids,
91 | num_signal=num_signal,
92 | )
93 |
94 | spectrum = tops(
95 | received_data=received_data,
96 | num_signal=num_signal,
97 | array=array,
98 | fs=fs,
99 | num_groups=32,
100 | angle_grids=search_grids,
101 | fre_ref=4e6,
102 | unit="deg",
103 | )
104 |
105 | plot_spatial_spectrum(
106 | spectrum=spectrum,
107 | ground_truth=angle_incidence,
108 | angle_grids=search_grids,
109 | num_signal=num_signal,
110 | )
111 |
--------------------------------------------------------------------------------
/examples/ula.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
5 |
6 | import numpy as np
7 |
8 | from doa_py.algorithm import esprit, l1_svd, music, omp, root_music
9 | from doa_py.arrays import UniformLinearArray
10 | from doa_py.plot import plot_estimated_value, plot_spatial_spectrum
11 | from doa_py.signals import ComplexStochasticSignal
12 |
13 | # signal parameters
14 | num_snapshots = 300
15 | signal_fre = 2e7
16 | fs = 5e7
17 | snr = -5
18 |
19 | # array parameters
20 | num_antennas = 8
21 | antenna_spacing = 0.5 * (
22 | 3e8 / signal_fre
23 | ) # set array spacing to half wavelength
24 |
25 | # incident angles
26 | angle_incidence = np.array([0, 30])
27 | num_signal = len(angle_incidence)
28 |
29 | # initialize signal instance
30 | signal = ComplexStochasticSignal(fc=signal_fre)
31 |
32 | # initialize array instance
33 | array = UniformLinearArray(m=num_antennas, dd=antenna_spacing)
34 |
35 | # generate received data
36 | received_data = array.received_signal(
37 | signal=signal,
38 | snr=snr,
39 | nsamples=num_snapshots,
40 | angle_incidence=angle_incidence,
41 | unit="deg",
42 | )
43 |
44 | search_grids = np.arange(-90, 90, 1)
45 |
46 | music_spectrum = music(
47 | received_data=received_data,
48 | num_signal=num_signal,
49 | array=array,
50 | signal_fre=signal_fre,
51 | angle_grids=search_grids,
52 | unit="deg",
53 | )
54 |
55 | # plot spatial spectrum
56 | plot_spatial_spectrum(
57 | spectrum=music_spectrum,
58 | angle_grids=search_grids,
59 | ground_truth=angle_incidence,
60 | num_signal=num_signal,
61 | y_label="MUSIC Spectrum (dB)",
62 | )
63 |
64 | rmusic_estimates = root_music(
65 | received_data=received_data,
66 | num_signal=num_signal,
67 | array=array,
68 | signal_fre=signal_fre,
69 | unit="deg",
70 | )
71 | plot_estimated_value(
72 | estimates=rmusic_estimates,
73 | ground_truth=angle_incidence,
74 | y_label="Root-MUSIC Estimated Angle (deg)",
75 | )
76 |
77 | esprit_estimates = esprit(
78 | received_data=received_data,
79 | num_signal=num_signal,
80 | array=array,
81 | signal_fre=signal_fre,
82 | )
83 |
84 | plot_estimated_value(
85 | estimates=esprit_estimates,
86 | ground_truth=angle_incidence,
87 | y_label="ESPRIT Estimated Angle (deg)",
88 | )
89 |
90 | omp_estimates = omp(
91 | received_data=received_data,
92 | num_signal=num_signal,
93 | array=array,
94 | signal_fre=signal_fre,
95 | angle_grids=search_grids,
96 | unit="deg",
97 | )
98 |
99 | plot_estimated_value(
100 | estimates=omp_estimates,
101 | ground_truth=angle_incidence,
102 | y_label="OMP Estimated Angle (deg)",
103 | )
104 |
105 | l1_svd_spectrum = l1_svd(
106 | received_data=received_data,
107 | num_signal=num_signal,
108 | array=array,
109 | signal_fre=signal_fre,
110 | angle_grids=search_grids,
111 | unit="deg",
112 | )
113 |
114 | plot_spatial_spectrum(
115 | spectrum=l1_svd_spectrum,
116 | angle_grids=search_grids,
117 | ground_truth=angle_incidence,
118 | num_signal=num_signal,
119 | y_label="L1-SVD Spectrum (dB)",
120 | )
121 |
--------------------------------------------------------------------------------
/pics/doa_py.svg:
--------------------------------------------------------------------------------
1 |
2 |
70 |
--------------------------------------------------------------------------------
/pics/esprit.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
1310 |
--------------------------------------------------------------------------------
/pics/l1_svd.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
1460 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "doa_py"
3 | version = "0.4.0"
4 | description = "DOA estimation algorithms implemented in Python"
5 | readme = "README.md"
6 | license = { text = "MIT" }
7 | requires-python = ">=3.10"
8 | authors = [{ name = "Qian Xu", email = "xuq3196@outlook.com" }]
9 | maintainers = [{ name = "Qian Xu", email = "xuq3196@outlook.com" }]
10 | keywords = [
11 | "doa",
12 | "direction of arrival",
13 | "doa estimation",
14 | "array signal processing",
15 | ]
16 | classifiers = [
17 | "Development Status :: 3 - Alpha",
18 | "Intended Audience :: Science/Research",
19 | "License :: OSI Approved :: MIT License",
20 | "Programming Language :: Python :: 3.10",
21 | "Programming Language :: Python :: 3.11",
22 | "Programming Language :: Python :: 3.12",
23 | ]
24 | urls = { Homepage = "https://github.com/zhiim/doa_py" }
25 | dependencies = [
26 | "cvxpy>=1.5.3",
27 | "matplotlib>=3.9.2",
28 | "numpy>=2.1.1",
29 | "scikit-image>=0.24.0",
30 | "scipy>=1.14.1",
31 | ]
32 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | # Exclude a variety of commonly ignored directories.
2 | exclude = [
3 | ".bzr",
4 | ".direnv",
5 | ".eggs",
6 | ".git",
7 | ".git-rewrite",
8 | ".hg",
9 | ".ipynb_checkpoints",
10 | ".mypy_cache",
11 | ".nox",
12 | ".pants.d",
13 | ".pyenv",
14 | ".pytest_cache",
15 | ".pytype",
16 | ".ruff_cache",
17 | ".svn",
18 | ".tox",
19 | ".venv",
20 | ".vscode",
21 | "__pypackages__",
22 | "_build",
23 | "buck-out",
24 | "build",
25 | "dist",
26 | "node_modules",
27 | "site-packages",
28 | "venv",
29 | ]
30 |
31 | # Same as Black.
32 | line-length = 80
33 | indent-width = 4
34 |
35 | [lint]
36 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
37 | # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
38 | # McCabe complexity (`C901`) by default.
39 | select = [
40 | "F", # Pyflakes (various error check)
41 | "E", # pycodestyle Error
42 | "W", # pycodestyle Warn
43 | "N", # pep8-naming
44 | "A", # flak8-builtins
45 | "PLC", # Pylint
46 | "PLE",
47 | "PLW",
48 | "NPY", # Numpy-specific rules
49 | "I", # isort
50 | ]
51 | ignore = [
52 | # "N802", # function name should be lowwercase (conflict with PyQt)
53 | # "E402", # import must at top of file (in some case sys.path.append need to
54 | # be called before import some modules)
55 | ]
56 |
57 | # Allow fix for all enabled rules (when `--fix`) is provided.
58 | fixable = ["ALL"]
59 | unfixable = ["F401"]
60 |
61 | # Allow unused variables when underscore-prefixed.
62 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
63 |
64 | [lint.pydocstyle]
65 | convention = "google" # use google style in docstring style check
66 |
67 | [format]
68 | # Like Black, use double quotes for strings.
69 | quote-style = "double"
70 |
71 | # Like Black, indent with spaces, rather than tabs.
72 | indent-style = "space"
73 |
74 | # Like Black, respect magic trailing commas.
75 | skip-magic-trailing-comma = false
76 |
77 | # Like Black, automatically detect the appropriate line ending.
78 | line-ending = "auto"
79 |
80 | # Enable auto-formatting of code examples in docstrings. Markdown,
81 | # reStructuredText code/literal blocks and doctests are all supported.
82 | #
83 | # This is currently disabled by default, but it is planned for this
84 | # to be opt-out in the future.
85 | docstring-code-format = false
86 |
87 | # Set the line length limit used when formatting code snippets in
88 | # docstrings.
89 | #
90 | # This only has an effect when the `docstring-code-format` setting is
91 | # enabled.
92 | docstring-code-line-length = "dynamic"
93 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | def find_long_description():
5 | with open("README.md", encoding="utf-8") as f:
6 | return f.read()
7 |
8 |
9 | def find_version():
10 | with open("doa_py/__init__.py", encoding="utf-8") as f:
11 | for line in f:
12 | if line.startswith("__version__"):
13 | return line.strip().split()[-1][1:-1]
14 | return ""
15 |
16 |
17 | setup(
18 | name="doa_py",
19 | version=find_version(),
20 | packages=find_packages(),
21 | description="DOA estimation algorithms implemented in Python",
22 | author="Qian Xu",
23 | author_email="xuq3196@outlook.com",
24 | url="https://github.com/zhiim/doa_py",
25 | long_description=find_long_description(),
26 | long_description_content_type="text/markdown",
27 | license="MIT",
28 | install_requires=["numpy", "matplotlib", "scipy", "scikit-image", "cvxpy"],
29 | classifiers=[
30 | "Development Status :: 3 - Alpha",
31 | "Intended Audience :: Science/Research",
32 | "License :: OSI Approved :: MIT License",
33 | "Programming Language :: Python :: 3.10",
34 | "Programming Language :: Python :: 3.11",
35 | "Programming Language :: Python :: 3.12",
36 | ],
37 | )
38 |
--------------------------------------------------------------------------------