├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── complexity_analysis.py ├── datasets.py ├── debug.py ├── doit.sh ├── download_data.sh ├── figures ├── data_example.pdf ├── data_example.png └── dmel.png ├── init_dataset.py ├── main.py ├── models.py ├── panns.py ├── predict_test.py ├── produce_figures.py ├── produce_tables.py ├── requirements.txt ├── run_experiments.sh ├── run_test_predictions.sh ├── search_spaces.py ├── time_frequency.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | notebooks 132 | results 133 | data 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 John Martinsson 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2024 John Martinsson 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DMEL: The differentiable log-Mel spectrogram as a trainable layer in neural networks 2 | 3 | ![DMEL](figures/dmel.png) 4 | 5 | The official implementation of DMEL the method presented in the paper DMEL: The differentiable log-Mel spectrogram as a trainable layer in neural networks. 6 | 7 | [Paper](https://johnmartinsson.org/publications/2024/differentiable-log-mel-spectrogram) 8 | 9 | Cite as: 10 | 11 | @INPROCEEDINGS{Martinsson2024, 12 | author={Martinsson, John and Sandsten, Maria}, 13 | booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 14 | title={DMEL: The Differentiable Log-Mel Spectrogram as a Trainable Layer in Neural Networks}, 15 | year={2024}, 16 | volume={}, 17 | number={}, 18 | pages={5005-5009}, 19 | keywords={Neural networks;Transforms;Acoustics;Computational efficiency;Task analysis;Speech processing;Spectrogram;Deep learning;STFT;learnable Mel spectrogram;audio classification;adaptive transforms}, 20 | doi={10.1109/ICASSP48485.2024.10446816} 21 | } 22 | 23 | # Reproduce the results in the paper 24 | 25 | ## Install the environment 26 | Requires python < 3.11 27 | 28 | Using conda and pip (tested): 29 | 30 | conda create -n reproduce python==3.10 31 | conda activate reproduce 32 | conda install pip 33 | 34 | pip3 install -r requirements.txt 35 | 36 | ## Run everything in one script 37 | Run the doit.sh script to download the audio data, run the experiments and produce the plots. 38 | 39 | sh doit.sh 40 | 41 | which runs the commands 42 | 43 | # download the FSD dataset 44 | sh download_data.sh 45 | 46 | # run all the experiments (takes time ...) 47 | sh run_experiments.sh 48 | 49 | # run all test predictions 50 | sh run_test_predictions.sh 51 | 52 | # produce all the tables 53 | python produce_tables.py 54 | 55 | The default number of runs are set to 1. Run once to check how long time it takes and that the results are reasonable, then increase the number of runs to 10 (set in run_experiments.sh) as in the paper. 56 | 57 | # More details 58 | Some details on how to use the code. 59 | 60 | ## Run experiments 61 | An experiment is defined as a ray tune search space. The three search spaces presented in the paper are found in search_spaces.py, called esc50, audio_mnist and time_frequency. If you want to change the hyper parameter distribution simply modify the search space (e.g., https://docs.ray.io/en/latest/tune/api/search_space.html), or define another search space and update the appropriate line in main.py where the search space is loaded. 62 | 63 | The experiments will not reproduce the exact results in the paper, the random seed has never been fixed, but the same trends should of the averages and standard deviations should be observed when re-running the experiments. 64 | 65 | Run all the experiments. 66 | 67 | sh run_experiments.sh 68 | 69 | The code uses 0.25 GPUs and 4 CPUs per experiment, edit the tune.with_resources line in main.py if you want to use more or less GPUs or CPUs. Defaults to cuda:0 device. 70 | 71 | All search spaces has been defined using 'grid_search' in ray tune. This means that the command --num_samples will define the number of times each configuration in the grid is run. If this is re-written using e.g. 'uniform' or other sampling commands from the search space api (https://docs.ray.io/en/latest/tune/api/search_space.html), then --num_samples will define the number of samples from that distribution. 72 | -------------------------------------------------------------------------------- /complexity_analysis.py: -------------------------------------------------------------------------------- 1 | # Mel spectrogram 2 | 3 | import numpy as np 4 | 5 | import matplotlib.pyplot as plt 6 | import matplotlib as mpl 7 | 8 | plt.rcParams['text.usetex'] = True 9 | mpl.rc('font', family = 'serif') 10 | 11 | #init_mis = [0.01, 0.3] 12 | init_mis = [0.02, 0.3] 13 | mi_labels = [r'$l_{\lambda_{init}} = 20$ ms', r'$l_{\lambda_{init}} = 300$ ms'] 14 | #mi_labels = [r'$l_{\lambda_{init}} = 10$ ms', r'$l_{\lambda_{init}} = 300$ ms'] 15 | 16 | fig, ax = plt.subplots(1, 2, figsize=(5, 2.5)) #plt.figure(figsize=(4/2, 3/2)) 17 | for idx_mi in range(2): 18 | C1s = [0.0001, 0.9999] 19 | #c1_labels = [r'$C_2 \gg C_1$',r'$C_1 \gg C_2$'] 20 | c1_labels = ['Cost dominated by NN', 'Cost dominated by FFT'] 21 | 22 | for idx_C1 in range(2): 23 | C1 = C1s[idx_C1] 24 | C2 = 1-C1 25 | 26 | K_min, K_max = 1, 60 27 | Ks = np.arange(K_min, K_max) # number of models 28 | Fs = 8000 # sample rate 29 | n = Fs * 5 # signal length 30 | 31 | M = 128 # mel bins 32 | c = 0.010 * Fs # hop length 33 | lr = 0.001 # pseudo-learning rate 34 | 35 | init_mi = init_mis[idx_mi] # initial window length 36 | #print("init_mi: ", init_mi) 37 | opt_mi = 0.035 # optimal window length 38 | 39 | B = int(np.abs(init_mi - opt_mi)/lr) # number of forward passes 40 | #print(B) 41 | 42 | cost_quote = np.zeros(len(Ks)) 43 | 44 | for idx_K, K in enumerate(Ks): 45 | base_mi = np.linspace(c*2, 0.3 * Fs, K) # TODO: discuss this 46 | ours_mi = np.linspace(init_mi * Fs, opt_mi * Fs, B) 47 | 48 | cost_base_tf = B * C1 * np.sum(n * np.log(base_mi)) # TODO: w/o DA, or w. enough space B dissapears 49 | #print("base_mi: ", base_mi) 50 | #print("ours_mi: ", ours_mi) 51 | #cost_base_tf = B * C1 * np.sum(n/base_mi * np.log(base_mi)) # TODO: w/o DA, or w. enough space B dissapears 52 | cost_base_nn = B * C2 * np.sum(2 * M * n / base_mi) 53 | cost_base = cost_base_tf + cost_base_nn 54 | 55 | cost_ours_tf = C1 * n/c * np.sum(ours_mi * np.log(ours_mi)) 56 | #cost_ours_tf = C1 * n/c * np.sum(np.log(ours_mi)) 57 | cost_ours_nn = B * C2 * M * n / c 58 | cost_ours = cost_ours_tf + cost_ours_nn 59 | 60 | cost_quote[idx_K] = cost_ours / cost_base 61 | 62 | ax[idx_C1].plot(Ks, cost_quote, label=mi_labels[idx_mi]) 63 | #if idx_C1 == 0: 64 | # ax[idx_C1].plot(Ks, cost_quote, label=mi_labels[idx_mi]) 65 | #print(idx_mi) 66 | #else: 67 | # ax[idx_C1].plot(Ks, cost_quote) 68 | 69 | ax[idx_C1].set_title(c1_labels[idx_C1]) 70 | ax[idx_C1].set_xlabel('D') 71 | ax[idx_C1].set_ylim([0, 2.0]) 72 | 73 | plt.tight_layout() 74 | ax[1].hlines(1, color='purple', xmin=K_min, xmax=K_max, label='reference', linestyle='dashed') 75 | ax[0].hlines(1, color='purple', xmin=K_min, xmax=K_max, label='reference', linestyle='dashed') 76 | ax[0].set_ylabel(r'$C_{DMEL} / C_{baseline}$') 77 | ax[0].legend() 78 | #ax[1].set_yticks([]) 79 | ax[1].legend() 80 | plt.savefig('time_complexity.pdf', bbox_inches='tight') -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import librosa 4 | import tqdm 5 | import numpy as np 6 | import glob 7 | 8 | import time_frequency as tf 9 | 10 | def fmconst(n_points, fnorm=0.25): 11 | ts = torch.arange(n_points) 12 | random_phase = torch.rand(1) * (2*torch.pi) 13 | 14 | y = torch.sin(2.0 * torch.pi * fnorm * ts + random_phase) 15 | y = y / torch.max(y) 16 | return y 17 | 18 | def gauss_pulse(t_loc, f_loc, sigma, n_points): 19 | gauss_window = tf.gauss_whole(sigma, t_loc, n_points) 20 | fm_signal = fmconst(n_points, f_loc) 21 | gp = gauss_window * fm_signal 22 | 23 | return gp - torch.mean(gp) 24 | 25 | def torch_random_uniform(limits): 26 | r1, r2 = limits 27 | x = (r1 - r2) * torch.rand(1) + r2 28 | return x 29 | 30 | class GaussPulseDatasetTimeFrequency(torch.utils.data.Dataset): 31 | def __init__(self, sigma, n_points, noise_std, n_samples=10000, f_center_max_offset=0, t_center_max_offset=0, demo=False): 32 | 33 | self.xs = torch.empty((n_samples, n_points), dtype=torch.float64) 34 | self.ys = torch.empty((n_samples), dtype=torch.long) 35 | self.locs = torch.zeros((n_samples, 4), dtype=torch.float64) 36 | 37 | # maximum displacement limits for time-offset for pulses from center 38 | image_displacement = 5 39 | t_max = n_points / image_displacement 40 | f_max = 0.5 / image_displacement 41 | 42 | # lower displacement limits 43 | t_min = sigma 44 | f_min = 0.5 * (t_min / n_points) # = sigma / (2N) = sigma / K = 1/(2*pi*sigma) 45 | 46 | 47 | sigma_scale_max = (2*t_max)/(6*sigma) + 1 48 | # minimum duration scaling for pulses on center 49 | sigma_scale_min = 1 / sigma_scale_max 50 | 51 | 52 | # generate samples 53 | for idx in range(n_samples): 54 | if demo: 55 | f_center_offset = 0 56 | t_center_offset = 0 57 | else: 58 | f_center_offset = torch_random_uniform([-f_center_max_offset, f_center_max_offset]) 59 | t_center_offset = torch_random_uniform([-t_center_max_offset, t_center_max_offset]) 60 | 61 | t_center = t_center_offset + torch.tensor(n_points/2, dtype=torch.float) 62 | f_center = f_center_offset + 0.25 63 | 64 | if demo: 65 | f_offset = 0.5 * f_max 66 | t_offset = 0.5 * t_max 67 | else: 68 | f_offset = torch_random_uniform([f_min, f_max]) 69 | t_offset = torch_random_uniform([t_min, t_max]) 70 | 71 | y = np.random.choice([0, 1, 2]) 72 | 73 | if y == 0: 74 | # spread randomly along frequency or time axis 75 | r = np.random.choice([True, False]) 76 | if r: 77 | sigma_scale = torch_random_uniform([1.0, sigma_scale_max]) 78 | else: 79 | sigma_scale = torch_random_uniform([sigma_scale_min, 1.0]) 80 | 81 | if demo: 82 | sigma_scale = 1.0 83 | 84 | x = gauss_pulse(t_center, f_center, sigma*sigma_scale, n_points) 85 | 86 | # used to sanity check 87 | self.locs[idx, 0] = t_center 88 | self.locs[idx, 1] = f_center 89 | self.locs[idx, 2] = int(r) 90 | self.locs[idx, 3] = sigma_scale 91 | 92 | elif y == 1: 93 | f_loc = f_center 94 | t_loc_1 = t_center - t_offset 95 | t_loc_2 = t_center + t_offset 96 | 97 | x1 = gauss_pulse(t_loc_1, f_loc, sigma, n_points) 98 | x2 = gauss_pulse(t_loc_2, f_loc, sigma, n_points) 99 | x = x1 + x2 100 | 101 | # used to sanity check 102 | self.locs[idx, 0] = t_loc_1 103 | self.locs[idx, 1] = f_loc 104 | self.locs[idx, 2] = t_loc_2 105 | self.locs[idx, 3] = f_loc 106 | else: 107 | t_loc = t_center 108 | f_loc_1 = f_center - f_offset 109 | f_loc_2 = f_center + f_offset 110 | 111 | x1 = gauss_pulse(t_loc, f_loc_1, sigma, n_points) 112 | x2 = gauss_pulse(t_loc, f_loc_2, sigma, n_points) 113 | x = x1 + x2 114 | 115 | # used to sanity check 116 | self.locs[idx, 0] = t_loc 117 | self.locs[idx, 1] = f_loc_1 118 | self.locs[idx, 2] = t_loc 119 | self.locs[idx, 3] = f_loc_2 120 | 121 | # variability 122 | noise = noise_std * torch.rand(n_points) 123 | 124 | if demo: 125 | amplitude_scale = 1.0 126 | else: 127 | amplitude_scale = torch_random_uniform([0.5, 1]) 128 | x = (x * amplitude_scale) + noise 129 | 130 | self.ys[idx] = torch.tensor(y, dtype=torch.long) 131 | self.xs[idx] = x - torch.mean(x) 132 | 133 | def __len__(self): 134 | return len(self.xs) 135 | 136 | def __getitem__(self, idx): 137 | return self.xs[idx], self.ys[idx] 138 | 139 | def parse_row(row): 140 | filename = row[0] 141 | fold = int(row[1]) 142 | target = int(row[2]) 143 | category = row[3] 144 | 145 | return filename, fold, target, category 146 | 147 | def parse_csv(csv_file): 148 | with open(csv_file, 'r') as f: 149 | lines = f.readlines() 150 | meta = [] 151 | for line in lines[1:]: 152 | row = line.rstrip().split(',') 153 | filename, fold, target, category = parse_row(row) 154 | meta.append((filename, fold, target, category)) 155 | return meta 156 | 157 | def load_meta_data(source_dir): 158 | csv_file = os.path.join(source_dir, 'meta', 'esc50.csv') 159 | meta_data = parse_csv(csv_file) 160 | return meta_data 161 | 162 | class AudioMNISTBigDataset(torch.utils.data.Dataset): 163 | def __init__(self, wav_paths): 164 | self.xs = [] 165 | self.ys = [] 166 | 167 | sample_rates = [] 168 | for wav_path in wav_paths: 169 | audio, sr = librosa.load(wav_path, sr=None) # already 8000 Hz 170 | sample_rates.append(sr) 171 | target = int(os.path.basename(wav_path).split('_')[0]) 172 | 173 | x = audio.copy() 174 | # zero pad signal on both sides 175 | x = np.pad(x, 1 + (8000-len(x)) // 2)[:8000] 176 | self.xs.append(x) 177 | self.ys.append(target) 178 | 179 | assert len(list(set(self.ys))) == 10 # assert 10 classes 180 | 181 | self.xs = np.array(self.xs) 182 | self.ys = np.array(self.ys) 183 | 184 | # assert all files have the same sample rates 185 | assert len(list(set(sample_rates))) == 1 186 | # assert proper sample rate 187 | assert sample_rates[0] == 8000 188 | 189 | self.sample_rate = 8000 190 | 191 | def __len__(self): 192 | return len(self.xs) 193 | def __getitem__(self, idx): 194 | return self.xs[idx], self.ys[idx] 195 | 196 | 197 | class AudioMNISTDataset(torch.utils.data.Dataset): 198 | # VERSION: https://doi.org/10.5281/zenodo.1342401 199 | def __init__(self, source_dir): 200 | # load data 201 | wav_paths = glob.glob(os.path.join(source_dir, 'recordings', '*.wav')) 202 | 203 | self.xs = [] 204 | self.ys = [] 205 | 206 | sample_rates = [] 207 | for wav_path in wav_paths: 208 | audio, sr = librosa.load(wav_path, sr=None) # already 8000 Hz 209 | sample_rates.append(sr) 210 | target = int(os.path.basename(wav_path).split('_')[0]) 211 | 212 | if len(audio) >= 1500 and len(audio) <= 5500: 213 | x = audio.copy() 214 | x.resize(5500) # pad end with zeros up to max length 215 | self.xs.append(x) 216 | self.ys.append(target) 217 | 218 | assert len(list(set(self.ys))) == 10 # assert 10 classes 219 | 220 | self.xs = np.array(self.xs) 221 | self.ys = np.array(self.ys) 222 | 223 | # assert all files have the same sample rates 224 | assert len(list(set(sample_rates))) == 1 225 | # assert proper sample rate 226 | assert sample_rates[0] == 8000 227 | 228 | self.sample_rate = 8000 229 | 230 | def __len__(self): 231 | return len(self.xs) 232 | def __getitem__(self, idx): 233 | return self.xs[idx], self.ys[idx] 234 | 235 | class ESC50Dataset(torch.utils.data.Dataset): 236 | def __init__(self, source_dir, resample_rate=8000): 237 | meta_data = load_meta_data(source_dir) 238 | 239 | self.xs = [] 240 | self.ys = [] 241 | self.sample_rate = None 242 | 243 | xs_path = os.path.join(source_dir, "{}_xs.npy".format(resample_rate)) 244 | ys_path = os.path.join(source_dir, "{}_ys.npy".format(resample_rate)) 245 | 246 | if os.path.exists(xs_path) and os.path.exists(ys_path): 247 | self.xs = np.load(xs_path) 248 | self.ys = np.load(ys_path) 249 | self.sample_rate = resample_rate 250 | else: 251 | sample_rates = [] 252 | for (filename, fold, target, category) in tqdm.tqdm(meta_data): 253 | # load audio 254 | audio_file = os.path.join(source_dir, 'audio', filename) 255 | audio, sr = librosa.load(audio_file, sr=resample_rate, res_type='kaiser_fast') 256 | sample_rates.append(sr) 257 | self.xs.append(audio) 258 | self.ys.append(target) 259 | 260 | self.xs = np.array(self.xs) 261 | self.ys = np.array(self.ys) 262 | 263 | np.save(xs_path, self.xs) 264 | np.save(ys_path, self.ys) 265 | 266 | # assert all files have the same sample rates 267 | assert len(list(set(sample_rates))) == 1 268 | # assert proper resampling 269 | assert sample_rates[0] == resample_rate 270 | 271 | self.sample_rate = resample_rate 272 | 273 | def __len__(self): 274 | return len(self.xs) 275 | 276 | def __getitem__(self, idx): 277 | return self.xs[idx], self.ys[idx] 278 | -------------------------------------------------------------------------------- /debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | 4 | resample_rate = 8000 5 | config = { 6 | # model 7 | 'model_name' : 'mel_linear_net', 8 | 'n_mels' : 128, 9 | 'hop_length' :int(resample_rate * 0.01), 10 | 'energy_normalize' : True, 11 | 'optimized' : True, 12 | 'normalize_window' : False, 13 | 'augment' : False, 14 | 'trainable' : True, 15 | 16 | # training 17 | 'pretrained' : False, 18 | 'checkpoint_path' : '/home/john/gits/differentiable-time-frequency-transforms/weights/Cnn6_mAP=0.343.pth', 19 | 'optimizer_name' : 'adam', 20 | 'lr_model' : 1e-3, 21 | 'lr_tf' : 1e-1, 22 | 'batch_size' : 16, 23 | 'max_epochs' : 1, 24 | 'patience' : 10000, 25 | 'device' : 'cuda:0', 26 | 27 | # dataset 28 | 'resample_rate' : resample_rate, 29 | 'init_lambd' : resample_rate*0.025/6, 30 | 'dataset_name' : 'esc50', 31 | 'n_points' : resample_rate * 5, 32 | } 33 | 34 | trainset, validset, _ = utils.get_dataset_by_config(config, data_dir='./data/esc50') 35 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=config['batch_size'], shuffle=True, num_workers=2) 36 | validloader = torch.utils.data.DataLoader(validset, batch_size=config['batch_size'], shuffle=False, num_workers=2) 37 | 38 | net = utils.get_model_by_config(config) 39 | net.to(config['device']) 40 | net.spectrogram_layer.requires_grad_(config['trainable']) 41 | 42 | if 'panns' in config['model_name']: 43 | one_hot = True 44 | loss_fn = torch.nn.functional.binary_cross_entropy 45 | else: 46 | one_hot = False 47 | loss_fn = torch.nn.CrossEntropyLoss() 48 | 49 | 50 | if one_hot: 51 | # TODO: this won't work in general 52 | labels = torch.nn.functional.one_hot(labels, 50).float() 53 | 54 | for idx_batch, data in enumerate(trainloader): 55 | inputs, labels = data 56 | inputs, labels = inputs.to(config['device']), labels.to(config['device']) 57 | continue 58 | 59 | logits, s = net(inputs) 60 | loss = loss_fn(logits, labels) 61 | print("batch_loss = ", loss.item()) 62 | 63 | 64 | # debug logits 65 | print("logits: ", logits[0].detach().cpu().numpy()) 66 | print("softmax: ", torch.nn.functional.softmax(logits).detach().cpu().numpy()[0]) 67 | print("label: ", labels[0].detach().cpu().numpy()) 68 | print("spectrogram: ", s[0].detach().cpu().numpy()) 69 | 70 | # debug spectrogram 71 | 72 | -------------------------------------------------------------------------------- /doit.sh: -------------------------------------------------------------------------------- 1 | # download the FSD dataset 2 | echo "downloading all data ..." 3 | sh download_data.sh 4 | 5 | # run all the experiments (takes ~16h on a 2080Ti) 6 | echo "running all experiments ..." 7 | sh run_experiments.sh 8 | 9 | # run all test predictions 10 | echo "running all test predictions ..." 11 | sh run_test_predictions.sh 12 | 13 | # produce all the tables 14 | echo "producing all tables ..." 15 | python produce_tables.py --ray_results_dir=$(pwd)/ray_results/ 16 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | # clone datasets 2 | git clone https://github.com/soerenab/AudioMNIST.git 3 | git clone https://github.com/karolpiczak/ESC-50.git 4 | 5 | mkdir data 6 | 7 | # move datasets to data folder 8 | mv ESC-50/ data/esc50 9 | mv AudioMNIST/ data/audio_mnist 10 | 11 | # resample AudioMNIST to 8000 Hz 12 | echo "resample all Audio-MNIST files to 8000 Hz" 13 | for file in $(find ./data/audio_mnist -type f -name "*.wav"); do 14 | sox $file -r 8000 ${file%.wav}_8k.wav 15 | mv ${file%.wav}_8k.wav $file 16 | done 17 | 18 | # initialize datasets 19 | echo "initialize audio datasets ..." 20 | python3 init_dataset.py $(pwd)/data 21 | -------------------------------------------------------------------------------- /figures/data_example.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmartinsson/differentiable-mel-spectrogram/0667fc1beed34afd7c417a3d8b29a055b6cb8045/figures/data_example.pdf -------------------------------------------------------------------------------- /figures/data_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmartinsson/differentiable-mel-spectrogram/0667fc1beed34afd7c417a3d8b29a055b6cb8045/figures/data_example.png -------------------------------------------------------------------------------- /figures/dmel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johnmartinsson/differentiable-mel-spectrogram/0667fc1beed34afd7c417a3d8b29a055b6cb8045/figures/dmel.png -------------------------------------------------------------------------------- /init_dataset.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import sys 3 | 4 | def main(): 5 | 6 | # ETC50 7 | config = { 8 | 'dataset_name': 'esc50', 9 | 'resample_rate': 8000, 10 | } 11 | 12 | base_data_dir = sys.argv[1] 13 | data_dir = base_data_dir + '/esc50' 14 | 15 | # get the dataset once to initialize dataset files 16 | trainset, validset, testset = utils.get_dataset_by_config(config, data_dir) 17 | print("ETC50 data. trainset: {}, validset: {}, testset: {}".format(len(trainset), len(validset), len(testset))) 18 | 19 | # Audio-MNIST 20 | config = { 21 | 'dataset_name': 'audio_mnist', 22 | } 23 | 24 | data_dir = base_data_dir + '/audio_mnist' 25 | 26 | # get the dataset once to initialize dataset files 27 | trainset, validset, testset = utils.get_dataset_by_config(config, data_dir) 28 | print("Audio-MNIST data. trainset: {}, validset: {}, testset: {}".format(len(trainset), len(validset), len(testset))) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from functools import partial 4 | 5 | from ray import tune 6 | from ray import air 7 | from ray.tune import CLIReporter 8 | 9 | import argparse 10 | 11 | import datasets 12 | import models 13 | import train 14 | import utils 15 | import search_spaces 16 | 17 | def run_experiment(config, data_dir): 18 | # load dataset 19 | trainset, validset, _ = utils.get_dataset_by_config(config, data_dir) 20 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=config['batch_size'], shuffle=True, num_workers=2) 21 | validloader = torch.utils.data.DataLoader(validset, batch_size=config['batch_size'], shuffle=False, num_workers=2) 22 | 23 | # load model 24 | net = utils.get_model_by_config(config) 25 | net.to(config['device']) 26 | 27 | net.spectrogram_layer.requires_grad_(config['trainable']) 28 | 29 | # pre-trained 30 | if config['model_name'] == 'panns_cnn6' and config['pretrained'] is not None: 31 | if config['pretrained']: 32 | checkpoint_path = config['checkpoint_path'] 33 | # load weights 34 | utils.load_checkpoint(net, checkpoint_path=checkpoint_path, device=config['device']) 35 | 36 | parameters = [] 37 | for idx, (name, param) in enumerate(net.named_parameters()): 38 | 39 | if name == "spectrogram_layer.lambd": 40 | parameters += [{ 41 | 'params' : [param], 42 | 'lr' : config['lr_tf'], 43 | }] 44 | else: 45 | parameters += [{ 46 | 'params' : [param], 47 | 'lr' : config['lr_model'], 48 | }] 49 | 50 | if config['optimizer_name'] == 'sgd': 51 | optimizer = torch.optim.SGD(parameters) 52 | elif config['optimizer_name'] == 'adam': 53 | optimizer = torch.optim.Adam(parameters) 54 | else: 55 | raise ValueError("optimizer not found: ", config['optimizer_name']) 56 | 57 | # PANNs models are trained with binary cross entropy 58 | if 'panns' in config['model_name']: 59 | one_hot = True 60 | loss_fn = torch.nn.functional.binary_cross_entropy 61 | else: 62 | one_hot = False 63 | loss_fn = torch.nn.CrossEntropyLoss() 64 | 65 | # TODO: this is not doing anything since gamma = 1.0 66 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 67 | step_size = 20, # Period of learning rate decay 68 | gamma = 1.0) # Multiplicative factor of learning rate decay 69 | 70 | net, history = train.train_model( 71 | net=net, 72 | optimizer=optimizer, 73 | loss_fn=loss_fn, 74 | trainloader=trainloader, 75 | validloader=validloader, 76 | scheduler=scheduler, 77 | patience=config['patience'], 78 | max_epochs=config['max_epochs'], 79 | verbose=0, 80 | device=config['device'], 81 | one_hot=one_hot, 82 | n_classes=10 if 'audio_mnist' in config['dataset_name'] else 50, 83 | ) 84 | 85 | def main(): 86 | 87 | parser = argparse.ArgumentParser(description='Hyperparameter search.') 88 | parser.add_argument('--num_samples', help='The number of hyperparameter samples.', required=True, type=int) 89 | parser.add_argument('--max_epochs', help='The maximum number of epochs.', required=True, type=int) 90 | parser.add_argument('--name', help='The name of the hyperparamter search experiment.', required=True, type=str) 91 | parser.add_argument('--ray_root_dir', help='The name of the directory to save the ray search results.', required=True, type=str) 92 | parser.add_argument('--data_dir', help='The absolute path to the audio-mnist directory.', required=True, type=str) 93 | args = parser.parse_args() 94 | 95 | # hyperparamter search space 96 | if "audio_mnist" in args.name: 97 | search_space = search_spaces.audio_mnist(args.max_epochs) 98 | elif "esc50" in args.name: 99 | search_space = search_spaces.esc50(args.max_epochs) 100 | elif "time_frequency" in args.name: 101 | search_space = search_spaces.time_frequency(args.max_epochs) 102 | else: 103 | raise ValueError("search space not found ...") 104 | 105 | 106 | # results terminal reporter 107 | reporter = CLIReporter( 108 | metric_columns=[ 109 | "loss", 110 | "valid_loss", 111 | "valid_acc", 112 | "best_valid_acc", 113 | "lambd_est", 114 | "training_iteration", 115 | ], 116 | parameter_columns = [ 117 | 'init_lambd', 118 | 'trainable', 119 | 'speaker_id', 120 | #'augment', 121 | #'lr_tf', 122 | #'pretrained', 123 | #'normalize_window', 124 | 'model_name', 125 | ], 126 | max_column_length = 10 127 | ) 128 | 129 | run_experiment_fn = partial(run_experiment, data_dir=args.data_dir) 130 | 131 | trainable_with_resources = tune.with_resources(run_experiment_fn, {"cpu" : 4.0, "gpu": 0.25}) 132 | 133 | tuner = tune.Tuner( 134 | trainable_with_resources, 135 | param_space = search_space, 136 | run_config = air.RunConfig( 137 | verbose=1, 138 | progress_reporter = reporter, 139 | name = args.name, 140 | local_dir = args.ray_root_dir, 141 | ), 142 | tune_config = tune.TuneConfig( 143 | num_samples = args.num_samples, 144 | ), 145 | ) 146 | 147 | result = tuner.fit() 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchaudio 5 | 6 | #from torchlibrosa.stft import Spectrogram, LogmelFilterBank 7 | 8 | import panns 9 | import time_frequency as tf 10 | 11 | ############################################################################### 12 | # Differentiable Mel spectrogram 13 | ############################################################################### 14 | class MelSpectrogramLayer(nn.Module): 15 | def __init__(self, init_lambd, n_mels, n_points, sample_rate, f_min=0, f_max=None, hop_length=1, device='cpu', optimized=False, normalize_window=False): 16 | super(MelSpectrogramLayer, self).__init__() 17 | 18 | self.hop_length = hop_length 19 | self.lambd = nn.Parameter(init_lambd) 20 | self.device = device 21 | self.optimized = optimized 22 | self.normalize_window = normalize_window 23 | 24 | self.f_min = f_min 25 | self.f_max = f_max if f_max is not None else sample_rate // 2 26 | self.n_mels = n_mels 27 | self.sample_rate = sample_rate 28 | 29 | self.n_freq = n_mels 30 | self.n_time = n_points // hop_length + 1 31 | 32 | 33 | def forward(self, x): 34 | 35 | (batch_size, n_points) = x.shape 36 | mel_spectrograms = torch.empty((batch_size, 1, self.n_freq, self.n_time), dtype=torch.float32).to(self.device) 37 | for idx in range(batch_size): 38 | spectrogram = tf.differentiable_spectrogram(x[idx]-torch.mean(x[idx]), torch.abs(self.lambd), device=self.device, optimized=self.optimized, hop_length=self.hop_length, norm=self.normalize_window) 39 | 40 | (n_freq, _) = spectrogram.shape 41 | 42 | mel_fb = torchaudio.functional.melscale_fbanks( 43 | n_freqs = n_freq, 44 | f_min = self.f_min, 45 | f_max = self.f_max, 46 | n_mels = self.n_mels, 47 | sample_rate = self.sample_rate, 48 | ) 49 | 50 | mel_fb = mel_fb.to(self.device) 51 | mel_fb = mel_fb.to(spectrogram.dtype) 52 | 53 | mel_spectrogram = torch.matmul(spectrogram.transpose(-1, -2), mel_fb).transpose(-1, -2) 54 | mel_spectrograms[idx,:,:,:] = torch.unsqueeze(mel_spectrogram, axis=0) 55 | 56 | return mel_spectrograms 57 | 58 | class MelLinearNet(nn.Module): 59 | def __init__(self, n_classes, init_lambd, device, n_mels, sample_rate, n_points, hop_length=1, optimized=False, energy_normalize=False, normalize_window=False): 60 | super(MelLinearNet, self).__init__() 61 | self.spectrogram_layer = MelSpectrogramLayer(init_lambd, n_mels=n_mels, n_points=n_points, sample_rate=sample_rate, hop_length=hop_length, device=device, optimized=optimized, normalize_window=normalize_window) 62 | self.device = device 63 | self.size = (n_mels, n_points // hop_length + 1) 64 | self.energy_normalize = energy_normalize 65 | 66 | self.fc = nn.Linear(self.size[0] * self.size[1], n_classes) 67 | 68 | def forward(self, x): 69 | # compute spectrograms 70 | s = self.spectrogram_layer(x) 71 | # normalization of s? PCEN? 72 | if self.energy_normalize: 73 | s = torch.log(s + 1e-10) 74 | 75 | # dropout on spectrogram 76 | x = F.dropout(s.view(-1, self.size[0] * self.size[1]), p=0.2) 77 | x = self.fc(x) 78 | return x, s 79 | 80 | class MelMlpNet(nn.Module): 81 | def __init__(self, n_classes, init_lambd, device, n_mels, sample_rate, n_points, hop_length=1, optimized=False, energy_normalize=False, normalize_window=False): 82 | super(MelMlpNet, self).__init__() 83 | self.spectrogram_layer = MelSpectrogramLayer(init_lambd, n_mels=n_mels, n_points=n_points, sample_rate=sample_rate, hop_length=hop_length, device=device, optimized=optimized, normalize_window=normalize_window) 84 | self.device = device 85 | self.size = (n_mels, n_points // hop_length + 1) 86 | 87 | self.fc1 = nn.Linear(self.size[0] * self.size[1], 32) 88 | self.fc2 = nn.Linear(32, n_classes) 89 | self.energy_normalize = energy_normalize 90 | 91 | def forward(self, x): 92 | # compute spectrograms 93 | s = self.spectrogram_layer(x) 94 | 95 | # normalization of s? PCEN? 96 | if self.energy_normalize: 97 | s = torch.log(s + 1e-10) 98 | 99 | x = self.fc1(s.view(-1, self.size[0] * self.size[1])) 100 | x = F.relu(x) 101 | x = F.dropout(x, p=0.2) 102 | x = self.fc2(x) 103 | return x, s 104 | 105 | class MelConvNet(nn.Module): 106 | def __init__(self, n_classes, init_lambd, device, n_mels, sample_rate, n_points, hop_length=1, optimized=False, energy_normalize=False, normalize_window=False): 107 | super(MelConvNet, self).__init__() 108 | self.spectrogram_layer = MelSpectrogramLayer(init_lambd, n_mels=n_mels, n_points=n_points, sample_rate=sample_rate, hop_length=hop_length, device=device, optimized=optimized, normalize_window=normalize_window) 109 | 110 | self.device = device 111 | self.size = (n_mels, n_points // hop_length + 1) 112 | self.energy_normalize = energy_normalize 113 | 114 | self.hidden_state = 32 115 | 116 | self.conv1 = nn.Conv2d(1, self.hidden_state, 5, padding='same') 117 | self.fc1 = nn.Linear(self.hidden_state * (self.size[0]) * (self.size[1]), self.hidden_state) 118 | self.fc2 = nn.Linear(self.hidden_state, n_classes) 119 | 120 | def forward(self, x): 121 | # compute spectrograms 122 | s = self.spectrogram_layer(x) 123 | 124 | # normalization of s? PCEN? 125 | if self.energy_normalize: 126 | s = torch.log(s + 1e-10) 127 | 128 | x = self.conv1(s) 129 | x = F.relu(x) 130 | 131 | x = x.view(-1, self.hidden_state * (self.size[0]) * (self.size[1])) 132 | x = self.fc1(x) 133 | x = F.relu(x) 134 | x = self.fc2(x) 135 | 136 | return x, s 137 | 138 | class MelPANNsNet(nn.Module): 139 | def __init__(self, n_classes, init_lambd, device, n_mels, sample_rate, n_points, hop_length=1, optimized=False, energy_normalize=False, normalize_window=False, augment=False): 140 | super(MelPANNsNet, self).__init__() 141 | 142 | self.energy_normalize = energy_normalize 143 | 144 | # Mel spectrogram extractor 145 | self.spectrogram_layer = MelSpectrogramLayer(init_lambd, n_mels=n_mels, n_points=n_points, sample_rate=sample_rate, hop_length=hop_length, device=device, optimized=optimized, normalize_window=normalize_window) 146 | #self.spectrogram_layer = MelSpectrogramLayerDebug(n_mels=n_mels) 147 | 148 | self.spectrogram_model = panns.Cnn6(n_classes, n_mels, augment=augment) 149 | 150 | def forward(self, x): 151 | """ Input (batch_size, n_points) """ 152 | 153 | # shape (batch_size, 1, mel_bins, time_steps) 154 | s = self.spectrogram_layer(x) 155 | 156 | if self.energy_normalize: 157 | s = torch.log(s + 1e-10) 158 | 159 | # shape (batch_size, 1, time_steps, mel_bins) 160 | #print("###################################################################") 161 | #print("SHAPE: ", s.shape) 162 | x = s.transpose(2, 3) 163 | 164 | x = self.spectrogram_model(x) 165 | 166 | return x, s 167 | 168 | ############################################################################### 169 | # Differentiable spectrogram 170 | ############################################################################### 171 | class SpectrogramLayer(nn.Module): 172 | def __init__(self, init_lambd, device='cpu', optimized=False, size=(512, 1024), hop_length=1, normalize_window=False): 173 | super(SpectrogramLayer, self).__init__() 174 | 175 | self.hop_length = hop_length 176 | self.lambd = nn.Parameter(init_lambd) 177 | self.device = device 178 | self.size = size #(512, 1024) 179 | self.optimized = optimized 180 | self.normalize_window = normalize_window 181 | 182 | def forward(self, x): 183 | 184 | (batch_size, n_points) = x.shape 185 | if self.optimized: 186 | spectrograms = torch.empty((batch_size, 1, self.size[0], self.size[1]), dtype=torch.float32).to(self.device) 187 | else: 188 | # redundancy in spectrogram 189 | spectrograms = torch.empty((batch_size, 1, n_points + 1, n_points // self.hop_length + 1), dtype=torch.float32).to(self.device) 190 | 191 | for idx in range(batch_size): 192 | spectrogram = tf.differentiable_spectrogram(x[idx]-torch.mean(x[idx]), torch.abs(self.lambd), optimized=self.optimized, device=self.device, hop_length=self.hop_length, norm=self.normalize_window) 193 | 194 | #if self.optimized: 195 | # spectrogram = F.interpolate(torch.unsqueeze(torch.unsqueeze(spectrogram, axis=0), axis=0), size=(self.size[0], self.size[1])) 196 | # spectrogram = spectrogram[0,0] 197 | 198 | spectrograms[idx,:,:,:] = torch.unsqueeze(spectrogram, axis=0) 199 | 200 | return spectrograms 201 | 202 | 203 | class MlpNet(nn.Module): 204 | def __init__(self, n_classes, init_lambd, device, optimized=False, size=(512, 1024), hop_length=1, normalize_window=False): 205 | super(MlpNet, self).__init__() 206 | self.spectrogram_layer = SpectrogramLayer(init_lambd, device=device, optimized=optimized, size=size, hop_length=hop_length, normalize_window=normalize_window) 207 | self.device = device 208 | self.size = size 209 | 210 | self.fc1 = nn.Linear(size[0] * size[1], 128) 211 | self.fc2 = nn.Linear(128, n_classes) 212 | 213 | def forward(self, x): 214 | # compute spectrograms 215 | s = self.spectrogram_layer(x) 216 | x = self.fc1(s.view(-1, self.size[0] * self.size[1])) 217 | x = F.relu(x) 218 | #x = F.dropout(x, p=0.2) 219 | x = self.fc2(x) 220 | return x, s 221 | 222 | class LinearNet(nn.Module): 223 | def __init__(self, n_classes, init_lambd, device, optimized=False, size=(512, 1024), hop_length=1, normalize_window=False): 224 | super(LinearNet, self).__init__() 225 | self.spectrogram_layer = SpectrogramLayer(init_lambd, device=device, optimized=optimized, size=size, hop_length=hop_length, normalize_window=normalize_window) 226 | 227 | self.device = device 228 | self.size = size 229 | 230 | self.fc = nn.Linear(size[0] * size[1], n_classes) 231 | 232 | def forward(self, x): 233 | # compute spectrograms 234 | s = self.spectrogram_layer(x) 235 | #x = F.dropout(s.view(-1, self.size[0] * self.size[1]), p=0.2) 236 | x = s.view(-1, self.size[0] * self.size[1]) 237 | x = self.fc(x) 238 | return x, s 239 | 240 | class BatchNormLinearNet(nn.Module): 241 | def __init__(self, n_classes, init_lambd, device, optimized=False, size=(512, 1024), hop_length=1, normalize_window=False): 242 | super(BatchNormLinearNet, self).__init__() 243 | self.spectrogram_layer = SpectrogramLayer(init_lambd, device=device, optimized=optimized, size=size, hop_length=hop_length, normalize_window=normalize_window) 244 | 245 | self.device = device 246 | self.size = size 247 | 248 | self.fc = nn.Linear(size[0] * size[1], n_classes) 249 | self.bn = torch.nn.BatchNorm2d(size[0]) 250 | 251 | def forward(self, x): 252 | # compute spectrograms 253 | s = self.spectrogram_layer(x) 254 | 255 | s = s.transpose(1, 2) 256 | s = self.bn(s) 257 | s = s.transpose(1, 2) 258 | 259 | x = s.view(-1, self.size[0] * self.size[1]) 260 | x = self.fc(x) 261 | return x, s 262 | 263 | 264 | class ConvNet(nn.Module): 265 | def __init__(self, n_classes, init_lambd, device, optimized=False, size=(512, 1024), hop_length=1, normalize_window=False): 266 | super(ConvNet, self).__init__() 267 | self.spectrogram_layer = SpectrogramLayer(init_lambd, device=device, optimized=optimized, size=size, hop_length=hop_length, normalize_window=normalize_window) 268 | 269 | self.device = device 270 | self.size = size 271 | 272 | self.hidden_state = 32 273 | 274 | self.conv1 = nn.Conv2d(1, self.hidden_state, 5, padding='same') 275 | self.fc1 = nn.Linear(self.hidden_state * (size[0]) * (size[1]), self.hidden_state) 276 | self.fc2 = nn.Linear(self.hidden_state, n_classes) 277 | 278 | #self.dropout = nn.Dropout(p=0.2) 279 | 280 | def forward(self, x): 281 | # compute spectrograms 282 | s = self.spectrogram_layer(x) 283 | 284 | x = self.conv1(s) 285 | x = F.relu(x) 286 | 287 | x = x.view(-1, self.hidden_state * (self.size[0]) * (self.size[1])) 288 | x = self.fc1(x) 289 | x = F.relu(x) 290 | #x = self.dropout(x) 291 | x = self.fc2(x) 292 | 293 | return x, s 294 | 295 | class MelSpectrogramLayerDebug(nn.Module): 296 | def __init__(self, sample_rate=8000, n_mels=128, window_size=1024, hop_length=320): 297 | super(MelSpectrogramLayerDebug, self).__init__() 298 | 299 | # Spectrogram extractor 300 | self.mel_spectrogram_extractor = torchaudio.transforms.MelSpectrogram( 301 | sample_rate = sample_rate, 302 | n_fft = window_size, 303 | win_length = window_size, 304 | hop_length = hop_length, 305 | f_min = 50, 306 | f_max = 4000, 307 | n_mels = n_mels, 308 | pad_mode = 'reflect' 309 | ) 310 | 311 | def forward(self, x): 312 | x = self.mel_spectrogram_extractor(x) 313 | x = torch.unsqueeze(x, 1) 314 | return x 315 | 316 | 317 | -------------------------------------------------------------------------------- /panns.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | #from torchlibrosa.stft import Spectrogram, LogmelFilterBank 6 | 7 | def init_layer(layer): 8 | """Initialize a Linear or Convolutional layer. """ 9 | nn.init.xavier_uniform_(layer.weight) 10 | 11 | if hasattr(layer, 'bias'): 12 | if layer.bias is not None: 13 | layer.bias.data.fill_(0.) 14 | 15 | 16 | def init_bn(bn): 17 | """Initialize a Batchnorm layer. """ 18 | bn.bias.data.fill_(0.) 19 | bn.weight.data.fill_(1.) 20 | 21 | 22 | class ConvBlock(nn.Module): 23 | def __init__(self, in_channels, out_channels): 24 | 25 | super(ConvBlock, self).__init__() 26 | 27 | self.conv1 = nn.Conv2d(in_channels=in_channels, 28 | out_channels=out_channels, 29 | kernel_size=(3, 3), stride=(1, 1), 30 | padding=(1, 1), bias=False) 31 | 32 | self.conv2 = nn.Conv2d(in_channels=out_channels, 33 | out_channels=out_channels, 34 | kernel_size=(3, 3), stride=(1, 1), 35 | padding=(1, 1), bias=False) 36 | 37 | self.bn1 = nn.BatchNorm2d(out_channels) 38 | self.bn2 = nn.BatchNorm2d(out_channels) 39 | 40 | self.init_weight() 41 | 42 | def init_weight(self): 43 | init_layer(self.conv1) 44 | init_layer(self.conv2) 45 | init_bn(self.bn1) 46 | init_bn(self.bn2) 47 | 48 | 49 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 50 | 51 | x = input 52 | x = F.relu_(self.bn1(self.conv1(x))) 53 | x = F.relu_(self.bn2(self.conv2(x))) 54 | if pool_type == 'max': 55 | x = F.max_pool2d(x, kernel_size=pool_size) 56 | elif pool_type == 'avg': 57 | x = F.avg_pool2d(x, kernel_size=pool_size) 58 | elif pool_type == 'avg+max': 59 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 60 | x2 = F.max_pool2d(x, kernel_size=pool_size) 61 | x = x1 + x2 62 | else: 63 | raise Exception('Incorrect argument!') 64 | 65 | return x 66 | 67 | 68 | class ConvBlock5x5(nn.Module): 69 | def __init__(self, in_channels, out_channels): 70 | 71 | super(ConvBlock5x5, self).__init__() 72 | 73 | self.conv1 = nn.Conv2d(in_channels=in_channels, 74 | out_channels=out_channels, 75 | kernel_size=(5, 5), stride=(1, 1), 76 | padding=(2, 2), bias=False) 77 | 78 | self.bn1 = nn.BatchNorm2d(out_channels) 79 | 80 | self.init_weight() 81 | 82 | def init_weight(self): 83 | init_layer(self.conv1) 84 | init_bn(self.bn1) 85 | 86 | 87 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 88 | 89 | x = input 90 | x = F.relu_(self.bn1(self.conv1(x))) 91 | if pool_type == 'max': 92 | x = F.max_pool2d(x, kernel_size=pool_size) 93 | elif pool_type == 'avg': 94 | x = F.avg_pool2d(x, kernel_size=pool_size) 95 | elif pool_type == 'avg+max': 96 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 97 | x2 = F.max_pool2d(x, kernel_size=pool_size) 98 | x = x1 + x2 99 | else: 100 | raise Exception('Incorrect argument!') 101 | 102 | return x 103 | 104 | 105 | class AttBlock(nn.Module): 106 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 107 | super(AttBlock, self).__init__() 108 | 109 | self.activation = activation 110 | self.temperature = temperature 111 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 112 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 113 | 114 | self.bn_att = nn.BatchNorm1d(n_out) 115 | self.init_weights() 116 | 117 | def init_weights(self): 118 | init_layer(self.att) 119 | init_layer(self.cla) 120 | init_bn(self.bn_att) 121 | 122 | def forward(self, x): 123 | # x: (n_samples, n_in, n_time) 124 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 125 | cla = self.nonlinear_transform(self.cla(x)) 126 | x = torch.sum(norm_att * cla, dim=2) 127 | return x, norm_att, cla 128 | 129 | def nonlinear_transform(self, x): 130 | if self.activation == 'linear': 131 | return x 132 | elif self.activation == 'sigmoid': 133 | return torch.sigmoid(x) 134 | 135 | class Cnn6(nn.Module): 136 | def __init__(self, classes_num, n_mels, augment=False): 137 | super(Cnn6, self).__init__() 138 | 139 | # augmentation 140 | self.augment = augment 141 | mask_time = torchaudio.transforms.TimeMasking(time_mask_param=64, iid_masks=True) 142 | mask_freq = torchaudio.transforms.FrequencyMasking(freq_mask_param=8, iid_masks=True) 143 | self.mask_time = mask_time 144 | self.mask_freq = mask_freq 145 | 146 | self.bn1 = nn.BatchNorm2d(n_mels) 147 | 148 | self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) 149 | self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) 150 | self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) 151 | self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) 152 | 153 | self.fc1 = nn.Linear(512, 512, bias=True) 154 | self.fc_esc50 = nn.Linear(512, classes_num, bias=True) 155 | 156 | self.init_weight() 157 | 158 | def init_weight(self): 159 | init_bn(self.bn1) 160 | init_layer(self.fc1) 161 | init_layer(self.fc_esc50) 162 | 163 | def forward(self, x): 164 | """ 165 | Input: (batch_size, 1, time_steps, mel_bins)""" 166 | 167 | #print("MEL SHAPE: ", x.shape) 168 | 169 | x = x.transpose(1, 3) 170 | # (N = batch_size, C = mel_bins, H = time_steps, W = 1) 171 | x = self.bn1(x) 172 | x = x.transpose(1, 3) 173 | 174 | if self.training and self.augment: 175 | x = self.mask_time(x) 176 | x = self.mask_freq(x) 177 | 178 | # Mixup on spectrogram 179 | #if self.training and mixup_lambda is not None: 180 | # x = do_mixup(x, mixup_lambda) 181 | 182 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 183 | x = F.dropout(x, p=0.2, training=self.training) 184 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 185 | x = F.dropout(x, p=0.2, training=self.training) 186 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 187 | x = F.dropout(x, p=0.2, training=self.training) 188 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 189 | x = F.dropout(x, p=0.2, training=self.training) 190 | x = torch.mean(x, dim=3) 191 | 192 | (x1, _) = torch.max(x, dim=2) 193 | x2 = torch.mean(x, dim=2) 194 | x = x1 + x2 195 | x = F.dropout(x, p=0.5, training=self.training) 196 | x = F.relu_(self.fc1(x)) 197 | embedding = F.dropout(x, p=0.5, training=self.training) 198 | 199 | # TODO: should there be a sigmoid activation here? 200 | clipwise_output = torch.sigmoid(self.fc_esc50(x)) 201 | 202 | return clipwise_output 203 | 204 | 205 | class Cnn14(nn.Module): 206 | def __init__(self, sample_rate, window_size, hop_size, mel_bins, fmin, 207 | fmax, classes_num): 208 | 209 | super(Cnn14, self).__init__() 210 | 211 | window = 'hann' 212 | center = True 213 | pad_mode = 'reflect' 214 | ref = 1.0 215 | amin = 1e-10 216 | top_db = None 217 | 218 | # Spectrogram extractor 219 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 220 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 221 | freeze_parameters=True) 222 | 223 | # Logmel feature extractor 224 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 225 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 226 | freeze_parameters=True) 227 | 228 | # Spec augmenter 229 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 230 | freq_drop_width=8, freq_stripes_num=2) 231 | 232 | self.bn0 = nn.BatchNorm2d(64) 233 | 234 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 235 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 236 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 237 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 238 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 239 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 240 | 241 | self.fc1 = nn.Linear(2048, 2048, bias=True) 242 | self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 243 | 244 | self.init_weight() 245 | 246 | def init_weight(self): 247 | init_bn(self.bn0) 248 | init_layer(self.fc1) 249 | init_layer(self.fc_audioset) 250 | 251 | def forward(self, input, mixup_lambda=None): 252 | """ 253 | Input: (batch_size, data_length)""" 254 | 255 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 256 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 257 | 258 | x = x.transpose(1, 3) 259 | x = self.bn0(x) 260 | x = x.transpose(1, 3) 261 | 262 | if self.training: 263 | x = self.spec_augmenter(x) 264 | 265 | # Mixup on spectrogram 266 | if self.training and mixup_lambda is not None: 267 | x = do_mixup(x, mixup_lambda) 268 | 269 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 270 | x = F.dropout(x, p=0.2, training=self.training) 271 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 272 | x = F.dropout(x, p=0.2, training=self.training) 273 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 274 | x = F.dropout(x, p=0.2, training=self.training) 275 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 276 | x = F.dropout(x, p=0.2, training=self.training) 277 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 278 | x = F.dropout(x, p=0.2, training=self.training) 279 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 280 | x = F.dropout(x, p=0.2, training=self.training) 281 | x = torch.mean(x, dim=3) 282 | 283 | (x1, _) = torch.max(x, dim=2) 284 | x2 = torch.mean(x, dim=2) 285 | x = x1 + x2 286 | x = F.dropout(x, p=0.5, training=self.training) 287 | x = F.relu_(self.fc1(x)) 288 | embedding = F.dropout(x, p=0.5, training=self.training) 289 | clipwise_output = torch.sigmoid(self.fc_audioset(x)) 290 | 291 | return clipwise_output 292 | -------------------------------------------------------------------------------- /predict_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import tqdm 4 | import numpy as np 5 | import utils 6 | import torch 7 | 8 | from ray import tune 9 | 10 | def predict_test(df, data_base_dir, ray_results_dir, dataset_name): 11 | df['test_accuracy'] = 0 12 | #predictionss = [] 13 | #labelss = [] 14 | for row in tqdm.tqdm(df.iterrows()): 15 | config = utils.get_config_by_row(row) 16 | break 17 | 18 | 19 | _,_,testset = utils.get_dataset_by_config(config, os.path.join(data_base_dir, dataset_name)) 20 | testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8) 21 | 22 | print("making test predictions (takes a couple of minutes on GPU) ...") 23 | for row in tqdm.tqdm(df.iterrows()): 24 | idx = row[0] 25 | #labels, predictions = utils.get_predictions_by_row(row, data_dir, split='test', device='cuda:0') 26 | 27 | labels, predictions = utils.get_predictions_by_row_new(row, testloader, device='cuda:0') 28 | test_acc = np.mean(labels == predictions) 29 | df.at[idx, 'test_accuracy'] = test_acc 30 | 31 | #predictionss.append(predictions) 32 | #labelss.append(labels) 33 | 34 | experiment_path = os.path.join(ray_results_dir, dataset_name) 35 | print("saving predictions to file {} ...".format(os.path.join(experiment_path, "{}.csv"))) 36 | df.to_csv(os.path.join(experiment_path, "{}.csv".format(dataset_name))) 37 | 38 | return df 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser(description='Produce plots.') 42 | parser.add_argument('--ray_results_dir', help='The name of the ray results directory.', required=True, type=str) 43 | parser.add_argument('--data_base_dir', help='The path to the audio data directory.', required=True, type=str) 44 | parser.add_argument('--dataset_name', help='The dataset name.', required=True, type=str) 45 | args = parser.parse_args() 46 | 47 | experiment_path = os.path.join(args.ray_results_dir, args.dataset_name) 48 | tuner = tune.Tuner.restore(path=experiment_path) 49 | result = tuner.fit() 50 | df = result.get_dataframe() 51 | 52 | df = predict_test(df, args.data_base_dir, args.ray_results_dir, args.dataset_name) 53 | 54 | print(df.head()) 55 | 56 | 57 | if __name__ == '__main__': 58 | main() 59 | -------------------------------------------------------------------------------- /produce_figures.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | import tqdm 6 | 7 | from ray import tune, air 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | import pandas as pd 13 | 14 | import utils 15 | import datasets 16 | import time_frequency as tf 17 | 18 | def produce_data_example_plot(): 19 | # data 20 | sigma_ref = torch.tensor(6.38) 21 | dataset = datasets.GaussPulseDatasetTimeFrequency( 22 | sigma = sigma_ref, 23 | n_points = 128, 24 | noise_std = torch.tensor(.0), 25 | n_samples = 500, 26 | f_center_max_offset=0, 27 | t_center_max_offset=0, 28 | demo=True, # disables a lot of variability, more pedagogical figure 29 | ) 30 | 31 | plt.rcParams['text.usetex'] = True 32 | 33 | short_scale = 0.2 34 | long_scale = 5 35 | 36 | idx = 0 37 | n_classes = 3 38 | count = 0 39 | 40 | lambda_param = sigma_ref 41 | 42 | fig, ax = plt.subplots(3, 3, figsize=(8,3*2.7)) 43 | 44 | while True: 45 | #for idx in range(20): 46 | x, y = dataset[idx] 47 | 48 | if count % n_classes == y: 49 | count += 1 50 | t1, f1, t2, f2 = dataset.locs[idx] 51 | 52 | s, w = tf.differentiable_spectrogram(x-torch.mean(x), lambd=lambda_param, return_window=True) 53 | utils.plot_spectrogram(s.detach().numpy(), ax[count-1, 0], decorate_axes=False) 54 | 55 | s, w = tf.differentiable_spectrogram(x-torch.mean(x), lambd=lambda_param*short_scale, return_window=True) 56 | utils.plot_spectrogram(s.detach().numpy(), ax[count-1, 1], decorate_axes=False) 57 | 58 | s, w = tf.differentiable_spectrogram(x-torch.mean(x), lambd=lambda_param*long_scale, return_window=True) 59 | utils.plot_spectrogram(s.detach().numpy(), ax[count-1, 2], decorate_axes=False) 60 | 61 | idx += 1 62 | else: 63 | idx += 1 64 | 65 | if count > 2: 66 | break 67 | 68 | scales = [1.0, short_scale, long_scale] 69 | for i in range(3): 70 | ax[i, 0].set_ylabel('normalized frequency') 71 | ax[2, i].set_xlabel('time') 72 | ax[0, i].set_title(r'$\lambda = {0:.1f}$'.format(lambda_param * scales[i])) 73 | 74 | plt.tight_layout() 75 | plt.savefig(os.path.join(experiment_path, 'data_example.pdf'), bbox_inches='tight') 76 | 77 | 78 | def produce_accuracy_plot(experiment_path, data_dir, split='valid'): 79 | if 'audio_mnist' in experiment_path: 80 | dataset_name = 'audio_mnist' 81 | model_names = ['mel_linear_net', 'mel_conv_net'] 82 | elif 'esc50' in experiment_path: 83 | dataset_name = 'esc50' 84 | model_names = ['panns_cnn6'] 85 | elif 'time_frequency' in experiment_path: 86 | dataset_name = 'time_frequency' 87 | model_names = ['linear_net', 'conv_net'] 88 | 89 | print("############################################") 90 | print("Dataset : {}, and models: {}".format(dataset_name, model_names)) 91 | print("############################################") 92 | 93 | tuner = tune.Tuner.restore(path=experiment_path) 94 | print(experiment_path) 95 | result = tuner.fit() 96 | df = result.get_dataframe() 97 | 98 | 99 | if split == 'test': 100 | # make test predictions if they do not exist 101 | #if not os.path.exists(os.path.join(experiment_path, "{}.csv".format(dataset_name))): 102 | # predict_test(df, dataset_name, data_dir, experiment_path) 103 | 104 | # load test predictions 105 | df = pd.read_csv(os.path.join(experiment_path, "{}.csv".format(dataset_name))) 106 | predictionss = np.load(os.path.join(experiment_path, "{}_predictionss.npy".format(dataset_name))) 107 | labelss = np.load(os.path.join(experiment_path, "{}_labelss.npy".format(dataset_name))) 108 | 109 | column_width = 4 110 | figure_height = 3 111 | 112 | ##################################################################### 113 | # Accuracy plot 114 | ##################################################################### 115 | df = df[(df['config/dataset_name'] == dataset_name)] 116 | 117 | fig, ax = plt.subplots(2, 2, figsize=(column_width*2, figure_height*2)) 118 | 119 | for idx_column, model_name in enumerate(model_names): 120 | df_model = df[(df['config/model_name'] == model_name)] 121 | 122 | model_title = get_model_title(model_name) 123 | 124 | ax[0, idx_column].set_title(model_title) 125 | 126 | if split == 'valid': 127 | y_str = 'best_valid_acc' 128 | y_title = 'Validation accuracy' 129 | elif split == 'test': 130 | y_str = 'test_accuracy' 131 | y_title = 'Test accuracy' 132 | else: 133 | raise ValueError('split not found: ', split) 134 | 135 | sns.lineplot(data=df_model, x="config/init_lambd", 136 | y=y_str, marker='o', 137 | hue='config/trainable', ax=ax[0, idx_column]) 138 | 139 | ax[0, idx_column].legend(loc='lower center', title='Trainable') 140 | 141 | g = sns.lineplot(data=df_model, x="config/init_lambd", 142 | y='lambd_est', hue='config/trainable', 143 | marker="o", ax=ax[1, idx_column]) 144 | 145 | ax[1, idx_column].legend(loc='upper left', title='Trainable') 146 | 147 | ax[0, 0].set_ylabel(y_title) 148 | ax[0, 0].set_xlabel("") 149 | ax[0, 1].set_ylabel("") 150 | ax[0, 1].set_xlabel("") 151 | ax[1, 0].set_ylabel(r'$\lambda_{est}$') 152 | ax[1, 0].set_xlabel(r'$\lambda_{init}$') 153 | ax[1, 1].set_ylabel("") 154 | ax[1, 1].set_xlabel(r'$\lambda_{init}$') 155 | 156 | if dataset_name == 'audio_mnist': 157 | ax[0, 0].set_ylim([0.75, 0.96]) 158 | ax[0, 1].set_ylim([0.75, 0.96]) 159 | 160 | if dataset_name == 'time_frequency': 161 | ax[0, 0].set_ylim([0.95, 1]) 162 | ax[0, 1].set_ylim([0.95, 1]) 163 | 164 | plt.tight_layout() 165 | fig_path = os.path.join(experiment_path, '{}_{}.pdf'.format(split, dataset_name)) 166 | print("saving figure ... ", fig_path) 167 | plt.savefig(fig_path, bbox_inches='tight') 168 | 169 | def produce_baseline_plot(experiment_path, dataset_name, model_names, data_dir, split='valid'): 170 | print("############################################") 171 | print("Dataset : {}, and models: {}".format(dataset_name, model_names)) 172 | print("############################################") 173 | 174 | tuner = tune.Tuner.restore(path=experiment_path) 175 | print(experiment_path) 176 | result = tuner.fit() 177 | df = result.get_dataframe() 178 | 179 | if split == 'test': 180 | # make test predictions if they do not exist 181 | if not os.path.exists(os.path.join(experiment_path, "{}.csv".format(dataset_name))): 182 | predict_test(df, dataset_name, data_dir, experiment_path) 183 | 184 | # load test predictions 185 | df = pd.read_csv(os.path.join(experiment_path, "{}.csv".format(dataset_name))) 186 | predictionss = np.load(os.path.join(experiment_path, "{}_predictionss.npy".format(dataset_name))) 187 | labelss = np.load(os.path.join(experiment_path, "{}_labelss.npy".format(dataset_name))) 188 | 189 | column_width = 4 190 | figure_height = 3 191 | 192 | ##################################################################### 193 | # Accuracy plot 194 | ##################################################################### 195 | df = df[(df['config/dataset_name'] == dataset_name)] 196 | 197 | fig, ax = plt.subplots(1, 1, figsize=(column_width*1, figure_height*1)) 198 | 199 | for idx_column, model_name in enumerate(model_names): 200 | df_model = df[(df['config/model_name'] == model_name)] 201 | 202 | model_title = get_model_title(model_name) 203 | 204 | #ax[0, idx_column].set_title(model_title) 205 | 206 | if split == 'valid': 207 | y_str = 'best_valid_acc' 208 | y_title = 'Validation accuracy' 209 | elif split == 'test': 210 | y_str = 'test_accuracy' 211 | y_title = 'Test accuracy' 212 | else: 213 | raise ValueError('split not found: ', split) 214 | 215 | sns.lineplot(data=df_model, x="config/init_lambd", 216 | y=y_str, marker='o', 217 | hue='config/trainable', ax=ax) 218 | 219 | #ax[0, idx_column].legend(loc='lower center', title='Trainable') 220 | 221 | #g = sns.lineplot(data=df_model, x="config/init_lambd", 222 | # y='lambd_est', hue='config/trainable', 223 | # marker="o", ax=ax[1, idx_column]) 224 | 225 | #ax[1, idx_column].legend(loc='upper left', title='Trainable') 226 | 227 | ax.set_ylabel(y_title) 228 | ax.set_xlabel(r'$\lambda_{init}$') 229 | 230 | if dataset_name == 'audio_mnist': 231 | ax.set_ylim([0.75, 0.96]) 232 | ax.set_ylim([0.75, 0.96]) 233 | 234 | if dataset_name == 'time_frequency': 235 | ax.set_ylim([0.95, 1]) 236 | ax.set_ylim([0.95, 1]) 237 | 238 | if dataset_name == 'esc50': 239 | ax.set_ylim([0.65, 0.9]) 240 | ax.set_ylim([0.65, 0.9]) 241 | 242 | plt.tight_layout() 243 | fig_path = os.path.join(experiment_path, '{}_{}.pdf'.format(split, dataset_name)) 244 | print("saving figure ... ", fig_path) 245 | plt.savefig(fig_path, bbox_inches='tight') 246 | 247 | 248 | 249 | def main(): 250 | parser = argparse.ArgumentParser(description='Produce plots.') 251 | parser.add_argument('--experiment_path', help='The name of the experiment directory.', required=True, type=str) 252 | parser.add_argument('--data_dir', help='The absolute path to the audio-mnist data directory.', required=True, type=str) 253 | #parser.add_argument('--name', help='The name of the experiment', required=True, type=str) 254 | parser.add_argument('--split', help='The name of the split [train, valid].', required=True, type=str) 255 | args = parser.parse_args() 256 | 257 | # produce figure 1 258 | #produce_data_example_plot() 259 | 260 | # produce figure 2 261 | #experiment_path = os.path.join(args.ray_root_dir, 'time_frequency') 262 | #produce_accuracy_plot(experiment_path, data_dir=args.data_dir, split=args.split) 263 | 264 | # produce figure 3 265 | #experiment_path = os.path.join(args.ray_root_dir, 'audio_mnist') 266 | #produce_accuracy_plot(experiment_path, data_dir=args.data_dir, split=args.split) 267 | 268 | #produce_accuracy_plot(experiment_path, data_dir=args.data_dir, split=args.split) 269 | 270 | experiment_path = os.path.join(args.experiment_path) 271 | produce_baseline_plot(experiment_path, dataset_name='esc50', model_names=['panns_cnn6'], data_dir=args.data_dir, split=args.split) 272 | 273 | #print("") 274 | #print("the plots can now be be found in ./results/figures") 275 | 276 | 277 | def get_model_title(model_name): 278 | if model_name == 'conv_net': 279 | model_title = 'ConvNet' 280 | elif model_name == 'linear_net': 281 | model_title = 'LinearNet' 282 | elif model_name == 'mel_linear_net': 283 | model_title = 'MelLinearNet' 284 | elif model_name == 'mel_conv_net': 285 | model_title = 'MelConvNet' 286 | elif model_name == 'panns_cnn6': 287 | model_title = 'PANNs CNN6' 288 | else: 289 | raise ValueError("model_name: {} is not defined.".format(model_name)) 290 | 291 | return model_title 292 | 293 | def get_data_title(dataset_name): 294 | if dataset_name == 'time_frequency': 295 | dataset_title = "Gaussian-pulse dataset" 296 | elif dataset_name == 'audio_mnist': 297 | dataset_title = "Audio MNIST dataset" 298 | else: 299 | raise ValueError("dataset_name: {} is not defined.".format(dataset_name)) 300 | 301 | return dataset_title 302 | 303 | def get_title(dataset_name, model_name): 304 | 305 | dataset_title = get_data_title(dataset_name) 306 | model_title = get_model_title(model_name) 307 | 308 | return "{} with {}".format(dataset_title, model_title) 309 | 310 | if __name__ == '__main__': 311 | main() 312 | -------------------------------------------------------------------------------- /produce_tables.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import torch 5 | import tqdm 6 | 7 | #from ray import tune, air 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import seaborn as sns 12 | import pandas as pd 13 | 14 | import utils 15 | import datasets 16 | import time_frequency as tf 17 | 18 | def get_window_length_results(df, window_length, sr=8000): 19 | init_lambd = window_length / 6 * sr 20 | eps = 1e-5 21 | df_res = df[(df['config/init_lambd'] > (init_lambd - eps)) & (df['config/init_lambd'] < (init_lambd + eps))] 22 | 23 | return df_res 24 | 25 | def produce_table_1(experiment_path, dataset_name): 26 | df = pd.read_csv(os.path.join(experiment_path, "{}.csv".format(dataset_name))) 27 | df_train = df[df['config/trainable'] == True] 28 | df_fixed = df[df['config/trainable'] == False] 29 | 30 | window_lengths = [0.010, 0.035, 0.300] 31 | 32 | print("Model & $l_{\lambda_{init}}$ & $l_{\lambda_{est}}$ & Method & Accuracy \\\\") 33 | print("\\hline \\hline") 34 | 35 | 36 | for window_length in window_lengths: 37 | df_train_win = get_window_length_results(df_train, window_length, sr=8000) 38 | df_fixed_win = get_window_length_results(df_fixed, window_length, sr=8000) 39 | 40 | mean_train_acc = df_train_win['test_accuracy'].mean() * 100 41 | std_train_acc = df_train_win['test_accuracy'].std() * 100 42 | 43 | mean_fixed_acc = df_fixed_win['test_accuracy'].mean() * 100 44 | std_fixed_acc = df_fixed_win['test_accuracy'].std() * 100 45 | 46 | min_lambd_est = df_train_win['best_lambd_est'].abs().min() * 6 / 8000 47 | max_lambd_est = df_train_win['best_lambd_est'].abs().max() * 6 / 8000 48 | 49 | row_format = "{} & {} ms & ({}, {}) ms & {} & ${:.1f} \pm {:.1f}$ \\\\" 50 | print(row_format.format( 51 | "LNet", int(window_length * 1000), int(min_lambd_est * 1000), int(max_lambd_est * 1000), 52 | "DMEL", mean_train_acc, std_train_acc) 53 | ) 54 | row_format = "{} & {} ms & {} ms & {} & ${:.1f} \pm {:.1f}$ \\\\" 55 | print(row_format.format( 56 | "LNet", int(window_length * 1000), int(window_length * 1000), 57 | "baseline", mean_fixed_acc, std_fixed_acc) 58 | ) 59 | print("\\hline") 60 | 61 | def produce_table_2(experiment_path, dataset_name): 62 | df = pd.read_csv(os.path.join(experiment_path, "{}.csv".format(dataset_name))) 63 | #print(df) 64 | df_train = df[df['config/trainable'] == True] 65 | df_fixed = df[df['config/trainable'] == False] 66 | 67 | sigma_ref = 6.38 68 | #lambd_inits = [sigma_ref * 0.2, sigma_ref*0.6, sigma_ref, sigma_ref*1.8, sigma_ref*2.6] 69 | lambd_inits = [sigma_ref * 0.2, sigma_ref, sigma_ref*5.0] 70 | 71 | print("Model & $\lambda_{init}$ & $\lambda_{est}$ & Method & Accuracy \\\\") 72 | print("\\hline \\hline") 73 | 74 | 75 | for lambd_init in lambd_inits: 76 | df_train_win = df_train[df_train['config/init_lambd'] == lambd_init] 77 | df_fixed_win = df_fixed[df_fixed['config/init_lambd'] == lambd_init] 78 | 79 | mean_train_acc = df_train_win['test_accuracy'].mean() * 100 80 | std_train_acc = df_train_win['test_accuracy'].std() * 100 81 | 82 | mean_fixed_acc = df_fixed_win['test_accuracy'].mean() * 100 83 | std_fixed_acc = df_fixed_win['test_accuracy'].std() * 100 84 | 85 | min_lambd_est = df_train_win['best_lambd_est'].abs().min() 86 | max_lambd_est = df_train_win['best_lambd_est'].abs().max() 87 | 88 | mean_lambd_est = df_train_win['best_lambd_est'].abs().mean() 89 | std_lambd_est = df_train_win['best_lambd_est'].abs().std() 90 | 91 | row_format = "{} & {:.1f} & ({:.1f}, {:.1f}) & {} & ${:.1f} \pm {:.1f}$ \\\\" 92 | print(row_format.format( 93 | "LinearNet", lambd_init, min_lambd_est, max_lambd_est, 94 | "DSPEC", mean_train_acc, std_train_acc) 95 | ) 96 | row_format = "{} & {:.1f} & {:.1f} & {} & ${:.1f} \pm {:.1f}$ \\\\" 97 | print(row_format.format( 98 | "LinearNet", lambd_init, lambd_init, 99 | "baseline", mean_fixed_acc, std_fixed_acc) 100 | ) 101 | print("\\hline") 102 | 103 | def produce_result_table(experiment_path, dataset_name): 104 | if dataset_name == 'audio_mnist': 105 | model_names = ['mel_conv_net', 'mel_linear_net'] 106 | if dataset_name == 'esc50': 107 | model_names = ['panns_cnn6'] 108 | 109 | print("############################################") 110 | print("Dataset : {}, and models: {}".format(dataset_name, model_names)) 111 | print("############################################") 112 | 113 | # load test predictions 114 | df = pd.read_csv(os.path.join(experiment_path, "{}.csv".format(dataset_name))) 115 | predictionss = np.load(os.path.join(experiment_path, "{}_predictionss.npy".format(dataset_name))) 116 | labelss = np.load(os.path.join(experiment_path, "{}_labelss.npy".format(dataset_name))) 117 | 118 | column_width = 4 119 | figure_height = 3 120 | 121 | df = df[(df['config/dataset_name'] == dataset_name)] 122 | df = df[(df['config/init_lambd'] == 8000*0.025 / 6)] 123 | 124 | print("Trainable & True & False \\\\") 125 | for idx_column, model_name in enumerate(model_names): 126 | df_model = df[(df['config/model_name'] == model_name)] 127 | df_train = df_model[(df_model['config/trainable'] == True)] 128 | df_fixed = df_model[(df_model['config/trainable'] == False)] 129 | model_title = get_model_title(model_name) 130 | 131 | trainable_mean_acc = df_train['test_accuracy'].mean() 132 | trainable_std_acc = df_train['test_accuracy'].std() 133 | 134 | fixed_mean_acc = df_fixed['test_accuracy'].mean() 135 | fixed_std_acc = df_fixed['test_accuracy'].std() 136 | 137 | print("{} & ${:.2f} \\pm {:.2f}$ & ${:.2f} \\pm {:.2f}$ \\\\".format( 138 | model_title, 139 | trainable_mean_acc, trainable_std_acc, 140 | fixed_mean_acc, fixed_std_acc 141 | )) 142 | 143 | 144 | 145 | def main(): 146 | parser = argparse.ArgumentParser(description='Produce plots.') 147 | parser.add_argument('--ray_results_dir', help='The name of the ray results directory.', required=True, type=str) 148 | #parser.add_argument('--experiment_path', help='The name of the experiment directory.', required=True, type=str) 149 | #parser.add_argument('--dataset_name', help='The dataset name.', required=True, type=str) 150 | args = parser.parse_args() 151 | 152 | #experiment_path = os.path.join(args.experiment_path) 153 | print("ESC50") 154 | produce_table_1(os.path.join(args.ray_results_dir, 'esc50'), 'esc50') 155 | print("") 156 | 157 | print("A-MNIST") 158 | produce_table_1(os.path.join(args.ray_results_dir, 'audio_mnist'), 'audio_mnist') 159 | print("") 160 | 161 | print("time-frequency") 162 | produce_table_2(os.path.join(args.ray_results_dir, 'time_frequency'), 'time_frequency') 163 | print("") 164 | 165 | 166 | def get_model_title(model_name): 167 | if model_name == 'conv_net': 168 | model_title = 'ConvNet' 169 | elif model_name == 'linear_net': 170 | model_title = 'LinearNet' 171 | elif model_name == 'mel_linear_net': 172 | model_title = 'MelLinearNet' 173 | elif model_name == 'mel_conv_net': 174 | model_title = 'MelConvNet' 175 | elif model_name == 'panns_cnn6': 176 | model_title = 'PANNs CNN6' 177 | else: 178 | raise ValueError("model_name: {} is not defined.".format(model_name)) 179 | 180 | return model_title 181 | 182 | def get_data_title(dataset_name): 183 | if dataset_name == 'time_frequency': 184 | dataset_title = "Gaussian-pulse dataset" 185 | elif dataset_name == 'audio_mnist': 186 | dataset_title = "Audio MNIST dataset" 187 | else: 188 | raise ValueError("dataset_name: {} is not defined.".format(dataset_name)) 189 | 190 | return dataset_title 191 | 192 | def get_title(dataset_name, model_name): 193 | 194 | dataset_title = get_data_title(dataset_name) 195 | model_title = get_model_title(model_name) 196 | 197 | return "{} with {}".format(dataset_title, model_title) 198 | 199 | if __name__ == '__main__': 200 | main() 201 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.9.2 2 | matplotlib==3.6.2 3 | numpy==1.23.5 4 | pandas==1.5.1 5 | ray==2.0.1 6 | seaborn==0.12.2 7 | torch==1.13.1 8 | torchaudio==0.13.1 9 | tqdm==4.64.1 10 | ray[tune] 11 | ray[rllib] 12 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py --num_samples=1 --max_epochs=1000 --name=audio_mnist --ray_root_dir=$(pwd)/ray_results --data_dir=$(pwd)/data/audio_mnist 2 | CUDA_VISIBLE_DEVICES=0 python main.py --num_samples=1 --max_epochs=1000 --name=time_frequency --ray_root_dir=$(pwd)/ray_results --data_dir=$(pwd)/data/time_frequency 3 | CUDA_VISIBLE_DEVICES=0 python main.py --num_samples=1 --max_epochs=1000 --name=esc50 --ray_root_dir=$(pwd)/ray_results --data_dir=$(pwd)/data/esc50 -------------------------------------------------------------------------------- /run_test_predictions.sh: -------------------------------------------------------------------------------- 1 | python predict_test.py --ray_results_dir=$(pwd)/ray_results/ --data_base_dir=$(pwd)/data/ --dataset_name=audio_mnist 2 | python predict_test.py --ray_results_dir=$(pwd)/ray_results/ --data_base_dir=$(pwd)/data/ --dataset_name=esc50 3 | python predict_test.py --ray_results_dir=$(pwd)/ray_results/ --data_base_dir=$(pwd)/data/ --dataset_name=time_frequency 4 | -------------------------------------------------------------------------------- /search_spaces.py: -------------------------------------------------------------------------------- 1 | from ray import tune 2 | 3 | def esc50(max_epochs): 4 | resample_rate = 8000 5 | search_space = { 6 | # model 7 | 'model_name' : 'panns_cnn6', 8 | 'n_mels' : 64, 9 | 'hop_length' :int(resample_rate * 0.010), 10 | 'energy_normalize' : True, 11 | 'optimized' : True, 12 | 'normalize_window' : False, 13 | 'augment' : False, 14 | 15 | # training 16 | 'pretrained' : False, 17 | 'checkpoint_path' : '/home/john/gits/differentiable-time-frequency-transforms/weights/Cnn6_mAP=0.343.pth', 18 | 'optimizer_name' : 'adam', 19 | 'lr_model' : 1e-4, 20 | 'lr_tf' : 1.0, 21 | 'batch_size' : 32, 22 | 'trainable' : tune.grid_search([True, False]), 23 | 'max_epochs' : max_epochs, 24 | 'patience' : 100, 25 | 'device' : 'cuda:0', 26 | 27 | # dataset 28 | 'resample_rate' : resample_rate, 29 | 'init_lambd' : tune.grid_search([(resample_rate*x)/6 for x in [0.01, 0.035, 0.3]]), 30 | 'dataset_name' : 'esc50', 31 | 'n_points' : resample_rate * 5, # hard coded zero-padding 32 | } 33 | 34 | return search_space 35 | 36 | def audio_mnist(max_epochs): 37 | resample_rate = 8000 38 | search_space = { 39 | # model 40 | 'model_name' : 'mel_linear_net', 41 | 'n_mels' : 64, 42 | 'hop_length' :int(resample_rate * 0.010), 43 | 'energy_normalize' : True, 44 | 'optimized' : True, 45 | 'normalize_window' : False, 46 | 'augment' : False, 47 | 48 | # training 49 | 'pretrained' : False, 50 | 'checkpoint_path' : '/home/john/gits/differentiable-time-frequency-transforms/weights/Cnn6_mAP=0.343.pth', 51 | 'optimizer_name' : 'adam', 52 | 'lr_model' : 1e-4, 53 | 'lr_tf' : 1.0, 54 | 'batch_size' : 64, 55 | 'trainable' : tune.grid_search([True, False]), 56 | 'max_epochs' : max_epochs, 57 | 'patience' : 100, 58 | 'device' : 'cuda:0', 59 | 60 | # dataset 61 | 'resample_rate' : resample_rate, 62 | 'init_lambd' : tune.grid_search([(resample_rate*x)/6 for x in [0.01, 0.035, 0.3]]), 63 | 'dataset_name' : 'audio_mnist', 64 | 'n_points' : 8000, # hard coded zero-padding 65 | #'speaker_id' : tune.grid_search([[28, 56, 7, 19, 35]]), 66 | } 67 | 68 | return search_space 69 | 70 | def time_frequency(max_epochs): 71 | sigma_ref = 6.38 72 | 73 | search_space = { 74 | # model 75 | 'model_name' : 'linear_net', 76 | 'hop_length' : 1, 77 | 'optimized' : False, 78 | 'normalize_window' : False, 79 | 80 | # training 81 | 'optimizer_name' : 'sgd', 82 | 'lr_model' : 1e-3, 83 | 'lr_tf' : 1, 84 | 'batch_size' : 128, 85 | 'trainable' : tune.grid_search([True, False]), 86 | 'max_epochs' : max_epochs, 87 | 'patience' : 100, 88 | 'device' : 'cuda:0', 89 | 90 | # dataset 91 | 'n_points' : 128, 92 | 'noise_std' : 0.5, 93 | 'init_lambd' : tune.grid_search([x * sigma_ref for x in [0.2, 1.0, 5.0]]), 94 | 'n_samples' : 5000, 95 | 'sigma_ref' : sigma_ref, 96 | 'dataset_name' : 'time_frequency', 97 | 'center_offset' : False, 98 | } 99 | 100 | return search_space 101 | -------------------------------------------------------------------------------- /time_frequency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def gauss_whole(sigma, tc, signal_length, norm='amplitude', device='cpu'): 6 | ts = torch.arange(0, signal_length).float() 7 | 8 | ts = ts.to(device) 9 | tc = tc.to(device) 10 | sigma = sigma.to(device) 11 | 12 | window = torch.exp(-0.5 * torch.pow((ts-tc) / (sigma + 1e-15), 2)) 13 | 14 | if norm == 'energy': 15 | window_norm = window / torch.sum(torch.pow(window, 2)) 16 | elif norm == 'amplitude': 17 | window_norm = window / torch.max(window) 18 | 19 | return window_norm 20 | 21 | def differentiable_gaussian_window(lambd, window_length, device='cpu', norm=True): 22 | m = torch.arange(0, window_length).float().to(device) 23 | 24 | window = torch.exp(-0.5 * torch.pow((m - window_length / 2) / (lambd + 1e-15), 2)) 25 | window_norm = window / torch.sqrt(torch.sum(torch.pow(window, 2))) 26 | 27 | if norm: 28 | return window_norm 29 | else: 30 | return window 31 | 32 | def differentiable_spectrogram(x, lambd, optimized=False, device='cpu', hop_length=1, return_window=False, norm=False, n_stds=6): 33 | 34 | # optimization potentially makes gradients weaker, but faster 35 | if optimized: 36 | # TODO: not sure if this optimization works as intended, 37 | # never used in the experiments. Will be become important 38 | # for longer signals. 39 | window_length = next_power_of_2((lambd * n_stds).detach().cpu().numpy()) 40 | else: 41 | window_length = len(x) 42 | 43 | window = differentiable_gaussian_window(lambd, window_length=window_length, device=device, norm=norm).to(device) 44 | n_fft = len(window) 45 | 46 | # quadratic TF-image without redundancy 47 | if optimized: 48 | s = torch.stft(x, n_fft=len(window), hop_length=hop_length, win_length=len(window), window=window, return_complex=True, pad_mode='constant') 49 | else: 50 | # quadratic TF-image with redundancy 51 | s = torch.stft(x, n_fft=len(window)*2, hop_length=hop_length, win_length=len(window), window=window, return_complex=True, pad_mode='constant') 52 | 53 | s = torch.pow(torch.abs(s), 2) 54 | 55 | if not return_window: 56 | return s 57 | else: 58 | return s, window 59 | 60 | def shift_bit_length(x): 61 | x = int(x) 62 | return 1<<(x-1).bit_length() 63 | 64 | def next_power_of_2(x): 65 | return shift_bit_length(x) 66 | 67 | 68 | # NOTE: initial time-frequency implementation from scratch, 69 | # a bit slower than the pytorch implementation, kept here 70 | # if found useful in the future, or simply for educational 71 | # purposes. 72 | 73 | #def stft(x, windows): 74 | # dim = (len(x) // 2 + 1, len(windows)) 75 | # s = torch.empty(dim, dtype=torch.complex64) 76 | # for idx, window in enumerate(windows): 77 | # x_w = x * window 78 | # fft = torch.fft.rfft(x_w) 79 | # s[:,idx] = fft 80 | # 81 | # return s 82 | # 83 | #def spectrogram(x, lambd, overlap=0.5): 84 | # signal_length = len(x) 85 | # windows = get_gauss_windows(signal_length, lambd, overlap) 86 | # s = stft(x, windows) 87 | # s = torch.pow(torch.abs(s), 2) 88 | # return s 89 | # 90 | #def spectrogram_whole(x, lambd, device='cpu'): 91 | # signal_length = len(x) 92 | # windows = get_gauss_windows_whole(signal_length, lambd, device=device) 93 | # windows = [w.to(device) for w in windows] 94 | # s = stft(x, windows) 95 | # s = torch.pow(torch.abs(s), 2) 96 | # return s 97 | # 98 | #def get_gauss_windows(signal_length, lambd, overlap): 99 | # hop_size = (lambd*6*(1-overlap)).int() 100 | # windows = [ 101 | # gauss_whole(sigma=lambd, tc=(i+1)*hop_size, signal_length=signal_length) 102 | # for i in range(signal_length // hop_size - 1) 103 | # ] 104 | # 105 | # return windows 106 | # 107 | #def get_gauss_windows_whole(signal_length, lambd, device='cpu'): 108 | # half_window_length = 0 #(lambd*3).int() 109 | # windows = [ 110 | # gauss_whole(sigma=lambd, tc=torch.tensor(i).float(), signal_length=signal_length, device=device) 111 | # for i in range(half_window_length, signal_length-half_window_length) 112 | # ] 113 | # 114 | # return windows 115 | 116 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ray import tune 3 | 4 | import torch 5 | import time 6 | import numpy as np 7 | 8 | def train_model(net, optimizer, loss_fn, trainloader, validloader, scheduler, patience, max_epochs, verbose=1, device='cuda:0', one_hot=False, n_classes=50): 9 | history = { 10 | "best_valid_acc" : 0, 11 | "best_valid_loss" : np.inf, 12 | "init_lambd" : net.spectrogram_layer.lambd.item(), 13 | "converged" : False, 14 | } 15 | best_valid_acc = 0 16 | best_valid_loss = np.inf 17 | patience_count = 0 18 | for epoch in range(max_epochs): 19 | 20 | net.train() 21 | 22 | running_loss = 0.0 23 | running_energy = 0.0 24 | count = 0 25 | for i, data in enumerate(trainloader): 26 | t_tot1 = time.time() 27 | inputs, labels = data 28 | 29 | if one_hot: 30 | # TODO: this won't work in general 31 | labels = torch.nn.functional.one_hot(labels, n_classes).float() 32 | 33 | t1 = time.time() 34 | inputs, labels = inputs.to(device), labels.to(device) 35 | t2 = time.time() 36 | #print("time 1 = {}, batch = {}".format(t2-t1, i)) 37 | 38 | 39 | optimizer.zero_grad() 40 | 41 | t1 = time.time() 42 | logits, s = net(inputs) 43 | t2 = time.time() 44 | #print("time 2 = {}, batch = {}".format(t2-t1, i)) 45 | 46 | loss = loss_fn(logits, labels)# + aux_loss 47 | loss.backward() 48 | 49 | optimizer.step() 50 | 51 | if verbose >= 2: 52 | if i % 10 == 0: 53 | print("max values: ", torch.max(logits, dim=1).values.cpu().detach().numpy()) 54 | print("batch loss = {}".format(loss.item())) 55 | print("est. lambd = ", net.spectrogram_layer.lambd.item()) 56 | 57 | running_loss += loss.item() 58 | running_energy += np.sum(s.cpu().detach().numpy()) 59 | count += 1 60 | 61 | t_tot2 = time.time() 62 | #print("time total = {}, batch = {}".format(t_tot2-t_tot1, i)) 63 | 64 | # step scheduler 65 | scheduler.step() 66 | 67 | train_loss = running_loss / count 68 | train_energy = running_energy / count 69 | 70 | if verbose >= 1: 71 | print("epoch {}, train loss = {}".format(epoch, running_loss / count)) 72 | print("est. lambd = ", net.spectrogram_layer.lambd.item()) 73 | 74 | running_loss = 0.0 75 | count = 0 76 | running_acc = 0.0 77 | 78 | net.eval() 79 | for data in validloader: 80 | inputs, labels = data 81 | 82 | if one_hot: 83 | # TODO: this won't work in general 84 | labels = torch.nn.functional.one_hot(labels, n_classes).float() 85 | 86 | inputs, labels = inputs.to(device), labels.to(device) 87 | 88 | outputs, spectrograms = net(inputs) 89 | loss = loss_fn(outputs, labels) 90 | 91 | predictions = torch.argmax(outputs, axis=1) 92 | 93 | if one_hot: 94 | labels = torch.argmax(labels, axis=1) 95 | 96 | accuracy = torch.mean((predictions == labels).float()) 97 | running_acc += accuracy.item() 98 | 99 | running_loss += loss.item() 100 | count += 1 101 | 102 | valid_loss = running_loss / count 103 | valid_acc = running_acc / count 104 | 105 | 106 | # save epoch model 107 | #with tune.checkpoint_dir(epoch) as checkpoint_dir: 108 | # path = os.path.join(checkpoint_dir, "checkpoint") 109 | # torch.save((net.state_dict(), optimizer.state_dict()), path) 110 | 111 | 112 | if valid_loss < best_valid_loss: # < best_valid_loss: 113 | 114 | # save best model 115 | with tune.checkpoint_dir(0) as checkpoint_dir: 116 | path = os.path.join(checkpoint_dir, "best_model") 117 | torch.save((net.state_dict(), optimizer.state_dict()), path) 118 | 119 | best_valid_acc = valid_acc 120 | best_valid_loss = valid_loss 121 | best_lambd_est = net.spectrogram_layer.lambd.item() 122 | patience_count = 0 123 | if verbose >= 1: 124 | print("new best valid acc = {}, patience_count = {}".format(best_valid_acc, patience_count)) 125 | else: 126 | patience_count += 1 127 | 128 | # report results 129 | tune.report(loss=train_loss, lambd_est=net.spectrogram_layer.lambd.item(), valid_loss=valid_loss, valid_acc=valid_acc, best_valid_acc=best_valid_acc, best_valid_loss=best_valid_loss, energy=train_energy, best_lambd_est=best_lambd_est) 130 | 131 | if verbose >= 1: 132 | print("epoch {}, valid loss = {}".format(epoch, valid_loss)) 133 | print("epoch {}, valid acc = {}".format(epoch, valid_acc)) 134 | 135 | # plot spectrogram 136 | plt.imshow(np.flip(spectrograms[0,0,:,:].cpu().detach().numpy(), axis=0), aspect='auto') 137 | plt.title("label = {}".format(labels.cpu().detach().numpy()[0])) 138 | plt.show() 139 | 140 | running_loss = 0.0 141 | running_acc = 0.0 142 | count = 0 143 | 144 | if patience_count >= patience: 145 | print("no more patience, break training loop ...") 146 | history["converged"] = True 147 | break 148 | 149 | # save history 150 | history["best_valid_acc"] = best_valid_acc 151 | history["best_valid_loss"] = best_valid_loss 152 | history["est_lambd"] = net.spectrogram_layer.lambd.item() 153 | 154 | return net, history 155 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import numpy as np 4 | import os 5 | import tqdm 6 | import glob 7 | 8 | import models 9 | import datasets 10 | 11 | def create_folder(fd): 12 | if not os.path.exists(fd): 13 | os.makedirs(fd) 14 | 15 | def load_checkpoint(model, checkpoint_path=None, device=torch.device('cpu')): 16 | if not checkpoint_path: 17 | checkpoint_path='./weights/Cnn6_mAP=0.343.pth' 18 | print('Checkpoint path: {}'.format(checkpoint_path)) 19 | 20 | if not os.path.exists(checkpoint_path): # or os.path.getsize(checkpoint_path) < 3e8: 21 | print("checkpoint path does not exist: ", checkpoint_path) 22 | create_folder(os.path.dirname(checkpoint_path)) 23 | zenodo_path = 'https://zenodo.org/record/3987831/files/Cnn6_mAP%3D0.343.pth' 24 | os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path)) 25 | 26 | print("loading weights: ", checkpoint_path) 27 | checkpoint = torch.load(checkpoint_path, map_location=device) 28 | state_dict = checkpoint['model'] 29 | 30 | new_state_dict = collections.OrderedDict() 31 | 32 | for key, value in state_dict.items(): 33 | new_key = "spectrogram_model." + key 34 | new_state_dict[new_key] = value 35 | 36 | model.load_state_dict(new_state_dict, strict=False) 37 | 38 | def get_config_by_row(row): 39 | config = {} 40 | r = row[1] 41 | for k in r.keys(): 42 | if 'config' in k: 43 | config[k.split('/')[1]] = r[k] 44 | return config 45 | 46 | def get_dataset_by_config(config, data_dir): 47 | if config['dataset_name'] == 'audio_mnist': 48 | #trainset_speaker_id = config['speaker_id'] #28, 56, 7, 19, 35, 1, 6, 16, 23, 34, 46, 53, 36, 57, 9, 24, 37, 2, 8, 17, 29, 39, 48, 54, 43, 58, 14, 25, 38, 3, 10, 20, 30, 40, 49, 55] 49 | trainset_speaker_id = [28, 56, 7, 19, 35, 1, 6, 16, 23, 34, 46, 53, 36, 57, 9, 24, 37, 2, 8, 17, 29, 39, 48, 54, 43, 58, 14, 25, 38, 3, 10, 20, 30, 40, 49, 55] 50 | validset_speaker_id = [12, 47, 59, 15, 27, 41, 4, 11, 21, 31, 44, 50] 51 | testset_speaker_id = [26, 52, 60, 18, 32, 42, 5, 13, 22, 33, 45, 51] 52 | 53 | # assert no overlap 54 | assert(len(trainset_speaker_id + validset_speaker_id + testset_speaker_id) == 60) 55 | assert(len(set(trainset_speaker_id + validset_speaker_id + testset_speaker_id)) == 60) 56 | 57 | train_wav_paths = [] 58 | valid_wav_paths = [] 59 | test_wav_paths = [] 60 | 61 | for speaker_id in trainset_speaker_id: 62 | wav_paths = glob.glob(os.path.join(data_dir, 'data', '{:02d}'.format(speaker_id), '*.wav')) 63 | train_wav_paths += wav_paths 64 | 65 | for speaker_id in validset_speaker_id: 66 | wav_paths = glob.glob(os.path.join(data_dir, 'data', '{:02d}'.format(speaker_id), '*.wav')) 67 | valid_wav_paths += wav_paths 68 | 69 | for speaker_id in testset_speaker_id: 70 | wav_paths = glob.glob(os.path.join(data_dir, 'data', '{:02d}'.format(speaker_id), '*.wav')) 71 | test_wav_paths += wav_paths 72 | 73 | all_wav_paths = glob.glob(os.path.join(data_dir, 'data', '*/*.wav')) 74 | 75 | trainset = datasets.AudioMNISTBigDataset( 76 | wav_paths = train_wav_paths 77 | ) 78 | validset = datasets.AudioMNISTBigDataset( 79 | wav_paths = valid_wav_paths 80 | ) 81 | testset = datasets.AudioMNISTBigDataset( 82 | wav_paths = test_wav_paths 83 | ) 84 | 85 | assert((len(trainset) + len(validset) + len(testset)) == 30000) 86 | #print("Trainset: {}".format(len(trainset))) 87 | 88 | return trainset, validset, testset 89 | elif config['dataset_name'] == 'esc50': 90 | dataset = datasets.ESC50Dataset( 91 | source_dir = data_dir, 92 | resample_rate = config['resample_rate'], 93 | ) 94 | else: 95 | # random offset 1/5 of tf-image in each direction 96 | if config['center_offset']: 97 | f_center_max_offset = 0.1 98 | t_center_max_offset = config['n_points']/5 99 | else: 100 | f_center_max_offset = 0.0 101 | t_center_max_offset = 0.0 102 | 103 | 104 | if config['dataset_name'] == 'time': 105 | dataset = datasets.GaussPulseDatasetTime( 106 | sigma = torch.tensor(config['sigma_ref']), 107 | n_points = config['n_points'], 108 | noise_std = torch.tensor(config['noise_std']), 109 | n_samples = config['n_samples'], 110 | f_center_max_offset = f_center_max_offset, 111 | t_center_max_offset = t_center_max_offset, 112 | ) 113 | elif config['dataset_name'] == 'frequency': 114 | dataset = datasets.GaussPulseDatasetFrequency( 115 | sigma = torch.tensor(config['sigma_ref']), 116 | n_points = config['n_points'], 117 | noise_std = torch.tensor(config['noise_std']), 118 | n_samples = config['n_samples'], 119 | f_center_max_offset = f_center_max_offset, 120 | t_center_max_offset = t_center_max_offset, 121 | ) 122 | elif config['dataset_name'] == 'time_frequency': 123 | dataset = datasets.GaussPulseDatasetTimeFrequency( 124 | sigma = torch.tensor(config['sigma_ref']), 125 | n_points = config['n_points'], 126 | noise_std = torch.tensor(config['noise_std']), 127 | n_samples = config['n_samples'], 128 | f_center_max_offset = f_center_max_offset, 129 | t_center_max_offset = t_center_max_offset, 130 | ) 131 | else: 132 | raise ValueError("dataset not defined: ", config['dataset_name']) 133 | 134 | 135 | gen = torch.Generator() 136 | gen.manual_seed(0) 137 | trainset, validset, testset = torch.utils.data.random_split( 138 | dataset, [0.7, 0.1, 0.2], 139 | generator=gen 140 | ) 141 | 142 | return trainset, validset, testset 143 | 144 | def get_model_by_config(config): 145 | if config['dataset_name'] == 'time_frequency': 146 | n_classes = 3 147 | elif config['dataset_name'] == 'audio_mnist': 148 | n_classes = 10 149 | elif config['dataset_name'] == 'esc50': 150 | n_classes = 50 151 | else: 152 | raise ValueError('dataset_name: {} not supported.'.format(config['dataset_name'])) 153 | 154 | if config['model_name'] == 'linear_net': 155 | net = models.LinearNet( 156 | n_classes = n_classes, 157 | init_lambd = torch.tensor(config['init_lambd']), 158 | device = config['device'], 159 | size = (config['n_points']+1, config['n_points']+1), 160 | hop_length = config['hop_length'], 161 | optimized = config['optimized'], 162 | normalize_window = config['normalize_window'], 163 | ) 164 | elif config['model_name'] == 'bn_linear_net': 165 | net = models.BatchNormLinearNet( 166 | n_classes = n_classes, 167 | init_lambd = torch.tensor(config['init_lambd']), 168 | device = config['device'], 169 | size = (config['n_points']+1, config['n_points']+1), 170 | hop_length = config['hop_length'], 171 | optimized = config['optimized'], 172 | normalize_window = config['normalize_window'], 173 | ) 174 | 175 | elif config['model_name'] == 'non_linear_net': 176 | net = models.NonLinearNet( 177 | n_classes = n_classes, 178 | init_lambd = torch.tensor(config['init_lambd']), 179 | device = config['device'], 180 | size = (config['n_points']+1, config['n_points']+1), 181 | hop_length = config['hop_length'], 182 | optimized = config['optimized'], 183 | normalize_window = config['normalize_window'], 184 | ) 185 | elif config['model_name'] == 'mlp_net': 186 | net = models.MlpNet( 187 | n_classes = n_classes, 188 | init_lambd = torch.tensor(config['init_lambd']), 189 | device = config['device'], 190 | size = (config['n_points']+1, config['n_points']+1), 191 | hop_length = config['hop_length'], 192 | optimized = config['optimized'], 193 | normalize_window = config['normalize_window'], 194 | ) 195 | elif config['model_name'] == 'conv_net': 196 | net = models.ConvNet( 197 | n_classes = n_classes, 198 | init_lambd = torch.tensor(config['init_lambd']), 199 | device = config['device'], 200 | size = (config['n_points']+1, config['n_points']+1), 201 | hop_length = config['hop_length'], 202 | optimized = config['optimized'], 203 | normalize_window = config['normalize_window'], 204 | ) 205 | elif config['model_name'] == 'mel_linear_net': 206 | net = models.MelLinearNet( 207 | n_classes = n_classes, 208 | init_lambd = torch.tensor(config['init_lambd']), 209 | device = config['device'], 210 | n_mels = config['n_mels'], 211 | sample_rate = config['resample_rate'], 212 | n_points = config['n_points'], 213 | hop_length = config['hop_length'], 214 | optimized = config['optimized'], 215 | energy_normalize = config['energy_normalize'], 216 | normalize_window = config['normalize_window'], 217 | ) 218 | elif config['model_name'] == 'mel_mlp_net': 219 | net = models.MelMlpNet( 220 | n_classes = n_classes, 221 | init_lambd = torch.tensor(config['init_lambd']), 222 | device = config['device'], 223 | n_mels = config['n_mels'], 224 | sample_rate = config['resample_rate'], 225 | n_points = config['n_points'], 226 | hop_length = config['hop_length'], 227 | optimized = config['optimized'], 228 | energy_normalize = config['energy_normalize'], 229 | normalize_window = config['normalize_window'], 230 | ) 231 | elif config['model_name'] == 'mel_conv_net': 232 | net = models.MelConvNet( 233 | n_classes = n_classes, 234 | init_lambd = torch.tensor(config['init_lambd']), 235 | device = config['device'], 236 | n_mels = config['n_mels'], 237 | sample_rate = config['resample_rate'], 238 | n_points = config['n_points'], 239 | hop_length = config['hop_length'], 240 | optimized = config['optimized'], 241 | energy_normalize = config['energy_normalize'], 242 | normalize_window = config['normalize_window'], 243 | ) 244 | elif config['model_name'] == 'panns_cnn6': 245 | net = models.MelPANNsNet( 246 | n_classes = n_classes, 247 | init_lambd = torch.tensor(config['init_lambd']), 248 | device = config['device'], 249 | n_mels = config['n_mels'], 250 | sample_rate = config['resample_rate'], 251 | n_points = config['n_points'], 252 | hop_length = config['hop_length'], 253 | optimized = config['optimized'], 254 | augment = config['augment'], 255 | energy_normalize = config['energy_normalize'], 256 | normalize_window = config['normalize_window'], 257 | ) 258 | else: 259 | raise ValueError("model name not found: ", config['model_name']) 260 | 261 | return net 262 | 263 | def get_predictions_by_row_new(row, dataloader, device='cpu'): 264 | config = get_config_by_row(row) 265 | config['device'] = device 266 | logdir = row[1]['logdir'] 267 | model_chechpoint_path = os.path.join(logdir, 'checkpoint_000000', 'best_model') 268 | net = get_model_by_config(config) 269 | (model_state, optimizer_state) = torch.load(model_chechpoint_path) 270 | net.load_state_dict(model_state) 271 | net.to(device) 272 | 273 | net.eval() 274 | all_predictions = [] 275 | all_labels = [] 276 | for data in tqdm.tqdm(dataloader): 277 | inputs, labels = data 278 | inputs, labels = inputs.to(device), labels.to(device) 279 | 280 | outputs, _ = net(inputs) 281 | predictions = torch.argmax(outputs, axis=1) 282 | 283 | all_labels.append(labels.detach().cpu().numpy()) 284 | all_predictions.append(predictions.detach().cpu().numpy()) 285 | 286 | return np.concatenate(all_labels), np.concatenate(all_predictions) 287 | 288 | 289 | def get_predictions_by_row(row, data_dir, device='cpu', split='valid'): 290 | config = get_config_by_row(row) 291 | config['device'] = device 292 | logdir = row[1]['logdir'] 293 | model_chechpoint_path = os.path.join(logdir, 'checkpoint_000000', 'best_model') 294 | net = get_model_by_config(config) 295 | (model_state, optimizer_state) = torch.load(model_chechpoint_path) 296 | net.load_state_dict(model_state) 297 | net.to(device) 298 | 299 | # load dataset 300 | trainset,validset,testset = get_dataset_by_config(config, data_dir) 301 | 302 | if split == 'valid': 303 | dataset = validset 304 | elif split == 'train': 305 | dataset = trainset 306 | elif split == 'test': 307 | dataset = testset 308 | else: 309 | raise ValueError("data split: {} is not supported".format(split)) 310 | 311 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=8) 312 | 313 | net.eval() 314 | all_predictions = [] 315 | all_labels = [] 316 | for data in tqdm.tqdm(dataloader): 317 | inputs, labels = data 318 | inputs, labels = inputs.to(device), labels.to(device) 319 | 320 | outputs, _ = net(inputs) 321 | predictions = torch.argmax(outputs, axis=1) 322 | 323 | all_labels.append(labels.detach().cpu().numpy()) 324 | all_predictions.append(predictions.detach().cpu().numpy()) 325 | 326 | return np.concatenate(all_labels), np.concatenate(all_predictions) 327 | 328 | def plot_spectrogram(s, ax, decorate_axes=True): 329 | ax.imshow(np.flip(s, axis=0), aspect='auto') 330 | 331 | # decorate axes 332 | if decorate_axes: 333 | ax.set_xlabel('time') 334 | ax.set_ylabel('normalized frequency') 335 | 336 | (fbins, tbins) = s.shape 337 | yticks = [t for t in np.linspace(0, fbins-1, 5)] 338 | yticklabels = [str(l) for l in np.linspace(0.5, 0, 5)] 339 | ax.set_yticks(yticks) 340 | ax.set_yticklabels(yticklabels) 341 | 342 | 343 | # kept in case I need it again 344 | def sample_d_tloc_d_floc_ellipse(sigma): 345 | 346 | d_freq = 1/(torch.pi * sigma) 347 | d_time = 2 * sigma 348 | 349 | angle = torch.rand(1) * torch.pi * 2 350 | d_tloc = torch.sin(angle) * d_time 351 | d_floc = torch.cos(angle) * d_freq 352 | 353 | return d_tloc, d_floc 354 | 355 | def sample_d_tloc_d_floc_between_ellipses(sigma, scale): 356 | r1 = 1/(torch.pi * sigma) 357 | r2 = r1 * scale #1/(torch.pi * scale * sigma) 358 | d_freq = (r1 - r2) * torch.rand(1) + r2 359 | 360 | r1 = 2 * sigma 361 | r2 = 2 * scale * sigma 362 | 363 | d_time = (r1 - r2) * torch.rand(1) + r2 364 | 365 | angle = torch.rand(1) * torch.pi * 2 366 | d_tloc = torch.sin(angle) * d_time 367 | d_floc = torch.cos(angle) * d_freq 368 | 369 | return d_tloc, d_floc 370 | 371 | def sample_around_optimal_ellipse(t_loc, f_loc, sigma, scale=2): 372 | d_tloc, d_floc = sample_d_tloc_d_floc_between_ellipses(sigma, scale) 373 | return t_loc + d_tloc, f_loc + d_floc 374 | 375 | def sample_on_optimal_ellipse(t_loc, f_loc, sigma): 376 | d_tloc, d_floc = sample_d_tloc_d_floc_ellipse(sigma) 377 | return t_loc + d_tloc, f_loc + d_floc 378 | 379 | def sample_on_optimal_ellipse(t_loc, f_loc, sigma): 380 | d_tloc, d_floc = sample_d_tloc_d_floc_ellipse(sigma) 381 | return t_loc + d_tloc, f_loc + d_floc 382 | 383 | 384 | --------------------------------------------------------------------------------