├── LICENSE ├── README.md ├── dataset ├── MIR1K │ └── info.csv └── pre_MIR1K.py ├── evaluate.py ├── src ├── __init__.py ├── constants.py ├── dataset.py ├── inference.py ├── loss.py ├── model.py ├── modules.py ├── seq.py └── utils.py ├── train_Base.py ├── train_DJCM.py └── train_MMOE.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DJCM 2 | 3 | This repo is the Pytorch implementation of ["DJCM: A Deep Joint Cascade Model for Singing Voice Separation and Vocal Pitch Estimation"](https://arxiv.org/abs/2401.03856). 4 | -------------------------------------------------------------------------------- /dataset/MIR1K/info.csv: -------------------------------------------------------------------------------- 1 | name,duration,split 2 | Ani_1_01.wav,8.8,train 3 | Ani_1_02.wav,9.8880625,test 4 | Ani_1_03.wav,6.1440625,test 5 | Ani_1_04.wav,6.2720625,train 6 | Ani_1_05.wav,8.1280625,train 7 | Ani_1_06.wav,6.4640625,train 8 | Ani_1_07.wav,9.45975,train 9 | Ani_2_01.wav,6.272,test 10 | Ani_2_02.wav,9.2480625,train 11 | Ani_2_03.wav,7.072,train 12 | Ani_2_04.wav,6.3680625,train 13 | Ani_2_05.wav,11.36,test 14 | Ani_2_06.wav,10.8480625,train 15 | Ani_2_07.wav,7.514875,train 16 | Ani_3_01.wav,11.9038125,train 17 | Ani_3_02.wav,8.5438125,train 18 | Ani_3_03.wav,12.0,train 19 | Ani_3_04.wav,6.3680625,train 20 | Ani_3_05.wav,11.5838125,train 21 | Ani_3_06.wav,6.183875,train 22 | Ani_4_01.wav,8.96,train 23 | Ani_4_02.wav,7.8080625,train 24 | Ani_4_03.wav,8.9920625,train 25 | Ani_4_04.wav,6.1760625,train 26 | Ani_4_05.wav,9.7920625,train 27 | Ani_4_06.wav,8.0960625,train 28 | Ani_4_07.wav,8.1920625,train 29 | Ani_4_08.wav,6.4320625,train 30 | Ani_4_09.wav,6.3360625,train 31 | Ani_4_10.wav,7.2716875,train 32 | Ani_5_01.wav,6.944,train 33 | Ani_5_02.wav,8.4480625,test 34 | Ani_5_03.wav,6.2080625,train 35 | Ani_5_04.wav,8.0640625,train 36 | Ani_5_05.wav,6.0800625,train 37 | Ani_5_06.wav,6.0480625,train 38 | Ani_5_07.wav,6.9760625,train 39 | Ani_5_08.wav,6.6479375,test 40 | Kenshin_1_01.wav,7.2,train 41 | Kenshin_1_02.wav,10.4640625,train 42 | Kenshin_1_03.wav,7.0720625,train 43 | Kenshin_1_04.wav,10.6880625,train 44 | Kenshin_1_05.wav,11.2640625,train 45 | Kenshin_1_06.wav,7.4240625,train 46 | Kenshin_1_07.wav,11.3280625,train 47 | Kenshin_1_08.wav,11.4560625,train 48 | Kenshin_1_09.wav,6.3680625,train 49 | Kenshin_1_10.wav,11.1360625,train 50 | Kenshin_1_11.wav,6.9201875,train 51 | Kenshin_2_01.wav,8.192,train 52 | Kenshin_2_02.wav,7.3280625,train 53 | Kenshin_2_03.wav,8.2240625,train 54 | Kenshin_2_04.wav,7.2960625,train 55 | Kenshin_2_05.wav,8.2240625,train 56 | Kenshin_2_06.wav,9.2160625,train 57 | Kenshin_2_07.wav,9.2800625,train 58 | Kenshin_2_10.wav,4.9796875,train 59 | Kenshin_2_11.wav,7.115125,train 60 | Kenshin_3_01.wav,6.4,test 61 | Kenshin_3_02.wav,8.7680625,train 62 | Kenshin_3_03.wav,6.5920625,train 63 | Kenshin_3_04.wav,6.5280625,train 64 | Kenshin_3_05.wav,5.626125,train 65 | Kenshin_3_06.wav,7.42,train 66 | Kenshin_3_07.wav,6.6880625,train 67 | Kenshin_3_08.wav,11.641375,train 68 | Kenshin_4_01.wav,6.624,train 69 | Kenshin_4_02.wav,8.9920625,test 70 | Kenshin_4_03.wav,9.4400625,test 71 | Kenshin_4_04.wav,10.7200625,test 72 | Kenshin_4_05.wav,9.3120625,train 73 | Kenshin_4_06.wav,6.7200625,train 74 | Kenshin_4_07.wav,8.5120625,train 75 | Kenshin_4_08.wav,7.6800625,train 76 | Kenshin_4_09.wav,7.4240625,train 77 | Kenshin_4_10.wav,7.0400625,train 78 | Kenshin_4_11.wav,8.618875,train 79 | Kenshin_5_01.wav,6.112,test 80 | Kenshin_5_02.wav,6.7200625,train 81 | Kenshin_5_03.wav,9.4080625,test 82 | Kenshin_5_04.wav,10.8160625,test 83 | Kenshin_5_05.wav,6.7840625,train 84 | Kenshin_5_06.wav,6.4960625,train 85 | Kenshin_5_07.wav,6.9120625,train 86 | Kenshin_5_08.wav,7.6480625,train 87 | Kenshin_5_09.wav,10.2080625,train 88 | Kenshin_5_10.wav,9.7920625,test 89 | Kenshin_5_11.wav,6.5920625,test 90 | Kenshin_5_12.wav,10.3360625,train 91 | Kenshin_5_13.wav,11.2253125,train 92 | abjones_1_01.wav,11.616,test 93 | abjones_1_02.wav,6.8800625,test 94 | abjones_1_03.wav,7.1360625,train 95 | abjones_1_04.wav,6.5600625,test 96 | abjones_2_01.wav,6.56,train 97 | abjones_2_02.wav,8.3840625,train 98 | abjones_2_03.wav,10.1440625,train 99 | abjones_2_04.wav,6.4640625,train 100 | abjones_2_05.wav,6.2400625,test 101 | abjones_2_06.wav,7.781625,train 102 | abjones_2_07.wav,3.0784375,train 103 | abjones_2_08.wav,6.2080625,train 104 | abjones_2_09.wav,8.0320625,train 105 | abjones_2_10.wav,8.8640625,train 106 | abjones_2_11.wav,5.445625,train 107 | abjones_2_12.wav,5.7293125,train 108 | abjones_3_01.wav,8.896,train 109 | abjones_3_02.wav,7.2320625,test 110 | abjones_3_03.wav,7.7440625,train 111 | abjones_3_04.wav,6.7840625,train 112 | abjones_3_06.wav,7.4880625,train 113 | abjones_3_09.wav,8.7040625,train 114 | abjones_3_10.wav,6.8800625,train 115 | abjones_3_11.wav,8.0640625,train 116 | abjones_3_12.wav,7.2640625,train 117 | abjones_3_14.wav,6.8173125,train 118 | abjones_4_01.wav,7.008,train 119 | abjones_4_02.wav,8.2560625,train 120 | abjones_4_03.wav,11.0080625,train 121 | abjones_4_04.wav,11.2960625,train 122 | abjones_4_05.wav,7.5840625,test 123 | abjones_4_06.wav,7.2320625,train 124 | abjones_4_07.wav,7.5520625,train 125 | abjones_4_08.wav,7.3709375,test 126 | abjones_5_01.wav,8.48,train 127 | abjones_5_02.wav,7.4880625,train 128 | abjones_5_03.wav,7.0400625,train 129 | abjones_5_04.wav,7.3600625,test 130 | abjones_5_05.wav,11.7120625,train 131 | abjones_5_06.wav,5.575,train 132 | abjones_5_07.wav,6.602,test 133 | abjones_5_08.wav,7.3600625,train 134 | abjones_5_09.wav,8.4160625,test 135 | amy_10_01.wav,8.768,train 136 | amy_10_02.wav,10.3040625,train 137 | amy_10_03.wav,6.1760625,train 138 | amy_10_04.wav,6.3680625,train 139 | amy_10_05.wav,6.9440625,test 140 | amy_10_06.wav,8.2880625,test 141 | amy_10_07.wav,9.4080625,test 142 | amy_10_08.wav,8.1245,train 143 | amy_11_01.wav,10.592,train 144 | amy_11_02.wav,8.7040625,test 145 | amy_11_03.wav,6.7200625,train 146 | amy_11_04.wav,7.1040625,train 147 | amy_11_05.wav,11.4240625,train 148 | amy_11_06.wav,8.5440625,train 149 | amy_11_07.wav,8.580375,train 150 | amy_12_01.wav,10.112,train 151 | amy_12_02.wav,8.0640625,train 152 | amy_12_03.wav,8.6720625,train 153 | amy_12_04.wav,8.7680625,test 154 | amy_12_05.wav,8.4160625,test 155 | amy_12_06.wav,10.3680625,train 156 | amy_12_07.wav,8.471875,train 157 | amy_13_01.wav,10.816,train 158 | amy_13_02.wav,8.2560625,train 159 | amy_13_03.wav,10.5600625,train 160 | amy_13_04.wav,8.2880625,train 161 | amy_13_05.wav,6.2720625,train 162 | amy_13_06.wav,8.0320625,test 163 | amy_13_07.wav,6.71625,train 164 | amy_14_01.wav,11.936,train 165 | amy_14_02.wav,10.3360625,train 166 | amy_14_03.wav,6.9440625,train 167 | amy_14_04.wav,8.0320625,train 168 | amy_14_05.wav,10.7200625,train 169 | amy_14_06.wav,10.5526875,train 170 | amy_15_01.wav,8.736,train 171 | amy_15_02.wav,11.9680625,train 172 | amy_15_03.wav,6.9440625,train 173 | amy_15_04.wav,8.9280625,train 174 | amy_15_05.wav,11.3280625,test 175 | amy_15_06.wav,5.058125,test 176 | amy_15_07.wav,7.5853125,train 177 | amy_15_08.wav,7.1040625,train 178 | amy_15_09.wav,6.3360625,train 179 | amy_15_11.wav,8.6400625,train 180 | amy_15_12.wav,7.8080625,train 181 | amy_15_13.wav,9.4530625,train 182 | amy_16_01.wav,7.584,train 183 | amy_16_02.wav,6.5600625,train 184 | amy_16_03.wav,6.7200625,train 185 | amy_16_04.wav,10.9760625,train 186 | amy_16_05.wav,8.6720625,train 187 | amy_16_06.wav,6.1440625,test 188 | amy_16_07.wav,7.1360625,train 189 | amy_16_08.wav,9.1185625,test 190 | amy_1_01.wav,11.36,train 191 | amy_1_02.wav,9.8560625,train 192 | amy_1_03.wav,11.9040625,test 193 | amy_1_04.wav,8.7360625,train 194 | amy_1_05.wav,6.0764375,test 195 | amy_1_06.wav,4.8454375,train 196 | amy_1_07.wav,9.255125,train 197 | amy_2_01.wav,9.408,train 198 | amy_2_02.wav,9.0880625,test 199 | amy_2_03.wav,6.0480625,train 200 | amy_2_04.wav,8.9920625,train 201 | amy_2_05.wav,6.0480625,train 202 | amy_2_06.wav,8.7360625,train 203 | amy_2_07.wav,9.074125,train 204 | amy_3_01.wav,6.528,test 205 | amy_3_02.wav,9.1520625,train 206 | amy_3_03.wav,6.7840625,test 207 | amy_3_04.wav,11.5840625,train 208 | amy_3_05.wav,7.9680625,test 209 | amy_3_06.wav,8.3200625,test 210 | amy_3_07.wav,7.2320625,test 211 | amy_3_08.wav,6.433,train 212 | amy_4_01.wav,8.0,train 213 | amy_4_02.wav,6.9440625,train 214 | amy_4_03.wav,7.5840625,train 215 | amy_4_04.wav,6.3680625,train 216 | amy_4_05.wav,6.7200625,train 217 | amy_4_06.wav,6.5920625,train 218 | amy_4_07.wav,8.8640625,train 219 | amy_4_08.wav,7.3280625,train 220 | amy_4_09.wav,11.7120625,train 221 | amy_4_10.wav,10.2720625,train 222 | amy_4_11.wav,7.2181875,train 223 | amy_5_01.wav,11.328,train 224 | amy_5_02.wav,9.3760625,train 225 | amy_5_03.wav,7.4560625,train 226 | amy_5_04.wav,7.5200625,train 227 | amy_5_05.wav,6.0160625,train 228 | amy_5_06.wav,6.0480625,test 229 | amy_5_07.wav,10.6880625,test 230 | amy_5_08.wav,6.559625,train 231 | amy_6_01.wav,6.144,train 232 | amy_6_02.wav,6.0480625,train 233 | amy_6_03.wav,6.6560625,train 234 | amy_6_04.wav,6.1440625,train 235 | amy_6_05.wav,6.2400625,train 236 | amy_6_06.wav,6.1760625,train 237 | amy_6_07.wav,6.1120625,train 238 | amy_6_08.wav,7.6800625,test 239 | amy_6_09.wav,8.8320625,train 240 | amy_6_10.wav,11.8720625,train 241 | amy_6_11.wav,7.8765,train 242 | amy_7_01.wav,6.304,train 243 | amy_7_02.wav,11.6160625,train 244 | amy_7_03.wav,6.7840625,train 245 | amy_7_04.wav,6.2080625,train 246 | amy_7_05.wav,11.2640625,test 247 | amy_7_06.wav,7.8080625,train 248 | amy_7_07.wav,6.8160625,train 249 | amy_7_08.wav,7.79225,train 250 | amy_8_01.wav,10.88,train 251 | amy_8_02.wav,7.9360625,train 252 | amy_8_03.wav,10.8800625,test 253 | amy_8_04.wav,6.7840625,train 254 | amy_8_05.wav,8.8000625,train 255 | amy_8_06.wav,9.3120625,train 256 | amy_8_07.wav,7.5200625,train 257 | amy_8_08.wav,9.13375,test 258 | amy_9_01.wav,8.672,train 259 | amy_9_02.wav,8.3200625,test 260 | amy_9_03.wav,6.5600625,test 261 | amy_9_04.wav,9.8560625,test 262 | amy_9_05.wav,8.2560625,train 263 | amy_9_06.wav,6.4960625,test 264 | amy_9_07.wav,6.6240625,train 265 | amy_9_08.wav,6.1440625,train 266 | amy_9_09.wav,6.2918125,train 267 | annar_1_01.wav,6.752,test 268 | annar_1_02.wav,9.7280625,test 269 | annar_1_03.wav,6.8800625,test 270 | annar_1_04.wav,5.2160625,train 271 | annar_1_05.wav,9.9520625,train 272 | annar_1_06.wav,4.9920625,test 273 | annar_1_07.wav,6.5920625,train 274 | annar_1_08.wav,6.8104375,test 275 | annar_2_01.wav,5.312,train 276 | annar_2_02.wav,8.3520625,test 277 | annar_2_03.wav,9.6960625,train 278 | annar_2_04.wav,8.0320625,train 279 | annar_2_05.wav,9.6000625,train 280 | annar_2_06.wav,8.8320625,train 281 | annar_2_07.wav,5.3120625,train 282 | annar_2_08.wav,6.8118125,test 283 | annar_3_01.wav,6.272,test 284 | annar_3_02.wav,5.9520625,train 285 | annar_3_03.wav,5.8560625,train 286 | annar_3_04.wav,5.8560625,test 287 | annar_3_05.wav,5.3120625,train 288 | annar_3_06.wav,6.3680625,train 289 | annar_3_07.wav,5.2800625,train 290 | annar_3_08.wav,5.22625,train 291 | annar_4_01.wav,7.648,train 292 | annar_4_02.wav,5.1200625,train 293 | annar_4_03.wav,9.8880625,train 294 | annar_4_04.wav,6.0160625,train 295 | annar_4_05.wav,5.2160625,train 296 | annar_4_06.wav,9.4080625,test 297 | annar_4_07.wav,5.6320625,train 298 | annar_4_08.wav,5.3120625,train 299 | annar_4_09.wav,9.99525,test 300 | annar_5_01.wav,5.472,test 301 | annar_5_02.wav,6.6240625,train 302 | annar_5_03.wav,5.9520625,train 303 | annar_5_04.wav,9.5040625,train 304 | annar_5_05.wav,5.5360625,train 305 | annar_5_06.wav,9.5680625,train 306 | annar_5_07.wav,6.3360625,train 307 | annar_5_08.wav,6.4640625,train 308 | annar_5_09.wav,6.7153125,train 309 | ariel_1_01.wav,9.216,test 310 | ariel_1_02.wav,9.6000625,train 311 | ariel_1_03.wav,10.1440625,train 312 | ariel_1_04.wav,7.5520625,test 313 | ariel_1_05.wav,7.9360625,train 314 | ariel_1_06.wav,10.1440625,train 315 | ariel_1_07.wav,6.6065625,train 316 | ariel_2_01.wav,6.688,train 317 | ariel_2_02.wav,6.5600625,train 318 | ariel_2_03.wav,7.0400625,train 319 | ariel_2_04.wav,6.6880625,train 320 | ariel_2_05.wav,6.5600625,train 321 | ariel_2_06.wav,10.0160625,train 322 | ariel_2_07.wav,10.2400625,test 323 | ariel_2_08.wav,7.5840625,train 324 | ariel_2_09.wav,6.85625,train 325 | ariel_3_01.wav,9.184,test 326 | ariel_3_02.wav,9.4080625,train 327 | ariel_3_03.wav,8.6400625,train 328 | ariel_3_04.wav,8.4480625,train 329 | ariel_3_05.wav,7.1360625,train 330 | ariel_3_06.wav,6.5280625,train 331 | ariel_3_07.wav,7.3920625,test 332 | ariel_3_08.wav,8.0024375,train 333 | ariel_4_01.wav,7.584,train 334 | ariel_4_02.wav,6.930375,train 335 | ariel_4_03.wav,3.633375,train 336 | ariel_4_04.wav,6.7840625,train 337 | ariel_4_05.wav,9.6640625,train 338 | ariel_4_06.wav,11.7120625,train 339 | ariel_4_07.wav,10.6880625,train 340 | ariel_4_08.wav,10.1453125,train 341 | ariel_5_01.wav,6.72,train 342 | ariel_5_02.wav,11.2320625,train 343 | ariel_5_03.wav,6.1120625,train 344 | ariel_5_04.wav,6.0160625,train 345 | ariel_5_05.wav,6.2080625,train 346 | ariel_5_06.wav,6.9440625,train 347 | ariel_5_07.wav,8.1600625,train 348 | ariel_5_08.wav,8.5086875,train 349 | bobon_1_01.wav,6.976,test 350 | bobon_1_02.wav,6.2400625,test 351 | bobon_1_03.wav,7.5840625,test 352 | bobon_1_04.wav,6.1440625,test 353 | bobon_1_05.wav,7.4240625,train 354 | bobon_1_06.wav,6.9120625,train 355 | bobon_1_07.wav,6.3680625,train 356 | bobon_1_08.wav,11.8400625,train 357 | bobon_1_09.wav,7.2320625,train 358 | bobon_1_10.wav,8.1156875,train 359 | bobon_2_01.wav,11.648,train 360 | bobon_2_02.wav,10.4640625,train 361 | bobon_2_03.wav,11.0400625,test 362 | bobon_2_04.wav,7.4240625,train 363 | bobon_2_05.wav,6.6880625,train 364 | bobon_2_06.wav,8.0640625,train 365 | bobon_2_07.wav,6.8160625,train 366 | bobon_2_08.wav,8.622625,train 367 | bobon_3_01.wav,6.4,test 368 | bobon_3_02.wav,6.6240625,train 369 | bobon_3_03.wav,11.9680625,test 370 | bobon_3_04.wav,10.4320625,train 371 | bobon_3_05.wav,6.5600625,train 372 | bobon_3_06.wav,6.0480625,train 373 | bobon_3_07.wav,6.6560625,train 374 | bobon_3_08.wav,4.4394375,train 375 | bobon_3_09.wav,5.75025,train 376 | bobon_3_10.wav,10.01425,test 377 | bobon_4_01.wav,6.304,train 378 | bobon_4_02.wav,8.0000625,train 379 | bobon_4_03.wav,7.2320625,train 380 | bobon_4_04.wav,6.0160625,test 381 | bobon_4_05.wav,8.7040625,train 382 | bobon_4_06.wav,8.1920625,train 383 | bobon_4_07.wav,8.7040625,train 384 | bobon_4_08.wav,10.7520625,train 385 | bobon_4_09.wav,6.0160625,train 386 | bobon_4_10.wav,6.6359375,test 387 | bobon_5_01.wav,6.656,train 388 | bobon_5_02.wav,6.4640625,train 389 | bobon_5_03.wav,6.8800625,train 390 | bobon_5_04.wav,6.0480625,test 391 | bobon_5_05.wav,9.7920625,train 392 | bobon_5_06.wav,9.8560625,train 393 | bobon_5_07.wav,6.6560625,test 394 | bobon_5_08.wav,8.6400625,train 395 | bobon_5_09.wav,3.8924375,train 396 | bobon_5_10.wav,9.277125,train 397 | bobon_5_11.wav,8.8320625,train 398 | bobon_5_12.wav,6.449625,train 399 | bug_1_07.wav,7.9360625,train 400 | bug_1_08.wav,8.0960625,test 401 | bug_1_09.wav,7.2960625,train 402 | bug_1_10.wav,9.261875,test 403 | bug_2_01.wav,7.04,test 404 | bug_2_02.wav,6.3040625,train 405 | bug_2_03.wav,7.1040625,train 406 | bug_2_04.wav,8.1280625,train 407 | bug_2_05.wav,8.4800625,test 408 | bug_2_06.wav,7.0400625,train 409 | bug_2_07.wav,6.1760625,test 410 | bug_2_08.wav,9.1789375,train 411 | bug_3_01.wav,10.336,train 412 | bug_3_02.wav,6.4320625,train 413 | bug_3_03.wav,9.5040625,train 414 | bug_3_04.wav,10.7520625,test 415 | bug_3_05.wav,10.1440625,train 416 | bug_3_06.wav,10.7840625,train 417 | bug_3_07.wav,11.7120625,train 418 | bug_3_08.wav,9.6000625,train 419 | bug_3_09.wav,11.2165,train 420 | bug_4_01.wav,8.8,train 421 | bug_4_02.wav,7.6156875,train 422 | bug_4_03.wav,6.1938125,train 423 | bug_5_01.wav,6.08,train 424 | bug_5_02.wav,11.9676875,test 425 | bug_5_03.wav,7.1356875,test 426 | bug_5_04.wav,10.56,train 427 | bug_5_05.wav,6.3356875,train 428 | bug_5_06.wav,6.113875,train 429 | bug_5_07.wav,4.3283125,train 430 | bug_5_08.wav,10.624,train 431 | bug_5_09.wav,6.272,train 432 | bug_5_10.wav,8.2876875,train 433 | bug_5_11.wav,7.264,train 434 | bug_5_12.wav,8.1276875,train 435 | bug_5_13.wav,6.3676875,test 436 | bug_5_14.wav,6.562125,train 437 | davidson_1_01.wav,7.68,train 438 | davidson_1_02.wav,7.6480625,train 439 | davidson_1_03.wav,6.2080625,train 440 | davidson_1_04.wav,7.9360625,train 441 | davidson_1_05.wav,9.6640625,test 442 | davidson_1_06.wav,10.9760625,train 443 | davidson_1_07.wav,7.6160625,train 444 | davidson_1_08.wav,6.4000625,train 445 | davidson_1_09.wav,7.6800625,train 446 | davidson_1_10.wav,8.9616875,train 447 | davidson_2_01.wav,7.648,train 448 | davidson_2_02.wav,6.8480625,train 449 | davidson_2_03.wav,6.7840625,train 450 | davidson_2_04.wav,9.7600625,test 451 | davidson_2_05.wav,9.6640625,train 452 | davidson_2_06.wav,8.6720625,train 453 | davidson_2_07.wav,11.0720625,train 454 | davidson_2_08.wav,8.1920625,train 455 | davidson_2_09.wav,6.0160625,train 456 | davidson_2_10.wav,10.9299375,train 457 | davidson_3_01.wav,10.24,train 458 | davidson_3_02.wav,6.2720625,train 459 | davidson_3_03.wav,6.9440625,train 460 | davidson_3_04.wav,10.5920625,train 461 | davidson_3_05.wav,6.5280625,train 462 | davidson_3_06.wav,6.6240625,train 463 | davidson_3_07.wav,7.8720625,train 464 | davidson_3_08.wav,6.6240625,train 465 | davidson_3_09.wav,9.8880625,train 466 | davidson_3_10.wav,10.0480625,test 467 | davidson_3_11.wav,6.5920625,train 468 | davidson_3_12.wav,9.9840625,train 469 | davidson_3_13.wav,6.5600625,train 470 | davidson_3_14.wav,6.05025,train 471 | davidson_4_01.wav,8.864,train 472 | davidson_4_02.wav,8.6080625,train 473 | davidson_4_03.wav,8.7040625,train 474 | davidson_4_04.wav,6.9120625,train 475 | davidson_4_05.wav,8.7680625,train 476 | davidson_4_06.wav,8.6720625,train 477 | davidson_4_07.wav,8.9496875,train 478 | davidson_5_01.wav,7.872,train 479 | davidson_5_02.wav,6.6240625,train 480 | davidson_5_03.wav,7.2000625,train 481 | davidson_5_04.wav,7.6160625,train 482 | davidson_5_05.wav,7.1040625,train 483 | davidson_5_06.wav,11.1680625,test 484 | davidson_5_07.wav,6.2400625,test 485 | davidson_5_08.wav,10.9760625,train 486 | davidson_5_09.wav,6.7840625,train 487 | davidson_5_10.wav,7.5520625,train 488 | davidson_5_11.wav,8.405375,train 489 | fdps_1_01.wav,8.0,train 490 | fdps_1_02.wav,6.4000625,test 491 | fdps_1_03.wav,6.4000625,test 492 | fdps_1_04.wav,6.1120625,test 493 | fdps_1_05.wav,7.9360625,train 494 | fdps_1_06.wav,7.6800625,train 495 | fdps_1_07.wav,6.6560625,train 496 | fdps_1_08.wav,9.6320625,test 497 | fdps_1_09.wav,7.6480625,train 498 | fdps_1_10.wav,11.2000625,test 499 | fdps_1_11.wav,9.4400625,train 500 | fdps_1_12.wav,6.5600625,train 501 | fdps_1_13.wav,7.0400625,train 502 | fdps_1_14.wav,7.718,train 503 | fdps_2_01.wav,7.584,test 504 | fdps_2_02.wav,7.2000625,train 505 | fdps_2_03.wav,7.2640625,train 506 | fdps_2_04.wav,6.9760625,test 507 | fdps_2_05.wav,7.3600625,train 508 | fdps_2_06.wav,6.2720625,train 509 | fdps_2_07.wav,8.7360625,train 510 | fdps_2_08.wav,6.2720625,train 511 | fdps_2_09.wav,7.0080625,train 512 | fdps_2_10.wav,6.6560625,test 513 | fdps_2_11.wav,7.8400625,train 514 | fdps_2_12.wav,7.7394375,test 515 | fdps_3_01.wav,9.632,train 516 | fdps_3_02.wav,9.8880625,train 517 | fdps_3_03.wav,8.1280625,train 518 | fdps_3_04.wav,6.0800625,test 519 | fdps_3_05.wav,6.0480625,train 520 | fdps_3_06.wav,6.9760625,train 521 | fdps_3_07.wav,7.1410625,test 522 | fdps_4_01.wav,10.464,train 523 | fdps_4_02.wav,10.9120625,train 524 | fdps_4_03.wav,10.7840625,train 525 | fdps_4_04.wav,11.1360625,train 526 | fdps_4_05.wav,9.4080625,train 527 | fdps_4_06.wav,11.2879375,train 528 | fdps_5_01.wav,9.376,test 529 | fdps_5_02.wav,11.6480625,test 530 | fdps_5_03.wav,10.2400625,train 531 | fdps_5_04.wav,9.1520625,train 532 | fdps_5_05.wav,7.6480625,train 533 | fdps_5_06.wav,10.4640625,train 534 | fdps_5_07.wav,9.9520625,train 535 | fdps_5_08.wav,8.5440625,train 536 | fdps_5_09.wav,7.6653125,train 537 | geniusturtle_1_01.wav,11.328,train 538 | geniusturtle_1_02.wav,11.4560625,test 539 | geniusturtle_1_03.wav,10.5600625,train 540 | geniusturtle_1_04.wav,6.1440625,train 541 | geniusturtle_1_05.wav,11.3280625,test 542 | geniusturtle_1_06.wav,6.0160625,train 543 | geniusturtle_1_07.wav,10.8800625,train 544 | geniusturtle_1_08.wav,5.719625,test 545 | geniusturtle_1_09.wav,6.816625,train 546 | geniusturtle_2_01.wav,10.336,train 547 | geniusturtle_2_02.wav,9.5360625,train 548 | geniusturtle_2_05.wav,11.8720625,train 549 | geniusturtle_2_06.wav,9.7920625,train 550 | geniusturtle_2_07.wav,4.822,train 551 | geniusturtle_2_08.wav,9.644,train 552 | geniusturtle_3_01.wav,7.136,train 553 | geniusturtle_3_02.wav,6.5280625,train 554 | geniusturtle_3_03.wav,8.6720625,train 555 | geniusturtle_3_04.wav,6.8160625,train 556 | geniusturtle_3_05.wav,8.4480625,train 557 | geniusturtle_3_06.wav,6.4320625,test 558 | geniusturtle_3_07.wav,7.4560625,train 559 | geniusturtle_3_08.wav,8.97025,train 560 | geniusturtle_4_01.wav,7.2,train 561 | geniusturtle_4_02.wav,6.5920625,train 562 | geniusturtle_4_03.wav,7.3600625,train 563 | geniusturtle_4_04.wav,6.6560625,train 564 | geniusturtle_4_05.wav,7.5200625,train 565 | geniusturtle_4_06.wav,6.6240625,test 566 | geniusturtle_4_07.wav,7.0720625,train 567 | geniusturtle_4_08.wav,7.1680625,train 568 | geniusturtle_4_09.wav,7.3600625,train 569 | geniusturtle_4_10.wav,6.2400625,train 570 | geniusturtle_4_11.wav,8.2240625,test 571 | geniusturtle_4_12.wav,7.2124375,train 572 | geniusturtle_5_01.wav,10.88,test 573 | geniusturtle_5_02.wav,10.9440625,train 574 | geniusturtle_5_03.wav,10.6880625,train 575 | geniusturtle_5_04.wav,11.4575,train 576 | geniusturtle_6_01.wav,10.336,train 577 | geniusturtle_6_02.wav,10.2720625,train 578 | geniusturtle_6_03.wav,10.2400625,train 579 | geniusturtle_6_04.wav,9.6320625,train 580 | geniusturtle_6_05.wav,10.0480625,train 581 | geniusturtle_6_06.wav,8.2560625,test 582 | geniusturtle_6_07.wav,9.0395625,train 583 | geniusturtle_7_01.wav,7.712,train 584 | geniusturtle_7_02.wav,8.0640625,train 585 | geniusturtle_7_03.wav,7.2000625,train 586 | geniusturtle_7_04.wav,7.7760625,train 587 | geniusturtle_7_05.wav,8.2240625,train 588 | geniusturtle_7_06.wav,11.1680625,train 589 | geniusturtle_7_07.wav,11.4880625,test 590 | geniusturtle_7_08.wav,8.1920625,train 591 | geniusturtle_7_09.wav,7.5840625,train 592 | geniusturtle_7_10.wav,8.2240625,train 593 | geniusturtle_7_11.wav,7.1040625,train 594 | geniusturtle_7_12.wav,8.0960625,train 595 | geniusturtle_7_13.wav,7.4560625,train 596 | geniusturtle_7_14.wav,8.3200625,train 597 | geniusturtle_7_15.wav,10.1209375,train 598 | geniusturtle_8_01.wav,6.432,train 599 | geniusturtle_8_02.wav,6.4320625,train 600 | geniusturtle_8_03.wav,8.9920625,train 601 | geniusturtle_8_04.wav,7.9680625,train 602 | geniusturtle_8_05.wav,9.7920625,train 603 | geniusturtle_8_06.wav,9.5040625,train 604 | geniusturtle_8_07.wav,7.3920625,train 605 | geniusturtle_8_08.wav,8.313625,train 606 | heycat_1_01.wav,7.753375,train 607 | heycat_1_02.wav,5.7328125,train 608 | heycat_1_03.wav,9.6640625,train 609 | heycat_1_04.wav,7.0080625,train 610 | heycat_1_05.wav,8.1280625,train 611 | heycat_1_06.wav,7.2000625,train 612 | heycat_1_07.wav,6.1120625,train 613 | heycat_1_08.wav,10.0306875,train 614 | heycat_2_01.wav,12.032,train 615 | heycat_2_02.wav,6.0800625,test 616 | heycat_2_03.wav,7.7440625,test 617 | heycat_2_04.wav,6.5600625,train 618 | heycat_2_05.wav,7.9680625,train 619 | heycat_2_06.wav,11.5520625,train 620 | heycat_2_07.wav,8.189125,train 621 | heycat_3_01.wav,9.152,train 622 | heycat_3_02.wav,6.7520625,train 623 | heycat_3_03.wav,7.1680625,train 624 | heycat_3_04.wav,10.4640625,train 625 | heycat_3_05.wav,6.7200625,test 626 | heycat_3_06.wav,6.4000625,train 627 | heycat_3_07.wav,9.8560625,train 628 | heycat_3_08.wav,6.691625,train 629 | heycat_4_01.wav,6.496,train 630 | heycat_4_02.wav,6.4640625,train 631 | heycat_4_03.wav,7.8400625,train 632 | heycat_4_04.wav,9.9840625,train 633 | heycat_4_05.wav,7.0080625,train 634 | heycat_4_06.wav,7.0080625,test 635 | heycat_4_07.wav,7.2640625,train 636 | heycat_4_08.wav,7.4560625,train 637 | heycat_4_09.wav,6.3959375,train 638 | heycat_5_01.wav,10.528,train 639 | heycat_5_02.wav,6.3360625,test 640 | heycat_5_03.wav,6.4640625,test 641 | heycat_5_04.wav,6.4960625,test 642 | heycat_5_05.wav,6.3680625,train 643 | heycat_5_06.wav,6.6880625,train 644 | heycat_5_07.wav,9.9520625,train 645 | heycat_5_08.wav,10.4515625,train 646 | jmzen_1_01.wav,9.248,train 647 | jmzen_1_02.wav,6.5600625,test 648 | jmzen_1_03.wav,11.4560625,train 649 | jmzen_1_04.wav,7.3280625,train 650 | jmzen_1_05.wav,7.4560625,train 651 | jmzen_1_07.wav,11.4240625,train 652 | jmzen_1_08.wav,6.1440625,train 653 | jmzen_1_09.wav,6.9440625,train 654 | jmzen_1_10.wav,4.8209375,train 655 | jmzen_1_11.wav,7.45875,train 656 | jmzen_1_12.wav,7.837,train 657 | jmzen_2_01.wav,6.144,test 658 | jmzen_2_02.wav,6.0480625,train 659 | jmzen_2_03.wav,6.7520625,test 660 | jmzen_2_04.wav,12.0000625,train 661 | jmzen_2_06.wav,5.3738125,train 662 | jmzen_2_07.wav,7.101125,train 663 | jmzen_2_08.wav,11.4880625,test 664 | jmzen_2_09.wav,11.6160625,train 665 | jmzen_2_10.wav,6.3360625,train 666 | jmzen_2_11.wav,8.8000625,train 667 | jmzen_2_12.wav,9.911875,test 668 | jmzen_3_01.wav,6.624,train 669 | jmzen_3_02.wav,8.1600625,train 670 | jmzen_3_03.wav,7.2000625,train 671 | jmzen_3_04.wav,6.4320625,train 672 | jmzen_3_05.wav,9.5040625,test 673 | jmzen_3_06.wav,6.1760625,test 674 | jmzen_3_07.wav,10.6240625,train 675 | jmzen_3_08.wav,7.6160625,train 676 | jmzen_3_09.wav,8.7040625,train 677 | jmzen_3_10.wav,11.3920625,train 678 | jmzen_3_11.wav,6.975125,train 679 | jmzen_4_01.wav,10.88,train 680 | jmzen_4_02.wav,10.7520625,train 681 | jmzen_4_03.wav,11.4560625,train 682 | jmzen_4_04.wav,10.1120625,train 683 | jmzen_4_05.wav,10.0480625,train 684 | jmzen_4_06.wav,8.3840625,train 685 | jmzen_4_07.wav,6.4320625,test 686 | jmzen_4_08.wav,7.2320625,train 687 | jmzen_4_09.wav,8.4160625,test 688 | jmzen_4_10.wav,11.0215,train 689 | jmzen_5_01.wav,9.088,train 690 | jmzen_5_02.wav,7.2320625,test 691 | jmzen_5_03.wav,6.8480625,train 692 | jmzen_5_04.wav,9.3440625,train 693 | jmzen_5_05.wav,7.0080625,test 694 | jmzen_5_06.wav,8.7360625,train 695 | jmzen_5_07.wav,8.6080625,train 696 | jmzen_5_08.wav,7.9680625,train 697 | jmzen_5_09.wav,10.7434375,train 698 | khair_1_01.wav,6.592,train 699 | khair_1_02.wav,6.8800625,train 700 | khair_1_03.wav,7.2320625,train 701 | khair_1_04.wav,10.5280625,train 702 | khair_1_05.wav,8.5760625,train 703 | khair_1_06.wav,6.7840625,test 704 | khair_1_07.wav,7.0080625,train 705 | khair_1_08.wav,7.1005,train 706 | khair_2_01.wav,6.4,train 707 | khair_2_02.wav,7.2320625,train 708 | khair_2_03.wav,10.3360625,train 709 | khair_2_04.wav,10.6560625,train 710 | khair_2_05.wav,9.1520625,train 711 | khair_2_06.wav,6.8800625,test 712 | khair_2_07.wav,10.3839375,train 713 | khair_3_01.wav,7.488,train 714 | khair_3_02.wav,7.5200625,train 715 | khair_3_03.wav,6.5920625,train 716 | khair_3_04.wav,11.0720625,train 717 | khair_3_05.wav,11.8720625,train 718 | khair_3_06.wav,8.2880625,train 719 | khair_3_07.wav,8.5959375,train 720 | khair_4_01.wav,8.416,train 721 | khair_4_02.wav,6.2400625,train 722 | khair_4_03.wav,8.1920625,train 723 | khair_4_04.wav,7.4880625,test 724 | khair_4_05.wav,8.6080625,test 725 | khair_4_06.wav,6.2080625,train 726 | khair_4_07.wav,8.9280625,train 727 | khair_4_08.wav,7.2344375,train 728 | khair_5_01.wav,8.704,train 729 | khair_5_02.wav,7.9680625,test 730 | khair_5_03.wav,7.5840625,train 731 | khair_5_04.wav,8.9280625,train 732 | khair_5_05.wav,11.4560625,test 733 | khair_5_06.wav,7.9360625,train 734 | khair_5_07.wav,8.228,train 735 | khair_6_01.wav,11.84,train 736 | khair_6_02.wav,6.6240625,train 737 | khair_6_03.wav,8.4800625,train 738 | khair_6_04.wav,7.2000625,train 739 | khair_6_05.wav,6.3040625,train 740 | khair_6_06.wav,10.1120625,train 741 | khair_6_07.wav,10.2231875,train 742 | leon_1_01.wav,7.168,train 743 | leon_1_02.wav,6.5600625,train 744 | leon_1_03.wav,8.0640625,train 745 | leon_1_04.wav,8.2240625,train 746 | leon_1_05.wav,6.5920625,test 747 | leon_1_06.wav,6.2720625,train 748 | leon_1_07.wav,8.0000625,train 749 | leon_1_08.wav,8.0640625,train 750 | leon_1_09.wav,7.0400625,train 751 | leon_1_10.wav,6.4960625,train 752 | leon_1_11.wav,9.2480625,train 753 | leon_1_12.wav,10.3766875,test 754 | leon_2_01.wav,6.496,train 755 | leon_2_02.wav,7.4240625,test 756 | leon_2_03.wav,9.1200625,train 757 | leon_2_04.wav,10.7200625,train 758 | leon_2_05.wav,6.4960625,train 759 | leon_2_06.wav,6.52325,train 760 | leon_2_07.wav,6.764875,train 761 | leon_2_08.wav,6.6038125,train 762 | leon_2_09.wav,6.09725,train 763 | leon_2_10.wav,7.4560625,train 764 | leon_2_11.wav,7.69875,train 765 | leon_3_01.wav,7.072,train 766 | leon_3_02.wav,6.1440625,test 767 | leon_3_03.wav,7.6480625,train 768 | leon_3_04.wav,6.0160625,train 769 | leon_3_05.wav,6.6560625,train 770 | leon_3_06.wav,7.8400625,train 771 | leon_3_07.wav,7.7760625,train 772 | leon_3_08.wav,7.9360625,train 773 | leon_3_09.wav,9.0560625,train 774 | leon_3_10.wav,11.9360625,train 775 | leon_3_11.wav,7.5520625,test 776 | leon_3_12.wav,7.7760625,train 777 | leon_3_13.wav,10.392875,test 778 | leon_4_01.wav,10.144,train 779 | leon_4_02.wav,8.1280625,train 780 | leon_4_03.wav,11.1360625,train 781 | leon_4_04.wav,8.4800625,train 782 | leon_4_05.wav,9.2480625,train 783 | leon_4_06.wav,8.6080625,train 784 | leon_4_07.wav,10.9760625,train 785 | leon_4_08.wav,7.782875,train 786 | leon_5_01.wav,6.304,train 787 | leon_5_02.wav,6.2400625,train 788 | leon_5_03.wav,6.0800625,train 789 | leon_5_04.wav,6.4000625,train 790 | leon_5_05.wav,6.2491875,train 791 | leon_5_06.wav,5.79975,train 792 | leon_5_07.wav,6.7520625,train 793 | leon_5_08.wav,7.7440625,train 794 | leon_5_09.wav,6.9440625,train 795 | leon_5_10.wav,6.9120625,train 796 | leon_5_11.wav,7.2640625,train 797 | leon_5_12.wav,6.1249375,train 798 | leon_6_01.wav,8.256,train 799 | leon_6_02.wav,9.9520625,train 800 | leon_6_03.wav,9.8240625,train 801 | leon_6_04.wav,8.6720625,train 802 | leon_6_05.wav,9.7920625,train 803 | leon_6_06.wav,7.6160625,test 804 | leon_6_07.wav,6.0160625,test 805 | leon_6_08.wav,6.9440625,train 806 | leon_6_09.wav,6.0548125,train 807 | leon_7_01.wav,8.096,train 808 | leon_7_02.wav,6.1440625,train 809 | leon_7_03.wav,7.8400625,train 810 | leon_7_04.wav,7.4880625,train 811 | leon_7_05.wav,9.0880625,test 812 | leon_7_06.wav,9.1200625,train 813 | leon_7_07.wav,7.1360625,train 814 | leon_7_08.wav,8.2240625,test 815 | leon_7_09.wav,6.2080625,test 816 | leon_7_10.wav,8.9600625,train 817 | leon_7_11.wav,7.9360625,train 818 | leon_7_12.wav,8.4800625,test 819 | leon_7_13.wav,11.92175,train 820 | leon_8_01.wav,7.136,train 821 | leon_8_02.wav,6.2400625,train 822 | leon_8_03.wav,6.8480625,train 823 | leon_8_04.wav,9.1200625,test 824 | leon_8_05.wav,6.4640625,train 825 | leon_8_06.wav,6.1760625,test 826 | leon_8_07.wav,9.6320625,train 827 | leon_8_08.wav,7.0400625,train 828 | leon_8_09.wav,6.2400625,train 829 | leon_8_10.wav,11.4240625,train 830 | leon_8_11.wav,7.6160625,test 831 | leon_8_12.wav,9.6320625,train 832 | leon_8_13.wav,10.1693125,test 833 | leon_9_01.wav,10.048,test 834 | leon_9_02.wav,9.5680625,train 835 | leon_9_03.wav,10.5280625,test 836 | leon_9_04.wav,9.7280625,train 837 | leon_9_05.wav,10.3360625,train 838 | leon_9_06.wav,10.5435,train 839 | stool_1_01.wav,7.296,train 840 | stool_1_02.wav,6.9440625,train 841 | stool_1_03.wav,9.4720625,train 842 | stool_1_04.wav,6.7520625,train 843 | stool_1_05.wav,9.5360625,train 844 | stool_1_06.wav,6.7200625,train 845 | stool_1_07.wav,9.0880625,train 846 | stool_1_08.wav,7.5175625,train 847 | stool_1_09.wav,4.034125,train 848 | stool_2_01.wav,9.024,train 849 | stool_2_02.wav,7.1680625,train 850 | stool_2_03.wav,9.4080625,test 851 | stool_2_04.wav,7.4240625,train 852 | stool_2_05.wav,9.2480625,test 853 | stool_2_06.wav,6.6880625,test 854 | stool_2_07.wav,8.5440625,train 855 | stool_2_08.wav,8.18225,test 856 | stool_3_01.wav,7.648,train 857 | stool_3_02.wav,11.6480625,test 858 | stool_3_03.wav,9.6640625,test 859 | stool_3_04.wav,6.1760625,train 860 | stool_3_05.wav,6.1440625,train 861 | stool_3_06.wav,6.4320625,train 862 | stool_3_07.wav,7.4240625,train 863 | stool_3_08.wav,11.7760625,test 864 | stool_3_09.wav,8.3200625,train 865 | stool_3_10.wav,7.1145625,test 866 | stool_4_01.wav,6.048,test 867 | stool_4_02.wav,6.0480625,train 868 | stool_4_03.wav,10.5600625,train 869 | stool_4_04.wav,8.4480625,test 870 | stool_4_05.wav,6.2400625,train 871 | stool_4_06.wav,5.1151875,train 872 | stool_4_07.wav,5.38425,test 873 | stool_4_08.wav,11.4560625,train 874 | stool_4_09.wav,11.3600625,train 875 | stool_4_10.wav,11.68025,train 876 | stool_5_01.wav,11.008,train 877 | stool_5_02.wav,6.9760625,train 878 | stool_5_03.wav,7.3920625,train 879 | stool_5_04.wav,10.3680625,train 880 | stool_5_05.wav,11.0720625,test 881 | stool_5_06.wav,11.4560625,test 882 | stool_5_07.wav,6.9760625,train 883 | stool_5_08.wav,7.30025,train 884 | tammy_1_01.wav,11.232,train 885 | tammy_1_02.wav,9.6640625,train 886 | tammy_1_03.wav,7.3280625,train 887 | tammy_1_04.wav,7.5200625,train 888 | tammy_1_05.wav,6.4320625,train 889 | tammy_1_06.wav,7.8400625,train 890 | tammy_1_07.wav,6.6880625,train 891 | tammy_1_08.wav,6.5968125,train 892 | titon_1_01.wav,8.032,test 893 | titon_1_02.wav,7.0720625,test 894 | titon_1_03.wav,11.6480625,train 895 | titon_1_04.wav,7.7120625,train 896 | titon_1_05.wav,6.2080625,train 897 | titon_1_06.wav,6.1120625,train 898 | titon_1_07.wav,8.7360625,train 899 | titon_1_08.wav,9.6835,test 900 | titon_2_01.wav,10.72,train 901 | titon_2_02.wav,9.9840625,test 902 | titon_2_03.wav,10.2080625,test 903 | titon_2_04.wav,7.4560625,train 904 | titon_2_05.wav,6.8160625,train 905 | titon_2_06.wav,6.7200625,train 906 | titon_2_07.wav,10.5280625,train 907 | titon_2_08.wav,10.5920625,train 908 | titon_2_09.wav,10.8959375,train 909 | titon_3_01.wav,8.0,train 910 | titon_3_02.wav,11.9360625,test 911 | titon_3_03.wav,11.7440625,train 912 | titon_3_04.wav,10.2720625,train 913 | titon_3_05.wav,6.1760625,train 914 | titon_3_06.wav,7.2000625,test 915 | titon_3_07.wav,10.5600625,test 916 | titon_3_08.wav,6.6458125,train 917 | titon_4_01.wav,6.816,train 918 | titon_4_02.wav,7.3600625,test 919 | titon_4_03.wav,8.3520625,train 920 | titon_4_04.wav,10.8800625,test 921 | titon_4_05.wav,6.2400625,train 922 | titon_4_06.wav,7.6160625,train 923 | titon_4_07.wav,11.0080625,train 924 | titon_4_08.wav,7.2960625,train 925 | titon_4_09.wav,9.1520625,train 926 | titon_4_10.wav,10.4640625,train 927 | titon_4_11.wav,7.0416875,train 928 | titon_5_01.wav,6.016,train 929 | titon_5_02.wav,6.2720625,test 930 | titon_5_03.wav,7.6160625,train 931 | titon_5_04.wav,6.7200625,train 932 | titon_5_05.wav,11.5200625,train 933 | titon_5_06.wav,7.3600625,test 934 | titon_5_07.wav,5.4335,train 935 | titon_5_08.wav,6.1938125,train 936 | titon_5_09.wav,6.5236875,train 937 | yifen_1_01.wav,5.088,train 938 | yifen_1_02.wav,6.9120625,train 939 | yifen_1_03.wav,5.5040625,train 940 | yifen_1_04.wav,5.6960625,test 941 | yifen_1_05.wav,5.0240625,train 942 | yifen_1_06.wav,5.5680625,train 943 | yifen_1_07.wav,7.9680625,train 944 | yifen_1_08.wav,5.5360625,test 945 | yifen_1_09.wav,5.7280625,train 946 | yifen_1_10.wav,5.4080625,train 947 | yifen_1_11.wav,8.0960625,train 948 | yifen_1_12.wav,7.3600625,train 949 | yifen_1_13.wav,7.9040625,test 950 | yifen_1_14.wav,5.4720625,train 951 | yifen_1_15.wav,5.1840625,train 952 | yifen_1_16.wav,5.0666875,train 953 | yifen_2_01.wav,7.712,train 954 | yifen_2_02.wav,5.7280625,train 955 | yifen_2_03.wav,6.0800625,train 956 | yifen_2_04.wav,5.6320625,train 957 | yifen_2_05.wav,6.1120625,train 958 | yifen_2_06.wav,6.0800625,train 959 | yifen_2_07.wav,5.9520625,test 960 | yifen_2_08.wav,8.1280625,test 961 | yifen_2_09.wav,6.7520625,test 962 | yifen_2_10.wav,5.9200625,train 963 | yifen_2_11.wav,6.2720625,train 964 | yifen_2_12.wav,7.7760625,train 965 | yifen_2_13.wav,9.9200625,train 966 | yifen_2_14.wav,6.2080625,test 967 | yifen_2_15.wav,7.2659375,test 968 | yifen_3_01.wav,5.6,train 969 | yifen_3_02.wav,5.7920625,test 970 | yifen_3_03.wav,5.0880625,train 971 | yifen_3_04.wav,5.0560625,train 972 | yifen_3_05.wav,6.4320625,test 973 | yifen_3_06.wav,8.4480625,train 974 | yifen_3_07.wav,7.0400625,train 975 | yifen_3_08.wav,9.5360625,train 976 | yifen_3_09.wav,5.0240625,test 977 | yifen_3_10.wav,8.2560625,train 978 | yifen_3_11.wav,4.9920625,train 979 | yifen_3_12.wav,9.6078125,train 980 | yifen_4_01.wav,7.744,train 981 | yifen_4_02.wav,9.3760625,train 982 | yifen_4_03.wav,9.0560625,train 983 | yifen_4_04.wav,7.0720625,test 984 | yifen_4_05.wav,6.3680625,train 985 | yifen_4_06.wav,6.5920625,test 986 | yifen_4_07.wav,5.5680625,test 987 | yifen_4_08.wav,6.5920625,train 988 | yifen_4_09.wav,9.9520625,train 989 | yifen_4_10.wav,6.5920625,train 990 | yifen_4_11.wav,6.890875,train 991 | yifen_5_01.wav,5.408,train 992 | yifen_5_02.wav,5.0240625,train 993 | yifen_5_03.wav,6.8160625,train 994 | yifen_5_04.wav,5.0880625,train 995 | yifen_5_05.wav,7.4560625,train 996 | yifen_5_06.wav,5.5360625,train 997 | yifen_5_07.wav,8.7040625,train 998 | yifen_5_08.wav,9.7920625,train 999 | yifen_5_09.wav,7.9680625,train 1000 | yifen_5_10.wav,4.4650625,train 1001 | yifen_5_11.wav,4.4970625,train 1002 | -------------------------------------------------------------------------------- /dataset/pre_MIR1K.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | from tqdm import tqdm 5 | import librosa 6 | import soundfile as sf 7 | import shutil 8 | 9 | df_info = pd.read_csv(r'D:\ICASSP_2024\SVSDT\dataset\MIR1K\info.csv') 10 | path_in = r'D:\联合模型\Data\MIR-1K\Wavfile' 11 | path_label_in = r'D:\联合模型\Data\MIR-1K\PitchLabel' 12 | path_out = r'D:\ICASSP_2024\SVSDT\dataset\MIR1K' 13 | 14 | for _, row in tqdm(df_info.iterrows()): 15 | filename, _, split = row[0], row[1], row[2] 16 | audio_m, _ = librosa.load(os.path.join(path_in, filename), sr=16000, mono=True) 17 | audio_t, _ = librosa.load(os.path.join(path_in, filename), sr=16000, mono=False) 18 | audio_v = audio_t[1] 19 | sf.write(os.path.join(path_out, split, filename.replace('.wav', '_m.wav')), audio_m.T, 16000, 'PCM_24') 20 | sf.write(os.path.join(path_out, split, filename.replace('.wav', '_v.wav')), audio_v.T, 16000, 'PCM_24') 21 | shutil.copy(os.path.join(path_label_in, filename.replace('.wav', '.pv')), 22 | os.path.join(path_out, split, filename.replace('.wav', '.pv'))) 23 | 24 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from collections import defaultdict 9 | import soundfile as sf 10 | import torch.nn.functional as F 11 | from src import to_local_average_cents, Inference 12 | from mir_eval.melody import raw_pitch_accuracy, to_cent_voicing, overall_accuracy, raw_chroma_accuracy 13 | from mir_eval.melody import voicing_false_alarm, voicing_recall 14 | from src import SAMPLE_RATE 15 | 16 | 17 | def calculate_sdr(ref, est): 18 | s_true = ref 19 | s_artif = est - ref 20 | sdr = 10.0 * ( 21 | torch.log10(torch.clip(torch.mean(s_true ** 2), 1e-8)) 22 | - torch.log10(torch.clip(torch.mean(s_artif ** 2), 1e-8)) 23 | ) 24 | return sdr 25 | 26 | 27 | def evaluate(dataset, model, batch_size, hop_length, seq_l, device, path=None, pitch_th=0.5): 28 | metrics = defaultdict(list) 29 | seq_l = int(seq_l * SAMPLE_RATE) 30 | hop_length = int(hop_length / 1000 * SAMPLE_RATE) 31 | seg_frames = seq_l // hop_length 32 | infer = Inference(model, seq_l, seg_frames, hop_length, batch_size, device) 33 | 34 | for data in tqdm(dataset): 35 | audio_m = data['audio_m'].to(device) 36 | audio_v = data['audio_v'].to(device) 37 | pitch_label = data['pitch'].to(device) 38 | 39 | audio_v_pred, pitch_pred = infer.inference(audio_m) 40 | loss_svs = F.l1_loss(audio_v_pred, audio_v) 41 | loss_pitch = F.binary_cross_entropy(pitch_pred, pitch_label) 42 | loss = loss_svs + loss_pitch 43 | metrics['loss_svs'].append(loss_svs.item()) 44 | metrics['loss_pe'].append(loss_pitch.item()) 45 | metrics['loss_total'].append(loss.item()) 46 | 47 | cents = to_local_average_cents(pitch_label.detach().cpu().numpy(), None, pitch_th) 48 | cents_pred = to_local_average_cents(pitch_pred.detach().cpu().numpy(), None, pitch_th) 49 | freqs = np.array([10 * (2 ** (cent / 1200)) if cent else 0 for cent in cents]) 50 | freqs_pred = np.array([10 * (2 ** (cent / 1200)) if cent else 0 for cent in cents_pred]) 51 | 52 | time_slice = np.array([i * hop_length / SAMPLE_RATE for i in range(len(freqs))]) 53 | ref_v, ref_c, est_v, est_c = to_cent_voicing(time_slice, freqs, time_slice, freqs_pred) 54 | rpa = raw_pitch_accuracy(ref_v, ref_c, est_v, est_c) 55 | rca = raw_chroma_accuracy(ref_v, ref_c, est_v, est_c) 56 | oa = overall_accuracy(ref_v, ref_c, est_v, est_c) 57 | vfa = voicing_false_alarm(ref_v, est_v) 58 | vr = voicing_recall(ref_v, est_v) 59 | 60 | metrics['RPA'].append(rpa) 61 | metrics['RCA'].append(rca) 62 | metrics['OA'].append(oa) 63 | metrics['VFA'].append(vfa) 64 | metrics['VR'].append(vr) 65 | 66 | if path is not None: 67 | sf.write(os.path.join(path, data['file'].replace('_v.wav', '.wav')), audio_v_pred.cpu().numpy(), 68 | samplerate=16000) 69 | df_pitch = pd.DataFrame(columns=['times', 'freqs', 'confi']) 70 | df_pitch['times'] = time_slice 71 | df_pitch['freqs'] = freqs_pred 72 | df_pitch['confi'] = torch.max(pitch_pred, dim=-1).values.numpy() 73 | df_pitch.to_csv(os.path.join(path, data['file'].replace('_v.wav', '.csv')), index=False) 74 | sdr = calculate_sdr(audio_v, audio_v_pred).item() 75 | sdr1 = calculate_sdr(audio_v, audio_m).item() 76 | metrics['SDR'].append(sdr) 77 | metrics['NSDR'].append(sdr - sdr1) 78 | metrics['NSDR_W'].append(len(audio_v) * (sdr - sdr1)) 79 | metrics['LENGTH'].append(len(audio_v)) 80 | print(sdr, '\t', rpa, '\t', rca) 81 | 82 | return metrics 83 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import MIR1K 2 | from .model import JM_Base, JM_MMOE, DJCM 3 | from .utils import summary, cycle, to_local_average_cents 4 | from .inference import Inference 5 | from .constants import SAMPLE_RATE 6 | from .loss import bce, FL, mse, mae, dynamic_weight_average 7 | from .constants import * 8 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | SAMPLE_RATE = 16000 2 | 3 | N_CLASS = 360 4 | 5 | N_MELS = 256 6 | MEL_FMIN = 30 7 | MEL_FMAX = SAMPLE_RATE // 2 8 | WINDOW_LENGTH = 2048 9 | CONST = 1997.3794084376191 10 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import librosa 3 | import numpy as np 4 | import torch 5 | 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | from glob import glob 9 | from .constants import * 10 | 11 | 12 | class MIR1K(Dataset): 13 | def __init__(self, path, hop_length, sequence_length=None, groups=None): 14 | self.path = path 15 | self.HOP_LENGTH = int(hop_length / 1000 * SAMPLE_RATE) 16 | self.seq_len = None if not sequence_length else int(sequence_length * SAMPLE_RATE) 17 | self.num_class = N_CLASS 18 | self.data = [] 19 | 20 | print(f"Loading {len(groups)} group{'s' if len(groups) > 1 else ''} " 21 | f"of {self.__class__.__name__} at {path}") 22 | for group in groups: 23 | for input_files in tqdm(self.files(group), desc='Loading group %s' % group): 24 | self.data.extend(self.load(*input_files)) 25 | 26 | def __getitem__(self, index): 27 | return self.data[index] 28 | 29 | def __len__(self): 30 | return len(self.data) 31 | 32 | @staticmethod 33 | def availabe_groups(): 34 | return ['test'] 35 | 36 | def files(self, group): 37 | audio_m_files = glob(os.path.join(self.path, group, '*_m.wav')) 38 | audio_v_files = [f.replace('_m.wav', '_v.wav') for f in audio_m_files] 39 | label_files = [f.replace('_m.wav', '.pv') for f in audio_m_files] 40 | 41 | assert (all(os.path.isfile(audio_v_file) for audio_v_file in audio_v_files)) 42 | assert (all(os.path.isfile(label_file) for label_file in label_files)) 43 | 44 | return sorted(zip(audio_m_files, audio_v_files, label_files)) 45 | 46 | def load(self, audio_m_path, audio_v_path, label_path): 47 | data = [] 48 | audio_m, _ = librosa.load(audio_m_path, sr=SAMPLE_RATE) 49 | if audio_m.ndim == 1: 50 | audio_m = np.array([audio_m]) 51 | audio_m = torch.from_numpy(audio_m) 52 | 53 | audio_v, _ = librosa.load(audio_v_path, sr=SAMPLE_RATE) 54 | if audio_v.ndim == 1: 55 | audio_v = np.array([audio_v]) 56 | audio_v = torch.from_numpy(audio_v) 57 | 58 | audio_l = audio_m.shape[-1] 59 | audio_steps = audio_l // self.HOP_LENGTH + 1 60 | 61 | pitch_label = torch.zeros(audio_steps, self.num_class, dtype=torch.float) 62 | voice_label = torch.zeros(audio_steps, dtype=torch.float) 63 | with open(label_path, 'r') as f: 64 | lines = f.readlines() 65 | i = 0 66 | for line in lines: 67 | i += 1 68 | if float(line) != 0: 69 | freq = 440 * (2.0 ** ((float(line) - 69.0) / 12.0)) 70 | cent = 1200 * np.log2(freq/10) 71 | index = int(round((cent-CONST)/20)) 72 | pitch_label[i][index] = 1 73 | voice_label[i] = 1 74 | 75 | if self.seq_len is not None: 76 | n_steps = self.seq_len // self.HOP_LENGTH 77 | for i in range(audio_l // self.seq_len): 78 | begin_t = i * self.seq_len 79 | end_t = begin_t + self.seq_len 80 | begin_step = begin_t // self.HOP_LENGTH 81 | end_step = begin_step + n_steps 82 | data.append(dict(audio_m=audio_m[:, begin_t:end_t], audio_v=audio_v[:, begin_t:end_t], 83 | pitch=pitch_label[begin_step:end_step], voice=voice_label[begin_step:end_step], 84 | file=os.path.basename(audio_m_path))) 85 | data.append(dict(audio_m=audio_m[:, -self.seq_len:], audio_v=audio_v[:, -self.seq_len:], 86 | pitch=pitch_label[-n_steps:], voice=voice_label[-n_steps:], 87 | file=os.path.basename(audio_m_path))) 88 | else: 89 | data.append(dict(audio_m=audio_m, audio_v=audio_v, pitch=pitch_label, voice=voice_label, 90 | file=os.path.basename(audio_m_path))) 91 | return data 92 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Inference: 6 | def __init__(self, model, seg_len, seg_frames, hop_length, batch_size, device): 7 | super(Inference, self).__init__() 8 | self.model = model.eval() 9 | self.seg_len = seg_len 10 | self.seg_frames = seg_frames 11 | self.batch_size = batch_size 12 | self.hop_length = hop_length 13 | self.device = device 14 | 15 | def inference(self, audio): 16 | with torch.no_grad(): 17 | padded_audio = self.pad_audio(audio) 18 | segments = self.en_frame(padded_audio) 19 | sep_segments, pitch_segments = self.forward_in_mini_batch(self.model, segments) 20 | out_audio = self.de_frame(sep_segments, type_seg='audio')[:, :audio.shape[-1]] 21 | pitch_pred = self.de_frame(pitch_segments, type_seg='pitch')[:(audio.shape[-1]//self.hop_length+1)] 22 | return out_audio, pitch_pred 23 | 24 | def pad_audio(self, audio): 25 | c, audio_len = audio.shape 26 | seg_nums = int(np.ceil(audio_len / self.seg_len)) + 1 27 | pad_len = seg_nums * self.seg_len - audio_len + self.seg_len // 2 28 | padded_audio = torch.cat([torch.zeros(c, self.seg_len // 4).to(self.device), audio, 29 | torch.zeros(c, pad_len - self.seg_len // 4).to(self.device)], dim=1) 30 | return padded_audio 31 | 32 | def en_frame(self, audio): 33 | c, audio_len = audio.shape 34 | assert audio_len % (self.seg_len // 2) == 0 35 | 36 | segments = [] 37 | start = 0 38 | while start + self.seg_len <= audio_len: 39 | segments.append(audio[:, start:start + self.seg_len]) 40 | start += self.seg_len // 2 41 | segments = torch.stack(segments, dim=0) 42 | return segments 43 | 44 | def forward_in_mini_batch(self, model, segments): 45 | out_segments = [] 46 | pitch_segments = [] 47 | segments_num = segments.shape[0] 48 | # print(segments_num, end='\t') 49 | batch_start = 0 50 | while True: 51 | # print('#', end='\t') 52 | if batch_start + self.batch_size >= segments_num: 53 | batch_tmp = segments[batch_start:].shape[0] 54 | segment_in = torch.cat([segments[batch_start:], 55 | torch.zeros_like(segments)[:self.batch_size-batch_tmp].to(self.device)], dim=0) 56 | # out_audio = model(segment_in) 57 | out_audio, pitch_pred = model(segment_in) 58 | out_segments.append(out_audio[:batch_tmp, :]) 59 | pitch_segments.append(pitch_pred[:batch_tmp, :]) 60 | break 61 | else: 62 | segment_in = segments[batch_start:batch_start+self.batch_size] 63 | out_audio, pitch_pred = model(segment_in) 64 | out_segments.append(out_audio) 65 | pitch_segments.append(pitch_pred) 66 | batch_start += self.batch_size 67 | out_segments = torch.cat(out_segments, dim=0) 68 | pitch_segments = torch.cat(pitch_segments, dim=0) 69 | 70 | return out_segments, pitch_segments 71 | 72 | def de_frame(self, segments, type_seg='audio'): 73 | output = [] 74 | if type_seg == 'audio': 75 | for segment in segments: 76 | output.append(segment[:, self.seg_len // 4: int(self.seg_len * 0.75)]) 77 | output = torch.cat(output, dim=1) 78 | else: 79 | for segment in segments: 80 | output.append(segment[self.seg_frames // 4: int(self.seg_frames * 0.75)]) 81 | output = torch.cat(output, dim=0) 82 | return output 83 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | def mae(input, target, weight=None): 8 | l1_loss = nn.L1Loss(reduce=False) 9 | loss = l1_loss(input, target) 10 | if weight is not None: 11 | loss = weight * loss 12 | return torch.mean(loss) 13 | 14 | 15 | def mse(input, target, weight=None): 16 | l2_loss = nn.MSELoss(reduce=False) 17 | loss = l2_loss(input, target) 18 | if weight is not None: 19 | loss = weight * loss 20 | return torch.mean(loss) 21 | 22 | 23 | def ce(input, target, weight=None): 24 | ce = nn.CrossEntropyLoss(reduce=False) 25 | loss = ce(input, target) 26 | if weight is not None: 27 | loss = loss * weight 28 | return torch.mean(loss) 29 | 30 | 31 | def bce(input, target, weight=None): 32 | bce = nn.BCELoss(reduce=False) 33 | loss = bce(input, target) 34 | if weight is not None: 35 | loss = loss * weight 36 | return torch.mean(loss) 37 | 38 | 39 | def FL(inputs, targets, alpha, gamma, weight_t=None): 40 | loss = F.binary_cross_entropy(inputs, targets, reduce=False) 41 | weight = torch.ones(inputs.shape, dtype=torch.float).to(inputs.device) 42 | weight[targets == 1] = float(alpha) 43 | loss_w = F.binary_cross_entropy(inputs, targets, weight=weight, reduce=False) 44 | pt = torch.exp(-loss) 45 | weight_gamma = (1 - pt) ** gamma 46 | if weight_t is not None: 47 | weight_gamma = weight_gamma * weight_t 48 | F_loss = torch.mean(weight_gamma * loss_w) 49 | return F_loss 50 | 51 | 52 | def dynamic_weight_average(loss_t_1, loss_t_2, T=2): 53 | """ 54 | 55 | :param loss_t_1: 每个task上一轮的loss列表,并且为标量 56 | :param loss_t_2: 57 | :return: 58 | """ 59 | # 第1和2轮,w初设化为1,lambda也对应为1 60 | if not loss_t_1 or not loss_t_2: 61 | return [1, 1] 62 | 63 | assert len(loss_t_1) == len(loss_t_2) 64 | task_n = len(loss_t_1) 65 | 66 | w = [l_1 / l_2 for l_1, l_2 in zip(loss_t_1, loss_t_2)] 67 | 68 | lamb = [math.exp(v / T) for v in w] 69 | 70 | lamb_sum = sum(lamb) 71 | 72 | return [task_n * l / lamb_sum for l in lamb] 73 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from .modules import Encoder, LatentBlocks, SVS_Decoder, PE_Decoder 4 | from .modules import Wav2Spec, Spec2Wav, SVS_PE_Base, SVS_PE_MMOE, init_bn 5 | from .constants import WINDOW_LENGTH, SAMPLE_RATE 6 | 7 | 8 | class JM_Base(nn.Module): 9 | def __init__(self, in_channels, n_blocks, hop_length, latent_layers, seq_frames, seq='gru', seq_layers=1): 10 | super(JM_Base, self).__init__() 11 | self.to_spec = Wav2Spec(int(hop_length / 1000 * SAMPLE_RATE), WINDOW_LENGTH) 12 | self.to_wav = Spec2Wav(int(hop_length / 1000 * SAMPLE_RATE), WINDOW_LENGTH) 13 | self.bn = nn.BatchNorm2d(2048 // 2 + 1, momentum=0.01) 14 | init_bn(self.bn) 15 | # in_channels, n_blocks, latent_layers, seq_frames, seq='gru', seq_layers=1 16 | self.model = SVS_PE_Base(in_channels, n_blocks, latent_layers, seq_frames, seq, seq_layers) 17 | 18 | def forward(self, audio_m, audio_v=None): 19 | spec_m, cos_m, sin_m = self.to_spec(audio_m) 20 | x = self.bn(spec_m.transpose(1, 3)).transpose(1, 3)[..., :-1] 21 | pe_out, svs_out = self.model(x) 22 | out_audio, out_spec = self.to_wav(svs_out, spec_m, cos_m, sin_m, audio_m.shape[-1]) 23 | if audio_v is None: 24 | return out_audio, pe_out 25 | else: 26 | spec_v, _, _ = self.to_spec(audio_v) 27 | loss_spec = F.l1_loss(out_spec[..., :-1], spec_v[..., :-1]) 28 | return out_audio, pe_out, loss_spec 29 | 30 | 31 | class JM_MMOE(nn.Module): 32 | def __init__(self, in_channels, n_blocks, hop_length, latent_layers, seq_frames, expert_num=2, seq='gru', 33 | seq_layers=1): 34 | super(JM_MMOE, self).__init__() 35 | self.to_spec = Wav2Spec(int(hop_length / 1000 * SAMPLE_RATE), WINDOW_LENGTH) 36 | self.to_wav = Spec2Wav(int(hop_length / 1000 * SAMPLE_RATE), WINDOW_LENGTH) 37 | self.bn = nn.BatchNorm2d(2048 // 2 + 1, momentum=0.01) 38 | init_bn(self.bn) 39 | # in_channels, n_blocks, latent_layers, seq_frames, expert_num=2, seq='gru', seq_layers=1 40 | self.model = SVS_PE_MMOE(in_channels, n_blocks, latent_layers, seq_frames, expert_num, seq, seq_layers) 41 | 42 | def forward(self, audio_m, audio_v=None): 43 | spec_m, cos_m, sin_m = self.to_spec(audio_m) 44 | x = self.bn(spec_m.transpose(1, 3)).transpose(1, 3)[..., :-1] 45 | pe_out, svs_out = self.model(x) 46 | out_audio, out_spec = self.to_wav(svs_out, spec_m, cos_m, sin_m, audio_m.shape[-1]) 47 | if audio_v is None: 48 | return out_audio, pe_out 49 | else: 50 | spec_v, _, _ = self.to_spec(audio_v) 51 | loss_spec = F.l1_loss(out_spec[..., :-1], spec_v[..., :-1]) 52 | return out_audio, pe_out, loss_spec 53 | 54 | 55 | class DJCM(nn.Module): 56 | def __init__(self, in_channels, n_blocks, hop_length, latent_layers, seq_frames, gate=False, seq='gru', seq_layers=1): 57 | super(DJCM, self).__init__() 58 | self.to_spec = Wav2Spec(int(hop_length / 1000 * SAMPLE_RATE), WINDOW_LENGTH) 59 | self.to_wav = Spec2Wav(int(hop_length / 1000 * SAMPLE_RATE), WINDOW_LENGTH) 60 | self.bn = nn.BatchNorm2d(2048 // 2 + 1, momentum=0.01) 61 | init_bn(self.bn) 62 | self.svs_encoder = Encoder(in_channels, n_blocks) 63 | self.svs_latent = LatentBlocks(n_blocks, latent_layers) 64 | self.svs_decoder = SVS_Decoder(in_channels, n_blocks, gate) 65 | 66 | self.pe_encoder = Encoder(in_channels, n_blocks) 67 | self.pe_latent = LatentBlocks(n_blocks, latent_layers) 68 | self.pe_decoder = PE_Decoder(n_blocks, seq_frames, seq, seq_layers, gate) 69 | 70 | def forward(self, audio_m, audio_v=None): 71 | spec_m, cos_m, sin_m = self.to_spec(audio_m) 72 | x = self.bn(spec_m.transpose(1, 3)).transpose(1, 3)[..., :-1] 73 | x, concat_tensors = self.svs_encoder(x) 74 | x = self.svs_latent(x) 75 | x = self.svs_decoder(x, concat_tensors) 76 | svs_out = F.pad(x, pad=(0, 1)) 77 | out_audio, out_spec = self.to_wav(svs_out, spec_m, cos_m, sin_m, audio_m.shape[-1]) 78 | x, concat_tensors = self.pe_encoder(out_spec[..., :-1]) 79 | x = self.pe_latent(x) 80 | pe_out = self.pe_decoder(x, concat_tensors) 81 | # pe_out, svs_out = self.model(x) 82 | if audio_v is None: 83 | return out_audio, pe_out 84 | else: 85 | spec_v, _, _ = self.to_spec(audio_v) 86 | loss_spec = F.l1_loss(out_spec[..., :-1], spec_v[..., :-1]) 87 | return out_audio, pe_out, loss_spec 88 | -------------------------------------------------------------------------------- /src/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchlibrosa.stft import STFT, ISTFT, magphase 6 | from .seq import BiGRU, BiLSTM 7 | from .constants import N_CLASS 8 | 9 | 10 | def init_layer(layer: nn.Module): 11 | r"""Initialize a Linear or Convolutional layer.""" 12 | nn.init.xavier_uniform_(layer.weight) 13 | 14 | if hasattr(layer, "bias"): 15 | if layer.bias is not None: 16 | layer.bias.data.fill_(0.0) 17 | 18 | 19 | def init_bn(bn: nn.Module): 20 | r"""Initialize a Batchnorm layer.""" 21 | bn.bias.data.fill_(0.0) 22 | bn.weight.data.fill_(1.0) 23 | bn.running_mean.data.fill_(0.0) 24 | bn.running_var.data.fill_(1.0) 25 | 26 | 27 | class Wav2Spec(nn.Module): 28 | def __init__(self, hop_length, window_size): 29 | super(Wav2Spec, self).__init__() 30 | self.hop_length = hop_length 31 | self.stft = STFT(window_size, hop_length, window_size) 32 | 33 | def forward(self, audio): 34 | bs, c, segment_samples = audio.shape 35 | audio = audio.reshape(bs * c, segment_samples) 36 | real, imag = self.stft(audio[:, :-1]) 37 | mag = torch.clamp(real ** 2 + imag ** 2, 1e-10, np.inf) ** 0.5 38 | cos = real / mag 39 | sin = imag / mag 40 | _, _, time_steps, freq_bins = mag.shape 41 | mag = mag.reshape(bs, c, time_steps, freq_bins) 42 | cos = cos.reshape(bs, c, time_steps, freq_bins) 43 | sin = sin.reshape(bs, c, time_steps, freq_bins) 44 | return mag, cos, sin 45 | 46 | 47 | class Spec2Wav(nn.Module): 48 | def __init__(self, hop_length, window_size): 49 | super(Spec2Wav, self).__init__() 50 | self.istft = ISTFT(window_size, hop_length, window_size) 51 | 52 | def forward(self, x, spec_m, cos_m, sin_m, audio_len): 53 | bs, c, time_steps, freqs_steps = x.shape 54 | x = x.reshape(bs, c // 4, 4, time_steps, freqs_steps) 55 | mask_spec = torch.sigmoid(x[:, :, 0, :, :]) 56 | _mask_real = torch.tanh(x[:, :, 1, :, :]) 57 | _mask_imag = torch.tanh(x[:, :, 2, :, :]) 58 | _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag) 59 | linear_spec = x[:, :, 3, :, :] 60 | out_cos = cos_m * mask_cos - sin_m * mask_sin 61 | out_sin = sin_m * mask_cos + cos_m * mask_sin 62 | out_spec = F.relu(spec_m.detach() * mask_spec + linear_spec) 63 | out_real = (out_spec * out_cos).reshape(bs * c // 4, 1, time_steps, freqs_steps) 64 | out_imag = (out_spec * out_sin).reshape(bs * c // 4, 1, time_steps, freqs_steps) 65 | audio = self.istft(out_real, out_imag, audio_len).reshape(bs, c // 4, audio_len) 66 | return audio, out_spec 67 | 68 | 69 | class ResConvBlock(nn.Module): 70 | def __init__(self, in_planes, planes, bias=False): 71 | super(ResConvBlock, self).__init__() 72 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.01) 73 | self.bn2 = nn.BatchNorm2d(planes, momentum=0.01) 74 | self.act1 = nn.PReLU() 75 | self.act2 = nn.PReLU() 76 | self.conv1 = nn.Conv2d(in_planes, planes, (3, 3), padding=(1, 1), bias=bias) 77 | self.conv2 = nn.Conv2d(planes, planes, (3, 3), padding=(1, 1), bias=bias) 78 | if in_planes != planes: 79 | self.shortcut = nn.Conv2d(in_planes, planes, (1, 1)) 80 | self.is_shortcut = True 81 | else: 82 | self.is_shortcut = False 83 | self.init_weights() 84 | 85 | def init_weights(self): 86 | r"""Initialize weights.""" 87 | init_bn(self.bn1) 88 | init_bn(self.bn2) 89 | init_layer(self.conv1) 90 | init_layer(self.conv2) 91 | 92 | if self.is_shortcut: 93 | init_layer(self.shortcut) 94 | 95 | def forward(self, x): 96 | out = self.conv1(self.act1(self.bn1(x))) 97 | out = self.conv2(self.act2(self.bn2(out))) 98 | if self.is_shortcut: 99 | return self.shortcut(x) + out 100 | else: 101 | return out + x 102 | 103 | 104 | class EncoderBlock(nn.Module): 105 | def __init__(self, in_channels, out_channels, n_blocks, kernel_size, bias): 106 | super(EncoderBlock, self).__init__() 107 | self.conv = nn.ModuleList([ 108 | ResConvBlock(in_channels, out_channels, bias) 109 | ]) 110 | for i in range(n_blocks - 1): 111 | self.conv.append(ResConvBlock(out_channels, out_channels, bias)) 112 | if kernel_size is not None: 113 | self.pool = nn.MaxPool2d(kernel_size) 114 | else: 115 | self.pool = None 116 | 117 | def forward(self, x): 118 | for each_layer in self.conv: 119 | x = each_layer(x) 120 | if self.pool is not None: 121 | return x, self.pool(x) 122 | return x 123 | 124 | 125 | class DecoderBlock(nn.Module): 126 | def __init__(self, in_channels, out_channels, n_blocks, stride, bias, gate=False): 127 | super(DecoderBlock, self).__init__() 128 | self.gate = gate 129 | if self.gate: 130 | self.W_g = nn.Sequential( 131 | nn.Conv2d(out_channels, out_channels // 2, (1, 1)), 132 | nn.BatchNorm2d(out_channels // 2) 133 | ) 134 | self.W_x = nn.Sequential( 135 | nn.Conv2d(out_channels, out_channels // 2, (1, 1)), 136 | nn.BatchNorm2d(out_channels // 2) 137 | ) 138 | self.psi = nn.Sequential( 139 | nn.Conv2d(out_channels // 2, 1, (1, 1)), 140 | nn.BatchNorm2d(1), 141 | nn.Sigmoid() 142 | ) 143 | self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, stride, stride, (0, 0), bias=bias) 144 | self.bn1 = nn.BatchNorm2d(in_channels, momentum=0.01) 145 | self.conv = nn.ModuleList([ 146 | ResConvBlock(out_channels * 2, out_channels, bias) 147 | ]) 148 | for i in range(n_blocks - 1): 149 | self.conv.append(ResConvBlock(out_channels, out_channels, bias)) 150 | self.init_weights() 151 | 152 | def init_weights(self): 153 | init_bn(self.bn1) 154 | init_layer(self.conv1) 155 | 156 | def forward(self, x, concat): 157 | x = self.conv1(F.relu_(self.bn1(x))) 158 | if self.gate: 159 | concat = x * self.psi(F.relu_(self.W_g(x) + self.W_x(concat))) 160 | x = torch.cat((x, concat), dim=1) 161 | for each_layer in self.conv: 162 | x = each_layer(x) 163 | return x 164 | 165 | 166 | class Encoder(nn.Module): 167 | def __init__(self, in_channels, n_blocks): 168 | super(Encoder, self).__init__() 169 | self.en_blocks = nn.ModuleList([ 170 | EncoderBlock(in_channels, 32, n_blocks, (1, 2), False), 171 | EncoderBlock(32, 64, n_blocks, (1, 2), False), 172 | EncoderBlock(64, 128, n_blocks, (1, 2), False), 173 | EncoderBlock(128, 256, n_blocks, (1, 2), False), 174 | EncoderBlock(256, 384, n_blocks, (1, 2), False), 175 | EncoderBlock(384, 384, n_blocks, (1, 2), False) 176 | ]) 177 | 178 | def forward(self, x): 179 | concat_tensors = [] 180 | for layer in self.en_blocks: 181 | _, x = layer(x) 182 | concat_tensors.append(_) 183 | return x, concat_tensors 184 | 185 | 186 | class Decoder(nn.Module): 187 | def __init__(self, n_blocks, gate=False): 188 | super(Decoder, self).__init__() 189 | self.de_blocks = nn.ModuleList([ 190 | DecoderBlock(384, 384, n_blocks, (1, 2), False, gate), 191 | DecoderBlock(384, 384, n_blocks, (1, 2), False, gate), 192 | DecoderBlock(384, 256, n_blocks, (1, 2), False, gate), 193 | DecoderBlock(256, 128, n_blocks, (1, 2), False, gate), 194 | DecoderBlock(128, 64, n_blocks, (1, 2), False, gate), 195 | DecoderBlock(64, 32, n_blocks, (1, 2), False, gate), 196 | ]) 197 | 198 | def forward(self, x, concat_tensors): 199 | for i, layer in enumerate(self.de_blocks): 200 | x = layer(x, concat_tensors[-1-i]) 201 | return x 202 | 203 | 204 | class LatentBlocks(nn.Module): 205 | def __init__(self, n_blocks, latent_layers): 206 | super(LatentBlocks, self).__init__() 207 | self.latent_blocks = nn.ModuleList([]) 208 | for i in range(latent_layers): 209 | self.latent_blocks.append(EncoderBlock(384, 384, n_blocks, None, False)) 210 | 211 | def forward(self, x): 212 | for layer in self.latent_blocks: 213 | x = layer(x) 214 | return x 215 | 216 | 217 | class SVS_Decoder(nn.Module): 218 | def __init__(self, in_channels, n_blocks, gate=False): 219 | super(SVS_Decoder, self).__init__() 220 | self.de_blocks = Decoder(n_blocks, gate) 221 | self.after_conv1 = EncoderBlock(32, 32, n_blocks, None, False) 222 | self.after_conv2 = nn.Conv2d(32, in_channels * 4, (1, 1)) 223 | self.init_weights() 224 | 225 | def init_weights(self): 226 | init_layer(self.after_conv2) 227 | 228 | def forward(self, x, concat_tensors): 229 | x = self.de_blocks(x, concat_tensors) 230 | return self.after_conv2(self.after_conv1(x)) 231 | 232 | 233 | class PE_Decoder(nn.Module): 234 | def __init__(self, n_blocks, seq_frames, seq='gru', seq_layers=1, gate=False): 235 | super(PE_Decoder, self).__init__() 236 | self.de_blocks = Decoder(n_blocks, gate) 237 | self.after_conv1 = EncoderBlock(32, 32, n_blocks, None, False) 238 | self.after_conv2 = nn.Conv2d(32, 1, (1, 1)) 239 | init_layer(self.after_conv2) 240 | if seq.lower() == 'gru': 241 | self.fc = nn.Sequential( 242 | BiGRU((seq_frames, 1024), (1, 1024), 1, seq_layers), 243 | nn.Linear(1024, N_CLASS), 244 | nn.Sigmoid() 245 | ) 246 | elif seq.lower() == 'lstm': 247 | self.fc = nn.Sequential( 248 | BiLSTM((seq_frames, 1024), (1, 1024), 1, seq_layers), 249 | nn.Linear(1024, N_CLASS), 250 | nn.Sigmoid() 251 | ) 252 | else: 253 | self.fc = nn.Sequential( 254 | nn.Linear(1024, N_CLASS), 255 | nn.Sigmoid() 256 | ) 257 | 258 | def forward(self, x, concat_tensors): 259 | x = self.de_blocks(x, concat_tensors) 260 | x = self.after_conv2(self.after_conv1(x)) 261 | x = self.fc(x).squeeze(1) 262 | return x 263 | 264 | 265 | class SVS_PE_Base(nn.Module): 266 | def __init__(self, in_channels, n_blocks, latent_layers, seq_frames, seq='gru', seq_layers=1): 267 | super(SVS_PE_Base, self).__init__() 268 | self.encoder = Encoder(in_channels, n_blocks) 269 | self.svs_latent = LatentBlocks(n_blocks, latent_layers) 270 | self.pe_latent = LatentBlocks(n_blocks, latent_layers) 271 | self.svs_decoder = SVS_Decoder(in_channels, n_blocks) 272 | self.pe_decoder = PE_Decoder(n_blocks, seq_frames, seq, seq_layers) 273 | 274 | def forward(self, spec_m): 275 | x, concat_tensors = self.encoder(spec_m) 276 | pe_x = self.pe_latent(x) 277 | pe_out = self.pe_decoder(pe_x, concat_tensors) 278 | svs_x = self.svs_latent(x) 279 | svs_out = F.pad(self.svs_decoder(svs_x, concat_tensors), pad=(0, 1)) 280 | return pe_out, svs_out 281 | 282 | 283 | class SVS_PE_MMOE(nn.Module): 284 | def __init__(self, in_channels, n_blocks, latent_layers, seq_frames, expert_num=2, seq='gru', seq_layers=1): 285 | super(SVS_PE_MMOE, self).__init__() 286 | self.expert_num = expert_num 287 | self.encoder_expert = nn.ModuleList([ 288 | Encoder(in_channels, n_blocks) for _ in range(expert_num) 289 | ]) 290 | self.svs_gate = nn.Sequential( 291 | nn.Linear(1024, 512), 292 | nn.PReLU(), 293 | nn.Linear(512, expert_num), 294 | nn.Softmax(dim=-1) 295 | ) 296 | self.pe_gate = nn.Sequential( 297 | nn.Linear(1024, 512), 298 | nn.PReLU(), 299 | nn.Linear(512, expert_num), 300 | nn.Softmax(dim=-1) 301 | ) 302 | self.svs_latent = LatentBlocks(n_blocks, latent_layers) 303 | self.pe_latent = LatentBlocks(n_blocks, latent_layers) 304 | self.svs_decoder = SVS_Decoder(in_channels, n_blocks) 305 | self.pe_decoder = PE_Decoder(n_blocks, seq_frames, seq, seq_layers) 306 | 307 | def forward(self, spec_m): 308 | x, concat_tensors = [], [] 309 | for layer in self.encoder_expert: 310 | x_tmp, concat_tensors_tmp = layer(spec_m) 311 | x.append(x_tmp) 312 | concat_tensors.append(concat_tensors_tmp) 313 | x = torch.stack(x, dim=-1) 314 | 315 | svs_gate, pe_gate = self.svs_gate(spec_m).unsqueeze(-1), self.pe_gate(spec_m).unsqueeze(-1) 316 | svs_concat, pe_concat = [], [] 317 | for i in range(len(concat_tensors[0])): 318 | tmp = torch.stack([concat_tensors[j][i] for j in range(self.expert_num)], dim=-1) 319 | svs_concat.append(torch.matmul(tmp, svs_gate).squeeze(-1)) 320 | pe_concat.append(torch.matmul(tmp, pe_gate).squeeze(-1)) 321 | svs_x = torch.matmul(x, svs_gate).squeeze(-1) 322 | svs_x = self.svs_latent(svs_x) 323 | pe_x = torch.matmul(x, pe_gate).squeeze(-1) 324 | pe_x = self.pe_latent(pe_x) 325 | svs_out = F.pad(self.svs_decoder(svs_x, svs_concat), pad=(0, 1)) 326 | pe_out = self.pe_decoder(pe_x, pe_concat) 327 | return pe_out, svs_out 328 | -------------------------------------------------------------------------------- /src/seq.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from einops.layers.torch import Rearrange 3 | 4 | 5 | class BiGRU(nn.Module): 6 | def __init__(self, image_size, patch_size, channels, depth): 7 | super(BiGRU, self).__init__() 8 | image_width, image_height = pair(image_size) 9 | patch_width, patch_height = pair(patch_size) 10 | 11 | assert image_height % patch_height == 0 and image_width % patch_width == 0 12 | patch_dim = channels * patch_height * patch_width 13 | self.to_patch_embedding = nn.Sequential( 14 | Rearrange('b c (w p1) (h p2) -> b (w h) (p1 p2 c)', p1=patch_width, p2=patch_height), 15 | ) 16 | self.gru = nn.GRU(patch_dim, patch_dim // 2, num_layers=depth, batch_first=True, bidirectional=True) 17 | 18 | def forward(self, x): 19 | x = self.to_patch_embedding(x) 20 | return self.gru(x)[0] 21 | 22 | 23 | class BiLSTM(nn.Module): 24 | def __init__(self, image_size, patch_size, channels, depth): 25 | super(BiLSTM, self).__init__() 26 | image_width, image_height = pair(image_size) 27 | patch_width, patch_height = pair(patch_size) 28 | 29 | assert image_height % patch_height == 0 and image_width % patch_width == 0 30 | patch_dim = channels * patch_height * patch_width 31 | self.to_patch_embedding = nn.Sequential( 32 | Rearrange('b c (w p1) (h p2) -> b (w h) (p1 p2 c)', p1=patch_width, p2=patch_height), 33 | ) 34 | self.lstm = nn.LSTM(patch_dim, patch_dim // 2, num_layers=depth, batch_first=True, bidirectional=True) 35 | 36 | def forward(self, x): 37 | return self.lstm(x)[0] 38 | 39 | 40 | # helpers 41 | def pair(t): 42 | return t if isinstance(t, tuple) else (t, t) 43 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | from functools import reduce 4 | from torch.nn.modules.module import _addindent 5 | 6 | 7 | def cycle(iterable): 8 | while True: 9 | for item in iterable: 10 | yield item 11 | 12 | 13 | def summary(model, file=sys.stdout): 14 | def repr(model): 15 | # We treat the extra repr like the sub-module, one item per line 16 | extra_lines = [] 17 | extra_repr = model.extra_repr() 18 | # empty string will be split into list [''] 19 | if extra_repr: 20 | extra_lines = extra_repr.split('\n') 21 | child_lines = [] 22 | total_params = 0 23 | for key, module in model._modules.items(): 24 | mod_str, num_params = repr(module) 25 | mod_str = _addindent(mod_str, 2) 26 | child_lines.append('(' + key + '): ' + mod_str) 27 | total_params += num_params 28 | lines = extra_lines + child_lines 29 | 30 | for name, p in model._parameters.items(): 31 | if hasattr(p, 'shape'): 32 | total_params += reduce(lambda x, y: x * y, p.shape) 33 | 34 | main_str = model._get_name() + '(' 35 | if lines: 36 | # simple one-liner info, which most builtin Modules will use 37 | if len(extra_lines) == 1 and not child_lines: 38 | main_str += extra_lines[0] 39 | else: 40 | main_str += '\n ' + '\n '.join(lines) + '\n' 41 | 42 | main_str += ')' 43 | if file is sys.stdout: 44 | main_str += ', \033[92m{:,}\033[0m params'.format(total_params) 45 | else: 46 | main_str += ', {:,} params'.format(total_params) 47 | return main_str, total_params 48 | 49 | string, count = repr(model) 50 | if file is not None: 51 | if isinstance(file, str): 52 | file = open(file, 'w') 53 | print(string, file=file) 54 | file.flush() 55 | 56 | return count 57 | 58 | 59 | def to_local_average_cents(salience, center=None, thred=0.0): 60 | """ 61 | find the weighted average cents near the argmax bin 62 | """ 63 | 64 | if not hasattr(to_local_average_cents, 'cents_mapping'): 65 | # the bin number-to-cents mapping 66 | to_local_average_cents.cents_mapping = ( 67 | np.linspace(0, 7180, 360) + 1997.3794084376191) 68 | 69 | if salience.ndim == 1: 70 | if center is None: 71 | center = int(np.argmax(salience)) 72 | start = max(0, center - 4) 73 | end = min(len(salience), center + 5) 74 | salience = salience[start:end] 75 | product_sum = np.sum( 76 | salience * to_local_average_cents.cents_mapping[start:end]) 77 | weight_sum = np.sum(salience) 78 | return product_sum / weight_sum if np.max(salience) > thred else 0 79 | if salience.ndim == 2: 80 | return np.array([to_local_average_cents(salience[i, :], None, thred) for i in 81 | range(salience.shape[0])]) 82 | 83 | raise Exception("label should be either 1d or 2d ndarray") 84 | 85 | -------------------------------------------------------------------------------- /train_Base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn.utils import clip_grad_norm_ 5 | from torch.optim.lr_scheduler import StepLR 6 | from torch.utils.data import DataLoader 7 | from torch import nn 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | import numpy as np 11 | from src import MIR1K, cycle, summary, JM_Base, FL, mae 12 | from evaluate import evaluate 13 | 14 | 15 | def train(weight_svs): 16 | alpha = 10 17 | gamma = 0 18 | # weight_svs = 10 19 | weight_pe = round(2 - weight_svs, 2) 20 | in_channels = 1 21 | n_blocks = 1 22 | latent_layers = 1 23 | seq_l = 2.56 24 | hop_length = 20 25 | seq_frames = int(seq_l * 1000 / hop_length) 26 | logdir = 'runs/MIR1K_Base/' + 'nblocks' + str(n_blocks) + '_latent' + str(latent_layers) + '_frames' + str(seq_frames) \ 27 | + '_alpha' + str(alpha) + '_gamma' + str(gamma) + '_svs' + str(weight_svs) + '_pe' + str(weight_pe) 28 | 29 | pitch_th = 0.5 30 | learning_rate = 5e-4 31 | batch_size = 16 32 | clip_grad_norm = 3 33 | learning_rate_decay_rate = 0.95 34 | learning_rate_decay_epochs = 5 35 | train_epochs = 250 36 | early_stop_epochs = 10 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | # path, hop_length, sequence_length = None, groups = None 39 | train_dataset = MIR1K(path='./dataset/MIR1K', hop_length=hop_length, groups=['train'], sequence_length=seq_l) 40 | print('train nums:', len(train_dataset)) 41 | valid_dataset = MIR1K(path='./dataset/MIR1K', hop_length=hop_length, groups=['test'], sequence_length=None) 42 | print('valid nums:', len(valid_dataset)) 43 | data_loader = DataLoader(train_dataset, batch_size, shuffle=True) 44 | epoch_nums = len(data_loader) 45 | print('epoch_nums:', epoch_nums) 46 | learning_rate_decay_steps = len(data_loader) * learning_rate_decay_epochs 47 | iterations = epoch_nums * train_epochs 48 | 49 | resume_iteration = None 50 | os.makedirs(logdir, exist_ok=True) 51 | writer = SummaryWriter(logdir) 52 | 53 | if resume_iteration is None: 54 | # in_channels, n_blocks, hop_length, latent_layers, seq_frames, seq='gru', seq_layers=1 55 | model = JM_Base(in_channels, n_blocks, hop_length, latent_layers, seq_frames) 56 | model = nn.DataParallel(model).to(device) 57 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 58 | resume_iteration = 0 59 | else: 60 | model_path = os.path.join(logdir, f'model-{resume_iteration}.pt') 61 | model = torch.load(model_path) 62 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 63 | optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt'))) 64 | 65 | scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate) 66 | summary(model) 67 | SDR, RPA, GNSDR, RCA, it = 0, 0, 0, 0, 0 68 | loop = tqdm(range(resume_iteration + 1, iterations + 1)) 69 | 70 | for i, data in zip(loop, cycle(data_loader)): 71 | audio_m = data['audio_m'].to(device) 72 | audio_v = data['audio_v'].to(device) 73 | pitch_label = data['pitch'].to(device) 74 | out_audio, out_pitch, loss_spec = model(audio_m, audio_v) 75 | 76 | loss_svs = mae(out_audio, audio_v) 77 | loss_pitch = FL(out_pitch, pitch_label, alpha, gamma) 78 | # weight_pe = loss_svs.item() / loss_pitch.item() 79 | loss_total = weight_svs * loss_svs + weight_pe * loss_pitch 80 | 81 | optimizer.zero_grad() 82 | loss_total.backward() 83 | if clip_grad_norm: 84 | clip_grad_norm_(model.parameters(), clip_grad_norm) 85 | optimizer.step() 86 | scheduler.step() 87 | 88 | print(i, end='\t') 89 | print('loss_total:{:.6f}'.format(loss_total.item()), end='\t') 90 | print('loss_svs:{:.6f}'.format(loss_svs.item()), end='\t') 91 | print('loss_pe:{:.6f}'.format(loss_pitch.item())) 92 | 93 | writer.add_scalar('loss/loss_total', loss_total.item(), global_step=i) 94 | writer.add_scalar('loss/loss_svs', loss_svs.item(), global_step=i) 95 | writer.add_scalar('loss/loss_pe', loss_pitch.item(), global_step=i) 96 | 97 | if i % epoch_nums == 0: 98 | print('*' * 50) 99 | print(i, '\t', epoch_nums) 100 | model.eval() 101 | with torch.no_grad(): 102 | metrics = evaluate(valid_dataset, model, batch_size, hop_length, seq_l, device, None, pitch_th) 103 | for key, value in metrics.items(): 104 | writer.add_scalar('validation/' + key, np.mean(value), global_step=i) 105 | gnsdr = np.round((np.sum(metrics["NSDR_W"]) / np.sum(metrics["LENGTH"])), 2) 106 | writer.add_scalar('validation/GNSDR', gnsdr, global_step=i) 107 | sdr = np.round(np.mean(metrics['SDR']), 2) 108 | rpa = np.round(np.mean(metrics['RPA']) * 100, 2) 109 | rca = np.round(np.mean(metrics['RCA']) * 100, 2) 110 | oa = np.round(np.mean(metrics['OA']) * 100, 2) 111 | if sdr + rpa >= RPA + SDR: 112 | SDR, GNSDR, RPA, RCA, it = sdr, gnsdr, rpa, rca, i 113 | with open(os.path.join(logdir, 'result.txt'), 'a') as f: 114 | f.write(str(i) + '\t') 115 | f.write(str(RPA) + '±' + str(np.round(np.std(metrics['RPA']) * 100, 2)) + '\t') 116 | f.write(str(RCA) + '±' + str(np.round(np.std(metrics['RCA']) * 100, 2)) + '\t') 117 | f.write(str(oa) + '±' + str(np.round(np.std(metrics['OA']) * 100, 2)) + '\t') 118 | f.write(str(SDR) + '±' + str(np.round(np.std(metrics['SDR']), 2)) + '\t') 119 | f.write(str(GNSDR) + '\n') 120 | torch.save(model, os.path.join(logdir, f'model-{i}.pt')) 121 | torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt')) 122 | model.train() 123 | 124 | if i - it >= epoch_nums * early_stop_epochs: 125 | break 126 | 127 | 128 | for weight_svs in [1.2, 0.8, 1.4, 0.6, 1.6, 0.4, 1.8, 0.2, 1.0]: 129 | train(weight_svs) 130 | -------------------------------------------------------------------------------- /train_DJCM.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn.utils import clip_grad_norm_ 5 | from torch.optim.lr_scheduler import StepLR 6 | from torch.utils.data import DataLoader 7 | from torch import nn 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | import numpy as np 11 | from src import MIR1K, cycle, summary, DJCM, FL, mae 12 | from evaluate import evaluate 13 | 14 | 15 | def train(weight_svs): 16 | alpha = 10 17 | gamma = 0 18 | weight_pe = 2 - weight_svs 19 | in_channels = 1 20 | n_blocks = 1 21 | latent_layers = 1 22 | seq_l = 2.56 23 | hop_length = 20 24 | seq_frames = int(seq_l * 1000 / hop_length) 25 | logdir = 'runs/MIR1K_Cascade/' + 'nblocks' + str(n_blocks) + '_latent' + str(latent_layers) + '_frames' + str(seq_frames) \ 26 | + '_alpha' + str(alpha) + '_gamma' + str(gamma) + '_svs' + str(weight_svs) + '_pe' + str(weight_pe) + \ 27 | '_gateT' 28 | 29 | pitch_th = 0.5 30 | learning_rate = 5e-4 31 | batch_size = 16 32 | clip_grad_norm = 3 33 | learning_rate_decay_rate = 0.95 34 | learning_rate_decay_epochs = 5 35 | train_epochs = 250 36 | early_stop_epochs = 10 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 38 | # path, hop_length, sequence_length = None, groups = None 39 | train_dataset = MIR1K(path='./dataset/MIR1K', hop_length=hop_length, groups=['train'], sequence_length=seq_l) 40 | print('train nums:', len(train_dataset)) 41 | valid_dataset = MIR1K(path='./dataset/MIR1K', hop_length=hop_length, groups=['test'], sequence_length=None) 42 | print('valid nums:', len(valid_dataset)) 43 | data_loader = DataLoader(train_dataset, batch_size, shuffle=True) 44 | epoch_nums = len(data_loader) 45 | print('epoch_nums:', epoch_nums) 46 | learning_rate_decay_steps = len(data_loader) * learning_rate_decay_epochs 47 | iterations = epoch_nums * train_epochs 48 | 49 | resume_iteration = None 50 | os.makedirs(logdir, exist_ok=True) 51 | writer = SummaryWriter(logdir) 52 | 53 | if resume_iteration is None: 54 | # in_channels, n_blocks, hop_length, latent_layers, seq_frames, seq='gru', seq_layers=1 55 | model = DJCM(in_channels, n_blocks, hop_length, latent_layers, seq_frames) 56 | model = nn.DataParallel(model).to(device) 57 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 58 | resume_iteration = 0 59 | else: 60 | model_path = os.path.join(logdir, f'model-{resume_iteration}.pt') 61 | model = torch.load(model_path) 62 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 63 | optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt'))) 64 | 65 | scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate) 66 | summary(model) 67 | SDR, RPA, GNSDR, RCA, it = 0, 0, 0, 0, 0 68 | loop = tqdm(range(resume_iteration + 1, iterations + 1)) 69 | 70 | for i, data in zip(loop, cycle(data_loader)): 71 | audio_m = data['audio_m'].to(device) 72 | audio_v = data['audio_v'].to(device) 73 | pitch_label = data['pitch'].to(device) 74 | out_audio, out_pitch, loss_spec = model(audio_m, audio_v) 75 | 76 | loss_svs = mae(out_audio, audio_v) 77 | loss_pitch = FL(out_pitch, pitch_label, alpha, gamma) 78 | # weight_pe = loss_svs.item() / loss_pitch.item() 79 | loss_total = weight_svs * loss_svs + weight_pe * loss_pitch 80 | 81 | optimizer.zero_grad() 82 | loss_total.backward() 83 | if clip_grad_norm: 84 | clip_grad_norm_(model.parameters(), clip_grad_norm) 85 | optimizer.step() 86 | scheduler.step() 87 | 88 | print(i, end='\t') 89 | print('loss_total:{:.6f}'.format(loss_total.item()), end='\t') 90 | print('loss_svs:{:.6f}'.format(loss_svs.item()), end='\t') 91 | print('loss_pe:{:.6f}'.format(loss_pitch.item())) 92 | 93 | writer.add_scalar('loss/loss_total', loss_total.item(), global_step=i) 94 | writer.add_scalar('loss/loss_svs', loss_svs.item(), global_step=i) 95 | writer.add_scalar('loss/loss_pe', loss_pitch.item(), global_step=i) 96 | 97 | if i % epoch_nums == 0: 98 | print('*' * 50) 99 | print(i, '\t', epoch_nums) 100 | model.eval() 101 | with torch.no_grad(): 102 | metrics = evaluate(valid_dataset, model, batch_size, hop_length, seq_l, device, None, pitch_th) 103 | for key, value in metrics.items(): 104 | writer.add_scalar('validation/' + key, np.mean(value), global_step=i) 105 | gnsdr = np.round((np.sum(metrics["NSDR_W"]) / np.sum(metrics["LENGTH"])), 2) 106 | writer.add_scalar('validation/GNSDR', gnsdr, global_step=i) 107 | sdr = np.round(np.mean(metrics['SDR']), 2) 108 | rpa = np.round(np.mean(metrics['RPA']) * 100, 2) 109 | rca = np.round(np.mean(metrics['RCA']) * 100, 2) 110 | oa = np.round(np.mean(metrics['OA']) * 100, 2) 111 | if sdr + rpa >= SDR + RPA: 112 | SDR, GNSDR, RPA, RCA, it = sdr, gnsdr, rpa, rca, i 113 | with open(os.path.join(logdir, 'result.txt'), 'a') as f: 114 | f.write(str(i) + '\t') 115 | f.write(str(SDR) + '±' + str(np.round(np.std(metrics['SDR']), 2)) + '\t') 116 | f.write(str(GNSDR) + '\t') 117 | f.write(str(RPA) + '±' + str(np.round(np.std(metrics['RPA']) * 100, 2)) + '\t') 118 | f.write(str(RCA) + '±' + str(np.round(np.std(metrics['RCA']) * 100, 2)) + '\t') 119 | f.write(str(oa) + '±' + str(np.round(np.std(metrics['OA']) * 100, 2)) + '\n') 120 | torch.save(model, os.path.join(logdir, f'model-{i}.pt')) 121 | torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt')) 122 | model.train() 123 | 124 | if i - it >= epoch_nums * early_stop_epochs: 125 | break 126 | 127 | 128 | for weight_svs in [1.2, 0.8, 1.4, 0.6, 1.6, 0.4, 1.8, 0.2]: 129 | train(weight_svs) 130 | -------------------------------------------------------------------------------- /train_MMOE.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn.utils import clip_grad_norm_ 5 | from torch.optim.lr_scheduler import StepLR 6 | from torch.utils.data import DataLoader 7 | from torch import nn 8 | from torch.utils.tensorboard import SummaryWriter 9 | from tqdm import tqdm 10 | import numpy as np 11 | from src import MIR1K, cycle, summary, JM_MMOE, FL, mae 12 | from evaluate import evaluate 13 | 14 | 15 | def train(expert_num, alpha): 16 | # alpha = 1 17 | gamma = 0 18 | 19 | in_channels = 1 20 | n_blocks = 1 21 | latent_layers = 1 22 | seq_l = 2.56 23 | hop_length = 20 24 | weight_svs = 1 25 | weight_pe = 1 26 | seq_frames = int(seq_l * 1000 / hop_length) 27 | logdir = 'runs/MIR1K/' + 'nblocks' + str(n_blocks) + '_latent' + str(latent_layers) + '_frames' + str(seq_frames) \ 28 | + '_expertnum' + str(expert_num) + '_alpha' + str(alpha) + '_gamma' + str(gamma) + '_svs' + str(weight_svs)\ 29 | + '_pe' + str(weight_pe) 30 | 31 | pitch_th = 0.5 32 | learning_rate = 5e-4 33 | if expert_num == 3: 34 | batch_size = 12 35 | else: 36 | batch_size = 16 37 | clip_grad_norm = 3 38 | learning_rate_decay_rate = 0.95 39 | learning_rate_decay_epochs = 5 40 | train_epochs = 250 41 | early_stop_epochs = 5 42 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | # path, hop_length, sequence_length = None, groups = None 44 | train_dataset = MIR1K(path='./dataset/MIR1K', hop_length=hop_length, groups=['train'], sequence_length=seq_l) 45 | print('train nums:', len(train_dataset)) 46 | valid_dataset = MIR1K(path='./dataset/MIR1K', hop_length=hop_length, groups=['test'], sequence_length=None) 47 | print('valid nums:', len(valid_dataset)) 48 | data_loader = DataLoader(train_dataset, batch_size, shuffle=True) 49 | epoch_nums = len(data_loader) 50 | print('epoch_nums:', epoch_nums) 51 | learning_rate_decay_steps = len(data_loader) * learning_rate_decay_epochs 52 | iterations = epoch_nums * train_epochs 53 | 54 | resume_iteration = None 55 | os.makedirs(logdir, exist_ok=True) 56 | writer = SummaryWriter(logdir) 57 | 58 | if resume_iteration is None: 59 | # in_channels, n_blocks, hop_length, latent_layers, seq_frames, expert_num = 2, seq = 'gru', seq_layers = 1 60 | model = JM_MMOE(in_channels, n_blocks, hop_length, latent_layers, seq_frames, expert_num) 61 | model = nn.DataParallel(model).to(device) 62 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 63 | resume_iteration = 0 64 | else: 65 | model_path = os.path.join(logdir, f'model-{resume_iteration}.pt') 66 | model = torch.load(model_path) 67 | optimizer = torch.optim.Adam(model.parameters(), learning_rate) 68 | optimizer.load_state_dict(torch.load(os.path.join(logdir, 'last-optimizer-state.pt'))) 69 | 70 | scheduler = StepLR(optimizer, step_size=learning_rate_decay_steps, gamma=learning_rate_decay_rate) 71 | summary(model) 72 | SDR, RPA, GNSDR, RCA, it = 0, 0, 0, 0, 0 73 | loop = tqdm(range(resume_iteration + 1, iterations + 1)) 74 | 75 | for i, data in zip(loop, cycle(data_loader)): 76 | audio_m = data['audio_m'].to(device) 77 | audio_v = data['audio_v'].to(device) 78 | pitch_label = data['pitch'].to(device) 79 | out_audio, out_pitch, loss_spec = model(audio_m, audio_v) 80 | 81 | loss_svs = mae(out_audio, audio_v) 82 | loss_pitch = FL(out_pitch, pitch_label, alpha, gamma) 83 | loss_total = weight_svs * loss_svs + weight_pe * loss_pitch 84 | 85 | optimizer.zero_grad() 86 | loss_total.backward() 87 | if clip_grad_norm: 88 | clip_grad_norm_(model.parameters(), clip_grad_norm) 89 | optimizer.step() 90 | scheduler.step() 91 | 92 | print(i, end='\t') 93 | print('loss_total:{:.6f}'.format(loss_total.item()), end='\t') 94 | print('loss_svs:{:.6f}'.format(loss_svs.item()), end='\t') 95 | print('loss_pe:{:.6f}'.format(loss_pitch.item())) 96 | 97 | writer.add_scalar('loss/loss_total', loss_total.item(), global_step=i) 98 | writer.add_scalar('loss/loss_svs', loss_svs.item(), global_step=i) 99 | writer.add_scalar('loss/loss_pe', loss_pitch.item(), global_step=i) 100 | 101 | if i % epoch_nums == 0: 102 | print('*' * 50) 103 | print(i, '\t', epoch_nums) 104 | model.eval() 105 | with torch.no_grad(): 106 | metrics = evaluate(valid_dataset, model, batch_size, hop_length, seq_l, device, None, pitch_th) 107 | for key, value in metrics.items(): 108 | writer.add_scalar('validation/' + key, np.mean(value), global_step=i) 109 | gnsdr = np.round((np.sum(metrics["NSDR_W"]) / np.sum(metrics["LENGTH"])), 2) 110 | writer.add_scalar('validation/GNSDR', gnsdr, global_step=i) 111 | sdr = np.round(np.mean(metrics['SDR']), 2) 112 | rpa = np.round(np.mean(metrics['RPA']) * 100, 2) 113 | rca = np.round(np.mean(metrics['RCA']) * 100, 2) 114 | oa = np.round(np.mean(metrics['OA']) * 100, 2) 115 | if sdr > SDR or rpa > RPA: 116 | SDR, GNSDR, RPA, RCA, it = sdr, gnsdr, rpa, rca, i 117 | with open(os.path.join(logdir, 'result.txt'), 'a') as f: 118 | f.write(str(i) + '\t') 119 | f.write(str(RPA) + '±' + str(np.round(np.std(metrics['RPA']) * 100, 2)) + '\t') 120 | f.write(str(RCA) + '±' + str(np.round(np.std(metrics['RCA']) * 100, 2)) + '\t') 121 | f.write(str(oa) + '±' + str(np.round(np.std(metrics['OA']) * 100, 2)) + '\t') 122 | f.write(str(SDR) + '±' + str(np.round(np.std(metrics['SDR']), 2)) + '\t') 123 | f.write(str(GNSDR) + '\n') 124 | torch.save(model, os.path.join(logdir, f'model-{i}.pt')) 125 | torch.save(optimizer.state_dict(), os.path.join(logdir, 'last-optimizer-state.pt')) 126 | model.train() 127 | 128 | if i - it >= epoch_nums * early_stop_epochs: 129 | break 130 | 131 | 132 | for alpha in [1, 2, 3, 4, 5]: 133 | train(2, alpha) 134 | --------------------------------------------------------------------------------