├── .gitignore
├── LICENSE.md
├── README.md
├── emo-net
├── __init__.py
├── cli.py
├── data
│ ├── __init__.py
│ ├── compute_scaling.py
│ └── loader.py
├── models
│ ├── __init__.py
│ ├── adapter_resnet.py
│ ├── adapter_rnn.py
│ ├── attention.py
│ ├── build_model.py
│ └── input_layers.py
├── training
│ ├── __init__.py
│ ├── evaluate.py
│ ├── losses.py
│ ├── metrics.py
│ └── train.py
└── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | auDeep.egg-info
3 | build/
4 | __pycache__
5 | .vscode
6 | .env/
7 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ### GNU GENERAL PUBLIC LICENSE
2 |
3 | Version 3, 29 June 2007
4 |
5 | Copyright (C) 2007 Free Software Foundation, Inc.
6 |
7 |
8 | Everyone is permitted to copy and distribute verbatim copies of this
9 | license document, but changing it is not allowed.
10 |
11 | ### Preamble
12 |
13 | The GNU General Public License is a free, copyleft license for
14 | software and other kinds of works.
15 |
16 | The licenses for most software and other practical works are designed
17 | to take away your freedom to share and change the works. By contrast,
18 | the GNU General Public License is intended to guarantee your freedom
19 | to share and change all versions of a program--to make sure it remains
20 | free software for all its users. We, the Free Software Foundation, use
21 | the GNU General Public License for most of our software; it applies
22 | also to any other work released this way by its authors. You can apply
23 | it to your programs, too.
24 |
25 | When we speak of free software, we are referring to freedom, not
26 | price. Our General Public Licenses are designed to make sure that you
27 | have the freedom to distribute copies of free software (and charge for
28 | them if you wish), that you receive source code or can get it if you
29 | want it, that you can change the software or use pieces of it in new
30 | free programs, and that you know you can do these things.
31 |
32 | To protect your rights, we need to prevent others from denying you
33 | these rights or asking you to surrender the rights. Therefore, you
34 | have certain responsibilities if you distribute copies of the
35 | software, or if you modify it: responsibilities to respect the freedom
36 | of others.
37 |
38 | For example, if you distribute copies of such a program, whether
39 | gratis or for a fee, you must pass on to the recipients the same
40 | freedoms that you received. You must make sure that they, too, receive
41 | or can get the source code. And you must show them these terms so they
42 | know their rights.
43 |
44 | Developers that use the GNU GPL protect your rights with two steps:
45 | (1) assert copyright on the software, and (2) offer you this License
46 | giving you legal permission to copy, distribute and/or modify it.
47 |
48 | For the developers' and authors' protection, the GPL clearly explains
49 | that there is no warranty for this free software. For both users' and
50 | authors' sake, the GPL requires that modified versions be marked as
51 | changed, so that their problems will not be attributed erroneously to
52 | authors of previous versions.
53 |
54 | Some devices are designed to deny users access to install or run
55 | modified versions of the software inside them, although the
56 | manufacturer can do so. This is fundamentally incompatible with the
57 | aim of protecting users' freedom to change the software. The
58 | systematic pattern of such abuse occurs in the area of products for
59 | individuals to use, which is precisely where it is most unacceptable.
60 | Therefore, we have designed this version of the GPL to prohibit the
61 | practice for those products. If such problems arise substantially in
62 | other domains, we stand ready to extend this provision to those
63 | domains in future versions of the GPL, as needed to protect the
64 | freedom of users.
65 |
66 | Finally, every program is threatened constantly by software patents.
67 | States should not allow patents to restrict development and use of
68 | software on general-purpose computers, but in those that do, we wish
69 | to avoid the special danger that patents applied to a free program
70 | could make it effectively proprietary. To prevent this, the GPL
71 | assures that patents cannot be used to render the program non-free.
72 |
73 | The precise terms and conditions for copying, distribution and
74 | modification follow.
75 |
76 | ### TERMS AND CONDITIONS
77 |
78 | #### 0. Definitions.
79 |
80 | "This License" refers to version 3 of the GNU General Public License.
81 |
82 | "Copyright" also means copyright-like laws that apply to other kinds
83 | of works, such as semiconductor masks.
84 |
85 | "The Program" refers to any copyrightable work licensed under this
86 | License. Each licensee is addressed as "you". "Licensees" and
87 | "recipients" may be individuals or organizations.
88 |
89 | To "modify" a work means to copy from or adapt all or part of the work
90 | in a fashion requiring copyright permission, other than the making of
91 | an exact copy. The resulting work is called a "modified version" of
92 | the earlier work or a work "based on" the earlier work.
93 |
94 | A "covered work" means either the unmodified Program or a work based
95 | on the Program.
96 |
97 | To "propagate" a work means to do anything with it that, without
98 | permission, would make you directly or secondarily liable for
99 | infringement under applicable copyright law, except executing it on a
100 | computer or modifying a private copy. Propagation includes copying,
101 | distribution (with or without modification), making available to the
102 | public, and in some countries other activities as well.
103 |
104 | To "convey" a work means any kind of propagation that enables other
105 | parties to make or receive copies. Mere interaction with a user
106 | through a computer network, with no transfer of a copy, is not
107 | conveying.
108 |
109 | An interactive user interface displays "Appropriate Legal Notices" to
110 | the extent that it includes a convenient and prominently visible
111 | feature that (1) displays an appropriate copyright notice, and (2)
112 | tells the user that there is no warranty for the work (except to the
113 | extent that warranties are provided), that licensees may convey the
114 | work under this License, and how to view a copy of this License. If
115 | the interface presents a list of user commands or options, such as a
116 | menu, a prominent item in the list meets this criterion.
117 |
118 | #### 1. Source Code.
119 |
120 | The "source code" for a work means the preferred form of the work for
121 | making modifications to it. "Object code" means any non-source form of
122 | a work.
123 |
124 | A "Standard Interface" means an interface that either is an official
125 | standard defined by a recognized standards body, or, in the case of
126 | interfaces specified for a particular programming language, one that
127 | is widely used among developers working in that language.
128 |
129 | The "System Libraries" of an executable work include anything, other
130 | than the work as a whole, that (a) is included in the normal form of
131 | packaging a Major Component, but which is not part of that Major
132 | Component, and (b) serves only to enable use of the work with that
133 | Major Component, or to implement a Standard Interface for which an
134 | implementation is available to the public in source code form. A
135 | "Major Component", in this context, means a major essential component
136 | (kernel, window system, and so on) of the specific operating system
137 | (if any) on which the executable work runs, or a compiler used to
138 | produce the work, or an object code interpreter used to run it.
139 |
140 | The "Corresponding Source" for a work in object code form means all
141 | the source code needed to generate, install, and (for an executable
142 | work) run the object code and to modify the work, including scripts to
143 | control those activities. However, it does not include the work's
144 | System Libraries, or general-purpose tools or generally available free
145 | programs which are used unmodified in performing those activities but
146 | which are not part of the work. For example, Corresponding Source
147 | includes interface definition files associated with source files for
148 | the work, and the source code for shared libraries and dynamically
149 | linked subprograms that the work is specifically designed to require,
150 | such as by intimate data communication or control flow between those
151 | subprograms and other parts of the work.
152 |
153 | The Corresponding Source need not include anything that users can
154 | regenerate automatically from other parts of the Corresponding Source.
155 |
156 | The Corresponding Source for a work in source code form is that same
157 | work.
158 |
159 | #### 2. Basic Permissions.
160 |
161 | All rights granted under this License are granted for the term of
162 | copyright on the Program, and are irrevocable provided the stated
163 | conditions are met. This License explicitly affirms your unlimited
164 | permission to run the unmodified Program. The output from running a
165 | covered work is covered by this License only if the output, given its
166 | content, constitutes a covered work. This License acknowledges your
167 | rights of fair use or other equivalent, as provided by copyright law.
168 |
169 | You may make, run and propagate covered works that you do not convey,
170 | without conditions so long as your license otherwise remains in force.
171 | You may convey covered works to others for the sole purpose of having
172 | them make modifications exclusively for you, or provide you with
173 | facilities for running those works, provided that you comply with the
174 | terms of this License in conveying all material for which you do not
175 | control copyright. Those thus making or running the covered works for
176 | you must do so exclusively on your behalf, under your direction and
177 | control, on terms that prohibit them from making any copies of your
178 | copyrighted material outside their relationship with you.
179 |
180 | Conveying under any other circumstances is permitted solely under the
181 | conditions stated below. Sublicensing is not allowed; section 10 makes
182 | it unnecessary.
183 |
184 | #### 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
185 |
186 | No covered work shall be deemed part of an effective technological
187 | measure under any applicable law fulfilling obligations under article
188 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
189 | similar laws prohibiting or restricting circumvention of such
190 | measures.
191 |
192 | When you convey a covered work, you waive any legal power to forbid
193 | circumvention of technological measures to the extent such
194 | circumvention is effected by exercising rights under this License with
195 | respect to the covered work, and you disclaim any intention to limit
196 | operation or modification of the work as a means of enforcing, against
197 | the work's users, your or third parties' legal rights to forbid
198 | circumvention of technological measures.
199 |
200 | #### 4. Conveying Verbatim Copies.
201 |
202 | You may convey verbatim copies of the Program's source code as you
203 | receive it, in any medium, provided that you conspicuously and
204 | appropriately publish on each copy an appropriate copyright notice;
205 | keep intact all notices stating that this License and any
206 | non-permissive terms added in accord with section 7 apply to the code;
207 | keep intact all notices of the absence of any warranty; and give all
208 | recipients a copy of this License along with the Program.
209 |
210 | You may charge any price or no price for each copy that you convey,
211 | and you may offer support or warranty protection for a fee.
212 |
213 | #### 5. Conveying Modified Source Versions.
214 |
215 | You may convey a work based on the Program, or the modifications to
216 | produce it from the Program, in the form of source code under the
217 | terms of section 4, provided that you also meet all of these
218 | conditions:
219 |
220 | - a) The work must carry prominent notices stating that you modified
221 | it, and giving a relevant date.
222 | - b) The work must carry prominent notices stating that it is
223 | released under this License and any conditions added under
224 | section 7. This requirement modifies the requirement in section 4
225 | to "keep intact all notices".
226 | - c) You must license the entire work, as a whole, under this
227 | License to anyone who comes into possession of a copy. This
228 | License will therefore apply, along with any applicable section 7
229 | additional terms, to the whole of the work, and all its parts,
230 | regardless of how they are packaged. This License gives no
231 | permission to license the work in any other way, but it does not
232 | invalidate such permission if you have separately received it.
233 | - d) If the work has interactive user interfaces, each must display
234 | Appropriate Legal Notices; however, if the Program has interactive
235 | interfaces that do not display Appropriate Legal Notices, your
236 | work need not make them do so.
237 |
238 | A compilation of a covered work with other separate and independent
239 | works, which are not by their nature extensions of the covered work,
240 | and which are not combined with it such as to form a larger program,
241 | in or on a volume of a storage or distribution medium, is called an
242 | "aggregate" if the compilation and its resulting copyright are not
243 | used to limit the access or legal rights of the compilation's users
244 | beyond what the individual works permit. Inclusion of a covered work
245 | in an aggregate does not cause this License to apply to the other
246 | parts of the aggregate.
247 |
248 | #### 6. Conveying Non-Source Forms.
249 |
250 | You may convey a covered work in object code form under the terms of
251 | sections 4 and 5, provided that you also convey the machine-readable
252 | Corresponding Source under the terms of this License, in one of these
253 | ways:
254 |
255 | - a) Convey the object code in, or embodied in, a physical product
256 | (including a physical distribution medium), accompanied by the
257 | Corresponding Source fixed on a durable physical medium
258 | customarily used for software interchange.
259 | - b) Convey the object code in, or embodied in, a physical product
260 | (including a physical distribution medium), accompanied by a
261 | written offer, valid for at least three years and valid for as
262 | long as you offer spare parts or customer support for that product
263 | model, to give anyone who possesses the object code either (1) a
264 | copy of the Corresponding Source for all the software in the
265 | product that is covered by this License, on a durable physical
266 | medium customarily used for software interchange, for a price no
267 | more than your reasonable cost of physically performing this
268 | conveying of source, or (2) access to copy the Corresponding
269 | Source from a network server at no charge.
270 | - c) Convey individual copies of the object code with a copy of the
271 | written offer to provide the Corresponding Source. This
272 | alternative is allowed only occasionally and noncommercially, and
273 | only if you received the object code with such an offer, in accord
274 | with subsection 6b.
275 | - d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 | - e) Convey the object code using peer-to-peer transmission,
288 | provided you inform other peers where the object code and
289 | Corresponding Source of the work are being offered to the general
290 | public at no charge under subsection 6d.
291 |
292 | A separable portion of the object code, whose source code is excluded
293 | from the Corresponding Source as a System Library, need not be
294 | included in conveying the object code work.
295 |
296 | A "User Product" is either (1) a "consumer product", which means any
297 | tangible personal property which is normally used for personal,
298 | family, or household purposes, or (2) anything designed or sold for
299 | incorporation into a dwelling. In determining whether a product is a
300 | consumer product, doubtful cases shall be resolved in favor of
301 | coverage. For a particular product received by a particular user,
302 | "normally used" refers to a typical or common use of that class of
303 | product, regardless of the status of the particular user or of the way
304 | in which the particular user actually uses, or expects or is expected
305 | to use, the product. A product is a consumer product regardless of
306 | whether the product has substantial commercial, industrial or
307 | non-consumer uses, unless such uses represent the only significant
308 | mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to
312 | install and execute modified versions of a covered work in that User
313 | Product from a modified version of its Corresponding Source. The
314 | information must suffice to ensure that the continued functioning of
315 | the modified object code is in no case prevented or interfered with
316 | solely because modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or
331 | updates for a work that has been modified or installed by the
332 | recipient, or for the User Product in which it has been modified or
333 | installed. Access to a network may be denied when the modification
334 | itself materially and adversely affects the operation of the network
335 | or violates the rules and protocols for communication across the
336 | network.
337 |
338 | Corresponding Source conveyed, and Installation Information provided,
339 | in accord with this section must be in a format that is publicly
340 | documented (and with an implementation available to the public in
341 | source code form), and must require no special password or key for
342 | unpacking, reading or copying.
343 |
344 | #### 7. Additional Terms.
345 |
346 | "Additional permissions" are terms that supplement the terms of this
347 | License by making exceptions from one or more of its conditions.
348 | Additional permissions that are applicable to the entire Program shall
349 | be treated as though they were included in this License, to the extent
350 | that they are valid under applicable law. If additional permissions
351 | apply only to part of the Program, that part may be used separately
352 | under those permissions, but the entire Program remains governed by
353 | this License without regard to the additional permissions.
354 |
355 | When you convey a copy of a covered work, you may at your option
356 | remove any additional permissions from that copy, or from any part of
357 | it. (Additional permissions may be written to require their own
358 | removal in certain cases when you modify the work.) You may place
359 | additional permissions on material, added by you to a covered work,
360 | for which you have or can give appropriate copyright permission.
361 |
362 | Notwithstanding any other provision of this License, for material you
363 | add to a covered work, you may (if authorized by the copyright holders
364 | of that material) supplement the terms of this License with terms:
365 |
366 | - a) Disclaiming warranty or limiting liability differently from the
367 | terms of sections 15 and 16 of this License; or
368 | - b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 | - c) Prohibiting misrepresentation of the origin of that material,
372 | or requiring that modified versions of such material be marked in
373 | reasonable ways as different from the original version; or
374 | - d) Limiting the use for publicity purposes of names of licensors
375 | or authors of the material; or
376 | - e) Declining to grant rights under trademark law for use of some
377 | trade names, trademarks, or service marks; or
378 | - f) Requiring indemnification of licensors and authors of that
379 | material by anyone who conveys the material (or modified versions
380 | of it) with contractual assumptions of liability to the recipient,
381 | for any liability that these contractual assumptions directly
382 | impose on those licensors and authors.
383 |
384 | All other non-permissive additional terms are considered "further
385 | restrictions" within the meaning of section 10. If the Program as you
386 | received it, or any part of it, contains a notice stating that it is
387 | governed by this License along with a term that is a further
388 | restriction, you may remove that term. If a license document contains
389 | a further restriction but permits relicensing or conveying under this
390 | License, you may add to a covered work material governed by the terms
391 | of that license document, provided that the further restriction does
392 | not survive such relicensing or conveying.
393 |
394 | If you add terms to a covered work in accord with this section, you
395 | must place, in the relevant source files, a statement of the
396 | additional terms that apply to those files, or a notice indicating
397 | where to find the applicable terms.
398 |
399 | Additional terms, permissive or non-permissive, may be stated in the
400 | form of a separately written license, or stated as exceptions; the
401 | above requirements apply either way.
402 |
403 | #### 8. Termination.
404 |
405 | You may not propagate or modify a covered work except as expressly
406 | provided under this License. Any attempt otherwise to propagate or
407 | modify it is void, and will automatically terminate your rights under
408 | this License (including any patent licenses granted under the third
409 | paragraph of section 11).
410 |
411 | However, if you cease all violation of this License, then your license
412 | from a particular copyright holder is reinstated (a) provisionally,
413 | unless and until the copyright holder explicitly and finally
414 | terminates your license, and (b) permanently, if the copyright holder
415 | fails to notify you of the violation by some reasonable means prior to
416 | 60 days after the cessation.
417 |
418 | Moreover, your license from a particular copyright holder is
419 | reinstated permanently if the copyright holder notifies you of the
420 | violation by some reasonable means, this is the first time you have
421 | received notice of violation of this License (for any work) from that
422 | copyright holder, and you cure the violation prior to 30 days after
423 | your receipt of the notice.
424 |
425 | Termination of your rights under this section does not terminate the
426 | licenses of parties who have received copies or rights from you under
427 | this License. If your rights have been terminated and not permanently
428 | reinstated, you do not qualify to receive new licenses for the same
429 | material under section 10.
430 |
431 | #### 9. Acceptance Not Required for Having Copies.
432 |
433 | You are not required to accept this License in order to receive or run
434 | a copy of the Program. Ancillary propagation of a covered work
435 | occurring solely as a consequence of using peer-to-peer transmission
436 | to receive a copy likewise does not require acceptance. However,
437 | nothing other than this License grants you permission to propagate or
438 | modify any covered work. These actions infringe copyright if you do
439 | not accept this License. Therefore, by modifying or propagating a
440 | covered work, you indicate your acceptance of this License to do so.
441 |
442 | #### 10. Automatic Licensing of Downstream Recipients.
443 |
444 | Each time you convey a covered work, the recipient automatically
445 | receives a license from the original licensors, to run, modify and
446 | propagate that work, subject to this License. You are not responsible
447 | for enforcing compliance by third parties with this License.
448 |
449 | An "entity transaction" is a transaction transferring control of an
450 | organization, or substantially all assets of one, or subdividing an
451 | organization, or merging organizations. If propagation of a covered
452 | work results from an entity transaction, each party to that
453 | transaction who receives a copy of the work also receives whatever
454 | licenses to the work the party's predecessor in interest had or could
455 | give under the previous paragraph, plus a right to possession of the
456 | Corresponding Source of the work from the predecessor in interest, if
457 | the predecessor has it or can get it with reasonable efforts.
458 |
459 | You may not impose any further restrictions on the exercise of the
460 | rights granted or affirmed under this License. For example, you may
461 | not impose a license fee, royalty, or other charge for exercise of
462 | rights granted under this License, and you may not initiate litigation
463 | (including a cross-claim or counterclaim in a lawsuit) alleging that
464 | any patent claim is infringed by making, using, selling, offering for
465 | sale, or importing the Program or any portion of it.
466 |
467 | #### 11. Patents.
468 |
469 | A "contributor" is a copyright holder who authorizes use under this
470 | License of the Program or a work on which the Program is based. The
471 | work thus licensed is called the contributor's "contributor version".
472 |
473 | A contributor's "essential patent claims" are all patent claims owned
474 | or controlled by the contributor, whether already acquired or
475 | hereafter acquired, that would be infringed by some manner, permitted
476 | by this License, of making, using, or selling its contributor version,
477 | but do not include claims that would be infringed only as a
478 | consequence of further modification of the contributor version. For
479 | purposes of this definition, "control" includes the right to grant
480 | patent sublicenses in a manner consistent with the requirements of
481 | this License.
482 |
483 | Each contributor grants you a non-exclusive, worldwide, royalty-free
484 | patent license under the contributor's essential patent claims, to
485 | make, use, sell, offer for sale, import and otherwise run, modify and
486 | propagate the contents of its contributor version.
487 |
488 | In the following three paragraphs, a "patent license" is any express
489 | agreement or commitment, however denominated, not to enforce a patent
490 | (such as an express permission to practice a patent or covenant not to
491 | sue for patent infringement). To "grant" such a patent license to a
492 | party means to make such an agreement or commitment not to enforce a
493 | patent against the party.
494 |
495 | If you convey a covered work, knowingly relying on a patent license,
496 | and the Corresponding Source of the work is not available for anyone
497 | to copy, free of charge and under the terms of this License, through a
498 | publicly available network server or other readily accessible means,
499 | then you must either (1) cause the Corresponding Source to be so
500 | available, or (2) arrange to deprive yourself of the benefit of the
501 | patent license for this particular work, or (3) arrange, in a manner
502 | consistent with the requirements of this License, to extend the patent
503 | license to downstream recipients. "Knowingly relying" means you have
504 | actual knowledge that, but for the patent license, your conveying the
505 | covered work in a country, or your recipient's use of the covered work
506 | in a country, would infringe one or more identifiable patents in that
507 | country that you have reason to believe are valid.
508 |
509 | If, pursuant to or in connection with a single transaction or
510 | arrangement, you convey, or propagate by procuring conveyance of, a
511 | covered work, and grant a patent license to some of the parties
512 | receiving the covered work authorizing them to use, propagate, modify
513 | or convey a specific copy of the covered work, then the patent license
514 | you grant is automatically extended to all recipients of the covered
515 | work and works based on it.
516 |
517 | A patent license is "discriminatory" if it does not include within the
518 | scope of its coverage, prohibits the exercise of, or is conditioned on
519 | the non-exercise of one or more of the rights that are specifically
520 | granted under this License. You may not convey a covered work if you
521 | are a party to an arrangement with a third party that is in the
522 | business of distributing software, under which you make payment to the
523 | third party based on the extent of your activity of conveying the
524 | work, and under which the third party grants, to any of the parties
525 | who would receive the covered work from you, a discriminatory patent
526 | license (a) in connection with copies of the covered work conveyed by
527 | you (or copies made from those copies), or (b) primarily for and in
528 | connection with specific products or compilations that contain the
529 | covered work, unless you entered into that arrangement, or that patent
530 | license was granted, prior to 28 March 2007.
531 |
532 | Nothing in this License shall be construed as excluding or limiting
533 | any implied license or other defenses to infringement that may
534 | otherwise be available to you under applicable patent law.
535 |
536 | #### 12. No Surrender of Others' Freedom.
537 |
538 | If conditions are imposed on you (whether by court order, agreement or
539 | otherwise) that contradict the conditions of this License, they do not
540 | excuse you from the conditions of this License. If you cannot convey a
541 | covered work so as to satisfy simultaneously your obligations under
542 | this License and any other pertinent obligations, then as a
543 | consequence you may not convey it at all. For example, if you agree to
544 | terms that obligate you to collect a royalty for further conveying
545 | from those to whom you convey the Program, the only way you could
546 | satisfy both those terms and this License would be to refrain entirely
547 | from conveying the Program.
548 |
549 | #### 13. Use with the GNU Affero General Public License.
550 |
551 | Notwithstanding any other provision of this License, you have
552 | permission to link or combine any covered work with a work licensed
553 | under version 3 of the GNU Affero General Public License into a single
554 | combined work, and to convey the resulting work. The terms of this
555 | License will continue to apply to the part which is the covered work,
556 | but the special requirements of the GNU Affero General Public License,
557 | section 13, concerning interaction through a network will apply to the
558 | combination as such.
559 |
560 | #### 14. Revised Versions of this License.
561 |
562 | The Free Software Foundation may publish revised and/or new versions
563 | of the GNU General Public License from time to time. Such new versions
564 | will be similar in spirit to the present version, but may differ in
565 | detail to address new problems or concerns.
566 |
567 | Each version is given a distinguishing version number. If the Program
568 | specifies that a certain numbered version of the GNU General Public
569 | License "or any later version" applies to it, you have the option of
570 | following the terms and conditions either of that numbered version or
571 | of any later version published by the Free Software Foundation. If the
572 | Program does not specify a version number of the GNU General Public
573 | License, you may choose any version ever published by the Free
574 | Software Foundation.
575 |
576 | If the Program specifies that a proxy can decide which future versions
577 | of the GNU General Public License can be used, that proxy's public
578 | statement of acceptance of a version permanently authorizes you to
579 | choose that version for the Program.
580 |
581 | Later license versions may give you additional or different
582 | permissions. However, no additional obligations are imposed on any
583 | author or copyright holder as a result of your choosing to follow a
584 | later version.
585 |
586 | #### 15. Disclaimer of Warranty.
587 |
588 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
589 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
590 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT
591 | WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT
592 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
593 | A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND
594 | PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE
595 | DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR
596 | CORRECTION.
597 |
598 | #### 16. Limitation of Liability.
599 |
600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR
602 | CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
603 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES
604 | ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT
605 | NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR
606 | LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM
607 | TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER
608 | PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
609 |
610 | #### 17. Interpretation of Sections 15 and 16.
611 |
612 | If the disclaimer of warranty and limitation of liability provided
613 | above cannot be given local legal effect according to their terms,
614 | reviewing courts shall apply local law that most closely approximates
615 | an absolute waiver of all civil liability in connection with the
616 | Program, unless a warranty or assumption of liability accompanies a
617 | copy of the Program in return for a fee.
618 |
619 | END OF TERMS AND CONDITIONS
620 |
621 | ### How to Apply These Terms to Your New Programs
622 |
623 | If you develop a new program, and you want it to be of the greatest
624 | possible use to the public, the best way to achieve this is to make it
625 | free software which everyone can redistribute and change under these
626 | terms.
627 |
628 | To do so, attach the following notices to the program. It is safest to
629 | attach them to the start of each source file to most effectively state
630 | the exclusion of warranty; and each file should have at least the
631 | "copyright" line and a pointer to where the full notice is found.
632 |
633 |
634 | Copyright (C)
635 |
636 | This program is free software: you can redistribute it and/or modify
637 | it under the terms of the GNU General Public License as published by
638 | the Free Software Foundation, either version 3 of the License, or
639 | (at your option) any later version.
640 |
641 | This program is distributed in the hope that it will be useful,
642 | but WITHOUT ANY WARRANTY; without even the implied warranty of
643 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
644 | GNU General Public License for more details.
645 |
646 | You should have received a copy of the GNU General Public License
647 | along with this program. If not, see .
648 |
649 | Also add information on how to contact you by electronic and paper
650 | mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands \`show w' and \`show c' should show the
661 | appropriate parts of the General Public License. Of course, your
662 | program's commands might be different; for a GUI interface, you would
663 | use an "about box".
664 |
665 | You should also get your employer (if you work as a programmer) or
666 | school, if any, to sign a "copyright disclaimer" for the program, if
667 | necessary. For more information on this, and how to apply and follow
668 | the GNU GPL, see .
669 |
670 | The GNU General Public License does not permit incorporating your
671 | program into proprietary programs. If your program is a subroutine
672 | library, you may consider it more useful to permit linking proprietary
673 | applications with the library. If this is what you want to do, use the
674 | GNU Lesser General Public License instead of this License. But first,
675 | please read .
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EmoNet
2 |
3 |
4 | **EmoNet** is a Python toolkit for multi-corpus speech emotion recognition and other audio classification tasks.
5 |
6 | **(c) 2021 Maurice Gerczuk, Shahin Amiriparian, Björn Schuller: Universität Augsburg**
7 |
8 | Please direct any questions or requests to Maurice Gerczuk (maurice.gerczuk at uni-a.de) or Shahin Amiriparian (shahin.amiriparian at uni-a.de).
9 |
10 | # Citing
11 | If you use EmoNet or any code from EmoNet in your research work, you are kindly asked to acknowledge the use of EmoNet in your publications.
12 | > M. Gerczuk, S. Amiriparian, S. Ottl, and B. Schuller, “EmoNet: A transfer learning framework for multi-corpus speech emotionrecognition,” 2021. [https://arxiv.org/abs/2103.08310](https://arxiv.org/abs/2103.08310)
13 |
14 |
15 | ```
16 | @misc{gerczuk2021emonet,
17 | title={EmoNet: A Transfer Learning Framework for Multi-Corpus Speech Emotion Recognition},
18 | author={Maurice Gerczuk and Shahin Amiriparian and Sandra Ottl and Björn Schuller},
19 | year={2021},
20 | eprint={2103.08310},
21 | archivePrefix={arXiv},
22 | primaryClass={cs.SD}
23 | }
24 | ```
25 |
26 |
27 | ## Installation
28 |
29 | All dependencies can be installed via pip from the requirements.txt:
30 |
31 | ```bash
32 | pip install -r requirements.txt
33 | ```
34 |
35 | It is advisable to do this from within a newly created virtual environment.
36 |
37 | ## Usage
38 |
39 | The basic commandline is accessible from the repository's basedirectory by calling:
40 |
41 | ```bash
42 | python -m emo-net.cli --help
43 | ```
44 |
45 | This prints a help message specifying the list of subcommands. For each subcommand, more help is available via:
46 |
47 | ```bash
48 | python -m emo-net.cli [subcommand] --help
49 | ```
50 |
51 | ### Data Preparation
52 |
53 | The toolkit can be used for arbitrary audio classification tasks. To prepare your dataset, resample all audio content to 16kHz wav files (e.g. with ffmpeg). Afterwards, you need label files in .csv format that specify the categorical target for each sample in the training, development and test partitions, i.e., three files "train.csv", "devel.csv" and "test.csv". The files must include the path to each audio file in the first column - relative to a common basedirectory - and a categorical label in the second column. A header line "file,label" should be included.
54 |
55 | ### Command line options
56 |
57 | The CLI has a nested structure, i.e., it uses two layers of subcommands. The first subcommand specifies the type of neural network architecture that is used. Here, "cnn" gives access to the ResNet architecture which also includes residual adapters, based on the training setting. Two other options, "rnn" and "fusion" are also included but untested and in early stages of development. The rest of this guide will therefore focus on the "cnn" subcommand. After specifying the model type, two distinct subcommands are accessible: "single-task" and "multi-task", which refer to the type of training procedure. For single task, training is performed on one database at a time specified by its basedirectory and the label files for train, validation and developments:
58 |
59 | ```bash
60 | python -m emo-net.cli -v cnn single-task -t [taskName] --data-path /path/to/task/wavs -tr train.csv -v devel.csv -te test.csv
61 | ```
62 |
63 | One additional parameter is needed that defines the type of training performed. Here, the choice can be made between tuning a fresh model from scratch (`-m scratch`), fully finetuning an existing model (`-m finetune`), training only the classifier head (`-m last-layer`) and the residual adapter approach (`-m adapters`). For the last three methods, a pre-trained model has to be loaded by specifying the path to its weights via `-im /path/to/weights.h5`. While all other parameters have sensible default values, the full list is given below:
64 |
65 | | Option | Type | Description |
66 | | ------------------------------- | ----------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
67 | | -dp, --data-path | DIRECTORY | Directory of data files. [required] |
68 | | -t, --task | TEXT | Name of the task that is trained. [required] |
69 | | -tr, --train-csv | FILE | Path to training csv file. [required] |
70 | | -v, --val-csv | FILE | Path to validation csv file. [required] |
71 | | -te, --test-csv | FILE | Path to test csv file. [required] |
72 | | -bs, --batch-size | INTEGER | Define batch size. |
73 | | -nm, --num-mels | INTEGER | Number of mel bands in spectrogram. |
74 | | -e, --epochs | INTEGER | Define max number of training epochs. |
75 | | -p, --patience | INTEGER | Define patience before early stopping / reducing learning rate in epochs. |
76 | | -im, --initial-model | FILE | Initial model for resuming training. |
77 | | -bw, --balanced-weights | FLAG | Automatically set balanced class weights. |
78 | | -lr, --learning-rate | FLOAT | Initial earning rate for optimizer. |
79 | | -do, --dropout | FLOAT | Dropout for the two positions (after first and second convolution of each block). |
80 | | -ebp, --experiment-base-path | PATH | Basepath where logs and checkpoints should be stored. |
81 | | -o, --optimizer | [sgd\|rmsprop\|adam\|adadelta] | Optimizer used for training. |
82 | | -N, --number-of-resnet-blocks | INTEGER | Number of convolutional blocks in the ResNet layers. |
83 | | -nf, --number-of-filters | INTEGER | Number of filters in first convolutional block. |
84 | | -wf, --widen-factor | INTEGER | Widen factor of wide ResNet |
85 | | -c, --classifier | [avgpool\|FCNAttention] | The classification top of the network architeture. Choose between simple pooling + dense layer (needs fixed window size) and fully convolutional attention. |
86 | | -w, --window | FLOAT | Window size in seconds. |
87 | | -l, --loss | [crossentropy\|focal\|ordinal] | Classification loss. Ordinal loss ues sorted class labels. |
88 | | -m, --mode | [scratch\|adapters\|last-layer\|finetune] | Type of training to be performed. |
89 | | -sfl, --share-feature-layer | FLAG | Share the feature layer (weighted attention of deep features) between tasks. |
90 | | -iwd, --individual-weight-decay | FLAG | Set weight decay in adapters according to size of training dataset. Smaller datasets will have larger weight decay to keep closer to the pre-trained network. |
91 | | --help | FLAG | Show this message and exit. |
92 |
93 | The "multi-task" command line slightly differs from the one described above. The most notable difference is in how the data is passed. Instead of passing individual .csv files for each partition, a directory - "--multi-task-setup" - which contains a folder with "train.csv", "val.csv" and "test.csv" files for each database has to be specified. Additionally, "-t" now is used to specify a list of databases (subfolders of the multi task setup) that should be used for training. As multi-domain training is done in a round-robin fashion, there is no predefined notion of a training epoch. Therefore, an additional option ("--steps-per-epoch") is used to define the size of an artificial training epoch. These additional parameters are also given in the table below.
94 |
95 | | Option | Type | Description |
96 | | ------------------------ | --------- | ----------------------------------------------------------------------------------------------------------------- |
97 | | -dp, --data-path | DIRECTORY | Directory of wav files. [required] |
98 | | -mts, --multi-task-setup | DIRECTORY | Directory with the setup csvs ("train.csv", "val.csv", "test.csv") for each task in a separate folder. [required] |
99 | | -t, --tasks | TEXT | Names of the tasks that are trained. [required] |
100 | | -spe, --steps-per-epoch | INTEGER | Number of training steps for each artificial epoch. |
101 |
--------------------------------------------------------------------------------
/emo-net/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/__init__.py
--------------------------------------------------------------------------------
/emo-net/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/data/__init__.py
--------------------------------------------------------------------------------
/emo-net/data/compute_scaling.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | import numpy as np
21 | import pickle
22 | import glob
23 | from tqdm import tqdm
24 | from ..models.input_layers import LogMelgramLayer
25 | from ..data.loader import AudioDataGenerator
26 | from os.path import join
27 | from sklearn.preprocessing import StandardScaler
28 |
29 |
30 |
31 | def compute_scaling(dataset_base):
32 | train_generator = AudioDataGenerator(join(dataset_base, 'train.csv'),
33 | '/mnt/nas/data_work/shahin/EmoSet/wavs-reordered/',
34 | batch_size=1,
35 | window=None,
36 | shuffle=False,
37 | sr=16000,
38 | time_stretch=None,
39 | pitch_shift=None,
40 | save_dir=None,
41 | val_split=None,
42 | subset='train',
43 | variable_duration=True)
44 | train_dataset = train_generator.tf_dataset().prefetch(tf.data.experimental.AUTOTUNE)
45 |
46 | input_tensor = tf.keras.layers.Input(shape=(None,))
47 | input_reshaped = tf.keras.layers.Reshape(
48 | target_shape=(-1, ))(input_tensor)
49 |
50 | x = LogMelgramLayer(num_fft=512,
51 | hop_length=256,
52 | sample_rate=16000,
53 | f_min=80,
54 | f_max=8000,
55 | num_mels=64,
56 | eps=1e-6,
57 | return_decibel=True,
58 | name='trainable_stft')(input_reshaped)
59 | model = tf.keras.Model(inputs=input_tensor, outputs=x)
60 | spectrograms = []
61 | for batch in tqdm(train_dataset):
62 | spectrograms.append(np.squeeze(model.predict_on_batch(batch)))
63 | spectrograms_concat = np.concatenate(spectrograms)
64 | mean = np.mean(spectrograms_concat)
65 | std = np.std(spectrograms_concat)
66 | mean_std = {'mean': mean, 'std': std}
67 | print(dataset_base, mean, std)
68 | with open(join(dataset_base, 'mean_std.pkl'), 'wb') as f:
69 | pickle.dump(mean_std, f)
70 |
71 |
72 | if __name__=='__main__':
73 | datasets = glob.glob('/mnt/student/MauriceGerczuk/EmoSet/multiTaskSetup-wavs-with-test/*/')
74 | for dataset in datasets:
75 | compute_scaling(dataset)
--------------------------------------------------------------------------------
/emo-net/data/loader.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import numpy as np
20 | import time
21 | import pandas as pd
22 | import numpy as np
23 | import itertools
24 | import csv
25 | import librosa
26 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
27 | from tensorflow.keras import applications
28 | from tensorflow.keras.utils import Sequence
29 | from tensorflow.keras.utils import to_categorical
30 | from sklearn.utils import class_weight
31 | from sklearn.preprocessing import LabelEncoder
32 | #from vgg16bn import Vgg16BN
33 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
34 | from tensorflow.keras.preprocessing.sequence import pad_sequences
35 | from sklearn.model_selection import StratifiedShuffleSplit
36 | from os.path import join, dirname, basename, relpath
37 | from os import makedirs
38 | from math import ceil
39 | from abc import ABC, abstractmethod
40 | from glob import glob
41 | from PIL import Image
42 | import tensorflow as tf
43 | import logging
44 | logger = logging.getLogger(__name__)
45 |
46 | class AudioDataGenerator(Sequence):
47 | def __init__(self,
48 | csv_file,
49 | directory,
50 | batch_size=32,
51 | window=1,
52 | shuffle=True,
53 | sr=16000,
54 | time_stretch=None,
55 | pitch_shift=None,
56 | save_dir=None,
57 | val_split=0.2,
58 | val_indices=None,
59 | subset='train',
60 | variable_duration=False):
61 | self.random_state = 42
62 | self.variable_duration = variable_duration
63 | self.files = []
64 | self.classes = []
65 | with open(csv_file) as f:
66 | reader = csv.reader(f, delimiter=',')
67 | header = next(reader)
68 | if 'label' in header:
69 | label_index = header.index('label')
70 | logger.info(f'Setup csv "{csv_file}" contains "label" column at index {label_index}.')
71 |
72 | else:
73 | label_index = len(header) - 1
74 | logger.warn(f'Setup csv "{csv_file}" does not contain "label" column. Choosing last column: "{header[label_index]}" instead.')
75 | if 'file' in header:
76 | path_index = header.index('file')
77 | logger.info(f'Setup csv "{csv_file}" contains "file" column at index {path_index}.')
78 |
79 | else:
80 | path_index = 0
81 | logger.warn(f'Setup csv "{csv_file}" does not contain "file" column. Choosing first column: "{header[path_index]}" instead.')
82 | for line in reader:
83 | self.files.append(
84 | join(directory, line[path_index]))
85 | self.classes.append(line[label_index])
86 |
87 | logger.info(f'Parsed {len(self.files)} audio files')
88 | self.val_split = val_split
89 | self.train_indices = None
90 | self.val_indices = val_indices
91 | self.subset = subset
92 |
93 |
94 |
95 | self.label_binarizer = LabelEncoder()
96 | self.label_binarizer.fit(self.classes)
97 |
98 | if self.val_split is not None and subset == 'train':
99 | self.__create_split()
100 | elif not (self.val_indices is None):
101 | self.__apply_split()
102 |
103 | self.directory = directory
104 | self.window = window
105 | self.classes = self.label_binarizer.transform(self.classes)
106 | if len(self.label_binarizer.classes_) > 2:
107 | self.categorical_classes = to_categorical(self.classes)
108 | else:
109 | self.categorical_classes = self.classes
110 | self.class_indices = {c: i for i, c in enumerate(self.label_binarizer.classes_) }
111 | logger.info(f'Class indices: {self.class_indices}')
112 | self.batch_size = batch_size
113 | self.shuffle = shuffle
114 | self.time_stretch = time_stretch
115 | self.pitch_shift = pitch_shift
116 | self.save_dir = save_dir
117 | self.sr = sr
118 | np.random.seed(self.random_state)
119 | self.on_epoch_end()
120 |
121 |
122 | @staticmethod
123 | def load_audio(filename, label):
124 | raw = tf.io.read_file(filename)
125 | audio, sr = tf.audio.decode_wav(raw, desired_channels=1)
126 | audio = tf.reshape(audio, (-1,))
127 | return audio, label
128 |
129 |
130 | @staticmethod
131 | def random_slice(audio, label, size):
132 | size = tf.math.minimum(tf.shape(audio), size)
133 | audio = tf.image.random_crop(audio, size, seed=42)
134 | return audio, label
135 |
136 | @staticmethod
137 | def center_slice(audio, label, size):
138 | duration = tf.shape(audio)[0]
139 | start = duration // 2 if duration // 2 > size else 0
140 | audio = audio[start:start+size]
141 | return audio, label
142 |
143 |
144 | def tf_dataset(self):
145 | dataset = tf.data.Dataset.from_tensor_slices((self.files, self.categorical_classes))
146 | binary = len(self.categorical_classes.shape) < 2
147 | if self.shuffle:
148 | dataset = dataset.shuffle(len(self.files), seed=42)
149 | dataset = dataset.map(AudioDataGenerator.load_audio, num_parallel_calls=tf.data.experimental.AUTOTUNE)
150 | #dataset = dataset.filter(lambda x, _: tf.math.count_nonzero(x) > 0)
151 | if self.window is not None:
152 | window_size = int(self.window*self.sr)
153 | padded_size = window_size if not self.variable_duration else None
154 | if self.shuffle:
155 | dataset = dataset.map(lambda audio, label: AudioDataGenerator.random_slice(audio, label, size=window_size), num_parallel_calls=tf.data.experimental.AUTOTUNE)
156 | else:
157 | dataset = dataset.map(lambda audio, label: AudioDataGenerator.center_slice(audio, label, size=window_size), num_parallel_calls=tf.data.experimental.AUTOTUNE)
158 | else :
159 | padded_size = None
160 | padded_label_size = () if binary else (self.categorical_classes.shape[1],)
161 | dataset = dataset.padded_batch(self.batch_size, ((padded_size,), padded_label_size))
162 | #dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE).cache()
163 | return dataset
164 |
165 |
166 |
167 |
168 |
169 |
170 | def __create_split(self):
171 | sss = StratifiedShuffleSplit(n_splits=1, test_size=self.val_split, random_state=self.random_state)
172 | for train_index, test_index in sss.split(self.files, self.classes):
173 | self.val_indices = test_index
174 | self.train_indices = train_index
175 |
176 |
177 | def __apply_split(self):
178 | indices = self.train_indices if self.subset == 'train' else self.val_indices
179 | for index in sorted(indices, reverse=True):
180 | del self.files[index]
181 | del self.classes[index]
182 |
183 |
184 | def __len__(self):
185 | return ceil(len(self.files) / self.batch_size)
186 | """ if len(self.files) % self.batch_size == 0:
187 | return int(len(self.files) / self.batch_size)
188 | else:
189 | return int(len(self.files) / self.batch_size) + 1 """
190 |
191 | def __getitem__(self, index):
192 | # Generate indexes of the batch
193 | index = index % len(self)
194 | indices = self.indices[index * self.batch_size:min(len(self.indices), (index + 1) *
195 | self.batch_size)]
196 |
197 | files_batch = [self.files[k] for k in indices]
198 | y = np.asarray([self.categorical_classes[k] for k in indices])
199 |
200 | # Generate data
201 | x = self.__data_generation(files_batch)
202 |
203 | return x, y
204 |
205 | def _set_index_array(self):
206 | self.indices = np.arange(len(self.files))
207 | if self.shuffle:
208 | np.random.shuffle(self.indices)
209 |
210 | def on_epoch_end(self):
211 | 'Updates indexes after each epoch'
212 | self._set_index_array()
213 |
214 | def __data_generation(self, files):
215 | audio_data = []
216 |
217 | for file in files:
218 | duration = librosa.core.get_duration(filename=file)
219 |
220 | if self.window is not None:
221 | stretched_window = self.window * (
222 | 1 + self.time_stretch
223 | ) if self.time_stretch is not None else self.window
224 | if self.shuffle:
225 | start = np.random.randint(0, max(1, int(duration - stretched_window)))
226 |
227 | else:
228 | start = duration / 2 if duration / 2 > stretched_window else 0 # take the middle chunk
229 | y, sr = librosa.core.load(file,
230 | offset=start,
231 | duration=min(stretched_window, duration),
232 | sr=self.sr)
233 | y = self.__get_random_transform(y, sr)
234 | end_sample = min(int(self.window * sr), int(duration * sr))
235 | y = y[:end_sample]
236 | else:
237 | y, sr = librosa.core.load(file, sr=self.sr)
238 | y = self.__get_random_transform(y, sr)
239 |
240 | if self.save_dir:
241 | rel_path = relpath(file, self.directory)
242 | save_path = join(self.save_dir, rel_path.wav)
243 | makedirs(dirname(save_path), exist_ok=True)
244 | librosa.output.write_wav(
245 | join(self.save_dir, rel_path),
246 | audio_data, sr)
247 | audio_data.append(y)
248 | if (self.window is not None) and (not self.variable_duration):
249 | audio_data = pad_sequences(
250 | audio_data, maxlen=int(self.window*self.sr), dtype='float32')
251 | else:
252 | audio_data = pad_sequences(
253 | audio_data, dtype='float32')
254 |
255 | return audio_data
256 |
257 | def __get_random_transform(self, y, sr):
258 | if self.time_stretch is not None:
259 | factor = np.random.normal(1, self.time_stretch)
260 | y = librosa.effects.time_stretch(y, factor)
261 | if self.pitch_shift is not None:
262 | steps = np.random.randint(0 - self.pitch_shift,
263 | 1 + self.pitch_shift)
264 | y = librosa.effects.pitch_shift(y, sr, steps)
265 | return y
266 |
267 |
268 | def benchmark(dataset, num_epochs=2):
269 | start_time = time.perf_counter()
270 | for epoch_num in range(num_epochs):
271 | for sample in dataset:
272 | # Performing a training step
273 | time.sleep(0.01)
274 | tf.print("Execution time:", time.perf_counter() - start_time)
275 |
276 | def benchmark_generator(generator, num_epochs=2):
277 | start_time = time.perf_counter()
278 | for epoch_num in range(num_epochs):
279 | for i in range(len(generator)):
280 | sample = generator[i]
281 | # Performing a training step
282 | #time.sleep(0.01)
283 | tf.print("Execution time:", time.perf_counter() - start_time)
284 |
--------------------------------------------------------------------------------
/emo-net/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/models/__init__.py
--------------------------------------------------------------------------------
/emo-net/models/adapter_resnet.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 |
21 | kernel_regularizer = tf.keras.regularizers.l2(1e-6)
22 |
23 | channel_axis = -1
24 |
25 | class BasicBlock(object):
26 | def __init__(self,
27 | filters,
28 | factor,
29 | strides=2,
30 | dropout1=0,
31 | dropout2=0,
32 | shortcut=True,
33 | learnall=True,
34 | tasks=['IEMOCAP-4cl'],
35 | weight_decays=None,
36 | **kwargs):
37 | self.filters = filters
38 | self.factor = factor
39 | self.strides = strides
40 | self.dropout1 = tf.keras.layers.Dropout(dropout1)
41 | self.dropout2 = tf.keras.layers.Dropout(dropout2)
42 | self.shortcut = shortcut
43 | self.learnall = learnall
44 | self.tasks = weight_decays if weight_decays is not None else [1e-6]*len(self.tasks)
45 | self.weight_decays = weight_decays
46 | self.conv_task1 = ConvTasks(filters,
47 | factor,
48 | strides=strides,
49 | learnall=learnall,
50 | dropout=dropout1,
51 | tasks=tasks,
52 | weight_decays=self.weight_decays,
53 | **kwargs)
54 | self.conv_task2 = ConvTasks(filters,
55 | factor,
56 | strides=1,
57 | learnall=learnall,
58 | dropout=dropout2,
59 | tasks=tasks,
60 | weight_decays=self.weight_decays,
61 | **kwargs)
62 |
63 | self.relu = tf.keras.layers.Activation('relu')
64 | if self.shortcut:
65 | self.avg_pool = tf.keras.layers.AveragePooling2D((2, 2), padding='same')
66 | self.lmbda = tf.keras.layers.Lambda(lambda x: x * 0)
67 | self.add = tf.keras.layers.Add()
68 |
69 | def __call__(self, input_tensor, task):
70 | residual = input_tensor
71 | x = self.conv_task1(input_tensor, task=task)
72 | x = self.relu(x)
73 | x = self.conv_task2(x, task=task)
74 | if self.shortcut:
75 | residual = self.avg_pool(residual)
76 | residual0 = self.lmbda(residual)
77 | residual = tf.keras.layers.concatenate([residual, residual0], axis=-1)
78 | x = self.add([residual, x])
79 | x = self.relu(x)
80 | return x
81 |
82 | def _add_new_task(self, task, weight_decay=1e-6):
83 | self.conv_task1._add_new_task(task, weight_decay=weight_decay)
84 | self.conv_task2._add_new_task(task, weight_decay=weight_decay)
85 |
86 |
87 | class ConvTasks(object):
88 | def __init__(self,
89 | filters,
90 | factor=1,
91 | strides=1,
92 | learnall=True,
93 | dropout=0,
94 | tasks=['IEMOCAP-4cl', 'GEMEP'],
95 | weight_decays=None,
96 | reuse_batchnorm=False,
97 | **kwargs):
98 | self.filters = filters
99 | self.factor = factor
100 | self.strides = strides
101 | self.learnall = learnall
102 | self.dropout = tf.keras.layers.Dropout(dropout)
103 | self.tasks = tasks
104 | self.weight_decays = weight_decays if weight_decays is not None else [1e-6]*len(self.tasks)
105 | self.reuse_batchnorm = reuse_batchnorm
106 |
107 | # shared parameters
108 | self.conv2d = tf.keras.layers.Convolution2D(self.filters * self.factor, (3, 3),
109 | strides=self.strides,
110 | padding='same',
111 | kernel_initializer='he_normal',
112 | use_bias=False,
113 | trainable=self.learnall,
114 | kernel_regularizer=kernel_regularizer)
115 |
116 | # task specificparameters
117 | self.res_adapts = {}
118 | self.add = tf.keras.layers.Add()
119 | self.bns = {}
120 | self.core_bn = tf.keras.layers.BatchNormalization(
121 | axis=channel_axis,
122 | name=f'core_{self.conv2d.name}_batch_normalization')
123 | for task, weight_decay in zip(self.tasks, self.weight_decays):
124 | self._add_new_task(task, weight_decay=weight_decay)
125 |
126 | def __call__(self, input_tensor, task):
127 | in_t = input_tensor
128 | if task is None:
129 | in_t = self.dropout(in_t)
130 | x = self.conv2d(in_t)
131 | if task is not None:
132 | adapter_in = self.dropout(in_t)
133 | res_adapt = self.res_adapts[task](adapter_in)
134 | x = self.add([x, res_adapt])
135 | if self.reuse_batchnorm or task is None:
136 | x = self.core_bn(x)
137 | else:
138 | x = self.bns[task](x)
139 | return x
140 |
141 | def _add_new_task(self, task, weight_decay=1e-6):
142 | assert task not in self.bns, 'Task already exists!'
143 | self.res_adapts[task] = tf.keras.layers.Convolution2D(
144 | self.filters * self.factor, (1, 1),
145 | padding='valid',
146 | kernel_initializer='he_normal',
147 | strides=self.strides,
148 | use_bias=False,
149 | kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
150 | name=f'{task}_{self.conv2d.name}_adapter')
151 |
152 | if not self.reuse_batchnorm:
153 | self.bns[task] = tf.keras.layers.BatchNormalization(
154 | axis=channel_axis,
155 | name=f'{task}_{self.conv2d.name}_batch_normalization')
156 |
157 |
158 | class ResNet(object):
159 | def __init__(self,
160 | filters=32,
161 | factor=1,
162 | N=2,
163 | verbose=1,
164 | learnall=True,
165 | dropout1=0,
166 | dropout2=0,
167 | tasks=['IEMOCAP-4cl', 'GEMEP'],
168 | weight_decays=None,
169 | reuse_batchnorm=False,
170 | input_bn=False):
171 | self.filters = filters
172 | self.factor = factor
173 | self.N = N
174 | self.learnall = learnall
175 | self.dropout1 = dropout1
176 | self.dropout2 = dropout2
177 | self.tasks = tasks
178 | self.weight_decays = weight_decays if weight_decays is not None else [1e-6]*len(self.tasks)
179 | self.reuse_batchnorm = reuse_batchnorm
180 | self.input_bn = input_bn
181 | if self.input_bn:
182 | self.input_core_bn = tf.keras.layers.BatchNormalization(axis=channel_axis, name=f'core_input_batch_normalization')
183 | self.input_bns = {
184 | task: tf.keras.layers.BatchNormalization(axis=channel_axis,
185 | name=f'{task}_input_batch_normalization')
186 | for task in self.tasks
187 | }
188 |
189 | # conv blocks
190 | self.pre_conv = ConvTasks(filters=self.filters,
191 | factor=factor,
192 | strides=1,
193 | learnall=learnall,
194 | tasks=self.tasks,
195 | weight_decays=self.weight_decays,
196 | reuse_batchnorm=reuse_batchnorm)
197 | self.blocks = []
198 | self.nb_conv = 1
199 | for i in range(1, 4):
200 | block = BasicBlock(self.filters * (2**i),
201 | self.factor,
202 | strides=2,
203 | dropout1=self.dropout1,
204 | dropout2=self.dropout2,
205 | shortcut=True,
206 | learnall=self.learnall,
207 | tasks=self.tasks,
208 | weight_decays=self.weight_decays,
209 | reuse_batchnorm=reuse_batchnorm)
210 | self.blocks.append(block)
211 | for j in range(N - 1):
212 | block = BasicBlock(filters=self.filters *
213 | (2**i),
214 | factor=self.factor,
215 | strides=1,
216 | dropout1=self.dropout1,
217 | dropout2=self.dropout2,
218 | shortcut=False,
219 | learnall=self.learnall,
220 | tasks=self.tasks,
221 | weight_decays=self.weight_decays,
222 | reuse_batchnorm=reuse_batchnorm)
223 | self.blocks.append(block)
224 | self.nb_conv += 2
225 | self.nb_conv += 6
226 |
227 | # bns and relus
228 | self.relu = tf.keras.layers.Activation('relu')
229 | self.bns = {
230 | task: tf.keras.layers.BatchNormalization(axis=channel_axis,
231 | name=f'{task}_final_batch_normalization')
232 | for task in self.tasks
233 | }
234 | self.core_bn = tf.keras.layers.BatchNormalization(axis=channel_axis,
235 | name=f'core_final_batch_normalization')
236 |
237 | def _add_new_task(self, task, weight_decay=1e-6):
238 | assert task not in self.bns, f'Task {task} already exists!'
239 | self.pre_conv._add_new_task(task, weight_decay=weight_decay)
240 | for block in self.blocks:
241 | block._add_new_task(task, weight_decay=weight_decay)
242 | if not self.reuse_batchnorm:
243 | self.bns[task] = tf.keras.layers.BatchNormalization(axis=channel_axis, name=f'{task}_final_batch_normalization')
244 | if self.input_bn:
245 | self.input_bns[task] = tf.keras.layers.BatchNormalization(axis=channel_axis, name=f'{task}_input_batch_normalization')
246 |
247 | def __call__(self, input_tensor, task):
248 | if self.input_bn:
249 | if task is None or self.reuse_batchnorm:
250 | x = self.input_core_bn(input_tensor)
251 | else:
252 | x = self.input_bns[task](input_tensor)
253 | else:
254 | x = input_tensor
255 | x = self.pre_conv(x, task=task)
256 | for block in self.blocks:
257 | x = block(x, task=task)
258 | if task is None or self.reuse_batchnorm:
259 | x = self.core_bn(x)
260 | else:
261 | x = self.bns[task](x)
262 | x = self.relu(x)
263 | return x
--------------------------------------------------------------------------------
/emo-net/models/adapter_rnn.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | #from .attention import SeqSelfAttention, SeqWeightedAttention
21 |
22 | kernel_regularizer = tf.keras.regularizers.l2(1e-6)
23 | rnn_regularizer = tf.keras.regularizers.L1L2(1e-6)
24 |
25 |
26 | class RNNWithAdapters(object):
27 | def __init__(self,
28 | input_dims,
29 | hidden_size=512,
30 | learnall=True,
31 | dropout=0.2,
32 | layers=2,
33 | input_projection_factor=1,
34 | adapter_projection_factor=2,
35 | tasks=[],
36 | recurrent_cell='lstm',
37 | bidirectional=False,
38 | input_projection=True,
39 | input_bn=False,
40 | downpool=None,
41 | share_feature_layer=False,
42 | share_attention=False,
43 | use_attention=True,
44 | **kwargs):
45 |
46 | self.hidden_size = hidden_size
47 | self.learnall = learnall
48 | self.dropout = dropout
49 | self.tasks = tasks
50 | self.input_dims = input_dims
51 | self.adapter_projection_factor = adapter_projection_factor
52 | self.input_projection = input_projection
53 | self.input_projection_factor = input_projection_factor
54 | self.tasks = tasks
55 | self.recurrent_cell = recurrent_cell
56 | self.layers = layers
57 | self.bidirectional = bidirectional
58 | self.input_bn = input_bn
59 | self.cnn_input = len(input_dims) > 3
60 | self.share_feature_layer = share_feature_layer
61 | self.use_attention = use_attention
62 | self.share_attention = share_attention
63 | if self.cnn_input: # cnn feature extractor
64 | feature_dims = input_dims[1] * input_dims[3]
65 | else:
66 | feature_dims = input_dims[-1]
67 | self.reorder_dims = tf.keras.layers.Permute((2, 1, 3))
68 | if downpool is not None:
69 | self.downpool = tf.keras.layers.AveragePooling1D(
70 | pool_size=downpool, strides=downpool, padding='same', name='rnn_downpool')
71 | else:
72 | self.downpool = None
73 | self.reshape = tf.keras.layers.Reshape(target_shape=(-1,
74 | feature_dims))
75 | if self.input_bn:
76 | self.input_bns = {task: tf.keras.layers.BatchNormalization(
77 | trainable=True, name=f'{task}_rnn_input_bn') for task in tasks}
78 | if not self.input_bns:
79 | self.core_input_bn = tf.keras.layers.BatchNormalization(
80 | trainable=True, name=f'core_rnn_input_bn')
81 | self.projection = tf.keras.layers.Dense(feature_dims // self.input_projection_factor,
82 | activation=None,
83 | trainable=learnall,
84 | kernel_regularizer=kernel_regularizer)
85 | self.adapter_hidden_size = self.hidden_size * \
86 | 2 if self.bidirectional else self.hidden_size
87 | self.rnns = []
88 | self.selfattentions = []
89 | self.selfattention = []
90 | self.adapters = []
91 | for i in range(self.layers):
92 | rnn = tf.keras.layers.GRU(self.hidden_size,
93 | dropout=self.dropout,
94 | return_sequences=True,
95 | trainable=learnall, kernel_regularizer=rnn_regularizer) if self.recurrent_cell.lower() == 'gru' else tf.keras.layers.LSTM(self.hidden_size,
96 | dropout=self.dropout,
97 | return_sequences=True,
98 | trainable=learnall, kernel_regularizer=rnn_regularizer)
99 | if self.bidirectional:
100 | rnn = tf.keras.layers.Bidirectional(rnn)
101 | self.rnns.append(rnn)
102 | self.adapters.append({
103 | task: RNNAdapter(self.adapter_hidden_size, self.adapter_projection_factor,
104 | task, i)
105 | for task in tasks
106 | })
107 | if i < self.layers - 1:
108 | self.selfattentions.append({task: SeqSelfAttention(
109 | attention_activation='sigmoid',
110 | kernel_regularizer=kernel_regularizer,
111 | use_attention_bias=False,
112 | trainable=True,
113 | name=f'{task}_self_attention_{i}') for task in tasks})
114 | self.selfattention.append(SeqSelfAttention(
115 | attention_activation='sigmoid',
116 | kernel_regularizer=kernel_regularizer,
117 | use_attention_bias=False,
118 | trainable=learnall,
119 | name=f'core_self_attention_{i}'))
120 |
121 | self.add = tf.keras.layers.Add()
122 |
123 |
124 | self.weighted_attentions = {task: SeqWeightedAttention(
125 | trainable=True, name=f'{task}_seq_weighted_attention') for task in tasks}
126 | self.weighted_attention = SeqWeightedAttention(
127 | trainable=learnall, name=f'core_seq_weighted_attention')
128 |
129 | def __call__(self, x, task, mask=None):
130 | if self.input_bn:
131 | if task is not None:
132 | self.input_bns[task](x)
133 | else:
134 | self.core_input_bn(x)
135 | if self.cnn_input:
136 | x = self.reorder_dims(x)
137 | x = self.reshape(x)
138 | if self.downpool is not None:
139 | x = self.downpool(x)
140 | #x = self.mask(x)
141 | if self.input_projection:
142 | x = self.projection(x)
143 | for i in range(self.layers):
144 | x = self.rnns[i](x, mask=mask)
145 | if task is not None:
146 | adapter = self.adapters[i][task](x)
147 | x = self.add([x, adapter])
148 | if i < self.layers - 1 and self.use_attention:
149 | if self.share_attention:
150 | x = self.selfattention[i](x, mask=mask)
151 | else:
152 | x = self.selfattentions[i][task](x, mask=mask)
153 | else:
154 | if i < self.layers - 1 and self.use_attention:
155 | x = self.selfattention[i](x, mask=mask)
156 | if task is not None and not self.share_feature_layer:
157 | x = self.weighted_attentions[task](x, mask=mask)
158 | else:
159 | x = self.weighted_attention(x, mask=mask)
160 | return x
161 |
162 | def _add_new_task(self, task):
163 | assert task not in self.adapters, f'Task {task} already exists!'
164 | for i in range(self.layers):
165 | self.adapters[i][task] = RNNAdapter(self.adapter_hidden_size,
166 | self.adapter_projection_factor, task,i)
167 | if i < self.layers - 1:
168 | self.selfattentions[i][task] = SeqSelfAttention(
169 | attention_activation='sigmoid',
170 | kernel_regularizer=kernel_regularizer,
171 | use_attention_bias=False,
172 | trainable=True,
173 | name=f'{task}_self_attention_{i}')
174 | self.weighted_attentions[task] = SeqWeightedAttention(
175 | trainable=True, name=f'{task}_seq_weighted_attention')
176 | if self.input_bn:
177 | self.input_bns[task] = tf.keras.layers.BatchNormalization(
178 | trainable=True, name=f'{task}_rnn_input_bn')
179 |
180 |
181 | class RNNAdapter(object):
182 | def __init__(self, input_size, downprojection=4, task='IEMOCAP', index=1):
183 | self.input_size = input_size
184 | self.downprojection_factor = downprojection
185 | self.task = task
186 | self.layer_norm = tf.keras.layers.TimeDistributed(tf.keras.layers.LayerNormalization(),
187 | name=f'{task}_rnn_adapter_{index}_layer_norm')
188 | self.downprojection = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(
189 | self.input_size // self.downprojection_factor, activation='relu', use_bias=False, kernel_regularizer=kernel_regularizer),
190 | name=f'{task}_rnn_adapter_{index}_downprojection')
191 | self.upprojection = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.input_size, use_bias=False, kernel_regularizer=kernel_regularizer),
192 | name=f'{task}_rnn_adapter_{index}_upprojection')
193 |
194 | def __call__(self, x):
195 | x = self.layer_norm(x)
196 | x = self.downprojection(x)
197 | x = self.upprojection(x)
198 | #x = self.selfattention(x)
199 | return x
200 |
--------------------------------------------------------------------------------
/emo-net/models/attention.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | import tensorflow.keras.backend as K
21 |
22 |
23 | class Attention2Dtanh(tf.keras.layers.Layer):
24 | def __init__(self, lmbda=0.3, mlp_units=256, **kwargs):
25 | super(Attention2Dtanh, self).__init__(**kwargs)
26 | self.mlp_units = mlp_units
27 | self.tanh = tf.keras.layers.Activation('tanh')
28 | self.lmbda = lmbda
29 |
30 | def build(self, input_shape):
31 | self.w = self.add_weight(shape=(input_shape[-1], self.mlp_units),
32 | initializer='random_normal',
33 | trainable=True,
34 | name='W')
35 | self.b = self.add_weight(shape=(self.mlp_units, ),
36 | initializer='random_normal',
37 | trainable=True,
38 | name='b')
39 | self.u = self.add_weight(shape=(input_shape[-1], ),
40 | initializer='random_normal',
41 | trainable=True,
42 | name='u')
43 | self.flatten = tf.keras.layers.Reshape(target_shape=(-1,
44 | input_shape[-1]))
45 | super(Attention2Dtanh, self).build(input_shape)
46 |
47 | def call(self, inputs):
48 | flat_input = self.flatten(inputs)
49 | x = tf.matmul(flat_input, self.w) + self.b
50 | x = self.tanh(x)
51 | e = tf.tensordot(self.u, x, axes=[[0], [-1]]) * self.lmbda
52 | a = tf.nn.softmax(e, axis=-1)
53 | weighted_sum = tf.reduce_sum(tf.expand_dims(a, -1) * flat_input,
54 | axis=1)
55 | return weighted_sum
56 |
57 | def get_config(self):
58 | config = {'lmbda': self.lmbda, 'mlp_units': self.mlp_units}
59 | base_config = super(Attention2Dtanh, self).get_config()
60 | return dict(list(base_config.items()) + list(config.items()))
61 |
--------------------------------------------------------------------------------
/emo-net/models/build_model.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | tf.random.set_seed(42)
21 |
22 | import tensorflow.keras.backend as K
23 | import h5py
24 | from .input_layers import *
25 | from .adapter_rnn import *
26 | from .adapter_resnet import *
27 | from .attention import *
28 | from ..utils import array_list_equal
29 |
30 | import logging
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 |
35 | def avgpool(x):
36 | x = tf.keras.layers.AveragePooling2D((8, 8))(x)
37 | x = tf.keras.layers.Flatten()(x)
38 | return x
39 |
40 | def global_avgpool(x):
41 | x = tf.keras.layers.GlobalAveragePooling2D()(x)
42 | x = tf.keras.layers.Flatten()(x)
43 | return x
44 |
45 |
46 | def infer_tasks_from_weightfile(initial_weights):
47 | with h5py.File(initial_weights) as f:
48 | base_tasks = []
49 | base_nb_classes = []
50 | for k in f['model_weights']:
51 | prefices = ('activation', 'add', 'average_pooling',
52 | 'batch_normalization', 'bidirectional', 'concat', 'conv2d', 'dropout', 'attention',
53 | 'dense', 'flatten', 'input', 'lambda',
54 | 'normalization2d', 'reshape', 'trainable_stft', 'apply_zero_mask', 'core', 'lstm', 'masking', 'permute', 'pooled', 'seq', 'zero_mask', 'adapter', 'mask', 'expand', 'mfcc', 'downpool')
55 | skip_layer = any([prefix in k for prefix in prefices])
56 | if not skip_layer:
57 | task = k
58 | classes = _find_n_classes(f['model_weights'][k], k)
59 | #classes = f['model_weights'][k]['softmax'][k]['softmax']['kernel:0'].shape[1]
60 | logger.info(f'Found task {k} with {classes} classes.')
61 | base_tasks.append(task)
62 | base_nb_classes.append(classes)
63 | return base_tasks, base_nb_classes
64 |
65 |
66 | def _find_n_classes(weight_dict, task):
67 | if 'sigmoid' in weight_dict:
68 | return 2
69 | elif 'kernel:0' in weight_dict:
70 | output_shape = weight_dict['kernel:0'].shape[1]
71 | if output_shape == 1: # binary
72 | output_shape += 1
73 | return output_shape
74 | elif 'softmax' in weight_dict:
75 | return _find_n_classes(weight_dict['softmax'], task)
76 | elif task in weight_dict:
77 | return _find_n_classes(weight_dict[task], task)
78 |
79 |
80 | def input_features_and_mask(audio_in, num_fft=1024, hop_length=512, sample_rate=16000, f_min=20, f_max=8000, num_mels=128, eps=1e-6, return_decibel=False, num_mfccs=None):
81 | input_features = LogMelgramLayer(num_fft=num_fft,
82 | hop_length=hop_length,
83 | sample_rate=sample_rate,
84 | f_min=f_min,
85 | f_max=f_max,
86 | num_mels=num_mels,
87 | eps=eps,
88 | return_decibel=return_decibel,
89 | name='trainable_stft')
90 | x = input_features(audio_in)
91 | mask = ComputeMask(input_features.num_fft,
92 | input_features.hop_length)(audio_in)
93 | if num_mfccs is not None:
94 | x = MFCCLayer(num_mfccs=num_mfccs)(x)
95 | return x, mask
96 |
97 |
98 | def create_multi_task_networks(input_dim, feature_extractor='cnn',
99 | initial_weights=None,
100 | base_nb_classes=None,
101 | learnall=True,
102 | num_mels=128,
103 | base_tasks=None,
104 | new_tasks=None,
105 | new_nb_classes=None,
106 | mode=None,
107 | random_noise=None,
108 | input_bn=False,
109 | share_feature_layer=True,
110 | base_weight_decays=None,
111 | new_weight_decays=None,
112 | **kwargs):
113 | if feature_extractor == 'cnn':
114 | return create_multi_task_resnets(input_dim=input_dim, mode=mode, num_mels=num_mels, initial_weights=initial_weights, base_nb_classes=base_nb_classes, base_weight_decays=base_weight_decays, new_weight_decays=new_weight_decays, learnall=learnall, base_tasks=base_tasks, new_tasks=new_tasks, new_nb_classes=new_nb_classes, random_noise=random_noise, input_bn=input_bn, share_feature_layer=share_feature_layer, **kwargs)
115 | elif feature_extractor == 'rnn':
116 | return create_multi_task_rnn(input_dim=input_dim, mode=mode, num_mels=num_mels, initial_weights=initial_weights, base_nb_classes=base_nb_classes, learnall=learnall, base_tasks=base_tasks, new_tasks=new_tasks, new_nb_classes=new_nb_classes, random_noise=random_noise, input_bn=input_bn, share_feature_layer=share_feature_layer, **kwargs)
117 | elif feature_extractor == 'vgg16':
118 | return create_multi_task_vgg16(input_dim=input_dim,
119 | tasks=base_tasks,
120 | num_mels=num_mels,
121 | nb_classes=base_nb_classes,
122 | random_noise=random_noise,
123 | initial_weights=initial_weights,
124 | share_feature_layer=share_feature_layer,
125 | **kwargs)
126 | elif feature_extractor == 'fusion':
127 | return create_multi_task_fusion(input_dim=input_dim, mode=mode, num_mels=num_mels, initial_weights=initial_weights, base_nb_classes=base_nb_classes, learnall=learnall, base_tasks=base_tasks, new_tasks=new_tasks, new_nb_classes=new_nb_classes, random_noise=random_noise, input_bn=input_bn, share_feature_layer=share_feature_layer, **kwargs)
128 |
129 |
130 | def create_multi_task_fusion(input_dim,filters=32,
131 | factor=1,
132 | N=4,
133 | hidden_dim=512,
134 | cell='lstm',
135 | number_of_layers=2,
136 | down_pool=8,
137 | bidirectional=False,
138 | num_mels=128,
139 | learnall=True,
140 | learnall_classifier=True,
141 | mode='adapters',
142 | dropout1=0,
143 | dropout2=0,
144 | rnn_dropout=0.2,
145 | base_tasks=['EMO-DB', 'GEMEP'],
146 | new_tasks=None,
147 | base_nb_classes=[6, 10],
148 | new_nb_classes=None,
149 | initial_weights=None,
150 | random_noise=0.1,
151 | reuse_batchnorm=False,
152 | input_bn=False,
153 | share_feature_layer=False):
154 | channel_axis = -1
155 | if base_tasks is None:
156 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!'
157 | logger.info(
158 | 'Trying to determine trained tasks from initial weights...')
159 | base_tasks, base_nb_classes = infer_tasks_from_weightfile(
160 | initial_weights)
161 |
162 | input_tensor = tf.keras.layers.Input(shape=input_dim)
163 |
164 | input_reshaped = tf.keras.layers.Reshape(
165 | target_shape=(-1, ))(input_tensor)
166 |
167 |
168 | x, mask = input_features_and_mask(input_reshaped, num_fft=1024,
169 | hop_length=512,
170 | sample_rate=16000,
171 | f_min=20,
172 | f_max=8000,
173 | num_mels=num_mels,
174 | eps=1e-6,
175 | return_decibel=True,
176 | num_mfccs=None)
177 |
178 | adapter_rnn = RNNWithAdapters(K.int_shape(x),
179 | hidden_size=hidden_dim,
180 | learnall=learnall,
181 | dropout=rnn_dropout,
182 | input_projection_factor=1,
183 | adapter_projection_factor=4,
184 | bidirectional=bidirectional,
185 | layers=number_of_layers,
186 | recurrent_cell=cell,
187 | input_bn=input_bn,
188 | downpool=down_pool,
189 | input_projection=False,
190 | # tasks=base_tasks,
191 | tasks=base_tasks if mode == 'adapters' else [],
192 | share_feature_layer=share_feature_layer)
193 | if down_pool is not None:
194 | mask = PoolMask((down_pool,))(mask)
195 |
196 |
197 | expand_dims = tf.keras.layers.Lambda(
198 | lambda x: tf.expand_dims(x, 3), name='expand_input_dims')
199 | x_resnet = expand_dims(x)
200 | x_resnet = tf.keras.layers.Permute((2, 1, 3))(x_resnet)
201 | x_rnn = x
202 |
203 |
204 | adapter_resnet = ResNet(filters=filters,
205 | factor=factor,
206 | N=N,
207 | learnall=learnall,
208 | dropout1=dropout1,
209 | dropout2=dropout2,
210 | # tasks=base_tasks,
211 | tasks=base_tasks if mode == 'adapters' else [],
212 | input_bn=input_bn)
213 |
214 | if new_tasks is not None:
215 | really_new_tasks = [t for t in new_tasks if t not in base_tasks]
216 | new_nb_classes = [
217 | c for c, t in zip(new_nb_classes, new_tasks) if t not in base_tasks
218 | ]
219 | else:
220 | really_new_tasks = []
221 | new_nb_classes = []
222 |
223 | task_models = {}
224 | outputs = []
225 | attention2d = None
226 | for task, classes in zip(base_tasks, base_nb_classes):
227 | logger.info(f'Building model for {task} with {classes} classes...')
228 | adapters_in = task if mode == 'adapters' else None
229 | y_resnet = adapter_resnet(x_resnet, task=adapters_in)
230 | y_rnn = adapter_rnn(x_rnn, task=adapters_in, mask=mask)
231 |
232 |
233 | if attention2d is None:
234 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y_resnet)[-1], name=f'core_2d_attention', trainable=not(mode=='adapters'))
235 | if not share_feature_layer and mode == 'adapters':
236 | attention2d = Attention2Dtanh(name=f'{task}_2d_attention', mlp_units=K.int_shape(y_resnet)[-1], trainable=True)
237 | y_resnet = attention2d(y_resnet)
238 |
239 | y = tf.keras.layers.Concatenate()([y_resnet, y_rnn])
240 |
241 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y)
242 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y)
243 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y)
244 | y = tf.keras.layers.Dropout(0.2)(y)
245 | if classes == 2:
246 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=task)(y)
247 | else:
248 | y = tf.keras.layers.Dense(
249 | classes, activation='softmax', name=task)(y)
250 | model = tf.keras.Model(inputs=input_tensor, outputs=y)
251 | outputs.append(y)
252 | task_models[task] = model
253 |
254 | if really_new_tasks is not None and new_nb_classes is not None:
255 | for task, classes in zip(really_new_tasks, new_nb_classes):
256 | logger.info(f'Building model for {task} with {classes} classes...')
257 | adapter_resnet._add_new_task(task)
258 | adapter_rnn._add_new_task(task)
259 | y_resnet = adapter_resnet(x, task)
260 | y_rnn = adapter_rnn(x_rnn, task, mask=mask)
261 |
262 | if attention2d is None:
263 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y_resnet)[-1], name=f'core_2d_attention', trainable=not(mode=='adapters'))
264 | if not share_feature_layer and mode == 'adapters':
265 | attention2d = Attention2Dtanh(lmbda=0.3, name=f'{task}_2d_attention', mlp_units=K.int_shape(y_resnet)[-1], trainable=True)
266 | y_resnet = attention2d(y_resnet)
267 |
268 | y = tf.keras.layers.Concatenate()([y_resnet, y_rnn])
269 |
270 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y)
271 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y)
272 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y)
273 | y = tf.keras.layers.Dropout(0.2)(y)
274 | if classes == 2:
275 | y = tf.keras.layers.Dense(
276 | 1, activation='sigmoid', name=task)(y)
277 | else:
278 | y = tf.keras.layers.Dense(
279 | classes, activation='softmax', name=task)(y)
280 | model = tf.keras.Model(inputs=input_tensor, outputs=y)
281 | outputs.append(y)
282 | task_models[task] = model
283 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs)
284 |
285 | if initial_weights is not None:
286 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights)
287 |
288 | return task_models, shared_model
289 |
290 | def create_multi_task_resnets(input_dim,
291 | filters=32,
292 | factor=1,
293 | N=4,
294 | num_mels=128,
295 | learnall=True,
296 | learnall_classifier=True,
297 | mode='adapters',
298 | dropout1=0,
299 | dropout2=0,
300 | rnn_dropout=0.2,
301 | base_tasks=['EMO-DB', 'GEMEP'],
302 | new_tasks=None,
303 | base_weight_decays=None,
304 | new_weight_decays=None,
305 | base_nb_classes=[6, 10],
306 | new_nb_classes=None,
307 | initial_weights=None,
308 | random_noise=0.1,
309 | classifier='rnn',
310 | input_bn=False,
311 | share_feature_layer=False):
312 |
313 | channel_axis = -1
314 | if base_tasks is None:
315 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!'
316 | logger.info(
317 | 'Trying to determine trained tasks from initial weights...')
318 | base_tasks, base_nb_classes = infer_tasks_from_weightfile(
319 | initial_weights)
320 |
321 | base_weight_decays = base_weight_decays if base_weight_decays is not None else [1e-6]*len(base_tasks)
322 |
323 | if new_tasks is not None:
324 | really_new_tasks = [t for t in new_tasks if t not in base_tasks]
325 | new_nb_classes = [
326 | c for c, t in zip(new_nb_classes, new_tasks) if t not in base_tasks
327 | ]
328 | else:
329 | really_new_tasks = []
330 | new_nb_classes = []
331 |
332 | new_weight_decays = new_weight_decays if new_weight_decays is not None else [1e-6]*len(really_new_tasks)
333 |
334 | # check if batchnorm should be reused
335 | if len(new_tasks) != 1:
336 | reuse_batchnorm = False
337 | elif new_tasks[0] in base_tasks and len(base_tasks) > 1:
338 | reuse_batchnorm = False
339 | else:
340 | reuse_batchnorm = True
341 | print(reuse_batchnorm)
342 | task_models = {}
343 | outputs = []
344 | adapter_rnn = None
345 | attention2d = None
346 | input_tensor = tf.keras.layers.Input(shape=input_dim)
347 | variable_duration = not (classifier == 'avgpool')
348 | if variable_duration:
349 | input_reshaped = tf.keras.layers.Reshape(
350 | target_shape=(-1, ))(input_tensor)
351 | else:
352 | input_reshaped = tf.keras.layers.Reshape(
353 | target_shape=(input_dim[0], ))(input_tensor)
354 |
355 | x, mask = input_features_and_mask(input_reshaped, num_fft=1024,
356 | hop_length=512,
357 | sample_rate=16000,
358 | f_min=20,
359 | f_max=8000,
360 | num_mels=num_mels,
361 | eps=1e-6,
362 | return_decibel=False,
363 | num_mfccs=None)
364 |
365 | pooled_mask = PoolMask((8,))(mask)
366 |
367 | expand_dims = tf.keras.layers.Lambda(
368 | lambda x: tf.expand_dims(x, 3), name='expand_input_dims')
369 | x = expand_dims(x)
370 | x = tf.keras.layers.Permute((2, 1, 3))(x)
371 | adapter_resnet = ResNet(filters=filters,
372 | factor=factor,
373 | N=N,
374 | learnall=learnall,
375 | dropout1=dropout1,
376 | dropout2=dropout2,
377 | weight_decays=base_weight_decays,
378 | reuse_batchnorm=reuse_batchnorm,
379 | tasks=base_tasks if mode == 'adapters' else [],
380 | input_bn=input_bn)
381 |
382 |
383 | for task, classes in zip(base_tasks, base_nb_classes):
384 | logger.info(f'Building model for {task} with {classes} classes...')
385 | adapters_in_resnet = task if mode == 'adapters' else None
386 | y = adapter_resnet(x, task=adapters_in_resnet)
387 |
388 | if initial_weights is not None and classifier == 'avgpool': # might need new last dense layer
389 | name = f'{task}_1'
390 | else:
391 | name = task
392 | if classifier == 'avgpool':
393 | y = avgpool(y)
394 | elif classifier == 'FCNAttention':
395 | if attention2d is None:
396 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y)[-1], name=f'core_2d_attention', trainable=learnall_classifier)
397 | if not share_feature_layer and mode == 'adapters':
398 | attention2d = Attention2Dtanh(name=f'{task}_2d_attention', mlp_units=K.int_shape(y)[-1], trainable=True)
399 | y = attention2d(y)
400 |
401 | elif classifier == 'rnn':
402 | if adapter_rnn is None:
403 | adapter_rnn = RNNWithAdapters(K.int_shape(y),
404 | hidden_size=K.int_shape(y)[-1],
405 | learnall=learnall_classifier,
406 | dropout=rnn_dropout,
407 | input_projection_factor=4,
408 | adapter_projection_factor=4,
409 | share_feature_layer=share_feature_layer,
410 | # tasks=base_tasks,
411 | tasks=base_tasks if mode == 'adapters' else [])
412 |
413 | y = adapter_rnn(y, adapters_in_resnet, mask=pooled_mask)
414 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{name}_dense')(y)
415 | y = tf.keras.layers.BatchNormalization(name=f'{name}_dense_batchnorm')(y)
416 | y = tf.keras.layers.Activation('relu', name=f'{name}_dense_relu')(y)
417 | y = tf.keras.layers.Dropout(0.2)(y)
418 | if classes == 2:
419 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=name)(y)
420 | else:
421 | y = tf.keras.layers.Dense(
422 | classes, activation='softmax', name=name)(y)
423 | model = tf.keras.Model(inputs=input_tensor, outputs=y)
424 | outputs.append(y)
425 | task_models[task] = model
426 |
427 | if really_new_tasks is not None and new_nb_classes is not None:
428 | for task, classes, weight_decay in zip(really_new_tasks, new_nb_classes, new_weight_decays):
429 | logger.info(f'Building model for {task} with {classes} classes...')
430 | adapter_resnet._add_new_task(task, weight_decay=weight_decay)
431 | y = adapter_resnet(x, task)
432 |
433 | if classifier == 'avgpool':
434 | y = avgpool(y)
435 | elif classifier == 'rnn':
436 | adapter_rnn._add_new_task(task)
437 | y = adapter_rnn(y, task)
438 | elif classifier == 'FCNAttention':
439 | if attention2d is None:
440 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y)[-1], name=f'core_2d_attention', trainable=learnall_classifier)
441 | if not share_feature_layer and mode == 'adapters':
442 | attention2d = Attention2Dtanh(lmbda=0.3, name=f'{task}_2d_attention', mlp_units=K.int_shape(y)[-1], trainable=True)
443 | y = attention2d(y)
444 |
445 |
446 |
447 |
448 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y)
449 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y)
450 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y)
451 | y = tf.keras.layers.Dropout(0.2)(y)
452 | if classes == 2:
453 | y = tf.keras.layers.Dense(
454 | 1, activation='sigmoid', name=task)(y)
455 | else:
456 | y = tf.keras.layers.Dense(
457 | classes, activation='softmax', name=task)(y)
458 | model = tf.keras.Model(inputs=input_tensor, outputs=y)
459 | outputs.append(y)
460 | task_models[task] = model
461 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs)
462 |
463 | if initial_weights is not None:
464 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights)
465 |
466 | return task_models, shared_model
467 |
468 |
469 | def create_multi_task_vgg16(input_dim,
470 | tasks=['EMO-DB', 'GEMEP'],
471 | nb_classes=[6, 10],
472 | random_noise=0.1,
473 | num_mels=128,
474 | classifier='attention2d',
475 | dropout=0.2,
476 | initial_weights=None,
477 | share_feature_layer=False,
478 | freeze_up_to=None):
479 | channel_axis = -1
480 | if tasks is None:
481 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!'
482 | logger.info(
483 | 'Trying to determine trained tasks from initial weights...')
484 | base_tasks, base_nb_classes = infer_tasks_from_weightfile(
485 | initial_weights)
486 |
487 | input_tensor = tf.keras.layers.Input(shape=input_dim)
488 | variable_duration = not (classifier == 'avgpool')
489 | if variable_duration:
490 | input_reshaped = tf.keras.layers.Reshape(
491 | target_shape=(-1, ))(input_tensor)
492 | else:
493 | input_reshaped = tf.keras.layers.Reshape(
494 | target_shape=(input_dim[0], ))(input_tensor)
495 |
496 | x, mask = input_features_and_mask(input_reshaped, num_fft=1024,
497 | hop_length=512,
498 | sample_rate=16000,
499 | f_min=20,
500 | f_max=8000,
501 | num_mels=num_mels,
502 | eps=1e-6,
503 | return_decibel=True,
504 | num_mfccs=None)
505 |
506 | pooled_mask = PoolMask((8,))(mask)
507 |
508 | expand_dims = tf.keras.layers.Lambda(
509 | lambda x: tf.expand_dims(x, 3), name='expand_input_dims')
510 | x = expand_dims(x)
511 | x = tf.keras.layers.Permute((2, 1, 3))(x)
512 | x = tf.keras.layers.Convolution2D(3, 1, activation='relu', name='learn_colourmapping', use_bias=False)(x)
513 | x = tf.keras.layers.BatchNormalization()(x)
514 | vgg16 = tf.keras.applications.VGG16(include_top=False, weights='imagenet', pooling=None)
515 | #vgg16 = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', pooling=None)
516 | if freeze_up_to is not None:
517 | for layer in vgg16.layers[:freeze_up_to]:
518 | layer.trainable = False
519 | else:
520 | for layer in vgg16.layers:
521 | layer.trainable = False
522 | task_models = {}
523 | outputs = []
524 | adapter_rnn = None
525 | attention2d = None
526 | for task, classes in zip(tasks, nb_classes):
527 | logger.info(f'Building model for {task} with {classes} classes...')
528 | y = vgg16(x)
529 |
530 | if initial_weights is not None and classifier == 'avgpool': # might need new last dense layer
531 | name = f'{task}_1'
532 | else:
533 | name = task
534 | if classifier == 'avgpool':
535 | y = global_avgpool(y)
536 |
537 |
538 | if classifier == 'FCNAttention':
539 | if attention2d is None:
540 | attention2d = Attention2Dtanh(lmbda=0.3, mlp_units=K.int_shape(y)[-1], name=f'core_2d_attention', trainable=True)
541 | if not share_feature_layer:
542 | attention2d = Attention2Dtanh(name=f'{task}_2d_attention', mlp_units=K.int_shape(y)[-1], trainable=True)
543 | y = attention2d(y)
544 |
545 | if classifier == 'rnn':
546 | if adapter_rnn is None:
547 | adapter_rnn = RNNWithAdapters(K.int_shape(y),
548 | hidden_size=K.int_shape(y)[-1],
549 | learnall=True,
550 | dropout=dropout,
551 | input_projection_factor=4,
552 | adapter_projection_factor=4,
553 | share_feature_layer=share_feature_layer,
554 | # tasks=base_tasks,
555 | tasks=[])
556 |
557 | y = adapter_rnn(y, None, mask=pooled_mask)
558 |
559 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y)
560 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y)
561 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y)
562 | y = tf.keras.layers.Dropout(dropout)(y)
563 | if classes == 2:
564 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=name)(y)
565 | else:
566 | y = tf.keras.layers.Dense(
567 | classes, activation='softmax', name=name)(y)
568 | model = tf.keras.Model(inputs=input_tensor, outputs=y)
569 | outputs.append(y)
570 | task_models[task] = model
571 |
572 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs)
573 |
574 | if initial_weights is not None:
575 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights)
576 |
577 | return task_models, shared_model
578 |
579 |
580 | def create_multi_task_rnn(input_dim,
581 | num_mels=128,
582 | num_mfccs=40,
583 | hidden_dim=512,
584 | cell='lstm',
585 | number_of_layers=2,
586 | down_pool=8,
587 | mode='adapters',
588 | bidirectional=False,
589 | learnall=True,
590 | dropout=0.2,
591 | base_tasks=['EMO-DB', 'GEMEP'],
592 | new_tasks=[],
593 | base_nb_classes=[6, 10],
594 | new_nb_classes=None,
595 | initial_weights=None,
596 | random_noise=0.1,
597 | input_bn=False,
598 | share_feature_layer=False,
599 | use_attention=True,
600 | share_attention=False,
601 | input_projection=True):
602 | channel_axis = -1
603 | if base_tasks is None:
604 | assert initial_weights is not None, f'Either base tasks or initial weights have to be specified!'
605 |
606 | logger.info(
607 | 'Trying to determine trained tasks from initial weights...')
608 | base_tasks, base_nb_classes = infer_tasks_from_weightfile(
609 | initial_weights)
610 |
611 | input_tensor = tf.keras.layers.Input(shape=input_dim)
612 | input_reshaped = tf.keras.layers.Reshape(target_shape=(-1, ))(input_tensor)
613 |
614 | x, mask = input_features_and_mask(input_reshaped,
615 | num_fft=1024,
616 | hop_length=512,
617 | num_mfccs=num_mfccs,
618 | sample_rate=16000,
619 | f_min=20,
620 | f_max=8000,
621 | num_mels=num_mels,
622 | eps=1e-6,
623 | return_decibel=num_mels is None)
624 | adapter_rnn = RNNWithAdapters(K.int_shape(x),
625 | hidden_size=hidden_dim,
626 | learnall=learnall,
627 | dropout=dropout,
628 | input_projection_factor=1,
629 | adapter_projection_factor=4,
630 | bidirectional=bidirectional,
631 | layers=number_of_layers,
632 | recurrent_cell=cell,
633 | input_bn=input_bn,
634 | downpool=down_pool,
635 | input_projection=input_projection,
636 | use_attention=use_attention,
637 | share_attention=share_attention,
638 | # tasks=base_tasks,
639 | tasks=base_tasks if mode == 'adapters' else [],
640 | share_feature_layer=share_feature_layer)
641 | if down_pool is not None:
642 | mask = PoolMask((down_pool,))(mask)
643 |
644 |
645 | if new_tasks is not None:
646 | really_new_tasks = [t for t in new_tasks if t not in base_tasks]
647 | new_nb_classes = [
648 | c for c, t in zip(new_nb_classes, new_tasks) if t not in base_tasks
649 | ]
650 | else:
651 | really_new_tasks = []
652 | new_nb_classes = []
653 |
654 | task_models = {}
655 | outputs = []
656 | for task, classes in zip(base_tasks, base_nb_classes):
657 | logger.info(f'Building model for {task} with {classes} classes...')
658 | #y = adapter_resnet(x, task)
659 | adapters_in = task if mode == 'adapters' else None
660 | y = adapter_rnn(x, adapters_in, mask=mask)
661 | if initial_weights is not None: # might need new last dense layer
662 | name = f'{task}_1'
663 | else:
664 | name = task
665 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y)
666 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y)
667 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y)
668 | y = tf.keras.layers.Dropout(0.2)(y)
669 | if classes == 2:
670 | y = tf.keras.layers.Dense(1, activation='sigmoid', name=name)(y)
671 | else:
672 | y = tf.keras.layers.Dense(
673 | classes, activation='softmax', name=name)(y)
674 | model = tf.keras.Model(inputs=input_tensor, outputs=y)
675 | outputs.append(y)
676 | task_models[task] = model
677 |
678 | if really_new_tasks is not None and new_nb_classes is not None:
679 | for task, classes in zip(really_new_tasks, new_nb_classes):
680 | logger.info(f'Building model for {task} with {classes} classes...')
681 | adapter_rnn._add_new_task(task)
682 | y = adapter_rnn(x, task, mask)
683 | # y = apply_zero_mask([pooled_mask,
684 | # y]) # zero out silence activations
685 | y = tf.keras.layers.Dense(K.int_shape(y)[-1]//2, activation=None, name=f'{task}_dense')(y)
686 | y = tf.keras.layers.BatchNormalization(name=f'{task}_dense_batchnorm')(y)
687 | y = tf.keras.layers.Activation('relu', name=f'{task}_dense_relu')(y)
688 | y = tf.keras.layers.Dropout(0.2)(y)
689 | if classes == 2:
690 | y = tf.keras.layers.Dense(
691 | 1, activation='sigmoid', name=task)(y)
692 | else:
693 | y = tf.keras.layers.Dense(
694 | classes, activation='softmax', name=task)(y)
695 | model = tf.keras.Model(inputs=input_tensor, outputs=y)
696 | outputs.append(y)
697 | task_models[task] = model
698 | shared_model = tf.keras.Model(inputs=input_tensor, outputs=outputs)
699 |
700 | if initial_weights is not None:
701 | shared_model, task_models = load_and_assert_loaded(shared_model, task_models, initial_weights)
702 | return task_models, shared_model
703 |
704 |
705 | def load_and_assert_loaded(shared_model, task_models, initial_weights):
706 | preloaded_layers = shared_model.layers.copy()
707 | shared_weights_pre_load, shared_names_pre_load = [], []
708 | for layer in preloaded_layers:
709 | if not layer.trainable:
710 | logger.debug(f'Appending weights of {layer.name} with shape before load.')
711 | shared_names_pre_load.append(layer.name)
712 | shared_weights_pre_load.append(layer.get_weights())
713 | logger.info('Loading weights from pre-trained model...')
714 | shared_model.load_weights(initial_weights, by_name=True)
715 | logger.info('Finished.')
716 |
717 | shared_weights_post_load = []
718 | shared_names_post_load = []
719 | for layer in shared_model.layers:
720 | if not layer.trainable:
721 | logger.debug(f'Appending weights of {layer.name} after load.')
722 | shared_names_post_load.append(layer.name)
723 | shared_weights_post_load.append(layer.get_weights())
724 | shared_weights_single_task = []
725 | shared_names_single_task = []
726 | single_task_model = task_models[list(
727 | task_models.keys())[0]]
728 | for layer in single_task_model.layers:
729 | if not layer.trainable:
730 | logger.debug(
731 | f'Appending weights of {layer.name} of single task model after load.')
732 | shared_names_single_task.append(layer.name)
733 | shared_weights_single_task.append(layer.get_weights())
734 | loaded, not_loaded, errors = 0, 0, 0
735 | assert shared_names_pre_load == shared_names_post_load and shared_names_post_load == shared_names_single_task, f'Layer name mistmatch: {shared_names_pre_load, shared_names_post_load, shared_names_single_task}.'
736 | for pre, pre_n, post, post_n, single, single_n in zip(shared_weights_pre_load, shared_names_pre_load, shared_weights_post_load, shared_names_post_load, shared_weights_single_task, shared_names_single_task):
737 | if array_list_equal(post, pre):
738 | not_loaded += 1
739 | logger.debug(
740 | f'Not loaded weights for layer {pre_n}: Total not loaded: {not_loaded}')
741 | elif array_list_equal(single, pre):
742 | not_loaded += 1
743 | logger.debug(
744 | f'Not loaded weights for layer {pre_n}: Total not loaded: {not_loaded}')
745 | elif array_list_equal(post, single):
746 | loaded += 1
747 | logger.debug(
748 | f'Loaded weights for layer {pre_n}. Total loaded: {loaded}')
749 | else:
750 | errors += 1
751 | logger.debug(
752 | f'Something went wrong with {pre_n, post_n, single_n}: {errors}')
753 | logger.info(
754 | f'Weights for {loaded} layers have been loaded from pre-trained model {initial_weights}.')
755 | return shared_model, task_models
756 |
--------------------------------------------------------------------------------
/emo-net/models/input_layers.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | import tensorflow.keras.backend as K
21 |
22 | """ https://gist.github.com/keunwoochoi/f4854acb68acf791a49a051893bcd23b """
23 | class LogMelgramLayer(tf.keras.layers.Layer):
24 | def __init__(
25 | self, num_fft, hop_length, num_mels, sample_rate, f_min=80, f_max=7600, eps=1e-6, return_decibel=True, top_db=80, mask_zero=True, **kwargs
26 | ):
27 | super(LogMelgramLayer, self).__init__(**kwargs)
28 | self.num_fft = num_fft
29 | self.hop_length = hop_length
30 | self.num_mels = num_mels
31 | self.sample_rate = sample_rate
32 | self.f_min = f_min
33 | self.f_max = f_max
34 | self.eps = eps
35 | self.return_decibel = return_decibel
36 | self.num_freqs = num_fft // 2 + 1
37 | self.mask_zero = mask_zero
38 | self.top_db = top_db
39 |
40 | lin_to_mel_matrix = tf.signal.linear_to_mel_weight_matrix(
41 | num_mel_bins=self.num_mels,
42 | num_spectrogram_bins=self.num_freqs,
43 | sample_rate=self.sample_rate,
44 | lower_edge_hertz=self.f_min,
45 | upper_edge_hertz=self.f_max,
46 | )
47 |
48 | self.lin_to_mel_matrix = lin_to_mel_matrix
49 |
50 | def build(self, input_shape):
51 | self.non_trainable_weights.append(self.lin_to_mel_matrix)
52 | super(LogMelgramLayer, self).build(input_shape)
53 |
54 | def call(self, input):
55 | """
56 | Args:
57 | input (tensor): Batch of mono waveform, shape: (None, N)
58 | Returns:
59 | log_melgrams (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1)
60 | """
61 | def _tf_log10(x):
62 | numerator = tf.math.log(x)
63 | denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
64 | return numerator / denominator
65 |
66 | stfts = tf.signal.stft(
67 | input,
68 | frame_length=self.num_fft,
69 | frame_step=self.hop_length,
70 | pad_end=False, # librosa test compatibility
71 | )
72 | mag_stfts = tf.abs(stfts)
73 |
74 | melgrams = tf.tensordot(
75 | tf.square(mag_stfts), self.lin_to_mel_matrix, 1
76 | )
77 | melgrams.set_shape(mag_stfts.shape[:-1].concatenate(self.lin_to_mel_matrix.shape[-1:]))
78 |
79 | if self.return_decibel:
80 | log_melgrams = 10 * _tf_log10((melgrams + self.eps) / tf.reduce_max(melgrams))
81 | if self.top_db is not None:
82 | if self.top_db < 0:
83 | raise ParameterError('top_db must be non-negative')
84 | log_melgrams = tf.math.maximum(log_melgrams, tf.reduce_max(log_melgrams) - self.top_db)
85 | #log_melgrams = (log_melgrams + self.top_db) / self.top_db
86 |
87 | else:
88 | log_melgrams = tf.math.log(melgrams + self.eps)
89 | return log_melgrams
90 |
91 |
92 |
93 | def get_config(self):
94 | config = {
95 | 'num_fft': self.num_fft,
96 | 'hop_length': self.hop_length,
97 | 'num_mels': self.num_mels,
98 | 'sample_rate': self.sample_rate,
99 | 'f_min': self.f_min,
100 | 'f_max': self.f_max,
101 | 'eps': self.eps,
102 | 'return_decibel': self.return_decibel,
103 | 'mask_zero': self.mask_zero,
104 | 'top_db': self.top_db
105 | }
106 | base_config = super(LogMelgramLayer, self).get_config()
107 | return dict(list(config.items()) + list(base_config.items()))
108 |
109 |
110 | class MFCCLayer(tf.keras.layers.Layer):
111 | def __init__(
112 | self, num_mfccs=50, **kwargs
113 | ):
114 | super(MFCCLayer, self).__init__(**kwargs)
115 | self.num_mfccs = num_mfccs
116 |
117 |
118 | def build(self, input_shape):
119 | super(MFCCLayer, self).build(input_shape)
120 |
121 | def call(self, input):
122 | """
123 | Args:
124 | input (tensor): Batch of log mel-spectrograms, shape: (None, num_frame, mel_bins, channel=1)
125 | Returns:
126 | mfccs (tensor): Batch of mfccs, shape: (None, num_frame, num_mfccs)
127 | """
128 |
129 | log_mel_spectrograms = input
130 |
131 | mfccs = tf.signal.mfccs_from_log_mel_spectrograms(
132 | log_mel_spectrograms)[..., :self.num_mfccs]
133 | return mfccs
134 |
135 | def get_config(self):
136 | config = {
137 | 'num_mfccs': self.num_mfccs
138 |
139 | }
140 | base_config = super(MFCCLayer, self).get_config()
141 | return dict(list(config.items()) + list(base_config.items()))
142 |
143 | class ComputeMask(tf.keras.layers.Layer):
144 | def __init__(self, num_fft, hop_length, **kwargs):
145 | super(ComputeMask, self).__init__(**kwargs)
146 | self.num_fft = num_fft
147 | self.hop_length = hop_length
148 |
149 | def call(self, x):
150 | frames = tf.signal.frame(x, self.num_fft, self.hop_length, pad_end=False,
151 | axis=-1,
152 | name=None)
153 | non_zeros = tf.math.count_nonzero(frames, axis=-1)
154 | mask = tf.not_equal(non_zeros, 0)
155 | return mask
156 |
157 | def get_config(self):
158 | config = {
159 | 'num_fft': self.num_fft,
160 | 'hop_length': self.hop_length,
161 | }
162 | base_config = super(ComputeMask, self).get_config()
163 | return dict(list(config.items()) + list(base_config.items()))
164 |
165 |
166 | class PoolMask(tf.keras.layers.Layer):
167 | def __init__(self, pool_size, **kwargs):
168 | super(PoolMask, self).__init__(**kwargs)
169 | self.pool_size = pool_size
170 |
171 | def call(self, x):
172 | x = tf.expand_dims(x, -1)
173 | x = tf.cast(x, dtype='int8')
174 | x = tf.nn.pool(x, self.pool_size, pooling_type='MAX', padding='SAME', strides=self.pool_size)
175 | x = K.batch_flatten(x)
176 | x = tf.not_equal(x, 0)
177 | return x
178 |
179 | def get_config(self):
180 | config = {
181 | 'pool_size': self.pool_size,
182 | }
183 | base_config = super(PoolMask, self).get_config()
184 | return dict(list(config.items()) + list(base_config.items()))
--------------------------------------------------------------------------------
/emo-net/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EIHW/EmoNet/e76dd53ab4c33e99182c69f6247e4d924f9a3d99/emo-net/training/__init__.py
--------------------------------------------------------------------------------
/emo-net/training/evaluate.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | tf.random.set_seed(42)
21 |
22 | import time
23 | import pandas as pd
24 | from os.path import join
25 | from sklearn.utils import class_weight
26 | from .losses import *
27 | from .metrics import *
28 | from ..models.build_model import create_multi_task_resnets, create_multi_task_rnn, create_multi_task_networks
29 | from ..data.loader import *
30 | from os import makedirs
31 | import logging
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | def evaluate(
36 | initial_weights='weights.h5',
37 | feature_extractor='cnn',
38 | batch_size=64,
39 | window=5,
40 | num_mels=128,
41 | task="",
42 | directory='EmoSet/IEMOCAP',
43 | val_csv='val.csv',
44 | share_feature_layer=True,
45 | input_bn=False,
46 | mode='adapters',
47 | output='pred.csv',
48 | **kwargs):
49 | if feature_extractor in ['cnn', 'vgg16']:
50 | variable_duration = False if kwargs['classifier'] == 'avgpool' else True
51 | else:
52 | variable_duration = True
53 | #variable_duration = False
54 |
55 |
56 | val_generator = AudioDataGenerator(val_csv,
57 | directory,
58 | batch_size=batch_size,
59 | window=window,
60 | shuffle=False,
61 | sr=16000,
62 | time_stretch=None,
63 | pitch_shift=None,
64 | save_dir=None,
65 | variable_duration=variable_duration)
66 |
67 | val_dataset = val_generator.tf_dataset()
68 |
69 |
70 |
71 |
72 |
73 | x, _ = val_generator[0]
74 | if not variable_duration:
75 | init = x.shape[1:]
76 | else:
77 | init = (None, )
78 | models, shared_model = create_multi_task_networks(
79 | init,
80 | feature_extractor=feature_extractor,
81 | initial_weights=initial_weights,
82 | num_mels=num_mels,
83 | new_tasks=[],
84 | new_nb_classes=[],
85 | mode=mode,
86 | input_bn=input_bn,
87 | learnall=False,
88 | share_feature_layer=share_feature_layer,
89 | **kwargs)
90 | model = models[task]
91 | #model.load_weights(initial_weights, by_name=True)
92 |
93 | model.summary()
94 | #print(model.non_trainable_weights)
95 | #model.load_weights(initial_weights, by_name=True)
96 |
97 |
98 | metric_callback = ClassificationMetricCallback(
99 | validation_generator=val_dataset.prefetch(tf.data.experimental.AUTOTUNE), dataset_name='Test', labels=val_generator.class_indices, true=val_generator.categorical_classes)
100 |
101 |
102 |
103 | filenames = list(map(lambda x: join(*(x.split('/')[-4:])), val_generator.files))
104 | index_to_class = {v: k for k, v in val_generator.class_indices.items()}
105 |
106 | logger.info("Model loaded.")
107 | model.compile(loss='categorical_crossentropy', optimizer=tf.keras.optimizers.SGD(lr=0.1), metrics=['accuracy'])
108 | metric_callback.set_model(model)
109 | x = model.evaluate(val_generator.tf_dataset(),
110 | # use_multiprocessing=True,
111 | # max_queue_size=n_workers * 2,
112 | verbose=1,
113 | callbacks=[
114 | metric_callback
115 | ])
116 | metric_callback.on_epoch_end(epoch=1)
117 | predictions = model.predict(val_generator.tf_dataset())
118 | probas = predictions
119 | if predictions.shape[1] > 1:
120 | predictions = list(map(lambda x: index_to_class[x], np.argmax(predictions, axis=-1)))
121 | else:
122 | predictions = list(map(lambda x: index_to_class[x], np.squeeze(np.where(predictions < 0.5, 0, 1))))
123 | true = list(map(lambda x: index_to_class[x], val_generator.classes))
124 | columns = ['filename', *[f'probability_{index_to_class[i]}' for i in range(probas.shape[1])], 'pred_label', 'true_label']
125 | df = pd.DataFrame(columns=columns)
126 | df['filename'] = filenames
127 | df['pred_label'] = predictions
128 | df['true_label'] = true
129 | df[[f'probability_{index_to_class[i]}' for i in range(probas.shape[1])]] = probas
130 | df.to_csv(output, index=False)
131 |
132 |
--------------------------------------------------------------------------------
/emo-net/training/losses.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | """
20 | Define our custom loss function.
21 |
22 | https://github.com/umbertogriffo/focal-loss-keras/blob/master/losses.py
23 | """
24 | from tensorflow.keras import backend as K
25 | import tensorflow as tf
26 |
27 | import dill
28 |
29 |
30 | def soft_ordinal_categorical_loss(n_classes=3, metric=lambda x, y: tf.math.abs(x-y)):
31 | ranks_tensor = tf.constant(list(range(n_classes)), dtype='float32')
32 |
33 | def soft_ordinal_categorical_loss_fixed(y_true, y_pred):
34 | trues_tensor = tf.cast(tf.math.argmax(y_true, -1, output_type='int32'), 'float32')
35 | diff = metric(tf.expand_dims(trues_tensor, -1), tf.expand_dims(ranks_tensor, 0))
36 | softmax = tf.nn.softmax(-diff)
37 | return tf.keras.losses.categorical_crossentropy(softmax, y_pred)
38 |
39 | return soft_ordinal_categorical_loss_fixed
40 |
41 |
42 | def binary_focal_loss(gamma=2., alpha=.25):
43 | """
44 | Binary form of focal loss.
45 | FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
46 | where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
47 | References:
48 | https://arxiv.org/pdf/1708.02002.pdf
49 | Usage:
50 | model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
51 | """
52 | def binary_focal_loss_fixed(y_true, y_pred):
53 | """
54 | :param y_true: A tensor of the same shape as `y_pred`
55 | :param y_pred: A tensor resulting from a sigmoid
56 | :return: Output tensor.
57 | """
58 | pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
59 | pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
60 |
61 | epsilon = K.epsilon()
62 | # clip to prevent NaN's and Inf's
63 | pt_1 = K.clip(pt_1, epsilon, 1. - epsilon)
64 | pt_0 = K.clip(pt_0, epsilon, 1. - epsilon)
65 |
66 | return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) \
67 | -K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))
68 |
69 | return binary_focal_loss_fixed
70 |
71 |
72 | def categorical_focal_loss(gamma=2., alpha=.25):
73 | """
74 | Softmax version of focal loss.
75 | m
76 | FL = ∑ -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c)
77 | c=1
78 | where m = number of classes, c = class and o = observation
79 | Parameters:
80 | alpha -- the same as weighing factor in balanced cross entropy
81 | gamma -- focusing parameter for modulating factor (1-p)
82 | Default value:
83 | gamma -- 2.0 as mentioned in the paper
84 | alpha -- 0.25 as mentioned in the paper
85 | References:
86 | Official paper: https://arxiv.org/pdf/1708.02002.pdf
87 | https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy
88 | Usage:
89 | model.compile(loss=[categorical_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
90 | """
91 | def categorical_focal_loss_fixed(y_true, y_pred):
92 | """
93 | :param y_true: A tensor of the same shape as `y_pred`
94 | :param y_pred: A tensor resulting from a softmax
95 | :return: Output tensor.
96 | """
97 |
98 | # Scale predictions so that the class probas of each sample sum to 1
99 | y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
100 |
101 | # Clip the prediction value to prevent NaN's and Inf's
102 | epsilon = K.epsilon()
103 | y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
104 |
105 | # Calculate Cross Entropy
106 | cross_entropy = -y_true * K.log(y_pred)
107 |
108 | # Calculate Focal Loss
109 | loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy
110 |
111 | # Sum the losses in mini_batch
112 | return K.sum(loss, axis=1)
113 |
114 | return categorical_focal_loss_fixed
115 |
116 |
117 | if __name__ == '__main__':
118 |
119 | # Test serialization of nested functions
120 | bin_inner = dill.loads(dill.dumps(binary_focal_loss(gamma=2., alpha=.25)))
121 | print(bin_inner)
122 |
123 | cat_inner = dill.loads(dill.dumps(categorical_focal_loss(gamma=2., alpha=.25)))
124 | print(cat_inner)
--------------------------------------------------------------------------------
/emo-net/training/metrics.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | def warn(*args, **kwargs):
20 | pass
21 | import warnings
22 | warnings.warn = warn
23 |
24 | import numpy as np
25 | import tensorflow as tf
26 | from abc import ABC, abstractmethod
27 | from dataclasses import dataclass, field
28 | from scipy.stats import shapiro, pearsonr
29 | from sklearn.metrics import recall_score, make_scorer, accuracy_score, f1_score, mean_squared_error, classification_report, confusion_matrix, multilabel_confusion_matrix, precision_score, roc_auc_score, average_precision_score, roc_curve
30 | from sklearn.metrics.scorer import _BaseScorer
31 | from statistics import pstdev, mean
32 | from typing import Dict, List, ClassVar, Set
33 | from math import sqrt
34 | from tqdm import tqdm
35 |
36 | import logging
37 | logger = logging.getLogger(__name__)
38 |
39 |
40 | def mask_metric(func):
41 | def mask_metric_function(*args, **kwargs):
42 | mask = np.not_equal(kwargs['y_true'], -1).astype(float)
43 | kwargs['y_true'] = (kwargs['y_true'] * mask)
44 | kwargs['y_pred'] = (kwargs['y_pred'] * mask)
45 | return func(*args, **kwargs)
46 |
47 | return mask_metric_function
48 |
49 |
50 | def optimal_threshold(fpr, tpr, thresholds):
51 | optimal_idx = np.argmax(tpr - fpr)
52 | optimal_threshold = thresholds[optimal_idx]
53 | return optimal_threshold
54 |
55 |
56 | def compute_binary_cutoffs(y_true, y_pred):
57 | if y_true.shape == y_pred.shape and len(y_true.shape) == 1: # 2 classes
58 | fpr, tpr, thresholds = roc_curve(y_true, y_pred)
59 | return [optimal_threshold(fpr, tpr, thresholds)]
60 | elif y_true.shape == y_pred.shape and len(y_true.shape) == 2: # multilabel
61 | fpr_tpr_thresholds = [
62 | roc_curve(y_true[:, i], y_pred[:, i])
63 | for i in range(y_true.shape[1])
64 | ]
65 | return [optimal_threshold(*x) for x in fpr_tpr_thresholds]
66 |
67 |
68 | class ClassificationMetricCallback(tf.keras.callbacks.Callback):
69 | def __init__(self,
70 | labels: List = None,
71 | validation_generator=None,
72 | validation_data=None,
73 | multi_label=False,
74 | partition='validation',
75 | true=None,
76 | period=1,
77 | dataset_name='default'):
78 | super().__init__()
79 | if labels is not None:
80 | self.labels = {name: index for index, name in enumerate(labels)}
81 | self.binary = (len(labels) == 2)
82 |
83 | elif validation_generator is not None:
84 | self.labels = validation_generator.class_indices
85 | self.binary = len(self.labels) == 2
86 |
87 | # if true is not None:
88 | # self.y_val = np.squeeze(true)
89 |
90 |
91 | self.validation_generator = validation_generator
92 | self.validation_data = validation_data
93 | if isinstance(self.validation_generator, tf.data.Dataset):
94 | self.y_val = []
95 | for features, labels in self.validation_generator.take(-1): # only take first element of dataset
96 | labels_numpy = labels.numpy()
97 | if labels_numpy.shape[-1] == 1:
98 | labels_numpy = np.squeeze(labels_numpy, axis=-1)
99 | self.y_val.append(labels_numpy)
100 | self.y_val = np.concatenate(self.y_val)
101 | else:
102 | self.y_val = np.squeeze(self.validation_generator.categorical_classes)
103 | self.multi_label = multi_label
104 | self.partition = partition
105 | self.keras_metric_quantities = KERAS_METRIC_QUANTITIES
106 | self.dataset_name = dataset_name
107 |
108 | self._binary_cutoffs = []
109 | self._data = []
110 | self.period = period
111 |
112 | def on_train_begin(self, logs={}):
113 | pass
114 |
115 | def on_epoch_end(self, epoch, logs={}):
116 | if epoch % self.period == 0:
117 | if self.validation_generator is None:
118 | X_val, y_val = self.validation_data[0], self.validation_data[1]
119 | y_pred = np.asarray(self.model.predict(X_val))
120 | else:
121 | y_pred = np.squeeze(self.model.predict(self.validation_generator))
122 |
123 | logs = self.compute_metrics(self.y_val,
124 | y_pred,
125 | multi_label=self.multi_label,
126 | binary=self.binary,
127 | labels=sorted(
128 | self.labels.values()),
129 | prefix=f'{self.partition}',
130 | logs=logs,
131 | target_names=sorted(self.labels.keys()))
132 |
133 | return
134 |
135 | def get_data(self):
136 | return self._data
137 |
138 | def compute_metrics(self,
139 | y_val,
140 | y_pred,
141 | multi_label=False,
142 | binary=False,
143 | labels=None,
144 | prefix='',
145 | logs={},
146 | target_names=None):
147 | eval_string = f'\nEvaluation results for partition {self.partition} of dataset {self.dataset_name}:\n'
148 | all_classes_present = np.all(np.any(y_val > 0, axis=0))
149 | if multi_label:
150 | binary_cutoffs = compute_binary_cutoffs(y_true=y_val,
151 | y_pred=y_pred)
152 | self._binary_cutoffs.append(binary_cutoffs)
153 | logger.info(f'Optimal cutoffs: {binary_cutoffs}')
154 | else:
155 | binary_cutoffs = None
156 | y_val_t, y_pred_t = ClassificationMetric._transform_arrays(
157 | y_true=y_val,
158 | y_pred=y_pred,
159 | multi_label=multi_label,
160 | binary=binary,
161 | binary_cutoffs=binary_cutoffs)
162 | eval_string += classification_report(y_val_t,
163 | y_pred_t,
164 | target_names=target_names)
165 | if self.multi_label:
166 | eval_string += '\n'+ str(multilabel_confusion_matrix(y_true=y_val_t,
167 | y_pred=y_pred_t,
168 | labels=labels))
169 | else:
170 | conf_matrix = confusion_matrix(y_true=np.argmax(y_val_t, axis=1) if
171 | len(y_val_t.shape) > 1 else y_val_t,
172 | y_pred=np.argmax(y_pred_t, axis=1) if
173 | len(y_pred_t.shape) > 1 else y_pred_t,
174 | labels=labels)
175 | eval_string += '\n'+ str(conf_matrix)
176 | logs[f'{prefix}confusion_matrix'] = conf_matrix
177 | for i, cm in enumerate(CLASSIFICATION_METRICS):
178 | if all_classes_present or not (cm == ROC_AUC or cm == PR_AUC):
179 | if cm.needs_categorical:
180 | metric = cm.compute(y_true=y_val_t,
181 | y_pred=y_pred_t,
182 | labels=labels,
183 | binary=binary,
184 | multi_label=multi_label,
185 | binary_cutoffs=binary_cutoffs)
186 | else:
187 | metric = cm.compute(y_true=y_val,
188 | y_pred=y_pred,
189 | labels=labels,
190 | binary=binary,
191 | multi_label=multi_label,
192 | binary_cutoffs=binary_cutoffs)
193 | metric_value = metric.value
194 | eval_string += f'\n{prefix} {cm.description}: {metric_value}'
195 | if not self._data: # first recorded value
196 | self._data.append({
197 | f'{self.keras_metric_quantities[cm]}/{prefix}':
198 | metric_value,
199 | })
200 | elif i == 0 and self._data and f'{self.keras_metric_quantities[cm]}/{prefix}' in self._data[
201 | -1].keys():
202 | self._data.append({
203 | f'{self.keras_metric_quantities[cm]}/{prefix}':
204 | metric_value,
205 | })
206 | else:
207 | self._data[-1][
208 | f'{self.keras_metric_quantities[cm]}/{prefix}'] = metric_value
209 | if len(
210 | self._data
211 | ) > 1: # this is the second epoch and metrics have been recorded for the first epoch
212 | cur_best = self._data[-2][
213 | f'{self.keras_metric_quantities[cm]}_best/{prefix}']
214 | else: # this is the first epoch
215 | cur_best = metric_value
216 |
217 | new_best = metric_value if metric > cm(
218 | value=cur_best) else cur_best
219 |
220 | self._data[-1][
221 | f'{self.keras_metric_quantities[cm]}_best/{prefix}'] = new_best
222 |
223 | logs[f'{self.keras_metric_quantities[cm]}/{prefix}'] = metric_value
224 | logs[f'{self.keras_metric_quantities[cm]}_best/{prefix}'] = new_best
225 | else:
226 | logger.info(
227 | f'Not all classes occur in the validation data, skipping ROC AUC and PR AUC.'
228 | )
229 | logger.info(eval_string)
230 | return logs
231 |
232 |
233 | class RegressionMetricCallback(tf.keras.callbacks.Callback):
234 | def __init__(self, validation_data=()):
235 | super().__init__()
236 | self.validation_data = validation_data
237 |
238 | def on_train_begin(self, logs={}):
239 | self._data = []
240 |
241 | def on_epoch_end(self, batch, logs={}):
242 | X_val, y_val = self.validation_data[0], self.validation_data[1]
243 | y_predict = np.asarray(self.model.predict(X_val))
244 |
245 | for metric in REGRESSION_METRICS:
246 | metric_value = metric.compute(y_true=y_val, y_pred=y_predict).value
247 | self._data.append({f'val_{metric.__name__.lower()}': metric_value})
248 | logs[f'val_{metric.__name__.lower()}'] = metric_value
249 | return
250 |
251 | def get_data(self):
252 | return self._data
253 |
254 |
255 | @dataclass(order=True)
256 | class Metric(ABC):
257 | sort_index: float = field(init=False, repr=False)
258 | description: ClassVar[str] = 'Metric'
259 | key: ClassVar[str] = 'M'
260 | value: float
261 | scikit_scorer: ClassVar[_BaseScorer] = field(init=False, repr=False)
262 | greater_is_better: ClassVar[bool] = True
263 |
264 | def __post_init__(self):
265 | self.sort_index = self.value if self.greater_is_better else -self.value
266 |
267 |
268 | @dataclass(order=True)
269 | class ClassificationMetric(Metric, ABC):
270 | multi_label: bool = False
271 | binary: bool = False
272 | average: ClassVar[str] = None
273 | needs_categorical: ClassVar[bool] = True
274 |
275 | @classmethod
276 | @mask_metric
277 | @abstractmethod
278 | def compute(cls,
279 | y_true: np.array,
280 | y_pred: np.array,
281 | labels: List,
282 | multi_label: bool,
283 | binary: bool,
284 | binary_cutoffs: List[float] = None) -> Metric:
285 | pass
286 |
287 | @staticmethod
288 | def _transform_arrays(y_true: np.array,
289 | y_pred: np.array,
290 | multi_label: bool,
291 | binary: bool,
292 | binary_cutoffs: List[float] = None
293 | ) -> (np.array, np.array):
294 | if binary:
295 | if len(y_pred.shape) > 1:
296 | y_pred = np.reshape(y_pred, -1)
297 | if len(y_true.shape) > 1:
298 | y_true = np.reshape(y_true, -1)
299 | assert (
300 | y_true.shape == y_pred.shape and len(y_true.shape) == 1
301 | ), f'Shapes of predictions and labels for binary classification should conform to (n_samples,) but received {y_pred.shape} and {y_true.shape}.'
302 | #if binary_cutoffs is None:
303 | binary_cutoffs = 0.5
304 | #y_pred_transformed = np.zeros_like(y_pred, dtype=int)
305 | #y_pred_transformed[y_pred > binary_cutoffs[0]] = 1
306 | y_pred_transformed = np.where(y_pred > binary_cutoffs, 1, 0)
307 | y_true_transformed = y_true
308 |
309 | elif multi_label:
310 | assert (
311 | y_true.shape == y_pred.shape
312 | ), f'Shapes of predictions and labels for multilabel classification should conform to (n_samples, n_classes) but received {y_pred.shape} and {y_true.shape}.'
313 | if binary_cutoffs is None:
314 | binary_cutoffs = compute_binary_cutoffs(y_true, y_pred)
315 | # y_pred_transformed = np.zeros_like(y_pred, dtype=int)
316 | # y_pred_transformed[y_pred > 0.5] = 1
317 | y_pred_transformed = np.where(y_pred > binary_cutoffs, 1, 0)
318 | y_true_transformed = y_true
319 | else:
320 | if y_true.shape[1] > 1:
321 | y_true_transformed = np.zeros_like(y_true)
322 | y_true_transformed[range(len(y_true)), y_true.argmax(1)] = 1
323 | if y_pred.shape[1] > 1:
324 | y_pred_transformed = np.zeros_like(y_pred)
325 | y_pred_transformed[range(len(y_pred)), y_pred.argmax(1)] = 1
326 | assert (
327 | y_true.shape == y_pred.shape
328 | ), f'Shapes of predictions and labels for multiclass classification should conform to (n_samples,n_classes) but received {y_pred.shape} and {y_true.shape}.'
329 | return y_true_transformed, y_pred_transformed
330 |
331 |
332 | @dataclass(order=True)
333 | class RegressionMetric(Metric, ABC):
334 | @staticmethod
335 | @abstractmethod
336 | def compute(y_true: np.array, y_pred: np.array) -> Metric:
337 | pass
338 |
339 |
340 | @dataclass(order=True)
341 | class MicroRecall(ClassificationMetric):
342 | description: ClassVar[str] = 'Micro Average Recall'
343 | average: ClassVar[str] = 'micro'
344 | key: ClassVar[str] = 'Recall/Micro'
345 |
346 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(recall_score,
347 | average='micro')
348 | greater_is_better: ClassVar[bool] = True
349 |
350 | @classmethod
351 | @mask_metric
352 | def compute(cls,
353 | y_true: np.array,
354 | y_pred: np.array,
355 | labels: List,
356 | multi_label: bool,
357 | binary: bool,
358 | binary_cutoffs: List[float] = None) -> ClassificationMetric:
359 | score = recall_score(y_true=y_true,
360 | y_pred=y_pred,
361 | labels=labels,
362 | average=cls.average)
363 | return cls(value=score, multi_label=multi_label, binary=binary)
364 |
365 |
366 | @dataclass(order=True)
367 | class UAR(MicroRecall):
368 | average: ClassVar[str] = 'macro'
369 | description: ClassVar[str] = 'Unweighted Average Recall'
370 | key: ClassVar[str] = 'Recall/Macro'
371 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(recall_score,
372 | average='macro')
373 | greater_is_better: ClassVar[bool] = True
374 |
375 |
376 | @dataclass(order=True)
377 | class Accuracy(ClassificationMetric):
378 | description: ClassVar[str] = 'Accuracy'
379 | key: ClassVar[str] = 'acc'
380 |
381 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(accuracy_score)
382 | greater_is_better: ClassVar[bool] = True
383 |
384 | @classmethod
385 | @mask_metric
386 | def compute(cls,
387 | y_true: np.array,
388 | y_pred: np.array,
389 | labels: List,
390 | multi_label: bool,
391 | binary: bool,
392 | binary_cutoffs: List[float] = None) -> ClassificationMetric:
393 | score = accuracy_score(y_true=y_true, y_pred=y_pred)
394 | return cls(value=score, multi_label=multi_label, binary=binary)
395 |
396 |
397 | @dataclass(order=True)
398 | class MacroF1(ClassificationMetric):
399 | average: ClassVar[str] = 'macro'
400 | description: ClassVar[str] = 'Macro Average F1'
401 | key: ClassVar[str] = 'F1/Macro'
402 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(f1_score,
403 | average='macro')
404 | greater_is_better: ClassVar[bool] = True
405 |
406 | @classmethod
407 | @mask_metric
408 | def compute(cls,
409 | y_true: np.array,
410 | y_pred: np.array,
411 | labels: List,
412 | multi_label: bool,
413 | binary: bool,
414 | binary_cutoffs: List[float] = None) -> ClassificationMetric:
415 | score = f1_score(y_true=y_true,
416 | y_pred=y_pred,
417 | labels=labels,
418 | average=cls.average)
419 | return cls(value=score, multi_label=multi_label, binary=binary)
420 |
421 |
422 | @dataclass(order=True)
423 | class MicroF1(MacroF1):
424 | average: ClassVar[str] = 'micro'
425 | description: ClassVar[str] = 'Micro Average F1'
426 | key: ClassVar[str] = 'F1/Micro'
427 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(f1_score,
428 | average='micro')
429 | greater_is_better: ClassVar[bool] = True
430 |
431 |
432 | @dataclass(order=True)
433 | class MacroPrecision(ClassificationMetric):
434 | average: ClassVar[str] = 'macro'
435 | description: ClassVar[str] = 'Macro Average Precision'
436 | key: ClassVar[str] = 'Prec/Macro'
437 |
438 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(precision_score,
439 | average='macro')
440 | greater_is_better: ClassVar[bool] = True
441 |
442 | @classmethod
443 | @mask_metric
444 | def compute(cls,
445 | y_true: np.array,
446 | y_pred: np.array,
447 | labels: List,
448 | multi_label: bool,
449 | binary: bool,
450 | binary_cutoffs: List[float] = None) -> ClassificationMetric:
451 | score = precision_score(y_true=y_true,
452 | y_pred=y_pred,
453 | labels=labels,
454 | average=cls.average)
455 | return cls(value=score, multi_label=multi_label, binary=binary)
456 |
457 |
458 | @dataclass(order=True)
459 | class MicroPrecision(MacroPrecision):
460 | average: ClassVar[str] = 'micro'
461 | description: ClassVar[str] = 'Micro Average Prec'
462 | key: ClassVar[str] = 'Prec/Micro'
463 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(precision_score,
464 | average='micro')
465 | greater_is_better: ClassVar[bool] = True
466 |
467 |
468 | @dataclass(order=True)
469 | class ROC_AUC(ClassificationMetric):
470 | average: ClassVar[str] = 'macro'
471 | description: ClassVar[
472 | str] = 'Area Under the Receiver Operating Characteristic Curve'
473 | key: ClassVar[str] = 'ROC AUC'
474 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(roc_auc_score,
475 | average='macro')
476 | greater_is_better: ClassVar[bool] = True
477 | needs_categorical: ClassVar[bool] = False
478 |
479 | @classmethod
480 | @mask_metric
481 | def compute(cls,
482 | y_true: np.array,
483 | y_pred: np.array,
484 | labels: List,
485 | multi_label: bool,
486 | binary: bool,
487 | binary_cutoffs: List[float] = None) -> ClassificationMetric:
488 | score = roc_auc_score(y_true=y_true,
489 | y_score=y_pred,
490 | average=cls.average)
491 | return cls(value=score, multi_label=multi_label, binary=binary)
492 |
493 |
494 | @dataclass(order=True)
495 | class PR_AUC(ClassificationMetric):
496 | average: ClassVar[str] = 'macro'
497 | description: ClassVar[str] = 'Area Under the Precision Recall Curve'
498 | key: ClassVar[str] = 'PR AUC'
499 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(average_precision_score,
500 | average='macro')
501 | greater_is_better: ClassVar[bool] = True
502 | needs_categorical: ClassVar[bool] = False
503 |
504 |
505 | @classmethod
506 | @mask_metric
507 | def compute(cls,
508 | y_true: np.array,
509 | y_pred: np.array,
510 | labels: List,
511 | multi_label: bool,
512 | binary: bool,
513 | binary_cutoffs: List[float] = None) -> ClassificationMetric:
514 | score = average_precision_score(y_true=y_true,
515 | y_score=y_pred,
516 | average=cls.average)
517 | return cls(value=score, multi_label=multi_label, binary=binary)
518 |
519 |
520 | @dataclass(order=True)
521 | class MSE(RegressionMetric):
522 | description: ClassVar[str] = 'Mean Squared Error'
523 | key: ClassVar[str] = 'mse'
524 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(mean_squared_error,
525 | greater_is_better=False)
526 | greater_is_better: ClassVar[bool] = False
527 |
528 | @staticmethod
529 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric:
530 | score = mean_squared_error(y_true=y_true, y_pred=y_pred)
531 | return MSE(value=score)
532 |
533 |
534 | def pearson_correlation_coefficient(y_true, y_pred):
535 | return pearsonr(y_true, y_pred)[0]
536 |
537 |
538 | @dataclass(order=True)
539 | class PCC(RegressionMetric):
540 | description: ClassVar[str] = 'Pearson\'s Correlation Coeffiecient'
541 | key: ClassVar[str] = 'pcc'
542 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(
543 | pearson_correlation_coefficient, greater_is_better=True)
544 | greater_is_better: ClassVar[bool] = True
545 |
546 | @staticmethod
547 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric:
548 | score = pearson_correlation_coefficient(y_true=y_true, y_pred=y_pred)
549 | return PCC(value=score)
550 |
551 |
552 | def concordance_correlation_coefficient(y_true, y_pred):
553 | ccc = 2 * pearson_correlation_coefficient(y_true=y_true, y_pred=y_pred) / (
554 | np.var(y_true) + np.var(y_pred) +
555 | (np.mean(y_true) - np.mean(y_pred))**2)
556 | return ccc
557 |
558 |
559 | @dataclass(order=True)
560 | class CCC(RegressionMetric):
561 | description: ClassVar[str] = 'Concordance Correlation Coeffiecient'
562 | key: ClassVar[str] = 'ccc'
563 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(
564 | pearson_correlation_coefficient, greater_is_better=True)
565 | greater_is_better: ClassVar[bool] = True
566 |
567 | @staticmethod
568 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric:
569 | score = concordance_correlation_coefficient(y_true=y_true,
570 | y_pred=y_pred)
571 | return CCC(value=score)
572 |
573 |
574 | def root_mean_squared_error(y_true, y_pred):
575 | return sqrt(mean_squared_error(y_true=y_true, y_pred=y_pred))
576 |
577 |
578 | @dataclass(order=True)
579 | class RMSE(RegressionMetric):
580 | description: ClassVar[str] = 'Root Mean Squared Error'
581 | key: ClassVar[str] = 'rmse'
582 |
583 | scikit_scorer: ClassVar[_BaseScorer] = make_scorer(root_mean_squared_error,
584 | greater_is_better=False)
585 | greater_is_better: ClassVar[bool] = False
586 |
587 | @staticmethod
588 | def compute(y_true: np.array, y_pred: np.array) -> RegressionMetric:
589 | score = root_mean_squared_error(y_true=y_true, y_pred=y_pred)
590 | return RMSE(value=score)
591 |
592 |
593 | @dataclass
594 | class MetricStats():
595 | mean: float
596 | standard_deviation: float
597 | normality_tests: Dict[str, tuple]
598 |
599 |
600 | def compute_metric_stats(metrics: List[Metric]) -> MetricStats:
601 | metric_values = [metric.value for metric in metrics]
602 | normality_tests = dict()
603 | if len(metric_values) > 2:
604 | normality_tests['Shapiro-Wilk'] = shapiro(metric_values)
605 | return MetricStats(mean=mean(metric_values),
606 | standard_deviation=pstdev(metric_values),
607 | normality_tests=normality_tests)
608 |
609 |
610 | def all_subclasses(cls):
611 | return set(cls.__subclasses__()).union(
612 | [s for c in cls.__subclasses__() for s in all_subclasses(c)])
613 |
614 |
615 | CLASSIFICATION_METRICS = all_subclasses(ClassificationMetric)
616 | REGRESSION_METRICS = all_subclasses(RegressionMetric)
617 | ALL_METRICS = all_subclasses(Metric)
618 |
619 | SCIKIT_CLASSIFICATION_SCORERS = {
620 | M.__name__: M.scikit_scorer
621 | for M in CLASSIFICATION_METRICS if M != ROC_AUC and M != PR_AUC
622 | }
623 |
624 | SCIKIT_CLASSIFICATION_SCORERS_EXTENDED = {
625 | M.__name__: M.scikit_scorer
626 | for M in CLASSIFICATION_METRICS
627 | }
628 |
629 | SCIKIT_REGRESSION_SCORERS = {
630 | M.__name__: M.scikit_scorer
631 | for M in REGRESSION_METRICS
632 | }
633 |
634 | KERAS_METRIC_QUANTITIES = {
635 | M: f'val_{"_".join(M.key.lower().split(" "))}'
636 | for M in ALL_METRICS
637 | }
638 |
639 | KERAS_METRIC_MODES = {
640 | M: 'max' if M.greater_is_better else 'min'
641 | for M in ALL_METRICS
642 | }
643 |
644 | KEY_TO_METRIC = {metric.__name__: metric for metric in ALL_METRICS}
645 |
--------------------------------------------------------------------------------
/emo-net/training/train.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import tensorflow as tf
20 | tf.random.set_seed(42)
21 |
22 | import time
23 | from os.path import join
24 | from sklearn.utils import class_weight
25 | from .losses import *
26 | from .metrics import *
27 | from ..models.build_model import create_multi_task_resnets, create_multi_task_rnn, create_multi_task_networks
28 | from ..data.loader import *
29 | from os import makedirs
30 | import logging
31 | logger = logging.getLogger(__name__)
32 |
33 |
34 |
35 | categorical_loss_map = {'crossentropy': "categorical_crossentropy", "focal": categorical_focal_loss(
36 | ), "ordinal": soft_ordinal_categorical_loss}
37 |
38 | binary_loss_map = {'crossentropy': "binary_crossentropy", "focal": binary_focal_loss(
39 | ), "ordinal": "binary_crossentropy"}
40 |
41 | def determine_decay(generator, batch_size):
42 | if 10000 > len(generator) * batch_size >= 1000:
43 | decay = 0.0005
44 | elif 1000 > len(generator) * batch_size >= 500:
45 | decay = 0.002
46 | elif len(generator) * batch_size < 500:
47 | decay = 0.005
48 | else:
49 | decay = 1e-6
50 | return decay
51 |
52 |
53 | def named_logs(model, logs):
54 | result = {}
55 | for l in zip(model.metrics_names, logs):
56 | result[l[0]] = l[1]
57 | return result
58 |
59 | def __feature_extractor_params_string(feature_extractor, **kwargs):
60 | if feature_extractor == 'cnn':
61 | return __cnn_params(**kwargs)
62 | elif feature_extractor == 'rnn':
63 | return __rnn_params(**kwargs)
64 | elif feature_extractor == 'vgg16':
65 | return __vgg16_params(**kwargs)
66 | elif feature_extractor == 'fusion':
67 | return __fusion_params(**kwargs)
68 |
69 | def __cnn_params(classifier, N, factor, dropout1, dropout2, rnn_dropout, filters, learnall_classifier):
70 | return f'filters-{filters}-N-{N}_factor-{factor}-do1-{dropout1}-do2-{dropout2}-classifier-{classifier}-learnall_classifier-{learnall_classifier}{"-rd-"+str(rnn_dropout) if classifier == "rnn" else ""}'
71 |
72 | def __fusion_params(N, factor, dropout1, dropout2, rnn_dropout, filters, hidden_dim, cell, bidirectional, number_of_layers, down_pool):
73 | return f'filters-{filters}-N-{N}_factor-{factor}-do1-{dropout1}-do2-{dropout2}-cell-{cell}-bidirectional-{bidirectional}-hidden_dim-{number_of_layers}x{hidden_dim}-rd{rnn_dropout}-downpool-{down_pool}'
74 |
75 |
76 | def __rnn_params(hidden_dim, cell, bidirectional, dropout, number_of_layers, down_pool, num_mfccs, use_attention, share_attention, input_projection):
77 | return f'cell-{cell}-bidirectional-{bidirectional}-hidden_dim-{number_of_layers}x{hidden_dim}-do-{dropout}-downpool-{down_pool}-mfccs-{num_mfccs}-attention-{use_attention}-shareAttention-{share_attention}-ip-{input_projection}'
78 |
79 | def __vgg16_params(freeze_up_to, classifier, dropout):
80 | return f'freezeUpTo-{freeze_up_to}-classifier-{classifier}-dropout-{dropout}'
81 |
82 |
83 | def train_single_task(
84 | initial_weights='/mnt/student/MauriceGerczuk/EmoSet/experiments/residual-adapters-emonet-revised/parallel/128mels/2s/scratch/N-2_factor-1-balancedClassWeights-True/GEMEP/weights_GEMEP.h5',
85 | feature_extractor='cnn',
86 | batch_size=64,
87 | epochs=50,
88 | balanced_weights=True,
89 | window=6,
90 | num_mels=128,
91 | task='IEMOCAP-4cl',
92 | loss='categorical_cross_entropy',
93 | directory='/mnt/nas/data_work/shahin/EmoSet/wavs-reordered/',
94 | train_csv='train.csv',
95 | val_csv='val.csv',
96 | test_csv='test.csv',
97 | experiment_base_path='/mnt/student/MauriceGerczuk/EmoSet/experiments/residual-adapters-emonet-revised',
98 | random_noise=None,
99 | learnall=False,
100 | last_layer_only=False,
101 | initial_learning_rate=0.1,
102 | optimizer=tf.keras.optimizers.SGD,
103 | n_workers=5,
104 | patience=20,
105 | mode='adapters',
106 | input_bn=False,
107 | share_feature_layer=False,
108 | individual_weight_decay=False,
109 | **kwargs):
110 | if feature_extractor in ['cnn', 'vgg16']:
111 | variable_duration = False if kwargs['classifier'] == 'avgpool' else True
112 | else:
113 | variable_duration = True
114 | #variable_duration = False
115 | base_tasks = None
116 | base_nb_classes = None
117 | feature_extractor_params = __feature_extractor_params_string(feature_extractor, **kwargs)
118 | training_params = f'balancedClassWeights-{balanced_weights}-loss-{loss}-optimizer-{optimizer.__name__}-lr-{initial_learning_rate}-bs-{batch_size}-patience-{patience}-random_noise-{random_noise}-numMels-{num_mels}-ib-{input_bn}-sfl-{share_feature_layer}-iwd-{individual_weight_decay}'
119 | experiment_base_path = f"{join(experiment_base_path, 'single-task', feature_extractor, mode, f'Window-{window}s', feature_extractor_params, training_params)}"
120 |
121 | train_generator = AudioDataGenerator(train_csv,
122 | directory,
123 | batch_size=batch_size,
124 | window=window,
125 | shuffle=True,
126 | sr=16000,
127 | time_stretch=None,
128 | pitch_shift=None,
129 | save_dir=None,
130 | val_split=None,
131 | subset='train',
132 | variable_duration=variable_duration)
133 | val_generator = AudioDataGenerator(val_csv,
134 | directory,
135 | batch_size=batch_size,
136 | window=window,
137 | shuffle=False,
138 | sr=16000,
139 | time_stretch=None,
140 | pitch_shift=None,
141 | save_dir=None,
142 | variable_duration=variable_duration)
143 | test_generator = AudioDataGenerator(test_csv,
144 | directory,
145 | batch_size=batch_size,
146 | window=window,
147 | shuffle=False,
148 | sr=16000,
149 | time_stretch=None,
150 | pitch_shift=None,
151 | save_dir=None,
152 | variable_duration=variable_duration)
153 | val_dataset = val_generator.tf_dataset()
154 | test_dataset = test_generator.tf_dataset()
155 |
156 | decay = determine_decay(train_generator, batch_size)
157 |
158 | if initial_weights is not None and mode == 'adapters':
159 | new_tasks = [task]
160 | new_nb_classes = [len(set(train_generator.classes))]
161 | new_weight_decays = [decay] if individual_weight_decay else None
162 | base_weight_decays = None
163 |
164 | else:
165 | base_tasks = [task]
166 | base_nb_classes = [len(set(train_generator.classes))]
167 | base_weight_decays = [decay] if individual_weight_decay else None
168 | new_tasks = []
169 | new_nb_classes = []
170 | new_weight_decays = None
171 |
172 | if balanced_weights:
173 | class_weights = class_weight.compute_class_weight(
174 | 'balanced', np.unique(train_generator.classes),
175 | train_generator.classes)
176 | class_weight_dict = dict(enumerate(class_weights))
177 | logger.info(f'Class weights: {class_weight_dict}')
178 | else:
179 | class_weight_dict = None
180 | logger.info('Not using class weights.')
181 |
182 | task_base_path = join(experiment_base_path, task)
183 | weights = join(task_base_path, "weights_" + task + ".h5")
184 |
185 |
186 | x, _ = train_generator[0]
187 | if not variable_duration:
188 | init = x.shape[1:]
189 | else:
190 | init = (None, )
191 | models, shared_model = create_multi_task_networks(
192 | init,
193 | feature_extractor=feature_extractor,
194 | initial_weights=initial_weights,
195 | num_mels=num_mels,
196 | mode=mode,
197 | base_nb_classes=base_nb_classes,
198 | base_weight_decays=base_weight_decays,
199 | learnall=learnall,
200 | base_tasks=base_tasks,
201 | new_tasks=new_tasks,
202 | new_nb_classes=new_nb_classes,
203 | new_weight_decays=new_weight_decays,
204 | random_noise=random_noise,
205 | input_bn=input_bn,
206 | share_feature_layer=share_feature_layer,
207 | **kwargs)
208 | model = models[task]
209 | #model.load_weights(initial_weights, by_name=True)
210 | if last_layer_only:
211 | for layer in model.layers[:-5]:
212 | layer.trainable = False
213 | model.summary()
214 | #print(model.non_trainable_weights)
215 | #model.load_weights(initial_weights, by_name=True)
216 |
217 | tbCallBack = tf.keras.callbacks.TensorBoard(log_dir=join(task_base_path, 'log'),
218 | histogram_freq=0,
219 | write_graph=True)
220 | #hpCallback = hp.KerasCallback(join(task_base_path, 'log', 'hparam_tuning'), hparams)
221 | mc = tf.keras.callbacks.ModelCheckpoint(weights,
222 | monitor='val_recall/macro/validation',
223 | verbose=1,
224 | save_best_only=True,
225 | save_weights_only=False,
226 | mode='max',
227 | period=1)
228 | metric_callback = ClassificationMetricCallback(
229 | validation_generator=val_dataset.prefetch(tf.data.experimental.AUTOTUNE), dataset_name=task, labels=val_generator.class_indices, true=val_generator.categorical_classes)
230 | metric_callback_test = ClassificationMetricCallback(
231 | validation_generator=test_dataset.prefetch(tf.data.experimental.AUTOTUNE), dataset_name=task, partition='test', labels=test_generator.class_indices, true=test_generator.categorical_classes)
232 |
233 | lrs = [
234 | initial_learning_rate, initial_learning_rate * 0.1,
235 | initial_learning_rate * 0.01
236 | ]
237 | makedirs(task_base_path, exist_ok=True)
238 | stopped_epoch = 0
239 | best = 0
240 | patience = patience
241 | load = False
242 | loss_string = loss
243 | if len(set(train_generator.classes)) == 2:
244 | loss = binary_loss_map[loss_string]
245 | else:
246 | loss = categorical_loss_map[loss_string]
247 | if loss_string == 'ordinal':
248 | loss = loss(n_classes=len(set(train_generator.classes)))
249 | for i, lr in enumerate(lrs):
250 | if optimizer.__name__ == tf.keras.optimizers.SGD.__name__:
251 | opt = optimizer(learning_rate=lr, decay=decay if not individual_weight_decay else 1e-6, momentum=0.9, nesterov=False)
252 | else:
253 | opt = optimizer(learning_rate=lr)
254 | model.compile(loss=loss, optimizer=opt, metrics=["acc"], experimental_run_tf_function=False)
255 |
256 | if load:
257 | model.load_weights(weights)
258 | logger.info("Model loaded.")
259 | early_stopper = tf.keras.callbacks.EarlyStopping(monitor='val_recall/macro/validation',
260 | min_delta=0.005,
261 | patience=patience,
262 | verbose=1,
263 | mode='max',
264 | restore_best_weights=False,
265 | baseline=best)
266 | model.fit(train_generator.tf_dataset().prefetch(tf.data.experimental.AUTOTUNE),
267 | validation_data=val_generator.tf_dataset(),
268 | epochs=epochs,
269 | workers=n_workers // 2,
270 | initial_epoch=stopped_epoch,
271 | class_weight=class_weight_dict,
272 | # use_multiprocessing=True,
273 | # max_queue_size=n_workers * 2,
274 | verbose=2,
275 | callbacks=[
276 | metric_callback, metric_callback_test,
277 | early_stopper, tbCallBack, mc
278 | ])
279 | load = True
280 | stopped_epoch = early_stopper.stopped_epoch
281 | best = early_stopper.best
282 |
283 |
284 | def train_multi_task(
285 | batch_size=64,
286 | epochs=50,
287 | balanced_weights=True,
288 | feature_extractor='cnn',
289 | window=2,
290 | num_mels=128,
291 | mode='adapters',
292 | initial_learning_rate=0.1,
293 | tasks=[
294 | "AirplaneBehaviourCorpus", "AngerDetection", "BurmeseEmotionalSpeech",
295 | "CASIA", "ChineseVocalEmotions", "DanishEmotionalSpeech", "DEMoS",
296 | "EA-ACT", "EA-BMW", "EA-WSJ", "EMO-DB", "EmoFilm", "EmotiW-2014",
297 | "ENTERFACE", "EU-EmoSS", "FAU_AIBO", "GEMEP", "GVESS",
298 | "MandarinEmotionalSpeech", "MELD", "PPMK-EMO", "SIMIS", "SMARTKOM",
299 | "SUSAS", "TurkishEmoBUEE"
300 | ],
301 | loss='crossentropy',
302 | directory='/mnt/nas/data_work/shahin/EmoSet/wavs-reordered/',
303 | experiment_base_path='/mnt/student/MauriceGerczuk/EmoSet/experiments/residual-adapters-emonet-revised',
304 | multi_task_setup='/mnt/student/MauriceGerczuk/EmoSet/multiTaskSetup-wavs-with-test/',
305 | steps_per_epoch=20,
306 | optimizer=tf.keras.optimizers.SGD,
307 | random_noise=None,
308 | input_bn=False,
309 | share_feature_layer=False,
310 | individual_weight_decay=False,
311 | **kwargs):
312 | if feature_extractor == 'cnn':
313 | variable_duration = False if kwargs['classifier'] == 'avgpool' else True
314 | else:
315 | variable_duration = True
316 | feature_extractor_params = __feature_extractor_params_string(feature_extractor, **kwargs)
317 | training_params = f'balancedClassWeights-{balanced_weights}-loss-{loss}-optimizer-{optimizer.__name__}-lr-{initial_learning_rate}-bs-{batch_size}-epochs-{epochs}-spe-{steps_per_epoch}-random_noise-{random_noise}-numMels-{num_mels}-ib-{input_bn}-sfl-{share_feature_layer}-iwd-{individual_weight_decay}'
318 | experiment_base_path = f"{join(experiment_base_path, 'multi-task', '-'.join(map(lambda x: x[:4], tasks)), feature_extractor, mode, f'Window-{window}s', feature_extractor_params, training_params)}"
319 |
320 |
321 | train_generators = [
322 | AudioDataGenerator(f'{multi_task_setup}/{task}/train.csv',
323 | directory,
324 | batch_size=batch_size,
325 | window=window,
326 | shuffle=True,
327 | sr=16000,
328 | time_stretch=None,
329 | pitch_shift=None,
330 | variable_duration=variable_duration,
331 | save_dir=None,
332 | val_split=None,
333 | subset='train') for task in tasks
334 | ]
335 | val_generators = [
336 | AudioDataGenerator(f'{multi_task_setup}/{task}/val.csv',
337 | directory,
338 | batch_size=batch_size,
339 | window=window,
340 | shuffle=False,
341 | sr=16000,
342 | time_stretch=None,
343 | variable_duration=variable_duration,
344 | pitch_shift=None,
345 | save_dir=None) for task in tasks
346 | ]
347 | test_generators = [
348 | AudioDataGenerator(f'{multi_task_setup}/{task}/test.csv',
349 | directory,
350 | batch_size=batch_size,
351 | window=window,
352 | shuffle=False,
353 | sr=16000,
354 | time_stretch=None,
355 | variable_duration=variable_duration,
356 | pitch_shift=None,
357 | save_dir=None) for task in tasks
358 | ]
359 |
360 | train_datasets = tuple(gen.tf_dataset().repeat() for gen in train_generators)
361 | val_datasets = tuple(gen.tf_dataset() for gen in val_generators)
362 | test_datasets = tuple(gen.tf_dataset() for gen in test_generators)
363 |
364 |
365 |
366 | if balanced_weights:
367 | class_weights = [
368 | class_weight.compute_class_weight('balanced', np.unique(t.classes),
369 | t.classes)
370 | for t in train_generators
371 | ]
372 | class_weight_dicts = [dict(enumerate(cw)) for cw in class_weights]
373 | logger.info(f'Class weights: {class_weight_dicts}')
374 | else:
375 | class_weight_dicts = [None] * len(tasks)
376 | logger.info('Not using class weights.')
377 |
378 | task_base_paths = [join(experiment_base_path, task) for task in tasks]
379 | weight_paths = [
380 | join(task_base_path, "weights_" + task + ".h5")
381 | for task_base_path, task in zip(task_base_paths, tasks)
382 | ]
383 |
384 | tbCallBacks = [
385 | tf.keras.callbacks.TensorBoard(log_dir=join(task_base_path, 'log'),
386 | histogram_freq=0,
387 | write_graph=True) for task_base_path in task_base_paths
388 | ]
389 |
390 | metric_callbacks = [
391 | ClassificationMetricCallback(validation_generator=val_dataset,
392 | period=1, dataset_name=task, labels=val_generator.class_indices)
393 | for val_dataset, val_generator, task in zip(val_datasets, val_generators, tasks)
394 | ]
395 | metric_callbacks_test = [
396 | ClassificationMetricCallback(validation_generator=test_dataset,
397 | partition='test',
398 | period=1,
399 | dataset_name=task,
400 | labels=test_generator.class_indices)
401 | for test_dataset, test_generator, task in zip(test_datasets, test_generators, tasks)
402 | ]
403 | decays = [determine_decay(tg, batch_size) for tg in train_generators]
404 |
405 | #steps_per_epoch = 10
406 | x, _ = train_generators[0][0]
407 | if not variable_duration:
408 | init = x.shape[1:]
409 | else:
410 | init = (None, )
411 |
412 | lrs = [
413 | initial_learning_rate, initial_learning_rate * 0.1,
414 | initial_learning_rate * 0.01
415 | ]
416 | nb_classes = [len(tg.class_indices) for tg in train_generators]
417 | models, shared_model = create_multi_task_networks(
418 | init,
419 | feature_extractor=feature_extractor,
420 | mode=mode,
421 | num_mels=num_mels,
422 | base_nb_classes=nb_classes,
423 | learnall=True,
424 | base_tasks=tasks,
425 | base_weight_decays=decays if individual_weight_decay else None,
426 | random_noise=random_noise,
427 | input_bn=input_bn,
428 | share_feature_layer=share_feature_layer,
429 | **kwargs)
430 | shared_model.summary()
431 | for i, t in enumerate(tasks):
432 | tbCallBacks[i].set_model(models[t])
433 | metric_callbacks[i].set_model(models[t])
434 | metric_callbacks_test[i].set_model(models[t])
435 |
436 | max_steps= epochs * steps_per_epoch
437 | loss_string = loss
438 | for step, lr in enumerate(lrs):
439 | for i, batch in tqdm(tf.data.Dataset.zip(train_datasets).enumerate().prefetch(1), total=max_steps):
440 | if i >= max_steps:
441 | break
442 | for t, task in enumerate(tasks):
443 | model = models[task]
444 | if i == 0: # reset learning rate
445 | if len(set(train_generators[t].classes)) == 2:
446 | loss = binary_loss_map[loss_string]
447 | else:
448 | loss = categorical_loss_map[loss_string]
449 | if loss_string == 'ordinal':
450 | loss = loss(n_classes=len(set(train_generators[t].classes)))
451 | if optimizer.__name__ == tf.keras.optimizers.SGD.__name__:
452 | opt = optimizer(lr=lr,
453 | decay=decays[t] if not individual_weight_decay else 1e-6,
454 | momentum=0.9,
455 | nesterov=False)
456 | else:
457 | opt = optimizer(lr=lr)
458 | model.compile(loss=loss, optimizer=opt, metrics=["acc"])
459 | # if i % len(train_generators[t]) == 0:
460 | # train_generators[t].on_epoch_end()
461 | logs = model.train_on_batch(*batch[t])
462 |
463 | named_l = named_logs(model, logs)
464 | # loss = named_l["loss"]
465 | # logger.info(f'Step {i}: loss {loss} ({task})')
466 | logger.debug(f'i % steps_per_epoch: {i%steps_per_epoch}')
467 | if i % steps_per_epoch == 0:
468 | logger.debug('In epoch end')
469 | metric_callbacks[t].on_epoch_end(
470 | i // steps_per_epoch + step * epochs, named_l)
471 | metric_callbacks_test[t].on_epoch_end(
472 | i // steps_per_epoch + step * epochs, named_l)
473 | model.save(weight_paths[t])
474 | tbCallBacks[t].on_epoch_end(
475 | i // steps_per_epoch + step * epochs, named_l)
476 | shared_model.save(join(experiment_base_path, 'shared_model.h5'))
477 |
--------------------------------------------------------------------------------
/emo-net/utils.py:
--------------------------------------------------------------------------------
1 | # EmoNet
2 | # ==============================================================================
3 | # Copyright (C) 2021 Maurice Gerczuk, Shahin Amiriparian,
4 | # Sandra Ottl, Björn Schuller: University of Augsburg. All Rights Reserved.
5 | #
6 | # This program is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU General Public License
17 | # along with this program. If not, see .
18 | # ==============================================================================
19 | import numpy as np
20 |
21 | def array_list_equal(a_list, b_list):
22 | if type(a_list) == list and type(b_list) == list:
23 | if len(a_list) != len(b_list):
24 | return False
25 | else:
26 | for a, b in zip(a_list, b_list):
27 | if not np.array_equal(a,b):
28 | return False
29 | return True
30 | elif type(a_list) == np.array and type(b_list) == np.array:
31 | return np.array_equal(a_list, b_list)
32 | else:
33 | return False
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Click
2 | dill
3 | imbalanced-learn
4 | librosa
5 | numpy
6 | pandas
7 | Pillow
8 | numba==0.48.*
9 | scikit-learn==0.22
10 | tensorflow==2.2.*
11 | tqdm
--------------------------------------------------------------------------------