├── .gitignore
├── LICENSE
├── README.md
├── egs
├── daps
│ ├── config.json
│ └── run.py
├── edinburgh_tts
│ ├── config.json
│ └── run.py
└── wsj0-2mix
│ ├── chimera
│ ├── evaluate.py
│ ├── msa
│ │ ├── RESULT
│ │ ├── config.json
│ │ └── run.py
│ └── psa
│ │ ├── RESULT
│ │ ├── config.json
│ │ └── run.py
│ ├── deep_clustering
│ ├── RESULT
│ ├── config.json
│ ├── evaluate.py
│ └── run.py
│ ├── phase-net
│ ├── config.json
│ └── run.py
│ └── tasnet
│ ├── conv-tasnet
│ ├── config.json
│ └── run.py
│ ├── evaluate.py
│ └── lstm-tasnet
│ ├── config.json
│ └── run.py
└── onssen
├── __init__.py
├── data
├── __init__.py
├── daps_enhance.py
├── edinburgh_tts.py
├── feature_utils.py
└── wsj0_2mix.py
├── evaluate
├── __init__.py
└── sdr.py
├── loss
├── __init__.py
├── loss_chimera.py
├── loss_dc.py
├── loss_e2e.py
├── loss_mask.py
├── loss_phase.py
└── loss_util.py
├── nn
├── __init__.py
├── chimera.py
├── deep_clustering.py
├── enhancement.py
├── phase_network.py
├── tasnet.py
└── uPIT-LSTM.py
└── utils
├── __init__.py
├── basic.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ONSSEN: An Open-source Speech Separation and Enhancement Library
2 | ======
3 | Onssen, pronounced as おんせん(温泉, Japanese hot spring), is a PyTorch-based library for speech separation, speech enhancement, or speech style transformation.
4 |
5 | Development plan:
6 | ------
7 | * [ ] Provide template classes for data, model, and evaluation
8 | * [ ] Move models to separate folders (i.e. Kaldi style)
9 | * [ ] Reproduce scores and upload pretrained models
10 | * [ ] Finish inference method for online separation
11 |
12 | 2020-04-20 Updates:
13 | -----
14 | + Add evaluation method for deep clustering
15 | + Use W_{MR} weight in deep clustering
16 | + Minor changes
17 |
18 |
19 | Supported Models
20 | ------
21 |
22 | + Deep Clustering
23 | + Chimera Net
24 | + Chimera++
25 | + Phase Estimation Network
26 | + Speech Enhancement with Restoration Layers
27 |
28 |
29 | Supported Dataset
30 | ------
31 |
32 | + Wsj0-2mix (http://www.merl.com/demos/deep-clustering)
33 | + Daps (https://archive.org/details/daps_dataset)
34 | + Edinburgh-TTS (https://datashare.is.ed.ac.uk/handle/10283/2791)
35 |
36 | Requirements
37 | ------
38 | + PyTorch
39 | + LibRosa
40 | + NumPy
41 |
42 | Usage
43 | ------
44 | You can simply use the existing config JSON file or customize your config file to train the enhancement or separation model.
45 | under the egs/wsj0-2mix/deep_clustering/ directory:
46 | ```
47 | python run.py -c config.json
48 | ```
49 |
50 |
51 | Citing
52 | ------
53 |
54 | If you use onssen for your research project, please cite one of the following bibtex citations:
55 |
56 | @article{ni2019onssen,
57 | title={Onssen: an open-source speech separation and enhancement library},
58 | author={Ni, Zhaoheng and Mandel, Michael I},
59 | journal={arXiv preprint arXiv:1911.00982},
60 | year={2019}
61 | }
62 |
63 | @Misc{onssen,
64 | author = {Zhaoheng Ni and Michael Mandel},
65 | title = "ONSSEN: An Open-source Speech Separation and Enhancement Library",
66 | howpublished = {\url{https://github.com/speechLabBcCuny/onssen}},
67 | year = {2019}
68 | }
69 |
--------------------------------------------------------------------------------
/egs/daps/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "enhance",
3 | "dataset": "daps",
4 | "feature_options": {
5 | "data_path": "/home/data/daps",
6 | "batch_size": 8,
7 | "frame_length": 400,
8 | "sampling_rate": 44100,
9 | "window_size": 2048,
10 | "hop_size": 441
11 | },
12 | "optimizer_options": {
13 | "name": "adam",
14 | "lr": 0.001
15 | },
16 | "model_options": {
17 | "input_dim": 1025,
18 | "output_dim": 1025,
19 | "hidden_dim": 300,
20 | "num_layers": 3,
21 | "dropout": 0.3
22 | },
23 | "device": "cpu",
24 | "num_speaker": 2,
25 | "num_epoch": 200,
26 | "resume_from_checkpoint": "False",
27 | "checkpoint_path": "./checkpoint/"
28 | }
29 |
--------------------------------------------------------------------------------
/egs/daps/run.py:
--------------------------------------------------------------------------------
1 | from onssen import data, loss, nn, utils
2 | from attrdict import AttrDict
3 | import torch
4 | import json
5 |
6 |
7 | def main():
8 | parser = argparse.ArgumentParser(description='Parse the config path')
9 | parser.add_argument("-c", "--config", dest="path",
10 | help='The path to the config file. e.g. python run.py --config dc_config.json')
11 |
12 | config = parser.parse_args()
13 | with open(config.path) as f:
14 | args = json.load(f)
15 | args = AttrDict(args)
16 | device = torch.device(args.device)
17 | args.model = onssen.nn.enhance(args.model_options)
18 | args.model.to(device)
19 | args.train_loader = data.daps_enhance_dataloader(args.train_num_batch, args.feature_options, 'train', args.cuda_option, self.device)
20 | args.valid_loader = data.daps_enhance_dataloader(args.vaildate_num_batch, args.feature_options, 'validation', args.cuda_option, self.device)
21 | args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
22 | args.loss_fn = loss.loss_mask_msa
23 | trainer = onssen.utils.trainer(args)
24 | trainer.run()
25 |
26 | tester = onssen.utils.tester(args)
27 | tester.eval()
28 |
29 |
30 | if __name__ == "__main__":
31 | main()
32 |
--------------------------------------------------------------------------------
/egs/edinburgh_tts/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "chimera++",
3 | "dataset": "edinburgh_tts",
4 | "feature_options": {
5 | "data_path": "/home/data/Edinburgh_TTS/",
6 | "batch_size": 16,
7 | "frame_length": 400,
8 | "sampling_rate": 16000,
9 | "window_size": 1024,
10 | "hop_size": 256,
11 | "db_threshold": 40
12 | },
13 | "optimizer_options": {
14 | "name": "adam",
15 | "lr": 0.001
16 | },
17 | "model_options": {
18 | "input_dim": 513,
19 | "hidden_dim": 600,
20 | "embedding_dim": 40,
21 | "num_layers": 3,
22 | "dropout": 0.3
23 | },
24 | "device": "cpu",
25 | "num_speaker": 2,
26 | "num_epoch": 200,
27 | "resume_from_checkpoint": "False",
28 | "checkpoint_path": "./checkpoint/"
29 | }
30 |
--------------------------------------------------------------------------------
/egs/edinburgh_tts/run.py:
--------------------------------------------------------------------------------
1 | from onssen import data, loss, nn, utils
2 | from attrdict import AttrDict
3 | import torch
4 | import json
5 |
6 |
7 | def main():
8 | parser = argparse.ArgumentParser(description='Parse the config path')
9 | parser.add_argument("-c", "--config", dest="path",
10 | help='The path to the config file. e.g. python run.py --config dc_config.json')
11 |
12 | config = parser.parse_args()
13 | with open(config.path) as f:
14 | args = json.load(f)
15 | args = AttrDict(args)
16 | device = torch.device(args.device)
17 | args.model = onssen.nn.chimera(args.model_options)
18 | args.model.to(device)
19 | args.train_loader = data.edinburgh_tts_dataloader(args.model_name, args.feature_options, 'train', args.cuda_option, self.device)
20 | args.valid_loader = data.edinburgh_tts_dataloader(args.model_name, args.feature_options, 'validation', args.cuda_option, self.device)
21 | args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
22 | args.loss_fn = loss.loss_chimera_psa
23 | trainer = onssen.utils.trainer(args)
24 | trainer.run()
25 |
26 | tester = onssen.utils.tester(args)
27 | tester.eval()
28 |
29 |
30 | if __name__ == "__main__":
31 | main()
32 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/chimera/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../../../onssen/')
3 |
4 | from onssen import utils
5 | from sklearn.cluster import KMeans
6 | import librosa
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class tester_chimera(utils.tester):
12 | def get_est_sig(self, input, label, output):
13 | """
14 | args:
15 | feature_mix: batch x frame x frequency
16 | embedding: batch x frame x frequency x embedding_dim
17 | stft_r_mix: batch x frame x frequency
18 | stft_i_mix: batch x frame x frequency
19 | sig_ref: batch x num_spk x nsample
20 | return:
21 | sig_est: batch x num_spk x nsample
22 | """
23 | feature_mix, = input
24 | embedding, mask_A, mask_B = output
25 | stft_r_mix, stft_i_mix, sig_ref = label
26 |
27 | stft_r_mix = stft_r_mix.detach().cpu().numpy()
28 | stft_i_mix = stft_i_mix.detach().cpu().numpy()
29 | embedding = embedding.detach().cpu().numpy()
30 | feature_mix = feature_mix.detach().cpu().numpy()
31 | mask_A = mask_A.detach().cpu().numpy()
32 | mask_B = mask_B.detach().cpu().numpy()
33 |
34 | stft_mix = stft_r_mix + 1j * stft_i_mix
35 | batch, frame, frequency = feature_mix.shape
36 | batch, num_spk, nsample = sig_ref.shape
37 | mask = np.zeros((num_spk, frame, frequency))
38 | mask[0, :, :] = mask_A[0]
39 | mask[1, :, :] = mask_B[0]
40 | stft_est = stft_mix * mask
41 | sig_est = np.zeros((batch, num_spk, nsample))
42 | for i in range(num_spk):
43 | sig_est[0, i] = librosa.core.istft(stft_est[i].T, hop_length=64, length=nsample)
44 | sig_est = torch.tensor(sig_est).to(self.device)
45 | return sig_est, sig_ref
46 |
47 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/chimera/msa/RESULT:
--------------------------------------------------------------------------------
1 | SI-SDR: 10.33
2 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/chimera/msa/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "chimera",
3 | "feature_options": {
4 | "data_path": "/home/data/wsj0-mix/2speakers/",
5 | "batch_size": 8,
6 | "frame_length": 400,
7 | "sampling_rate": 8000,
8 | "window_size": 256,
9 | "hop_size": 64,
10 | "db_threshold": 40
11 | },
12 | "optimizer_options": {
13 | "name": "adam",
14 | "lr": 0.001
15 | },
16 | "model_options": {
17 | "input_dim": 129,
18 | "hidden_dim": 300,
19 | "embedding_dim": 20,
20 | "num_layers": 4
21 | },
22 | "device": "cuda",
23 | "num_speaker": 2,
24 | "num_epoch": 200,
25 | "resume_from_checkpoint": "False",
26 | "checkpoint_path": "./checkpoint/"
27 | }
28 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/chimera/msa/run.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../../../onssen/')
3 | sys.path.append('../')
4 | from onssen import data, loss, nn, utils
5 | from attrdict import AttrDict
6 | import torch
7 | import json
8 | from evaluate import tester_chimera
9 |
10 | def main():
11 | config_path = './config.json'
12 | with open(config_path) as f:
13 | args = json.load(f)
14 | args = AttrDict(args)
15 | device = torch.device(args.device)
16 | args.model = nn.chimera(**(args['model_options']))
17 | args.model.to(device)
18 | args.train_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tr', device)
19 | args.valid_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'cv', device)
20 | args.test_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tt', device)
21 | args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
22 | args.loss_fn = loss.loss_chimera_msa
23 | trainer = utils.trainer(args)
24 | trainer.run()
25 | tester = tester_chimera(args)
26 | tester.eval()
27 |
28 |
29 | if __name__ == "__main__":
30 | main()
31 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/chimera/psa/RESULT:
--------------------------------------------------------------------------------
1 | SI-SDR: 10.93
2 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/chimera/psa/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "chimera++",
3 | "feature_options": {
4 | "data_path": "/home/data/wsj0-mix/2speakers/",
5 | "batch_size": 8,
6 | "frame_length": 400,
7 | "sampling_rate": 8000,
8 | "window_size": 256,
9 | "hop_size": 64,
10 | "db_threshold": 40
11 | },
12 | "optimizer_options": {
13 | "name": "adam",
14 | "lr": 0.001
15 | },
16 | "model_options": {
17 | "input_dim": 129,
18 | "hidden_dim": 600,
19 | "embedding_dim": 20,
20 | "num_layers": 4
21 | },
22 | "device": "cuda",
23 | "num_speaker": 2,
24 | "num_epoch": 200,
25 | "resume_from_checkpoint": "False",
26 | "checkpoint_path": "./checkpoint/"
27 | }
28 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/chimera/psa/run.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../../../onssen/')
3 | sys.path.append('../')
4 | from onssen import data, loss, nn, utils
5 | from attrdict import AttrDict
6 | import torch
7 | import json
8 | from evaluate import tester_chimera
9 |
10 |
11 | def main():
12 | config_path = './config.json'
13 | with open(config_path) as f:
14 | args = json.load(f)
15 | args = AttrDict(args)
16 | device = torch.device(args.device)
17 | args.model = nn.chimera(**(args['model_options']))
18 | args.model.to(device)
19 | args.train_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tr', device)
20 | args.valid_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'cv', device)
21 | args.test_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tt', device)
22 | args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
23 | args.loss_fn = loss.loss_chimera_psa
24 | trainer = utils.trainer(args)
25 | trainer.run()
26 | tester = tester_chimera(args)
27 | tester.eval()
28 |
29 |
30 | if __name__ == "__main__":
31 | main()
32 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/deep_clustering/RESULT:
--------------------------------------------------------------------------------
1 | SI-SDR: 8.858
2 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/deep_clustering/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "dc",
3 | "feature_options": {
4 | "data_path": "/home/data/wsj0-mix/2speakers/",
5 | "batch_size": 16,
6 | "frame_length": 400,
7 | "sampling_rate": 8000,
8 | "window_size": 256,
9 | "hop_size": 64,
10 | "db_threshold": 40
11 | },
12 | "optimizer_options": {
13 | "name": "adam",
14 | "lr": 0.001
15 | },
16 | "model_options": {
17 | "input_dim": 129,
18 | "hidden_dim": 600,
19 | "embedding_dim": 20,
20 | "num_layers": 3
21 | },
22 | "device": "cuda:0",
23 | "num_speaker": 2,
24 | "num_epoch": 200,
25 | "resume_from_checkpoint": "False",
26 | "checkpoint_path": "./checkpoint/"
27 | }
28 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/deep_clustering/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('/home/near/onssen/')
3 | from onssen import utils
4 | from sklearn.cluster import KMeans
5 | import librosa
6 | import numpy as np
7 | import torch
8 |
9 |
10 | class tester_dc(utils.tester):
11 | def get_est_sig(self, input, label, output):
12 | """
13 | args:
14 | feature_mix: batch x frame x frequency
15 | embedding: batch x frame x frequency x embedding_dim
16 | stft_r_mix: batch x frame x frequency
17 | stft_i_mix: batch x frame x frequency
18 | sig_ref: batch x num_spk x nsample
19 | return:
20 | sig_est: batch x num_spk x nsample
21 | """
22 | feature_mix, = input
23 | embedding, = output
24 | stft_r_mix, stft_i_mix, sig_ref = label
25 |
26 | stft_r_mix = stft_r_mix.detach().cpu().numpy()
27 | stft_i_mix = stft_i_mix.detach().cpu().numpy()
28 | embedding = embedding.detach().cpu().numpy()
29 | feature_mix = feature_mix.detach().cpu().numpy()
30 |
31 | stft_mix = stft_r_mix + 1j * stft_i_mix
32 | batch, frame, frequency = feature_mix.shape
33 | batch, num_spk, nsample = sig_ref.shape
34 | feature_mix = feature_mix.reshape(frame, frequency)
35 | embedding = embedding.reshape(frame, frequency, -1)
36 | m = np.max(feature_mix) - 40/20
37 | emb = embedding[feature_mix>=m,:]
38 | label = KMeans(n_clusters=num_spk, random_state=0).fit_predict(emb)
39 | mask = np.zeros((num_spk, frame, frequency))
40 | mask[0, feature_mix>=m] = label
41 | mask[1, feature_mix>=m] = 1-label
42 | stft_est = stft_mix * mask
43 | sig_est = np.zeros((batch, num_spk, nsample))
44 | for i in range(num_spk):
45 | sig_est[0, i] = librosa.core.istft(stft_est[i].T, hop_length=64, length=nsample)
46 | sig_est = torch.tensor(sig_est).to(self.device)
47 | return sig_est, sig_ref
48 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/deep_clustering/run.py:
--------------------------------------------------------------------------------
1 |
2 | import sys
3 | sys.path.append('../../../../onssen/')
4 |
5 | from onssen import data, loss, nn, utils
6 | from evaluate import tester_dc
7 | from attrdict import AttrDict
8 | import argparse
9 | import torch
10 | import json
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser(description='Parse the config path')
15 | parser.add_argument("-c", "--config", dest="path",
16 | help='The path to the config file. e.g. python run.py --config onfig.json')
17 |
18 | config = parser.parse_args()
19 | with open(config.path) as f:
20 | args = json.load(f)
21 | args = AttrDict(args)
22 | device = torch.device(args.device)
23 | args.model = nn.deep_clustering(**(args['model_options']))
24 | args.model.to(device)
25 | args.train_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tr', device)
26 | args.valid_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'cv', device)
27 | args.test_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tt', device)
28 | args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
29 | args.loss_fn = loss.loss_dc
30 | trainer = utils.trainer(args)
31 | trainer.run()
32 |
33 | tester = tester_dc(args)
34 | tester.eval()
35 |
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/phase-net/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "phase",
3 | "dataset": "wsj0-2mix",
4 | "feature_options": {
5 | "data_path": "/home/data/wsj0-2mix/",
6 | "batch_size": 16,
7 | "frame_length": 400,
8 | "sampling_rate": 8000,
9 | "window_size": 256,
10 | "hop_size": 64,
11 | "db_threshold": 40
12 | },
13 | "optimizer_options": {
14 | "name": "adam",
15 | "lr": 0.001
16 | },
17 | "model_options": {
18 | "input_dim": 129,
19 | "hidden_dim": 300,
20 | "embedding_dim": 20,
21 | "num_layers": 3
22 | },
23 | "loss_option": "loss_phase",
24 | "num_speaker": 2,
25 | "num_epoch": 200,
26 | "output_path": "./output/",
27 | "cuda_option": "True"
28 | }
29 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/phase-net/run.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/speechLabBcCuny/onssen/179cff94451918601f648d17c76ed0788fc5295c/egs/wsj0-2mix/phase-net/run.py
--------------------------------------------------------------------------------
/egs/wsj0-2mix/tasnet/conv-tasnet/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_name": "conv-tasnet",
3 | "feature_options": {
4 | "data_path": "/home/data/wsj0-mix/2speakers/",
5 | "batch_size": 3,
6 | "sampling_rate": 8000,
7 | "chunk_size": 32000
8 | },
9 | "optimizer_options": {
10 | "name": "adam",
11 | "lr": 0.001
12 | },
13 | "model_options": {
14 | "N": 512,
15 | "L": 16,
16 | "B": 128,
17 | "H": 512,
18 | "P": 3,
19 | "X": 8,
20 | "R": 3,
21 | "norm": "gln",
22 | "num_spks": 2,
23 | "activate": "sigmoid",
24 | "causal": false
25 | },
26 | "device": "cuda:1",
27 | "num_epoch": 200,
28 | "resume_from_checkpoint": "False",
29 | "checkpoint_path": "./checkpoint/"
30 | }
31 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/tasnet/conv-tasnet/run.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../../../onssen/')
3 | sys.path.append('../')
4 | from onssen import data, loss, nn, utils
5 | from attrdict import AttrDict
6 | import torch
7 | import json
8 | from evaluate import tester_tasnet
9 |
10 |
11 | def main():
12 | config_path = './config.json'
13 | with open(config_path) as f:
14 | args = json.load(f)
15 | args = AttrDict(args)
16 | device = torch.device(args.device)
17 | args.device = device
18 | args.model = nn.ConvTasNet(**args["model_options"])
19 | args.model.to(device)
20 | args.train_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tr', device)
21 | args.valid_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'cv', device)
22 | args.test_loader = data.wsj0_2mix_dataloader(args.model_name, args.feature_options, 'tt', device)
23 | args.optimizer = utils.build_optimizer(args.model.parameters(), args.optimizer_options)
24 | args.loss_fn = loss.si_snr_loss
25 | trainer = utils.trainer(args)
26 | trainer.run()
27 | tester = tester_tasnet(args)
28 | tester.eval()
29 |
30 |
31 | if __name__ == "__main__":
32 | main()
33 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/tasnet/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../../../../onssen/')
3 |
4 | from onssen import utils
5 | from sklearn.cluster import KMeans
6 | import librosa
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class tester_tasnet(utils.tester):
12 | def get_est_sig(self, input, label, output):
13 | """
14 | args:
15 | feature_mix: batch x frame x frequency
16 | embedding: batch x frame x frequency x embedding_dim
17 | stft_r_mix: batch x frame x frequency
18 | stft_i_mix: batch x frame x frequency
19 | sig_ref: batch x num_spk x nsample
20 | return:
21 | sig_est: batch x num_spk x nsample
22 | """
23 | feature_mix, = input
24 | sig_ref, = label
25 | batch, num_spk, N = sig_ref.shape
26 | sig_est = torch.zeros((batch, num_spk, N), device=self.device)
27 | for i in range(num_spk):
28 | sig_est[:, i, :] = output[i][0:N]
29 | return sig_est, sig_ref
30 |
31 |
--------------------------------------------------------------------------------
/egs/wsj0-2mix/tasnet/lstm-tasnet/config.json:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/speechLabBcCuny/onssen/179cff94451918601f648d17c76ed0788fc5295c/egs/wsj0-2mix/tasnet/lstm-tasnet/config.json
--------------------------------------------------------------------------------
/egs/wsj0-2mix/tasnet/lstm-tasnet/run.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/speechLabBcCuny/onssen/179cff94451918601f648d17c76ed0788fc5295c/egs/wsj0-2mix/tasnet/lstm-tasnet/run.py
--------------------------------------------------------------------------------
/onssen/__init__.py:
--------------------------------------------------------------------------------
1 | import onssen.data
2 | import onssen.loss
3 | import onssen.evaluate
4 | import onssen.nn
5 | import onssen.utils
6 |
--------------------------------------------------------------------------------
/onssen/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .daps_enhance import daps_enhance_dataloader
2 | from .edinburgh_tts import edinburgh_tts_dataloader
3 | from .wsj0_2mix import wsj0_2mix_dataloader
4 |
--------------------------------------------------------------------------------
/onssen/data/daps_enhance.py:
--------------------------------------------------------------------------------
1 | """
2 | We need to define a batch size for training the deep clustering model.
3 | Each batch has a shape (batch_size, 100/400, feature_dim)
4 |
5 | For STFT:
6 | 8kHz fs
7 | 32 ms window length 32*8 = 256
8 | 8 ms window shift = 64
9 |
10 | 44kHz fs
11 | 128 ms window length 128*8 = 1024
12 | 32 ms window shift = 256
13 | """
14 |
15 | from torch.utils.data.dataset import Dataset
16 | from torch.utils.data import DataLoader
17 | from .feature_utils import *
18 | import glob,librosa, numpy as np, os, random, torch
19 |
20 | """
21 | what I want from the DataLoader?
22 | Each batch should contain different speakers
23 | There should be 400 * 513 tensor for each sample
24 | We have a list of files, each contains N times of 400 frames
25 | To make full use of them, we need to generate something
26 | We have 20 speakers, each will contains N X 400 X 513 tensors
27 |
28 | """
29 |
30 | def daps_enhance_dataloader(num_batch, feature_options, partition, device=None):
31 | return DataLoader(
32 | daps_dataset(num_batch, feature_options, partition, device=device),
33 | batch_size=feature_options.batch_size,
34 | shuffle=True,
35 | )
36 |
37 |
38 | class daps_dataset(Dataset):
39 | def __init__(self, num_batch, feature_options, partition, device=None):
40 | """
41 | The arguments:
42 | feature_options: a dictionary containing the feature params
43 | partition: can be "train", "validation"
44 | num_batch: Each training epoch uses num_batch * batch_size * frame_length data
45 | e.g.
46 | "feature_options": {
47 | "data_path": "/home/data/wsj0-2mix",
48 | "batch_size": 16,
49 | "frame_length": 400,
50 | "sampling_rate": 8000,
51 | "window_size": 256,
52 | "hop_size": 64,
53 | "db_threshold": 40
54 | }
55 | The returns:
56 | input: a tuple which follows the requirement of the loss
57 | label: a tuple which follows the requirement of the loss
58 | e.g.
59 | for dc loss:
60 | input: (feature_mix)
61 | label: (one_hot_label)
62 | for chimera loss:
63 | input: (feature_mix)
64 | label: (one_hot_label, mag_mix, mag_s1, mag_s2)
65 | """
66 | self.sampling_rate = feature_options.sampling_rate
67 | self.window_size = feature_options.window_size
68 | self.hop_size = feature_options.hop_size
69 | self.frame_length = feature_options.frame_length
70 | self.num_batch = num_batch
71 | self.batch_size = feature_options.batch_size
72 | self.file_list = []
73 | self.base_path = feature_options.data_path
74 | self.partition = partition
75 | self.length_remaining = 0
76 | self.get_item_list()
77 | if device is None:
78 | self.device = torch.device('cpu')
79 | else:
80 | self.device = device
81 |
82 |
83 | def get_item_list(self):
84 | f = open(self.base_path+'/'+self.partition)
85 | self.file_list = [line.replace('\n','') for line in f]
86 | random.shuffle(self.file_list)
87 |
88 |
89 | def __getitem__(self, index):
90 | if self.length_remaining < self.frame_length:
91 | if len(self.file_list)==0:
92 | self.get_item_list()
93 | # add one more file, delete the index from the list
94 | index = index % len(self.file_list)
95 | f_noisy = self.file_list.pop(index)
96 | base_names = os.path.basename(f_noisy).split("_")
97 | f_clean = self.base_path + "/clean/" + base_names[0] + "_" + base_names[1] + "_clean.wav"
98 | stft_noisy = get_stft(f_noisy, self.sampling_rate, self.window_size, self.hop_size)
99 | stft_clean = get_stft(f_clean, self.sampling_rate, self.window_size, self.hop_size)
100 |
101 | feature = get_log_magnitude(stft_noisy)
102 | #feature = get_log_mel_spectrogram(f_noisy, self.sampling_rate, self.window_size, self.hop_size)
103 | # one_hot_label
104 | mag_noisy = np.abs(stft_noisy)
105 | mag_clean = np.abs(stft_clean)
106 | cos_diff = get_cos_difference(stft_noisy, stft_clean)
107 | input, label = [feature, mag_noisy], [mag_clean, cos_diff]
108 |
109 | input = [torch.tensor(ele).to(self.device) for ele in input]
110 | label = [torch.tensor(ele).to(self.device) for ele in label]
111 |
112 | self.input = input
113 | self.label = label
114 | return self.cutoff_feature()
115 | else:
116 | return self.cutoff_feature()
117 |
118 | def cutoff_feature(self):
119 | input, label = [ele[0:self.frame_length] for ele in self.input], [ele[0:self.frame_length] for ele in self.label]
120 | self.input = [ele[self.frame_length:] for ele in self.input]
121 | self.label = [ele[self.frame_length:] for ele in self.label]
122 | self.length_remaining = self.input[0].shape[0]
123 | return input, label
124 |
125 | def __len__(self):
126 | return self.num_batch*self.batch_size
127 |
--------------------------------------------------------------------------------
/onssen/data/edinburgh_tts.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data.dataset import Dataset
2 | from torch.utils.data import DataLoader
3 | from .feature_utils import *
4 | import glob
5 | import numpy as np
6 | import random
7 | import torch
8 |
9 |
10 | def edinburgh_tts_dataloader(model_name, feature_options, partition, device=None):
11 | return DataLoader(
12 | edinburgh_tts_dataset(model_name, feature_options, partition, device=device),
13 | batch_size=feature_options.batch_size,
14 | shuffle=True,
15 | )
16 |
17 |
18 | class edinburgh_tts_dataset(Dataset):
19 | def __init__(self, model_name, feature_options, partition, device=None):
20 | """
21 | The arguments:
22 | feature_options: a dictionary containing the feature params
23 | partition: can be "train", "validation"
24 | model_name: can be "dc", "chimera", "chimera++", "phase"
25 | e.g.
26 | "feature_options": {
27 | "data_path": "/home/data/Edinburg_tts",
28 | "batch_size": 16,
29 | "frame_length": 400,
30 | "sampling_rate": 16000,
31 | "window_size": 512,
32 | "hop_size": 128,
33 | "db_threshold": 40
34 | }
35 | The returns:
36 | input: a tuple which follows the requirement of the loss
37 | label: a tuple which follows the requirement of the loss
38 | e.g.
39 | for dc loss:
40 | input: (feature_mix)
41 | label: (one_hot_label)
42 | for chimera loss:
43 | input: (feature_mix)
44 | label: (one_hot_label, mag_mix, mag_s1, mag_s2)
45 | """
46 | self.sampling_rate = feature_options.sampling_rate
47 | self.window_size = feature_options.window_size
48 | self.hop_size = feature_options.hop_size
49 | self.frame_length = feature_options.frame_length
50 | self.db_threshold = feature_options.db_threshold
51 | self.model_name = model_name
52 | self.data_path = feature_options.data_path
53 | self.partition = partition
54 | self.file_list = []
55 | self.get_file_list()
56 | if device is None:
57 | self.device = torch.device('cpu')
58 | else:
59 | self.device = device
60 |
61 | def get_file_list(self):
62 | with open(self.data_path+'/'+self.partition,'r') as f:
63 | for line in f:
64 | self.file_list.append(self.data_path+'/noisy_trainset_28spk_wav/'+line.replace('\n',''))
65 | random.shuffle(self.file_list)
66 |
67 |
68 | def get_feature(self,fn):
69 | stft_mix = get_stft(fn, self.sampling_rate, self.window_size, self.hop_size)
70 | stft_s1 = get_stft(fn.replace('/noisy_trainset_28spk_wav','/clean_trainset_28spk_wav'), self.sampling_rate, self.window_size, self.hop_size)
71 | stft_s2 = get_stft_from_subtraction(fn, fn.replace('/noisy_trainset_28spk_wav','/clean_trainset_28spk_wav'), self.sampling_rate, self.window_size, self.hop_size)
72 |
73 | if stft_mix.shape[0]<=self.frame_length:
74 | #pad in a double-copy fashion
75 | times = self.frame_length // stft_mix.shape[0]+1
76 | stft_mix = np.concatenate([stft_mix]*times, axis=0)
77 | stft_s1 = np.concatenate([stft_s1]*times, axis=0)
78 | stft_s2 = np.concatenate([stft_s2]*times, axis=0)
79 |
80 | stft_mix = stft_mix[:self.frame_length]
81 | stft_s1 = stft_s1[:self.frame_length]
82 | stft_s2 = stft_s2[:self.frame_length]
83 | # base feature
84 | feature_mix = get_log_magnitude(stft_mix)
85 | # one_hot_label
86 | mag_mix = np.abs(stft_mix)
87 | mag_s1 = np.abs(stft_s1)
88 | mag_s2 = np.abs(stft_s2)
89 | one_hot_label = get_one_hot(feature_mix, mag_s1, mag_s2, self.db_threshold)
90 |
91 | if self.model_name == "dc":
92 | input, label = [feature_mix], [one_hot_label]
93 |
94 | if self.model_name == "chimera":
95 | input, label = [feature_mix], [one_hot_label, mag_mix, mag_s1, mag_s2]
96 |
97 | if self.model_name == "chimera++":
98 | cos_s1 = get_cos_difference(stft_mix, stft_s1)
99 | cos_s2 = get_cos_difference(stft_mix, stft_s2)
100 | input, label = [feature_mix], [one_hot_label, mag_mix, mag_s1, mag_s2, cos_s1, cos_s2]
101 |
102 | if self.model_name == "phase":
103 | phase_mix = get_phase(stft_mix)
104 | phase_s1 = get_phase(stft_s1)
105 | phase_s2 = get_phase(stft_s2)
106 | input, label = [feature_mix, phase_mix], [one_hot_label, mag_mix, mag_s1, mag_s2, phase_s1, phase_s2]
107 |
108 | input = [torch.tensor(ele).to(self.device) for ele in input]
109 | label = [torch.tensor(ele).to(self.device) for ele in label]
110 |
111 | return input, label
112 |
113 |
114 | def __getitem__(self, index):
115 | file_name_mix = self.file_list[index]
116 | return self.get_feature(file_name_mix)
117 |
118 |
119 | def __len__(self):
120 | return len(self.file_list)
121 |
--------------------------------------------------------------------------------
/onssen/data/feature_utils.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 |
4 |
5 | def get_stft(fn, sampling_rate, window_size, hop_size):
6 | """
7 | fn: the absolute path of the wav file
8 | sampling_rate: in Hz
9 | window_size: window size for fft
10 | hop_size: the hop size for shifting the window
11 |
12 | return:
13 | stft: frame * frequency numpy array
14 | """
15 | sig, fs = librosa.load(fn, sr = None)
16 | if fs != sampling_rate:
17 | # print("WARNING!!! The sampling rate provided is different from the data")
18 | # print("Resample the audio...")
19 | sig = librosa.core.resample(sig, fs, sampling_rate)
20 | stft = np.transpose(librosa.core.stft(sig, n_fft=window_size, hop_length=hop_size))
21 | return stft
22 |
23 |
24 | def get_stft_from_subtraction(f_mix, f_clean, sampling_rate, window_size, hop_size):
25 | sig_mix, fs = librosa.load(f_mix, sr = None)
26 | sig_clean, fs = librosa.load(f_clean, sr = None)
27 | sig_noise = sig_mix - sig_clean
28 | if fs != sampling_rate:
29 | # print("WARNING!!! The sampling rate provided is different from the data")
30 | # print("Resample the audio...")
31 | sig_noise = librosa.core.resample(sig_noise, fs, sampling_rate)
32 | stft = np.transpose(librosa.core.stft(sig_noise, n_fft=window_size, hop_length=hop_size))
33 | return stft
34 |
35 |
36 | def get_log_mel_spectrogram(fn, sampling_rate, window_size, hop_size, epsilon=1e-7):
37 | sig, fs = librosa.load(fn, sr = None)
38 | assert sampling_rate == fs
39 | mel_spectra = librosa.feature.melspectrogram(
40 | sig,
41 | sr=sampling_rate,
42 | n_fft=window_size,
43 | hop_length=hop_size
44 | )
45 | mel_spectra = np.transpose(np.log10(mel_spectra + epsilon))
46 | return mel_spectra
47 |
48 |
49 | def get_log_magnitude(stft, epsilon=1e-7):
50 | feature = np.log10(np.abs(stft) + epsilon)
51 | return feature
52 |
53 |
54 | def get_phase(stft):
55 | """
56 | stft: frame * frequency complex numpy array
57 | return:
58 | phase: frame * frequency * 2 real numpy array
59 | """
60 | real = np.real(stft)
61 | imag = np.imag(stft)
62 | phase = np.array([real, imag])
63 | phase = np.transpose(phase, (1,2,0))
64 | return phase
65 |
66 |
67 | def get_angle(stft):
68 | """
69 | stft: frame * frequency complex numpy array
70 | return:
71 | angle: the angle of the STFT
72 | """
73 | angle = np.angle(stft)
74 | return angle
75 |
76 |
77 | def get_cos_difference(stft_1, stft_2):
78 | angle_1 = get_angle(stft_1)
79 | angle_2 = get_angle(stft_2)
80 | return np.cos(angle_1 - angle_2)
81 |
82 |
83 | def get_one_hot(feature_mix, mag_s1, mag_s2, db_threshold):
84 | specs = np.asarray([mag_s1, mag_s2])
85 | vals = np.argmax(specs, axis=0)
86 | Y = np.zeros(mag_s1.shape+(2,))
87 | for i in range(2):
88 | temp = np.zeros((2))
89 | temp[i]=1
90 | Y[vals == i] = temp
91 | #label the silence part
92 | m = np.max(feature_mix) - db_threshold/20
93 | temp = np.zeros((2))
94 | Y[feature_mix < m] = temp
95 | return Y
96 |
--------------------------------------------------------------------------------
/onssen/data/wsj0_2mix.py:
--------------------------------------------------------------------------------
1 | """
2 | We need to define a batch size for training the deep clustering model.
3 | Each batch has a shape (batch_size, 100/400, feature_dim)
4 |
5 | For STFT:
6 | 8kHz fs
7 | 32 ms window length 32*8 = 256
8 | 8 ms window shift = 64
9 |
10 | 44kHz fs
11 | 128 ms window length 128*8 = 1024
12 | 32 ms window shift = 256
13 | """
14 |
15 | from torch.utils.data.dataset import Dataset
16 | from torch.utils.data import DataLoader
17 | from .feature_utils import *
18 | import glob
19 | import numpy as np
20 | import random
21 | import torch
22 | import torchaudio
23 | import torch.nn.functional as F
24 |
25 |
26 | def wsj0_2mix_dataloader(model_name, feature_options, partition, device=None):
27 | if partition == "tr" or partition == "cv":
28 | return DataLoader(
29 | wsj0_2mix_dataset(model_name, feature_options, partition, device=device),
30 | batch_size=feature_options.batch_size,
31 | shuffle=True,
32 | )
33 | elif partition == "tt":
34 | return DataLoader(
35 | wsj0_2mix_eval_dataset(model_name, feature_options, partition, device=device),
36 | batch_size=1,
37 | )
38 |
39 |
40 | class wsj0_2mix_dataset(Dataset):
41 | def __init__(self, model_name, feature_options, partition, device=None):
42 | """
43 | The arguments:
44 | feature_options: a dictionary containing the feature params
45 | partition: can be "tr", "cv"
46 | model_name: can be "dc", "chimera", "chimera++", "phase"
47 | e.g.
48 | "feature_options": {
49 | "data_path": "/home/data/wsj0-2mix",
50 | "batch_size": 16,
51 | "frame_length": 400,
52 | "sampling_rate": 8000,
53 | "window_size": 256,
54 | "hop_size": 64,
55 | "db_threshold": 40
56 | }
57 | The returns:
58 | input: a tuple which follows the requirement of the loss
59 | label: a tuple which follows the requirement of the loss
60 | e.g.
61 | for dc loss:
62 | input: (feature_mix)
63 | label: (one_hot_label)
64 | for chimera loss:
65 | input: (feature_mix)
66 | label: (one_hot_label, mag_mix, mag_s1, mag_s2)
67 | """
68 | self.model_name = model_name
69 | self.sampling_rate = feature_options.sampling_rate
70 | if self.model_name in ["lstm-tasnet", "conv-tasnet"]:
71 | self.chunk_size = feature_options.chunk_size
72 | else:
73 | self.window_size = feature_options.window_size
74 | self.hop_size = feature_options.hop_size
75 | self.frame_length = feature_options.frame_length
76 | self.db_threshold = feature_options.db_threshold
77 | self.file_list = []
78 | full_path = feature_options.data_path+'/wav8k/min/'+partition+'/mix/*.wav'
79 | self.file_list = glob.glob(full_path)
80 | if device is None:
81 | self.device = torch.device('cpu')
82 | else:
83 | self.device = device
84 |
85 |
86 | def get_tr_sigs(self, fn, sr):
87 | sig, rate = torchaudio.load(fn)
88 | assert(rate==sr)
89 | sig_s1, rate = torchaudio.load(fn.replace('/mix','/s1'))
90 | sig_s2, rate = torchaudio.load(fn.replace('/mix','/s2'))
91 | if sig.shape[1] < self.chunk_size:
92 | gap = self.chunk_size- sig.shape[1]
93 | sig = F.pad(sig, (0, gap), mode='constant')
94 | sig_s1 = F.pad(sig_s1, (0, gap), mode='constant')
95 | sig_s2 = F.pad(sig_s2, (0, gap), mode='constant')
96 | else:
97 | random_start = random.randint(0, sig.shape[1]-self.chunk_size)
98 | sig = sig[:, random_start:self.chunk_size+random_start]
99 | sig_s1 = sig_s1[:, random_start:self.chunk_size+random_start]
100 | sig_s2 = sig_s2[:, random_start:self.chunk_size+random_start]
101 | return sig, sig_s1, sig_s2
102 |
103 | def get_feature(self,fn):
104 | if self.model_name in ["lstm-tasnet", "conv-tasnet"]:
105 | sig_mix, sig_s1, sig_s2 = self.get_tr_sigs(fn, self.sampling_rate)
106 | sig_mix = sig_mix.reshape(-1,)
107 | sig_s1 = sig_s1.reshape(-1,)
108 | sig_s2 = sig_s2.reshape(-1,)
109 | input, label = [sig_mix], [sig_s1, sig_s2]
110 | input = [ele.to(self.device) for ele in input]
111 | label = [ele.to(self.device) for ele in label]
112 | return input, label
113 |
114 | stft_mix = get_stft(fn, self.sampling_rate, self.window_size, self.hop_size)
115 | stft_s1 = get_stft(fn.replace('/mix','/s1'), self.sampling_rate, self.window_size, self.hop_size)
116 | stft_s2 = get_stft(fn.replace('/mix','/s2'), self.sampling_rate, self.window_size, self.hop_size)
117 |
118 | if stft_mix.shape[0]<=self.frame_length:
119 | #pad in a double-copy fashion
120 | times = self.frame_length // stft_mix.shape[0]+1
121 | stft_mix = np.concatenate([stft_mix]*times, axis=0)
122 | stft_s1 = np.concatenate([stft_s1]*times, axis=0)
123 | stft_s2 = np.concatenate([stft_s2]*times, axis=0)
124 |
125 | random_index = np.random.randint(stft_mix.shape[0]-self.frame_length)
126 | stft_mix = stft_mix[random_index:random_index+self.frame_length]
127 | stft_s1 = stft_s1[random_index:random_index+self.frame_length]
128 | stft_s2 = stft_s2[random_index:random_index+self.frame_length]
129 | # base feature
130 | feature_mix = get_log_magnitude(stft_mix)
131 | # one_hot_label
132 | mag_mix = np.abs(stft_mix)
133 | mag_s1 = np.abs(stft_s1)
134 | mag_s2 = np.abs(stft_s2)
135 | one_hot_label = get_one_hot(feature_mix, mag_s1, mag_s2, self.db_threshold)
136 |
137 | if self.model_name == "dc":
138 | input, label = [feature_mix], [one_hot_label, mag_mix]
139 |
140 | if self.model_name == "chimera":
141 | input, label = [feature_mix], [one_hot_label, mag_mix, mag_s1, mag_s2]
142 |
143 | if self.model_name == "chimera++":
144 | cos_s1 = get_cos_difference(stft_mix, stft_s1)
145 | cos_s2 = get_cos_difference(stft_mix, stft_s2)
146 | input, label = [feature_mix], [one_hot_label, mag_mix, mag_s1, mag_s2, cos_s1, cos_s2]
147 |
148 | if self.model_name == "phase":
149 | phase_mix = get_phase(stft_mix)
150 | phase_s1 = get_phase(stft_s1)
151 | phase_s2 = get_phase(stft_s2)
152 | input, label = [feature_mix, phase_mix], [one_hot_label, mag_mix, mag_s1, mag_s2, phase_s1, phase_s2]
153 |
154 |
155 | input = [torch.tensor(ele).to(self.device) for ele in input]
156 | label = [torch.tensor(ele).to(self.device) for ele in label]
157 |
158 | return input, label
159 |
160 |
161 | def __getitem__(self, index):
162 | file_name_mix = self.file_list[index]
163 | return self.get_feature(file_name_mix)
164 |
165 |
166 | def __len__(self):
167 | return len(self.file_list)
168 |
169 |
170 | class wsj0_2mix_eval_dataset(Dataset):
171 | def __init__(self, model_name, feature_options, partition, device=None):
172 | """
173 | The arguments:
174 | feature_options: a dictionary containing the feature params
175 | partition: can be "tr", "cv"
176 | model_name: can be "dc", "chimera", "chimera++", "phase"
177 | e.g.
178 | "feature_options": {
179 | "data_path": "/home/data/wsj0-2mix",
180 | "batch_size": 16,
181 | "frame_length": 400,
182 | "sampling_rate": 8000,
183 | "window_size": 256,
184 | "hop_size": 64,
185 | "db_threshold": 40
186 | }
187 | The returns:
188 | input: a tuple which follows the requirement of the loss
189 | label: a tuple which follows the requirement of the loss
190 | e.g.
191 | for dc loss:
192 | input: (feature_mix)
193 | label: (one_hot_label)
194 | for chimera loss:
195 | input: (feature_mix)
196 | label: (one_hot_label, mag_mix, mag_s1, mag_s2)
197 | """
198 | self.model_name = model_name
199 | self.sampling_rate = feature_options.sampling_rate
200 | if self.model_name in ["lstm-tasnet", "conv-tasnet"]:
201 | self.chunk_size = feature_options.chunk_size
202 | else:
203 | self.window_size = feature_options.window_size
204 | self.hop_size = feature_options.hop_size
205 | self.frame_length = feature_options.frame_length
206 | self.db_threshold = feature_options.db_threshold
207 | self.file_list = []
208 | full_path = feature_options.data_path+'/wav8k/min/'+partition+'/mix/*.wav'
209 | self.file_list = glob.glob(full_path)
210 | if device is None:
211 | self.device = torch.device('cpu')
212 | else:
213 | self.device = device
214 |
215 |
216 | def get_sigs(self, fn, sr):
217 | sig_mix, rate = torchaudio.load(fn)
218 | assert(rate==sr)
219 | sig_s1, rate = torchaudio.load(fn.replace('tt/mix/','tt/s1/'))
220 | sig_s2, rate = torchaudio.load(fn.replace('tt/mix/','tt/s2/'))
221 | N = sig_mix.shape[1]
222 | gap = 32- N % 32
223 | sig_mix = F.pad(sig_mix, (0, gap), mode='constant')
224 | sig_s1 = F.pad(sig_s1, (0, gap), mode='constant')
225 | sig_s2 = F.pad(sig_s2, (0, gap), mode='constant')
226 | sig_ref = torch.cat((sig_s1, sig_s2), dim=0)
227 | sig_mix = sig_mix.reshape(-1,)
228 | return sig_mix, sig_ref
229 |
230 |
231 | def get_feature(self,fn):
232 | if self.model_name in ["lstm-tasnet", "conv-tasnet"]:
233 | sig_mix, sig_ref = self.get_sigs(fn, self.sampling_rate)
234 | input, label = [sig_mix.to(self.device)], [sig_ref.to(self.device)]
235 | else:
236 | stft_mix = get_stft(fn, self.sampling_rate, self.window_size, self.hop_size)
237 | stft_r_mix = np.real(stft_mix)
238 | stft_i_mix = np.imag(stft_mix)
239 | feature_mix = get_log_magnitude(stft_mix)
240 | sig_ref = self.get_ref_sig(fn)
241 | input, label = [feature_mix], [stft_r_mix, stft_i_mix, sig_ref]
242 | input = [torch.tensor(ele).to(self.device) for ele in input]
243 | label = [torch.tensor(ele).to(self.device) for ele in label]
244 |
245 | return input, label
246 |
247 |
248 | def __getitem__(self, index):
249 | file_name_mix = self.file_list[index]
250 | return self.get_feature(file_name_mix)
251 |
252 |
253 | def __len__(self):
254 | return len(self.file_list)
255 |
--------------------------------------------------------------------------------
/onssen/evaluate/__init__.py:
--------------------------------------------------------------------------------
1 | from .sdr import batch_SDR_torch
2 |
3 |
--------------------------------------------------------------------------------
/onssen/evaluate/sdr.py:
--------------------------------------------------------------------------------
1 | ### Forked from https://github.com/yluo42/TAC/blob/master/utility/sdr.py
2 | import numpy as np
3 | from itertools import permutations
4 | from torch.autograd import Variable
5 |
6 | import scipy,time,numpy
7 |
8 | import torch
9 |
10 | # Pytorch implementation with batch processing
11 | def calc_sdr_torch(estimation, origin, mask=None):
12 | """
13 | batch-wise SDR caculation for one audio file on pytorch Variables.
14 | estimation: (batch, nsample)
15 | origin: (batch, nsample)
16 | mask: an optional mask for sequence masking. This is for cases where zero-padding was applied at the end and should not be consider for SDR calculation.
17 | """
18 |
19 | if mask is not None:
20 | origin = origin * mask
21 | estimation = estimation * mask
22 |
23 | def calculate(estimation, origin):
24 | origin_power = torch.pow(origin, 2).sum(1, keepdim=True) + 1e-8 # (batch, 1)
25 | scale = torch.sum(origin*estimation, 1, keepdim=True) / origin_power # (batch, 1)
26 |
27 | est_true = scale * origin # (batch, nsample)
28 | est_res = estimation - est_true # (batch, nsample)
29 |
30 | true_power = torch.pow(est_true, 2).sum(1) + 1e-8
31 | res_power = torch.pow(est_res, 2).sum(1) + 1e-8
32 |
33 | return 10*torch.log10(true_power) - 10*torch.log10(res_power) # (batch, )
34 |
35 | best_sdr = calculate(estimation, origin)
36 |
37 | return best_sdr
38 |
39 |
40 | def batch_SDR_torch(estimation, origin, mask=None, return_perm=False):
41 | """
42 | batch-wise SDR caculation for multiple audio files.
43 | estimation: (batch, nsource, nsample)
44 | origin: (batch, nsource, nsample)
45 | mask: optional, (batch, nsample), binary
46 | return_perm: bool, whether to return the permutation index. Default is false.
47 | """
48 |
49 | batch_size_est, nsource_est, nsample_est = estimation.size()
50 | batch_size_ori, nsource_ori, nsample_ori = origin.size()
51 |
52 | assert batch_size_est == batch_size_ori, "Estimation and original sources should have same shape."
53 | assert nsource_est == nsource_ori, "Estimation and original sources should have same shape."
54 | assert nsample_est == nsample_ori, "Estimation and original sources should have same shape."
55 |
56 | assert nsource_est < nsample_est, "Axis 1 should be the number of sources, and axis 2 should be the signal."
57 |
58 | batch_size = batch_size_est
59 | nsource = nsource_est
60 |
61 | # zero mean signals
62 | estimation = estimation - torch.mean(estimation, 2, keepdim=True).expand_as(estimation)
63 | origin = origin - torch.mean(origin, 2, keepdim=True).expand_as(estimation)
64 |
65 | # SDR for each permutation
66 | SDR = torch.zeros((batch_size, nsource, nsource)).type(estimation.type())
67 | for i in range(nsource):
68 | for j in range(nsource):
69 | SDR[:,i,j] = calc_sdr_torch(estimation[:,i], origin[:,j], mask)
70 |
71 | # choose the best permutation
72 | SDR_max = []
73 | SDR_perm = []
74 | perm = sorted(list(set(permutations(np.arange(nsource)))))
75 | for permute in perm:
76 | sdr = []
77 | for idx in range(len(permute)):
78 | sdr.append(SDR[:,idx,permute[idx]].view(batch_size,-1))
79 | sdr = torch.sum(torch.cat(sdr, 1), 1)
80 | SDR_perm.append(sdr.view(batch_size, 1))
81 | SDR_perm = torch.cat(SDR_perm, 1)
82 | SDR_max, SDR_idx = torch.max(SDR_perm, dim=1)
83 |
84 | if not return_perm:
85 | return SDR_max / nsource
86 | else:
87 | return SDR_max / nsource, SDR_idx
88 |
--------------------------------------------------------------------------------
/onssen/loss/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | For every new loss function added, please import it here and add it to loss_fns
3 | The loss function should take two arguments:
4 | output: a tuple from the network
5 | label: a tuple which is from dataloader
6 | You need to assert the format of the output and label in the loss function!
7 | """
8 | from .loss_dc import loss_dc
9 | from .loss_chimera import loss_chimera_msa, loss_chimera_psa
10 | from .loss_mask import loss_mask_msa, loss_mask_psa
11 | from .loss_phase import loss_phase
12 | from .loss_e2e import SI_SNR, permute_SI_SNR, sisnr, si_snr_loss
13 | from .loss_util import T, norm, norm_1d
14 |
15 |
16 | __all__ = [
17 | 'loss_dc', 'loss_chimera_msa', 'loss_chimera_psa',
18 | 'loss_mask_msa', 'loss_mask_psa',
19 | 'loss_phase',
20 | 'SI_SNR', 'permute_SI_SNR', 'sisnr', 'si_snr_loss',
21 | ]
22 |
--------------------------------------------------------------------------------
/onssen/loss/loss_chimera.py:
--------------------------------------------------------------------------------
1 | from .loss_util import T, norm, norm_1d
2 | from .loss_dc import loss_dc
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | def loss_chimera_msa(output, label):
7 | """
8 | output:
9 | noisy_mag: batch_size X T X F tensor
10 | masks: batch_size X T X F X num_speaker tensor
11 | clean_mags: batch_size X T X F X num_speaker tensor
12 | label:
13 | one_hot_label: the label for deep clustering
14 | mag_mix: the magnitude of mix speech
15 | mag_s1: the magnitude of clean speech s1
16 | mag_s2: the magnitude of clean speech s2
17 | """
18 | [embedding, mask_A, mask_B] = output
19 | [one_hot_label, mag_mix, mag_s1, mag_s2] = label
20 | batch_size, frame, frequency = mask_A.shape
21 | # compute the loss of embedding part
22 | loss_embedding = loss_dc([embedding], [one_hot_label, mag_mix])
23 |
24 | #compute the loss of mask part
25 | loss_mask1 = norm_1d(mask_A*mag_mix - mag_s1)\
26 | + norm_1d(mask_B*mag_mix - mag_s2)
27 | loss_mask2 = norm_1d(mask_B*mag_mix - mag_s1)\
28 | + norm_1d(mask_A*mag_mix - mag_s2)
29 | loss_mask = torch.min(loss_mask1, loss_mask2)
30 |
31 | return loss_embedding*0.975 + loss_mask*0.025
32 |
33 | def loss_chimera_psa(output, label):
34 | """
35 | output:
36 | noisy_mag: batch_size X T X F tensor
37 | masks: batch_size X T X F X num_speaker tensor
38 | clean_mags: batch_size X T X F X num_speaker tensor
39 | label:
40 | one_hot_label: the label for deep clustering
41 | mag_mix: the magnitude of mix speech
42 | mag_s1: the magnitude of clean speech s1
43 | mag_s2: the magnitude of clean speech s2
44 | cos_s1: the cosine of phase difference between mix and s1
45 | cos_s2: the cosine of phase difference between mix and s2
46 | """
47 | [embedding, mask_A, mask_B] = output
48 | [one_hot_label, mag_mix, mag_s1, mag_s2, cos_s1, cos_s2] = label
49 | batch_size, frame, frequency = mask_A.shape
50 | # compute the loss of embedding part
51 | loss_embedding = loss_dc([embedding], [one_hot_label, mag_mix])
52 | #compute the loss of mask part
53 | loss_mask1 = norm_1d(mask_A*mag_mix - torch.min(mag_mix,F.relu(mag_s1*cos_s1)))\
54 | + norm_1d(mask_B*mag_mix - torch.min(mag_mix,F.relu(mag_s2*cos_s2)))
55 | loss_mask2 = norm_1d(mask_B*mag_mix - torch.min(mag_mix,F.relu(mag_s1*cos_s1)))\
56 | + norm_1d(mask_A*mag_mix - torch.min(mag_mix,F.relu(mag_s2*cos_s2)))
57 | loss_mask = torch.min(loss_mask1, loss_mask2)
58 |
59 | return loss_embedding*0.975 + loss_mask*0.025
60 |
--------------------------------------------------------------------------------
/onssen/loss/loss_dc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .loss_util import T, norm
4 |
5 |
6 | def loss_dc(output, label):
7 | """
8 | adopted from nussl loss function:
9 | https://github.com/interactiveaudiolab/nussl/blob/master/nussl/transformers/transformer_deep_clustering.py
10 | inputs:
11 | output: a tuple containing a batch_size X T X F X embedding_dim tensor
12 | label: a tuple containing a batch_size X T X F X num_speaker tensor
13 | outputs:
14 | loss of deep clustering model/layer
15 | """
16 | assert len(output)==1, "Number of output must be 1 for Deep Clustering"
17 | assert len(label)==2, "Number of label must be 2 for Deep Clustering"
18 | embedding, = output
19 | label, mag_mix = label
20 | label = label.float()
21 | batch_size, frame_dim, frequency_dim, one_hot_dim = label.size()
22 | _, _, _, embedding_dim = embedding.size()
23 |
24 | embedding = embedding.view(batch_size, -1, embedding_dim)
25 | mag_mix = mag_mix.detach().view(batch_size, -1)
26 | label = label.view(batch_size, -1, one_hot_dim)
27 |
28 | # remove the loss of silence TF regions
29 | silence_mask = label.sum(2, keepdim=True)
30 | embedding = silence_mask * embedding
31 |
32 | # referred as weight WR
33 | # W_i = |x_i| / \sigma_j{|x_j|}
34 | weights = torch.sqrt(mag_mix / mag_mix.sum(1, keepdim=True))
35 | label = label * weights.view(batch_size, frame_dim*frequency_dim, 1)
36 | embedding = embedding * weights.view(batch_size, frame_dim*frequency_dim, 1)
37 |
38 | # do batch affinity matrix computation
39 | loss_est = norm(torch.bmm(T(embedding), embedding))
40 | loss_est_true = 2*norm(torch.bmm(T(embedding), label))
41 | loss_true = norm(torch.bmm(T(label), label))
42 | loss_embedding = loss_est - loss_est_true + loss_true
43 |
44 | return loss_embedding * mag_mix.sum(1, keepdim=True)
45 |
--------------------------------------------------------------------------------
/onssen/loss/loss_e2e.py:
--------------------------------------------------------------------------------
1 | ### Created by Kai Li
2 | ### https://github.com/JusperLee/Conv-TasNet/blob/master/SI_SNR.py
3 | import torch
4 | from itertools import permutations
5 |
6 |
7 | def SI_SNR(_s, s, zero_mean=True):
8 | '''
9 | Calculate the SNR indicator between the two audios.
10 | The larger the value, the better the separation.
11 | input:
12 | _s: Generated audio
13 | s: Ground Truth audio
14 | output:
15 | SNR value
16 | '''
17 | if zero_mean:
18 | _s = _s - torch.mean(_s)
19 | s = s - torch.mean(s)
20 | s_target = sum(torch.mul(_s, s))*s/torch.pow(torch.norm(s, p=2), 2)
21 | e_noise = _s - s_target
22 | return 20*torch.log10(torch.norm(s_target, p=2)/torch.norm(e_noise, p=2))
23 |
24 |
25 | def permute_SI_SNR(_s_lists, s_lists):
26 | '''
27 | Calculate all possible SNRs according to
28 | the permutation combination and
29 | then find the maximum value.
30 | input:
31 | _s_lists: Generated audio list
32 | s_lists: Ground truth audio list
33 | output:
34 | max of SI-SNR
35 | '''
36 | length = len(_s_lists)
37 | results = []
38 | for p in permutations(range(length)):
39 | s_list = [s_lists[n] for n in p]
40 | result = sum([SI_SNR(_s, s) for _s, s in zip(_s_lists, s_list)])/length
41 | results.append(result)
42 | return max(results)
43 |
44 |
45 | def sisnr(x, s, eps=1e-8):
46 | """
47 | calculate training loss
48 | input:
49 | x: separated signal, N x S tensor
50 | s: reference signal, N x S tensor
51 | Return:
52 | sisnr: N tensor
53 | """
54 |
55 | def l2norm(mat, keepdim=False):
56 | return torch.norm(mat, dim=-1, keepdim=keepdim)
57 |
58 | if x.shape != s.shape:
59 | raise RuntimeError(
60 | "Dimention mismatch when calculate si-snr, {} vs {}".format(
61 | x.shape, s.shape))
62 | x_zm = x - torch.mean(x, dim=-1, keepdim=True)
63 | s_zm = s - torch.mean(s, dim=-1, keepdim=True)
64 | t = torch.sum(
65 | x_zm * s_zm, dim=-1,
66 | keepdim=True) * s_zm / (l2norm(s_zm, keepdim=True)**2 + eps)
67 | return 20 * torch.log10(eps + l2norm(t) / (l2norm(x_zm - t) + eps))
68 |
69 |
70 | def si_snr_loss(ests, refs):
71 | # spks x N x S
72 | num_spks = len(refs)
73 |
74 | def sisnr_loss(permute):
75 | # for one permute
76 | return sum(
77 | [sisnr(ests[s], refs[t])
78 | for s, t in enumerate(permute)]) / len(permute)
79 | # average the value
80 |
81 | # P x N
82 | N, S = refs[0].shape
83 | sisnr_mat = torch.stack(
84 | [sisnr_loss(p) for p in permutations(range(num_spks))])
85 | max_perutt, _ = torch.max(sisnr_mat, dim=0)
86 | # si-snr
87 | return -torch.sum(max_perutt) / N
88 |
89 |
90 | if __name__ == "__main__":
91 | a_t = torch.tensor([1, 2, 3], dtype=torch.float32)
92 | b_t = torch.tensor([1, 4, 6], dtype=torch.float32)
93 | print(permute_SI_SNR([a_t], [b_t]))
94 |
--------------------------------------------------------------------------------
/onssen/loss/loss_mask.py:
--------------------------------------------------------------------------------
1 | from .loss_util import norm, norm_1d
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def loss_mask_msa(output, label):
7 | """
8 | The loss function of Magnitude Spectrum Approximation (MSA).
9 | It is for enhancing speech in a noisy recording.
10 | output:
11 | mask: batch_size X T X F tensor
12 | label:
13 | mag_noisy: the magnitude of mix speech
14 | mag_clean: the magnitude of clean speech s1
15 | """
16 | [clean_est] = output
17 | [mag_clean, cos_diff] = label
18 | #compute the loss of mask part
19 | # loss = nn.MSELoss()(mask * mag_noisy, mag_clean)
20 |
21 | loss = torch.nn.MSELoss()(clean_est, mag_clean)
22 | return loss
23 |
24 |
25 | def loss_mask_psa(output, label):
26 | """
27 | The loss function of Phase-sensitive Spectrum Approximation (PSA).
28 | It is for enhancing speech in a noisy recording.
29 | output:
30 | mask: batch_size X T X F tensor
31 | label:
32 | mag_noisy: the magnitude of mix speech
33 | mag_clean: the magnitude of clean speech s1
34 | cos_diff: the cosine of phase difference between mix and clean
35 | """
36 | [mask] = output
37 | [mag_noisy, mag_clean, cos_diff] = label
38 | #compute the loss of mask part
39 | loss = norm_1d(mask * mag_noisy - torch.min(mag_noisy,F.relu(mag_clean*cos_diff)))
40 | return loss
41 |
--------------------------------------------------------------------------------
/onssen/loss/loss_phase.py:
--------------------------------------------------------------------------------
1 | from .loss_util import T, norm, norm_1d
2 | from .loss_dc import loss_dc
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | def loss_phase(output, label):
7 | assert len(output) == 6, "There must be 5 tensors in the output"
8 | assert len(label) == 6, "There must be 6 tensors in the label"
9 | [embedding, mask_A, mask_B, phase_A, phase_B] = output
10 | [one_hot_label, mag_mix, mag_s1, mag_s2, phase_s1, phase_s2] = label
11 | batch_size, time_size, frequency_size = mag_mix.size()
12 | # compute the loss of embedding part
13 | loss_embedding = loss_dc([embedding, mag_mix], [one_hot_label])
14 |
15 | #compute the loss of mask part
16 | loss_mask1 = norm_1d(mask_A*mag_mix - mag_s1)\
17 | + norm_1d(mask_B*mag_mix - mag_s2)
18 | loss_mask2 = norm_1d(mask_B*mag_mix - mag_s1)\
19 | + norm_1d(mask_A*mag_mix - mag_s2)
20 |
21 | amin = loss_mask1 n x N x T
193 | self.encoder = Conv1D(1, N, L, stride=L // 2, padding=0)
194 | # n x N x T Layer Normalization of Separation
195 | self.LayerN_S = select_norm('cln', N)
196 | # n x B x T Conv 1 x 1 of Separation
197 | self.BottleN_S = Conv1D(N, B, 1)
198 | # Separation block
199 | # n x B x T => n x B x T
200 | self.separation = self._Sequential_repeat(
201 | R, X, in_channels=B, out_channels=H, kernel_size=P, norm=norm, causal=causal)
202 | # n x B x T => n x 2*N x T
203 | self.gen_masks = Conv1D(B, num_spks*N, 1)
204 | # n x N x T => n x 1 x L
205 | self.decoder = ConvTrans1D(N, 1, L, stride=L//2)
206 | # activation function
207 | active_f = {
208 | 'relu': nn.ReLU(),
209 | 'sigmoid': nn.Sigmoid(),
210 | 'softmax': nn.Softmax(dim=0)
211 | }
212 | self.activation_type = activate
213 | self.activation = active_f[activate]
214 | self.num_spks = num_spks
215 |
216 | def _Sequential_block(self, num_blocks, **block_kwargs):
217 | '''
218 | Sequential 1-D Conv Block
219 | input:
220 | num_block: how many blocks in every repeats
221 | **block_kwargs: parameters of Conv1D_Block
222 | '''
223 | Conv1D_Block_lists = [Conv1D_Block(
224 | **block_kwargs, dilation=(2**i)) for i in range(num_blocks)]
225 |
226 | return nn.Sequential(*Conv1D_Block_lists)
227 |
228 | def _Sequential_repeat(self, num_repeats, num_blocks, **block_kwargs):
229 | '''
230 | Sequential repeats
231 | input:
232 | num_repeats: Number of repeats
233 | num_blocks: Number of block in every repeats
234 | **block_kwargs: parameters of Conv1D_Block
235 | '''
236 | repeats_lists = [self._Sequential_block(
237 | num_blocks, **block_kwargs) for i in range(num_repeats)]
238 | return nn.Sequential(*repeats_lists)
239 |
240 | def forward(self, inp):
241 | x, = inp
242 | if x.dim() >= 3:
243 | raise RuntimeError(
244 | "{} accept 1/2D tensor as input, but got {:d}".format(
245 | self.__name__, x.dim()))
246 | if x.dim() == 1:
247 | x = torch.unsqueeze(x, 0)
248 | # x: n x 1 x L => n x N x T
249 | w = self.encoder(x)
250 | # n x N x L => n x B x L
251 | e = self.LayerN_S(w)
252 | e = self.BottleN_S(e)
253 | # n x B x L => n x B x L
254 | e = self.separation(e)
255 | # n x B x L => n x num_spk*N x L
256 | m = self.gen_masks(e)
257 | # n x N x L x num_spks
258 | m = torch.chunk(m, chunks=self.num_spks, dim=1)
259 | # num_spks x n x N x L
260 | m = self.activation(torch.stack(m, dim=0))
261 | d = [w*m[i] for i in range(self.num_spks)]
262 | # decoder part num_spks x n x L
263 | s = [self.decoder(d[i], squeeze=True) for i in range(self.num_spks)]
264 | return s
265 |
266 |
267 | def check_parameters(net):
268 | '''
269 | Returns module parameters. Mb
270 | '''
271 | parameters = sum(param.numel() for param in net.parameters())
272 | return parameters / 10**6
273 |
274 |
275 | def test_convtasnet():
276 | x = torch.randn(4, 32)
277 | nnet = ConvTasNet()
278 | s = nnet(x)
279 | print(str(look_parameters(nnet))+' Mb')
280 | print(s[1].shape)
281 |
282 |
--------------------------------------------------------------------------------
/onssen/nn/uPIT-LSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class uPIT_LSTM(nn.Module):
6 | def __init__(
7 | self,
8 | input_dim,
9 | output_dim,
10 | hidden_dim=300,
11 | embedding_dim=20,
12 | num_layers=3,
13 | dropout=0.5,
14 | num_speaker=2,
15 | ):
16 | super(uPIT_LSTM, self).__init__()
17 | self.output_dim = output_dim
18 | rnn = nn.LSTM(
19 | input_dim,
20 | hidden_dim,
21 | num_layers,
22 | dropout=dropout,
23 | bidirectional=True,
24 | batch_first = True
25 | )
26 | fc_mi = nn.Linear(hidden_dim * 2, output_dim * num_speaker)
27 | self.add_module('rnn', rnn)
28 | self.add_module('fc_mi', fc_mi)
29 |
30 | def forward(self, input):
31 | # x is B*T*F tensor
32 | x = input[0].float()
33 | batch_size, frame_size, _ = x.size()
34 | self.rnn.flatten_parameters()
35 | rnn_output, hidden = self.rnn(x)
36 | masks = self.fc_mi(rnn_output)
37 | masks = torch.sigmoid(masks)
38 | masks = masks.reshape(batch_size, frame_size, self.output_dim, -1)
39 | mask_A = masks[:,:,:,0]
40 | mask_B = masks[:,:,:,1]
41 | return [mask_A, mask_B]
42 |
--------------------------------------------------------------------------------
/onssen/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .train import trainer
2 | from .test import tester
3 | from .basic import build_optimizer, get_free_gpu, AverageMeter, generate_train_validation_list
4 |
5 | __all__ = [
6 | 'trainer', 'tester', 'build_optimizer',
7 | 'AverageMeter', 'get_free_gpu',
8 | 'generate_train_validation_list'
9 | ]
10 |
--------------------------------------------------------------------------------
/onssen/utils/basic.py:
--------------------------------------------------------------------------------
1 | import glob, numpy as np, torch
2 | from sklearn.model_selection import train_test_split
3 |
4 |
5 | def build_optimizer(params, optimizer_options):
6 | if optimizer_options.name == "adam":
7 | return torch.optim.Adam(params, lr=optimizer_options.lr)
8 | if optimizer_options.name == "sgd":
9 | return torch.optim.SGD(params, lr=optimizer_options.lr, momentum=0.9)
10 | if optimizer_options.name == "rmsprop":
11 | return torch.optim.RMSprop(params, lr=optimizer_options.lr)
12 |
13 | class AverageMeter(object):
14 | """Computes and stores the average and current value"""
15 | def __init__(self):
16 | self.reset()
17 |
18 | def reset(self):
19 | self.val = 0
20 | self.avg = 0
21 | self.sum = 0
22 | self.count = 0
23 |
24 | def update(self, val, n=1):
25 | self.val = val
26 | self.sum += val * n
27 | self.count += n
28 | self.avg = self.sum / self.count
29 |
30 | def get_free_gpu():
31 | os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
32 | memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
33 | return np.argmax(memory_available)
34 |
35 | def generate_train_validation_list(data_path, train_size=0.8):
36 | file_list = glob.glob(data_path+'*.wav')
37 | file_list = np.array(file_list)
38 | train, validation = train_test_split(filenames,train_size=train_size)
39 |
--------------------------------------------------------------------------------
/onssen/utils/test.py:
--------------------------------------------------------------------------------
1 | from attrdict import AttrDict
2 | from .basic import AverageMeter
3 | import argparse, json, os, time, torch
4 | from ..evaluate import batch_SDR_torch
5 |
6 |
7 | class tester:
8 | def __init__(self, args):
9 | """
10 | args: a dictionary containing
11 | model_name(str): the name of the model
12 | model(nn.Module): the module object
13 | test_loader(DataLoader): the PyTorch built-in DataLoader object
14 | """
15 | self.model_name = args.model_name
16 | self.test_loader = args.test_loader
17 | self.device = args.device
18 |
19 | # build model
20 | self.model = args.model
21 | saved_dict = torch.load(args.checkpoint_path+'/final.mdl')
22 | self.model.load_state_dict(saved_dict["model"])
23 | self.model = self.model.to(self.device)
24 | print("Loaded the model...")
25 |
26 | def get_est_sig(self, input, label, output):
27 | pass
28 |
29 | def eval(self):
30 | sdrs = AverageMeter()
31 | self.model = self.model.eval()
32 | with torch.no_grad():
33 | for i, data in enumerate(self.test_loader):
34 | input, label = data
35 | output = self.model(input)
36 | sig_est, sig_ref = self.get_est_sig(input, label, output)
37 | sdr = batch_SDR_torch(sig_est, sig_ref)
38 | sdrs.update(sdr)
39 | print("SDR: %.2f"%(sdrs.avg), end='\r')
40 |
41 | print("\n")
42 |
43 |
44 | def main():
45 | parser = argparse.ArgumentParser(description='Parse the config path')
46 | parser.add_argument("-c", "--config", dest="path",
47 | help='The path to the config file. e.g. python train.py --config config.json')
48 |
49 | config = parser.parse_args()
50 | with open(config.path) as f:
51 | args = json.load(f)
52 | args = AttrDict(args)
53 | t = tester(args)
54 | t.eval()
55 |
56 |
57 | if __name__ == "__main__":
58 | main()
59 |
--------------------------------------------------------------------------------
/onssen/utils/train.py:
--------------------------------------------------------------------------------
1 | from attrdict import AttrDict
2 | from .basic import AverageMeter
3 | import argparse, json, os, time, torch
4 |
5 |
6 | class trainer:
7 | def __init__(self, args):
8 | """
9 | args: a dictionary containing
10 | model_name(str): the name of the model
11 | model(nn.Module): the module object
12 | data_loader(DataLoader): the PyTorch built-in DataLoader object
13 | loss_fn(function): the loss function
14 | resume_from_checkpoint(bool): check if resume the training from a checkpoint, False by default
15 | checkpoint_path(str): the path to the saved dictionary
16 | device(torch.device): the device for training, cpu by default
17 | """
18 |
19 | if "resume_from_checkpoint" in args and args.resume_from_checkpoint=="True":
20 | self.resume_from_checkpoint = True
21 | else:
22 | self.resume_from_checkpoint = False
23 |
24 | self.device = args.device
25 | if "cv_device" not in args:
26 | self.cv_device = self.device
27 | else:
28 | self.cv_device = args.cv_device
29 | self.model_name = args.model_name
30 | self.train_loader = args.train_loader
31 | self.valid_loader = args.valid_loader
32 | self.loss_fn = args.loss_fn
33 |
34 | # build model
35 | if self.resume_from_checkpoint:
36 | self.resume_from_checkpoint(args.checkpoint_path)
37 | else:
38 | self.model = args.model
39 | self.optimizer = args.optimizer
40 | self.epoch = 0
41 | self.min_loss = float("inf")
42 | self.early_stop_count = 0
43 | print("Loaded the model...")
44 | self.num_epoch = args.num_epoch
45 | self.checkpoint_path = args.checkpoint_path
46 | if not os.path.exists(self.checkpoint_path):
47 | os.makedirs(self.checkpoint_path)
48 |
49 | def resume_from_checkpoint(self, checkpoint_path):
50 | saved_dict = torch.load(checkpoint_path+'/final.mdl')
51 | self.model = saved_dict["model"]
52 | self.model = self.model.to(self.device)
53 | self.epoch = saved_dict["epoch"]
54 | self.min_loss = saved_dict["cv_loss"]
55 | self.early_stop_count = saved_dict["early_stop_count"]
56 |
57 | def run(self):
58 | for epoch in range(self.epoch, self.num_epoch):
59 | self.train(epoch)
60 | self.validate(epoch)
61 | if self.early_stop_count == 8:
62 | print("Model stops improving, stop the training")
63 | break
64 | print("Model training is finished.")
65 |
66 | def train(self, epoch):
67 | losses = AverageMeter()
68 | times = AverageMeter()
69 | losses.reset()
70 | times.reset()
71 | self.model = self.model.train()
72 | len_d = len(self.train_loader)
73 | init_time = time.time()
74 | end = init_time
75 | for i, data in enumerate(self.train_loader):
76 | input, label = data
77 | output = self.model(input)
78 | loss = self.loss_fn(output, label)
79 | loss_avg = torch.mean(loss)
80 | losses.update(loss_avg.item())
81 | self.optimizer.zero_grad()
82 | loss_avg.backward()
83 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
84 | self.optimizer.step()
85 | times.update(time.time()-end)
86 | end = time.time()
87 | print('epoch %d, %d/%d, training loss: %f, time estimated: %.2f/%.2f seconds'%(epoch, i+1,len_d,losses.avg, end-init_time, times.avg*len_d), end='\r')
88 | print("\n")
89 |
90 | def validate(self, epoch):
91 | self.model = self.model.eval()
92 | losses = AverageMeter()
93 | times = AverageMeter()
94 | losses.reset()
95 | times.reset()
96 | len_d = len(self.valid_loader)
97 | init_time = time.time()
98 | end = init_time
99 | with torch.no_grad():
100 | for i, data in enumerate(self.valid_loader):
101 | begin = time.time()
102 | input, label = data
103 | output = self.model(input)
104 | loss = self.loss_fn(output, label)
105 | loss_avg = torch.mean(loss)
106 | losses.update(loss_avg.item())
107 | times.update(time.time()-end)
108 | end = time.time()
109 | print('epoch %d, %d/%d, validation loss: %f, time estimated: %.2f/%.2f seconds'%(epoch, i+1,len_d,losses.avg, end-init_time, times.avg*len_d), end='\r')
110 | print("\n")
111 | if losses.avg < self.min_loss:
112 | self.early_stop_count = 0
113 | self.min_loss = losses.avg
114 | saved_dict = {
115 | 'model': self.model.state_dict(),
116 | 'epoch': epoch,
117 | 'optimizer': self.optimizer,
118 | 'cv_loss': self.min_loss,
119 | "early_stop_count": self.early_stop_count
120 | }
121 | torch.save(saved_dict,self.checkpoint_path+"/final.mdl")
122 | print("Saved new model")
123 | else:
124 | self.early_stop_count += 1
125 |
126 |
127 | def main():
128 | parser = argparse.ArgumentParser(description='Parse the config path')
129 | parser.add_argument("-c", "--config", dest="path",
130 | help='The path to the config file. e.g. python train.py --config configs/dc_config.json')
131 |
132 | config = parser.parse_args()
133 | with open(config.path) as f:
134 | args = json.load(f)
135 | args = AttrDict(args)
136 | t = trainer(args)
137 | t.run()
138 |
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------