├── .gitattributes ├── .gitmodules ├── LICENSE ├── README.md ├── deepwavetorch ├── layers │ ├── backproj.py │ └── graph_conv.py ├── models │ └── deepwave.py ├── tests │ ├── inference │ │ └── mse_benchmark.py │ └── simple_load │ │ ├── grid.npy │ │ └── load_model.py └── utils │ ├── activations.py │ ├── laplacian.py │ └── matrix_operations.py ├── figures ├── DeepWave_fields_comparison.png └── task4_recording1.gif ├── notebooks ├── .ipynb_checkpoints │ └── inference_benchmark-checkpoint.ipynb ├── DeepWave_training.ipynb ├── deepwave_sketches.ipynb └── inference_benchmark.ipynb ├── requirements.txt ├── setup.cfg ├── setup.py └── tracks ├── eigenmike_grid.npy ├── locata_task1_recording2.wav ├── pretrained_freq0_locata_weights.npz └── task1_recording1.wav /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb -linguist-detectable 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ImoT_tools"] 2 | path = ImoT_tools 3 | url = git@github.com:imagingofthings/ImoT_tools.git 4 | [submodule "DeepWave"] 5 | path = DeepWave 6 | url = git@github.com:imagingofthings/DeepWave.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | 69 | Section 1 -- Definitions. 70 | 71 | a. Adapted Material means material subject to Copyright and Similar 72 | Rights that is derived from or based upon the Licensed Material 73 | and in which the Licensed Material is translated, altered, 74 | arranged, transformed, or otherwise modified in a manner requiring 75 | permission under the Copyright and Similar Rights held by the 76 | Licensor. For purposes of this Public License, where the Licensed 77 | Material is a musical work, performance, or sound recording, 78 | Adapted Material is always produced where the Licensed Material is 79 | synched in timed relation with a moving image. 80 | 81 | b. Adapter's License means the license You apply to Your Copyright 82 | and Similar Rights in Your contributions to Adapted Material in 83 | accordance with the terms and conditions of this Public License. 84 | 85 | c. Copyright and Similar Rights means copyright and/or similar rights 86 | closely related to copyright including, without limitation, 87 | performance, broadcast, sound recording, and Sui Generis Database 88 | Rights, without regard to how the rights are labeled or 89 | categorized. For purposes of this Public License, the rights 90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 91 | Rights. 92 | 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. Share means to provide material to the public by any means or 116 | process that requires permission under the Licensed Rights, such 117 | as reproduction, public display, public performance, distribution, 118 | dissemination, communication, or importation, and to make material 119 | available to the public including in ways that members of the 120 | public may access the material from a place and at a time 121 | individually chosen by them. 122 | 123 | j. Sui Generis Database Rights means rights other than copyright 124 | resulting from Directive 96/9/EC of the European Parliament and of 125 | the Council of 11 March 1996 on the legal protection of databases, 126 | as amended and/or succeeded, as well as other essentially 127 | equivalent rights anywhere in the world. 128 | 129 | k. You means the individual or entity exercising the Licensed Rights 130 | under this Public License. Your has a corresponding meaning. 131 | 132 | 133 | Section 2 -- Scope. 134 | 135 | a. License grant. 136 | 137 | 1. Subject to the terms and conditions of this Public License, 138 | the Licensor hereby grants You a worldwide, royalty-free, 139 | non-sublicensable, non-exclusive, irrevocable license to 140 | exercise the Licensed Rights in the Licensed Material to: 141 | 142 | a. reproduce and Share the Licensed Material, in whole or 143 | in part; and 144 | 145 | b. produce, reproduce, and Share Adapted Material. 146 | 147 | 2. Exceptions and Limitations. For the avoidance of doubt, where 148 | Exceptions and Limitations apply to Your use, this Public 149 | License does not apply, and You do not need to comply with 150 | its terms and conditions. 151 | 152 | 3. Term. The term of this Public License is specified in Section 153 | 6(a). 154 | 155 | 4. Media and formats; technical modifications allowed. The 156 | Licensor authorizes You to exercise the Licensed Rights in 157 | all media and formats whether now known or hereafter created, 158 | and to make technical modifications necessary to do so. The 159 | Licensor waives and/or agrees not to assert any right or 160 | authority to forbid You from making technical modifications 161 | necessary to exercise the Licensed Rights, including 162 | technical modifications necessary to circumvent Effective 163 | Technological Measures. For purposes of this Public License, 164 | simply making modifications authorized by this Section 2(a) 165 | (4) never produces Adapted Material. 166 | 167 | 5. Downstream recipients. 168 | 169 | a. Offer from the Licensor -- Licensed Material. Every 170 | recipient of the Licensed Material automatically 171 | receives an offer from the Licensor to exercise the 172 | Licensed Rights under the terms and conditions of this 173 | Public License. 174 | 175 | b. No downstream restrictions. You may not offer or impose 176 | any additional or different terms or conditions on, or 177 | apply any Effective Technological Measures to, the 178 | Licensed Material if doing so restricts exercise of the 179 | Licensed Rights by any recipient of the Licensed 180 | Material. 181 | 182 | 6. No endorsement. Nothing in this Public License constitutes or 183 | may be construed as permission to assert or imply that You 184 | are, or that Your use of the Licensed Material is, connected 185 | with, or sponsored, endorsed, or granted official status by, 186 | the Licensor or others designated to receive attribution as 187 | provided in Section 3(a)(1)(A)(i). 188 | 189 | b. Other rights. 190 | 191 | 1. Moral rights, such as the right of integrity, are not 192 | licensed under this Public License, nor are publicity, 193 | privacy, and/or other similar personality rights; however, to 194 | the extent possible, the Licensor waives and/or agrees not to 195 | assert any such rights held by the Licensor to the limited 196 | extent necessary to allow You to exercise the Licensed 197 | Rights, but not otherwise. 198 | 199 | 2. Patent and trademark rights are not licensed under this 200 | Public License. 201 | 202 | 3. To the extent possible, the Licensor waives any right to 203 | collect royalties from You for the exercise of the Licensed 204 | Rights, whether directly or through a collecting society 205 | under any voluntary or waivable statutory or compulsory 206 | licensing scheme. In all other cases the Licensor expressly 207 | reserves any right to collect such royalties. 208 | 209 | 210 | Section 3 -- License Conditions. 211 | 212 | Your exercise of the Licensed Rights is expressly made subject to the 213 | following conditions. 214 | 215 | a. Attribution. 216 | 217 | 1. If You Share the Licensed Material (including in modified 218 | form), You must: 219 | 220 | a. retain the following if it is supplied by the Licensor 221 | with the Licensed Material: 222 | 223 | i. identification of the creator(s) of the Licensed 224 | Material and any others designated to receive 225 | attribution, in any reasonable manner requested by 226 | the Licensor (including by pseudonym if 227 | designated); 228 | 229 | ii. a copyright notice; 230 | 231 | iii. a notice that refers to this Public License; 232 | 233 | iv. a notice that refers to the disclaimer of 234 | warranties; 235 | 236 | v. a URI or hyperlink to the Licensed Material to the 237 | extent reasonably practicable; 238 | 239 | b. indicate if You modified the Licensed Material and 240 | retain an indication of any previous modifications; and 241 | 242 | c. indicate the Licensed Material is licensed under this 243 | Public License, and include the text of, or the URI or 244 | hyperlink to, this Public License. 245 | 246 | 2. You may satisfy the conditions in Section 3(a)(1) in any 247 | reasonable manner based on the medium, means, and context in 248 | which You Share the Licensed Material. For example, it may be 249 | reasonable to satisfy the conditions by providing a URI or 250 | hyperlink to a resource that includes the required 251 | information. 252 | 253 | 3. If requested by the Licensor, You must remove any of the 254 | information required by Section 3(a)(1)(A) to the extent 255 | reasonably practicable. 256 | 257 | 4. If You Share Adapted Material You produce, the Adapter's 258 | License You apply must not prevent recipients of the Adapted 259 | Material from complying with this Public License. 260 | 261 | 262 | Section 4 -- Sui Generis Database Rights. 263 | 264 | Where the Licensed Rights include Sui Generis Database Rights that 265 | apply to Your use of the Licensed Material: 266 | 267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 268 | to extract, reuse, reproduce, and Share all or a substantial 269 | portion of the contents of the database; 270 | 271 | b. if You include all or a substantial portion of the database 272 | contents in a database in which You have Sui Generis Database 273 | Rights, then the database in which You have Sui Generis Database 274 | Rights (but not its individual contents) is Adapted Material; and 275 | 276 | c. You must comply with the conditions in Section 3(a) if You Share 277 | all or a substantial portion of the contents of the database. 278 | 279 | For the avoidance of doubt, this Section 4 supplements and does not 280 | replace Your obligations under this Public License where the Licensed 281 | Rights include other Copyright and Similar Rights. 282 | 283 | 284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 285 | 286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 296 | 297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 306 | 307 | c. The disclaimer of warranties and limitation of liability provided 308 | above shall be interpreted in a manner that, to the extent 309 | possible, most closely approximates an absolute disclaimer and 310 | waiver of all liability. 311 | 312 | 313 | Section 6 -- Term and Termination. 314 | 315 | a. This Public License applies for the term of the Copyright and 316 | Similar Rights licensed here. However, if You fail to comply with 317 | this Public License, then Your rights under this Public License 318 | terminate automatically. 319 | 320 | b. Where Your right to use the Licensed Material has terminated under 321 | Section 6(a), it reinstates: 322 | 323 | 1. automatically as of the date the violation is cured, provided 324 | it is cured within 30 days of Your discovery of the 325 | violation; or 326 | 327 | 2. upon express reinstatement by the Licensor. 328 | 329 | For the avoidance of doubt, this Section 6(b) does not affect any 330 | right the Licensor may have to seek remedies for Your violations 331 | of this Public License. 332 | 333 | c. For the avoidance of doubt, the Licensor may also offer the 334 | Licensed Material under separate terms or conditions or stop 335 | distributing the Licensed Material at any time; however, doing so 336 | will not terminate this Public License. 337 | 338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 339 | License. 340 | 341 | 342 | Section 7 -- Other Terms and Conditions. 343 | 344 | a. The Licensor shall not be bound by any additional or different 345 | terms or conditions communicated by You unless expressly agreed. 346 | 347 | b. Any arrangements, understandings, or agreements regarding the 348 | Licensed Material not stated herein are separate from and 349 | independent of the terms and conditions of this Public License. 350 | 351 | 352 | Section 8 -- Interpretation. 353 | 354 | a. For the avoidance of doubt, this Public License does not, and 355 | shall not be interpreted to, reduce, limit, restrict, or impose 356 | conditions on any use of the Licensed Material that could lawfully 357 | be made without permission under this Public License. 358 | 359 | b. To the extent possible, if any provision of this Public License is 360 | deemed unenforceable, it shall be automatically reformed to the 361 | minimum extent necessary to make it enforceable. If the provision 362 | cannot be reformed, it shall be severed from this Public License 363 | without affecting the enforceability of the remaining terms and 364 | conditions. 365 | 366 | c. No term or condition of this Public License will be waived and no 367 | failure to comply consented to unless expressly agreed to by the 368 | Licensor. 369 | 370 | d. Nothing in this Public License constitutes or may be interpreted 371 | as a limitation upon, or waiver of, any privileges and immunities 372 | that apply to the Licensor or You, including from the legal 373 | processes of any jurisdiction or authority. 374 | 375 | 376 | ======================================================================= 377 | 378 | Creative Commons is not a party to its public 379 | licenses. Notwithstanding, Creative Commons may elect to apply one of 380 | its public licenses to material it publishes and in those instances 381 | will be considered the “Licensor.” The text of the Creative Commons 382 | public licenses is dedicated to the public domain under the CC0 Public 383 | Domain Dedication. Except for the limited purpose of indicating that 384 | material is shared under a Creative Commons public license or as 385 | otherwise permitted by the Creative Commons policies published at 386 | creativecommons.org/policies, Creative Commons does not authorize the 387 | use of the trademark "Creative Commons" or any other trademark or logo 388 | of Creative Commons without its prior written consent including, 389 | without limitation, in connection with any unauthorized modifications 390 | to any of its public licenses or any other arrangements, 391 | understandings, or agreements concerning use of licensed material. For 392 | the avoidance of doubt, this paragraph does not form part of the 393 | public licenses. 394 | 395 | Creative Commons may be contacted at creativecommons.org. 396 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## DeepWave: A Recurrent Neural-Network for Real-Time Acoustic Imaging (PyTorch) 2 | 3 | 4 | 5 | This repository contains a PyTorch implementation of the DeepWave model originally published at NeurIPS 2019 6 | 7 | | [paper](https://proceedings.neurips.cc/paper/2019/file/e9bf14a419d77534105016f5ec122d62-Paper.pdf) | [original code](https://github.com/imagingofthings/DeepWave) | 8 | 9 | Get started with DeepWave (PyTorch) inference 10 | 11 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BC72KmoAyeydS0X7Dti3fFxzLAxfbdWC?usp=sharing) 12 | 13 | ## Main building blocks: 14 | 15 | - SphericalChebConv: Spherical Chebyshev graph convolutions 16 | - BackProjLayer: project correlation matrix into image form (intensity map form) 17 | - ReTanh: Rectified hyperbolic tangent action function 18 | - DeepWave: the actual model architecture 19 | 20 | DeepWave (PyTorch) architecture: 21 | ``` 22 | Deepwave: input=S (visibility matrix), trainable={mu, D, tau} 23 | y <- BackProjLayer(S) 24 | conv4 <- SphericalChebConv(I_init) + y 25 | conv4 <- ReTanh(conv4) 26 | conv3 <- SphericalChebConv(conv4) + y 27 | conv3 <- ReTanh(conv3) 28 | conv2 <- SphericalChebConv(conv3) + y 29 | conv2 <- ReTanh(conv2) 30 | conv1 <- SphericalChebConv(conv2) + y 31 | conv1 <- ReTanh(conv1) 32 | conv0 <- SphericalChebConv(conv1) + y 33 | I_out <- ReTanh(conv0) 34 | ``` 35 | 36 | Many of the operations used in this implementation were borrowed from the repository [deepsphere-pytorch](https://github.com/deepsphere/deepsphere-pytorch). 37 | 38 | ## Installation 39 | 40 | #### First time setup 41 | Create a Python virtual environment 42 | ``` 43 | python3 -m venv /path/to/your_new_venv 44 | ``` 45 | 46 | Start `your_new_venv` 47 | ``` 48 | source /path/to//bin/activate 49 | ``` 50 | 51 | - Clone `DeepWaveTorch` (this repo!). 52 | ``` 53 | git clone git@github.com:adrianSRoman/DeepWaveTorch.git 54 | ``` 55 | 56 | - Initialize submodules 57 | ``` 58 | git submodule init 59 | ``` 60 | 61 | - Start `ImoT_tools`: plotting library used to nicely visualize DeepWave's output 62 | ``` 63 | cd ImoT_tools 64 | pip install -r requirements.txt 65 | python3 setup.py develop 66 | ``` 67 | 68 | - Start `DeepWave`: original DeepWave implementation. Used for benchmarking against the PyTorch implementation. Data loaders are also re-used from the original implementation. 69 | ``` 70 | cd DeepWave 71 | python3 setup.py develop 72 | ``` 73 | 74 | - Start `DeepWaveTorch`: new DeepWave PyTorch implementation. 75 | ``` 76 | cd DeepWaveTorch 77 | pip install -r requirements.txt 78 | python3 setup.py develop 79 | ``` 80 | 81 | ### Executing example `notebooks` 82 | 83 | - Start `your_new_venv` 84 | ``` 85 | source /path/to//bin/activate 86 | ``` 87 | 88 | - Create a Jupyter Kernel to contain your required packages (first time setup only) 89 | ``` 90 | pip install ipykernel 91 | ipython kernel install --user --name= 92 | ``` 93 | 94 | - Start Jupyter 95 | 96 | ``` 97 | jupyter notebook 98 | ``` 99 | 100 | - Select a given notebook you want to work with 101 | - Select `your_new_kernel_name` under: Kernel > Change kernel > `your_new_kernel_name` 102 | 103 | ## Extracting visibility matrices (DeepWave's input) 104 | 105 | Extracting visibility matrices `S` is perhaps the main aspect you will need because they are DeepWave's input! To save you some time we created a simple notebook that for a given audio track of M microphones and N samples (we use an eigenmike32), it generates a visibility matrix (MxM) over 100 msec audio frames. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1BC72KmoAyeydS0X7Dti3fFxzLAxfbdWC?usp=sharing) 106 | 107 | Note: the data extraction for the experiments in this repo (including extracting visibility matrices) was done using the `run.sh` files from the original DeepWave code [see example run.sh](https://github.com/imagingofthings/DeepWave/blob/master/datasets/FRIDA/run.sh). The `run.sh` files perform three tasks: 108 | 109 | (1) Extract data from a mic array of N channels to generate visibility matrices `S` (this is what DeepWave takes as input). Extract a ground truth intensity field `I` which DeepWave learns how to generate. The script will then generate a `.npz` file. 110 | 111 | (2) Merge datasets in single `.npz` files for the 9 different extracted frequency bands across all tracks. 112 | 113 | (3) Train the original DeepWave model. After training is done, you will find `.npz` files containing the trained weights of the model for each frequency band. 114 | 115 | Overall, you can either use the colab notebook to extract visibility matrices, or execute the `run.sh` scripts from the original DeepWave repo (you choose). 116 | 117 | ## Qualitative and quantitative comparison against the original DeepWave implementation 118 | 119 | #### Qualitatively, the implementation from this repository generates the same intensity fields as the original DeepWave implementation. 120 | 121 | ##### Inferred intensity field for a single frequency band: DeepWave original Vs. DeepWave PyTorch 122 | 123 |

124 | 125 |

126 | 127 | 128 | ## YouTube presentation 129 | 130 | [![Mexico LatAm BISH Bash: DeepWave](https://img.youtube.com/vi/ZO5jfqY_NwA/0.jpg)](https://www.youtube.com/watch?v=ZO5jfqY_NwA) 131 | 132 | 133 | ## License 134 | Shield: [![CC BY 4.0][cc-by-shield]][cc-by] 135 | 136 | This work is licensed under a 137 | [Creative Commons Attribution 4.0 International License][cc-by]. 138 | 139 | [![CC BY 4.0][cc-by-image]][cc-by] 140 | 141 | [cc-by]: http://creativecommons.org/licenses/by/4.0/ 142 | [cc-by-image]: https://i.creativecommons.org/l/by/4.0/88x31.png 143 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg 144 | -------------------------------------------------------------------------------- /deepwavetorch/layers/backproj.py: -------------------------------------------------------------------------------- 1 | """Back projection layer 2 | Project frequency domain mic correlation matrix into an image form of N pixels 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | 8 | class BackProjLayer(torch.nn.Module): 9 | """Spherical Convolutional Neural Netork. 10 | """ 11 | 12 | def __init__(self, Nch, Npx, tau=None, D=None): 13 | """Initialization. 14 | Args: 15 | Nch (int): number of channels in mic array 16 | Npx (int): number of pixels in Robinson projection 17 | """ 18 | super().__init__() 19 | if tau is None or D is None: 20 | self.tau = torch.nn.Parameter(torch.empty((Npx), dtype=torch.float64)) 21 | self.D = torch.nn.Parameter(torch.empty((Nch, Npx), dtype=torch.complex128)) 22 | self.reset_parameters() 23 | else: 24 | self.tau = torch.nn.Parameter(tau) 25 | self.D = torch.nn.Parameter(D) 26 | 27 | 28 | def reset_parameters(self): 29 | std = 1e-5 30 | self.tau.data.normal_(0, std) 31 | self.D.data.normal_(0, std) 32 | 33 | def forward(self, S): 34 | """Forward Pass. 35 | Args: 36 | S (:obj:`torch.Tensor`): input to be forwarded. (N_sample, Npx) 37 | Returns: 38 | :obj:`torch.Tensor`: output: (N_sample, Npx) 39 | """ 40 | N_sample, N_px = S.shape[0], self.tau.shape[0] 41 | y = torch.zeros((N_sample, N_px)) 42 | for i in range(N_sample): # Loop to handle linalg.eigh: broadcasting can be slower 43 | Ds, Vs = torch.linalg.eigh(S[i]) # (Nch, Nch), (Nch, Nch) 44 | idx = Ds > 0 # To avoid np.sqrt() issues. 45 | Ds, Vs = Ds[idx], Vs[:, idx] 46 | y[i] = torch.linalg.norm(self.D.conj().T @ (Vs * torch.sqrt(Ds)), axis=1) ** 2 # (Npx, Nch) dot ((Nch, Nch) * (Nch, Nch)) 47 | y -= self.tau 48 | return y -------------------------------------------------------------------------------- /deepwavetorch/layers/graph_conv.py: -------------------------------------------------------------------------------- 1 | """Chebyshev convolution layer. 2 | PyTorch implementation inspired from: https://github.com/deepsphere/deepsphere-pytorch/blob/master/deepsphere/layers/chebyshev.py 3 | Based upon NumPy implementation from: https://github.com/imagingofthings/DeepWave 4 | """ 5 | 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | 11 | def cheb_conv(laplacian, inputs, weight): 12 | """Chebyshev convolution. 13 | Args: 14 | laplacian (:obj:`torch.sparse.Tensor`): The laplacian corresponding to the current sampling of the sphere. (Npx, Npx) 15 | inputs (:obj:`torch.Tensor`): The current input data being forwarded. (Nsamps, Npx) 16 | weight (:obj:`torch.Tensor`): The weights of the current layer. (K,) 17 | Returns: 18 | :obj:`torch.Tensor`: Inputs after applying Chebyshev convolution. (Nsamps, Npx) 19 | """ 20 | K = weight.shape[0] 21 | was_1d = (inputs.ndim == 1) 22 | if was_1d: 23 | inputs = inputs.unsqueeze(0) 24 | N_sample, Npx = inputs.shape[0], inputs.shape[1] 25 | 26 | x0 = inputs.T # (Nsamps, Npx).T -> (Npx, Nsamps) 27 | inputs = x0.T.unsqueeze(0) # (1, Npx, Nsamps) 28 | x1 = torch.sparse.mm(laplacian, x0) # (Npx, Npx) x (Npx, Nsamps) -> (Npx, Nsamps) 29 | inputs = torch.cat((inputs, x1.T.unsqueeze(0)), 0) # (1, Nsamps, Npx) + (1, Nsamps, Npx) = (2*Nsamps, Npx) 30 | for _ in range(1, K - 1): 31 | x2 = 2 * torch.sparse.mm(laplacian, x1) - x0 # (Npx, Npx) x (Npx, Nsamps) - (Npx, Nsamps) = (Npx, Nsamps) 32 | inputs = torch.cat((inputs, x2.T.unsqueeze(0)), 0) # (ki, Nsamps, Npx) + (1, Nsamps, Npx) 33 | x0, x1 = x1, x2 # (Npx, Nsamps), (Npx, Nsamps) 34 | inputs = inputs.permute(1, 2, 0).contiguous() # (K, Nsamps, Npx) 35 | inputs = inputs.view([N_sample*Npx, K]) # (Nsamps*Npx, K) 36 | inputs = inputs.matmul(weight) # (Nsamps*Npx, K) x (K,) -> (Nsamps*Npx,) 37 | inputs = inputs.view([N_sample, Npx]) # (Nsamps, Npx) 38 | return inputs 39 | 40 | class ChebConv(torch.nn.Module): 41 | """Graph convolutional layer. 42 | """ 43 | 44 | def __init__(self, in_channels, out_channels, kernel_size, weight=None, bias=False, conv=cheb_conv): 45 | """Initialize the Chebyshev layer. 46 | Args: 47 | in_channels (int): Number of channels/features in the input graph. 48 | out_channels (int): Number of channels/features in the output graph. 49 | kernel_size (int): Number of trainable parameters per filter, which is also the size of the convolutional kernel. 50 | The order of the Chebyshev polynomials is kernel_size - 1. 51 | weight (torch.Tensor): pre-trained or intial state weight matrix (K,) 52 | bias (bool): Whether to add a bias term. 53 | conv (callable): Function which will perform the graph convolution. 54 | """ 55 | super().__init__() 56 | 57 | self.in_channels = in_channels 58 | self.out_channels = out_channels 59 | self.kernel_size = kernel_size 60 | self._conv = conv 61 | if weight is None: 62 | shape = (kernel_size,) 63 | self.weight = torch.nn.Parameter(torch.DoubleTensor(*shape)) 64 | std = math.sqrt(2 / (self.in_channels * self.kernel_size)) 65 | self.weight.data.normal_(0, std) 66 | else: 67 | self.weight = torch.nn.Parameter(weight) 68 | 69 | if bias: 70 | self.bias = torch.nn.Parameter(torch.DoubleTensor(out_channels)) 71 | else: 72 | self.register_parameter("bias", None) 73 | 74 | self.bias_initialization() 75 | 76 | def bias_initialization(self): 77 | """Initialize bias. 78 | """ 79 | if self.bias is not None: 80 | self.bias.data.fill_(0.00001) 81 | 82 | def forward(self, laplacian, inputs): 83 | """Forward graph convolution. 84 | Args: 85 | laplacian (:obj:`torch.sparse.Tensor`): The laplacian corresponding to the current sampling of the sphere. 86 | inputs (:obj:`torch.Tensor`): The current input data being forwarded. 87 | Returns: 88 | :obj:`torch.Tensor`: The convoluted inputs. 89 | """ 90 | outputs = self._conv(laplacian, inputs, self.weight) 91 | if self.bias is not None: 92 | outputs += self.bias 93 | return outputs 94 | 95 | class SphericalChebConv(torch.nn.Module): 96 | """Chebyshev Graph Convolution. 97 | """ 98 | 99 | def __init__(self, in_channels, out_channels, lap, kernel_size, weight=None): 100 | """Initialization. 101 | Args: 102 | in_channels (int): initial number of channels 103 | out_channels (int): output number of channels 104 | lap (:obj:`torch.sparse.DoubleTensor`): laplacian 105 | kernel_size (int): order of polynomial filter K. Defaults to 3. 106 | weight (:obj:`torch.sparse.DoubleTensor`): weight convolutional matrix (K,) 107 | """ 108 | super().__init__() 109 | self.register_buffer("laplacian", lap) 110 | self.chebconv = ChebConv(in_channels, out_channels, kernel_size, weight) 111 | 112 | def forward(self, x): 113 | """Forward pass. 114 | Args: 115 | x (:obj:`torch.tensor`): input [batch x vertices x channels/features] 116 | Returns: 117 | :obj:`torch.tensor`: output [batch x vertices x channels/features] 118 | """ 119 | x = self.chebconv(self.laplacian, x) 120 | return x -------------------------------------------------------------------------------- /deepwavetorch/models/deepwave.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from deepwavetorch.layers.backproj import BackProjLayer 5 | from deepwavetorch.layers.graph_conv import SphericalChebConv 6 | from deepwavetorch.utils.activations import ReTanh 7 | from deepwavetorch.utils.laplacian import laplacian_exp 8 | 9 | class DeepWave(torch.nn.Module): 10 | """DeepWave: real-time recurrent neural network for acoustic imaging. 11 | """ 12 | 13 | def __init__(self, R, kernel_size, Nch, Npx, batch_size=1, depth=1, pretr_params=None): 14 | """Initialization. 15 | Args: 16 | R: Cartesian coordinates of point set (N,3) 17 | kernel_size (int): polynomial degree. 18 | Nch (int): number of channels in mic array 19 | Npx (int): number of pixels in Robinson projection 20 | """ 21 | super().__init__() 22 | 23 | self.laps, self.rho = laplacian_exp(R, depth) 24 | if pretr_params: # use pre-trained parameters: tau, mu and D 25 | self.y_backproj = BackProjLayer(Nch, Npx, tau=pretr_params['tau'], D=pretr_params['D']) 26 | self.sconvl = SphericalChebConv(Nch, Npx, self.laps[0], kernel_size, weight=pretr_params['mu']) 27 | else: 28 | self.y_backproj = BackProjLayer(Nch, Npx) 29 | self.sconvl = SphericalChebConv(Nch, Npx, self.laps[0], kernel_size) 30 | self.retanh = ReTanh(alpha=1.000000) 31 | 32 | def reset_parameters(self): 33 | std = 1e-4 34 | self.I.data.random_(0, std) 35 | 36 | def forward(self, S, I_prev=None): 37 | """Forward Pass. 38 | Args: 39 | S (:obj:`torch.Tensor`): input (Nch, Nch) 40 | Returns: 41 | y_proj :obj: `torch.Tensor`: output (N_sample, Npx) 42 | x_conv* :obj: `torch.Tensor`: output (N_samle, Npx) 43 | """ 44 | y_proj = self.y_backproj(S) 45 | if I_prev is None: 46 | I_prev = torch.zeros(y_proj.shape[1], dtype=torch.float64) 47 | x_conv4 = self.sconvl(I_prev) 48 | x_conv4 = x_conv4.add(y_proj) 49 | x_conv4 = self.retanh(x_conv4) 50 | x_conv3 = self.sconvl(x_conv4) 51 | x_conv3 = x_conv3.add(y_proj) 52 | x_conv3 = self.retanh(x_conv3) 53 | x_conv2 = self.sconvl(x_conv3) 54 | x_conv2 = x_conv2.add(y_proj) 55 | x_conv2 = self.retanh(x_conv2) 56 | x_conv1 = self.sconvl(x_conv2) 57 | x_conv1 = x_conv1.add(y_proj) 58 | x_conv1 = self.retanh(x_conv1) 59 | x_conv0 = self.sconvl(x_conv1) 60 | x_conv0 = x_conv0.add(y_proj) 61 | x_conv0 = self.retanh(x_conv0) 62 | out = x_conv0 63 | 64 | return out 65 | -------------------------------------------------------------------------------- /deepwavetorch/tests/inference/mse_benchmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import deepwave.nn as nn 3 | import deepwave.tools.math.func as func 4 | import deepwave.tools.math.graph as graph 5 | import deepwave.nn.crnn as deepwave_numpy # NumPy DeepWave model library 6 | 7 | import torch 8 | from deepwavetorch.models.deepwave import DeepWave as deepwave_torch # PyTorch DeepWave model 9 | 10 | Df = nn.DataSet.from_file("/home/asroman/repos/DeepWave/datasets/FRIDA/dataset/D_1-5_freq0_cold.npz") # 0th frequency training data 11 | Pf = np.load("/home/asroman/repos/DeepWave/datasets/FRIDA/dataset/D_freq0_train.npz") # trained model parameters 12 | N_antenna = Df.XYZ.shape[1] # number of microphones in the array data 13 | N_px = Df.R.shape[1] # number of pixels in the intensity map 14 | K = int(Pf['K']) # Chev filter polynomial order 15 | print("Num antenna:", N_antenna) 16 | print("Num pixels:", N_px) 17 | print("Filter Kth:", K) 18 | parameter = deepwave_numpy.Parameter(N_antenna, N_px, K) 19 | sampler = Df.sampler() 20 | p_opt = Pf['p_opt'][np.argmin(Pf['v_loss'])] 21 | 22 | # We will use the 0th image to test DeepWave original VS. DeepWave Pytorch 23 | S, I, I_prev = sampler.decode(Df[0]) 24 | N_layer = Pf['N_layer'] 25 | print("Number of layers:", N_layer) 26 | p_mu, p_D, p_tau = parameter.decode(p_opt) # Load trained parameters 27 | 28 | # Load the DeepWave NumPy model 29 | Ln, _ = graph.laplacian_exp(Df.R, normalized=True) 30 | afunc = lambda _: func.retanh(Pf['tanh_lin_limit'], _) 31 | deepwavenumpy = deepwave_numpy.Evaluator(N_layer, parameter, p_opt, Ln, afunc) 32 | 33 | # Next we want to re-use the network parameters for the PyTorch DeepWave model 34 | # Save orinal DeepWave network's trained parameters 35 | pretr_params = {"mu": torch.from_numpy(p_mu), 36 | "tau": torch.from_numpy(p_tau), 37 | "D": torch.from_numpy(p_D)} 38 | # Load the DeepWave PyTorch model 39 | deepwavetorch = deepwave_torch(R=Df.R, kernel_size=K, Nch=N_antenna, Npx=N_px, 40 | batch_size=1, depth=1, pretr_params=pretr_params) 41 | 42 | # Comparison of NumPy Vs. PyTorch DeepWave model 43 | I_numpy = deepwavenumpy(S, I_prev) 44 | I_torch = deepwavetorch(torch.from_numpy(S).unsqueeze(0), torch.from_numpy(I_prev).double()).cpu().detach().numpy() 45 | print("Intensity fields MSE error (NumPy vs. PyTorch):") 46 | mse = np.square(np.subtract(I_numpy, I_torch)).mean() 47 | print(mse) # NOTE: we expect to see 1e-15 error, which is pretty close to machine epsilon 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /deepwavetorch/tests/simple_load/grid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianSRoman/DeepWaveTorch/7897779076428e2a69f36d7420b067f827236f89/deepwavetorch/tests/simple_load/grid.npy -------------------------------------------------------------------------------- /deepwavetorch/tests/simple_load/load_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from deepwavetorch.models.deepwave import DeepWave 5 | 6 | N_px = 2234 7 | R = np.load("./grid.npy") 8 | K = 22 9 | N_antenna = 48 10 | 11 | deepwave_torch = DeepWave(R=R, kernel_size=K, 12 | Nch=N_antenna, Npx=N_px, 13 | batch_size=1, depth=1, 14 | pretr_params=None) 15 | 16 | print(deepwave_torch) 17 | -------------------------------------------------------------------------------- /deepwavetorch/utils/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class ReTanh(torch.nn.Module): 5 | ''' 6 | Rectified Hyperbolic Tangent 7 | ''' 8 | def __init__(self, alpha=1.000000): 9 | super().__init__() 10 | self.alpha = alpha 11 | 12 | def forward(self, x): 13 | beta = self.alpha / torch.tanh(torch.ones(1, dtype=torch.float64)) 14 | return torch.fmax(torch.zeros(x.shape, dtype=torch.float64), beta * torch.tanh(x)) -------------------------------------------------------------------------------- /deepwavetorch/utils/laplacian.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import scipy 5 | import scipy.sparse as sp 6 | import scipy.linalg as linalg 7 | import scipy.spatial as spatial 8 | import scipy.sparse.linalg as splinalg 9 | from scipy.sparse import coo_matrix 10 | from pygsp.graphs import Graph 11 | 12 | # Reference: https://github.com/deepsphere/deepsphere-pytorch/blob/master/deepsphere/utils/laplacian_funcs.py 13 | def scipy_csr_to_sparse_tensor(csr_mat): 14 | """Convert scipy csr to sparse pytorch tensor. 15 | Args: 16 | csr_mat (csr_matrix): The sparse scipy matrix. 17 | Returns: 18 | sparse_tensor :obj:`torch.sparse.DoubleTensor`: The sparse torch matrix. 19 | """ 20 | coo = coo_matrix(csr_mat) 21 | values = coo.data 22 | indices = np.vstack((coo.row, coo.col)) 23 | idx = torch.LongTensor(indices) 24 | vals = torch.DoubleTensor(values) 25 | shape = coo.shape 26 | sparse_tensor = torch.sparse.DoubleTensor(idx, vals, torch.Size(shape)) 27 | sparse_tensor = sparse_tensor.coalesce() 28 | return sparse_tensor 29 | 30 | 31 | def prepare_laplacian(laplacian): 32 | """Prepare a graph Laplacian to be fed to a graph convolutional layer. 33 | """ 34 | def estimate_lmax(laplacian, tol=5e-3): 35 | """Estimate the largest eigenvalue of an operator. 36 | """ 37 | lmax = sp.linalg.eigsh(laplacian, k=1, tol=tol, ncv=min(laplacian.shape[0], 10), return_eigenvectors=False) 38 | lmax = lmax[0] 39 | lmax *= 1 + 2 * tol # Be robust to errors. 40 | return lmax 41 | 42 | def scale_operator(L, lmax, scale=1): 43 | """Scale the eigenvalues from [0, lmax] to [-scale, scale]. 44 | """ 45 | I = sp.identity(L.shape[0], format=L.format, dtype=L.dtype) 46 | L *= 2 * scale / lmax 47 | L -= I 48 | return L 49 | 50 | lmax = estimate_lmax(laplacian) 51 | laplacian = scale_operator(laplacian, lmax) 52 | laplacian = scipy_csr_to_sparse_tensor(laplacian) 53 | return laplacian 54 | 55 | 56 | def cvxhull_graph(R: np.ndarray, cheb_normalized: bool = True, compute_differential_operator: bool = True): 57 | 58 | r""" 59 | Build the convex hull graph of a point set in :math:`\mathbb{R}^3`. 60 | The graph edges have exponential-decay weighting. 61 | Definitions of the graph Laplacians: 62 | 63 | .. math:: 64 | 65 | L = I - D^{-1/2} W D^{-1/2},\qquad L_{n} = (2 / \mu_{\max}) L - I 66 | 67 | Parameters 68 | ---------- 69 | R : :py:class:`~numpy.ndarray` 70 | (N,3) Cartesian coordinates of point set with size N. All points must be **distinct**. 71 | cheb_normalized : bool 72 | Rescale Laplacian spectrum to [-1, 1]. 73 | compute_differential_operator : bool 74 | Computes the graph gradient. 75 | 76 | Returns 77 | ------- 78 | G : :py:class:`~pygsp.graphs.Graph` 79 | If ``cheb_normalized = True``, ``G.Ln`` is created (Chebyshev Laplacian :math:`L_{n}` above) 80 | If ``compute_differential_operator = True``, ``G.D`` is created and contains the gradient. 81 | rho : float 82 | Scale parameter :math:`\rho` corresponding to the average distance of a point 83 | on the graph to its nearest neighbors. 84 | 85 | Examples 86 | -------- 87 | 88 | .. plot:: 89 | 90 | import numpy as np 91 | from pycgsp.graph import cvxhull_graph 92 | from pygsp.plotting import plot_graph 93 | theta, phi = np.linspace(0,np.pi,6, endpoint=False)[1:], np.linspace(0,2*np.pi,9, endpoint=False) 94 | theta, phi = np.meshgrid(theta, phi) 95 | x,y,z = np.cos(phi)*np.sin(theta), np.sin(phi)*np.sin(theta), np.cos(theta) 96 | R = np.stack((x.flatten(), y.flatten(), z.flatten()), axis=-1) 97 | G, _ = cvxhull_graph(R) 98 | plot_graph(G) 99 | 100 | Warnings 101 | -------- 102 | In the newest version of PyGSP (> 0.5.1) the convention is changed: ``Graph.D`` is the divergence operator and 103 | ``Graph.D.transpose()`` the gradient (see routine `Graph.compute_differential_operator `_). The code should be adapted when this new version is released. 104 | 105 | """ 106 | 107 | # Form convex hull to extract nearest neighbors. Each row in 108 | # cvx_hull.simplices is a triangle of connected points. 109 | cvx_hull = spatial.ConvexHull(R.T) 110 | cols = np.roll(cvx_hull.simplices, shift=1, axis=-1).reshape(-1) 111 | rows = cvx_hull.simplices.reshape(-1) 112 | 113 | # Form sparse affinity matrix from extracted pairs 114 | W = sp.coo_matrix((cols * 0 + 1, (rows, cols)), 115 | shape=(cvx_hull.vertices.size, cvx_hull.vertices.size)) 116 | # Symmetrize the matrix to obtain an undirected graph. 117 | extended_row = np.concatenate([W.row, W.col]) 118 | extended_col = np.concatenate([W.col, W.row]) 119 | W.row, W.col = extended_row, extended_col 120 | W.data = np.concatenate([W.data, W.data]) 121 | W = W.tocsr().tocoo() # Delete potential duplicate pairs 122 | 123 | # Weight matrix elements according to the exponential kernel 124 | distance = linalg.norm(cvx_hull.points[W.row, :] - 125 | cvx_hull.points[W.col, :], axis=-1) 126 | rho = np.mean(distance) 127 | W.data = np.exp(- (distance / rho) ** 2) 128 | W = W.tocsc() 129 | 130 | G = _graph_laplacian(W, R, compute_differential_operator=compute_differential_operator, 131 | cheb_normalized=cheb_normalized) 132 | return G, rho 133 | 134 | def _graph_laplacian(W, R, compute_differential_operator=False, cheb_normalized=False): 135 | ''' 136 | Form Graph Laplacian 137 | ''' 138 | G = Graph(W, gtype='undirected', lap_type='normalized', coords=R) 139 | G.compute_laplacian(lap_type='normalized') # Stored in G.L, sparse matrix, csc ordering 140 | if compute_differential_operator is True: 141 | G.compute_differential_operator() # stored in G.D, also accessible via G.grad() or G.div() (for the adjoint). 142 | else: 143 | pass 144 | 145 | if cheb_normalized: 146 | D_max = splinalg.eigsh(G.L, k=1, return_eigenvectors=False) 147 | Ln = (2 / D_max[0]) * G.L - sp.identity(W.shape[0], dtype=np.float64, format='csc') 148 | G.Ln = Ln 149 | else: 150 | pass 151 | return G 152 | 153 | def laplacian_exp(R, depth): 154 | """Get the icosahedron laplacian list for a certain depth. 155 | Args: 156 | R : :py:class:`~numpy.ndarray` 157 | (N,3) Cartesian coordinates of point set with size N. All points must be **distinct**. 158 | Returns: 159 | laplacian: `torch.Tensor` laplacian 160 | rho: float: laplacian order 161 | """ 162 | laps = [] 163 | for i in range(depth): 164 | G, rho = cvxhull_graph(R) 165 | laplacian = prepare_laplacian(G.L) 166 | laps.append(laplacian) 167 | return laps[::-1], rho 168 | -------------------------------------------------------------------------------- /deepwavetorch/utils/matrix_operations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as linalg 3 | import scipy.sparse.linalg as splinalg 4 | 5 | def estimate_chevpol_order(XYZ, rho, wl, eps): 6 | r""" 7 | Compute order of polynomial filter to approximate asymptotic 8 | point-spread function on \cS^{2}. 9 | 10 | Parameters 11 | ---------- 12 | XYZ : :py:class:`~numpy.ndarray` 13 | (3, N_antenna) Cartesian instrument coordinates. 14 | rho : float 15 | Scale parameter \rho corresponding to the average distance of a point 16 | on the graph to its nearest neighbor. 17 | Output of :py:func:`~deepwave.tools.math.graph.laplacian_exp`. 18 | wl : float 19 | Wavelength of observations [m]. 20 | eps : float 21 | Ratio in (0, 1). 22 | Ensures all PSF magnitudes lower than `max(PSF)*eps` past the main 23 | lobe are clipped at 0. 24 | 25 | Returns 26 | ------- 27 | K : int 28 | Order of polynomial filter. 29 | """ 30 | XYZ = XYZ / wl 31 | XYZ_centroid = np.mean(XYZ, axis=1, keepdims=True) 32 | XYZ_radius = np.mean(linalg.norm(XYZ - XYZ_centroid, axis=0)) 33 | 34 | theta = np.linspace(0, np.pi, 1000) 35 | f = 20 * np.log10(np.abs(np.sinc(theta / np.pi))) 36 | eps_dB = 10 * np.log10(eps) 37 | theta_max = np.max(theta[f >= eps_dB]) 38 | 39 | beam_width = theta_max / (2 * np.pi * XYZ_radius) 40 | K = np.sqrt(2 - 2 * np.cos(beam_width)) / rho 41 | K = int(np.ceil(K)) 42 | return K 43 | 44 | def steering_operator(XYZ, R, wl): 45 | ''' 46 | XYZ : :py:class:`~numpy.ndarray` 47 | (3, N_antenna) Cartesian array geometry. 48 | R : :py:class:`~numpy.ndarray` 49 | (3, N_px) Cartesian grid points in :math:`\mathbb{S}^{2}`. 50 | wl : float 51 | Wavelength [m]. 52 | return: steering matrix 53 | ''' 54 | scale = 2 * np.pi / wl 55 | A = np.exp((-1j * scale * XYZ.T) @ R) 56 | return A 57 | 58 | def eighMax(A): 59 | r""" 60 | Evaluate :math:`\mu_{\max}(\bbB)` with 61 | :math: 62 | B = (\overline{\bbA} \circ \bbA)^{H} (\overline{\bbA} \circ \bbA) 63 | Uses a matrix-free formulation of the Lanczos algorithm. 64 | Parameters 65 | ---------- 66 | A : :py:class:`~numpy.ndarray` 67 | (M, N) array. 68 | 69 | Returns 70 | ------- 71 | D_max : float 72 | Leading eigenvalue of `B`. 73 | """ 74 | if A.ndim != 2: 75 | raise ValueError('Parameter[A] has wrong dimensions.') 76 | 77 | def matvec(v): 78 | r""" 79 | Parameters 80 | ---------- 81 | v : :py:class:`~numpy.ndarray` 82 | (N,) or (N, 1) array 83 | 84 | Returns 85 | ------- 86 | w : :py:class:`~numpy.ndarray` 87 | (N,) array containing :math:`\bbB \bbv` 88 | """ 89 | v = v.reshape(-1) 90 | 91 | C = (A * v) @ A.conj().T 92 | D = C @ A 93 | w = np.sum(A.conj() * D, axis=0).real 94 | return w 95 | 96 | M, N = A.shape 97 | B = splinalg.LinearOperator(shape=(N, N), 98 | matvec=matvec, 99 | dtype=np.float64) 100 | D_max = splinalg.eigsh(B, k=1, which='LM', return_eigenvectors=False) 101 | return D_max[0] -------------------------------------------------------------------------------- /figures/DeepWave_fields_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianSRoman/DeepWaveTorch/7897779076428e2a69f36d7420b067f827236f89/figures/DeepWave_fields_comparison.png -------------------------------------------------------------------------------- /figures/task4_recording1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianSRoman/DeepWaveTorch/7897779076428e2a69f36d7420b067f827236f89/figures/task4_recording1.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astropy==5.2.1 2 | asttokens==2.2.1 3 | backcall==0.2.0 4 | basemap==1.3.6 5 | basemap-data==1.3.2 6 | comm==0.1.2 7 | contourpy==1.0.7 8 | cycler==0.11.0 9 | debugpy==1.6.5 10 | decorator==5.1.1 11 | entrypoints==0.4 12 | executing==1.2.0 13 | fonttools==4.38.0 14 | healpy==1.16.2 15 | ipykernel==6.20.2 16 | ipython==8.8.0 17 | jedi==0.18.2 18 | jupyter-client==7.4.9 19 | jupyter-core==5.1.3 20 | kiwisolver==1.4.4 21 | matplotlib==3.6.3 22 | matplotlib-inline==0.1.6 23 | nest-asyncio==1.5.6 24 | numpy==1.23.5 25 | packaging==23.0 26 | pandas==1.5.3 27 | parso==0.8.3 28 | pexpect==4.8.0 29 | pickleshare==0.7.5 30 | Pillow==9.4.0 31 | platformdirs==2.6.2 32 | prompt-toolkit==3.0.36 33 | psutil==5.9.4 34 | ptyprocess==0.7.0 35 | pure-eval==0.2.2 36 | pyerfa==2.0.0.1 37 | Pygments==2.14.0 38 | PyGSP==0.5.1 39 | pyparsing==3.0.9 40 | pyproj==1.9.6 41 | pyshp==2.3.1 42 | python-dateutil==2.8.2 43 | pytz==2022.7.1 44 | PyYAML==6.0 45 | pyzmq==25.0.0 46 | scipy==1.10.0 47 | six==1.16.0 48 | stack-data==0.6.2 49 | torch==1.11.0 50 | tornado==6.2 51 | tqdm==4.28.1 52 | traitlets==5.8.1 53 | typing-extensions==4.4.0 54 | wcwidth==0.2.6 55 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # ############################################################################ 2 | # setup.cfg 3 | # ========= 4 | # Author : Adrian S. ROMAN [asroman@ucdavis.edu] 5 | # ############################################################################ 6 | 7 | [metadata] 8 | name = deepwavetorch 9 | summary = DeepWave pytorch implementation. [] 10 | long_description = file:README.md 11 | long_description_content_type = text/x-rst; charset=UTF-8 12 | keywords = 13 | array signal processing 14 | first-order convex optimization 15 | inverse problems 16 | neural networks 17 | graph neural networks 18 | 19 | author = Adrian S. ROMAN [ucdavis, usc] 20 | author_email = asroman@ucdavis.edu 21 | url = https://github.com/adrianSRoman/DeepWaveTorch 22 | download_url = git@github.com:adrianSRoman/DeepWaveTorch.git 23 | 24 | classifiers = 25 | Intended Audience :: Science/Research 26 | License :: OSI Approved :: GNU General Public License v3 (GPLv3) 27 | Operating System :: POSIX :: Linux 28 | Programming Language :: Python :: 3 29 | Programming Language :: Python :: Implementation :: pytorch 30 | Topic :: Scientific/Engineering 31 | license = GPLv3 32 | 33 | [options] 34 | include_package_data = True 35 | python_requires = >=3.6 36 | zip_safe = False 37 | 38 | [files] 39 | packages = 40 | deepwavetorch 41 | data_files = 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # ############################################################################ 4 | # setup.py 5 | # ======== 6 | # Author : Adrian S ROMAN [asroman@ucdavis.edu] 7 | # ############################################################################ 8 | 9 | """ 10 | DeepWave setup script. 11 | """ 12 | 13 | from setuptools import setup 14 | 15 | setup(setup_requires=['pbr'], pbr=True) 16 | -------------------------------------------------------------------------------- /tracks/eigenmike_grid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianSRoman/DeepWaveTorch/7897779076428e2a69f36d7420b067f827236f89/tracks/eigenmike_grid.npy -------------------------------------------------------------------------------- /tracks/locata_task1_recording2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianSRoman/DeepWaveTorch/7897779076428e2a69f36d7420b067f827236f89/tracks/locata_task1_recording2.wav -------------------------------------------------------------------------------- /tracks/pretrained_freq0_locata_weights.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianSRoman/DeepWaveTorch/7897779076428e2a69f36d7420b067f827236f89/tracks/pretrained_freq0_locata_weights.npz -------------------------------------------------------------------------------- /tracks/task1_recording1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adrianSRoman/DeepWaveTorch/7897779076428e2a69f36d7420b067f827236f89/tracks/task1_recording1.wav --------------------------------------------------------------------------------