├── .gitignore ├── LICENSE.md ├── README.md ├── emo-net ├── __init__.py ├── cli.py ├── data │ ├── __init__.py │ ├── compute_scaling.py │ └── loader.py ├── models │ ├── __init__.py │ ├── adapter_resnet.py │ ├── adapter_rnn.py │ ├── attention.py │ ├── build_model.py │ └── input_layers.py ├── training │ ├── __init__.py │ ├── evaluate.py │ ├── losses.py │ ├── metrics.py │ └── train.py └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | auDeep.egg-info 3 | build/ 4 | __pycache__ 5 | .vscode 6 | .env/ 7 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ### GNU GENERAL PUBLIC LICENSE 2 | 3 | Version 3, 29 June 2007 4 | 5 | Copyright (C) 2007 Free Software Foundation, Inc. 6 | 7 | 8 | Everyone is permitted to copy and distribute verbatim copies of this 9 | license document, but changing it is not allowed. 10 | 11 | ### Preamble 12 | 13 | The GNU General Public License is a free, copyleft license for 14 | software and other kinds of works. 15 | 16 | The licenses for most software and other practical works are designed 17 | to take away your freedom to share and change the works. By contrast, 18 | the GNU General Public License is intended to guarantee your freedom 19 | to share and change all versions of a program--to make sure it remains 20 | free software for all its users. We, the Free Software Foundation, use 21 | the GNU General Public License for most of our software; it applies 22 | also to any other work released this way by its authors. You can apply 23 | it to your programs, too. 24 | 25 | When we speak of free software, we are referring to freedom, not 26 | price. Our General Public Licenses are designed to make sure that you 27 | have the freedom to distribute copies of free software (and charge for 28 | them if you wish), that you receive source code or can get it if you 29 | want it, that you can change the software or use pieces of it in new 30 | free programs, and that you know you can do these things. 31 | 32 | To protect your rights, we need to prevent others from denying you 33 | these rights or asking you to surrender the rights. Therefore, you 34 | have certain responsibilities if you distribute copies of the 35 | software, or if you modify it: responsibilities to respect the freedom 36 | of others. 37 | 38 | For example, if you distribute copies of such a program, whether 39 | gratis or for a fee, you must pass on to the recipients the same 40 | freedoms that you received. You must make sure that they, too, receive 41 | or can get the source code. And you must show them these terms so they 42 | know their rights. 43 | 44 | Developers that use the GNU GPL protect your rights with two steps: 45 | (1) assert copyright on the software, and (2) offer you this License 46 | giving you legal permission to copy, distribute and/or modify it. 47 | 48 | For the developers' and authors' protection, the GPL clearly explains 49 | that there is no warranty for this free software. For both users' and 50 | authors' sake, the GPL requires that modified versions be marked as 51 | changed, so that their problems will not be attributed erroneously to 52 | authors of previous versions. 53 | 54 | Some devices are designed to deny users access to install or run 55 | modified versions of the software inside them, although the 56 | manufacturer can do so. This is fundamentally incompatible with the 57 | aim of protecting users' freedom to change the software. The 58 | systematic pattern of such abuse occurs in the area of products for 59 | individuals to use, which is precisely where it is most unacceptable. 60 | Therefore, we have designed this version of the GPL to prohibit the 61 | practice for those products. If such problems arise substantially in 62 | other domains, we stand ready to extend this provision to those 63 | domains in future versions of the GPL, as needed to protect the 64 | freedom of users. 65 | 66 | Finally, every program is threatened constantly by software patents. 67 | States should not allow patents to restrict development and use of 68 | software on general-purpose computers, but in those that do, we wish 69 | to avoid the special danger that patents applied to a free program 70 | could make it effectively proprietary. To prevent this, the GPL 71 | assures that patents cannot be used to render the program non-free. 72 | 73 | The precise terms and conditions for copying, distribution and 74 | modification follow. 75 | 76 | ### TERMS AND CONDITIONS 77 | 78 | #### 0. Definitions. 79 | 80 | "This License" refers to version 3 of the GNU General Public License. 81 | 82 | "Copyright" also means copyright-like laws that apply to other kinds 83 | of works, such as semiconductor masks. 84 | 85 | "The Program" refers to any copyrightable work licensed under this 86 | License. Each licensee is addressed as "you". "Licensees" and 87 | "recipients" may be individuals or organizations. 88 | 89 | To "modify" a work means to copy from or adapt all or part of the work 90 | in a fashion requiring copyright permission, other than the making of 91 | an exact copy. The resulting work is called a "modified version" of 92 | the earlier work or a work "based on" the earlier work. 93 | 94 | A "covered work" means either the unmodified Program or a work based 95 | on the Program. 96 | 97 | To "propagate" a work means to do anything with it that, without 98 | permission, would make you directly or secondarily liable for 99 | infringement under applicable copyright law, except executing it on a 100 | computer or modifying a private copy. Propagation includes copying, 101 | distribution (with or without modification), making available to the 102 | public, and in some countries other activities as well. 103 | 104 | To "convey" a work means any kind of propagation that enables other 105 | parties to make or receive copies. Mere interaction with a user 106 | through a computer network, with no transfer of a copy, is not 107 | conveying. 108 | 109 | An interactive user interface displays "Appropriate Legal Notices" to 110 | the extent that it includes a convenient and prominently visible 111 | feature that (1) displays an appropriate copyright notice, and (2) 112 | tells the user that there is no warranty for the work (except to the 113 | extent that warranties are provided), that licensees may convey the 114 | work under this License, and how to view a copy of this License. If 115 | the interface presents a list of user commands or options, such as a 116 | menu, a prominent item in the list meets this criterion. 117 | 118 | #### 1. Source Code. 119 | 120 | The "source code" for a work means the preferred form of the work for 121 | making modifications to it. "Object code" means any non-source form of 122 | a work. 123 | 124 | A "Standard Interface" means an interface that either is an official 125 | standard defined by a recognized standards body, or, in the case of 126 | interfaces specified for a particular programming language, one that 127 | is widely used among developers working in that language. 128 | 129 | The "System Libraries" of an executable work include anything, other 130 | than the work as a whole, that (a) is included in the normal form of 131 | packaging a Major Component, but which is not part of that Major 132 | Component, and (b) serves only to enable use of the work with that 133 | Major Component, or to implement a Standard Interface for which an 134 | implementation is available to the public in source code form. A 135 | "Major Component", in this context, means a major essential component 136 | (kernel, window system, and so on) of the specific operating system 137 | (if any) on which the executable work runs, or a compiler used to 138 | produce the work, or an object code interpreter used to run it. 139 | 140 | The "Corresponding Source" for a work in object code form means all 141 | the source code needed to generate, install, and (for an executable 142 | work) run the object code and to modify the work, including scripts to 143 | control those activities. However, it does not include the work's 144 | System Libraries, or general-purpose tools or generally available free 145 | programs which are used unmodified in performing those activities but 146 | which are not part of the work. For example, Corresponding Source 147 | includes interface definition files associated with source files for 148 | the work, and the source code for shared libraries and dynamically 149 | linked subprograms that the work is specifically designed to require, 150 | such as by intimate data communication or control flow between those 151 | subprograms and other parts of the work. 152 | 153 | The Corresponding Source need not include anything that users can 154 | regenerate automatically from other parts of the Corresponding Source. 155 | 156 | The Corresponding Source for a work in source code form is that same 157 | work. 158 | 159 | #### 2. Basic Permissions. 160 | 161 | All rights granted under this License are granted for the term of 162 | copyright on the Program, and are irrevocable provided the stated 163 | conditions are met. This License explicitly affirms your unlimited 164 | permission to run the unmodified Program. The output from running a 165 | covered work is covered by this License only if the output, given its 166 | content, constitutes a covered work. This License acknowledges your 167 | rights of fair use or other equivalent, as provided by copyright law. 168 | 169 | You may make, run and propagate covered works that you do not convey, 170 | without conditions so long as your license otherwise remains in force. 171 | You may convey covered works to others for the sole purpose of having 172 | them make modifications exclusively for you, or provide you with 173 | facilities for running those works, provided that you comply with the 174 | terms of this License in conveying all material for which you do not 175 | control copyright. Those thus making or running the covered works for 176 | you must do so exclusively on your behalf, under your direction and 177 | control, on terms that prohibit them from making any copies of your 178 | copyrighted material outside their relationship with you. 179 | 180 | Conveying under any other circumstances is permitted solely under the 181 | conditions stated below. Sublicensing is not allowed; section 10 makes 182 | it unnecessary. 183 | 184 | #### 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 185 | 186 | No covered work shall be deemed part of an effective technological 187 | measure under any applicable law fulfilling obligations under article 188 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 189 | similar laws prohibiting or restricting circumvention of such 190 | measures. 191 | 192 | When you convey a covered work, you waive any legal power to forbid 193 | circumvention of technological measures to the extent such 194 | circumvention is effected by exercising rights under this License with 195 | respect to the covered work, and you disclaim any intention to limit 196 | operation or modification of the work as a means of enforcing, against 197 | the work's users, your or third parties' legal rights to forbid 198 | circumvention of technological measures. 199 | 200 | #### 4. Conveying Verbatim Copies. 201 | 202 | You may convey verbatim copies of the Program's source code as you 203 | receive it, in any medium, provided that you conspicuously and 204 | appropriately publish on each copy an appropriate copyright notice; 205 | keep intact all notices stating that this License and any 206 | non-permissive terms added in accord with section 7 apply to the code; 207 | keep intact all notices of the absence of any warranty; and give all 208 | recipients a copy of this License along with the Program. 209 | 210 | You may charge any price or no price for each copy that you convey, 211 | and you may offer support or warranty protection for a fee. 212 | 213 | #### 5. Conveying Modified Source Versions. 214 | 215 | You may convey a work based on the Program, or the modifications to 216 | produce it from the Program, in the form of source code under the 217 | terms of section 4, provided that you also meet all of these 218 | conditions: 219 | 220 | - a) The work must carry prominent notices stating that you modified 221 | it, and giving a relevant date. 222 | - b) The work must carry prominent notices stating that it is 223 | released under this License and any conditions added under 224 | section 7. This requirement modifies the requirement in section 4 225 | to "keep intact all notices". 226 | - c) You must license the entire work, as a whole, under this 227 | License to anyone who comes into possession of a copy. This 228 | License will therefore apply, along with any applicable section 7 229 | additional terms, to the whole of the work, and all its parts, 230 | regardless of how they are packaged. This License gives no 231 | permission to license the work in any other way, but it does not 232 | invalidate such permission if you have separately received it. 233 | - d) If the work has interactive user interfaces, each must display 234 | Appropriate Legal Notices; however, if the Program has interactive 235 | interfaces that do not display Appropriate Legal Notices, your 236 | work need not make them do so. 237 | 238 | A compilation of a covered work with other separate and independent 239 | works, which are not by their nature extensions of the covered work, 240 | and which are not combined with it such as to form a larger program, 241 | in or on a volume of a storage or distribution medium, is called an 242 | "aggregate" if the compilation and its resulting copyright are not 243 | used to limit the access or legal rights of the compilation's users 244 | beyond what the individual works permit. Inclusion of a covered work 245 | in an aggregate does not cause this License to apply to the other 246 | parts of the aggregate. 247 | 248 | #### 6. Conveying Non-Source Forms. 249 | 250 | You may convey a covered work in object code form under the terms of 251 | sections 4 and 5, provided that you also convey the machine-readable 252 | Corresponding Source under the terms of this License, in one of these 253 | ways: 254 | 255 | - a) Convey the object code in, or embodied in, a physical product 256 | (including a physical distribution medium), accompanied by the 257 | Corresponding Source fixed on a durable physical medium 258 | customarily used for software interchange. 259 | - b) Convey the object code in, or embodied in, a physical product 260 | (including a physical distribution medium), accompanied by a 261 | written offer, valid for at least three years and valid for as 262 | long as you offer spare parts or customer support for that product 263 | model, to give anyone who possesses the object code either (1) a 264 | copy of the Corresponding Source for all the software in the 265 | product that is covered by this License, on a durable physical 266 | medium customarily used for software interchange, for a price no 267 | more than your reasonable cost of physically performing this 268 | conveying of source, or (2) access to copy the Corresponding 269 | Source from a network server at no charge. 270 | - c) Convey individual copies of the object code with a copy of the 271 | written offer to provide the Corresponding Source. This 272 | alternative is allowed only occasionally and noncommercially, and 273 | only if you received the object code with such an offer, in accord 274 | with subsection 6b. 275 | - d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | - e) Convey the object code using peer-to-peer transmission, 288 | provided you inform other peers where the object code and 289 | Corresponding Source of the work are being offered to the general 290 | public at no charge under subsection 6d. 291 | 292 | A separable portion of the object code, whose source code is excluded 293 | from the Corresponding Source as a System Library, need not be 294 | included in conveying the object code work. 295 | 296 | A "User Product" is either (1) a "consumer product", which means any 297 | tangible personal property which is normally used for personal, 298 | family, or household purposes, or (2) anything designed or sold for 299 | incorporation into a dwelling. In determining whether a product is a 300 | consumer product, doubtful cases shall be resolved in favor of 301 | coverage. For a particular product received by a particular user, 302 | "normally used" refers to a typical or common use of that class of 303 | product, regardless of the status of the particular user or of the way 304 | in which the particular user actually uses, or expects or is expected 305 | to use, the product. A product is a consumer product regardless of 306 | whether the product has substantial commercial, industrial or 307 | non-consumer uses, unless such uses represent the only significant 308 | mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to 312 | install and execute modified versions of a covered work in that User 313 | Product from a modified version of its Corresponding Source. The 314 | information must suffice to ensure that the continued functioning of 315 | the modified object code is in no case prevented or interfered with 316 | solely because modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or 331 | updates for a work that has been modified or installed by the 332 | recipient, or for the User Product in which it has been modified or 333 | installed. Access to a network may be denied when the modification 334 | itself materially and adversely affects the operation of the network 335 | or violates the rules and protocols for communication across the 336 | network. 337 | 338 | Corresponding Source conveyed, and Installation Information provided, 339 | in accord with this section must be in a format that is publicly 340 | documented (and with an implementation available to the public in 341 | source code form), and must require no special password or key for 342 | unpacking, reading or copying. 343 | 344 | #### 7. Additional Terms. 345 | 346 | "Additional permissions" are terms that supplement the terms of this 347 | License by making exceptions from one or more of its conditions. 348 | Additional permissions that are applicable to the entire Program shall 349 | be treated as though they were included in this License, to the extent 350 | that they are valid under applicable law. If additional permissions 351 | apply only to part of the Program, that part may be used separately 352 | under those permissions, but the entire Program remains governed by 353 | this License without regard to the additional permissions. 354 | 355 | When you convey a copy of a covered work, you may at your option 356 | remove any additional permissions from that copy, or from any part of 357 | it. (Additional permissions may be written to require their own 358 | removal in certain cases when you modify the work.) You may place 359 | additional permissions on material, added by you to a covered work, 360 | for which you have or can give appropriate copyright permission. 361 | 362 | Notwithstanding any other provision of this License, for material you 363 | add to a covered work, you may (if authorized by the copyright holders 364 | of that material) supplement the terms of this License with terms: 365 | 366 | - a) Disclaiming warranty or limiting liability differently from the 367 | terms of sections 15 and 16 of this License; or 368 | - b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | - c) Prohibiting misrepresentation of the origin of that material, 372 | or requiring that modified versions of such material be marked in 373 | reasonable ways as different from the original version; or 374 | - d) Limiting the use for publicity purposes of names of licensors 375 | or authors of the material; or 376 | - e) Declining to grant rights under trademark law for use of some 377 | trade names, trademarks, or service marks; or 378 | - f) Requiring indemnification of licensors and authors of that 379 | material by anyone who conveys the material (or modified versions 380 | of it) with contractual assumptions of liability to the recipient, 381 | for any liability that these contractual assumptions directly 382 | impose on those licensors and authors. 383 | 384 | All other non-permissive additional terms are considered "further 385 | restrictions" within the meaning of section 10. If the Program as you 386 | received it, or any part of it, contains a notice stating that it is 387 | governed by this License along with a term that is a further 388 | restriction, you may remove that term. If a license document contains 389 | a further restriction but permits relicensing or conveying under this 390 | License, you may add to a covered work material governed by the terms 391 | of that license document, provided that the further restriction does 392 | not survive such relicensing or conveying. 393 | 394 | If you add terms to a covered work in accord with this section, you 395 | must place, in the relevant source files, a statement of the 396 | additional terms that apply to those files, or a notice indicating 397 | where to find the applicable terms. 398 | 399 | Additional terms, permissive or non-permissive, may be stated in the 400 | form of a separately written license, or stated as exceptions; the 401 | above requirements apply either way. 402 | 403 | #### 8. Termination. 404 | 405 | You may not propagate or modify a covered work except as expressly 406 | provided under this License. Any attempt otherwise to propagate or 407 | modify it is void, and will automatically terminate your rights under 408 | this License (including any patent licenses granted under the third 409 | paragraph of section 11). 410 | 411 | However, if you cease all violation of this License, then your license 412 | from a particular copyright holder is reinstated (a) provisionally, 413 | unless and until the copyright holder explicitly and finally 414 | terminates your license, and (b) permanently, if the copyright holder 415 | fails to notify you of the violation by some reasonable means prior to 416 | 60 days after the cessation. 417 | 418 | Moreover, your license from a particular copyright holder is 419 | reinstated permanently if the copyright holder notifies you of the 420 | violation by some reasonable means, this is the first time you have 421 | received notice of violation of this License (for any work) from that 422 | copyright holder, and you cure the violation prior to 30 days after 423 | your receipt of the notice. 424 | 425 | Termination of your rights under this section does not terminate the 426 | licenses of parties who have received copies or rights from you under 427 | this License. If your rights have been terminated and not permanently 428 | reinstated, you do not qualify to receive new licenses for the same 429 | material under section 10. 430 | 431 | #### 9. Acceptance Not Required for Having Copies. 432 | 433 | You are not required to accept this License in order to receive or run 434 | a copy of the Program. Ancillary propagation of a covered work 435 | occurring solely as a consequence of using peer-to-peer transmission 436 | to receive a copy likewise does not require acceptance. However, 437 | nothing other than this License grants you permission to propagate or 438 | modify any covered work. These actions infringe copyright if you do 439 | not accept this License. Therefore, by modifying or propagating a 440 | covered work, you indicate your acceptance of this License to do so. 441 | 442 | #### 10. Automatic Licensing of Downstream Recipients. 443 | 444 | Each time you convey a covered work, the recipient automatically 445 | receives a license from the original licensors, to run, modify and 446 | propagate that work, subject to this License. You are not responsible 447 | for enforcing compliance by third parties with this License. 448 | 449 | An "entity transaction" is a transaction transferring control of an 450 | organization, or substantially all assets of one, or subdividing an 451 | organization, or merging organizations. If propagation of a covered 452 | work results from an entity transaction, each party to that 453 | transaction who receives a copy of the work also receives whatever 454 | licenses to the work the party's predecessor in interest had or could 455 | give under the previous paragraph, plus a right to possession of the 456 | Corresponding Source of the work from the predecessor in interest, if 457 | the predecessor has it or can get it with reasonable efforts. 458 | 459 | You may not impose any further restrictions on the exercise of the 460 | rights granted or affirmed under this License. For example, you may 461 | not impose a license fee, royalty, or other charge for exercise of 462 | rights granted under this License, and you may not initiate litigation 463 | (including a cross-claim or counterclaim in a lawsuit) alleging that 464 | any patent claim is infringed by making, using, selling, offering for 465 | sale, or importing the Program or any portion of it. 466 | 467 | #### 11. Patents. 468 | 469 | A "contributor" is a copyright holder who authorizes use under this 470 | License of the Program or a work on which the Program is based. The 471 | work thus licensed is called the contributor's "contributor version". 472 | 473 | A contributor's "essential patent claims" are all patent claims owned 474 | or controlled by the contributor, whether already acquired or 475 | hereafter acquired, that would be infringed by some manner, permitted 476 | by this License, of making, using, or selling its contributor version, 477 | but do not include claims that would be infringed only as a 478 | consequence of further modification of the contributor version. For 479 | purposes of this definition, "control" includes the right to grant 480 | patent sublicenses in a manner consistent with the requirements of 481 | this License. 482 | 483 | Each contributor grants you a non-exclusive, worldwide, royalty-free 484 | patent license under the contributor's essential patent claims, to 485 | make, use, sell, offer for sale, import and otherwise run, modify and 486 | propagate the contents of its contributor version. 487 | 488 | In the following three paragraphs, a "patent license" is any express 489 | agreement or commitment, however denominated, not to enforce a patent 490 | (such as an express permission to practice a patent or covenant not to 491 | sue for patent infringement). To "grant" such a patent license to a 492 | party means to make such an agreement or commitment not to enforce a 493 | patent against the party. 494 | 495 | If you convey a covered work, knowingly relying on a patent license, 496 | and the Corresponding Source of the work is not available for anyone 497 | to copy, free of charge and under the terms of this License, through a 498 | publicly available network server or other readily accessible means, 499 | then you must either (1) cause the Corresponding Source to be so 500 | available, or (2) arrange to deprive yourself of the benefit of the 501 | patent license for this particular work, or (3) arrange, in a manner 502 | consistent with the requirements of this License, to extend the patent 503 | license to downstream recipients. "Knowingly relying" means you have 504 | actual knowledge that, but for the patent license, your conveying the 505 | covered work in a country, or your recipient's use of the covered work 506 | in a country, would infringe one or more identifiable patents in that 507 | country that you have reason to believe are valid. 508 | 509 | If, pursuant to or in connection with a single transaction or 510 | arrangement, you convey, or propagate by procuring conveyance of, a 511 | covered work, and grant a patent license to some of the parties 512 | receiving the covered work authorizing them to use, propagate, modify 513 | or convey a specific copy of the covered work, then the patent license 514 | you grant is automatically extended to all recipients of the covered 515 | work and works based on it. 516 | 517 | A patent license is "discriminatory" if it does not include within the 518 | scope of its coverage, prohibits the exercise of, or is conditioned on 519 | the non-exercise of one or more of the rights that are specifically 520 | granted under this License. You may not convey a covered work if you 521 | are a party to an arrangement with a third party that is in the 522 | business of distributing software, under which you make payment to the 523 | third party based on the extent of your activity of conveying the 524 | work, and under which the third party grants, to any of the parties 525 | who would receive the covered work from you, a discriminatory patent 526 | license (a) in connection with copies of the covered work conveyed by 527 | you (or copies made from those copies), or (b) primarily for and in 528 | connection with specific products or compilations that contain the 529 | covered work, unless you entered into that arrangement, or that patent 530 | license was granted, prior to 28 March 2007. 531 | 532 | Nothing in this License shall be construed as excluding or limiting 533 | any implied license or other defenses to infringement that may 534 | otherwise be available to you under applicable patent law. 535 | 536 | #### 12. No Surrender of Others' Freedom. 537 | 538 | If conditions are imposed on you (whether by court order, agreement or 539 | otherwise) that contradict the conditions of this License, they do not 540 | excuse you from the conditions of this License. If you cannot convey a 541 | covered work so as to satisfy simultaneously your obligations under 542 | this License and any other pertinent obligations, then as a 543 | consequence you may not convey it at all. For example, if you agree to 544 | terms that obligate you to collect a royalty for further conveying 545 | from those to whom you convey the Program, the only way you could 546 | satisfy both those terms and this License would be to refrain entirely 547 | from conveying the Program. 548 | 549 | #### 13. Use with the GNU Affero General Public License. 550 | 551 | Notwithstanding any other provision of this License, you have 552 | permission to link or combine any covered work with a work licensed 553 | under version 3 of the GNU Affero General Public License into a single 554 | combined work, and to convey the resulting work. The terms of this 555 | License will continue to apply to the part which is the covered work, 556 | but the special requirements of the GNU Affero General Public License, 557 | section 13, concerning interaction through a network will apply to the 558 | combination as such. 559 | 560 | #### 14. Revised Versions of this License. 561 | 562 | The Free Software Foundation may publish revised and/or new versions 563 | of the GNU General Public License from time to time. Such new versions 564 | will be similar in spirit to the present version, but may differ in 565 | detail to address new problems or concerns. 566 | 567 | Each version is given a distinguishing version number. If the Program 568 | specifies that a certain numbered version of the GNU General Public 569 | License "or any later version" applies to it, you have the option of 570 | following the terms and conditions either of that numbered version or 571 | of any later version published by the Free Software Foundation. If the 572 | Program does not specify a version number of the GNU General Public 573 | License, you may choose any version ever published by the Free 574 | Software Foundation. 575 | 576 | If the Program specifies that a proxy can decide which future versions 577 | of the GNU General Public License can be used, that proxy's public 578 | statement of acceptance of a version permanently authorizes you to 579 | choose that version for the Program. 580 | 581 | Later license versions may give you additional or different 582 | permissions. However, no additional obligations are imposed on any 583 | author or copyright holder as a result of your choosing to follow a 584 | later version. 585 | 586 | #### 15. Disclaimer of Warranty. 587 | 588 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 589 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 590 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT 591 | WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT 592 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 593 | A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND 594 | PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE 595 | DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR 596 | CORRECTION. 597 | 598 | #### 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR 602 | CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 603 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES 604 | ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT 605 | NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR 606 | LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM 607 | TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER 608 | PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 609 | 610 | #### 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | ### How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these 626 | terms. 627 | 628 | To do so, attach the following notices to the program. It is safest to 629 | attach them to the start of each source file to most effectively state 630 | the exclusion of warranty; and each file should have at least the 631 | "copyright" line and a pointer to where the full notice is found. 632 | 633 | 634 | Copyright (C) 635 | 636 | This program is free software: you can redistribute it and/or modify 637 | it under the terms of the GNU General Public License as published by 638 | the Free Software Foundation, either version 3 of the License, or 639 | (at your option) any later version. 640 | 641 | This program is distributed in the hope that it will be useful, 642 | but WITHOUT ANY WARRANTY; without even the implied warranty of 643 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 644 | GNU General Public License for more details. 645 | 646 | You should have received a copy of the GNU General Public License 647 | along with this program. If not, see . 648 | 649 | Also add information on how to contact you by electronic and paper 650 | mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands \`show w' and \`show c' should show the 661 | appropriate parts of the General Public License. Of course, your 662 | program's commands might be different; for a GUI interface, you would 663 | use an "about box". 664 | 665 | You should also get your employer (if you work as a programmer) or 666 | school, if any, to sign a "copyright disclaimer" for the program, if 667 | necessary. For more information on this, and how to apply and follow 668 | the GNU GPL, see . 669 | 670 | The GNU General Public License does not permit incorporating your 671 | program into proprietary programs. If your program is a subroutine 672 | library, you may consider it more useful to permit linking proprietary 673 | applications with the library. If this is what you want to do, use the 674 | GNU Lesser General Public License instead of this License. But first, 675 | please read . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | 3 | 4 | **EmoNet** is a Python toolkit for multi-corpus speech emotion recognition and other audio classification tasks. 5 | 6 | **(c) 2021 Maurice Gerczuk, Shahin Amiriparian, Björn Schuller: Universität Augsburg** 7 | 8 | Please direct any questions or requests to Maurice Gerczuk (maurice.gerczuk at uni-a.de) or Shahin Amiriparian (shahin.amiriparian at uni-a.de). 9 | 10 | # Citing 11 | If you use EmoNet or any code from EmoNet in your research work, you are kindly asked to acknowledge the use of EmoNet in your publications. 12 | > M. Gerczuk, S. Amiriparian, S. Ottl, and B. Schuller, “EmoNet: A transfer learning framework for multi-corpus speech emotionrecognition,” 2021. [https://arxiv.org/abs/2103.08310](https://arxiv.org/abs/2103.08310) 13 | 14 | 15 | ``` 16 | @misc{gerczuk2021emonet, 17 | title={EmoNet: A Transfer Learning Framework for Multi-Corpus Speech Emotion Recognition}, 18 | author={Maurice Gerczuk and Shahin Amiriparian and Sandra Ottl and Björn Schuller}, 19 | year={2021}, 20 | eprint={2103.08310}, 21 | archivePrefix={arXiv}, 22 | primaryClass={cs.SD} 23 | } 24 | ``` 25 | 26 | 27 | ## Installation 28 | 29 | All dependencies can be installed via pip from the requirements.txt: 30 | 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | It is advisable to do this from within a newly created virtual environment. 36 | 37 | ## Usage 38 | 39 | The basic commandline is accessible from the repository's basedirectory by calling: 40 | 41 | ```bash 42 | python -m emo-net.cli --help 43 | ``` 44 | 45 | This prints a help message specifying the list of subcommands. For each subcommand, more help is available via: 46 | 47 | ```bash 48 | python -m emo-net.cli [subcommand] --help 49 | ``` 50 | 51 | ### Data Preparation 52 | 53 | The toolkit can be used for arbitrary audio classification tasks. To prepare your dataset, resample all audio content to 16kHz wav files (e.g. with ffmpeg). Afterwards, you need label files in .csv format that specify the categorical target for each sample in the training, development and test partitions, i.e., three files "train.csv", "devel.csv" and "test.csv". The files must include the path to each audio file in the first column - relative to a common basedirectory - and a categorical label in the second column. A header line "file,label" should be included. 54 | 55 | ### Command line options 56 | 57 | The CLI has a nested structure, i.e., it uses two layers of subcommands. The first subcommand specifies the type of neural network architecture that is used. Here, "cnn" gives access to the ResNet architecture which also includes residual adapters, based on the training setting. Two other options, "rnn" and "fusion" are also included but untested and in early stages of development. The rest of this guide will therefore focus on the "cnn" subcommand. After specifying the model type, two distinct subcommands are accessible: "single-task" and "multi-task", which refer to the type of training procedure. For single task, training is performed on one database at a time specified by its basedirectory and the label files for train, validation and developments: 58 | 59 | ```bash 60 | python -m emo-net.cli -v cnn single-task -t [taskName] --data-path /path/to/task/wavs -tr train.csv -v devel.csv -te test.csv 61 | ``` 62 | 63 | One additional parameter is needed that defines the type of training performed. Here, the choice can be made between tuning a fresh model from scratch (`-m scratch`), fully finetuning an existing model (`-m finetune`), training only the classifier head (`-m last-layer`) and the residual adapter approach (`-m adapters`). For the last three methods, a pre-trained model has to be loaded by specifying the path to its weights via `-im /path/to/weights.h5`. While all other parameters have sensible default values, the full list is given below: 64 | 65 | | Option | Type | Description | 66 | | ------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | 67 | | -dp, --data-path | DIRECTORY | Directory of data files. [required] | 68 | | -t, --task | TEXT | Name of the task that is trained. [required] | 69 | | -tr, --train-csv | FILE | Path to training csv file. [required] | 70 | | -v, --val-csv | FILE | Path to validation csv file. [required] | 71 | | -te, --test-csv | FILE | Path to test csv file. [required] | 72 | | -bs, --batch-size | INTEGER | Define batch size. | 73 | | -nm, --num-mels | INTEGER | Number of mel bands in spectrogram. | 74 | | -e, --epochs | INTEGER | Define max number of training epochs. | 75 | | -p, --patience | INTEGER | Define patience before early stopping / reducing learning rate in epochs. | 76 | | -im, --initial-model | FILE | Initial model for resuming training. | 77 | | -bw, --balanced-weights | FLAG | Automatically set balanced class weights. | 78 | | -lr, --learning-rate | FLOAT | Initial earning rate for optimizer. | 79 | | -do, --dropout | FLOAT | Dropout for the two positions (after first and second convolution of each block). | 80 | | -ebp, --experiment-base-path | PATH | Basepath where logs and checkpoints should be stored. | 81 | | -o, --optimizer | [sgd\|rmsprop\|adam\|adadelta] | Optimizer used for training. | 82 | | -N, --number-of-resnet-blocks | INTEGER | Number of convolutional blocks in the ResNet layers. | 83 | | -nf, --number-of-filters | INTEGER | Number of filters in first convolutional block. | 84 | | -wf, --widen-factor | INTEGER | Widen factor of wide ResNet | 85 | | -c, --classifier | [avgpool\|FCNAttention] | The classification top of the network architeture. Choose between simple pooling + dense layer (needs fixed window size) and fully convolutional attention. | 86 | | -w, --window | FLOAT | Window size in seconds. | 87 | | -l, --loss | [crossentropy\|focal\|ordinal] | Classification loss. Ordinal loss ues sorted class labels. | 88 | | -m, --mode | [scratch\|adapters\|last-layer\|finetune] | Type of training to be performed. | 89 | | -sfl, --share-feature-layer | FLAG | Share the feature layer (weighted attention of deep features) between tasks. | 90 | | -iwd, --individual-weight-decay | FLAG | Set weight decay in adapters according to size of training dataset. Smaller datasets will have larger weight decay to keep closer to the pre-trained network. | 91 | | --help | FLAG | Show this message and exit. | 92 | 93 | The "multi-task" command line slightly differs from the one described above. The most notable difference is in how the data is passed. Instead of passing individual .csv files for each partition, a directory - "--multi-task-setup" - which contains a folder with "train.csv", "val.csv" and "test.csv" files for each database has to be specified. Additionally, "-t" now is used to specify a list of databases (subfolders of the multi task setup) that should be used for training. As multi-domain training is done in a round-robin fashion, there is no predefined notion of a training epoch. Therefore, an additional option ("--steps-per-epoch") is used to define the size of an artificial training epoch. These additional parameters are also given in the table below. 94 | 95 | | Option | Type | Description | 96 | | ------------------------ | --------- | ----------------------------------------------------------------------------------------------------------------- | 97 | | -dp, --data-path | DIRECTORY | Directory of wav files. [required] | 98 | | -mts, --multi-task-setup | DIRECTORY | Directory with the setup csvs ("train.csv", "val.csv", "test.csv") for each task in a separate folder. [required] | 99 | | -t, --tasks | TEXT | Names of the tasks that are trained. [required] | 100 | | -spe, --steps-per-epoch | INTEGER | Number of training steps for each artificial epoch. | 101 | -------------------------------------------------------------------------------- /emo-net/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/__init__.py -------------------------------------------------------------------------------- /emo-net/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/data/__init__.py -------------------------------------------------------------------------------- /emo-net/data/compute_scaling.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | import numpy as np 21 | import pickle 22 | import glob 23 | from tqdm import tqdm 24 | from ..models.input_layers import LogMelgramLayer 25 | from ..data.loader import AudioDataGenerator 26 | from os.path import join 27 | from sklearn.preprocessing import StandardScaler 28 | 29 | 30 | 31 | def compute_scaling(dataset_base): 32 | train_generator = AudioDataGenerator(join(dataset_base, 'train.csv'), 33 | '/mnt/nas/data_work/shahin/EmoSet/wavs-reordered/', 34 | batch_size=1, 35 | window=None, 36 | shuffle=False, 37 | sr=16000, 38 | time_stretch=None, 39 | pitch_shift=None, 40 | save_dir=None, 41 | val_split=None, 42 | subset='train', 43 | variable_duration=True) 44 | train_dataset = train_generator.tf_dataset().prefetch(tf.data.experimental.AUTOTUNE) 45 | 46 | input_tensor = tf.keras.layers.Input(shape=(None,)) 47 | input_reshaped = tf.keras.layers.Reshape( 48 | target_shape=(-1, ))(input_tensor) 49 | 50 | x = LogMelgramLayer(num_fft=512, 51 | hop_length=256, 52 | sample_rate=16000, 53 | f_min=80, 54 | f_max=8000, 55 | num_mels=64, 56 | eps=1e-6, 57 | return_decibel=True, 58 | name='trainable_stft')(input_reshaped) 59 | model = tf.keras.Model(inputs=input_tensor, outputs=x) 60 | spectrograms = [] 61 | for batch in tqdm(train_dataset): 62 | spectrograms.append(np.squeeze(model.predict_on_batch(batch))) 63 | spectrograms_concat = np.concatenate(spectrograms) 64 | mean = np.mean(spectrograms_concat) 65 | std = np.std(spectrograms_concat) 66 | mean_std = {'mean': mean, 'std': std} 67 | print(dataset_base, mean, std) 68 | with open(join(dataset_base, 'mean_std.pkl'), 'wb') as f: 69 | pickle.dump(mean_std, f) 70 | 71 | 72 | if __name__=='__main__': 73 | datasets = glob.glob('/mnt/student/MauriceGerczuk/EmoSet/multiTaskSetup-wavs-with-test/*/') 74 | for dataset in datasets: 75 | compute_scaling(dataset) -------------------------------------------------------------------------------- /emo-net/data/loader.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import numpy as np 20 | import time 21 | import pandas as pd 22 | import numpy as np 23 | import itertools 24 | import csv 25 | import librosa 26 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 27 | from tensorflow.keras import applications 28 | from tensorflow.keras.utils import Sequence 29 | from tensorflow.keras.utils import to_categorical 30 | from sklearn.utils import class_weight 31 | from sklearn.preprocessing import LabelEncoder 32 | #from vgg16bn import Vgg16BN 33 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 34 | from tensorflow.keras.preprocessing.sequence import pad_sequences 35 | from sklearn.model_selection import StratifiedShuffleSplit 36 | from os.path import join, dirname, basename, relpath 37 | from os import makedirs 38 | from math import ceil 39 | from abc import ABC, abstractmethod 40 | from glob import glob 41 | from PIL import Image 42 | import tensorflow as tf 43 | import logging 44 | logger = logging.getLogger(__name__) 45 | 46 | class AudioDataGenerator(Sequence): 47 | def __init__(self, 48 | csv_file, 49 | directory, 50 | batch_size=32, 51 | window=1, 52 | shuffle=True, 53 | sr=16000, 54 | time_stretch=None, 55 | pitch_shift=None, 56 | save_dir=None, 57 | val_split=0.2, 58 | val_indices=None, 59 | subset='train', 60 | variable_duration=False): 61 | self.random_state = 42 62 | self.variable_duration = variable_duration 63 | self.files = [] 64 | self.classes = [] 65 | with open(csv_file) as f: 66 | reader = csv.reader(f, delimiter=',') 67 | header = next(reader) 68 | if 'label' in header: 69 | label_index = header.index('label') 70 | logger.info(f'Setup csv "{csv_file}" contains "label" column at index {label_index}.') 71 | 72 | else: 73 | label_index = len(header) - 1 74 | logger.warn(f'Setup csv "{csv_file}" does not contain "label" column. Choosing last column: "{header[label_index]}" instead.') 75 | if 'file' in header: 76 | path_index = header.index('file') 77 | logger.info(f'Setup csv "{csv_file}" contains "file" column at index {path_index}.') 78 | 79 | else: 80 | path_index = 0 81 | logger.warn(f'Setup csv "{csv_file}" does not contain "file" column. Choosing first column: "{header[path_index]}" instead.') 82 | for line in reader: 83 | self.files.append( 84 | join(directory, line[path_index])) 85 | self.classes.append(line[label_index]) 86 | 87 | logger.info(f'Parsed {len(self.files)} audio files') 88 | self.val_split = val_split 89 | self.train_indices = None 90 | self.val_indices = val_indices 91 | self.subset = subset 92 | 93 | 94 | 95 | self.label_binarizer = LabelEncoder() 96 | self.label_binarizer.fit(self.classes) 97 | 98 | if self.val_split is not None and subset == 'train': 99 | self.__create_split() 100 | elif not (self.val_indices is None): 101 | self.__apply_split() 102 | 103 | self.directory = directory 104 | self.window = window 105 | self.classes = self.label_binarizer.transform(self.classes) 106 | if len(self.label_binarizer.classes_) > 2: 107 | self.categorical_classes = to_categorical(self.classes) 108 | else: 109 | self.categorical_classes = self.classes 110 | self.class_indices = {c: i for i, c in enumerate(self.label_binarizer.classes_) } 111 | logger.info(f'Class indices: {self.class_indices}') 112 | self.batch_size = batch_size 113 | self.shuffle = shuffle 114 | self.time_stretch = time_stretch 115 | self.pitch_shift = pitch_shift 116 | self.save_dir = save_dir 117 | self.sr = sr 118 | np.random.seed(self.random_state) 119 | self.on_epoch_end() 120 | 121 | 122 | @staticmethod 123 | def load_audio(filename, label): 124 | raw = tf.io.read_file(filename) 125 | audio, sr = tf.audio.decode_wav(raw, desired_channels=1) 126 | audio = tf.reshape(audio, (-1,)) 127 | return audio, label 128 | 129 | 130 | @staticmethod 131 | def random_slice(audio, label, size): 132 | size = tf.math.minimum(tf.shape(audio), size) 133 | audio = tf.image.random_crop(audio, size, seed=42) 134 | return audio, label 135 | 136 | @staticmethod 137 | def center_slice(audio, label, size): 138 | duration = tf.shape(audio)[0] 139 | start = duration // 2 if duration // 2 > size else 0 140 | audio = audio[start:start+size] 141 | return audio, label 142 | 143 | 144 | def tf_dataset(self): 145 | dataset = tf.data.Dataset.from_tensor_slices((self.files, self.categorical_classes)) 146 | binary = len(self.categorical_classes.shape) < 2 147 | if self.shuffle: 148 | dataset = dataset.shuffle(len(self.files), seed=42) 149 | dataset = dataset.map(AudioDataGenerator.load_audio, num_parallel_calls=tf.data.experimental.AUTOTUNE) 150 | #dataset = dataset.filter(lambda x, _: tf.math.count_nonzero(x) > 0) 151 | if self.window is not None: 152 | window_size = int(self.window*self.sr) 153 | padded_size = window_size if not self.variable_duration else None 154 | if self.shuffle: 155 | dataset = dataset.map(lambda audio, label: AudioDataGenerator.random_slice(audio, label, size=window_size), num_parallel_calls=tf.data.experimental.AUTOTUNE) 156 | else: 157 | dataset = dataset.map(lambda audio, label: AudioDataGenerator.center_slice(audio, label, size=window_size), num_parallel_calls=tf.data.experimental.AUTOTUNE) 158 | else : 159 | padded_size = None 160 | padded_label_size = () if binary else (self.categorical_classes.shape[1],) 161 | dataset = dataset.padded_batch(self.batch_size, ((padded_size,), padded_label_size)) 162 | #dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE).cache() 163 | return dataset 164 | 165 | 166 | 167 | 168 | 169 | 170 | def __create_split(self): 171 | sss = StratifiedShuffleSplit(n_splits=1, test_size=self.val_split, random_state=self.random_state) 172 | for train_index, test_index in sss.split(self.files, self.classes): 173 | self.val_indices = test_index 174 | self.train_indices = train_index 175 | 176 | 177 | def __apply_split(self): 178 | indices = self.train_indices if self.subset == 'train' else self.val_indices 179 | for index in sorted(indices, reverse=True): 180 | del self.files[index] 181 | del self.classes[index] 182 | 183 | 184 | def __len__(self): 185 | return ceil(len(self.files) / self.batch_size) 186 | """ if len(self.files) % self.batch_size == 0: 187 | return int(len(self.files) / self.batch_size) 188 | else: 189 | return int(len(self.files) / self.batch_size) + 1 """ 190 | 191 | def __getitem__(self, index): 192 | # Generate indexes of the batch 193 | index = index % len(self) 194 | indices = self.indices[index * self.batch_size:min(len(self.indices), (index + 1) * 195 | self.batch_size)] 196 | 197 | files_batch = [self.files[k] for k in indices] 198 | y = np.asarray([self.categorical_classes[k] for k in indices]) 199 | 200 | # Generate data 201 | x = self.__data_generation(files_batch) 202 | 203 | return x, y 204 | 205 | def _set_index_array(self): 206 | self.indices = np.arange(len(self.files)) 207 | if self.shuffle: 208 | np.random.shuffle(self.indices) 209 | 210 | def on_epoch_end(self): 211 | 'Updates indexes after each epoch' 212 | self._set_index_array() 213 | 214 | def __data_generation(self, files): 215 | audio_data = [] 216 | 217 | for file in files: 218 | duration = librosa.core.get_duration(filename=file) 219 | 220 | if self.window is not None: 221 | stretched_window = self.window * ( 222 | 1 + self.time_stretch 223 | ) if self.time_stretch is not None else self.window 224 | if self.shuffle: 225 | start = np.random.randint(0, max(1, int(duration - stretched_window))) 226 | 227 | else: 228 | start = duration / 2 if duration / 2 > stretched_window else 0 # take the middle chunk 229 | y, sr = librosa.core.load(file, 230 | offset=start, 231 | duration=min(stretched_window, duration), 232 | sr=self.sr) 233 | y = self.__get_random_transform(y, sr) 234 | end_sample = min(int(self.window * sr), int(duration * sr)) 235 | y = y[:end_sample] 236 | else: 237 | y, sr = librosa.core.load(file, sr=self.sr) 238 | y = self.__get_random_transform(y, sr) 239 | 240 | if self.save_dir: 241 | rel_path = relpath(file, self.directory) 242 | save_path = join(self.save_dir, rel_path.wav) 243 | makedirs(dirname(save_path), exist_ok=True) 244 | librosa.output.write_wav( 245 | join(self.save_dir, rel_path), 246 | audio_data, sr) 247 | audio_data.append(y) 248 | if (self.window is not None) and (not self.variable_duration): 249 | audio_data = pad_sequences( 250 | audio_data, maxlen=int(self.window*self.sr), dtype='float32') 251 | else: 252 | audio_data = pad_sequences( 253 | audio_data, dtype='float32') 254 | 255 | return audio_data 256 | 257 | def __get_random_transform(self, y, sr): 258 | if self.time_stretch is not None: 259 | factor = np.random.normal(1, self.time_stretch) 260 | y = librosa.effects.time_stretch(y, factor) 261 | if self.pitch_shift is not None: 262 | steps = np.random.randint(0 - self.pitch_shift, 263 | 1 + self.pitch_shift) 264 | y = librosa.effects.pitch_shift(y, sr, steps) 265 | return y 266 | 267 | 268 | def benchmark(dataset, num_epochs=2): 269 | start_time = time.perf_counter() 270 | for epoch_num in range(num_epochs): 271 | for sample in dataset: 272 | # Performing a training step 273 | time.sleep(0.01) 274 | tf.print("Execution time:", time.perf_counter() - start_time) 275 | 276 | def benchmark_generator(generator, num_epochs=2): 277 | start_time = time.perf_counter() 278 | for epoch_num in range(num_epochs): 279 | for i in range(len(generator)): 280 | sample = generator[i] 281 | # Performing a training step 282 | #time.sleep(0.01) 283 | tf.print("Execution time:", time.perf_counter() - start_time) 284 | -------------------------------------------------------------------------------- /emo-net/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/models/__init__.py -------------------------------------------------------------------------------- /emo-net/models/adapter_resnet.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | 21 | kernel_regularizer = tf.keras.regularizers.l2(1e-6) 22 | 23 | channel_axis = -1 24 | 25 | class BasicBlock(object): 26 | def __init__(self, 27 | filters, 28 | factor, 29 | strides=2, 30 | dropout1=0, 31 | dropout2=0, 32 | shortcut=True, 33 | learnall=True, 34 | tasks=['IEMOCAP-4cl'], 35 | weight_decays=None, 36 | **kwargs): 37 | self.filters = filters 38 | self.factor = factor 39 | self.strides = strides 40 | self.dropout1 = tf.keras.layers.Dropout(dropout1) 41 | self.dropout2 = tf.keras.layers.Dropout(dropout2) 42 | self.shortcut = shortcut 43 | self.learnall = learnall 44 | self.tasks = weight_decays if weight_decays is not None else [1e-6]*len(self.tasks) 45 | self.weight_decays = weight_decays 46 | self.conv_task1 = ConvTasks(filters, 47 | factor, 48 | strides=strides, 49 | learnall=learnall, 50 | dropout=dropout1, 51 | tasks=tasks, 52 | weight_decays=self.weight_decays, 53 | **kwargs) 54 | self.conv_task2 = ConvTasks(filters, 55 | factor, 56 | strides=1, 57 | learnall=learnall, 58 | dropout=dropout2, 59 | tasks=tasks, 60 | weight_decays=self.weight_decays, 61 | **kwargs) 62 | 63 | self.relu = tf.keras.layers.Activation('relu') 64 | if self.shortcut: 65 | self.avg_pool = tf.keras.layers.AveragePooling2D((2, 2), padding='same') 66 | self.lmbda = tf.keras.layers.Lambda(lambda x: x * 0) 67 | self.add = tf.keras.layers.Add() 68 | 69 | def __call__(self, input_tensor, task): 70 | residual = input_tensor 71 | x = self.conv_task1(input_tensor, task=task) 72 | x = self.relu(x) 73 | x = self.conv_task2(x, task=task) 74 | if self.shortcut: 75 | residual = self.avg_pool(residual) 76 | residual0 = self.lmbda(residual) 77 | residual = tf.keras.layers.concatenate([residual, residual0], axis=-1) 78 | x = self.add([residual, x]) 79 | x = self.relu(x) 80 | return x 81 | 82 | def _add_new_task(self, task, weight_decay=1e-6): 83 | self.conv_task1._add_new_task(task, weight_decay=weight_decay) 84 | self.conv_task2._add_new_task(task, weight_decay=weight_decay) 85 | 86 | 87 | class ConvTasks(object): 88 | def __init__(self, 89 | filters, 90 | factor=1, 91 | strides=1, 92 | learnall=True, 93 | dropout=0, 94 | tasks=['IEMOCAP-4cl', 'GEMEP'], 95 | weight_decays=None, 96 | reuse_batchnorm=False, 97 | **kwargs): 98 | self.filters = filters 99 | self.factor = factor 100 | self.strides = strides 101 | self.learnall = learnall 102 | self.dropout = tf.keras.layers.Dropout(dropout) 103 | self.tasks = tasks 104 | self.weight_decays = weight_decays if weight_decays is not None else [1e-6]*len(self.tasks) 105 | self.reuse_batchnorm = reuse_batchnorm 106 | 107 | # shared parameters 108 | self.conv2d = tf.keras.layers.Convolution2D(self.filters * self.factor, (3, 3), 109 | strides=self.strides, 110 | padding='same', 111 | kernel_initializer='he_normal', 112 | use_bias=False, 113 | trainable=self.learnall, 114 | kernel_regularizer=kernel_regularizer) 115 | 116 | # task specificparameters 117 | self.res_adapts = {} 118 | self.add = tf.keras.layers.Add() 119 | self.bns = {} 120 | self.core_bn = tf.keras.layers.BatchNormalization( 121 | axis=channel_axis, 122 | name=f'core_{self.conv2d.name}_batch_normalization') 123 | for task, weight_decay in zip(self.tasks, self.weight_decays): 124 | self._add_new_task(task, weight_decay=weight_decay) 125 | 126 | def __call__(self, input_tensor, task): 127 | in_t = input_tensor 128 | if task is None: 129 | in_t = self.dropout(in_t) 130 | x = self.conv2d(in_t) 131 | if task is not None: 132 | adapter_in = self.dropout(in_t) 133 | res_adapt = self.res_adapts[task](adapter_in) 134 | x = self.add([x, res_adapt]) 135 | if self.reuse_batchnorm or task is None: 136 | x = self.core_bn(x) 137 | else: 138 | x = self.bns[task](x) 139 | return x 140 | 141 | def _add_new_task(self, task, weight_decay=1e-6): 142 | assert task not in self.bns, 'Task already exists!' 143 | self.res_adapts[task] = tf.keras.layers.Convolution2D( 144 | self.filters * self.factor, (1, 1), 145 | padding='valid', 146 | kernel_initializer='he_normal', 147 | strides=self.strides, 148 | use_bias=False, 149 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay), 150 | name=f'{task}_{self.conv2d.name}_adapter') 151 | 152 | if not self.reuse_batchnorm: 153 | self.bns[task] = tf.keras.layers.BatchNormalization( 154 | axis=channel_axis, 155 | name=f'{task}_{self.conv2d.name}_batch_normalization') 156 | 157 | 158 | class ResNet(object): 159 | def __init__(self, 160 | filters=32, 161 | factor=1, 162 | N=2, 163 | verbose=1, 164 | learnall=True, 165 | dropout1=0, 166 | dropout2=0, 167 | tasks=['IEMOCAP-4cl', 'GEMEP'], 168 | weight_decays=None, 169 | reuse_batchnorm=False, 170 | input_bn=False): 171 | self.filters = filters 172 | self.factor = factor 173 | self.N = N 174 | self.learnall = learnall 175 | self.dropout1 = dropout1 176 | self.dropout2 = dropout2 177 | self.tasks = tasks 178 | self.weight_decays = weight_decays if weight_decays is not None else [1e-6]*len(self.tasks) 179 | self.reuse_batchnorm = reuse_batchnorm 180 | self.input_bn = input_bn 181 | if self.input_bn: 182 | self.input_core_bn = tf.keras.layers.BatchNormalization(axis=channel_axis, name=f'core_input_batch_normalization') 183 | self.input_bns = { 184 | task: tf.keras.layers.BatchNormalization(axis=channel_axis, 185 | name=f'{task}_input_batch_normalization') 186 | for task in self.tasks 187 | } 188 | 189 | # conv blocks 190 | self.pre_conv = ConvTasks(filters=self.filters, 191 | factor=factor, 192 | strides=1, 193 | learnall=learnall, 194 | tasks=self.tasks, 195 | weight_decays=self.weight_decays, 196 | reuse_batchnorm=reuse_batchnorm) 197 | self.blocks = [] 198 | self.nb_conv = 1 199 | for i in range(1, 4): 200 | block = BasicBlock(self.filters * (2**i), 201 | self.factor, 202 | strides=2, 203 | dropout1=self.dropout1, 204 | dropout2=self.dropout2, 205 | shortcut=True, 206 | learnall=self.learnall, 207 | tasks=self.tasks, 208 | weight_decays=self.weight_decays, 209 | reuse_batchnorm=reuse_batchnorm) 210 | self.blocks.append(block) 211 | for j in range(N - 1): 212 | block = BasicBlock(filters=self.filters * 213 | (2**i), 214 | factor=self.factor, 215 | strides=1, 216 | dropout1=self.dropout1, 217 | dropout2=self.dropout2, 218 | shortcut=False, 219 | learnall=self.learnall, 220 | tasks=self.tasks, 221 | weight_decays=self.weight_decays, 222 | reuse_batchnorm=reuse_batchnorm) 223 | self.blocks.append(block) 224 | self.nb_conv += 2 225 | self.nb_conv += 6 226 | 227 | # bns and relus 228 | self.relu = tf.keras.layers.Activation('relu') 229 | self.bns = { 230 | task: tf.keras.layers.BatchNormalization(axis=channel_axis, 231 | name=f'{task}_final_batch_normalization') 232 | for task in self.tasks 233 | } 234 | self.core_bn = tf.keras.layers.BatchNormalization(axis=channel_axis, 235 | name=f'core_final_batch_normalization') 236 | 237 | def _add_new_task(self, task, weight_decay=1e-6): 238 | assert task not in self.bns, f'Task {task} already exists!' 239 | self.pre_conv._add_new_task(task, weight_decay=weight_decay) 240 | for block in self.blocks: 241 | block._add_new_task(task, weight_decay=weight_decay) 242 | if not self.reuse_batchnorm: 243 | self.bns[task] = tf.keras.layers.BatchNormalization(axis=channel_axis, name=f'{task}_final_batch_normalization') 244 | if self.input_bn: 245 | self.input_bns[task] = tf.keras.layers.BatchNormalization(axis=channel_axis, name=f'{task}_input_batch_normalization') 246 | 247 | def __call__(self, input_tensor, task): 248 | if self.input_bn: 249 | if task is None or self.reuse_batchnorm: 250 | x = self.input_core_bn(input_tensor) 251 | else: 252 | x = self.input_bns[task](input_tensor) 253 | else: 254 | x = input_tensor 255 | x = self.pre_conv(x, task=task) 256 | for block in self.blocks: 257 | x = block(x, task=task) 258 | if task is None or self.reuse_batchnorm: 259 | x = self.core_bn(x) 260 | else: 261 | x = self.bns[task](x) 262 | x = self.relu(x) 263 | return x -------------------------------------------------------------------------------- /emo-net/models/adapter_rnn.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | #from .attention import SeqSelfAttention, SeqWeightedAttention 21 | 22 | kernel_regularizer = tf.keras.regularizers.l2(1e-6) 23 | rnn_regularizer = tf.keras.regularizers.L1L2(1e-6) 24 | 25 | 26 | class RNNWithAdapters(object): 27 | def __init__(self, 28 | input_dims, 29 | hidden_size=512, 30 | learnall=True, 31 | dropout=0.2, 32 | layers=2, 33 | input_projection_factor=1, 34 | adapter_projection_factor=2, 35 | tasks=[], 36 | recurrent_cell='lstm', 37 | bidirectional=False, 38 | input_projection=True, 39 | input_bn=False, 40 | downpool=None, 41 | share_feature_layer=False, 42 | share_attention=False, 43 | use_attention=True, 44 | **kwargs): 45 | 46 | self.hidden_size = hidden_size 47 | self.learnall = learnall 48 | self.dropout = dropout 49 | self.tasks = tasks 50 | self.input_dims = input_dims 51 | self.adapter_projection_factor = adapter_projection_factor 52 | self.input_projection = input_projection 53 | self.input_projection_factor = input_projection_factor 54 | self.tasks = tasks 55 | self.recurrent_cell = recurrent_cell 56 | self.layers = layers 57 | self.bidirectional = bidirectional 58 | self.input_bn = input_bn 59 | self.cnn_input = len(input_dims) > 3 60 | self.share_feature_layer = share_feature_layer 61 | self.use_attention = use_attention 62 | self.share_attention = share_attention 63 | if self.cnn_input: # cnn feature extractor 64 | feature_dims = input_dims[1] * input_dims[3] 65 | else: 66 | feature_dims = input_dims[-1] 67 | self.reorder_dims = tf.keras.layers.Permute((2, 1, 3)) 68 | if downpool is not None: 69 | self.downpool = tf.keras.layers.AveragePooling1D( 70 | pool_size=downpool, strides=downpool, padding='same', name='rnn_downpool') 71 | else: 72 | self.downpool = None 73 | self.reshape = tf.keras.layers.Reshape(target_shape=(-1, 74 | feature_dims)) 75 | if self.input_bn: 76 | self.input_bns = {task: tf.keras.layers.BatchNormalization( 77 | trainable=True, name=f'{task}_rnn_input_bn') for task in tasks} 78 | if not self.input_bns: 79 | self.core_input_bn = tf.keras.layers.BatchNormalization( 80 | trainable=True, name=f'core_rnn_input_bn') 81 | self.projection = tf.keras.layers.Dense(feature_dims // self.input_projection_factor, 82 | activation=None, 83 | trainable=learnall, 84 | kernel_regularizer=kernel_regularizer) 85 | self.adapter_hidden_size = self.hidden_size * \ 86 | 2 if self.bidirectional else self.hidden_size 87 | self.rnns = [] 88 | self.selfattentions = [] 89 | self.selfattention = [] 90 | self.adapters = [] 91 | for i in range(self.layers): 92 | rnn = tf.keras.layers.GRU(self.hidden_size, 93 | dropout=self.dropout, 94 | return_sequences=True, 95 | trainable=learnall, kernel_regularizer=rnn_regularizer) if self.recurrent_cell.lower() == 'gru' else tf.keras.layers.LSTM(self.hidden_size, 96 | dropout=self.dropout, 97 | return_sequences=True, 98 | trainable=learnall, kernel_regularizer=rnn_regularizer) 99 | if self.bidirectional: 100 | rnn = tf.keras.layers.Bidirectional(rnn) 101 | self.rnns.append(rnn) 102 | self.adapters.append({ 103 | task: RNNAdapter(self.adapter_hidden_size, self.adapter_projection_factor, 104 | task, i) 105 | for task in tasks 106 | }) 107 | if i < self.layers - 1: 108 | self.selfattentions.append({task: SeqSelfAttention( 109 | attention_activation='sigmoid', 110 | kernel_regularizer=kernel_regularizer, 111 | use_attention_bias=False, 112 | trainable=True, 113 | name=f'{task}_self_attention_{i}') for task in tasks}) 114 | self.selfattention.append(SeqSelfAttention( 115 | attention_activation='sigmoid', 116 | kernel_regularizer=kernel_regularizer, 117 | use_attention_bias=False, 118 | trainable=learnall, 119 | name=f'core_self_attention_{i}')) 120 | 121 | self.add = tf.keras.layers.Add() 122 | 123 | 124 | self.weighted_attentions = {task: SeqWeightedAttention( 125 | trainable=True, name=f'{task}_seq_weighted_attention') for task in tasks} 126 | self.weighted_attention = SeqWeightedAttention( 127 | trainable=learnall, name=f'core_seq_weighted_attention') 128 | 129 | def __call__(self, x, task, mask=None): 130 | if self.input_bn: 131 | if task is not None: 132 | self.input_bns[task](x) 133 | else: 134 | self.core_input_bn(x) 135 | if self.cnn_input: 136 | x = self.reorder_dims(x) 137 | x = self.reshape(x) 138 | if self.downpool is not None: 139 | x = self.downpool(x) 140 | #x = self.mask(x) 141 | if self.input_projection: 142 | x = self.projection(x) 143 | for i in range(self.layers): 144 | x = self.rnns[i](x, mask=mask) 145 | if task is not None: 146 | adapter = self.adapters[i][task](x) 147 | x = self.add([x, adapter]) 148 | if i < self.layers - 1 and self.use_attention: 149 | if self.share_attention: 150 | x = self.selfattention[i](x, mask=mask) 151 | else: 152 | x = self.selfattentions[i][task](x, mask=mask) 153 | else: 154 | if i < self.layers - 1 and self.use_attention: 155 | x = self.selfattention[i](x, mask=mask) 156 | if task is not None and not self.share_feature_layer: 157 | x = self.weighted_attentions[task](x, mask=mask) 158 | else: 159 | x = self.weighted_attention(x, mask=mask) 160 | return x 161 | 162 | def _add_new_task(self, task): 163 | assert task not in self.adapters, f'Task {task} already exists!' 164 | for i in range(self.layers): 165 | self.adapters[i][task] = RNNAdapter(self.adapter_hidden_size, 166 | self.adapter_projection_factor, task,i) 167 | if i < self.layers - 1: 168 | self.selfattentions[i][task] = SeqSelfAttention( 169 | attention_activation='sigmoid', 170 | kernel_regularizer=kernel_regularizer, 171 | use_attention_bias=False, 172 | trainable=True, 173 | name=f'{task}_self_attention_{i}') 174 | self.weighted_attentions[task] = SeqWeightedAttention( 175 | trainable=True, name=f'{task}_seq_weighted_attention') 176 | if self.input_bn: 177 | self.input_bns[task] = tf.keras.layers.BatchNormalization( 178 | trainable=True, name=f'{task}_rnn_input_bn') 179 | 180 | 181 | class RNNAdapter(object): 182 | def __init__(self, input_size, downprojection=4, task='IEMOCAP', index=1): 183 | self.input_size = input_size 184 | self.downprojection_factor = downprojection 185 | self.task = task 186 | self.layer_norm = tf.keras.layers.TimeDistributed(tf.keras.layers.LayerNormalization(), 187 | name=f'{task}_rnn_adapter_{index}_layer_norm') 188 | self.downprojection = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense( 189 | self.input_size // self.downprojection_factor, activation='relu', use_bias=False, kernel_regularizer=kernel_regularizer), 190 | name=f'{task}_rnn_adapter_{index}_downprojection') 191 | self.upprojection = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.input_size, use_bias=False, kernel_regularizer=kernel_regularizer), 192 | name=f'{task}_rnn_adapter_{index}_upprojection') 193 | 194 | def __call__(self, x): 195 | x = self.layer_norm(x) 196 | x = self.downprojection(x) 197 | x = self.upprojection(x) 198 | #x = self.selfattention(x) 199 | return x 200 | -------------------------------------------------------------------------------- /emo-net/models/attention.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | import tensorflow.keras.backend as K 21 | 22 | 23 | class Attention2Dtanh(tf.keras.layers.Layer): 24 | def __init__(self, lmbda=0.3, mlp_units=256, **kwargs): 25 | super(Attention2Dtanh, self).__init__(**kwargs) 26 | self.mlp_units = mlp_units 27 | self.tanh = tf.keras.layers.Activation('tanh') 28 | self.lmbda = lmbda 29 | 30 | def build(self, input_shape): 31 | self.w = self.add_weight(shape=(input_shape[-1], self.mlp_units), 32 | initializer='random_normal', 33 | trainable=True, 34 | name='W') 35 | self.b = self.add_weight(shape=(self.mlp_units, ), 36 | initializer='random_normal', 37 | trainable=True, 38 | name='b') 39 | self.u = self.add_weight(shape=(input_shape[-1], ), 40 | initializer='random_normal', 41 | trainable=True, 42 | name='u') 43 | self.flatten = tf.keras.layers.Reshape(target_shape=(-1, 44 | input_shape[-1])) 45 | super(Attention2Dtanh, self).build(input_shape) 46 | 47 | def call(self, inputs): 48 | flat_input = self.flatten(inputs) 49 | x = tf.matmul(flat_input, self.w) + self.b 50 | x = self.tanh(x) 51 | e = tf.tensordot(self.u, x, axes=[[0], [-1]]) * self.lmbda 52 | a = tf.nn.softmax(e, axis=-1) 53 | weighted_sum = tf.reduce_sum(tf.expand_dims(a, -1) * flat_input, 54 | axis=1) 55 | return weighted_sum 56 | 57 | def get_config(self): 58 | config = {'lmbda': self.lmbda, 'mlp_units': self.mlp_units} 59 | base_config = super(Attention2Dtanh, self).get_config() 60 | return dict(list(base_config.items()) + list(config.items())) 61 | -------------------------------------------------------------------------------- /emo-net/models/build_model.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | tf.random.set_seed(42) 21 | 22 | import tensorflow.keras.backend as K 23 | import h5py 24 | from .input_layers import * 25 | from .adapter_rnn import * 26 | from .adapter_resnet import * 27 | from .attention import * 28 | from ..utils import array_list_equal 29 | 30 | import logging 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | 35 | def avgpool(x): 36 | x = tf.keras.layers.AveragePooling2D((8, 8))(x) 37 | x = tf.keras.layers.Flatten()(x) 38 | return x 39 | 40 | def global_avgpool(x): 41 | x = tf.keras.layers.GlobalAveragePooling2D()(x) 42 | x = tf.keras.layers.Flatten()(x) 43 | return x 44 | 45 | 46 | def infer_tasks_from_weightfile(initial_weights): 47 | with h5py.File(initial_weights) as f: 48 | base_tasks = [] 49 | base_nb_classes = [] 50 | for k in f['model_weights']: 51 | prefices = ('activation', 'add', 'average_pooling', 52 | 'batch_normalization', 'bidirectional', 'concat', 'conv2d', 'dropout', 'attention', 53 | 'dense', 'flatten', 'input', 'lambda', 54 | 'normalization2d', 'reshape', 'trainable_stft', 'apply_zero_mask', 'core', 'lstm', 'masking', 'permute', 'pooled', 'seq', 'zero_mask', 'adapter', 'mask', 'expand', 'mfcc', 'downpool') 55 | skip_layer = any([prefix in k for prefix in prefices]) 56 | if not skip_layer: 57 | task = k 58 | classes = _find_n_classes(f['model_weights'][k], k) 59 | #classes = f['model_weights'][k]['softmax'][k]['softmax']['kernel:0'].shape[1] 60 | logger.info(f'Found task {k} with {classes} classes.') 61 | base_tasks.append(task) 62 | base_nb_classes.append(classes) 63 | return base_tasks, base_nb_classes 64 | 65 | 66 | def _find_n_classes(weight_dict, task): 67 | if 'sigmoid' in weight_dict: 68 | return 2 69 | elif 'kernel:0' in weight_dict: 70 | output_shape = weight_dict['kernel:0'].shape[1] 71 | if output_shape == 1: # binary 72 | output_shape += 1 73 | return output_shape 74 | elif 'softmax' in weight_dict: 75 | return _find_n_classes(weight_dict['softmax'], task) 76 | elif task in weight_dict: 77 | return _find_n_classes(weight_dict[task], task) 78 | 79 | 80 | def input_features_and_mask(audio_in, num_fft=1024, hop_length=512, sample_rate=16000, f_min=20, f_max=8000, num_mels=128, eps=1e-6, return_decibel=False, num_mfccs=None): 81 | input_features = LogMelgramLayer(num_fft=num_fft, 82 | hop_length=hop_length, 83 | sample_rate=sample_rate, 84 | f_min=f_min, 85 | f_max=f_max, 86 | num_mels=num_mels, 87 | eps=eps, 88 | return_decibel=return_decibel, 89 | name='trainable_stft') 90 | x = input_features(audio_in) 91 | mask = ComputeMask(input_features.num_fft, 92 | input_features.hop_length)(audio_in) 93 | if num_mfccs is not None: 94 | x = MFCCLayer(num_mfccs=num_mfccs)(x) 95 | return x, mask 96 | 97 | 98 | def create_multi_task_networks(input_dim, feature_extractor='cnn', 99 | initial_weights=None, 100 | base_nb_classes=None, 101 | learnall=True, 102 | num_mels=128, 103 | base_tasks=None, 104 | new_tasks=None, 105 | new_nb_classes=None, 106 | mode=None, 107 | random_noise=None, 108 | input_bn=False, 109 | share_feature_layer=True, 110 | base_weight_decays=None, 111 | new_weight_decays=None, 112 | **kwargs): 113 | if feature_extractor == 'cnn': 114 | return create_multi_task_resnets(input_dim=input_dim, mode=mode, num_mels=num_mels, initial_weights=initial_weights, base_nb_classes=base_nb_classes, base_weight_decays=base_weight_decays, new_weight_decays=new_weight_decays, learnall=learnall, base_tasks=base_tasks, new_tasks=new_tasks, new_nb_classes=new_nb_classes, random_noise=random_noise, input_bn=input_bn, share_feature_layer=share_feature_layer, **kwargs) 115 | elif feature_extractor == 'rnn': 116 | return create_multi_task_rnn(input_dim=input_dim, mode=mode, num_mels=num_mels, initial_weights=initial_weights, base_nb_classes=base_nb_classes, learnall=learnall, base_tasks=base_tasks, new_tasks=new_tasks, new_nb_classes=new_nb_classes, random_noise=random_noise, input_bn=input_bn, share_feature_layer=share_feature_layer, **kwargs) 117 | elif feature_extractor == 'vgg16': 118 | return create_multi_task_vgg16(input_dim=input_dim, 119 | tasks=base_tasks, 120 | num_mels=num_mels, 121 | nb_classes=base_nb_classes, 122 | random_noise=random_noise, 123 | initial_weights=initial_weights, 124 | share_feature_layer=share_feature_layer, 125 | **kwargs) 126 | elif feature_extractor == 'fusion': 127 | return create_multi_task_fusion(input_dim=input_dim, mode=mode, num_mels=num_mels, initial_weights=initial_weights, base_nb_classes=base_nb_classes, learnall=learnall, base_tasks=base_tasks, new_tasks=new_tasks, new_nb_classes=new_nb_classes, random_noise=random_noise, input_bn=input_bn, share_feature_layer=share_feature_layer, **kwargs) 128 | 129 | 130 | def create_multi_task_fusion(input_dim,filters=32, 131 | factor=1, 132 | N=4, 133 | hidden_dim=512, 134 | cell='lstm', 135 | number_of_layers=2, 136 | down_pool=8, 137 | bidirectional=False, 138 | num_mels=128, 139 | learnall=True, 140 | learnall_classifier=True, 141 | mode='adapters', 142 | dropout1=0, 143 | dropout2=0, 144 | rnn_dropout=0.2, 145 | base_tasks=['EMO-DB', 'GEMEP'], 146 | new_tasks=None, 147 | base_nb_classes=[6, 10], 148 | new_nb_classes=None, 149 | initial_weights=None, 150 | random_noise=0.1, 151 | reuse_batchnorm=False, 152 | input_bn=False, 153 | share_feature_layer=False): 154 | channel_axis = -1 155 | if base_tasks is None: 156 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!' 157 | logger.info( 158 | 'Trying to determine trained tasks from initial weights...') 159 | base_tasks, base_nb_classes = infer_tasks_from_weightfile( 160 | initial_weights) 161 | 162 | input_tensor = tf.keras.layers.Input(shape=input_dim) 163 | 164 | input_reshaped = tf.keras.layers.Reshape( 165 | target_shape=(-1, ))(input_tensor) 166 | 167 | 168 | x, mask = input_features_and_mask(input_reshaped, num_fft=1024, 169 | hop_length=512, 170 | sample_rate=16000, 171 | f_min=20, 172 | f_max=8000, 173 | num_mels=num_mels, 174 | eps=1e-6, 175 | return_decibel=True, 176 | num_mfccs=None) 177 | 178 | adapter_rnn = RNNWithAdapters(K.int_shape(x), 179 | hidden_size=hidden_dim, 180 | learnall=learnall, 181 | dropout=rnn_dropout, 182 | input_projection_factor=1, 183 | adapter_projection_factor=4, 184 | bidirectional=bidirectional, 185 | layers=number_of_layers, 186 | recurrent_cell=cell, 187 | input_bn=input_bn, 188 | downpool=down_pool, 189 | input_projection=False, 190 | # tasks=base_tasks, 191 | tasks=base_tasks if mode == 'adapters' else [], 192 | share_feature_layer=share_feature_layer) 193 | if down_pool is not None: 194 | mask = PoolMask((down_pool,))(mask) 195 | 196 | 197 | expand_dims = tf.keras.layers.Lambda( 198 | lambda x: tf.expand_dims(x, 3), name='expand_input_dims') 199 | x_resnet = expand_dims(x) 200 | x_resnet = tf.keras.layers.Permute((2, 1, 3))(x_resnet) 201 | x_rnn = x 202 | 203 | 204 | adapter_resnet = ResNet(filters=filters, 205 | factor=factor, 206 | N=N, 207 | learnall=learnall, 208 | dropout1=dropout1, 209 | dropout2=dropout2, 210 | # tasks=base_tasks, 211 | tasks=base_tasks if mode == 'adapters' else [], 212 | input_bn=input_bn) 213 | 214 | if new_tasks is not None: 215 | really_new_tasks = [t for t in new_tasks if t not in base_tasks] 216 | new_nb_classes = [ 217 | c for c, t in zip(new_nb_classes, new_tasks) if t not in base_tasks 218 | ] 219 | else: 220 | really_new_tasks = [] 221 | new_nb_classes = [] 222 | 223 | task_models = {} 224 | outputs = [] 225 | attention2d = None 226 | for task, classes in zip(base_tasks, base_nb_classes): 227 | logger.info(f'Building model for {task} with {classes} classes...') 228 | adapters_in = task if mode == 'adapters' else None 229 | y_resnet = adapter_resnet(x_resnet, task=adapters_in) 230 | y_rnn = adapter_rnn(x_rnn, task=adapters_in, mask=mask) 231 | 232 | 233 | if attention2d is None: 234 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y_resnet)[-1], name=f'core_2d_attention', trainable=not(mode=='adapters')) 235 | if not share_feature_layer and mode == 'adapters': 236 | attention2d = Attention2Dtanh(name=f'{task}_2d_attention', mlp_units=K.int_shape(y_resnet)[-1], trainable=True) 237 | y_resnet = attention2d(y_resnet) 238 | 239 | y = tf.keras.layers.Concatenate()([y_resnet, y_rnn]) 240 | 241 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y) 242 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y) 243 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y) 244 | y = tf.keras.layers.Dropout(0.2)(y) 245 | if classes == 2: 246 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=task)(y) 247 | else: 248 | y = tf.keras.layers.Dense( 249 | classes, activation='softmax', name=task)(y) 250 | model = tf.keras.Model(inputs=input_tensor, outputs=y) 251 | outputs.append(y) 252 | task_models[task] = model 253 | 254 | if really_new_tasks is not None and new_nb_classes is not None: 255 | for task, classes in zip(really_new_tasks, new_nb_classes): 256 | logger.info(f'Building model for {task} with {classes} classes...') 257 | adapter_resnet._add_new_task(task) 258 | adapter_rnn._add_new_task(task) 259 | y_resnet = adapter_resnet(x, task) 260 | y_rnn = adapter_rnn(x_rnn, task, mask=mask) 261 | 262 | if attention2d is None: 263 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y_resnet)[-1], name=f'core_2d_attention', trainable=not(mode=='adapters')) 264 | if not share_feature_layer and mode == 'adapters': 265 | attention2d = Attention2Dtanh(lmbda=0.3, name=f'{task}_2d_attention', mlp_units=K.int_shape(y_resnet)[-1], trainable=True) 266 | y_resnet = attention2d(y_resnet) 267 | 268 | y = tf.keras.layers.Concatenate()([y_resnet, y_rnn]) 269 | 270 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y) 271 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y) 272 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y) 273 | y = tf.keras.layers.Dropout(0.2)(y) 274 | if classes == 2: 275 | y = tf.keras.layers.Dense( 276 | 1, activation='sigmoid', name=task)(y) 277 | else: 278 | y = tf.keras.layers.Dense( 279 | classes, activation='softmax', name=task)(y) 280 | model = tf.keras.Model(inputs=input_tensor, outputs=y) 281 | outputs.append(y) 282 | task_models[task] = model 283 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs) 284 | 285 | if initial_weights is not None: 286 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights) 287 | 288 | return task_models, shared_model 289 | 290 | def create_multi_task_resnets(input_dim, 291 | filters=32, 292 | factor=1, 293 | N=4, 294 | num_mels=128, 295 | learnall=True, 296 | learnall_classifier=True, 297 | mode='adapters', 298 | dropout1=0, 299 | dropout2=0, 300 | rnn_dropout=0.2, 301 | base_tasks=['EMO-DB', 'GEMEP'], 302 | new_tasks=None, 303 | base_weight_decays=None, 304 | new_weight_decays=None, 305 | base_nb_classes=[6, 10], 306 | new_nb_classes=None, 307 | initial_weights=None, 308 | random_noise=0.1, 309 | classifier='rnn', 310 | input_bn=False, 311 | share_feature_layer=False): 312 | 313 | channel_axis = -1 314 | if base_tasks is None: 315 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!' 316 | logger.info( 317 | 'Trying to determine trained tasks from initial weights...') 318 | base_tasks, base_nb_classes = infer_tasks_from_weightfile( 319 | initial_weights) 320 | 321 | base_weight_decays = base_weight_decays if base_weight_decays is not None else [1e-6]*len(base_tasks) 322 | 323 | if new_tasks is not None: 324 | really_new_tasks = [t for t in new_tasks if t not in base_tasks] 325 | new_nb_classes = [ 326 | c for c, t in zip(new_nb_classes, new_tasks) if t not in base_tasks 327 | ] 328 | else: 329 | really_new_tasks = [] 330 | new_nb_classes = [] 331 | 332 | new_weight_decays = new_weight_decays if new_weight_decays is not None else [1e-6]*len(really_new_tasks) 333 | 334 | # check if batchnorm should be reused 335 | if len(new_tasks) != 1: 336 | reuse_batchnorm = False 337 | elif new_tasks[0] in base_tasks and len(base_tasks) > 1: 338 | reuse_batchnorm = False 339 | else: 340 | reuse_batchnorm = True 341 | print(reuse_batchnorm) 342 | task_models = {} 343 | outputs = [] 344 | adapter_rnn = None 345 | attention2d = None 346 | input_tensor = tf.keras.layers.Input(shape=input_dim) 347 | variable_duration = not (classifier == 'avgpool') 348 | if variable_duration: 349 | input_reshaped = tf.keras.layers.Reshape( 350 | target_shape=(-1, ))(input_tensor) 351 | else: 352 | input_reshaped = tf.keras.layers.Reshape( 353 | target_shape=(input_dim[0], ))(input_tensor) 354 | 355 | x, mask = input_features_and_mask(input_reshaped, num_fft=1024, 356 | hop_length=512, 357 | sample_rate=16000, 358 | f_min=20, 359 | f_max=8000, 360 | num_mels=num_mels, 361 | eps=1e-6, 362 | return_decibel=False, 363 | num_mfccs=None) 364 | 365 | pooled_mask = PoolMask((8,))(mask) 366 | 367 | expand_dims = tf.keras.layers.Lambda( 368 | lambda x: tf.expand_dims(x, 3), name='expand_input_dims') 369 | x = expand_dims(x) 370 | x = tf.keras.layers.Permute((2, 1, 3))(x) 371 | adapter_resnet = ResNet(filters=filters, 372 | factor=factor, 373 | N=N, 374 | learnall=learnall, 375 | dropout1=dropout1, 376 | dropout2=dropout2, 377 | weight_decays=base_weight_decays, 378 | reuse_batchnorm=reuse_batchnorm, 379 | tasks=base_tasks if mode == 'adapters' else [], 380 | input_bn=input_bn) 381 | 382 | 383 | for task, classes in zip(base_tasks, base_nb_classes): 384 | logger.info(f'Building model for {task} with {classes} classes...') 385 | adapters_in_resnet = task if mode == 'adapters' else None 386 | y = adapter_resnet(x, task=adapters_in_resnet) 387 | 388 | if initial_weights is not None and classifier == 'avgpool': # might need new last dense layer 389 | name = f'{task}_1' 390 | else: 391 | name = task 392 | if classifier == 'avgpool': 393 | y = avgpool(y) 394 | elif classifier == 'FCNAttention': 395 | if attention2d is None: 396 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y)[-1], name=f'core_2d_attention', trainable=learnall_classifier) 397 | if not share_feature_layer and mode == 'adapters': 398 | attention2d = Attention2Dtanh(name=f'{task}_2d_attention', mlp_units=K.int_shape(y)[-1], trainable=True) 399 | y = attention2d(y) 400 | 401 | elif classifier == 'rnn': 402 | if adapter_rnn is None: 403 | adapter_rnn = RNNWithAdapters(K.int_shape(y), 404 | hidden_size=K.int_shape(y)[-1], 405 | learnall=learnall_classifier, 406 | dropout=rnn_dropout, 407 | input_projection_factor=4, 408 | adapter_projection_factor=4, 409 | share_feature_layer=share_feature_layer, 410 | # tasks=base_tasks, 411 | tasks=base_tasks if mode == 'adapters' else []) 412 | 413 | y = adapter_rnn(y, adapters_in_resnet, mask=pooled_mask) 414 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{name}_dense')(y) 415 | y = tf.keras.layers.BatchNormalization(name=f'{name}_dense_batchnorm')(y) 416 | y = tf.keras.layers.Activation('relu', name=f'{name}_dense_relu')(y) 417 | y = tf.keras.layers.Dropout(0.2)(y) 418 | if classes == 2: 419 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=name)(y) 420 | else: 421 | y = tf.keras.layers.Dense( 422 | classes, activation='softmax', name=name)(y) 423 | model = tf.keras.Model(inputs=input_tensor, outputs=y) 424 | outputs.append(y) 425 | task_models[task] = model 426 | 427 | if really_new_tasks is not None and new_nb_classes is not None: 428 | for task, classes, weight_decay in zip(really_new_tasks, new_nb_classes, new_weight_decays): 429 | logger.info(f'Building model for {task} with {classes} classes...') 430 | adapter_resnet._add_new_task(task, weight_decay=weight_decay) 431 | y = adapter_resnet(x, task) 432 | 433 | if classifier == 'avgpool': 434 | y = avgpool(y) 435 | elif classifier == 'rnn': 436 | adapter_rnn._add_new_task(task) 437 | y = adapter_rnn(y, task) 438 | elif classifier == 'FCNAttention': 439 | if attention2d is None: 440 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y)[-1], name=f'core_2d_attention', trainable=learnall_classifier) 441 | if not share_feature_layer and mode == 'adapters': 442 | attention2d = Attention2Dtanh(lmbda=0.3, name=f'{task}_2d_attention', mlp_units=K.int_shape(y)[-1], trainable=True) 443 | y = attention2d(y) 444 | 445 | 446 | 447 | 448 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y) 449 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y) 450 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y) 451 | y = tf.keras.layers.Dropout(0.2)(y) 452 | if classes == 2: 453 | y = tf.keras.layers.Dense( 454 | 1, activation='sigmoid', name=task)(y) 455 | else: 456 | y = tf.keras.layers.Dense( 457 | classes, activation='softmax', name=task)(y) 458 | model = tf.keras.Model(inputs=input_tensor, outputs=y) 459 | outputs.append(y) 460 | task_models[task] = model 461 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs) 462 | 463 | if initial_weights is not None: 464 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights) 465 | 466 | return task_models, shared_model 467 | 468 | 469 | def create_multi_task_vgg16(input_dim, 470 | tasks=['EMO-DB', 'GEMEP'], 471 | nb_classes=[6, 10], 472 | random_noise=0.1, 473 | num_mels=128, 474 | classifier='attention2d', 475 | dropout=0.2, 476 | initial_weights=None, 477 | share_feature_layer=False, 478 | freeze_up_to=None): 479 | channel_axis = -1 480 | if tasks is None: 481 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!' 482 | logger.info( 483 | 'Trying to determine trained tasks from initial weights...') 484 | base_tasks, base_nb_classes = infer_tasks_from_weightfile( 485 | initial_weights) 486 | 487 | input_tensor = tf.keras.layers.Input(shape=input_dim) 488 | variable_duration = not (classifier == 'avgpool') 489 | if variable_duration: 490 | input_reshaped = tf.keras.layers.Reshape( 491 | target_shape=(-1, ))(input_tensor) 492 | else: 493 | input_reshaped = tf.keras.layers.Reshape( 494 | target_shape=(input_dim[0], ))(input_tensor) 495 | 496 | x, mask = input_features_and_mask(input_reshaped, num_fft=1024, 497 | hop_length=512, 498 | sample_rate=16000, 499 | f_min=20, 500 | f_max=8000, 501 | num_mels=num_mels, 502 | eps=1e-6, 503 | return_decibel=True, 504 | num_mfccs=None) 505 | 506 | pooled_mask = PoolMask((8,))(mask) 507 | 508 | expand_dims = tf.keras.layers.Lambda( 509 | lambda x: tf.expand_dims(x, 3), name='expand_input_dims') 510 | x = expand_dims(x) 511 | x = tf.keras.layers.Permute((2, 1, 3))(x) 512 | x = tf.keras.layers.Convolution2D(3, 1, activation='relu', name='learn_colourmapping', use_bias=False)(x) 513 | x = tf.keras.layers.BatchNormalization()(x) 514 | vgg16 = tf.keras.applications.VGG16(include_top=False, weights='imagenet', pooling=None) 515 | #vgg16 = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', pooling=None) 516 | if freeze_up_to is not None: 517 | for layer in vgg16.layers[:freeze_up_to]: 518 | layer.trainable = False 519 | else: 520 | for layer in vgg16.layers: 521 | layer.trainable = False 522 | task_models = {} 523 | outputs = [] 524 | adapter_rnn = None 525 | attention2d = None 526 | for task, classes in zip(tasks, nb_classes): 527 | logger.info(f'Building model for {task} with {classes} classes...') 528 | y = vgg16(x) 529 | 530 | if initial_weights is not None and classifier == 'avgpool': # might need new last dense layer 531 | name = f'{task}_1' 532 | else: 533 | name = task 534 | if classifier == 'avgpool': 535 | y = global_avgpool(y) 536 | 537 | 538 | if classifier == 'FCNAttention': 539 | if attention2d is None: 540 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y)[-1], name=f'core_2d_attention', trainable=True) 541 | if not share_feature_layer: 542 | attention2d = Attention2Dtanh(name=f'{task}_2d_attention', mlp_units=K.int_shape(y)[-1], trainable=True) 543 | y = attention2d(y) 544 | 545 | if classifier == 'rnn': 546 | if adapter_rnn is None: 547 | adapter_rnn = RNNWithAdapters(K.int_shape(y), 548 | hidden_size=K.int_shape(y)[-1], 549 | learnall=True, 550 | dropout=dropout, 551 | input_projection_factor=4, 552 | adapter_projection_factor=4, 553 | share_feature_layer=share_feature_layer, 554 | # tasks=base_tasks, 555 | tasks=[]) 556 | 557 | y = adapter_rnn(y, None, mask=pooled_mask) 558 | 559 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y) 560 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y) 561 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y) 562 | y = tf.keras.layers.Dropout(dropout)(y) 563 | if classes == 2: 564 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=name)(y) 565 | else: 566 | y = tf.keras.layers.Dense( 567 | classes, activation='softmax', name=name)(y) 568 | model = tf.keras.Model(inputs=input_tensor, outputs=y) 569 | outputs.append(y) 570 | task_models[task] = model 571 | 572 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs) 573 | 574 | if initial_weights is not None: 575 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights) 576 | 577 | return task_models, shared_model 578 | 579 | 580 | def create_multi_task_rnn(input_dim, 581 | num_mels=128, 582 | num_mfccs=40, 583 | hidden_dim=512, 584 | cell='lstm', 585 | number_of_layers=2, 586 | down_pool=8, 587 | mode='adapters', 588 | bidirectional=False, 589 | learnall=True, 590 | dropout=0.2, 591 | base_tasks=['EMO-DB', 'GEMEP'], 592 | new_tasks=[], 593 | base_nb_classes=[6, 10], 594 | new_nb_classes=None, 595 | initial_weights=None, 596 | random_noise=0.1, 597 | input_bn=False, 598 | share_feature_layer=False, 599 | use_attention=True, 600 | share_attention=False, 601 | input_projection=True): 602 | channel_axis = -1 603 | if base_tasks is None: 604 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!' 605 | 606 | logger.info( 607 | 'Trying to determine trained tasks from initial weights...') 608 | base_tasks, base_nb_classes = infer_tasks_from_weightfile( 609 | initial_weights) 610 | 611 | input_tensor = tf.keras.layers.Input(shape=input_dim) 612 | input_reshaped = tf.keras.layers.Reshape(target_shape=(-1, ))(input_tensor) 613 | 614 | x, mask = input_features_and_mask(input_reshaped, 615 | num_fft=1024, 616 | hop_length=512, 617 | num_mfccs=num_mfccs, 618 | sample_rate=16000, 619 | f_min=20, 620 | f_max=8000, 621 | num_mels=num_mels, 622 | eps=1e-6, 623 | return_decibel=num_mels is None) 624 | adapter_rnn = RNNWithAdapters(K.int_shape(x), 625 | hidden_size=hidden_dim, 626 | learnall=learnall, 627 | dropout=dropout, 628 | input_projection_factor=1, 629 | adapter_projection_factor=4, 630 | bidirectional=bidirectional, 631 | layers=number_of_layers, 632 | recurrent_cell=cell, 633 | input_bn=input_bn, 634 | downpool=down_pool, 635 | input_projection=input_projection, 636 | use_attention=use_attention, 637 | share_attention=share_attention, 638 | # tasks=base_tasks, 639 | tasks=base_tasks if mode == 'adapters' else [], 640 | share_feature_layer=share_feature_layer) 641 | if down_pool is not None: 642 | mask = PoolMask((down_pool,))(mask) 643 | 644 | 645 | if new_tasks is not None: 646 | really_new_tasks = [t for t in new_tasks if t not in base_tasks] 647 | new_nb_classes = [ 648 | c for c, t in zip(new_nb_classes, new_tasks) if t not in base_tasks 649 | ] 650 | else: 651 | really_new_tasks = [] 652 | new_nb_classes = [] 653 | 654 | task_models = {} 655 | outputs = [] 656 | for task, classes in zip(base_tasks, base_nb_classes): 657 | logger.info(f'Building model for {task} with {classes} classes...') 658 | #y = adapter_resnet(x, task) 659 | adapters_in = task if mode == 'adapters' else None 660 | y = adapter_rnn(x, adapters_in, mask=mask) 661 | if initial_weights is not None: # might need new last dense layer 662 | name = f'{task}_1' 663 | else: 664 | name = task 665 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y) 666 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y) 667 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y) 668 | y = tf.keras.layers.Dropout(0.2)(y) 669 | if classes == 2: 670 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=name)(y) 671 | else: 672 | y = tf.keras.layers.Dense( 673 | classes, activation='softmax', name=name)(y) 674 | model = tf.keras.Model(inputs=input_tensor, outputs=y) 675 | outputs.append(y) 676 | task_models[task] = model 677 | 678 | if really_new_tasks is not None and new_nb_classes is not None: 679 | for task, classes in zip(really_new_tasks, new_nb_classes): 680 | logger.info(f'Building model for {task} with {classes} classes...') 681 | adapter_rnn._add_new_task(task) 682 | y = adapter_rnn(x, task, mask) 683 | # y = apply_zero_mask([pooled_mask, 684 | # y]) # zero out silence activations 685 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y) 686 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y) 687 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y) 688 | y = tf.keras.layers.Dropout(0.2)(y) 689 | if classes == 2: 690 | y = tf.keras.layers.Dense( 691 | 1, activation='sigmoid', name=task)(y) 692 | else: 693 | y = tf.keras.layers.Dense( 694 | classes, activation='softmax', name=task)(y) 695 | model = tf.keras.Model(inputs=input_tensor, outputs=y) 696 | outputs.append(y) 697 | task_models[task] = model 698 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs) 699 | 700 | if initial_weights is not None: 701 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights) 702 | return task_models, shared_model 703 | 704 | 705 | def load_and_assert_loaded(shared_model, task_models, initial_weights): 706 | preloaded_layers = shared_model.layers.copy() 707 | shared_weights_pre_load, shared_names_pre_load = [], [] 708 | for layer in preloaded_layers: 709 | if not layer.trainable: 710 | logger.debug(f'Appending weights of {layer.name} with shape before load.') 711 | shared_names_pre_load.append(layer.name) 712 | shared_weights_pre_load.append(layer.get_weights()) 713 | logger.info('Loading weights from pre-trained model...') 714 | shared_model.load_weights(initial_weights, by_name=True) 715 | logger.info('Finished.') 716 | 717 | shared_weights_post_load = [] 718 | shared_names_post_load = [] 719 | for layer in shared_model.layers: 720 | if not layer.trainable: 721 | logger.debug(f'Appending weights of {layer.name} after load.') 722 | shared_names_post_load.append(layer.name) 723 | shared_weights_post_load.append(layer.get_weights()) 724 | shared_weights_single_task = [] 725 | shared_names_single_task = [] 726 | single_task_model = task_models[list( 727 | task_models.keys())[0]] 728 | for layer in single_task_model.layers: 729 | if not layer.trainable: 730 | logger.debug( 731 | f'Appending weights of {layer.name} of single task model after load.') 732 | shared_names_single_task.append(layer.name) 733 | shared_weights_single_task.append(layer.get_weights()) 734 | loaded, not_loaded, errors = 0, 0, 0 735 | assert shared_names_pre_load == shared_names_post_load and shared_names_post_load == shared_names_single_task, f'Layer name mistmatch: {shared_names_pre_load, shared_names_post_load, shared_names_single_task}.' 736 | for pre, pre_n, post, post_n, single, single_n in zip(shared_weights_pre_load, shared_names_pre_load, shared_weights_post_load, shared_names_post_load, shared_weights_single_task, shared_names_single_task): 737 | if array_list_equal(post, pre): 738 | not_loaded += 1 739 | logger.debug( 740 | f'Not loaded weights for layer {pre_n}: Total not loaded: {not_loaded}') 741 | elif array_list_equal(single, pre): 742 | not_loaded += 1 743 | logger.debug( 744 | f'Not loaded weights for layer {pre_n}: Total not loaded: {not_loaded}') 745 | elif array_list_equal(post, single): 746 | loaded += 1 747 | logger.debug( 748 | f'Loaded weights for layer {pre_n}. Total loaded: {loaded}') 749 | else: 750 | errors += 1 751 | logger.debug( 752 | f'Something went wrong with {pre_n, post_n, single_n}: {errors}') 753 | logger.info( 754 | f'Weights for {loaded} layers have been loaded from pre-trained model {initial_weights}.') 755 | return shared_model, task_models 756 | -------------------------------------------------------------------------------- /emo-net/models/input_layers.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | import tensorflow.keras.backend as K 21 | 22 | """ https://gist.github.com/keunwoochoi/f4854acb68acf791a49a051893bcd23b """ 23 | class LogMelgramLayer(tf.keras.layers.Layer): 24 | def __init__( 25 | self, num_fft, hop_length, num_mels, sample_rate, f_min=80, f_max=7600, eps=1e-6, return_decibel=True, top_db=80, mask_zero=True, **kwargs 26 | ): 27 | super(LogMelgramLayer, self).__init__(**kwargs) 28 | self.num_fft = num_fft 29 | self.hop_length = hop_length 30 | self.num_mels = num_mels 31 | self.sample_rate = sample_rate 32 | self.f_min = f_min 33 | self.f_max = f_max 34 | self.eps = eps 35 | self.return_decibel = return_decibel 36 | self.num_freqs = num_fft // 2 + 1 37 | self.mask_zero = mask_zero 38 | self.top_db = top_db 39 | 40 | lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix( 41 | num_mel_bins=self.num_mels, 42 | num_spectrogram_bins=self.num_freqs, 43 | sample_rate=self.sample_rate, 44 | lower_edge_hertz=self.f_min, 45 | upper_edge_hertz=self.f_max, 46 | ) 47 | 48 | self.lin_to_mel_matrix = lin_to_mel_matrix 49 | 50 | def build(self, input_shape): 51 | self.non_trainable_weights.append(self.lin_to_mel_matrix) 52 | super(LogMelgramLayer, self).build(input_shape) 53 | 54 | def call(self, input): 55 | """ 56 | Args: 57 | input (tensor): Batch of mono waveform, shape: (None, N) 58 | Returns: 59 | log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1) 60 | """ 61 | def _tf_log10(x): 62 | numerator = tf.math.log(x) 63 | denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) 64 | return numerator / denominator 65 | 66 | stfts = tf.signal.stft( 67 | input, 68 | frame_length=self.num_fft, 69 | frame_step=self.hop_length, 70 | pad_end=False, # librosa test compatibility 71 | ) 72 | mag_stfts = tf.abs(stfts) 73 | 74 | melgrams = tf.tensordot( 75 | tf.square(mag_stfts), self.lin_to_mel_matrix, 1 76 | ) 77 | melgrams.set_shape(mag_stfts.shape[:-1].concatenate(self.lin_to_mel_matrix.shape[-1:])) 78 | 79 | if self.return_decibel: 80 | log_melgrams = 10 * _tf_log10((melgrams + self.eps) / tf.reduce_max(melgrams)) 81 | if self.top_db is not None: 82 | if self.top_db < 0: 83 | raise ParameterError('top_db must be non-negative') 84 | log_melgrams = tf.math.maximum(log_melgrams, tf.reduce_max(log_melgrams) - self.top_db) 85 | #log_melgrams = (log_melgrams + self.top_db) / self.top_db 86 | 87 | else: 88 | log_melgrams = tf.math.log(melgrams + self.eps) 89 | return log_melgrams 90 | 91 | 92 | 93 | def get_config(self): 94 | config = { 95 | 'num_fft': self.num_fft, 96 | 'hop_length': self.hop_length, 97 | 'num_mels': self.num_mels, 98 | 'sample_rate': self.sample_rate, 99 | 'f_min': self.f_min, 100 | 'f_max': self.f_max, 101 | 'eps': self.eps, 102 | 'return_decibel': self.return_decibel, 103 | 'mask_zero': self.mask_zero, 104 | 'top_db': self.top_db 105 | } 106 | base_config = super(LogMelgramLayer, self).get_config() 107 | return dict(list(config.items()) + list(base_config.items())) 108 | 109 | 110 | class MFCCLayer(tf.keras.layers.Layer): 111 | def __init__( 112 | self, num_mfccs=50, **kwargs 113 | ): 114 | super(MFCCLayer, self).__init__(**kwargs) 115 | self.num_mfccs = num_mfccs 116 | 117 | 118 | def build(self, input_shape): 119 | super(MFCCLayer, self).build(input_shape) 120 | 121 | def call(self, input): 122 | """ 123 | Args: 124 | input (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1) 125 | Returns: 126 | mfccs (tensor): Batch of mfccs, shape: (None, num_frame, num_mfccs) 127 | """ 128 | 129 | log_mel_spectrograms = input 130 | 131 | mfccs = tf.signal.mfccs_from_log_mel_spectrograms( 132 | log_mel_spectrograms)[..., :self.num_mfccs] 133 | return mfccs 134 | 135 | def get_config(self): 136 | config = { 137 | 'num_mfccs': self.num_mfccs 138 | 139 | } 140 | base_config = super(MFCCLayer, self).get_config() 141 | return dict(list(config.items()) + list(base_config.items())) 142 | 143 | class ComputeMask(tf.keras.layers.Layer): 144 | def __init__(self, num_fft, hop_length, **kwargs): 145 | super(ComputeMask, self).__init__(**kwargs) 146 | self.num_fft = num_fft 147 | self.hop_length = hop_length 148 | 149 | def call(self, x): 150 | frames = tf.signal.frame(x, self.num_fft, self.hop_length, pad_end=False, 151 | axis=-1, 152 | name=None) 153 | non_zeros = tf.math.count_nonzero(frames, axis=-1) 154 | mask = tf.not_equal(non_zeros, 0) 155 | return mask 156 | 157 | def get_config(self): 158 | config = { 159 | 'num_fft': self.num_fft, 160 | 'hop_length': self.hop_length, 161 | } 162 | base_config = super(ComputeMask, self).get_config() 163 | return dict(list(config.items()) + list(base_config.items())) 164 | 165 | 166 | class PoolMask(tf.keras.layers.Layer): 167 | def __init__(self, pool_size, **kwargs): 168 | super(PoolMask, self).__init__(**kwargs) 169 | self.pool_size = pool_size 170 | 171 | def call(self, x): 172 | x = tf.expand_dims(x, -1) 173 | x = tf.cast(x, dtype='int8') 174 | x = tf.nn.pool(x, self.pool_size, pooling_type='MAX', padding='SAME', strides=self.pool_size) 175 | x = K.batch_flatten(x) 176 | x = tf.not_equal(x, 0) 177 | return x 178 | 179 | def get_config(self): 180 | config = { 181 | 'pool_size': self.pool_size, 182 | } 183 | base_config = super(PoolMask, self).get_config() 184 | return dict(list(config.items()) + list(base_config.items())) -------------------------------------------------------------------------------- /emo-net/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/training/__init__.py -------------------------------------------------------------------------------- /emo-net/training/evaluate.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | tf.random.set_seed(42) 21 | 22 | import time 23 | import pandas as pd 24 | from os.path import join 25 | from sklearn.utils import class_weight 26 | from .losses import * 27 | from .metrics import * 28 | from ..models.build_model import create_multi_task_resnets, create_multi_task_rnn, create_multi_task_networks 29 | from ..data.loader import * 30 | from os import makedirs 31 | import logging 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def evaluate( 36 | initial_weights='weights.h5', 37 | feature_extractor='cnn', 38 | batch_size=64, 39 | window=5, 40 | num_mels=128, 41 | task="", 42 | directory='EmoSet/IEMOCAP', 43 | val_csv='val.csv', 44 | share_feature_layer=True, 45 | input_bn=False, 46 | mode='adapters', 47 | output='pred.csv', 48 | **kwargs): 49 | if feature_extractor in ['cnn', 'vgg16']: 50 | variable_duration = False if kwargs['classifier'] == 'avgpool' else True 51 | else: 52 | variable_duration = True 53 | #variable_duration = False 54 | 55 | 56 | val_generator = AudioDataGenerator(val_csv, 57 | directory, 58 | batch_size=batch_size, 59 | window=window, 60 | shuffle=False, 61 | sr=16000, 62 | time_stretch=None, 63 | pitch_shift=None, 64 | save_dir=None, 65 | variable_duration=variable_duration) 66 | 67 | val_dataset = val_generator.tf_dataset() 68 | 69 | 70 | 71 | 72 | 73 | x, _ = val_generator[0] 74 | if not variable_duration: 75 | init = x.shape[1:] 76 | else: 77 | init = (None, ) 78 | models, shared_model = create_multi_task_networks( 79 | init, 80 | feature_extractor=feature_extractor, 81 | initial_weights=initial_weights, 82 | num_mels=num_mels, 83 | new_tasks=[], 84 | new_nb_classes=[], 85 | mode=mode, 86 | input_bn=input_bn, 87 | learnall=False, 88 | share_feature_layer=share_feature_layer, 89 | **kwargs) 90 | model = models[task] 91 | #model.load_weights(initial_weights, by_name=True) 92 | 93 | model.summary() 94 | #print(model.non_trainable_weights) 95 | #model.load_weights(initial_weights, by_name=True) 96 | 97 | 98 | metric_callback = ClassificationMetricCallback( 99 | validation_generator=val_dataset.prefetch(tf.data.experimental.AUTOTUNE), dataset_name='Test', labels=val_generator.class_indices, true=val_generator.categorical_classes) 100 | 101 | 102 | 103 | filenames = list(map(lambda x: join(*(x.split('/')[-4:])), val_generator.files)) 104 | index_to_class = {v: k for k, v in val_generator.class_indices.items()} 105 | 106 | logger.info("Model loaded.") 107 | model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.SGD(lr=0.1), metrics=['accuracy']) 108 | metric_callback.set_model(model) 109 | x = model.evaluate(val_generator.tf_dataset(), 110 | # use_multiprocessing=True, 111 | # max_queue_size=n_workers * 2, 112 | verbose=1, 113 | callbacks=[ 114 | metric_callback 115 | ]) 116 | metric_callback.on_epoch_end(epoch=1) 117 | predictions = model.predict(val_generator.tf_dataset()) 118 | probas = predictions 119 | if predictions.shape[1] > 1: 120 | predictions = list(map(lambda x: index_to_class[x], np.argmax(predictions, axis=-1))) 121 | else: 122 | predictions = list(map(lambda x: index_to_class[x], np.squeeze(np.where(predictions < 0.5, 0, 1)))) 123 | true = list(map(lambda x: index_to_class[x], val_generator.classes)) 124 | columns = ['filename', *[f'probability_{index_to_class[i]}' for i in range(probas.shape[1])], 'pred_label', 'true_label'] 125 | df = pd.DataFrame(columns=columns) 126 | df['filename'] = filenames 127 | df['pred_label'] = predictions 128 | df['true_label'] = true 129 | df[[f'probability_{index_to_class[i]}' for i in range(probas.shape[1])]] = probas 130 | df.to_csv(output, index=False) 131 | 132 | -------------------------------------------------------------------------------- /emo-net/training/losses.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | """ 20 | Define our custom loss function. 21 | 22 | https://github.com/umbertogriffo/focal-loss-keras/blob/master/losses.py 23 | """ 24 | from tensorflow.keras import backend as K 25 | import tensorflow as tf 26 | 27 | import dill 28 | 29 | 30 | def soft_ordinal_categorical_loss(n_classes=3, metric=lambda x, y: tf.math.abs(x-y)): 31 | ranks_tensor = tf.constant(list(range(n_classes)), dtype='float32') 32 | 33 | def soft_ordinal_categorical_loss_fixed(y_true, y_pred): 34 | trues_tensor = tf.cast(tf.math.argmax(y_true, -1, output_type='int32'), 'float32') 35 | diff = metric(tf.expand_dims(trues_tensor, -1), tf.expand_dims(ranks_tensor, 0)) 36 | softmax = tf.nn.softmax(-diff) 37 | return tf.keras.losses.categorical_crossentropy(softmax, y_pred) 38 | 39 | return soft_ordinal_categorical_loss_fixed 40 | 41 | 42 | def binary_focal_loss(gamma=2., alpha=.25): 43 | """ 44 | Binary form of focal loss. 45 | FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t) 46 | where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively. 47 | References: 48 | https://arxiv.org/pdf/1708.02002.pdf 49 | Usage: 50 | model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam) 51 | """ 52 | def binary_focal_loss_fixed(y_true, y_pred): 53 | """ 54 | :param y_true: A tensor of the same shape as `y_pred` 55 | :param y_pred: A tensor resulting from a sigmoid 56 | :return: Output tensor. 57 | """ 58 | pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred)) 59 | pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred)) 60 | 61 | epsilon = K.epsilon() 62 | # clip to prevent NaN's and Inf's 63 | pt_1 = K.clip(pt_1, epsilon, 1. - epsilon) 64 | pt_0 = K.clip(pt_0, epsilon, 1. - epsilon) 65 | 66 | return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) \ 67 | -K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0)) 68 | 69 | return binary_focal_loss_fixed 70 | 71 | 72 | def categorical_focal_loss(gamma=2., alpha=.25): 73 | """ 74 | Softmax version of focal loss. 75 | m 76 | FL = ∑ -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c) 77 | c=1 78 | where m = number of classes, c = class and o = observation 79 | Parameters: 80 | alpha -- the same as weighing factor in balanced cross entropy 81 | gamma -- focusing parameter for modulating factor (1-p) 82 | Default value: 83 | gamma -- 2.0 as mentioned in the paper 84 | alpha -- 0.25 as mentioned in the paper 85 | References: 86 | Official paper: https://arxiv.org/pdf/1708.02002.pdf 87 | https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy 88 | Usage: 89 | model.compile(loss=[categorical_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam) 90 | """ 91 | def categorical_focal_loss_fixed(y_true, y_pred): 92 | """ 93 | :param y_true: A tensor of the same shape as `y_pred` 94 | :param y_pred: A tensor resulting from a softmax 95 | :return: Output tensor. 96 | """ 97 | 98 | # Scale predictions so that the class probas of each sample sum to 1 99 | y_pred /= K.sum(y_pred, axis=-1, keepdims=True) 100 | 101 | # Clip the prediction value to prevent NaN's and Inf's 102 | epsilon = K.epsilon() 103 | y_pred = K.clip(y_pred, epsilon, 1. - epsilon) 104 | 105 | # Calculate Cross Entropy 106 | cross_entropy = -y_true * K.log(y_pred) 107 | 108 | # Calculate Focal Loss 109 | loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy 110 | 111 | # Sum the losses in mini_batch 112 | return K.sum(loss, axis=1) 113 | 114 | return categorical_focal_loss_fixed 115 | 116 | 117 | if __name__ == '__main__': 118 | 119 | # Test serialization of nested functions 120 | bin_inner = dill.loads(dill.dumps(binary_focal_loss(gamma=2., alpha=.25))) 121 | print(bin_inner) 122 | 123 | cat_inner = dill.loads(dill.dumps(categorical_focal_loss(gamma=2., alpha=.25))) 124 | print(cat_inner) -------------------------------------------------------------------------------- /emo-net/training/metrics.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | def warn(*args, **kwargs): 20 | pass 21 | import warnings 22 | warnings.warn = warn 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from abc import ABC, abstractmethod 27 | from dataclasses import dataclass, field 28 | from scipy.stats import shapiro, pearsonr 29 | from sklearn.metrics import recall_score, make_scorer, accuracy_score, f1_score, mean_squared_error, classification_report, confusion_matrix, multilabel_confusion_matrix, precision_score, roc_auc_score, average_precision_score, roc_curve 30 | from sklearn.metrics.scorer import _BaseScorer 31 | from statistics import pstdev, mean 32 | from typing import Dict, List, ClassVar, Set 33 | from math import sqrt 34 | from tqdm import tqdm 35 | 36 | import logging 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | def mask_metric(func): 41 | def mask_metric_function(*args, **kwargs): 42 | mask = np.not_equal(kwargs['y_true'], -1).astype(float) 43 | kwargs['y_true'] = (kwargs['y_true'] * mask) 44 | kwargs['y_pred'] = (kwargs['y_pred'] * mask) 45 | return func(*args, **kwargs) 46 | 47 | return mask_metric_function 48 | 49 | 50 | def optimal_threshold(fpr, tpr, thresholds): 51 | optimal_idx = np.argmax(tpr - fpr) 52 | optimal_threshold = thresholds[optimal_idx] 53 | return optimal_threshold 54 | 55 | 56 | def compute_binary_cutoffs(y_true, y_pred): 57 | if y_true.shape == y_pred.shape and len(y_true.shape) == 1: # 2 classes 58 | fpr, tpr, thresholds = roc_curve(y_true, y_pred) 59 | return [optimal_threshold(fpr, tpr, thresholds)] 60 | elif y_true.shape == y_pred.shape and len(y_true.shape) == 2: # multilabel 61 | fpr_tpr_thresholds = [ 62 | roc_curve(y_true[:, i], y_pred[:, i]) 63 | for i in range(y_true.shape[1]) 64 | ] 65 | return [optimal_threshold(*x) for x in fpr_tpr_thresholds] 66 | 67 | 68 | class ClassificationMetricCallback(tf.keras.callbacks.Callback): 69 | def __init__(self, 70 | labels: List = None, 71 | validation_generator=None, 72 | validation_data=None, 73 | multi_label=False, 74 | partition='validation', 75 | true=None, 76 | period=1, 77 | dataset_name='default'): 78 | super().__init__() 79 | if labels is not None: 80 | self.labels = {name: index for index, name in enumerate(labels)} 81 | self.binary = (len(labels) == 2) 82 | 83 | elif validation_generator is not None: 84 | self.labels = validation_generator.class_indices 85 | self.binary = len(self.labels) == 2 86 | 87 | # if true is not None: 88 | # self.y_val = np.squeeze(true) 89 | 90 | 91 | self.validation_generator = validation_generator 92 | self.validation_data = validation_data 93 | if isinstance(self.validation_generator, tf.data.Dataset): 94 | self.y_val = [] 95 | for features, labels in self.validation_generator.take(-1): # only take first element of dataset 96 | labels_numpy = labels.numpy() 97 | if labels_numpy.shape[-1] == 1: 98 | labels_numpy = np.squeeze(labels_numpy, axis=-1) 99 | self.y_val.append(labels_numpy) 100 | self.y_val = np.concatenate(self.y_val) 101 | else: 102 | self.y_val = np.squeeze(self.validation_generator.categorical_classes) 103 | self.multi_label = multi_label 104 | self.partition = partition 105 | self.keras_metric_quantities = KERAS_METRIC_QUANTITIES 106 | self.dataset_name = dataset_name 107 | 108 | self._binary_cutoffs = [] 109 | self._data = [] 110 | self.period = period 111 | 112 | def on_train_begin(self, logs={}): 113 | pass 114 | 115 | def on_epoch_end(self, epoch, logs={}): 116 | if epoch % self.period == 0: 117 | if self.validation_generator is None: 118 | X_val, y_val = self.validation_data[0], self.validation_data[1] 119 | y_pred = np.asarray(self.model.predict(X_val)) 120 | else: 121 | y_pred = np.squeeze(self.model.predict(self.validation_generator)) 122 | 123 | logs = self.compute_metrics(self.y_val, 124 | y_pred, 125 | multi_label=self.multi_label, 126 | binary=self.binary, 127 | labels=sorted( 128 | self.labels.values()), 129 | prefix=f'{self.partition}', 130 | logs=logs, 131 | target_names=sorted(self.labels.keys())) 132 | 133 | return 134 | 135 | def get_data(self): 136 | return self._data 137 | 138 | def compute_metrics(self, 139 | y_val, 140 | y_pred, 141 | multi_label=False, 142 | binary=False, 143 | labels=None, 144 | prefix='', 145 | logs={}, 146 | target_names=None): 147 | eval_string = f'\nEvaluation results for partition {self.partition} of dataset {self.dataset_name}:\n' 148 | all_classes_present = np.all(np.any(y_val > 0, axis=0)) 149 | if multi_label: 150 | binary_cutoffs = compute_binary_cutoffs(y_true=y_val, 151 | y_pred=y_pred) 152 | self._binary_cutoffs.append(binary_cutoffs) 153 | logger.info(f'Optimal cutoffs: {binary_cutoffs}') 154 | else: 155 | binary_cutoffs = None 156 | y_val_t, y_pred_t = ClassificationMetric._transform_arrays( 157 | y_true=y_val, 158 | y_pred=y_pred, 159 | multi_label=multi_label, 160 | binary=binary, 161 | binary_cutoffs=binary_cutoffs) 162 | eval_string += classification_report(y_val_t, 163 | y_pred_t, 164 | target_names=target_names) 165 | if self.multi_label: 166 | eval_string += '\n'+ str(multilabel_confusion_matrix(y_true=y_val_t, 167 | y_pred=y_pred_t, 168 | labels=labels)) 169 | else: 170 | conf_matrix = confusion_matrix(y_true=np.argmax(y_val_t, axis=1) if 171 | len(y_val_t.shape) > 1 else y_val_t, 172 | y_pred=np.argmax(y_pred_t, axis=1) if 173 | len(y_pred_t.shape) > 1 else y_pred_t, 174 | labels=labels) 175 | eval_string += '\n'+ str(conf_matrix) 176 | logs[f'{prefix}confusion_matrix'] = conf_matrix 177 | for i, cm in enumerate(CLASSIFICATION_METRICS): 178 | if all_classes_present or not (cm == ROC_AUC or cm == PR_AUC): 179 | if cm.needs_categorical: 180 | metric = cm.compute(y_true=y_val_t, 181 | y_pred=y_pred_t, 182 | labels=labels, 183 | binary=binary, 184 | multi_label=multi_label, 185 | binary_cutoffs=binary_cutoffs) 186 | else: 187 | metric = cm.compute(y_true=y_val, 188 | y_pred=y_pred, 189 | labels=labels, 190 | binary=binary, 191 | multi_label=multi_label, 192 | binary_cutoffs=binary_cutoffs) 193 | metric_value = metric.value 194 | eval_string += f'\n{prefix} {cm.description}: {metric_value}' 195 | if not self._data: # first recorded value 196 | self._data.append({ 197 | f'{self.keras_metric_quantities[cm]}/{prefix}': 198 | metric_value, 199 | }) 200 | elif i == 0 and self._data and f'{self.keras_metric_quantities[cm]}/{prefix}' in self._data[ 201 | -1].keys(): 202 | self._data.append({ 203 | f'{self.keras_metric_quantities[cm]}/{prefix}': 204 | metric_value, 205 | }) 206 | else: 207 | self._data[-1][ 208 | f'{self.keras_metric_quantities[cm]}/{prefix}'] = metric_value 209 | if len( 210 | self._data 211 | ) > 1: # this is the second epoch and metrics have been recorded for the first epoch 212 | cur_best = self._data[-2][ 213 | f'{self.keras_metric_quantities[cm]}_best/{prefix}'] 214 | else: # this is the first epoch 215 | cur_best = metric_value 216 | 217 | new_best = metric_value if metric > cm( 218 | value=cur_best) else cur_best 219 | 220 | self._data[-1][ 221 | f'{self.keras_metric_quantities[cm]}_best/{prefix}'] = new_best 222 | 223 | logs[f'{self.keras_metric_quantities[cm]}/{prefix}'] = metric_value 224 | logs[f'{self.keras_metric_quantities[cm]}_best/{prefix}'] = new_best 225 | else: 226 | logger.info( 227 | f'Not all classes occur in the validation data, skipping ROC AUC and PR AUC.' 228 | ) 229 | logger.info(eval_string) 230 | return logs 231 | 232 | 233 | class RegressionMetricCallback(tf.keras.callbacks.Callback): 234 | def __init__(self, validation_data=()): 235 | super().__init__() 236 | self.validation_data = validation_data 237 | 238 | def on_train_begin(self, logs={}): 239 | self._data = [] 240 | 241 | def on_epoch_end(self, batch, logs={}): 242 | X_val, y_val = self.validation_data[0], self.validation_data[1] 243 | y_predict = np.asarray(self.model.predict(X_val)) 244 | 245 | for metric in REGRESSION_METRICS: 246 | metric_value = metric.compute(y_true=y_val, y_pred=y_predict).value 247 | self._data.append({f'val_{metric.__name__.lower()}': metric_value}) 248 | logs[f'val_{metric.__name__.lower()}'] = metric_value 249 | return 250 | 251 | def get_data(self): 252 | return self._data 253 | 254 | 255 | @dataclass(order=True) 256 | class Metric(ABC): 257 | sort_index: float = field(init=False, repr=False) 258 | description: ClassVar[str] = 'Metric' 259 | key: ClassVar[str] = 'M' 260 | value: float 261 | scikit_scorer: ClassVar[_BaseScorer] = field(init=False, repr=False) 262 | greater_is_better: ClassVar[bool] = True 263 | 264 | def __post_init__(self): 265 | self.sort_index = self.value if self.greater_is_better else -self.value 266 | 267 | 268 | @dataclass(order=True) 269 | class ClassificationMetric(Metric, ABC): 270 | multi_label: bool = False 271 | binary: bool = False 272 | average: ClassVar[str] = None 273 | needs_categorical: ClassVar[bool] = True 274 | 275 | @classmethod 276 | @mask_metric 277 | @abstractmethod 278 | def compute(cls, 279 | y_true: np.array, 280 | y_pred: np.array, 281 | labels: List, 282 | multi_label: bool, 283 | binary: bool, 284 | binary_cutoffs: List[float] = None) -> Metric: 285 | pass 286 | 287 | @staticmethod 288 | def _transform_arrays(y_true: np.array, 289 | y_pred: np.array, 290 | multi_label: bool, 291 | binary: bool, 292 | binary_cutoffs: List[float] = None 293 | ) -> (np.array, np.array): 294 | if binary: 295 | if len(y_pred.shape) > 1: 296 | y_pred = np.reshape(y_pred, -1) 297 | if len(y_true.shape) > 1: 298 | y_true = np.reshape(y_true, -1) 299 | assert ( 300 | y_true.shape == y_pred.shape and len(y_true.shape) == 1 301 | ), f'Shapes of predictions and labels for binary classification should conform to (n_samples,) but received {y_pred.shape} and {y_true.shape}.' 302 | #if binary_cutoffs is None: 303 | binary_cutoffs = 0.5 304 | #y_pred_transformed = np.zeros_like(y_pred, dtype=int) 305 | #y_pred_transformed[y_pred > binary_cutoffs[0]] = 1 306 | y_pred_transformed = np.where(y_pred > binary_cutoffs, 1, 0) 307 | y_true_transformed = y_true 308 | 309 | elif multi_label: 310 | assert ( 311 | y_true.shape == y_pred.shape 312 | ), f'Shapes of predictions and labels for multilabel classification should conform to (n_samples, n_classes) but received {y_pred.shape} and {y_true.shape}.' 313 | if binary_cutoffs is None: 314 | binary_cutoffs = compute_binary_cutoffs(y_true, y_pred) 315 | # y_pred_transformed = np.zeros_like(y_pred, dtype=int) 316 | # y_pred_transformed[y_pred > 0.5] = 1 317 | y_pred_transformed = np.where(y_pred > binary_cutoffs, 1, 0) 318 | y_true_transformed = y_true 319 | else: 320 | if y_true.shape[1] > 1: 321 | y_true_transformed = np.zeros_like(y_true) 322 | y_true_transformed[range(len(y_true)), y_true.argmax(1)] = 1 323 | if y_pred.shape[1] > 1: 324 | y_pred_transformed = np.zeros_like(y_pred) 325 | y_pred_transformed[range(len(y_pred)), y_pred.argmax(1)] = 1 326 | assert ( 327 | y_true.shape == y_pred.shape 328 | ), f'Shapes of predictions and labels for multiclass classification should conform to (n_samples,n_classes) but received {y_pred.shape} and {y_true.shape}.' 329 | return y_true_transformed, y_pred_transformed 330 | 331 | 332 | @dataclass(order=True) 333 | class RegressionMetric(Metric, ABC): 334 | @staticmethod 335 | @abstractmethod 336 | def compute(y_true: np.array, y_pred: np.array) -> Metric: 337 | pass 338 | 339 | 340 | @dataclass(order=True) 341 | class MicroRecall(ClassificationMetric): 342 | description: ClassVar[str] = 'Micro Average Recall' 343 | average: ClassVar[str] = 'micro' 344 | key: ClassVar[str] = 'Recall/Micro' 345 | 346 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(recall_score, 347 | average='micro') 348 | greater_is_better: ClassVar[bool] = True 349 | 350 | @classmethod 351 | @mask_metric 352 | def compute(cls, 353 | y_true: np.array, 354 | y_pred: np.array, 355 | labels: List, 356 | multi_label: bool, 357 | binary: bool, 358 | binary_cutoffs: List[float] = None) -> ClassificationMetric: 359 | score = recall_score(y_true=y_true, 360 | y_pred=y_pred, 361 | labels=labels, 362 | average=cls.average) 363 | return cls(value=score, multi_label=multi_label, binary=binary) 364 | 365 | 366 | @dataclass(order=True) 367 | class UAR(MicroRecall): 368 | average: ClassVar[str] = 'macro' 369 | description: ClassVar[str] = 'Unweighted Average Recall' 370 | key: ClassVar[str] = 'Recall/Macro' 371 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(recall_score, 372 | average='macro') 373 | greater_is_better: ClassVar[bool] = True 374 | 375 | 376 | @dataclass(order=True) 377 | class Accuracy(ClassificationMetric): 378 | description: ClassVar[str] = 'Accuracy' 379 | key: ClassVar[str] = 'acc' 380 | 381 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(accuracy_score) 382 | greater_is_better: ClassVar[bool] = True 383 | 384 | @classmethod 385 | @mask_metric 386 | def compute(cls, 387 | y_true: np.array, 388 | y_pred: np.array, 389 | labels: List, 390 | multi_label: bool, 391 | binary: bool, 392 | binary_cutoffs: List[float] = None) -> ClassificationMetric: 393 | score = accuracy_score(y_true=y_true, y_pred=y_pred) 394 | return cls(value=score, multi_label=multi_label, binary=binary) 395 | 396 | 397 | @dataclass(order=True) 398 | class MacroF1(ClassificationMetric): 399 | average: ClassVar[str] = 'macro' 400 | description: ClassVar[str] = 'Macro Average F1' 401 | key: ClassVar[str] = 'F1/Macro' 402 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(f1_score, 403 | average='macro') 404 | greater_is_better: ClassVar[bool] = True 405 | 406 | @classmethod 407 | @mask_metric 408 | def compute(cls, 409 | y_true: np.array, 410 | y_pred: np.array, 411 | labels: List, 412 | multi_label: bool, 413 | binary: bool, 414 | binary_cutoffs: List[float] = None) -> ClassificationMetric: 415 | score = f1_score(y_true=y_true, 416 | y_pred=y_pred, 417 | labels=labels, 418 | average=cls.average) 419 | return cls(value=score, multi_label=multi_label, binary=binary) 420 | 421 | 422 | @dataclass(order=True) 423 | class MicroF1(MacroF1): 424 | average: ClassVar[str] = 'micro' 425 | description: ClassVar[str] = 'Micro Average F1' 426 | key: ClassVar[str] = 'F1/Micro' 427 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(f1_score, 428 | average='micro') 429 | greater_is_better: ClassVar[bool] = True 430 | 431 | 432 | @dataclass(order=True) 433 | class MacroPrecision(ClassificationMetric): 434 | average: ClassVar[str] = 'macro' 435 | description: ClassVar[str] = 'Macro Average Precision' 436 | key: ClassVar[str] = 'Prec/Macro' 437 | 438 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(precision_score, 439 | average='macro') 440 | greater_is_better: ClassVar[bool] = True 441 | 442 | @classmethod 443 | @mask_metric 444 | def compute(cls, 445 | y_true: np.array, 446 | y_pred: np.array, 447 | labels: List, 448 | multi_label: bool, 449 | binary: bool, 450 | binary_cutoffs: List[float] = None) -> ClassificationMetric: 451 | score = precision_score(y_true=y_true, 452 | y_pred=y_pred, 453 | labels=labels, 454 | average=cls.average) 455 | return cls(value=score, multi_label=multi_label, binary=binary) 456 | 457 | 458 | @dataclass(order=True) 459 | class MicroPrecision(MacroPrecision): 460 | average: ClassVar[str] = 'micro' 461 | description: ClassVar[str] = 'Micro Average Prec' 462 | key: ClassVar[str] = 'Prec/Micro' 463 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(precision_score, 464 | average='micro') 465 | greater_is_better: ClassVar[bool] = True 466 | 467 | 468 | @dataclass(order=True) 469 | class ROC_AUC(ClassificationMetric): 470 | average: ClassVar[str] = 'macro' 471 | description: ClassVar[ 472 | str] = 'Area Under the Receiver Operating Characteristic Curve' 473 | key: ClassVar[str] = 'ROC AUC' 474 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(roc_auc_score, 475 | average='macro') 476 | greater_is_better: ClassVar[bool] = True 477 | needs_categorical: ClassVar[bool] = False 478 | 479 | @classmethod 480 | @mask_metric 481 | def compute(cls, 482 | y_true: np.array, 483 | y_pred: np.array, 484 | labels: List, 485 | multi_label: bool, 486 | binary: bool, 487 | binary_cutoffs: List[float] = None) -> ClassificationMetric: 488 | score = roc_auc_score(y_true=y_true, 489 | y_score=y_pred, 490 | average=cls.average) 491 | return cls(value=score, multi_label=multi_label, binary=binary) 492 | 493 | 494 | @dataclass(order=True) 495 | class PR_AUC(ClassificationMetric): 496 | average: ClassVar[str] = 'macro' 497 | description: ClassVar[str] = 'Area Under the Precision Recall Curve' 498 | key: ClassVar[str] = 'PR AUC' 499 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(average_precision_score, 500 | average='macro') 501 | greater_is_better: ClassVar[bool] = True 502 | needs_categorical: ClassVar[bool] = False 503 | 504 | 505 | @classmethod 506 | @mask_metric 507 | def compute(cls, 508 | y_true: np.array, 509 | y_pred: np.array, 510 | labels: List, 511 | multi_label: bool, 512 | binary: bool, 513 | binary_cutoffs: List[float] = None) -> ClassificationMetric: 514 | score = average_precision_score(y_true=y_true, 515 | y_score=y_pred, 516 | average=cls.average) 517 | return cls(value=score, multi_label=multi_label, binary=binary) 518 | 519 | 520 | @dataclass(order=True) 521 | class MSE(RegressionMetric): 522 | description: ClassVar[str] = 'Mean Squared Error' 523 | key: ClassVar[str] = 'mse' 524 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(mean_squared_error, 525 | greater_is_better=False) 526 | greater_is_better: ClassVar[bool] = False 527 | 528 | @staticmethod 529 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric: 530 | score = mean_squared_error(y_true=y_true, y_pred=y_pred) 531 | return MSE(value=score) 532 | 533 | 534 | def pearson_correlation_coefficient(y_true, y_pred): 535 | return pearsonr(y_true, y_pred)[0] 536 | 537 | 538 | @dataclass(order=True) 539 | class PCC(RegressionMetric): 540 | description: ClassVar[str] = 'Pearson\'s Correlation Coeffiecient' 541 | key: ClassVar[str] = 'pcc' 542 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer( 543 | pearson_correlation_coefficient, greater_is_better=True) 544 | greater_is_better: ClassVar[bool] = True 545 | 546 | @staticmethod 547 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric: 548 | score = pearson_correlation_coefficient(y_true=y_true, y_pred=y_pred) 549 | return PCC(value=score) 550 | 551 | 552 | def concordance_correlation_coefficient(y_true, y_pred): 553 | ccc = 2 * pearson_correlation_coefficient(y_true=y_true, y_pred=y_pred) / ( 554 | np.var(y_true) + np.var(y_pred) + 555 | (np.mean(y_true) - np.mean(y_pred))**2) 556 | return ccc 557 | 558 | 559 | @dataclass(order=True) 560 | class CCC(RegressionMetric): 561 | description: ClassVar[str] = 'Concordance Correlation Coeffiecient' 562 | key: ClassVar[str] = 'ccc' 563 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer( 564 | pearson_correlation_coefficient, greater_is_better=True) 565 | greater_is_better: ClassVar[bool] = True 566 | 567 | @staticmethod 568 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric: 569 | score = concordance_correlation_coefficient(y_true=y_true, 570 | y_pred=y_pred) 571 | return CCC(value=score) 572 | 573 | 574 | def root_mean_squared_error(y_true, y_pred): 575 | return sqrt(mean_squared_error(y_true=y_true, y_pred=y_pred)) 576 | 577 | 578 | @dataclass(order=True) 579 | class RMSE(RegressionMetric): 580 | description: ClassVar[str] = 'Root Mean Squared Error' 581 | key: ClassVar[str] = 'rmse' 582 | 583 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(root_mean_squared_error, 584 | greater_is_better=False) 585 | greater_is_better: ClassVar[bool] = False 586 | 587 | @staticmethod 588 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric: 589 | score = root_mean_squared_error(y_true=y_true, y_pred=y_pred) 590 | return RMSE(value=score) 591 | 592 | 593 | @dataclass 594 | class MetricStats(): 595 | mean: float 596 | standard_deviation: float 597 | normality_tests: Dict[str, tuple] 598 | 599 | 600 | def compute_metric_stats(metrics: List[Metric]) -> MetricStats: 601 | metric_values = [metric.value for metric in metrics] 602 | normality_tests = dict() 603 | if len(metric_values) > 2: 604 | normality_tests['Shapiro-Wilk'] = shapiro(metric_values) 605 | return MetricStats(mean=mean(metric_values), 606 | standard_deviation=pstdev(metric_values), 607 | normality_tests=normality_tests) 608 | 609 | 610 | def all_subclasses(cls): 611 | return set(cls.__subclasses__()).union( 612 | [s for c in cls.__subclasses__() for s in all_subclasses(c)]) 613 | 614 | 615 | CLASSIFICATION_METRICS = all_subclasses(ClassificationMetric) 616 | REGRESSION_METRICS = all_subclasses(RegressionMetric) 617 | ALL_METRICS = all_subclasses(Metric) 618 | 619 | SCIKIT_CLASSIFICATION_SCORERS = { 620 | M.__name__: M.scikit_scorer 621 | for M in CLASSIFICATION_METRICS if M != ROC_AUC and M != PR_AUC 622 | } 623 | 624 | SCIKIT_CLASSIFICATION_SCORERS_EXTENDED = { 625 | M.__name__: M.scikit_scorer 626 | for M in CLASSIFICATION_METRICS 627 | } 628 | 629 | SCIKIT_REGRESSION_SCORERS = { 630 | M.__name__: M.scikit_scorer 631 | for M in REGRESSION_METRICS 632 | } 633 | 634 | KERAS_METRIC_QUANTITIES = { 635 | M: f'val_{"_".join(M.key.lower().split(" "))}' 636 | for M in ALL_METRICS 637 | } 638 | 639 | KERAS_METRIC_MODES = { 640 | M: 'max' if M.greater_is_better else 'min' 641 | for M in ALL_METRICS 642 | } 643 | 644 | KEY_TO_METRIC = {metric.__name__: metric for metric in ALL_METRICS} 645 | -------------------------------------------------------------------------------- /emo-net/training/train.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import tensorflow as tf 20 | tf.random.set_seed(42) 21 | 22 | import time 23 | from os.path import join 24 | from sklearn.utils import class_weight 25 | from .losses import * 26 | from .metrics import * 27 | from ..models.build_model import create_multi_task_resnets, create_multi_task_rnn, create_multi_task_networks 28 | from ..data.loader import * 29 | from os import makedirs 30 | import logging 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | 35 | categorical_loss_map = {'crossentropy': "categorical_crossentropy", "focal": categorical_focal_loss( 36 | ), "ordinal": soft_ordinal_categorical_loss} 37 | 38 | binary_loss_map = {'crossentropy': "binary_crossentropy", "focal": binary_focal_loss( 39 | ), "ordinal": "binary_crossentropy"} 40 | 41 | def determine_decay(generator, batch_size): 42 | if 10000 > len(generator) * batch_size >= 1000: 43 | decay = 0.0005 44 | elif 1000 > len(generator) * batch_size >= 500: 45 | decay = 0.002 46 | elif len(generator) * batch_size < 500: 47 | decay = 0.005 48 | else: 49 | decay = 1e-6 50 | return decay 51 | 52 | 53 | def named_logs(model, logs): 54 | result = {} 55 | for l in zip(model.metrics_names, logs): 56 | result[l[0]] = l[1] 57 | return result 58 | 59 | def __feature_extractor_params_string(feature_extractor, **kwargs): 60 | if feature_extractor == 'cnn': 61 | return __cnn_params(**kwargs) 62 | elif feature_extractor == 'rnn': 63 | return __rnn_params(**kwargs) 64 | elif feature_extractor == 'vgg16': 65 | return __vgg16_params(**kwargs) 66 | elif feature_extractor == 'fusion': 67 | return __fusion_params(**kwargs) 68 | 69 | def __cnn_params(classifier, N, factor, dropout1, dropout2, rnn_dropout, filters, learnall_classifier): 70 | return f'filters-{filters}-N-{N}_factor-{factor}-do1-{dropout1}-do2-{dropout2}-classifier-{classifier}-learnall_classifier-{learnall_classifier}{"-rd-"+str(rnn_dropout) if classifier == "rnn" else ""}' 71 | 72 | def __fusion_params(N, factor, dropout1, dropout2, rnn_dropout, filters, hidden_dim, cell, bidirectional, number_of_layers, down_pool): 73 | return f'filters-{filters}-N-{N}_factor-{factor}-do1-{dropout1}-do2-{dropout2}-cell-{cell}-bidirectional-{bidirectional}-hidden_dim-{number_of_layers}x{hidden_dim}-rd{rnn_dropout}-downpool-{down_pool}' 74 | 75 | 76 | def __rnn_params(hidden_dim, cell, bidirectional, dropout, number_of_layers, down_pool, num_mfccs, use_attention, share_attention, input_projection): 77 | return f'cell-{cell}-bidirectional-{bidirectional}-hidden_dim-{number_of_layers}x{hidden_dim}-do-{dropout}-downpool-{down_pool}-mfccs-{num_mfccs}-attention-{use_attention}-shareAttention-{share_attention}-ip-{input_projection}' 78 | 79 | def __vgg16_params(freeze_up_to, classifier, dropout): 80 | return f'freezeUpTo-{freeze_up_to}-classifier-{classifier}-dropout-{dropout}' 81 | 82 | 83 | def train_single_task( 84 | initial_weights='/mnt/student/MauriceGerczuk/EmoSet/experiments/residual-adapters-emonet-revised/parallel/128mels/2s/scratch/N-2_factor-1-balancedClassWeights-True/GEMEP/weights_GEMEP.h5', 85 | feature_extractor='cnn', 86 | batch_size=64, 87 | epochs=50, 88 | balanced_weights=True, 89 | window=6, 90 | num_mels=128, 91 | task='IEMOCAP-4cl', 92 | loss='categorical_cross_entropy', 93 | directory='/mnt/nas/data_work/shahin/EmoSet/wavs-reordered/', 94 | train_csv='train.csv', 95 | val_csv='val.csv', 96 | test_csv='test.csv', 97 | experiment_base_path='/mnt/student/MauriceGerczuk/EmoSet/experiments/residual-adapters-emonet-revised', 98 | random_noise=None, 99 | learnall=False, 100 | last_layer_only=False, 101 | initial_learning_rate=0.1, 102 | optimizer=tf.keras.optimizers.SGD, 103 | n_workers=5, 104 | patience=20, 105 | mode='adapters', 106 | input_bn=False, 107 | share_feature_layer=False, 108 | individual_weight_decay=False, 109 | **kwargs): 110 | if feature_extractor in ['cnn', 'vgg16']: 111 | variable_duration = False if kwargs['classifier'] == 'avgpool' else True 112 | else: 113 | variable_duration = True 114 | #variable_duration = False 115 | base_tasks = None 116 | base_nb_classes = None 117 | feature_extractor_params = __feature_extractor_params_string(feature_extractor, **kwargs) 118 | training_params = f'balancedClassWeights-{balanced_weights}-loss-{loss}-optimizer-{optimizer.__name__}-lr-{initial_learning_rate}-bs-{batch_size}-patience-{patience}-random_noise-{random_noise}-numMels-{num_mels}-ib-{input_bn}-sfl-{share_feature_layer}-iwd-{individual_weight_decay}' 119 | experiment_base_path = f"{join(experiment_base_path, 'single-task', feature_extractor, mode, f'Window-{window}s', feature_extractor_params, training_params)}" 120 | 121 | train_generator = AudioDataGenerator(train_csv, 122 | directory, 123 | batch_size=batch_size, 124 | window=window, 125 | shuffle=True, 126 | sr=16000, 127 | time_stretch=None, 128 | pitch_shift=None, 129 | save_dir=None, 130 | val_split=None, 131 | subset='train', 132 | variable_duration=variable_duration) 133 | val_generator = AudioDataGenerator(val_csv, 134 | directory, 135 | batch_size=batch_size, 136 | window=window, 137 | shuffle=False, 138 | sr=16000, 139 | time_stretch=None, 140 | pitch_shift=None, 141 | save_dir=None, 142 | variable_duration=variable_duration) 143 | test_generator = AudioDataGenerator(test_csv, 144 | directory, 145 | batch_size=batch_size, 146 | window=window, 147 | shuffle=False, 148 | sr=16000, 149 | time_stretch=None, 150 | pitch_shift=None, 151 | save_dir=None, 152 | variable_duration=variable_duration) 153 | val_dataset = val_generator.tf_dataset() 154 | test_dataset = test_generator.tf_dataset() 155 | 156 | decay = determine_decay(train_generator, batch_size) 157 | 158 | if initial_weights is not None and mode == 'adapters': 159 | new_tasks = [task] 160 | new_nb_classes = [len(set(train_generator.classes))] 161 | new_weight_decays = [decay] if individual_weight_decay else None 162 | base_weight_decays = None 163 | 164 | else: 165 | base_tasks = [task] 166 | base_nb_classes = [len(set(train_generator.classes))] 167 | base_weight_decays = [decay] if individual_weight_decay else None 168 | new_tasks = [] 169 | new_nb_classes = [] 170 | new_weight_decays = None 171 | 172 | if balanced_weights: 173 | class_weights = class_weight.compute_class_weight( 174 | 'balanced', np.unique(train_generator.classes), 175 | train_generator.classes) 176 | class_weight_dict = dict(enumerate(class_weights)) 177 | logger.info(f'Class weights: {class_weight_dict}') 178 | else: 179 | class_weight_dict = None 180 | logger.info('Not using class weights.') 181 | 182 | task_base_path = join(experiment_base_path, task) 183 | weights = join(task_base_path, "weights_" + task + ".h5") 184 | 185 | 186 | x, _ = train_generator[0] 187 | if not variable_duration: 188 | init = x.shape[1:] 189 | else: 190 | init = (None, ) 191 | models, shared_model = create_multi_task_networks( 192 | init, 193 | feature_extractor=feature_extractor, 194 | initial_weights=initial_weights, 195 | num_mels=num_mels, 196 | mode=mode, 197 | base_nb_classes=base_nb_classes, 198 | base_weight_decays=base_weight_decays, 199 | learnall=learnall, 200 | base_tasks=base_tasks, 201 | new_tasks=new_tasks, 202 | new_nb_classes=new_nb_classes, 203 | new_weight_decays=new_weight_decays, 204 | random_noise=random_noise, 205 | input_bn=input_bn, 206 | share_feature_layer=share_feature_layer, 207 | **kwargs) 208 | model = models[task] 209 | #model.load_weights(initial_weights, by_name=True) 210 | if last_layer_only: 211 | for layer in model.layers[:-5]: 212 | layer.trainable = False 213 | model.summary() 214 | #print(model.non_trainable_weights) 215 | #model.load_weights(initial_weights, by_name=True) 216 | 217 | tbCallBack = tf.keras.callbacks.TensorBoard(log_dir=join(task_base_path, 'log'), 218 | histogram_freq=0, 219 | write_graph=True) 220 | #hpCallback = hp.KerasCallback(join(task_base_path, 'log', 'hparam_tuning'), hparams) 221 | mc = tf.keras.callbacks.ModelCheckpoint(weights, 222 | monitor='val_recall/macro/validation', 223 | verbose=1, 224 | save_best_only=True, 225 | save_weights_only=False, 226 | mode='max', 227 | period=1) 228 | metric_callback = ClassificationMetricCallback( 229 | validation_generator=val_dataset.prefetch(tf.data.experimental.AUTOTUNE), dataset_name=task, labels=val_generator.class_indices, true=val_generator.categorical_classes) 230 | metric_callback_test = ClassificationMetricCallback( 231 | validation_generator=test_dataset.prefetch(tf.data.experimental.AUTOTUNE), dataset_name=task, partition='test', labels=test_generator.class_indices, true=test_generator.categorical_classes) 232 | 233 | lrs = [ 234 | initial_learning_rate, initial_learning_rate * 0.1, 235 | initial_learning_rate * 0.01 236 | ] 237 | makedirs(task_base_path, exist_ok=True) 238 | stopped_epoch = 0 239 | best = 0 240 | patience = patience 241 | load = False 242 | loss_string = loss 243 | if len(set(train_generator.classes)) == 2: 244 | loss = binary_loss_map[loss_string] 245 | else: 246 | loss = categorical_loss_map[loss_string] 247 | if loss_string == 'ordinal': 248 | loss = loss(n_classes=len(set(train_generator.classes))) 249 | for i, lr in enumerate(lrs): 250 | if optimizer.__name__ == tf.keras.optimizers.SGD.__name__: 251 | opt = optimizer(learning_rate=lr, decay=decay if not individual_weight_decay else 1e-6, momentum=0.9, nesterov=False) 252 | else: 253 | opt = optimizer(learning_rate=lr) 254 | model.compile(loss=loss, optimizer=opt, metrics=["acc"], experimental_run_tf_function=False) 255 | 256 | if load: 257 | model.load_weights(weights) 258 | logger.info("Model loaded.") 259 | early_stopper = tf.keras.callbacks.EarlyStopping(monitor='val_recall/macro/validation', 260 | min_delta=0.005, 261 | patience=patience, 262 | verbose=1, 263 | mode='max', 264 | restore_best_weights=False, 265 | baseline=best) 266 | model.fit(train_generator.tf_dataset().prefetch(tf.data.experimental.AUTOTUNE), 267 | validation_data=val_generator.tf_dataset(), 268 | epochs=epochs, 269 | workers=n_workers // 2, 270 | initial_epoch=stopped_epoch, 271 | class_weight=class_weight_dict, 272 | # use_multiprocessing=True, 273 | # max_queue_size=n_workers * 2, 274 | verbose=2, 275 | callbacks=[ 276 | metric_callback, metric_callback_test, 277 | early_stopper, tbCallBack, mc 278 | ]) 279 | load = True 280 | stopped_epoch = early_stopper.stopped_epoch 281 | best = early_stopper.best 282 | 283 | 284 | def train_multi_task( 285 | batch_size=64, 286 | epochs=50, 287 | balanced_weights=True, 288 | feature_extractor='cnn', 289 | window=2, 290 | num_mels=128, 291 | mode='adapters', 292 | initial_learning_rate=0.1, 293 | tasks=[ 294 | "AirplaneBehaviourCorpus", "AngerDetection", "BurmeseEmotionalSpeech", 295 | "CASIA", "ChineseVocalEmotions", "DanishEmotionalSpeech", "DEMoS", 296 | "EA-ACT", "EA-BMW", "EA-WSJ", "EMO-DB", "EmoFilm", "EmotiW-2014", 297 | "ENTERFACE", "EU-EmoSS", "FAU_AIBO", "GEMEP", "GVESS", 298 | "MandarinEmotionalSpeech", "MELD", "PPMK-EMO", "SIMIS", "SMARTKOM", 299 | "SUSAS", "TurkishEmoBUEE" 300 | ], 301 | loss='crossentropy', 302 | directory='/mnt/nas/data_work/shahin/EmoSet/wavs-reordered/', 303 | experiment_base_path='/mnt/student/MauriceGerczuk/EmoSet/experiments/residual-adapters-emonet-revised', 304 | multi_task_setup='/mnt/student/MauriceGerczuk/EmoSet/multiTaskSetup-wavs-with-test/', 305 | steps_per_epoch=20, 306 | optimizer=tf.keras.optimizers.SGD, 307 | random_noise=None, 308 | input_bn=False, 309 | share_feature_layer=False, 310 | individual_weight_decay=False, 311 | **kwargs): 312 | if feature_extractor == 'cnn': 313 | variable_duration = False if kwargs['classifier'] == 'avgpool' else True 314 | else: 315 | variable_duration = True 316 | feature_extractor_params = __feature_extractor_params_string(feature_extractor, **kwargs) 317 | training_params = f'balancedClassWeights-{balanced_weights}-loss-{loss}-optimizer-{optimizer.__name__}-lr-{initial_learning_rate}-bs-{batch_size}-epochs-{epochs}-spe-{steps_per_epoch}-random_noise-{random_noise}-numMels-{num_mels}-ib-{input_bn}-sfl-{share_feature_layer}-iwd-{individual_weight_decay}' 318 | experiment_base_path = f"{join(experiment_base_path, 'multi-task', '-'.join(map(lambda x: x[:4], tasks)), feature_extractor, mode, f'Window-{window}s', feature_extractor_params, training_params)}" 319 | 320 | 321 | train_generators = [ 322 | AudioDataGenerator(f'{multi_task_setup}/{task}/train.csv', 323 | directory, 324 | batch_size=batch_size, 325 | window=window, 326 | shuffle=True, 327 | sr=16000, 328 | time_stretch=None, 329 | pitch_shift=None, 330 | variable_duration=variable_duration, 331 | save_dir=None, 332 | val_split=None, 333 | subset='train') for task in tasks 334 | ] 335 | val_generators = [ 336 | AudioDataGenerator(f'{multi_task_setup}/{task}/val.csv', 337 | directory, 338 | batch_size=batch_size, 339 | window=window, 340 | shuffle=False, 341 | sr=16000, 342 | time_stretch=None, 343 | variable_duration=variable_duration, 344 | pitch_shift=None, 345 | save_dir=None) for task in tasks 346 | ] 347 | test_generators = [ 348 | AudioDataGenerator(f'{multi_task_setup}/{task}/test.csv', 349 | directory, 350 | batch_size=batch_size, 351 | window=window, 352 | shuffle=False, 353 | sr=16000, 354 | time_stretch=None, 355 | variable_duration=variable_duration, 356 | pitch_shift=None, 357 | save_dir=None) for task in tasks 358 | ] 359 | 360 | train_datasets = tuple(gen.tf_dataset().repeat() for gen in train_generators) 361 | val_datasets = tuple(gen.tf_dataset() for gen in val_generators) 362 | test_datasets = tuple(gen.tf_dataset() for gen in test_generators) 363 | 364 | 365 | 366 | if balanced_weights: 367 | class_weights = [ 368 | class_weight.compute_class_weight('balanced', np.unique(t.classes), 369 | t.classes) 370 | for t in train_generators 371 | ] 372 | class_weight_dicts = [dict(enumerate(cw)) for cw in class_weights] 373 | logger.info(f'Class weights: {class_weight_dicts}') 374 | else: 375 | class_weight_dicts = [None] * len(tasks) 376 | logger.info('Not using class weights.') 377 | 378 | task_base_paths = [join(experiment_base_path, task) for task in tasks] 379 | weight_paths = [ 380 | join(task_base_path, "weights_" + task + ".h5") 381 | for task_base_path, task in zip(task_base_paths, tasks) 382 | ] 383 | 384 | tbCallBacks = [ 385 | tf.keras.callbacks.TensorBoard(log_dir=join(task_base_path, 'log'), 386 | histogram_freq=0, 387 | write_graph=True) for task_base_path in task_base_paths 388 | ] 389 | 390 | metric_callbacks = [ 391 | ClassificationMetricCallback(validation_generator=val_dataset, 392 | period=1, dataset_name=task, labels=val_generator.class_indices) 393 | for val_dataset, val_generator, task in zip(val_datasets, val_generators, tasks) 394 | ] 395 | metric_callbacks_test = [ 396 | ClassificationMetricCallback(validation_generator=test_dataset, 397 | partition='test', 398 | period=1, 399 | dataset_name=task, 400 | labels=test_generator.class_indices) 401 | for test_dataset, test_generator, task in zip(test_datasets, test_generators, tasks) 402 | ] 403 | decays = [determine_decay(tg, batch_size) for tg in train_generators] 404 | 405 | #steps_per_epoch = 10 406 | x, _ = train_generators[0][0] 407 | if not variable_duration: 408 | init = x.shape[1:] 409 | else: 410 | init = (None, ) 411 | 412 | lrs = [ 413 | initial_learning_rate, initial_learning_rate * 0.1, 414 | initial_learning_rate * 0.01 415 | ] 416 | nb_classes = [len(tg.class_indices) for tg in train_generators] 417 | models, shared_model = create_multi_task_networks( 418 | init, 419 | feature_extractor=feature_extractor, 420 | mode=mode, 421 | num_mels=num_mels, 422 | base_nb_classes=nb_classes, 423 | learnall=True, 424 | base_tasks=tasks, 425 | base_weight_decays=decays if individual_weight_decay else None, 426 | random_noise=random_noise, 427 | input_bn=input_bn, 428 | share_feature_layer=share_feature_layer, 429 | **kwargs) 430 | shared_model.summary() 431 | for i, t in enumerate(tasks): 432 | tbCallBacks[i].set_model(models[t]) 433 | metric_callbacks[i].set_model(models[t]) 434 | metric_callbacks_test[i].set_model(models[t]) 435 | 436 | max_steps= epochs * steps_per_epoch 437 | loss_string = loss 438 | for step, lr in enumerate(lrs): 439 | for i, batch in tqdm(tf.data.Dataset.zip(train_datasets).enumerate().prefetch(1), total=max_steps): 440 | if i >= max_steps: 441 | break 442 | for t, task in enumerate(tasks): 443 | model = models[task] 444 | if i == 0: # reset learning rate 445 | if len(set(train_generators[t].classes)) == 2: 446 | loss = binary_loss_map[loss_string] 447 | else: 448 | loss = categorical_loss_map[loss_string] 449 | if loss_string == 'ordinal': 450 | loss = loss(n_classes=len(set(train_generators[t].classes))) 451 | if optimizer.__name__ == tf.keras.optimizers.SGD.__name__: 452 | opt = optimizer(lr=lr, 453 | decay=decays[t] if not individual_weight_decay else 1e-6, 454 | momentum=0.9, 455 | nesterov=False) 456 | else: 457 | opt = optimizer(lr=lr) 458 | model.compile(loss=loss, optimizer=opt, metrics=["acc"]) 459 | # if i % len(train_generators[t]) == 0: 460 | # train_generators[t].on_epoch_end() 461 | logs = model.train_on_batch(*batch[t]) 462 | 463 | named_l = named_logs(model, logs) 464 | # loss = named_l["loss"] 465 | # logger.info(f'Step {i}: loss {loss} ({task})') 466 | logger.debug(f'i % steps_per_epoch: {i%steps_per_epoch}') 467 | if i % steps_per_epoch == 0: 468 | logger.debug('In epoch end') 469 | metric_callbacks[t].on_epoch_end( 470 | i // steps_per_epoch + step * epochs, named_l) 471 | metric_callbacks_test[t].on_epoch_end( 472 | i // steps_per_epoch + step * epochs, named_l) 473 | model.save(weight_paths[t]) 474 | tbCallBacks[t].on_epoch_end( 475 | i // steps_per_epoch + step * epochs, named_l) 476 | shared_model.save(join(experiment_base_path, 'shared_model.h5')) 477 | -------------------------------------------------------------------------------- /emo-net/utils.py: -------------------------------------------------------------------------------- 1 | # EmoNet 2 | # ============================================================================== 3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian, 4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved. 5 | # 6 | # This program is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU General Public License 17 | # along with this program. If not, see . 18 | # ============================================================================== 19 | import numpy as np 20 | 21 | def array_list_equal(a_list, b_list): 22 | if type(a_list) == list and type(b_list) == list: 23 | if len(a_list) != len(b_list): 24 | return False 25 | else: 26 | for a, b in zip(a_list, b_list): 27 | if not np.array_equal(a,b): 28 | return False 29 | return True 30 | elif type(a_list) == np.array and type(b_list) == np.array: 31 | return np.array_equal(a_list, b_list) 32 | else: 33 | return False -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Click 2 | dill 3 | imbalanced-learn 4 | librosa 5 | numpy 6 | pandas 7 | Pillow 8 | numba==0.48.* 9 | scikit-learn==0.22 10 | tensorflow==2.2.* 11 | tqdm --------------------------------------------------------------------------------