├── .gitignore
├── README.md
├── __init__.py
├── doc
├── sing_workflow.json
└── speech_workflow.json
├── requirements.txt
└── seedvc
├── dac
├── __init__.py
├── __main__.py
├── model
│ ├── __init__.py
│ ├── base.py
│ ├── dac.py
│ ├── discriminator.py
│ └── encodec.py
├── nn
│ ├── __init__.py
│ ├── layers.py
│ ├── loss.py
│ └── quantize.py
└── utils
│ ├── __init__.py
│ ├── decode.py
│ └── encode.py
└── modules
├── alias_free_torch
├── __init__.py
├── act.py
├── filter.py
└── resample.py
├── audio.py
├── bigvgan
├── activations.py
├── alias_free_activation
│ ├── cuda
│ │ ├── __init__.py
│ │ ├── activation1d.py
│ │ ├── anti_alias_activation.cpp
│ │ ├── anti_alias_activation_cuda.cu
│ │ ├── compat.h
│ │ ├── load.py
│ │ └── type_shim.h
│ └── torch
│ │ ├── __init__.py
│ │ ├── act.py
│ │ ├── filter.py
│ │ └── resample.py
├── bigvgan.py
├── config.json
├── env.py
├── meldataset.py
└── utils.py
├── campplus
├── DTDNN.py
├── classifier.py
└── layers.py
├── commons.py
├── diffusion_transformer.py
├── encodec.py
├── flow_matching.py
├── gpt_fast
├── generate.py
├── model.py
└── quantize.py
├── hifigan
├── f0_predictor.py
└── generator.py
├── layers.py
├── length_regulator.py
├── quantize.py
├── rmvpe.py
└── wavenet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SeedVC-ComfyUI
2 | a custom node for [seed-vc](https://github.com/Plachtaa/seed-vc),you can find [workflow](./doc)
3 |
4 | ## Update
5 | - 2024-10-28:
6 | - Updated fine-tuned 44k singing voice conversion model with better audio quality
7 | - 2024-10-24:
8 | - Updated v0.3 pretrained model, changed speech content encoder to OpenAI Whisper
9 | - 2024-09-22:
10 | - Updated singing voice conversion model to use BigVGAN from NVIDIA, providing large improvement to high-pitched singing voices
11 | - 2024.9.18 sing vc
12 |
13 | ## 教程
14 | - [Demo](https://b23.tv/IfDHZ9w)
15 | - [镜像](https://www.xiangongyun.com/image/detail/f19243de-f62b-435e-96fc-ce29acbedd85)
16 | - [一键包](https://b23.tv/2Uj8QHD)
17 | ## Disclaimer / 免责声明
18 | We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws. 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
19 | ## Example
20 | |source|reference|output|
21 | |--|--|--|
22 | ||||
23 | ||||
24 | ||||
25 |
--------------------------------------------------------------------------------
/doc/sing_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 16,
3 | "last_link_id": 26,
4 | "nodes": [
5 | {
6 | "id": 2,
7 | "type": "LoadAudio",
8 | "pos": {
9 | "0": 34,
10 | "1": 75
11 | },
12 | "size": {
13 | "0": 315,
14 | "1": 124
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "inputs": [],
20 | "outputs": [
21 | {
22 | "name": "AUDIO",
23 | "type": "AUDIO",
24 | "links": [
25 | 1
26 | ],
27 | "shape": 3
28 | }
29 | ],
30 | "title": "LoadAudio上传歌曲",
31 | "properties": {
32 | "Node name for S&R": "LoadAudio"
33 | },
34 | "widgets_values": [
35 | "qinhua.mp3",
36 | null,
37 | ""
38 | ]
39 | },
40 | {
41 | "id": 3,
42 | "type": "PreviewAudio",
43 | "pos": {
44 | "0": 967,
45 | "1": 74
46 | },
47 | "size": {
48 | "0": 315,
49 | "1": 76
50 | },
51 | "flags": {},
52 | "order": 4,
53 | "mode": 0,
54 | "inputs": [
55 | {
56 | "name": "audio",
57 | "type": "AUDIO",
58 | "link": 2
59 | }
60 | ],
61 | "outputs": [],
62 | "title": "PreviewAudio原声",
63 | "properties": {
64 | "Node name for S&R": "PreviewAudio"
65 | },
66 | "widgets_values": [
67 | null
68 | ]
69 | },
70 | {
71 | "id": 14,
72 | "type": "PreviewAudio",
73 | "pos": {
74 | "0": 983,
75 | "1": 199
76 | },
77 | "size": {
78 | "0": 315,
79 | "1": 76
80 | },
81 | "flags": {},
82 | "order": 5,
83 | "mode": 0,
84 | "inputs": [
85 | {
86 | "name": "audio",
87 | "type": "AUDIO",
88 | "link": 19
89 | }
90 | ],
91 | "outputs": [],
92 | "title": "PreviewAudio伴奏",
93 | "properties": {
94 | "Node name for S&R": "PreviewAudio"
95 | },
96 | "widgets_values": [
97 | null
98 | ]
99 | },
100 | {
101 | "id": 13,
102 | "type": "PreviewAudio",
103 | "pos": {
104 | "0": 976,
105 | "1": 595
106 | },
107 | "size": {
108 | "0": 315,
109 | "1": 76
110 | },
111 | "flags": {},
112 | "order": 6,
113 | "mode": 0,
114 | "inputs": [
115 | {
116 | "name": "audio",
117 | "type": "AUDIO",
118 | "link": 18
119 | }
120 | ],
121 | "outputs": [],
122 | "properties": {
123 | "Node name for S&R": "PreviewAudio"
124 | },
125 | "widgets_values": [
126 | null
127 | ]
128 | },
129 | {
130 | "id": 11,
131 | "type": "PreviewAudio",
132 | "pos": {
133 | "0": 976,
134 | "1": 342
135 | },
136 | "size": {
137 | "0": 315,
138 | "1": 76
139 | },
140 | "flags": {},
141 | "order": 10,
142 | "mode": 0,
143 | "inputs": [
144 | {
145 | "name": "audio",
146 | "type": "AUDIO",
147 | "link": 16
148 | }
149 | ],
150 | "outputs": [],
151 | "title": "PreviewAudio最终音频",
152 | "properties": {
153 | "Node name for S&R": "PreviewAudio"
154 | },
155 | "widgets_values": [
156 | null
157 | ]
158 | },
159 | {
160 | "id": 1,
161 | "type": "VocalSeparationNode",
162 | "pos": {
163 | "0": 437,
164 | "1": 48
165 | },
166 | "size": {
167 | "0": 315,
168 | "1": 126
169 | },
170 | "flags": {},
171 | "order": 2,
172 | "mode": 0,
173 | "inputs": [
174 | {
175 | "name": "music",
176 | "type": "AUDIO",
177 | "link": 1
178 | }
179 | ],
180 | "outputs": [
181 | {
182 | "name": "vocals_AUDIO",
183 | "type": "AUDIO",
184 | "links": [
185 | 2,
186 | 21
187 | ],
188 | "slot_index": 0,
189 | "shape": 3
190 | },
191 | {
192 | "name": "instrumental_AUDIO",
193 | "type": "AUDIO",
194 | "links": [
195 | 14,
196 | 19
197 | ],
198 | "slot_index": 1,
199 | "shape": 3
200 | }
201 | ],
202 | "properties": {
203 | "Node name for S&R": "VocalSeparationNode"
204 | },
205 | "widgets_values": [
206 | "bs_roformer",
207 | 4,
208 | true
209 | ]
210 | },
211 | {
212 | "id": 8,
213 | "type": "VocalSeparationNode",
214 | "pos": {
215 | "0": 426,
216 | "1": 542
217 | },
218 | "size": {
219 | "0": 315,
220 | "1": 126
221 | },
222 | "flags": {},
223 | "order": 3,
224 | "mode": 0,
225 | "inputs": [
226 | {
227 | "name": "music",
228 | "type": "AUDIO",
229 | "link": 9
230 | }
231 | ],
232 | "outputs": [
233 | {
234 | "name": "vocals_AUDIO",
235 | "type": "AUDIO",
236 | "links": [
237 | 18,
238 | 22
239 | ],
240 | "slot_index": 0,
241 | "shape": 3
242 | },
243 | {
244 | "name": "instrumental_AUDIO",
245 | "type": "AUDIO",
246 | "links": [],
247 | "slot_index": 1,
248 | "shape": 3
249 | }
250 | ],
251 | "properties": {
252 | "Node name for S&R": "VocalSeparationNode"
253 | },
254 | "widgets_values": [
255 | "bs_roformer",
256 | 4,
257 | true
258 | ]
259 | },
260 | {
261 | "id": 7,
262 | "type": "LoadAudio",
263 | "pos": {
264 | "0": 42,
265 | "1": 270
266 | },
267 | "size": {
268 | "0": 315,
269 | "1": 124
270 | },
271 | "flags": {},
272 | "order": 1,
273 | "mode": 0,
274 | "inputs": [],
275 | "outputs": [
276 | {
277 | "name": "AUDIO",
278 | "type": "AUDIO",
279 | "links": [
280 | 9
281 | ],
282 | "slot_index": 0,
283 | "shape": 3
284 | }
285 | ],
286 | "title": "LoadAudio上传目标音色",
287 | "properties": {
288 | "Node name for S&R": "LoadAudio"
289 | },
290 | "widgets_values": [
291 | "dingzhen_0.wav",
292 | null,
293 | ""
294 | ]
295 | },
296 | {
297 | "id": 15,
298 | "type": "PreviewAudio",
299 | "pos": {
300 | "0": 972,
301 | "1": 464
302 | },
303 | "size": {
304 | "0": 315,
305 | "1": 76
306 | },
307 | "flags": {},
308 | "order": 9,
309 | "mode": 0,
310 | "inputs": [
311 | {
312 | "name": "audio",
313 | "type": "AUDIO",
314 | "link": 26
315 | }
316 | ],
317 | "outputs": [],
318 | "title": "PreviewAudio克隆结果",
319 | "properties": {
320 | "Node name for S&R": "PreviewAudio"
321 | },
322 | "widgets_values": [
323 | null
324 | ]
325 | },
326 | {
327 | "id": 16,
328 | "type": "SeedVC4SingNode",
329 | "pos": {
330 | "0": 440,
331 | "1": 302
332 | },
333 | "size": {
334 | "0": 315,
335 | "1": 174
336 | },
337 | "flags": {},
338 | "order": 7,
339 | "mode": 0,
340 | "inputs": [
341 | {
342 | "name": "source",
343 | "type": "AUDIO",
344 | "link": 21
345 | },
346 | {
347 | "name": "target",
348 | "type": "AUDIO",
349 | "link": 22
350 | }
351 | ],
352 | "outputs": [
353 | {
354 | "name": "AUDIO",
355 | "type": "AUDIO",
356 | "links": [
357 | 24,
358 | 26
359 | ],
360 | "shape": 3,
361 | "slot_index": 0
362 | }
363 | ],
364 | "properties": {
365 | "Node name for S&R": "SeedVC4SingNode"
366 | },
367 | "widgets_values": [
368 | 0,
369 | 50,
370 | 1,
371 | 0.7,
372 | 3
373 | ]
374 | },
375 | {
376 | "id": 9,
377 | "type": "CombineAudioNode",
378 | "pos": {
379 | "0": 731,
380 | "1": 215
381 | },
382 | "size": {
383 | "0": 229.20001220703125,
384 | "1": 46
385 | },
386 | "flags": {},
387 | "order": 8,
388 | "mode": 0,
389 | "inputs": [
390 | {
391 | "name": "vocal",
392 | "type": "AUDIO",
393 | "link": 24
394 | },
395 | {
396 | "name": "instrumental",
397 | "type": "AUDIO",
398 | "link": 14
399 | }
400 | ],
401 | "outputs": [
402 | {
403 | "name": "AUDIO",
404 | "type": "AUDIO",
405 | "links": [
406 | 16
407 | ],
408 | "slot_index": 0,
409 | "shape": 3
410 | }
411 | ],
412 | "properties": {
413 | "Node name for S&R": "CombineAudioNode"
414 | }
415 | }
416 | ],
417 | "links": [
418 | [
419 | 1,
420 | 2,
421 | 0,
422 | 1,
423 | 0,
424 | "AUDIO"
425 | ],
426 | [
427 | 2,
428 | 1,
429 | 0,
430 | 3,
431 | 0,
432 | "AUDIO"
433 | ],
434 | [
435 | 9,
436 | 7,
437 | 0,
438 | 8,
439 | 0,
440 | "AUDIO"
441 | ],
442 | [
443 | 14,
444 | 1,
445 | 1,
446 | 9,
447 | 1,
448 | "AUDIO"
449 | ],
450 | [
451 | 16,
452 | 9,
453 | 0,
454 | 11,
455 | 0,
456 | "AUDIO"
457 | ],
458 | [
459 | 18,
460 | 8,
461 | 0,
462 | 13,
463 | 0,
464 | "AUDIO"
465 | ],
466 | [
467 | 19,
468 | 1,
469 | 1,
470 | 14,
471 | 0,
472 | "AUDIO"
473 | ],
474 | [
475 | 21,
476 | 1,
477 | 0,
478 | 16,
479 | 0,
480 | "AUDIO"
481 | ],
482 | [
483 | 22,
484 | 8,
485 | 0,
486 | 16,
487 | 1,
488 | "AUDIO"
489 | ],
490 | [
491 | 24,
492 | 16,
493 | 0,
494 | 9,
495 | 0,
496 | "AUDIO"
497 | ],
498 | [
499 | 26,
500 | 16,
501 | 0,
502 | 15,
503 | 0,
504 | "AUDIO"
505 | ]
506 | ],
507 | "groups": [
508 | {
509 | "title": "输入组",
510 | "bounding": [
511 | 24,
512 | 2,
513 | 350,
514 | 689
515 | ],
516 | "color": "#3f789e",
517 | "font_size": 24,
518 | "flags": {}
519 | },
520 | {
521 | "title": "输出组",
522 | "bounding": [
523 | 956,
524 | 0,
525 | 352,
526 | 679
527 | ],
528 | "color": "#3f789e",
529 | "font_size": 24,
530 | "flags": {}
531 | }
532 | ],
533 | "config": {},
534 | "extra": {
535 | "ds": {
536 | "scale": 1,
537 | "offset": [
538 | 88,
539 | -12.79998779296875
540 | ]
541 | }
542 | },
543 | "version": 0.4
544 | }
--------------------------------------------------------------------------------
/doc/speech_workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 13,
3 | "last_link_id": 23,
4 | "nodes": [
5 | {
6 | "id": 9,
7 | "type": "CombineAudioNode",
8 | "pos": {
9 | "0": 809,
10 | "1": 234
11 | },
12 | "size": {
13 | "0": 229.20001220703125,
14 | "1": 46
15 | },
16 | "flags": {},
17 | "order": 7,
18 | "mode": 0,
19 | "inputs": [
20 | {
21 | "name": "vocal",
22 | "type": "AUDIO",
23 | "link": 23
24 | },
25 | {
26 | "name": "instrumental",
27 | "type": "AUDIO",
28 | "link": 14
29 | }
30 | ],
31 | "outputs": [
32 | {
33 | "name": "AUDIO",
34 | "type": "AUDIO",
35 | "links": [
36 | 16
37 | ],
38 | "slot_index": 0,
39 | "shape": 3
40 | }
41 | ],
42 | "properties": {
43 | "Node name for S&R": "CombineAudioNode"
44 | }
45 | },
46 | {
47 | "id": 3,
48 | "type": "PreviewAudio",
49 | "pos": {
50 | "0": 897.6000366210938,
51 | "1": 73.80000305175781
52 | },
53 | "size": {
54 | "0": 315,
55 | "1": 76
56 | },
57 | "flags": {},
58 | "order": 5,
59 | "mode": 0,
60 | "inputs": [
61 | {
62 | "name": "audio",
63 | "type": "AUDIO",
64 | "link": 2
65 | }
66 | ],
67 | "outputs": [],
68 | "properties": {
69 | "Node name for S&R": "PreviewAudio"
70 | },
71 | "widgets_values": [
72 | null
73 | ]
74 | },
75 | {
76 | "id": 11,
77 | "type": "PreviewAudio",
78 | "pos": {
79 | "0": 898,
80 | "1": 366
81 | },
82 | "size": {
83 | "0": 315,
84 | "1": 76
85 | },
86 | "flags": {},
87 | "order": 8,
88 | "mode": 0,
89 | "inputs": [
90 | {
91 | "name": "audio",
92 | "type": "AUDIO",
93 | "link": 16
94 | }
95 | ],
96 | "outputs": [],
97 | "properties": {
98 | "Node name for S&R": "PreviewAudio"
99 | },
100 | "widgets_values": [
101 | null
102 | ]
103 | },
104 | {
105 | "id": 7,
106 | "type": "LoadAudio",
107 | "pos": {
108 | "0": 26,
109 | "1": 511
110 | },
111 | "size": {
112 | "0": 315,
113 | "1": 124
114 | },
115 | "flags": {},
116 | "order": 0,
117 | "mode": 0,
118 | "inputs": [],
119 | "outputs": [
120 | {
121 | "name": "AUDIO",
122 | "type": "AUDIO",
123 | "links": [
124 | 9
125 | ],
126 | "slot_index": 0,
127 | "shape": 3
128 | }
129 | ],
130 | "title": "LoadAudio目标音色",
131 | "properties": {
132 | "Node name for S&R": "LoadAudio"
133 | },
134 | "widgets_values": [
135 | "dingzhen_0.wav",
136 | null,
137 | ""
138 | ]
139 | },
140 | {
141 | "id": 2,
142 | "type": "LoadAudio",
143 | "pos": {
144 | "0": 27,
145 | "1": 151
146 | },
147 | "size": {
148 | "0": 315,
149 | "1": 124
150 | },
151 | "flags": {},
152 | "order": 1,
153 | "mode": 0,
154 | "inputs": [],
155 | "outputs": [
156 | {
157 | "name": "AUDIO",
158 | "type": "AUDIO",
159 | "links": [
160 | 1
161 | ],
162 | "slot_index": 0,
163 | "shape": 3
164 | }
165 | ],
166 | "title": "LoadAudio上传演讲",
167 | "properties": {
168 | "Node name for S&R": "LoadAudio"
169 | },
170 | "widgets_values": [
171 | "s4p1.wav",
172 | null,
173 | ""
174 | ]
175 | },
176 | {
177 | "id": 1,
178 | "type": "VocalSeparationNode",
179 | "pos": {
180 | "0": 432,
181 | "1": 74
182 | },
183 | "size": {
184 | "0": 315,
185 | "1": 126
186 | },
187 | "flags": {},
188 | "order": 3,
189 | "mode": 0,
190 | "inputs": [
191 | {
192 | "name": "music",
193 | "type": "AUDIO",
194 | "link": 1
195 | }
196 | ],
197 | "outputs": [
198 | {
199 | "name": "vocals_AUDIO",
200 | "type": "AUDIO",
201 | "links": [
202 | 2,
203 | 21
204 | ],
205 | "slot_index": 0,
206 | "shape": 3
207 | },
208 | {
209 | "name": "instrumental_AUDIO",
210 | "type": "AUDIO",
211 | "links": [
212 | 14
213 | ],
214 | "slot_index": 1,
215 | "shape": 3
216 | }
217 | ],
218 | "properties": {
219 | "Node name for S&R": "VocalSeparationNode"
220 | },
221 | "widgets_values": [
222 | "bs_roformer",
223 | 4,
224 | true
225 | ]
226 | },
227 | {
228 | "id": 8,
229 | "type": "VocalSeparationNode",
230 | "pos": {
231 | "0": 418,
232 | "1": 520
233 | },
234 | "size": {
235 | "0": 315,
236 | "1": 126
237 | },
238 | "flags": {},
239 | "order": 2,
240 | "mode": 0,
241 | "inputs": [
242 | {
243 | "name": "music",
244 | "type": "AUDIO",
245 | "link": 9
246 | }
247 | ],
248 | "outputs": [
249 | {
250 | "name": "vocals_AUDIO",
251 | "type": "AUDIO",
252 | "links": [
253 | 22
254 | ],
255 | "slot_index": 0,
256 | "shape": 3
257 | },
258 | {
259 | "name": "instrumental_AUDIO",
260 | "type": "AUDIO",
261 | "links": [
262 | 15
263 | ],
264 | "slot_index": 1,
265 | "shape": 3
266 | }
267 | ],
268 | "properties": {
269 | "Node name for S&R": "VocalSeparationNode"
270 | },
271 | "widgets_values": [
272 | "bs_roformer",
273 | 4,
274 | true
275 | ]
276 | },
277 | {
278 | "id": 10,
279 | "type": "PreviewAudio",
280 | "pos": {
281 | "0": 897,
282 | "1": 550
283 | },
284 | "size": {
285 | "0": 315,
286 | "1": 76
287 | },
288 | "flags": {},
289 | "order": 4,
290 | "mode": 0,
291 | "inputs": [
292 | {
293 | "name": "audio",
294 | "type": "AUDIO",
295 | "link": 15
296 | }
297 | ],
298 | "outputs": [],
299 | "properties": {
300 | "Node name for S&R": "PreviewAudio"
301 | },
302 | "widgets_values": [
303 | null
304 | ]
305 | },
306 | {
307 | "id": 13,
308 | "type": "SeedVCNode",
309 | "pos": {
310 | "0": 426,
311 | "1": 294
312 | },
313 | "size": {
314 | "0": 315,
315 | "1": 150
316 | },
317 | "flags": {},
318 | "order": 6,
319 | "mode": 0,
320 | "inputs": [
321 | {
322 | "name": "source",
323 | "type": "AUDIO",
324 | "link": 21
325 | },
326 | {
327 | "name": "target",
328 | "type": "AUDIO",
329 | "link": 22
330 | }
331 | ],
332 | "outputs": [
333 | {
334 | "name": "AUDIO",
335 | "type": "AUDIO",
336 | "links": [
337 | 23
338 | ],
339 | "shape": 3,
340 | "slot_index": 0
341 | }
342 | ],
343 | "properties": {
344 | "Node name for S&R": "SeedVCNode"
345 | },
346 | "widgets_values": [
347 | 50,
348 | 1,
349 | 0.7,
350 | 3
351 | ]
352 | }
353 | ],
354 | "links": [
355 | [
356 | 1,
357 | 2,
358 | 0,
359 | 1,
360 | 0,
361 | "AUDIO"
362 | ],
363 | [
364 | 2,
365 | 1,
366 | 0,
367 | 3,
368 | 0,
369 | "AUDIO"
370 | ],
371 | [
372 | 9,
373 | 7,
374 | 0,
375 | 8,
376 | 0,
377 | "AUDIO"
378 | ],
379 | [
380 | 14,
381 | 1,
382 | 1,
383 | 9,
384 | 1,
385 | "AUDIO"
386 | ],
387 | [
388 | 15,
389 | 8,
390 | 1,
391 | 10,
392 | 0,
393 | "AUDIO"
394 | ],
395 | [
396 | 16,
397 | 9,
398 | 0,
399 | 11,
400 | 0,
401 | "AUDIO"
402 | ],
403 | [
404 | 21,
405 | 1,
406 | 0,
407 | 13,
408 | 0,
409 | "AUDIO"
410 | ],
411 | [
412 | 22,
413 | 8,
414 | 0,
415 | 13,
416 | 1,
417 | "AUDIO"
418 | ],
419 | [
420 | 23,
421 | 13,
422 | 0,
423 | 9,
424 | 0,
425 | "AUDIO"
426 | ]
427 | ],
428 | "groups": [],
429 | "config": {},
430 | "extra": {
431 | "ds": {
432 | "scale": 1,
433 | "offset": [
434 | 88,
435 | -12.79998779296875
436 | ]
437 | }
438 | },
439 | "version": 0.4
440 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.13.1
2 | onnxruntime-gpu==1.19.0
3 | librosa==0.10.2
4 | huggingface-hub
5 | munch
6 | einops
7 | descript-audio-codec
8 | git+https://github.com/openai/whisper.git
--------------------------------------------------------------------------------
/seedvc/dac/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
3 | # preserved here for legacy reasons
4 | __model_version__ = "latest"
5 |
6 | import audiotools
7 |
8 | audiotools.ml.BaseModel.INTERN += ["dac.**"]
9 | audiotools.ml.BaseModel.EXTERN += ["einops"]
10 |
11 |
12 | from . import nn
13 | from . import model
14 | from . import utils
15 | from .model import DAC
16 | from .model import DACFile
17 |
--------------------------------------------------------------------------------
/seedvc/dac/__main__.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import argbind
4 |
5 | from .utils import download
6 | from .utils.decode import decode
7 | from .utils.encode import encode
8 |
9 | STAGES = ["encode", "decode", "download"]
10 |
11 |
12 | def run(stage: str):
13 | """Run stages.
14 |
15 | Parameters
16 | ----------
17 | stage : str
18 | Stage to run
19 | """
20 | if stage not in STAGES:
21 | raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22 | stage_fn = globals()[stage]
23 |
24 | if stage == "download":
25 | stage_fn()
26 | return
27 |
28 | stage_fn()
29 |
30 |
31 | if __name__ == "__main__":
32 | group = sys.argv.pop(1)
33 | args = argbind.parse_args(group=group)
34 |
35 | with argbind.scope(args):
36 | run(group)
37 |
--------------------------------------------------------------------------------
/seedvc/dac/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import CodecMixin
2 | from .base import DACFile
3 | from .dac import DAC
4 | from .discriminator import Discriminator
5 |
--------------------------------------------------------------------------------
/seedvc/dac/model/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from pathlib import Path
4 | from typing import Union
5 |
6 | import numpy as np
7 | import torch
8 | import tqdm
9 | from audiotools import AudioSignal
10 | from torch import nn
11 |
12 | SUPPORTED_VERSIONS = ["1.0.0"]
13 |
14 |
15 | @dataclass
16 | class DACFile:
17 | codes: torch.Tensor
18 |
19 | # Metadata
20 | chunk_length: int
21 | original_length: int
22 | input_db: float
23 | channels: int
24 | sample_rate: int
25 | padding: bool
26 | dac_version: str
27 |
28 | def save(self, path):
29 | artifacts = {
30 | "codes": self.codes.numpy().astype(np.uint16),
31 | "metadata": {
32 | "input_db": self.input_db.numpy().astype(np.float32),
33 | "original_length": self.original_length,
34 | "sample_rate": self.sample_rate,
35 | "chunk_length": self.chunk_length,
36 | "channels": self.channels,
37 | "padding": self.padding,
38 | "dac_version": SUPPORTED_VERSIONS[-1],
39 | },
40 | }
41 | path = Path(path).with_suffix(".dac")
42 | with open(path, "wb") as f:
43 | np.save(f, artifacts)
44 | return path
45 |
46 | @classmethod
47 | def load(cls, path):
48 | artifacts = np.load(path, allow_pickle=True)[()]
49 | codes = torch.from_numpy(artifacts["codes"].astype(int))
50 | if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51 | raise RuntimeError(
52 | f"Given file {path} can't be loaded with this version of descript-audio-codec."
53 | )
54 | return cls(codes=codes, **artifacts["metadata"])
55 |
56 |
57 | class CodecMixin:
58 | @property
59 | def padding(self):
60 | if not hasattr(self, "_padding"):
61 | self._padding = True
62 | return self._padding
63 |
64 | @padding.setter
65 | def padding(self, value):
66 | assert isinstance(value, bool)
67 |
68 | layers = [
69 | l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70 | ]
71 |
72 | for layer in layers:
73 | if value:
74 | if hasattr(layer, "original_padding"):
75 | layer.padding = layer.original_padding
76 | else:
77 | layer.original_padding = layer.padding
78 | layer.padding = tuple(0 for _ in range(len(layer.padding)))
79 |
80 | self._padding = value
81 |
82 | def get_delay(self):
83 | # Any number works here, delay is invariant to input length
84 | l_out = self.get_output_length(0)
85 | L = l_out
86 |
87 | layers = []
88 | for layer in self.modules():
89 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90 | layers.append(layer)
91 |
92 | for layer in reversed(layers):
93 | d = layer.dilation[0]
94 | k = layer.kernel_size[0]
95 | s = layer.stride[0]
96 |
97 | if isinstance(layer, nn.ConvTranspose1d):
98 | L = ((L - d * (k - 1) - 1) / s) + 1
99 | elif isinstance(layer, nn.Conv1d):
100 | L = (L - 1) * s + d * (k - 1) + 1
101 |
102 | L = math.ceil(L)
103 |
104 | l_in = L
105 |
106 | return (l_in - l_out) // 2
107 |
108 | def get_output_length(self, input_length):
109 | L = input_length
110 | # Calculate output length
111 | for layer in self.modules():
112 | if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113 | d = layer.dilation[0]
114 | k = layer.kernel_size[0]
115 | s = layer.stride[0]
116 |
117 | if isinstance(layer, nn.Conv1d):
118 | L = ((L - d * (k - 1) - 1) / s) + 1
119 | elif isinstance(layer, nn.ConvTranspose1d):
120 | L = (L - 1) * s + d * (k - 1) + 1
121 |
122 | L = math.floor(L)
123 | return L
124 |
125 | @torch.no_grad()
126 | def compress(
127 | self,
128 | audio_path_or_signal: Union[str, Path, AudioSignal],
129 | win_duration: float = 1.0,
130 | verbose: bool = False,
131 | normalize_db: float = -16,
132 | n_quantizers: int = None,
133 | ) -> DACFile:
134 | """Processes an audio signal from a file or AudioSignal object into
135 | discrete codes. This function processes the signal in short windows,
136 | using constant GPU memory.
137 |
138 | Parameters
139 | ----------
140 | audio_path_or_signal : Union[str, Path, AudioSignal]
141 | audio signal to reconstruct
142 | win_duration : float, optional
143 | window duration in seconds, by default 5.0
144 | verbose : bool, optional
145 | by default False
146 | normalize_db : float, optional
147 | normalize db, by default -16
148 |
149 | Returns
150 | -------
151 | DACFile
152 | Object containing compressed codes and metadata
153 | required for decompression
154 | """
155 | audio_signal = audio_path_or_signal
156 | if isinstance(audio_signal, (str, Path)):
157 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158 |
159 | self.eval()
160 | original_padding = self.padding
161 | original_device = audio_signal.device
162 |
163 | audio_signal = audio_signal.clone()
164 | original_sr = audio_signal.sample_rate
165 |
166 | resample_fn = audio_signal.resample
167 | loudness_fn = audio_signal.loudness
168 |
169 | # If audio is > 10 minutes long, use the ffmpeg versions
170 | if audio_signal.signal_duration >= 10 * 60 * 60:
171 | resample_fn = audio_signal.ffmpeg_resample
172 | loudness_fn = audio_signal.ffmpeg_loudness
173 |
174 | original_length = audio_signal.signal_length
175 | resample_fn(self.sample_rate)
176 | input_db = loudness_fn()
177 |
178 | if normalize_db is not None:
179 | audio_signal.normalize(normalize_db)
180 | audio_signal.ensure_max_of_audio()
181 |
182 | nb, nac, nt = audio_signal.audio_data.shape
183 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
184 | win_duration = (
185 | audio_signal.signal_duration if win_duration is None else win_duration
186 | )
187 |
188 | if audio_signal.signal_duration <= win_duration:
189 | # Unchunked compression (used if signal length < win duration)
190 | self.padding = True
191 | n_samples = nt
192 | hop = nt
193 | else:
194 | # Chunked inference
195 | self.padding = False
196 | # Zero-pad signal on either side by the delay
197 | audio_signal.zero_pad(self.delay, self.delay)
198 | n_samples = int(win_duration * self.sample_rate)
199 | # Round n_samples to nearest hop length multiple
200 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
201 | hop = self.get_output_length(n_samples)
202 |
203 | codes = []
204 | range_fn = range if not verbose else tqdm.trange
205 |
206 | for i in range_fn(0, nt, hop):
207 | x = audio_signal[..., i : i + n_samples]
208 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
209 |
210 | audio_data = x.audio_data.to(self.device)
211 | audio_data = self.preprocess(audio_data, self.sample_rate)
212 | _, c, _, _, _ = self.encode(audio_data, n_quantizers)
213 | codes.append(c.to(original_device))
214 | chunk_length = c.shape[-1]
215 |
216 | codes = torch.cat(codes, dim=-1)
217 |
218 | dac_file = DACFile(
219 | codes=codes,
220 | chunk_length=chunk_length,
221 | original_length=original_length,
222 | input_db=input_db,
223 | channels=nac,
224 | sample_rate=original_sr,
225 | padding=self.padding,
226 | dac_version=SUPPORTED_VERSIONS[-1],
227 | )
228 |
229 | if n_quantizers is not None:
230 | codes = codes[:, :n_quantizers, :]
231 |
232 | self.padding = original_padding
233 | return dac_file
234 |
235 | @torch.no_grad()
236 | def decompress(
237 | self,
238 | obj: Union[str, Path, DACFile],
239 | verbose: bool = False,
240 | ) -> AudioSignal:
241 | """Reconstruct audio from a given .dac file
242 |
243 | Parameters
244 | ----------
245 | obj : Union[str, Path, DACFile]
246 | .dac file location or corresponding DACFile object.
247 | verbose : bool, optional
248 | Prints progress if True, by default False
249 |
250 | Returns
251 | -------
252 | AudioSignal
253 | Object with the reconstructed audio
254 | """
255 | self.eval()
256 | if isinstance(obj, (str, Path)):
257 | obj = DACFile.load(obj)
258 |
259 | original_padding = self.padding
260 | self.padding = obj.padding
261 |
262 | range_fn = range if not verbose else tqdm.trange
263 | codes = obj.codes
264 | original_device = codes.device
265 | chunk_length = obj.chunk_length
266 | recons = []
267 |
268 | for i in range_fn(0, codes.shape[-1], chunk_length):
269 | c = codes[..., i : i + chunk_length].to(self.device)
270 | z = self.quantizer.from_codes(c)[0]
271 | r = self.decode(z)
272 | recons.append(r.to(original_device))
273 |
274 | recons = torch.cat(recons, dim=-1)
275 | recons = AudioSignal(recons, self.sample_rate)
276 |
277 | resample_fn = recons.resample
278 | loudness_fn = recons.loudness
279 |
280 | # If audio is > 10 minutes long, use the ffmpeg versions
281 | if recons.signal_duration >= 10 * 60 * 60:
282 | resample_fn = recons.ffmpeg_resample
283 | loudness_fn = recons.ffmpeg_loudness
284 |
285 | recons.normalize(obj.input_db)
286 | resample_fn(obj.sample_rate)
287 | recons = recons[..., : obj.original_length]
288 | loudness_fn()
289 | recons.audio_data = recons.audio_data.reshape(
290 | -1, obj.channels, obj.original_length
291 | )
292 |
293 | self.padding = original_padding
294 | return recons
295 |
--------------------------------------------------------------------------------
/seedvc/dac/model/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from audiotools import AudioSignal
5 | from audiotools import ml
6 | from audiotools import STFTParams
7 | from einops import rearrange
8 | from torch.nn.utils import weight_norm
9 |
10 |
11 | def WNConv1d(*args, **kwargs):
12 | act = kwargs.pop("act", True)
13 | conv = weight_norm(nn.Conv1d(*args, **kwargs))
14 | if not act:
15 | return conv
16 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
17 |
18 |
19 | def WNConv2d(*args, **kwargs):
20 | act = kwargs.pop("act", True)
21 | conv = weight_norm(nn.Conv2d(*args, **kwargs))
22 | if not act:
23 | return conv
24 | return nn.Sequential(conv, nn.LeakyReLU(0.1))
25 |
26 |
27 | class MPD(nn.Module):
28 | def __init__(self, period):
29 | super().__init__()
30 | self.period = period
31 | self.convs = nn.ModuleList(
32 | [
33 | WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34 | WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35 | WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36 | WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37 | WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38 | ]
39 | )
40 | self.conv_post = WNConv2d(
41 | 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42 | )
43 |
44 | def pad_to_period(self, x):
45 | t = x.shape[-1]
46 | x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47 | return x
48 |
49 | def forward(self, x):
50 | fmap = []
51 |
52 | x = self.pad_to_period(x)
53 | x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54 |
55 | for layer in self.convs:
56 | x = layer(x)
57 | fmap.append(x)
58 |
59 | x = self.conv_post(x)
60 | fmap.append(x)
61 |
62 | return fmap
63 |
64 |
65 | class MSD(nn.Module):
66 | def __init__(self, rate: int = 1, sample_rate: int = 44100):
67 | super().__init__()
68 | self.convs = nn.ModuleList(
69 | [
70 | WNConv1d(1, 16, 15, 1, padding=7),
71 | WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72 | WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73 | WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74 | WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75 | WNConv1d(1024, 1024, 5, 1, padding=2),
76 | ]
77 | )
78 | self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79 | self.sample_rate = sample_rate
80 | self.rate = rate
81 |
82 | def forward(self, x):
83 | x = AudioSignal(x, self.sample_rate)
84 | x.resample(self.sample_rate // self.rate)
85 | x = x.audio_data
86 |
87 | fmap = []
88 |
89 | for l in self.convs:
90 | x = l(x)
91 | fmap.append(x)
92 | x = self.conv_post(x)
93 | fmap.append(x)
94 |
95 | return fmap
96 |
97 |
98 | BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99 |
100 |
101 | class MRD(nn.Module):
102 | def __init__(
103 | self,
104 | window_length: int,
105 | hop_factor: float = 0.25,
106 | sample_rate: int = 44100,
107 | bands: list = BANDS,
108 | ):
109 | """Complex multi-band spectrogram discriminator.
110 | Parameters
111 | ----------
112 | window_length : int
113 | Window length of STFT.
114 | hop_factor : float, optional
115 | Hop factor of the STFT, defaults to ``0.25 * window_length``.
116 | sample_rate : int, optional
117 | Sampling rate of audio in Hz, by default 44100
118 | bands : list, optional
119 | Bands to run discriminator over.
120 | """
121 | super().__init__()
122 |
123 | self.window_length = window_length
124 | self.hop_factor = hop_factor
125 | self.sample_rate = sample_rate
126 | self.stft_params = STFTParams(
127 | window_length=window_length,
128 | hop_length=int(window_length * hop_factor),
129 | match_stride=True,
130 | )
131 |
132 | n_fft = window_length // 2 + 1
133 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134 | self.bands = bands
135 |
136 | ch = 32
137 | convs = lambda: nn.ModuleList(
138 | [
139 | WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142 | WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143 | WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144 | ]
145 | )
146 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147 | self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148 |
149 | def spectrogram(self, x):
150 | x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151 | x = torch.view_as_real(x.stft())
152 | x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153 | # Split into bands
154 | x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155 | return x_bands
156 |
157 | def forward(self, x):
158 | x_bands = self.spectrogram(x)
159 | fmap = []
160 |
161 | x = []
162 | for band, stack in zip(x_bands, self.band_convs):
163 | for layer in stack:
164 | band = layer(band)
165 | fmap.append(band)
166 | x.append(band)
167 |
168 | x = torch.cat(x, dim=-1)
169 | x = self.conv_post(x)
170 | fmap.append(x)
171 |
172 | return fmap
173 |
174 |
175 | class Discriminator(nn.Module):
176 | def __init__(
177 | self,
178 | rates: list = [],
179 | periods: list = [2, 3, 5, 7, 11],
180 | fft_sizes: list = [2048, 1024, 512],
181 | sample_rate: int = 44100,
182 | bands: list = BANDS,
183 | ):
184 | """Discriminator that combines multiple discriminators.
185 |
186 | Parameters
187 | ----------
188 | rates : list, optional
189 | sampling rates (in Hz) to run MSD at, by default []
190 | If empty, MSD is not used.
191 | periods : list, optional
192 | periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193 | fft_sizes : list, optional
194 | Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195 | sample_rate : int, optional
196 | Sampling rate of audio in Hz, by default 44100
197 | bands : list, optional
198 | Bands to run MRD at, by default `BANDS`
199 | """
200 | super().__init__()
201 | discs = []
202 | discs += [MPD(p) for p in periods]
203 | discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204 | discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205 | self.discriminators = nn.ModuleList(discs)
206 |
207 | def preprocess(self, y):
208 | # Remove DC offset
209 | y = y - y.mean(dim=-1, keepdims=True)
210 | # Peak normalize the volume of input audio
211 | y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212 | return y
213 |
214 | def forward(self, x):
215 | x = self.preprocess(x)
216 | fmaps = [d(x) for d in self.discriminators]
217 | return fmaps
218 |
219 |
220 | if __name__ == "__main__":
221 | disc = Discriminator()
222 | x = torch.zeros(1, 1, 44100)
223 | results = disc(x)
224 | for i, result in enumerate(results):
225 | print(f"disc{i}")
226 | for i, r in enumerate(result):
227 | print(r.shape, r.mean(), r.min(), r.max())
228 | print()
229 |
--------------------------------------------------------------------------------
/seedvc/dac/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from . import layers
2 | from . import loss
3 | from . import quantize
4 |
--------------------------------------------------------------------------------
/seedvc/dac/nn/layers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 | from torch.nn.utils import weight_norm
7 |
8 |
9 | def WNConv1d(*args, **kwargs):
10 | return weight_norm(nn.Conv1d(*args, **kwargs))
11 |
12 |
13 | def WNConvTranspose1d(*args, **kwargs):
14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15 |
16 |
17 | # Scripting this brings model speed up 1.4x
18 | @torch.jit.script
19 | def snake(x, alpha):
20 | shape = x.shape
21 | x = x.reshape(shape[0], shape[1], -1)
22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23 | x = x.reshape(shape)
24 | return x
25 |
26 |
27 | class Snake1d(nn.Module):
28 | def __init__(self, channels):
29 | super().__init__()
30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31 |
32 | def forward(self, x):
33 | return snake(x, self.alpha)
34 |
--------------------------------------------------------------------------------
/seedvc/dac/nn/loss.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from typing import List
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from audiotools import AudioSignal
7 | from audiotools import STFTParams
8 | from torch import nn
9 |
10 |
11 | class L1Loss(nn.L1Loss):
12 | """L1 Loss between AudioSignals. Defaults
13 | to comparing ``audio_data``, but any
14 | attribute of an AudioSignal can be used.
15 |
16 | Parameters
17 | ----------
18 | attribute : str, optional
19 | Attribute of signal to compare, defaults to ``audio_data``.
20 | weight : float, optional
21 | Weight of this loss, defaults to 1.0.
22 |
23 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24 | """
25 |
26 | def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27 | self.attribute = attribute
28 | self.weight = weight
29 | super().__init__(**kwargs)
30 |
31 | def forward(self, x: AudioSignal, y: AudioSignal):
32 | """
33 | Parameters
34 | ----------
35 | x : AudioSignal
36 | Estimate AudioSignal
37 | y : AudioSignal
38 | Reference AudioSignal
39 |
40 | Returns
41 | -------
42 | torch.Tensor
43 | L1 loss between AudioSignal attributes.
44 | """
45 | if isinstance(x, AudioSignal):
46 | x = getattr(x, self.attribute)
47 | y = getattr(y, self.attribute)
48 | return super().forward(x, y)
49 |
50 |
51 | class SISDRLoss(nn.Module):
52 | """
53 | Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54 | of estimated and reference audio signals or aligned features.
55 |
56 | Parameters
57 | ----------
58 | scaling : int, optional
59 | Whether to use scale-invariant (True) or
60 | signal-to-noise ratio (False), by default True
61 | reduction : str, optional
62 | How to reduce across the batch (either 'mean',
63 | 'sum', or none).], by default ' mean'
64 | zero_mean : int, optional
65 | Zero mean the references and estimates before
66 | computing the loss, by default True
67 | clip_min : int, optional
68 | The minimum possible loss value. Helps network
69 | to not focus on making already good examples better, by default None
70 | weight : float, optional
71 | Weight of this loss, defaults to 1.0.
72 |
73 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74 | """
75 |
76 | def __init__(
77 | self,
78 | scaling: int = True,
79 | reduction: str = "mean",
80 | zero_mean: int = True,
81 | clip_min: int = None,
82 | weight: float = 1.0,
83 | ):
84 | self.scaling = scaling
85 | self.reduction = reduction
86 | self.zero_mean = zero_mean
87 | self.clip_min = clip_min
88 | self.weight = weight
89 | super().__init__()
90 |
91 | def forward(self, x: AudioSignal, y: AudioSignal):
92 | eps = 1e-8
93 | # nb, nc, nt
94 | if isinstance(x, AudioSignal):
95 | references = x.audio_data
96 | estimates = y.audio_data
97 | else:
98 | references = x
99 | estimates = y
100 |
101 | nb = references.shape[0]
102 | references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103 | estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104 |
105 | # samples now on axis 1
106 | if self.zero_mean:
107 | mean_reference = references.mean(dim=1, keepdim=True)
108 | mean_estimate = estimates.mean(dim=1, keepdim=True)
109 | else:
110 | mean_reference = 0
111 | mean_estimate = 0
112 |
113 | _references = references - mean_reference
114 | _estimates = estimates - mean_estimate
115 |
116 | references_projection = (_references**2).sum(dim=-2) + eps
117 | references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118 |
119 | scale = (
120 | (references_on_estimates / references_projection).unsqueeze(1)
121 | if self.scaling
122 | else 1
123 | )
124 |
125 | e_true = scale * _references
126 | e_res = _estimates - e_true
127 |
128 | signal = (e_true**2).sum(dim=1)
129 | noise = (e_res**2).sum(dim=1)
130 | sdr = -10 * torch.log10(signal / noise + eps)
131 |
132 | if self.clip_min is not None:
133 | sdr = torch.clamp(sdr, min=self.clip_min)
134 |
135 | if self.reduction == "mean":
136 | sdr = sdr.mean()
137 | elif self.reduction == "sum":
138 | sdr = sdr.sum()
139 | return sdr
140 |
141 |
142 | class MultiScaleSTFTLoss(nn.Module):
143 | """Computes the multi-scale STFT loss from [1].
144 |
145 | Parameters
146 | ----------
147 | window_lengths : List[int], optional
148 | Length of each window of each STFT, by default [2048, 512]
149 | loss_fn : typing.Callable, optional
150 | How to compare each loss, by default nn.L1Loss()
151 | clamp_eps : float, optional
152 | Clamp on the log magnitude, below, by default 1e-5
153 | mag_weight : float, optional
154 | Weight of raw magnitude portion of loss, by default 1.0
155 | log_weight : float, optional
156 | Weight of log magnitude portion of loss, by default 1.0
157 | pow : float, optional
158 | Power to raise magnitude to before taking log, by default 2.0
159 | weight : float, optional
160 | Weight of this loss, by default 1.0
161 | match_stride : bool, optional
162 | Whether to match the stride of convolutional layers, by default False
163 |
164 | References
165 | ----------
166 |
167 | 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168 | "DDSP: Differentiable Digital Signal Processing."
169 | International Conference on Learning Representations. 2019.
170 |
171 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172 | """
173 |
174 | def __init__(
175 | self,
176 | window_lengths: List[int] = [2048, 512],
177 | loss_fn: typing.Callable = nn.L1Loss(),
178 | clamp_eps: float = 1e-5,
179 | mag_weight: float = 1.0,
180 | log_weight: float = 1.0,
181 | pow: float = 2.0,
182 | weight: float = 1.0,
183 | match_stride: bool = False,
184 | window_type: str = None,
185 | ):
186 | super().__init__()
187 | self.stft_params = [
188 | STFTParams(
189 | window_length=w,
190 | hop_length=w // 4,
191 | match_stride=match_stride,
192 | window_type=window_type,
193 | )
194 | for w in window_lengths
195 | ]
196 | self.loss_fn = loss_fn
197 | self.log_weight = log_weight
198 | self.mag_weight = mag_weight
199 | self.clamp_eps = clamp_eps
200 | self.weight = weight
201 | self.pow = pow
202 |
203 | def forward(self, x: AudioSignal, y: AudioSignal):
204 | """Computes multi-scale STFT between an estimate and a reference
205 | signal.
206 |
207 | Parameters
208 | ----------
209 | x : AudioSignal
210 | Estimate signal
211 | y : AudioSignal
212 | Reference signal
213 |
214 | Returns
215 | -------
216 | torch.Tensor
217 | Multi-scale STFT loss.
218 | """
219 | loss = 0.0
220 | for s in self.stft_params:
221 | x.stft(s.window_length, s.hop_length, s.window_type)
222 | y.stft(s.window_length, s.hop_length, s.window_type)
223 | loss += self.log_weight * self.loss_fn(
224 | x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225 | y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226 | )
227 | loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228 | return loss
229 |
230 |
231 | class MelSpectrogramLoss(nn.Module):
232 | """Compute distance between mel spectrograms. Can be used
233 | in a multi-scale way.
234 |
235 | Parameters
236 | ----------
237 | n_mels : List[int]
238 | Number of mels per STFT, by default [150, 80],
239 | window_lengths : List[int], optional
240 | Length of each window of each STFT, by default [2048, 512]
241 | loss_fn : typing.Callable, optional
242 | How to compare each loss, by default nn.L1Loss()
243 | clamp_eps : float, optional
244 | Clamp on the log magnitude, below, by default 1e-5
245 | mag_weight : float, optional
246 | Weight of raw magnitude portion of loss, by default 1.0
247 | log_weight : float, optional
248 | Weight of log magnitude portion of loss, by default 1.0
249 | pow : float, optional
250 | Power to raise magnitude to before taking log, by default 2.0
251 | weight : float, optional
252 | Weight of this loss, by default 1.0
253 | match_stride : bool, optional
254 | Whether to match the stride of convolutional layers, by default False
255 |
256 | Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257 | """
258 |
259 | def __init__(
260 | self,
261 | n_mels: List[int] = [150, 80],
262 | window_lengths: List[int] = [2048, 512],
263 | loss_fn: typing.Callable = nn.L1Loss(),
264 | clamp_eps: float = 1e-5,
265 | mag_weight: float = 1.0,
266 | log_weight: float = 1.0,
267 | pow: float = 2.0,
268 | weight: float = 1.0,
269 | match_stride: bool = False,
270 | mel_fmin: List[float] = [0.0, 0.0],
271 | mel_fmax: List[float] = [None, None],
272 | window_type: str = None,
273 | ):
274 | super().__init__()
275 | self.stft_params = [
276 | STFTParams(
277 | window_length=w,
278 | hop_length=w // 4,
279 | match_stride=match_stride,
280 | window_type=window_type,
281 | )
282 | for w in window_lengths
283 | ]
284 | self.n_mels = n_mels
285 | self.loss_fn = loss_fn
286 | self.clamp_eps = clamp_eps
287 | self.log_weight = log_weight
288 | self.mag_weight = mag_weight
289 | self.weight = weight
290 | self.mel_fmin = mel_fmin
291 | self.mel_fmax = mel_fmax
292 | self.pow = pow
293 |
294 | def forward(self, x: AudioSignal, y: AudioSignal):
295 | """Computes mel loss between an estimate and a reference
296 | signal.
297 |
298 | Parameters
299 | ----------
300 | x : AudioSignal
301 | Estimate signal
302 | y : AudioSignal
303 | Reference signal
304 |
305 | Returns
306 | -------
307 | torch.Tensor
308 | Mel loss.
309 | """
310 | loss = 0.0
311 | for n_mels, fmin, fmax, s in zip(
312 | self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313 | ):
314 | kwargs = {
315 | "window_length": s.window_length,
316 | "hop_length": s.hop_length,
317 | "window_type": s.window_type,
318 | }
319 | x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320 | y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321 |
322 | loss += self.log_weight * self.loss_fn(
323 | x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324 | y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325 | )
326 | loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327 | return loss
328 |
329 |
330 | class GANLoss(nn.Module):
331 | """
332 | Computes a discriminator loss, given a discriminator on
333 | generated waveforms/spectrograms compared to ground truth
334 | waveforms/spectrograms. Computes the loss for both the
335 | discriminator and the generator in separate functions.
336 | """
337 |
338 | def __init__(self, discriminator):
339 | super().__init__()
340 | self.discriminator = discriminator
341 |
342 | def forward(self, fake, real):
343 | d_fake = self.discriminator(fake.audio_data)
344 | d_real = self.discriminator(real.audio_data)
345 | return d_fake, d_real
346 |
347 | def discriminator_loss(self, fake, real):
348 | d_fake, d_real = self.forward(fake.clone().detach(), real)
349 |
350 | loss_d = 0
351 | for x_fake, x_real in zip(d_fake, d_real):
352 | loss_d += torch.mean(x_fake[-1] ** 2)
353 | loss_d += torch.mean((1 - x_real[-1]) ** 2)
354 | return loss_d
355 |
356 | def generator_loss(self, fake, real):
357 | d_fake, d_real = self.forward(fake, real)
358 |
359 | loss_g = 0
360 | for x_fake in d_fake:
361 | loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362 |
363 | loss_feature = 0
364 |
365 | for i in range(len(d_fake)):
366 | for j in range(len(d_fake[i]) - 1):
367 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368 | return loss_g, loss_feature
369 |
--------------------------------------------------------------------------------
/seedvc/dac/nn/quantize.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 | from torch.nn.utils import weight_norm
9 |
10 | from dac.nn.layers import WNConv1d
11 |
12 | class VectorQuantizeLegacy(nn.Module):
13 | """
14 | Implementation of VQ similar to Karpathy's repo:
15 | https://github.com/karpathy/deep-vector-quantization
16 | removed in-out projection
17 | """
18 |
19 | def __init__(self, input_dim: int, codebook_size: int):
20 | super().__init__()
21 | self.codebook_size = codebook_size
22 | self.codebook = nn.Embedding(codebook_size, input_dim)
23 |
24 | def forward(self, z, z_mask=None):
25 | """Quantized the input tensor using a fixed codebook and returns
26 | the corresponding codebook vectors
27 |
28 | Parameters
29 | ----------
30 | z : Tensor[B x D x T]
31 |
32 | Returns
33 | -------
34 | Tensor[B x D x T]
35 | Quantized continuous representation of input
36 | Tensor[1]
37 | Commitment loss to train encoder to predict vectors closer to codebook
38 | entries
39 | Tensor[1]
40 | Codebook loss to update the codebook
41 | Tensor[B x T]
42 | Codebook indices (quantized discrete representation of input)
43 | Tensor[B x D x T]
44 | Projected latents (continuous representation of input before quantization)
45 | """
46 |
47 | z_e = z
48 | z_q, indices = self.decode_latents(z)
49 |
50 | if z_mask is not None:
51 | commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
52 | codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
53 | else:
54 | commitment_loss = F.mse_loss(z_e, z_q.detach())
55 | codebook_loss = F.mse_loss(z_q, z_e.detach())
56 | z_q = (
57 | z_e + (z_q - z_e).detach()
58 | ) # noop in forward pass, straight-through gradient estimator in backward pass
59 |
60 | return z_q, indices, z_e, commitment_loss, codebook_loss
61 |
62 | def embed_code(self, embed_id):
63 | return F.embedding(embed_id, self.codebook.weight)
64 |
65 | def decode_code(self, embed_id):
66 | return self.embed_code(embed_id).transpose(1, 2)
67 |
68 | def decode_latents(self, latents):
69 | encodings = rearrange(latents, "b d t -> (b t) d")
70 | codebook = self.codebook.weight # codebook: (N x D)
71 |
72 | # L2 normalize encodings and codebook (ViT-VQGAN)
73 | encodings = F.normalize(encodings)
74 | codebook = F.normalize(codebook)
75 |
76 | # Compute euclidean distance with codebook
77 | dist = (
78 | encodings.pow(2).sum(1, keepdim=True)
79 | - 2 * encodings @ codebook.t()
80 | + codebook.pow(2).sum(1, keepdim=True).t()
81 | )
82 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
83 | z_q = self.decode_code(indices)
84 | return z_q, indices
85 |
86 | class VectorQuantize(nn.Module):
87 | """
88 | Implementation of VQ similar to Karpathy's repo:
89 | https://github.com/karpathy/deep-vector-quantization
90 | Additionally uses following tricks from Improved VQGAN
91 | (https://arxiv.org/pdf/2110.04627.pdf):
92 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
93 | for improved codebook usage
94 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
95 | improves training stability
96 | """
97 |
98 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
99 | super().__init__()
100 | self.codebook_size = codebook_size
101 | self.codebook_dim = codebook_dim
102 |
103 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
104 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
105 | self.codebook = nn.Embedding(codebook_size, codebook_dim)
106 |
107 | def forward(self, z, z_mask=None):
108 | """Quantized the input tensor using a fixed codebook and returns
109 | the corresponding codebook vectors
110 |
111 | Parameters
112 | ----------
113 | z : Tensor[B x D x T]
114 |
115 | Returns
116 | -------
117 | Tensor[B x D x T]
118 | Quantized continuous representation of input
119 | Tensor[1]
120 | Commitment loss to train encoder to predict vectors closer to codebook
121 | entries
122 | Tensor[1]
123 | Codebook loss to update the codebook
124 | Tensor[B x T]
125 | Codebook indices (quantized discrete representation of input)
126 | Tensor[B x D x T]
127 | Projected latents (continuous representation of input before quantization)
128 | """
129 |
130 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
131 | z_e = self.in_proj(z) # z_e : (B x D x T)
132 | z_q, indices = self.decode_latents(z_e)
133 |
134 | if z_mask is not None:
135 | commitment_loss = (F.mse_loss(z_e, z_q.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
136 | codebook_loss = (F.mse_loss(z_q, z_e.detach(), reduction="none").mean(1) * z_mask).sum() / z_mask.sum()
137 | else:
138 | commitment_loss = F.mse_loss(z_e, z_q.detach())
139 | codebook_loss = F.mse_loss(z_q, z_e.detach())
140 |
141 | z_q = (
142 | z_e + (z_q - z_e).detach()
143 | ) # noop in forward pass, straight-through gradient estimator in backward pass
144 |
145 | z_q = self.out_proj(z_q)
146 |
147 | return z_q, commitment_loss, codebook_loss, indices, z_e
148 |
149 | def embed_code(self, embed_id):
150 | return F.embedding(embed_id, self.codebook.weight)
151 |
152 | def decode_code(self, embed_id):
153 | return self.embed_code(embed_id).transpose(1, 2)
154 |
155 | def decode_latents(self, latents):
156 | encodings = rearrange(latents, "b d t -> (b t) d")
157 | codebook = self.codebook.weight # codebook: (N x D)
158 |
159 | # L2 normalize encodings and codebook (ViT-VQGAN)
160 | encodings = F.normalize(encodings)
161 | codebook = F.normalize(codebook)
162 |
163 | # Compute euclidean distance with codebook
164 | dist = (
165 | encodings.pow(2).sum(1, keepdim=True)
166 | - 2 * encodings @ codebook.t()
167 | + codebook.pow(2).sum(1, keepdim=True).t()
168 | )
169 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
170 | z_q = self.decode_code(indices)
171 | return z_q, indices
172 |
173 |
174 | class ResidualVectorQuantize(nn.Module):
175 | """
176 | Introduced in SoundStream: An end2end neural audio codec
177 | https://arxiv.org/abs/2107.03312
178 | """
179 |
180 | def __init__(
181 | self,
182 | input_dim: int = 512,
183 | n_codebooks: int = 9,
184 | codebook_size: int = 1024,
185 | codebook_dim: Union[int, list] = 8,
186 | quantizer_dropout: float = 0.0,
187 | ):
188 | super().__init__()
189 | if isinstance(codebook_dim, int):
190 | codebook_dim = [codebook_dim for _ in range(n_codebooks)]
191 |
192 | self.n_codebooks = n_codebooks
193 | self.codebook_dim = codebook_dim
194 | self.codebook_size = codebook_size
195 |
196 | self.quantizers = nn.ModuleList(
197 | [
198 | VectorQuantize(input_dim, codebook_size, codebook_dim[i])
199 | for i in range(n_codebooks)
200 | ]
201 | )
202 | self.quantizer_dropout = quantizer_dropout
203 |
204 | def forward(self, z, n_quantizers: int = None):
205 | """Quantized the input tensor using a fixed set of `n` codebooks and returns
206 | the corresponding codebook vectors
207 | Parameters
208 | ----------
209 | z : Tensor[B x D x T]
210 | n_quantizers : int, optional
211 | No. of quantizers to use
212 | (n_quantizers < self.n_codebooks ex: for quantizer dropout)
213 | Note: if `self.quantizer_dropout` is True, this argument is ignored
214 | when in training mode, and a random number of quantizers is used.
215 | Returns
216 | -------
217 | dict
218 | A dictionary with the following keys:
219 |
220 | "z" : Tensor[B x D x T]
221 | Quantized continuous representation of input
222 | "codes" : Tensor[B x N x T]
223 | Codebook indices for each codebook
224 | (quantized discrete representation of input)
225 | "latents" : Tensor[B x N*D x T]
226 | Projected latents (continuous representation of input before quantization)
227 | "vq/commitment_loss" : Tensor[1]
228 | Commitment loss to train encoder to predict vectors closer to codebook
229 | entries
230 | "vq/codebook_loss" : Tensor[1]
231 | Codebook loss to update the codebook
232 | """
233 | z_q = 0
234 | residual = z
235 | commitment_loss = 0
236 | codebook_loss = 0
237 |
238 | codebook_indices = []
239 | latents = []
240 |
241 | if n_quantizers is None:
242 | n_quantizers = self.n_codebooks
243 | if self.training:
244 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
245 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
246 | n_dropout = int(z.shape[0] * self.quantizer_dropout)
247 | n_quantizers[:n_dropout] = dropout[:n_dropout]
248 | n_quantizers = n_quantizers.to(z.device)
249 |
250 | for i, quantizer in enumerate(self.quantizers):
251 | if self.training is False and i >= n_quantizers:
252 | break
253 |
254 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
255 | residual
256 | )
257 |
258 | # Create mask to apply quantizer dropout
259 | mask = (
260 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
261 | )
262 | z_q = z_q + z_q_i * mask[:, None, None]
263 | residual = residual - z_q_i
264 |
265 | # Sum losses
266 | commitment_loss += (commitment_loss_i * mask).mean()
267 | codebook_loss += (codebook_loss_i * mask).mean()
268 |
269 | codebook_indices.append(indices_i)
270 | latents.append(z_e_i)
271 |
272 | codes = torch.stack(codebook_indices, dim=1)
273 | latents = torch.cat(latents, dim=1)
274 |
275 | return z_q, codes, latents, commitment_loss, codebook_loss
276 |
277 | def from_codes(self, codes: torch.Tensor):
278 | """Given the quantized codes, reconstruct the continuous representation
279 | Parameters
280 | ----------
281 | codes : Tensor[B x N x T]
282 | Quantized discrete representation of input
283 | Returns
284 | -------
285 | Tensor[B x D x T]
286 | Quantized continuous representation of input
287 | """
288 | z_q = 0.0
289 | z_p = []
290 | n_codebooks = codes.shape[1]
291 | for i in range(n_codebooks):
292 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
293 | z_p.append(z_p_i)
294 |
295 | z_q_i = self.quantizers[i].out_proj(z_p_i)
296 | z_q = z_q + z_q_i
297 | return z_q, torch.cat(z_p, dim=1), codes
298 |
299 | def from_latents(self, latents: torch.Tensor):
300 | """Given the unquantized latents, reconstruct the
301 | continuous representation after quantization.
302 |
303 | Parameters
304 | ----------
305 | latents : Tensor[B x N x T]
306 | Continuous representation of input after projection
307 |
308 | Returns
309 | -------
310 | Tensor[B x D x T]
311 | Quantized representation of full-projected space
312 | Tensor[B x D x T]
313 | Quantized representation of latent space
314 | """
315 | z_q = 0
316 | z_p = []
317 | codes = []
318 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
319 |
320 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
321 | 0
322 | ]
323 | for i in range(n_codebooks):
324 | j, k = dims[i], dims[i + 1]
325 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
326 | z_p.append(z_p_i)
327 | codes.append(codes_i)
328 |
329 | z_q_i = self.quantizers[i].out_proj(z_p_i)
330 | z_q = z_q + z_q_i
331 |
332 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
333 |
334 |
335 | if __name__ == "__main__":
336 | rvq = ResidualVectorQuantize(quantizer_dropout=True)
337 | x = torch.randn(16, 512, 80)
338 | y = rvq(x)
339 | print(y["latents"].shape)
340 |
--------------------------------------------------------------------------------
/seedvc/dac/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import argbind
4 | from audiotools import ml
5 |
6 | import dac
7 |
8 | DAC = dac.model.DAC
9 | Accelerator = ml.Accelerator
10 |
11 | __MODEL_LATEST_TAGS__ = {
12 | ("44khz", "8kbps"): "0.0.1",
13 | ("24khz", "8kbps"): "0.0.4",
14 | ("16khz", "8kbps"): "0.0.5",
15 | ("44khz", "16kbps"): "1.0.0",
16 | }
17 |
18 | __MODEL_URLS__ = {
19 | (
20 | "44khz",
21 | "0.0.1",
22 | "8kbps",
23 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
24 | (
25 | "24khz",
26 | "0.0.4",
27 | "8kbps",
28 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
29 | (
30 | "16khz",
31 | "0.0.5",
32 | "8kbps",
33 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
34 | (
35 | "44khz",
36 | "1.0.0",
37 | "16kbps",
38 | ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
39 | }
40 |
41 |
42 | @argbind.bind(group="download", positional=True, without_prefix=True)
43 | def download(
44 | model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
45 | ):
46 | """
47 | Function that downloads the weights file from URL if a local cache is not found.
48 |
49 | Parameters
50 | ----------
51 | model_type : str
52 | The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
53 | model_bitrate: str
54 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
55 | Only 44khz model supports 16kbps.
56 | tag : str
57 | The tag of the model to download. Defaults to "latest".
58 |
59 | Returns
60 | -------
61 | Path
62 | Directory path required to load model via audiotools.
63 | """
64 | model_type = model_type.lower()
65 | tag = tag.lower()
66 |
67 | assert model_type in [
68 | "44khz",
69 | "24khz",
70 | "16khz",
71 | ], "model_type must be one of '44khz', '24khz', or '16khz'"
72 |
73 | assert model_bitrate in [
74 | "8kbps",
75 | "16kbps",
76 | ], "model_bitrate must be one of '8kbps', or '16kbps'"
77 |
78 | if tag == "latest":
79 | tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
80 |
81 | download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
82 |
83 | if download_link is None:
84 | raise ValueError(
85 | f"Could not find model with tag {tag} and model type {model_type}"
86 | )
87 |
88 | local_path = (
89 | Path.home()
90 | / ".cache"
91 | / "descript"
92 | / "dac"
93 | / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
94 | )
95 | if not local_path.exists():
96 | local_path.parent.mkdir(parents=True, exist_ok=True)
97 |
98 | # Download the model
99 | import requests
100 |
101 | response = requests.get(download_link)
102 |
103 | if response.status_code != 200:
104 | raise ValueError(
105 | f"Could not download model. Received response code {response.status_code}"
106 | )
107 | local_path.write_bytes(response.content)
108 |
109 | return local_path
110 |
111 |
112 | def load_model(
113 | model_type: str = "44khz",
114 | model_bitrate: str = "8kbps",
115 | tag: str = "latest",
116 | load_path: str = None,
117 | ):
118 | if not load_path:
119 | load_path = download(
120 | model_type=model_type, model_bitrate=model_bitrate, tag=tag
121 | )
122 | generator = DAC.load(load_path)
123 | return generator
124 |
--------------------------------------------------------------------------------
/seedvc/dac/utils/decode.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from pathlib import Path
3 |
4 | import argbind
5 | import numpy as np
6 | import torch
7 | from audiotools import AudioSignal
8 | from tqdm import tqdm
9 |
10 | from dac import DACFile
11 | from dac.utils import load_model
12 |
13 | warnings.filterwarnings("ignore", category=UserWarning)
14 |
15 |
16 | @argbind.bind(group="decode", positional=True, without_prefix=True)
17 | @torch.inference_mode()
18 | @torch.no_grad()
19 | def decode(
20 | input: str,
21 | output: str = "",
22 | weights_path: str = "",
23 | model_tag: str = "latest",
24 | model_bitrate: str = "8kbps",
25 | device: str = "cuda",
26 | model_type: str = "44khz",
27 | verbose: bool = False,
28 | ):
29 | """Decode audio from codes.
30 |
31 | Parameters
32 | ----------
33 | input : str
34 | Path to input directory or file
35 | output : str, optional
36 | Path to output directory, by default "".
37 | If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38 | weights_path : str, optional
39 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40 | model_tag and model_type.
41 | model_tag : str, optional
42 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43 | model_bitrate: str
44 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45 | device : str, optional
46 | Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47 | model_type : str, optional
48 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49 | """
50 | generator = load_model(
51 | model_type=model_type,
52 | model_bitrate=model_bitrate,
53 | tag=model_tag,
54 | load_path=weights_path,
55 | )
56 | generator.to(device)
57 | generator.eval()
58 |
59 | # Find all .dac files in input directory
60 | _input = Path(input)
61 | input_files = list(_input.glob("**/*.dac"))
62 |
63 | # If input is a .dac file, add it to the list
64 | if _input.suffix == ".dac":
65 | input_files.append(_input)
66 |
67 | # Create output directory
68 | output = Path(output)
69 | output.mkdir(parents=True, exist_ok=True)
70 |
71 | for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72 | # Load file
73 | artifact = DACFile.load(input_files[i])
74 |
75 | # Reconstruct audio from codes
76 | recons = generator.decompress(artifact, verbose=verbose)
77 |
78 | # Compute output path
79 | relative_path = input_files[i].relative_to(input)
80 | output_dir = output / relative_path.parent
81 | if not relative_path.name:
82 | output_dir = output
83 | relative_path = input_files[i]
84 | output_name = relative_path.with_suffix(".wav").name
85 | output_path = output_dir / output_name
86 | output_path.parent.mkdir(parents=True, exist_ok=True)
87 |
88 | # Write to file
89 | recons.write(output_path)
90 |
91 |
92 | if __name__ == "__main__":
93 | args = argbind.parse_args()
94 | with argbind.scope(args):
95 | decode()
96 |
--------------------------------------------------------------------------------
/seedvc/dac/utils/encode.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from pathlib import Path
4 |
5 | import argbind
6 | import numpy as np
7 | import torch
8 | from audiotools import AudioSignal
9 | from audiotools.core import util
10 | from tqdm import tqdm
11 |
12 | from dac.utils import load_model
13 |
14 | warnings.filterwarnings("ignore", category=UserWarning)
15 |
16 |
17 | @argbind.bind(group="encode", positional=True, without_prefix=True)
18 | @torch.inference_mode()
19 | @torch.no_grad()
20 | def encode(
21 | input: str,
22 | output: str = "",
23 | weights_path: str = "",
24 | model_tag: str = "latest",
25 | model_bitrate: str = "8kbps",
26 | n_quantizers: int = None,
27 | device: str = "cuda",
28 | model_type: str = "44khz",
29 | win_duration: float = 5.0,
30 | verbose: bool = False,
31 | ):
32 | """Encode audio files in input path to .dac format.
33 |
34 | Parameters
35 | ----------
36 | input : str
37 | Path to input audio file or directory
38 | output : str, optional
39 | Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
40 | weights_path : str, optional
41 | Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
42 | model_tag and model_type.
43 | model_tag : str, optional
44 | Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
45 | model_bitrate: str
46 | Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
47 | n_quantizers : int, optional
48 | Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate.
49 | device : str, optional
50 | Device to use, by default "cuda"
51 | model_type : str, optional
52 | The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
53 | """
54 | generator = load_model(
55 | model_type=model_type,
56 | model_bitrate=model_bitrate,
57 | tag=model_tag,
58 | load_path=weights_path,
59 | )
60 | generator.to(device)
61 | generator.eval()
62 | kwargs = {"n_quantizers": n_quantizers}
63 |
64 | # Find all audio files in input path
65 | input = Path(input)
66 | audio_files = util.find_audio(input)
67 |
68 | output = Path(output)
69 | output.mkdir(parents=True, exist_ok=True)
70 |
71 | for i in tqdm(range(len(audio_files)), desc="Encoding files"):
72 | # Load file
73 | signal = AudioSignal(audio_files[i])
74 |
75 | # Encode audio to .dac format
76 | artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs)
77 |
78 | # Compute output path
79 | relative_path = audio_files[i].relative_to(input)
80 | output_dir = output / relative_path.parent
81 | if not relative_path.name:
82 | output_dir = output
83 | relative_path = audio_files[i]
84 | output_name = relative_path.with_suffix(".dac").name
85 | output_path = output_dir / output_name
86 | output_path.parent.mkdir(parents=True, exist_ok=True)
87 |
88 | artifact.save(output_path)
89 |
90 |
91 | if __name__ == "__main__":
92 | args = argbind.parse_args()
93 | with argbind.scope(args):
94 | encode()
95 |
--------------------------------------------------------------------------------
/seedvc/modules/alias_free_torch/__init__.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 |
3 | from .filter import *
4 | from .resample import *
5 | from .act import *
6 |
--------------------------------------------------------------------------------
/seedvc/modules/alias_free_torch/act.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 |
3 | import torch.nn as nn
4 | from .resample import UpSample1d, DownSample1d
5 |
6 |
7 | class Activation1d(nn.Module):
8 | def __init__(
9 | self,
10 | activation,
11 | up_ratio: int = 2,
12 | down_ratio: int = 2,
13 | up_kernel_size: int = 12,
14 | down_kernel_size: int = 12,
15 | ):
16 | super().__init__()
17 | self.up_ratio = up_ratio
18 | self.down_ratio = down_ratio
19 | self.act = activation
20 | self.upsample = UpSample1d(up_ratio, up_kernel_size)
21 | self.downsample = DownSample1d(down_ratio, down_kernel_size)
22 |
23 | # x: [B,C,T]
24 | def forward(self, x):
25 | x = self.upsample(x)
26 | x = self.act(x)
27 | x = self.downsample(x)
28 |
29 | return x
30 |
--------------------------------------------------------------------------------
/seedvc/modules/alias_free_torch/filter.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 | if "sinc" in dir(torch):
9 | sinc = torch.sinc
10 | else:
11 | # This code is adopted from adefossez's julius.core.sinc under the MIT License
12 | # https://adefossez.github.io/julius/julius/core.html
13 | def sinc(x: torch.Tensor):
14 | """
15 | Implementation of sinc, i.e. sin(pi * x) / (pi * x)
16 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
17 | """
18 | return torch.where(
19 | x == 0,
20 | torch.tensor(1.0, device=x.device, dtype=x.dtype),
21 | torch.sin(math.pi * x) / math.pi / x,
22 | )
23 |
24 |
25 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26 | # https://adefossez.github.io/julius/julius/lowpass.html
27 | def kaiser_sinc_filter1d(
28 | cutoff, half_width, kernel_size
29 | ): # return filter [1,1,kernel_size]
30 | even = kernel_size % 2 == 0
31 | half_size = kernel_size // 2
32 |
33 | # For kaiser window
34 | delta_f = 4 * half_width
35 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36 | if A > 50.0:
37 | beta = 0.1102 * (A - 8.7)
38 | elif A >= 21.0:
39 | beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40 | else:
41 | beta = 0.0
42 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43 |
44 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45 | if even:
46 | time = torch.arange(-half_size, half_size) + 0.5
47 | else:
48 | time = torch.arange(kernel_size) - half_size
49 | if cutoff == 0:
50 | filter_ = torch.zeros_like(time)
51 | else:
52 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53 | # Normalize filter to have sum = 1, otherwise we will have a small leakage
54 | # of the constant component in the input signal.
55 | filter_ /= filter_.sum()
56 | filter = filter_.view(1, 1, kernel_size)
57 |
58 | return filter
59 |
60 |
61 | class LowPassFilter1d(nn.Module):
62 | def __init__(
63 | self,
64 | cutoff=0.5,
65 | half_width=0.6,
66 | stride: int = 1,
67 | padding: bool = True,
68 | padding_mode: str = "replicate",
69 | kernel_size: int = 12,
70 | ):
71 | # kernel_size should be even number for stylegan3 setup,
72 | # in this implementation, odd number is also possible.
73 | super().__init__()
74 | if cutoff < -0.0:
75 | raise ValueError("Minimum cutoff must be larger than zero.")
76 | if cutoff > 0.5:
77 | raise ValueError("A cutoff above 0.5 does not make sense.")
78 | self.kernel_size = kernel_size
79 | self.even = kernel_size % 2 == 0
80 | self.pad_left = kernel_size // 2 - int(self.even)
81 | self.pad_right = kernel_size // 2
82 | self.stride = stride
83 | self.padding = padding
84 | self.padding_mode = padding_mode
85 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86 | self.register_buffer("filter", filter)
87 |
88 | # input [B, C, T]
89 | def forward(self, x):
90 | _, C, _ = x.shape
91 |
92 | if self.padding:
93 | x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94 | out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95 |
96 | return out
97 |
--------------------------------------------------------------------------------
/seedvc/modules/alias_free_torch/resample.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 |
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from .filter import LowPassFilter1d
6 | from .filter import kaiser_sinc_filter1d
7 |
8 |
9 | class UpSample1d(nn.Module):
10 | def __init__(self, ratio=2, kernel_size=None):
11 | super().__init__()
12 | self.ratio = ratio
13 | self.kernel_size = (
14 | int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15 | )
16 | self.stride = ratio
17 | self.pad = self.kernel_size // ratio - 1
18 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19 | self.pad_right = (
20 | self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
21 | )
22 | filter = kaiser_sinc_filter1d(
23 | cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
24 | )
25 | self.register_buffer("filter", filter)
26 |
27 | # x: [B, C, T]
28 | def forward(self, x):
29 | _, C, _ = x.shape
30 |
31 | x = F.pad(x, (self.pad, self.pad), mode="replicate")
32 | x = self.ratio * F.conv_transpose1d(
33 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34 | )
35 | x = x[..., self.pad_left : -self.pad_right]
36 |
37 | return x
38 |
39 |
40 | class DownSample1d(nn.Module):
41 | def __init__(self, ratio=2, kernel_size=None):
42 | super().__init__()
43 | self.ratio = ratio
44 | self.kernel_size = (
45 | int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
46 | )
47 | self.lowpass = LowPassFilter1d(
48 | cutoff=0.5 / ratio,
49 | half_width=0.6 / ratio,
50 | stride=ratio,
51 | kernel_size=self.kernel_size,
52 | )
53 |
54 | def forward(self, x):
55 | xx = self.lowpass(x)
56 |
57 | return xx
58 |
--------------------------------------------------------------------------------
/seedvc/modules/audio.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data
4 | from librosa.filters import mel as librosa_mel_fn
5 | from scipy.io.wavfile import read
6 |
7 | MAX_WAV_VALUE = 32768.0
8 |
9 |
10 | def load_wav(full_path):
11 | sampling_rate, data = read(full_path)
12 | return data, sampling_rate
13 |
14 |
15 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17 |
18 |
19 | def dynamic_range_decompression(x, C=1):
20 | return np.exp(x) / C
21 |
22 |
23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24 | return torch.log(torch.clamp(x, min=clip_val) * C)
25 |
26 |
27 | def dynamic_range_decompression_torch(x, C=1):
28 | return torch.exp(x) / C
29 |
30 |
31 | def spectral_normalize_torch(magnitudes):
32 | output = dynamic_range_compression_torch(magnitudes)
33 | return output
34 |
35 |
36 | def spectral_de_normalize_torch(magnitudes):
37 | output = dynamic_range_decompression_torch(magnitudes)
38 | return output
39 |
40 |
41 | mel_basis = {}
42 | hann_window = {}
43 |
44 |
45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46 | if torch.min(y) < -1.0:
47 | print("min value is ", torch.min(y))
48 | if torch.max(y) > 1.0:
49 | print("max value is ", torch.max(y))
50 |
51 | global mel_basis, hann_window # pylint: disable=global-statement
52 | if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54 | mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55 | hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
56 |
57 | y = torch.nn.functional.pad(
58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59 | )
60 | y = y.squeeze(1)
61 |
62 | spec = torch.view_as_real(
63 | torch.stft(
64 | y,
65 | n_fft,
66 | hop_length=hop_size,
67 | win_length=win_size,
68 | window=hann_window[str(sampling_rate) + "_" + str(y.device)],
69 | center=center,
70 | pad_mode="reflect",
71 | normalized=False,
72 | onesided=True,
73 | return_complex=True,
74 | )
75 | )
76 |
77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78 |
79 | spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
80 | spec = spectral_normalize_torch(spec)
81 |
82 | return spec
83 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/activations.py:
--------------------------------------------------------------------------------
1 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch
5 | from torch import nn, sin, pow
6 | from torch.nn import Parameter
7 |
8 |
9 | class Snake(nn.Module):
10 | '''
11 | Implementation of a sine-based periodic activation function
12 | Shape:
13 | - Input: (B, C, T)
14 | - Output: (B, C, T), same shape as the input
15 | Parameters:
16 | - alpha - trainable parameter
17 | References:
18 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19 | https://arxiv.org/abs/2006.08195
20 | Examples:
21 | >>> a1 = snake(256)
22 | >>> x = torch.randn(256)
23 | >>> x = a1(x)
24 | '''
25 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26 | '''
27 | Initialization.
28 | INPUT:
29 | - in_features: shape of the input
30 | - alpha: trainable parameter
31 | alpha is initialized to 1 by default, higher values = higher-frequency.
32 | alpha will be trained along with the rest of your model.
33 | '''
34 | super(Snake, self).__init__()
35 | self.in_features = in_features
36 |
37 | # initialize alpha
38 | self.alpha_logscale = alpha_logscale
39 | if self.alpha_logscale: # log scale alphas initialized to zeros
40 | self.alpha = Parameter(torch.zeros(in_features) * alpha)
41 | else: # linear scale alphas initialized to ones
42 | self.alpha = Parameter(torch.ones(in_features) * alpha)
43 |
44 | self.alpha.requires_grad = alpha_trainable
45 |
46 | self.no_div_by_zero = 0.000000001
47 |
48 | def forward(self, x):
49 | '''
50 | Forward pass of the function.
51 | Applies the function to the input elementwise.
52 | Snake ∶= x + 1/a * sin^2 (xa)
53 | '''
54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55 | if self.alpha_logscale:
56 | alpha = torch.exp(alpha)
57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58 |
59 | return x
60 |
61 |
62 | class SnakeBeta(nn.Module):
63 | '''
64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components
65 | Shape:
66 | - Input: (B, C, T)
67 | - Output: (B, C, T), same shape as the input
68 | Parameters:
69 | - alpha - trainable parameter that controls frequency
70 | - beta - trainable parameter that controls magnitude
71 | References:
72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73 | https://arxiv.org/abs/2006.08195
74 | Examples:
75 | >>> a1 = snakebeta(256)
76 | >>> x = torch.randn(256)
77 | >>> x = a1(x)
78 | '''
79 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80 | '''
81 | Initialization.
82 | INPUT:
83 | - in_features: shape of the input
84 | - alpha - trainable parameter that controls frequency
85 | - beta - trainable parameter that controls magnitude
86 | alpha is initialized to 1 by default, higher values = higher-frequency.
87 | beta is initialized to 1 by default, higher values = higher-magnitude.
88 | alpha will be trained along with the rest of your model.
89 | '''
90 | super(SnakeBeta, self).__init__()
91 | self.in_features = in_features
92 |
93 | # initialize alpha
94 | self.alpha_logscale = alpha_logscale
95 | if self.alpha_logscale: # log scale alphas initialized to zeros
96 | self.alpha = Parameter(torch.zeros(in_features) * alpha)
97 | self.beta = Parameter(torch.zeros(in_features) * alpha)
98 | else: # linear scale alphas initialized to ones
99 | self.alpha = Parameter(torch.ones(in_features) * alpha)
100 | self.beta = Parameter(torch.ones(in_features) * alpha)
101 |
102 | self.alpha.requires_grad = alpha_trainable
103 | self.beta.requires_grad = alpha_trainable
104 |
105 | self.no_div_by_zero = 0.000000001
106 |
107 | def forward(self, x):
108 | '''
109 | Forward pass of the function.
110 | Applies the function to the input elementwise.
111 | SnakeBeta ∶= x + 1/b * sin^2 (xa)
112 | '''
113 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114 | beta = self.beta.unsqueeze(0).unsqueeze(-1)
115 | if self.alpha_logscale:
116 | alpha = torch.exp(alpha)
117 | beta = torch.exp(beta)
118 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119 |
120 | return x
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/cuda/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIFSH/SeedVC-ComfyUI/10731e18e28ccd4d73f87423a7f2a645ce3e6555/seedvc/modules/bigvgan/alias_free_activation/cuda/__init__.py
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/cuda/activation1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 NVIDIA CORPORATION.
2 | # Licensed under the MIT license.
3 |
4 | import torch
5 | import torch.nn as nn
6 | from ..torch.resample import UpSample1d, DownSample1d
7 |
8 | # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9 | from ..cuda import load
10 |
11 | anti_alias_activation_cuda = load.load()
12 |
13 |
14 | class FusedAntiAliasActivation(torch.autograd.Function):
15 | """
16 | Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17 | The hyperparameters are hard-coded in the kernel to maximize speed.
18 | NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19 | """
20 |
21 | @staticmethod
22 | def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23 | activation_results = anti_alias_activation_cuda.forward(
24 | inputs, up_ftr, down_ftr, alpha, beta
25 | )
26 |
27 | return activation_results
28 |
29 | @staticmethod
30 | def backward(ctx, output_grads):
31 | raise NotImplementedError
32 | return output_grads, None, None
33 |
34 |
35 | class Activation1d(nn.Module):
36 | def __init__(
37 | self,
38 | activation,
39 | up_ratio: int = 2,
40 | down_ratio: int = 2,
41 | up_kernel_size: int = 12,
42 | down_kernel_size: int = 12,
43 | fused: bool = True,
44 | ):
45 | super().__init__()
46 | self.up_ratio = up_ratio
47 | self.down_ratio = down_ratio
48 | self.act = activation
49 | self.upsample = UpSample1d(up_ratio, up_kernel_size)
50 | self.downsample = DownSample1d(down_ratio, down_kernel_size)
51 |
52 | self.fused = fused # Whether to use fused CUDA kernel or not
53 |
54 | def forward(self, x):
55 | if not self.fused:
56 | x = self.upsample(x)
57 | x = self.act(x)
58 | x = self.downsample(x)
59 | return x
60 | else:
61 | if self.act.__class__.__name__ == "Snake":
62 | beta = self.act.alpha.data # Snake uses same params for alpha and beta
63 | else:
64 | beta = (
65 | self.act.beta.data
66 | ) # Snakebeta uses different params for alpha and beta
67 | alpha = self.act.alpha.data
68 | if (
69 | not self.act.alpha_logscale
70 | ): # Exp baked into cuda kernel, cancel it out with a log
71 | alpha = torch.log(alpha)
72 | beta = torch.log(beta)
73 |
74 | x = FusedAntiAliasActivation.apply(
75 | x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76 | )
77 | return x
78 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 |
19 | extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23 | }
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 | #include "type_shim.h"
25 | #include
26 | #include
27 | #include
28 | #include
29 | #include
30 |
31 | namespace
32 | {
33 | // Hard-coded hyperparameters
34 | // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35 | constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36 | constexpr int BUFFER_SIZE = 32;
37 | constexpr int FILTER_SIZE = 12;
38 | constexpr int HALF_FILTER_SIZE = 6;
39 | constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40 | constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41 | constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42 |
43 | template
44 | __global__ void anti_alias_activation_forward(
45 | output_t *dst,
46 | const input_t *src,
47 | const input_t *up_ftr,
48 | const input_t *down_ftr,
49 | const input_t *alpha,
50 | const input_t *beta,
51 | int batch_size,
52 | int channels,
53 | int seq_len)
54 | {
55 | // Up and downsample filters
56 | input_t up_filter[FILTER_SIZE];
57 | input_t down_filter[FILTER_SIZE];
58 |
59 | // Load data from global memory including extra indices reserved for replication paddings
60 | input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61 | input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62 |
63 | // Output stores downsampled output before writing to dst
64 | output_t output[BUFFER_SIZE];
65 |
66 | // blockDim/threadIdx = (128, 1, 1)
67 | // gridDim/blockIdx = (seq_blocks, channels, batches)
68 | int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69 | int local_offset = threadIdx.x * BUFFER_SIZE;
70 | int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71 |
72 | // intermediate have double the seq_len
73 | int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74 | int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75 |
76 | // Get values needed for replication padding before moving pointer
77 | const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78 | input_t seq_left_most_value = right_most_pntr[0];
79 | input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80 |
81 | // Move src and dst pointers
82 | src += block_offset + local_offset;
83 | dst += block_offset + local_offset;
84 |
85 | // Alpha and beta values for snake activatons. Applies exp by default
86 | alpha = alpha + blockIdx.y;
87 | input_t alpha_val = expf(alpha[0]);
88 | beta = beta + blockIdx.y;
89 | input_t beta_val = expf(beta[0]);
90 |
91 | #pragma unroll
92 | for (int it = 0; it < FILTER_SIZE; it += 1)
93 | {
94 | up_filter[it] = up_ftr[it];
95 | down_filter[it] = down_ftr[it];
96 | }
97 |
98 | // Apply replication padding for upsampling, matching torch impl
99 | #pragma unroll
100 | for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101 | {
102 | int element_index = seq_offset + it; // index for element
103 | if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104 | {
105 | elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106 | }
107 | if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108 | {
109 | elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110 | }
111 | if ((element_index >= 0) && (element_index < seq_len))
112 | {
113 | elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114 | }
115 | }
116 |
117 | // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118 | #pragma unroll
119 | for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120 | {
121 | input_t acc = 0.0;
122 | int element_index = intermediate_seq_offset + it; // index for intermediate
123 | #pragma unroll
124 | for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125 | {
126 | if ((element_index + f_idx) >= 0)
127 | {
128 | acc += up_filter[f_idx] * elements[it + f_idx];
129 | }
130 | }
131 | intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132 | }
133 |
134 | // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135 | double no_div_by_zero = 0.000000001;
136 | #pragma unroll
137 | for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138 | {
139 | intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140 | }
141 |
142 | // Apply replication padding before downsampling conv from intermediates
143 | #pragma unroll
144 | for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145 | {
146 | intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147 | }
148 | #pragma unroll
149 | for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150 | {
151 | intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152 | }
153 |
154 | // Apply downsample strided convolution (assuming stride=2) from intermediates
155 | #pragma unroll
156 | for (int it = 0; it < BUFFER_SIZE; it += 1)
157 | {
158 | input_t acc = 0.0;
159 | #pragma unroll
160 | for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161 | {
162 | // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163 | acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164 | }
165 | output[it] = acc;
166 | }
167 |
168 | // Write output to dst
169 | #pragma unroll
170 | for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171 | {
172 | int element_index = seq_offset + it;
173 | if (element_index < seq_len)
174 | {
175 | dst[it] = output[it];
176 | }
177 | }
178 |
179 | }
180 |
181 | template
182 | void dispatch_anti_alias_activation_forward(
183 | output_t *dst,
184 | const input_t *src,
185 | const input_t *up_ftr,
186 | const input_t *down_ftr,
187 | const input_t *alpha,
188 | const input_t *beta,
189 | int batch_size,
190 | int channels,
191 | int seq_len)
192 | {
193 | if (seq_len == 0)
194 | {
195 | return;
196 | }
197 | else
198 | {
199 | // Use 128 threads per block to maximimize gpu utilization
200 | constexpr int threads_per_block = 128;
201 | constexpr int seq_len_per_block = 4096;
202 | int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203 | dim3 blocks(blocks_per_seq_len, channels, batch_size);
204 | dim3 threads(threads_per_block, 1, 1);
205 |
206 | anti_alias_activation_forward
207 | <<>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208 | }
209 | }
210 | }
211 |
212 | extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213 | {
214 | // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215 | const int batches = input.size(0);
216 | const int channels = input.size(1);
217 | const int seq_len = input.size(2);
218 |
219 | // Output
220 | auto act_options = input.options().requires_grad(false);
221 |
222 | torch::Tensor anti_alias_activation_results =
223 | torch::empty({batches, channels, seq_len}, act_options);
224 |
225 | void *input_ptr = static_cast(input.data_ptr());
226 | void *up_filter_ptr = static_cast(up_filter.data_ptr());
227 | void *down_filter_ptr = static_cast(down_filter.data_ptr());
228 | void *alpha_ptr = static_cast(alpha.data_ptr());
229 | void *beta_ptr = static_cast(beta.data_ptr());
230 | void *anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr());
231 |
232 | DISPATCH_FLOAT_HALF_AND_BFLOAT(
233 | input.scalar_type(),
234 | "dispatch anti alias activation_forward",
235 | dispatch_anti_alias_activation_forward(
236 | reinterpret_cast(anti_alias_activation_results_ptr),
237 | reinterpret_cast(input_ptr),
238 | reinterpret_cast(up_filter_ptr),
239 | reinterpret_cast(down_filter_ptr),
240 | reinterpret_cast(alpha_ptr),
241 | reinterpret_cast(beta_ptr),
242 | batches,
243 | channels,
244 | seq_len););
245 | return anti_alias_activation_results;
246 | }
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/cuda/compat.h:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | /*This code is copied fron NVIDIA apex:
18 | * https://github.com/NVIDIA/apex
19 | * with minor changes. */
20 |
21 | #ifndef TORCH_CHECK
22 | #define TORCH_CHECK AT_CHECK
23 | #endif
24 |
25 | #ifdef VERSION_GE_1_3
26 | #define DATA_PTR data_ptr
27 | #else
28 | #define DATA_PTR data
29 | #endif
30 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/cuda/load.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 NVIDIA CORPORATION.
2 | # Licensed under the MIT license.
3 |
4 | import os
5 | import pathlib
6 | import subprocess
7 |
8 | from torch.utils import cpp_extension
9 |
10 | """
11 | Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12 | Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13 | """
14 | os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15 |
16 |
17 | def load():
18 | # Check if cuda 11 is installed for compute capability 8.0
19 | cc_flag = []
20 | _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21 | if int(bare_metal_major) >= 11:
22 | cc_flag.append("-gencode")
23 | cc_flag.append("arch=compute_80,code=sm_80")
24 |
25 | # Build path
26 | srcpath = pathlib.Path(__file__).parent.absolute()
27 | buildpath = srcpath / "build"
28 | _create_build_dir(buildpath)
29 |
30 | # Helper function to build the kernels.
31 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32 | return cpp_extension.load(
33 | name=name,
34 | sources=sources,
35 | build_directory=buildpath,
36 | extra_cflags=[
37 | "-O3",
38 | ],
39 | extra_cuda_cflags=[
40 | "-O3",
41 | "-gencode",
42 | "arch=compute_70,code=sm_70",
43 | "--use_fast_math",
44 | ]
45 | + extra_cuda_flags
46 | + cc_flag,
47 | verbose=True,
48 | )
49 |
50 | extra_cuda_flags = [
51 | "-U__CUDA_NO_HALF_OPERATORS__",
52 | "-U__CUDA_NO_HALF_CONVERSIONS__",
53 | "--expt-relaxed-constexpr",
54 | "--expt-extended-lambda",
55 | ]
56 |
57 | sources = [
58 | srcpath / "anti_alias_activation.cpp",
59 | srcpath / "anti_alias_activation_cuda.cu",
60 | ]
61 | anti_alias_activation_cuda = _cpp_extention_load_helper(
62 | "anti_alias_activation_cuda", sources, extra_cuda_flags
63 | )
64 |
65 | return anti_alias_activation_cuda
66 |
67 |
68 | def _get_cuda_bare_metal_version(cuda_dir):
69 | raw_output = subprocess.check_output(
70 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71 | )
72 | output = raw_output.split()
73 | release_idx = output.index("release") + 1
74 | release = output[release_idx].split(".")
75 | bare_metal_major = release[0]
76 | bare_metal_minor = release[1][0]
77 |
78 | return raw_output, bare_metal_major, bare_metal_minor
79 |
80 |
81 | def _create_build_dir(buildpath):
82 | try:
83 | os.mkdir(buildpath)
84 | except OSError:
85 | if not os.path.isdir(buildpath):
86 | print(f"Creation of the build directory {buildpath} failed")
87 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/cuda/type_shim.h:
--------------------------------------------------------------------------------
1 | /* coding=utf-8
2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #include
18 | #include "compat.h"
19 |
20 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21 | switch (TYPE) \
22 | { \
23 | case at::ScalarType::Float: \
24 | { \
25 | using scalar_t = float; \
26 | __VA_ARGS__; \
27 | break; \
28 | } \
29 | case at::ScalarType::Half: \
30 | { \
31 | using scalar_t = at::Half; \
32 | __VA_ARGS__; \
33 | break; \
34 | } \
35 | case at::ScalarType::BFloat16: \
36 | { \
37 | using scalar_t = at::BFloat16; \
38 | __VA_ARGS__; \
39 | break; \
40 | } \
41 | default: \
42 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43 | }
44 |
45 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46 | switch (TYPEIN) \
47 | { \
48 | case at::ScalarType::Float: \
49 | { \
50 | using scalar_t_in = float; \
51 | switch (TYPEOUT) \
52 | { \
53 | case at::ScalarType::Float: \
54 | { \
55 | using scalar_t_out = float; \
56 | __VA_ARGS__; \
57 | break; \
58 | } \
59 | case at::ScalarType::Half: \
60 | { \
61 | using scalar_t_out = at::Half; \
62 | __VA_ARGS__; \
63 | break; \
64 | } \
65 | case at::ScalarType::BFloat16: \
66 | { \
67 | using scalar_t_out = at::BFloat16; \
68 | __VA_ARGS__; \
69 | break; \
70 | } \
71 | default: \
72 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73 | } \
74 | break; \
75 | } \
76 | case at::ScalarType::Half: \
77 | { \
78 | using scalar_t_in = at::Half; \
79 | using scalar_t_out = at::Half; \
80 | __VA_ARGS__; \
81 | break; \
82 | } \
83 | case at::ScalarType::BFloat16: \
84 | { \
85 | using scalar_t_in = at::BFloat16; \
86 | using scalar_t_out = at::BFloat16; \
87 | __VA_ARGS__; \
88 | break; \
89 | } \
90 | default: \
91 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92 | }
93 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/torch/__init__.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | from .filter import *
5 | from .resample import *
6 | from .act import *
7 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/torch/act.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch.nn as nn
5 | from .resample import UpSample1d, DownSample1d
6 |
7 |
8 | class Activation1d(nn.Module):
9 | def __init__(
10 | self,
11 | activation,
12 | up_ratio: int = 2,
13 | down_ratio: int = 2,
14 | up_kernel_size: int = 12,
15 | down_kernel_size: int = 12,
16 | ):
17 | super().__init__()
18 | self.up_ratio = up_ratio
19 | self.down_ratio = down_ratio
20 | self.act = activation
21 | self.upsample = UpSample1d(up_ratio, up_kernel_size)
22 | self.downsample = DownSample1d(down_ratio, down_kernel_size)
23 |
24 | # x: [B,C,T]
25 | def forward(self, x):
26 | x = self.upsample(x)
27 | x = self.act(x)
28 | x = self.downsample(x)
29 |
30 | return x
31 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/torch/filter.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import math
8 |
9 | if "sinc" in dir(torch):
10 | sinc = torch.sinc
11 | else:
12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License
13 | # https://adefossez.github.io/julius/julius/core.html
14 | # LICENSE is in incl_licenses directory.
15 | def sinc(x: torch.Tensor):
16 | """
17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19 | """
20 | return torch.where(
21 | x == 0,
22 | torch.tensor(1.0, device=x.device, dtype=x.dtype),
23 | torch.sin(math.pi * x) / math.pi / x,
24 | )
25 |
26 |
27 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28 | # https://adefossez.github.io/julius/julius/lowpass.html
29 | # LICENSE is in incl_licenses directory.
30 | def kaiser_sinc_filter1d(
31 | cutoff, half_width, kernel_size
32 | ): # return filter [1,1,kernel_size]
33 | even = kernel_size % 2 == 0
34 | half_size = kernel_size // 2
35 |
36 | # For kaiser window
37 | delta_f = 4 * half_width
38 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39 | if A > 50.0:
40 | beta = 0.1102 * (A - 8.7)
41 | elif A >= 21.0:
42 | beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43 | else:
44 | beta = 0.0
45 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46 |
47 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48 | if even:
49 | time = torch.arange(-half_size, half_size) + 0.5
50 | else:
51 | time = torch.arange(kernel_size) - half_size
52 | if cutoff == 0:
53 | filter_ = torch.zeros_like(time)
54 | else:
55 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56 | """
57 | Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58 | """
59 | filter_ /= filter_.sum()
60 | filter = filter_.view(1, 1, kernel_size)
61 |
62 | return filter
63 |
64 |
65 | class LowPassFilter1d(nn.Module):
66 | def __init__(
67 | self,
68 | cutoff=0.5,
69 | half_width=0.6,
70 | stride: int = 1,
71 | padding: bool = True,
72 | padding_mode: str = "replicate",
73 | kernel_size: int = 12,
74 | ):
75 | """
76 | kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77 | """
78 | super().__init__()
79 | if cutoff < -0.0:
80 | raise ValueError("Minimum cutoff must be larger than zero.")
81 | if cutoff > 0.5:
82 | raise ValueError("A cutoff above 0.5 does not make sense.")
83 | self.kernel_size = kernel_size
84 | self.even = kernel_size % 2 == 0
85 | self.pad_left = kernel_size // 2 - int(self.even)
86 | self.pad_right = kernel_size // 2
87 | self.stride = stride
88 | self.padding = padding
89 | self.padding_mode = padding_mode
90 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91 | self.register_buffer("filter", filter)
92 |
93 | # Input [B, C, T]
94 | def forward(self, x):
95 | _, C, _ = x.shape
96 |
97 | if self.padding:
98 | x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99 | out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100 |
101 | return out
102 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/alias_free_activation/torch/resample.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 | from .filter import LowPassFilter1d
7 | from .filter import kaiser_sinc_filter1d
8 |
9 |
10 | class UpSample1d(nn.Module):
11 | def __init__(self, ratio=2, kernel_size=None):
12 | super().__init__()
13 | self.ratio = ratio
14 | self.kernel_size = (
15 | int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16 | )
17 | self.stride = ratio
18 | self.pad = self.kernel_size // ratio - 1
19 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20 | self.pad_right = (
21 | self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22 | )
23 | filter = kaiser_sinc_filter1d(
24 | cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25 | )
26 | self.register_buffer("filter", filter)
27 |
28 | # x: [B, C, T]
29 | def forward(self, x):
30 | _, C, _ = x.shape
31 |
32 | x = F.pad(x, (self.pad, self.pad), mode="replicate")
33 | x = self.ratio * F.conv_transpose1d(
34 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35 | )
36 | x = x[..., self.pad_left : -self.pad_right]
37 |
38 | return x
39 |
40 |
41 | class DownSample1d(nn.Module):
42 | def __init__(self, ratio=2, kernel_size=None):
43 | super().__init__()
44 | self.ratio = ratio
45 | self.kernel_size = (
46 | int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47 | )
48 | self.lowpass = LowPassFilter1d(
49 | cutoff=0.5 / ratio,
50 | half_width=0.6 / ratio,
51 | stride=ratio,
52 | kernel_size=self.kernel_size,
53 | )
54 |
55 | def forward(self, x):
56 | xx = self.lowpass(x)
57 |
58 | return xx
59 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 32,
5 | "learning_rate": 0.0001,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.9999996,
9 | "seed": 1234,
10 |
11 | "upsample_rates": [4,4,2,2,2,2],
12 | "upsample_kernel_sizes": [8,8,4,4,4,4],
13 | "upsample_initial_channel": 1536,
14 | "resblock_kernel_sizes": [3,7,11],
15 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16 |
17 | "use_tanh_at_final": false,
18 | "use_bias_at_final": false,
19 |
20 | "activation": "snakebeta",
21 | "snake_logscale": true,
22 |
23 | "use_cqtd_instead_of_mrd": true,
24 | "cqtd_filters": 128,
25 | "cqtd_max_filters": 1024,
26 | "cqtd_filters_scale": 1,
27 | "cqtd_dilations": [1, 2, 4],
28 | "cqtd_hop_lengths": [512, 256, 256],
29 | "cqtd_n_octaves": [9, 9, 9],
30 | "cqtd_bins_per_octaves": [24, 36, 48],
31 |
32 | "mpd_reshapes": [2, 3, 5, 7, 11],
33 | "use_spectral_norm": false,
34 | "discriminator_channel_mult": 1,
35 |
36 | "use_multiscale_melloss": true,
37 | "lambda_melloss": 15,
38 |
39 | "clip_grad_norm": 500,
40 |
41 | "segment_size": 65536,
42 | "num_mels": 80,
43 | "num_freq": 1025,
44 | "n_fft": 1024,
45 | "hop_size": 256,
46 | "win_size": 1024,
47 |
48 | "sampling_rate": 22050,
49 |
50 | "fmin": 0,
51 | "fmax": null,
52 | "fmax_for_loss": null,
53 |
54 | "normalize_volume": true,
55 |
56 | "num_workers": 4,
57 |
58 | "dist_config": {
59 | "dist_backend": "nccl",
60 | "dist_url": "tcp://localhost:54321",
61 | "world_size": 1
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/env.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import os
5 | import shutil
6 |
7 |
8 | class AttrDict(dict):
9 | def __init__(self, *args, **kwargs):
10 | super(AttrDict, self).__init__(*args, **kwargs)
11 | self.__dict__ = self
12 |
13 |
14 | def build_env(config, config_name, path):
15 | t_path = os.path.join(path, config_name)
16 | if config != t_path:
17 | os.makedirs(path, exist_ok=True)
18 | shutil.copyfile(config, os.path.join(path, config_name))
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/meldataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 NVIDIA CORPORATION.
2 | # Licensed under the MIT license.
3 |
4 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5 | # LICENSE is in incl_licenses directory.
6 |
7 | import math
8 | import os
9 | import random
10 | import torch
11 | import torch.utils.data
12 | import numpy as np
13 | from librosa.util import normalize
14 | from scipy.io.wavfile import read
15 | from librosa.filters import mel as librosa_mel_fn
16 | import pathlib
17 | from tqdm import tqdm
18 |
19 | MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
20 |
21 |
22 | def load_wav(full_path, sr_target):
23 | sampling_rate, data = read(full_path)
24 | if sampling_rate != sr_target:
25 | raise RuntimeError(
26 | f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
27 | )
28 | return data, sampling_rate
29 |
30 |
31 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
32 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
33 |
34 |
35 | def dynamic_range_decompression(x, C=1):
36 | return np.exp(x) / C
37 |
38 |
39 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
40 | return torch.log(torch.clamp(x, min=clip_val) * C)
41 |
42 |
43 | def dynamic_range_decompression_torch(x, C=1):
44 | return torch.exp(x) / C
45 |
46 |
47 | def spectral_normalize_torch(magnitudes):
48 | return dynamic_range_compression_torch(magnitudes)
49 |
50 |
51 | def spectral_de_normalize_torch(magnitudes):
52 | return dynamic_range_decompression_torch(magnitudes)
53 |
54 |
55 | mel_basis_cache = {}
56 | hann_window_cache = {}
57 |
58 |
59 | def mel_spectrogram(
60 | y: torch.Tensor,
61 | n_fft: int,
62 | num_mels: int,
63 | sampling_rate: int,
64 | hop_size: int,
65 | win_size: int,
66 | fmin: int,
67 | fmax: int = None,
68 | center: bool = False,
69 | ) -> torch.Tensor:
70 | """
71 | Calculate the mel spectrogram of an input signal.
72 | This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
73 |
74 | Args:
75 | y (torch.Tensor): Input signal.
76 | n_fft (int): FFT size.
77 | num_mels (int): Number of mel bins.
78 | sampling_rate (int): Sampling rate of the input signal.
79 | hop_size (int): Hop size for STFT.
80 | win_size (int): Window size for STFT.
81 | fmin (int): Minimum frequency for mel filterbank.
82 | fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
83 | center (bool): Whether to pad the input to center the frames. Default is False.
84 |
85 | Returns:
86 | torch.Tensor: Mel spectrogram.
87 | """
88 | if torch.min(y) < -1.0:
89 | print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
90 | if torch.max(y) > 1.0:
91 | print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
92 |
93 | device = y.device
94 | key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
95 |
96 | if key not in mel_basis_cache:
97 | mel = librosa_mel_fn(
98 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
99 | )
100 | mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
101 | hann_window_cache[key] = torch.hann_window(win_size).to(device)
102 |
103 | mel_basis = mel_basis_cache[key]
104 | hann_window = hann_window_cache[key]
105 |
106 | padding = (n_fft - hop_size) // 2
107 | y = torch.nn.functional.pad(
108 | y.unsqueeze(1), (padding, padding), mode="reflect"
109 | ).squeeze(1)
110 |
111 | spec = torch.stft(
112 | y,
113 | n_fft,
114 | hop_length=hop_size,
115 | win_length=win_size,
116 | window=hann_window,
117 | center=center,
118 | pad_mode="reflect",
119 | normalized=False,
120 | onesided=True,
121 | return_complex=True,
122 | )
123 | spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
124 |
125 | mel_spec = torch.matmul(mel_basis, spec)
126 | mel_spec = spectral_normalize_torch(mel_spec)
127 |
128 | return mel_spec
129 |
130 |
131 | def get_mel_spectrogram(wav, h):
132 | """
133 | Generate mel spectrogram from a waveform using given hyperparameters.
134 |
135 | Args:
136 | wav (torch.Tensor): Input waveform.
137 | h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
138 |
139 | Returns:
140 | torch.Tensor: Mel spectrogram.
141 | """
142 | return mel_spectrogram(
143 | wav,
144 | h.n_fft,
145 | h.num_mels,
146 | h.sampling_rate,
147 | h.hop_size,
148 | h.win_size,
149 | h.fmin,
150 | h.fmax,
151 | )
152 |
153 |
154 | def get_dataset_filelist(a):
155 | training_files = []
156 | validation_files = []
157 | list_unseen_validation_files = []
158 |
159 | with open(a.input_training_file, "r", encoding="utf-8") as fi:
160 | training_files = [
161 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
162 | for x in fi.read().split("\n")
163 | if len(x) > 0
164 | ]
165 | print(f"first training file: {training_files[0]}")
166 |
167 | with open(a.input_validation_file, "r", encoding="utf-8") as fi:
168 | validation_files = [
169 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
170 | for x in fi.read().split("\n")
171 | if len(x) > 0
172 | ]
173 | print(f"first validation file: {validation_files[0]}")
174 |
175 | for i in range(len(a.list_input_unseen_validation_file)):
176 | with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
177 | unseen_validation_files = [
178 | os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
179 | for x in fi.read().split("\n")
180 | if len(x) > 0
181 | ]
182 | print(
183 | f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
184 | )
185 | list_unseen_validation_files.append(unseen_validation_files)
186 |
187 | return training_files, validation_files, list_unseen_validation_files
188 |
189 |
190 | class MelDataset(torch.utils.data.Dataset):
191 | def __init__(
192 | self,
193 | training_files,
194 | hparams,
195 | segment_size,
196 | n_fft,
197 | num_mels,
198 | hop_size,
199 | win_size,
200 | sampling_rate,
201 | fmin,
202 | fmax,
203 | split=True,
204 | shuffle=True,
205 | n_cache_reuse=1,
206 | device=None,
207 | fmax_loss=None,
208 | fine_tuning=False,
209 | base_mels_path=None,
210 | is_seen=True,
211 | ):
212 | self.audio_files = training_files
213 | random.seed(1234)
214 | if shuffle:
215 | random.shuffle(self.audio_files)
216 | self.hparams = hparams
217 | self.is_seen = is_seen
218 | if self.is_seen:
219 | self.name = pathlib.Path(self.audio_files[0]).parts[0]
220 | else:
221 | self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
222 |
223 | self.segment_size = segment_size
224 | self.sampling_rate = sampling_rate
225 | self.split = split
226 | self.n_fft = n_fft
227 | self.num_mels = num_mels
228 | self.hop_size = hop_size
229 | self.win_size = win_size
230 | self.fmin = fmin
231 | self.fmax = fmax
232 | self.fmax_loss = fmax_loss
233 | self.cached_wav = None
234 | self.n_cache_reuse = n_cache_reuse
235 | self._cache_ref_count = 0
236 | self.device = device
237 | self.fine_tuning = fine_tuning
238 | self.base_mels_path = base_mels_path
239 |
240 | print("[INFO] checking dataset integrity...")
241 | for i in tqdm(range(len(self.audio_files))):
242 | assert os.path.exists(
243 | self.audio_files[i]
244 | ), f"{self.audio_files[i]} not found"
245 |
246 | def __getitem__(self, index):
247 | filename = self.audio_files[index]
248 | if self._cache_ref_count == 0:
249 | audio, sampling_rate = load_wav(filename, self.sampling_rate)
250 | audio = audio / MAX_WAV_VALUE
251 | if not self.fine_tuning:
252 | audio = normalize(audio) * 0.95
253 | self.cached_wav = audio
254 | if sampling_rate != self.sampling_rate:
255 | raise ValueError(
256 | f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
257 | )
258 | self._cache_ref_count = self.n_cache_reuse
259 | else:
260 | audio = self.cached_wav
261 | self._cache_ref_count -= 1
262 |
263 | audio = torch.FloatTensor(audio)
264 | audio = audio.unsqueeze(0)
265 |
266 | if not self.fine_tuning:
267 | if self.split:
268 | if audio.size(1) >= self.segment_size:
269 | max_audio_start = audio.size(1) - self.segment_size
270 | audio_start = random.randint(0, max_audio_start)
271 | audio = audio[:, audio_start : audio_start + self.segment_size]
272 | else:
273 | audio = torch.nn.functional.pad(
274 | audio, (0, self.segment_size - audio.size(1)), "constant"
275 | )
276 |
277 | mel = mel_spectrogram(
278 | audio,
279 | self.n_fft,
280 | self.num_mels,
281 | self.sampling_rate,
282 | self.hop_size,
283 | self.win_size,
284 | self.fmin,
285 | self.fmax,
286 | center=False,
287 | )
288 | else: # Validation step
289 | # Match audio length to self.hop_size * n for evaluation
290 | if (audio.size(1) % self.hop_size) != 0:
291 | audio = audio[:, : -(audio.size(1) % self.hop_size)]
292 | mel = mel_spectrogram(
293 | audio,
294 | self.n_fft,
295 | self.num_mels,
296 | self.sampling_rate,
297 | self.hop_size,
298 | self.win_size,
299 | self.fmin,
300 | self.fmax,
301 | center=False,
302 | )
303 | assert (
304 | audio.shape[1] == mel.shape[2] * self.hop_size
305 | ), f"audio shape {audio.shape} mel shape {mel.shape}"
306 |
307 | else:
308 | mel = np.load(
309 | os.path.join(
310 | self.base_mels_path,
311 | os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
312 | )
313 | )
314 | mel = torch.from_numpy(mel)
315 |
316 | if len(mel.shape) < 3:
317 | mel = mel.unsqueeze(0)
318 |
319 | if self.split:
320 | frames_per_seg = math.ceil(self.segment_size / self.hop_size)
321 |
322 | if audio.size(1) >= self.segment_size:
323 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
324 | mel = mel[:, :, mel_start : mel_start + frames_per_seg]
325 | audio = audio[
326 | :,
327 | mel_start
328 | * self.hop_size : (mel_start + frames_per_seg)
329 | * self.hop_size,
330 | ]
331 | else:
332 | mel = torch.nn.functional.pad(
333 | mel, (0, frames_per_seg - mel.size(2)), "constant"
334 | )
335 | audio = torch.nn.functional.pad(
336 | audio, (0, self.segment_size - audio.size(1)), "constant"
337 | )
338 |
339 | mel_loss = mel_spectrogram(
340 | audio,
341 | self.n_fft,
342 | self.num_mels,
343 | self.sampling_rate,
344 | self.hop_size,
345 | self.win_size,
346 | self.fmin,
347 | self.fmax_loss,
348 | center=False,
349 | )
350 |
351 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
352 |
353 | def __len__(self):
354 | return len(self.audio_files)
355 |
--------------------------------------------------------------------------------
/seedvc/modules/bigvgan/utils.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import glob
5 | import os
6 | import matplotlib
7 | import torch
8 | from torch.nn.utils import weight_norm
9 |
10 | matplotlib.use("Agg")
11 | import matplotlib.pylab as plt
12 | from .meldataset import MAX_WAV_VALUE
13 | from scipy.io.wavfile import write
14 |
15 |
16 | def plot_spectrogram(spectrogram):
17 | fig, ax = plt.subplots(figsize=(10, 2))
18 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
19 | plt.colorbar(im, ax=ax)
20 |
21 | fig.canvas.draw()
22 | plt.close()
23 |
24 | return fig
25 |
26 |
27 | def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
28 | fig, ax = plt.subplots(figsize=(10, 2))
29 | im = ax.imshow(
30 | spectrogram,
31 | aspect="auto",
32 | origin="lower",
33 | interpolation="none",
34 | vmin=1e-6,
35 | vmax=clip_max,
36 | )
37 | plt.colorbar(im, ax=ax)
38 |
39 | fig.canvas.draw()
40 | plt.close()
41 |
42 | return fig
43 |
44 |
45 | def init_weights(m, mean=0.0, std=0.01):
46 | classname = m.__class__.__name__
47 | if classname.find("Conv") != -1:
48 | m.weight.data.normal_(mean, std)
49 |
50 |
51 | def apply_weight_norm(m):
52 | classname = m.__class__.__name__
53 | if classname.find("Conv") != -1:
54 | weight_norm(m)
55 |
56 |
57 | def get_padding(kernel_size, dilation=1):
58 | return int((kernel_size * dilation - dilation) / 2)
59 |
60 |
61 | def load_checkpoint(filepath, device):
62 | assert os.path.isfile(filepath)
63 | print(f"Loading '{filepath}'")
64 | checkpoint_dict = torch.load(filepath, map_location=device)
65 | print("Complete.")
66 | return checkpoint_dict
67 |
68 |
69 | def save_checkpoint(filepath, obj):
70 | print(f"Saving checkpoint to {filepath}")
71 | torch.save(obj, filepath)
72 | print("Complete.")
73 |
74 |
75 | def scan_checkpoint(cp_dir, prefix, renamed_file=None):
76 | # Fallback to original scanning logic first
77 | pattern = os.path.join(cp_dir, prefix + "????????")
78 | cp_list = glob.glob(pattern)
79 |
80 | if len(cp_list) > 0:
81 | last_checkpoint_path = sorted(cp_list)[-1]
82 | print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
83 | return last_checkpoint_path
84 |
85 | # If no pattern-based checkpoints are found, check for renamed file
86 | if renamed_file:
87 | renamed_path = os.path.join(cp_dir, renamed_file)
88 | if os.path.isfile(renamed_path):
89 | print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
90 | return renamed_path
91 |
92 | return None
93 |
94 |
95 | def save_audio(audio, path, sr):
96 | # wav: torch with 1d shape
97 | audio = audio * MAX_WAV_VALUE
98 | audio = audio.cpu().numpy().astype("int16")
99 | write(path, sr, audio)
100 |
--------------------------------------------------------------------------------
/seedvc/modules/campplus/DTDNN.py:
--------------------------------------------------------------------------------
1 | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3 |
4 | from collections import OrderedDict
5 |
6 | import torch
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from .layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, BasicResBlock, get_nonlinear
11 |
12 |
13 | class FCM(nn.Module):
14 | def __init__(self,
15 | block=BasicResBlock,
16 | num_blocks=[2, 2],
17 | m_channels=32,
18 | feat_dim=80):
19 | super(FCM, self).__init__()
20 | self.in_planes = m_channels
21 | self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(m_channels)
23 |
24 | self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
25 | self.layer2 = self._make_layer(block, m_channels, num_blocks[1], stride=2)
26 |
27 | self.conv2 = nn.Conv2d(m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False)
28 | self.bn2 = nn.BatchNorm2d(m_channels)
29 | self.out_channels = m_channels * (feat_dim // 8)
30 |
31 | def _make_layer(self, block, planes, num_blocks, stride):
32 | strides = [stride] + [1] * (num_blocks - 1)
33 | layers = []
34 | for stride in strides:
35 | layers.append(block(self.in_planes, planes, stride))
36 | self.in_planes = planes * block.expansion
37 | return nn.Sequential(*layers)
38 |
39 | def forward(self, x):
40 | x = x.unsqueeze(1)
41 | out = F.relu(self.bn1(self.conv1(x)))
42 | out = self.layer1(out)
43 | out = self.layer2(out)
44 | out = F.relu(self.bn2(self.conv2(out)))
45 |
46 | shape = out.shape
47 | out = out.reshape(shape[0], shape[1]*shape[2], shape[3])
48 | return out
49 |
50 | class CAMPPlus(nn.Module):
51 | def __init__(self,
52 | feat_dim=80,
53 | embedding_size=512,
54 | growth_rate=32,
55 | bn_size=4,
56 | init_channels=128,
57 | config_str='batchnorm-relu',
58 | memory_efficient=True):
59 | super(CAMPPlus, self).__init__()
60 |
61 | self.head = FCM(feat_dim=feat_dim)
62 | channels = self.head.out_channels
63 |
64 | self.xvector = nn.Sequential(
65 | OrderedDict([
66 |
67 | ('tdnn',
68 | TDNNLayer(channels,
69 | init_channels,
70 | 5,
71 | stride=2,
72 | dilation=1,
73 | padding=-1,
74 | config_str=config_str)),
75 | ]))
76 | channels = init_channels
77 | for i, (num_layers, kernel_size,
78 | dilation) in enumerate(zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
79 | block = CAMDenseTDNNBlock(num_layers=num_layers,
80 | in_channels=channels,
81 | out_channels=growth_rate,
82 | bn_channels=bn_size * growth_rate,
83 | kernel_size=kernel_size,
84 | dilation=dilation,
85 | config_str=config_str,
86 | memory_efficient=memory_efficient)
87 | self.xvector.add_module('block%d' % (i + 1), block)
88 | channels = channels + num_layers * growth_rate
89 | self.xvector.add_module(
90 | 'transit%d' % (i + 1),
91 | TransitLayer(channels,
92 | channels // 2,
93 | bias=False,
94 | config_str=config_str))
95 | channels //= 2
96 |
97 | self.xvector.add_module(
98 | 'out_nonlinear', get_nonlinear(config_str, channels))
99 |
100 | self.xvector.add_module('stats', StatsPool())
101 | self.xvector.add_module(
102 | 'dense',
103 | DenseLayer(channels * 2, embedding_size, config_str='batchnorm_'))
104 |
105 | for m in self.modules():
106 | if isinstance(m, (nn.Conv1d, nn.Linear)):
107 | nn.init.kaiming_normal_(m.weight.data)
108 | if m.bias is not None:
109 | nn.init.zeros_(m.bias)
110 |
111 | def forward(self, x):
112 | x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
113 | x = self.head(x)
114 | x = self.xvector(x)
115 | return x
--------------------------------------------------------------------------------
/seedvc/modules/campplus/classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .layers import DenseLayer
9 |
10 |
11 | class CosineClassifier(nn.Module):
12 | def __init__(
13 | self,
14 | input_dim,
15 | num_blocks=0,
16 | inter_dim=512,
17 | out_neurons=1000,
18 | ):
19 |
20 | super().__init__()
21 | self.blocks = nn.ModuleList()
22 |
23 | for index in range(num_blocks):
24 | self.blocks.append(
25 | DenseLayer(input_dim, inter_dim, config_str='batchnorm')
26 | )
27 | input_dim = inter_dim
28 |
29 | self.weight = nn.Parameter(
30 | torch.FloatTensor(out_neurons, input_dim)
31 | )
32 | nn.init.xavier_uniform_(self.weight)
33 |
34 | def forward(self, x):
35 | # x: [B, dim]
36 | for layer in self.blocks:
37 | x = layer(x)
38 |
39 | # normalized
40 | x = F.linear(F.normalize(x), F.normalize(self.weight))
41 | return x
42 |
43 | class LinearClassifier(nn.Module):
44 | def __init__(
45 | self,
46 | input_dim,
47 | num_blocks=0,
48 | inter_dim=512,
49 | out_neurons=1000,
50 | ):
51 |
52 | super().__init__()
53 | self.blocks = nn.ModuleList()
54 |
55 | self.nonlinear = nn.ReLU(inplace=True)
56 | for index in range(num_blocks):
57 | self.blocks.append(
58 | DenseLayer(input_dim, inter_dim, bias=True)
59 | )
60 | input_dim = inter_dim
61 |
62 | self.linear = nn.Linear(input_dim, out_neurons, bias=True)
63 |
64 | def forward(self, x):
65 | # x: [B, dim]
66 | x = self.nonlinear(x)
67 | for layer in self.blocks:
68 | x = layer(x)
69 | x = self.linear(x)
70 | return x
--------------------------------------------------------------------------------
/seedvc/modules/campplus/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import torch.utils.checkpoint as cp
7 | from torch import nn
8 |
9 |
10 | def get_nonlinear(config_str, channels):
11 | nonlinear = nn.Sequential()
12 | for name in config_str.split('-'):
13 | if name == 'relu':
14 | nonlinear.add_module('relu', nn.ReLU(inplace=True))
15 | elif name == 'prelu':
16 | nonlinear.add_module('prelu', nn.PReLU(channels))
17 | elif name == 'batchnorm':
18 | nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
19 | elif name == 'batchnorm_':
20 | nonlinear.add_module('batchnorm',
21 | nn.BatchNorm1d(channels, affine=False))
22 | else:
23 | raise ValueError('Unexpected module ({}).'.format(name))
24 | return nonlinear
25 |
26 | def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
27 | mean = x.mean(dim=dim)
28 | std = x.std(dim=dim, unbiased=unbiased)
29 | stats = torch.cat([mean, std], dim=-1)
30 | if keepdim:
31 | stats = stats.unsqueeze(dim=dim)
32 | return stats
33 |
34 |
35 | class StatsPool(nn.Module):
36 | def forward(self, x):
37 | return statistics_pooling(x)
38 |
39 |
40 | class TDNNLayer(nn.Module):
41 | def __init__(self,
42 | in_channels,
43 | out_channels,
44 | kernel_size,
45 | stride=1,
46 | padding=0,
47 | dilation=1,
48 | bias=False,
49 | config_str='batchnorm-relu'):
50 | super(TDNNLayer, self).__init__()
51 | if padding < 0:
52 | assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
53 | kernel_size)
54 | padding = (kernel_size - 1) // 2 * dilation
55 | self.linear = nn.Conv1d(in_channels,
56 | out_channels,
57 | kernel_size,
58 | stride=stride,
59 | padding=padding,
60 | dilation=dilation,
61 | bias=bias)
62 | self.nonlinear = get_nonlinear(config_str, out_channels)
63 |
64 | def forward(self, x):
65 | x = self.linear(x)
66 | x = self.nonlinear(x)
67 | return x
68 |
69 |
70 | class CAMLayer(nn.Module):
71 | def __init__(self,
72 | bn_channels,
73 | out_channels,
74 | kernel_size,
75 | stride,
76 | padding,
77 | dilation,
78 | bias,
79 | reduction=2):
80 | super(CAMLayer, self).__init__()
81 | self.linear_local = nn.Conv1d(bn_channels,
82 | out_channels,
83 | kernel_size,
84 | stride=stride,
85 | padding=padding,
86 | dilation=dilation,
87 | bias=bias)
88 | self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
89 | self.relu = nn.ReLU(inplace=True)
90 | self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
91 | self.sigmoid = nn.Sigmoid()
92 |
93 | def forward(self, x):
94 | y = self.linear_local(x)
95 | context = x.mean(-1, keepdim=True)+self.seg_pooling(x)
96 | context = self.relu(self.linear1(context))
97 | m = self.sigmoid(self.linear2(context))
98 | return y*m
99 |
100 | def seg_pooling(self, x, seg_len=100, stype='avg'):
101 | if stype == 'avg':
102 | seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
103 | elif stype == 'max':
104 | seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
105 | else:
106 | raise ValueError('Wrong segment pooling type.')
107 | shape = seg.shape
108 | seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1)
109 | seg = seg[..., :x.shape[-1]]
110 | return seg
111 |
112 |
113 | class CAMDenseTDNNLayer(nn.Module):
114 | def __init__(self,
115 | in_channels,
116 | out_channels,
117 | bn_channels,
118 | kernel_size,
119 | stride=1,
120 | dilation=1,
121 | bias=False,
122 | config_str='batchnorm-relu',
123 | memory_efficient=False):
124 | super(CAMDenseTDNNLayer, self).__init__()
125 | assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
126 | kernel_size)
127 | padding = (kernel_size - 1) // 2 * dilation
128 | self.memory_efficient = memory_efficient
129 | self.nonlinear1 = get_nonlinear(config_str, in_channels)
130 | self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
131 | self.nonlinear2 = get_nonlinear(config_str, bn_channels)
132 | self.cam_layer = CAMLayer(bn_channels,
133 | out_channels,
134 | kernel_size,
135 | stride=stride,
136 | padding=padding,
137 | dilation=dilation,
138 | bias=bias)
139 |
140 | def bn_function(self, x):
141 | return self.linear1(self.nonlinear1(x))
142 |
143 | def forward(self, x):
144 | if self.training and self.memory_efficient:
145 | x = cp.checkpoint(self.bn_function, x)
146 | else:
147 | x = self.bn_function(x)
148 | x = self.cam_layer(self.nonlinear2(x))
149 | return x
150 |
151 |
152 | class CAMDenseTDNNBlock(nn.ModuleList):
153 | def __init__(self,
154 | num_layers,
155 | in_channels,
156 | out_channels,
157 | bn_channels,
158 | kernel_size,
159 | stride=1,
160 | dilation=1,
161 | bias=False,
162 | config_str='batchnorm-relu',
163 | memory_efficient=False):
164 | super(CAMDenseTDNNBlock, self).__init__()
165 | for i in range(num_layers):
166 | layer = CAMDenseTDNNLayer(in_channels=in_channels + i * out_channels,
167 | out_channels=out_channels,
168 | bn_channels=bn_channels,
169 | kernel_size=kernel_size,
170 | stride=stride,
171 | dilation=dilation,
172 | bias=bias,
173 | config_str=config_str,
174 | memory_efficient=memory_efficient)
175 | self.add_module('tdnnd%d' % (i + 1), layer)
176 |
177 | def forward(self, x):
178 | for layer in self:
179 | x = torch.cat([x, layer(x)], dim=1)
180 | return x
181 |
182 |
183 | class TransitLayer(nn.Module):
184 | def __init__(self,
185 | in_channels,
186 | out_channels,
187 | bias=True,
188 | config_str='batchnorm-relu'):
189 | super(TransitLayer, self).__init__()
190 | self.nonlinear = get_nonlinear(config_str, in_channels)
191 | self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
192 |
193 | def forward(self, x):
194 | x = self.nonlinear(x)
195 | x = self.linear(x)
196 | return x
197 |
198 |
199 | class DenseLayer(nn.Module):
200 | def __init__(self,
201 | in_channels,
202 | out_channels,
203 | bias=False,
204 | config_str='batchnorm-relu'):
205 | super(DenseLayer, self).__init__()
206 | self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
207 | self.nonlinear = get_nonlinear(config_str, out_channels)
208 |
209 | def forward(self, x):
210 | if len(x.shape) == 2:
211 | x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
212 | else:
213 | x = self.linear(x)
214 | x = self.nonlinear(x)
215 | return x
216 |
217 |
218 | class BasicResBlock(nn.Module):
219 | expansion = 1
220 |
221 | def __init__(self, in_planes, planes, stride=1):
222 | super(BasicResBlock, self).__init__()
223 | self.conv1 = nn.Conv2d(in_planes,
224 | planes,
225 | kernel_size=3,
226 | stride=(stride, 1),
227 | padding=1,
228 | bias=False)
229 | self.bn1 = nn.BatchNorm2d(planes)
230 | self.conv2 = nn.Conv2d(planes,
231 | planes,
232 | kernel_size=3,
233 | stride=1,
234 | padding=1,
235 | bias=False)
236 | self.bn2 = nn.BatchNorm2d(planes)
237 |
238 | self.shortcut = nn.Sequential()
239 | if stride != 1 or in_planes != self.expansion * planes:
240 | self.shortcut = nn.Sequential(
241 | nn.Conv2d(in_planes,
242 | self.expansion * planes,
243 | kernel_size=1,
244 | stride=(stride, 1),
245 | bias=False),
246 | nn.BatchNorm2d(self.expansion * planes))
247 |
248 | def forward(self, x):
249 | out = F.relu(self.bn1(self.conv1(x)))
250 | out = self.bn2(self.conv2(out))
251 | out += self.shortcut(x)
252 | out = F.relu(out)
253 | return out
--------------------------------------------------------------------------------
/seedvc/modules/diffusion_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 | from .gpt_fast.model import ModelArgs, Transformer
6 | # from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
7 | from .wavenet import WN
8 | from .commons import sequence_mask
9 |
10 | from torch.nn.utils import weight_norm
11 |
12 | def modulate(x, shift, scale):
13 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
14 |
15 |
16 | #################################################################################
17 | # Embedding Layers for Timesteps and Class Labels #
18 | #################################################################################
19 |
20 | class TimestepEmbedder(nn.Module):
21 | """
22 | Embeds scalar timesteps into vector representations.
23 | """
24 | def __init__(self, hidden_size, frequency_embedding_size=256):
25 | super().__init__()
26 | self.mlp = nn.Sequential(
27 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
28 | nn.SiLU(),
29 | nn.Linear(hidden_size, hidden_size, bias=True),
30 | )
31 | self.frequency_embedding_size = frequency_embedding_size
32 | self.max_period = 10000
33 | self.scale = 1000
34 |
35 | half = frequency_embedding_size // 2
36 | freqs = torch.exp(
37 | -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
38 | )
39 | self.register_buffer("freqs", freqs)
40 |
41 | def timestep_embedding(self, t):
42 | """
43 | Create sinusoidal timestep embeddings.
44 | :param t: a 1-D Tensor of N indices, one per batch element.
45 | These may be fractional.
46 | :param dim: the dimension of the output.
47 | :param max_period: controls the minimum frequency of the embeddings.
48 | :return: an (N, D) Tensor of positional embeddings.
49 | """
50 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51 |
52 | args = self.scale * t[:, None].float() * self.freqs[None]
53 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54 | if self.frequency_embedding_size % 2:
55 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56 | return embedding
57 |
58 | def forward(self, t):
59 | t_freq = self.timestep_embedding(t)
60 | t_emb = self.mlp(t_freq)
61 | return t_emb
62 |
63 |
64 | class StyleEmbedder(nn.Module):
65 | """
66 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67 | """
68 | def __init__(self, input_size, hidden_size, dropout_prob):
69 | super().__init__()
70 | use_cfg_embedding = dropout_prob > 0
71 | self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
72 | self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
73 | self.input_size = input_size
74 | self.dropout_prob = dropout_prob
75 |
76 | def forward(self, labels, train, force_drop_ids=None):
77 | use_dropout = self.dropout_prob > 0
78 | if (train and use_dropout) or (force_drop_ids is not None):
79 | labels = self.token_drop(labels, force_drop_ids)
80 | else:
81 | labels = self.style_in(labels)
82 | embeddings = labels
83 | return embeddings
84 |
85 | class FinalLayer(nn.Module):
86 | """
87 | The final layer of DiT.
88 | """
89 | def __init__(self, hidden_size, patch_size, out_channels):
90 | super().__init__()
91 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
92 | self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
93 | self.adaLN_modulation = nn.Sequential(
94 | nn.SiLU(),
95 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
96 | )
97 |
98 | def forward(self, x, c):
99 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
100 | x = modulate(self.norm_final(x), shift, scale)
101 | x = self.linear(x)
102 | return x
103 |
104 | class DiT(torch.nn.Module):
105 | def __init__(
106 | self,
107 | args
108 | ):
109 | super(DiT, self).__init__()
110 | self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
111 | self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
112 | self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
113 | model_args = ModelArgs(
114 | block_size=16384,#args.DiT.block_size,
115 | n_layer=args.DiT.depth,
116 | n_head=args.DiT.num_heads,
117 | dim=args.DiT.hidden_dim,
118 | head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
119 | vocab_size=1024,
120 | uvit_skip_connection=self.uvit_skip_connection,
121 | )
122 | self.transformer = Transformer(model_args)
123 | self.in_channels = args.DiT.in_channels
124 | self.out_channels = args.DiT.in_channels
125 | self.num_heads = args.DiT.num_heads
126 |
127 | self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
128 |
129 | self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
130 | self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
131 | self.content_dim = args.DiT.content_dim # for continuous content
132 | self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
133 | self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
134 |
135 | self.is_causal = args.DiT.is_causal
136 |
137 | self.n_f0_bins = args.DiT.n_f0_bins
138 | self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
139 | self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
140 | self.f0_condition = args.DiT.f0_condition
141 |
142 | self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
143 | self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
144 | # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
145 | # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
146 |
147 | input_pos = torch.arange(16384)
148 | self.register_buffer("input_pos", input_pos)
149 |
150 | self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
151 | self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
152 | self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
153 | if self.final_layer_type == 'wavenet':
154 | self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
155 | kernel_size=args.wavenet.kernel_size,
156 | dilation_rate=args.wavenet.dilation_rate,
157 | n_layers=args.wavenet.num_layers,
158 | gin_channels=args.wavenet.hidden_dim,
159 | p_dropout=args.wavenet.p_dropout,
160 | causal=False)
161 | self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
162 | else:
163 | self.final_mlp = nn.Sequential(
164 | nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
165 | nn.SiLU(),
166 | nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
167 | )
168 | self.transformer_style_condition = args.DiT.style_condition
169 | self.wavenet_style_condition = args.wavenet.style_condition
170 | assert args.DiT.style_condition == args.wavenet.style_condition
171 |
172 | self.class_dropout_prob = args.DiT.class_dropout_prob
173 | self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
174 | self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
175 | self.long_skip_connection = args.DiT.long_skip_connection
176 | self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
177 |
178 | self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
179 | args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
180 | args.DiT.hidden_dim)
181 | if self.style_as_token:
182 | self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
183 |
184 | def setup_caches(self, max_batch_size, max_seq_length):
185 | self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
186 | def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
187 | class_dropout = False
188 | if self.training and torch.rand(1) < self.class_dropout_prob:
189 | class_dropout = True
190 | if not self.training and mask_content:
191 | class_dropout = True
192 | # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
193 | cond_in_module = self.cond_projection
194 |
195 | B, _, T = x.size()
196 |
197 |
198 | t1 = self.t_embedder(t) # (N, D)
199 |
200 | cond = cond_in_module(cond)
201 | if self.f0_condition and f0 is not None:
202 | quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
203 | cond = cond + self.f0_embedder(quantized_f0)
204 |
205 | x = x.transpose(1, 2)
206 | prompt_x = prompt_x.transpose(1, 2)
207 |
208 | x_in = torch.cat([x, prompt_x, cond], dim=-1)
209 | if self.transformer_style_condition and not self.style_as_token:
210 | x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
211 | if class_dropout:
212 | x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
213 | x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
214 |
215 | if self.style_as_token:
216 | style = self.style_in(style)
217 | style = torch.zeros_like(style) if class_dropout else style
218 | x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
219 | if self.time_as_token:
220 | x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
221 | x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
222 | input_pos = self.input_pos[:x_in.size(1)] # (T,)
223 | x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
224 | x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
225 | x_res = x_res[:, 1:] if self.time_as_token else x_res
226 | x_res = x_res[:, 1:] if self.style_as_token else x_res
227 | if self.long_skip_connection:
228 | x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
229 | if self.final_layer_type == 'wavenet':
230 | x = self.conv1(x_res)
231 | x = x.transpose(1, 2)
232 | t2 = self.t_embedder2(t)
233 | x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
234 | x_res) # long residual connection
235 | x = self.final_layer(x, t1).transpose(1, 2)
236 | x = self.conv2(x)
237 | else:
238 | x = self.final_mlp(x_res)
239 | x = x.transpose(1, 2)
240 | return x
241 |
--------------------------------------------------------------------------------
/seedvc/modules/encodec.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Convolutional layers wrappers and utilities."""
8 |
9 | import math
10 | import typing as tp
11 | import warnings
12 |
13 | import torch
14 | from torch import nn
15 | from torch.nn import functional as F
16 | from torch.nn.utils import spectral_norm, weight_norm
17 |
18 | import typing as tp
19 |
20 | import einops
21 |
22 |
23 | class ConvLayerNorm(nn.LayerNorm):
24 | """
25 | Convolution-friendly LayerNorm that moves channels to last dimensions
26 | before running the normalization and moves them back to original position right after.
27 | """
28 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29 | super().__init__(normalized_shape, **kwargs)
30 |
31 | def forward(self, x):
32 | x = einops.rearrange(x, 'b ... t -> b t ...')
33 | x = super().forward(x)
34 | x = einops.rearrange(x, 'b t ... -> b ... t')
35 | return
36 |
37 |
38 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39 | 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40 |
41 |
42 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43 | assert norm in CONV_NORMALIZATIONS
44 | if norm == 'weight_norm':
45 | return weight_norm(module)
46 | elif norm == 'spectral_norm':
47 | return spectral_norm(module)
48 | else:
49 | # We already check was in CONV_NORMALIZATION, so any other choice
50 | # doesn't need reparametrization.
51 | return module
52 |
53 |
54 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55 | """Return the proper normalization module. If causal is True, this will ensure the returned
56 | module is causal, or return an error if the normalization doesn't support causal evaluation.
57 | """
58 | assert norm in CONV_NORMALIZATIONS
59 | if norm == 'layer_norm':
60 | assert isinstance(module, nn.modules.conv._ConvNd)
61 | return ConvLayerNorm(module.out_channels, **norm_kwargs)
62 | elif norm == 'time_group_norm':
63 | if causal:
64 | raise ValueError("GroupNorm doesn't support causal evaluation.")
65 | assert isinstance(module, nn.modules.conv._ConvNd)
66 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67 | else:
68 | return nn.Identity()
69 |
70 |
71 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72 | padding_total: int = 0) -> int:
73 | """See `pad_for_conv1d`.
74 | """
75 | length = x.shape[-1]
76 | n_frames = (length - kernel_size + padding_total) / stride + 1
77 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78 | return ideal_length - length
79 |
80 |
81 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82 | """Pad for a convolution to make sure that the last window is full.
83 | Extra padding is added at the end. This is required to ensure that we can rebuild
84 | an output of the same length, as otherwise, even with padding, some time steps
85 | might get removed.
86 | For instance, with total padding = 4, kernel size = 4, stride = 2:
87 | 0 0 1 2 3 4 5 0 0 # (0s are padding)
88 | 1 2 3 # (output frames of a convolution, last 0 is never used)
89 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90 | 1 2 3 4 # once you removed padding, we are missing one time step !
91 | """
92 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93 | return F.pad(x, (0, extra_padding))
94 |
95 |
96 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98 | If this is the case, we insert extra 0 padding to the right before the reflection happen.
99 | """
100 | length = x.shape[-1]
101 | padding_left, padding_right = paddings
102 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103 | if mode == 'reflect':
104 | max_pad = max(padding_left, padding_right)
105 | extra_pad = 0
106 | if length <= max_pad:
107 | extra_pad = max_pad - length + 1
108 | x = F.pad(x, (0, extra_pad))
109 | padded = F.pad(x, paddings, mode, value)
110 | end = padded.shape[-1] - extra_pad
111 | return padded[..., :end]
112 | else:
113 | return F.pad(x, paddings, mode, value)
114 |
115 |
116 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117 | """Remove padding from x, handling properly zero padding. Only for 1d!"""
118 | padding_left, padding_right = paddings
119 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120 | assert (padding_left + padding_right) <= x.shape[-1]
121 | end = x.shape[-1] - padding_right
122 | return x[..., padding_left: end]
123 |
124 |
125 | class NormConv1d(nn.Module):
126 | """Wrapper around Conv1d and normalization applied to this conv
127 | to provide a uniform interface across normalization approaches.
128 | """
129 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
130 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131 | super().__init__()
132 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134 | self.norm_type = norm
135 |
136 | def forward(self, x):
137 | x = self.conv(x)
138 | x = self.norm(x)
139 | return x
140 |
141 |
142 | class NormConv2d(nn.Module):
143 | """Wrapper around Conv2d and normalization applied to this conv
144 | to provide a uniform interface across normalization approaches.
145 | """
146 | def __init__(self, *args, norm: str = 'none',
147 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148 | super().__init__()
149 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151 | self.norm_type = norm
152 |
153 | def forward(self, x):
154 | x = self.conv(x)
155 | x = self.norm(x)
156 | return x
157 |
158 |
159 | class NormConvTranspose1d(nn.Module):
160 | """Wrapper around ConvTranspose1d and normalization applied to this conv
161 | to provide a uniform interface across normalization approaches.
162 | """
163 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
164 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165 | super().__init__()
166 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168 | self.norm_type = norm
169 |
170 | def forward(self, x):
171 | x = self.convtr(x)
172 | x = self.norm(x)
173 | return x
174 |
175 |
176 | class NormConvTranspose2d(nn.Module):
177 | """Wrapper around ConvTranspose2d and normalization applied to this conv
178 | to provide a uniform interface across normalization approaches.
179 | """
180 | def __init__(self, *args, norm: str = 'none',
181 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182 | super().__init__()
183 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185 |
186 | def forward(self, x):
187 | x = self.convtr(x)
188 | x = self.norm(x)
189 | return x
190 |
191 |
192 | class SConv1d(nn.Module):
193 | """Conv1d with some builtin handling of asymmetric or causal padding
194 | and normalization.
195 | """
196 | def __init__(self, in_channels: int, out_channels: int,
197 | kernel_size: int, stride: int = 1, dilation: int = 1,
198 | groups: int = 1, bias: bool = True, causal: bool = False,
199 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200 | pad_mode: str = 'reflect', **kwargs):
201 | super().__init__()
202 | # warn user on unusual setup between dilation and stride
203 | if stride > 1 and dilation > 1:
204 | warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207 | dilation=dilation, groups=groups, bias=bias, causal=causal,
208 | norm=norm, norm_kwargs=norm_kwargs)
209 | self.causal = causal
210 | self.pad_mode = pad_mode
211 |
212 | def forward(self, x):
213 | B, C, T = x.shape
214 | kernel_size = self.conv.conv.kernel_size[0]
215 | stride = self.conv.conv.stride[0]
216 | dilation = self.conv.conv.dilation[0]
217 | kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
218 | padding_total = kernel_size - stride
219 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
220 | if self.causal:
221 | # Left padding for causal
222 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
223 | else:
224 | # Asymmetric padding required for odd strides
225 | padding_right = padding_total // 2
226 | padding_left = padding_total - padding_right
227 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
228 | return self.conv(x)
229 |
230 |
231 | class SConvTranspose1d(nn.Module):
232 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding
233 | and normalization.
234 | """
235 | def __init__(self, in_channels: int, out_channels: int,
236 | kernel_size: int, stride: int = 1, causal: bool = False,
237 | norm: str = 'none', trim_right_ratio: float = 1.,
238 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
239 | super().__init__()
240 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
241 | causal=causal, norm=norm, norm_kwargs=norm_kwargs)
242 | self.causal = causal
243 | self.trim_right_ratio = trim_right_ratio
244 | assert self.causal or self.trim_right_ratio == 1., \
245 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
246 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
247 |
248 | def forward(self, x):
249 | kernel_size = self.convtr.convtr.kernel_size[0]
250 | stride = self.convtr.convtr.stride[0]
251 | padding_total = kernel_size - stride
252 |
253 | y = self.convtr(x)
254 |
255 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
256 | # removed at the very end, when keeping only the right length for the output,
257 | # as removing it here would require also passing the length at the matching layer
258 | # in the encoder.
259 | if self.causal:
260 | # Trim the padding on the right according to the specified ratio
261 | # if trim_right_ratio = 1.0, trim everything from right
262 | padding_right = math.ceil(padding_total * self.trim_right_ratio)
263 | padding_left = padding_total - padding_right
264 | y = unpad1d(y, (padding_left, padding_right))
265 | else:
266 | # Asymmetric padding required for odd strides
267 | padding_right = padding_total // 2
268 | padding_left = padding_total - padding_right
269 | y = unpad1d(y, (padding_left, padding_right))
270 | return y
271 |
272 | class SLSTM(nn.Module):
273 | """
274 | LSTM without worrying about the hidden state, nor the layout of the data.
275 | Expects input as convolutional layout.
276 | """
277 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
278 | super().__init__()
279 | self.skip = skip
280 | self.lstm = nn.LSTM(dimension, dimension, num_layers)
281 | self.hidden = None
282 |
283 | def forward(self, x):
284 | x = x.permute(2, 0, 1)
285 | if self.training:
286 | y, _ = self.lstm(x)
287 | else:
288 | y, self.hidden = self.lstm(x, self.hidden)
289 | if self.skip:
290 | y = y + x
291 | y = y.permute(1, 2, 0)
292 | return y
--------------------------------------------------------------------------------
/seedvc/modules/flow_matching.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from .diffusion_transformer import DiT
7 | from .commons import sequence_mask
8 |
9 | from tqdm import tqdm
10 |
11 | class BASECFM(torch.nn.Module, ABC):
12 | def __init__(
13 | self,
14 | args,
15 | ):
16 | super().__init__()
17 | self.sigma_min = 1e-6
18 |
19 | self.estimator = None
20 |
21 | self.in_channels = args.DiT.in_channels
22 |
23 | self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
24 |
25 | if hasattr(args.DiT, 'zero_prompt_speech_token'):
26 | self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
27 | else:
28 | self.zero_prompt_speech_token = False
29 |
30 | @torch.inference_mode()
31 | def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
32 | """Forward diffusion
33 |
34 | Args:
35 | mu (torch.Tensor): output of encoder
36 | shape: (batch_size, n_feats, mel_timesteps)
37 | mask (torch.Tensor): output_mask
38 | shape: (batch_size, 1, mel_timesteps)
39 | n_timesteps (int): number of diffusion steps
40 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
41 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
42 | shape: (batch_size, spk_emb_dim)
43 | cond: Not used but kept for future purposes
44 |
45 | Returns:
46 | sample: generated mel-spectrogram
47 | shape: (batch_size, n_feats, mel_timesteps)
48 | """
49 | B, T = mu.size(0), mu.size(1)
50 | z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
51 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
52 | # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
53 | return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
54 |
55 | def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
56 | """
57 | Fixed euler solver for ODEs.
58 | Args:
59 | x (torch.Tensor): random noise
60 | t_span (torch.Tensor): n_timesteps interpolated
61 | shape: (n_timesteps + 1,)
62 | mu (torch.Tensor): output of encoder
63 | shape: (batch_size, n_feats, mel_timesteps)
64 | mask (torch.Tensor): output_mask
65 | shape: (batch_size, 1, mel_timesteps)
66 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
67 | shape: (batch_size, spk_emb_dim)
68 | cond: Not used but kept for future purposes
69 | """
70 | t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
71 |
72 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file
73 | # Or in future might add like a return_all_steps flag
74 | sol = []
75 | # apply prompt
76 | prompt_len = prompt.size(-1)
77 | prompt_x = torch.zeros_like(x)
78 | prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
79 | x[..., :prompt_len] = 0
80 | if self.zero_prompt_speech_token:
81 | mu[..., :prompt_len] = 0
82 | for step in tqdm(range(1, len(t_span))):
83 | dt = t_span[step] - t_span[step - 1]
84 | if inference_cfg_rate > 0:
85 | # Stack original and CFG (null) inputs for batched processing
86 | stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
87 | stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
88 | stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
89 | stacked_x = torch.cat([x, x], dim=0)
90 |
91 | # Perform a single forward pass for both original and CFG inputs
92 | stacked_dphi_dt = self.estimator(
93 | stacked_x, stacked_prompt_x, x_lens, t.unsqueeze(0), stacked_style, stacked_mu, None
94 | )
95 |
96 | # Split the output back into the original and CFG components
97 | dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
98 |
99 | # Apply CFG formula
100 | dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
101 | else:
102 | dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
103 |
104 | x = x + dt * dphi_dt
105 | t = t + dt
106 | sol.append(x)
107 | if step < len(t_span) - 1:
108 | dt = t_span[step + 1] - t
109 | x[:, :, :prompt_len] = 0
110 |
111 | return sol[-1]
112 |
113 |
114 |
115 | class CFM(BASECFM):
116 | def __init__(self, args):
117 | super().__init__(
118 | args
119 | )
120 | if args.dit_type == "DiT":
121 | self.estimator = DiT(args)
122 | else:
123 | raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
124 |
--------------------------------------------------------------------------------
/seedvc/modules/hifigan/f0_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn.utils import weight_norm
17 |
18 |
19 | class ConvRNNF0Predictor(nn.Module):
20 | def __init__(self,
21 | num_class: int = 1,
22 | in_channels: int = 80,
23 | cond_channels: int = 512
24 | ):
25 | super().__init__()
26 |
27 | self.num_class = num_class
28 | self.condnet = nn.Sequential(
29 | weight_norm(
30 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31 | ),
32 | nn.ELU(),
33 | weight_norm(
34 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35 | ),
36 | nn.ELU(),
37 | weight_norm(
38 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39 | ),
40 | nn.ELU(),
41 | weight_norm(
42 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43 | ),
44 | nn.ELU(),
45 | weight_norm(
46 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47 | ),
48 | nn.ELU(),
49 | )
50 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51 |
52 | def forward(self, x: torch.Tensor) -> torch.Tensor:
53 | x = self.condnet(x)
54 | x = x.transpose(1, 2)
55 | return torch.abs(self.classifier(x).squeeze(-1))
56 |
--------------------------------------------------------------------------------
/seedvc/modules/length_regulator.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 | from .commons import sequence_mask
6 | import numpy as np
7 | from ..dac.nn.quantize import VectorQuantize
8 |
9 | # f0_bin = 256
10 | f0_max = 1100.0
11 | f0_min = 50.0
12 | f0_mel_min = 1127 * np.log(1 + f0_min / 700)
13 | f0_mel_max = 1127 * np.log(1 + f0_max / 700)
14 |
15 | def f0_to_coarse(f0, f0_bin):
16 | f0_mel = 1127 * (1 + f0 / 700).log()
17 | a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
18 | b = f0_mel_min * a - 1.
19 | f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
20 | # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
21 | f0_coarse = torch.round(f0_mel).long()
22 | f0_coarse = f0_coarse * (f0_coarse > 0)
23 | f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
24 | f0_coarse = f0_coarse * (f0_coarse < f0_bin)
25 | f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
26 | return f0_coarse
27 |
28 | class InterpolateRegulator(nn.Module):
29 | def __init__(
30 | self,
31 | channels: int,
32 | sampling_ratios: Tuple,
33 | is_discrete: bool = False,
34 | in_channels: int = None, # only applies to continuous input
35 | vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
36 | codebook_size: int = 1024, # for discrete only
37 | out_channels: int = None,
38 | groups: int = 1,
39 | n_codebooks: int = 1, # number of codebooks
40 | quantizer_dropout: float = 0.0, # dropout for quantizer
41 | f0_condition: bool = False,
42 | n_f0_bins: int = 512,
43 | ):
44 | super().__init__()
45 | self.sampling_ratios = sampling_ratios
46 | out_channels = out_channels or channels
47 | model = nn.ModuleList([])
48 | if len(sampling_ratios) > 0:
49 | self.interpolate = True
50 | for _ in sampling_ratios:
51 | module = nn.Conv1d(channels, channels, 3, 1, 1)
52 | norm = nn.GroupNorm(groups, channels)
53 | act = nn.Mish()
54 | model.extend([module, norm, act])
55 | else:
56 | self.interpolate = False
57 | model.append(
58 | nn.Conv1d(channels, out_channels, 1, 1)
59 | )
60 | self.model = nn.Sequential(*model)
61 | self.embedding = nn.Embedding(codebook_size, channels)
62 | self.is_discrete = is_discrete
63 |
64 | self.mask_token = nn.Parameter(torch.zeros(1, channels))
65 |
66 | self.n_codebooks = n_codebooks
67 | if n_codebooks > 1:
68 | self.extra_codebooks = nn.ModuleList([
69 | nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
70 | ])
71 | self.extra_codebook_mask_tokens = nn.ParameterList([
72 | nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
73 | ])
74 | self.quantizer_dropout = quantizer_dropout
75 |
76 | if f0_condition:
77 | self.f0_embedding = nn.Embedding(n_f0_bins, channels)
78 | self.f0_condition = f0_condition
79 | self.n_f0_bins = n_f0_bins
80 | self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
81 | self.f0_mask = nn.Parameter(torch.zeros(1, channels))
82 | else:
83 | self.f0_condition = False
84 |
85 | if not is_discrete:
86 | self.content_in_proj = nn.Linear(in_channels, channels)
87 | if vector_quantize:
88 | self.vq = VectorQuantize(channels, codebook_size, 8)
89 |
90 | def forward(self, x, ylens=None, n_quantizers=None, f0=None):
91 | # apply token drop
92 | if self.training:
93 | n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
94 | dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
95 | n_dropout = int(x.shape[0] * self.quantizer_dropout)
96 | n_quantizers[:n_dropout] = dropout[:n_dropout]
97 | n_quantizers = n_quantizers.to(x.device)
98 | # decide whether to drop for each sample in batch
99 | else:
100 | n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
101 | if self.is_discrete:
102 | if self.n_codebooks > 1:
103 | assert len(x.size()) == 3
104 | x_emb = self.embedding(x[:, 0])
105 | for i, emb in enumerate(self.extra_codebooks):
106 | x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
107 | # add mask token if not using this codebook
108 | # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
109 | x = x_emb
110 | elif self.n_codebooks == 1:
111 | if len(x.size()) == 2:
112 | x = self.embedding(x)
113 | else:
114 | x = self.embedding(x[:, 0])
115 | else:
116 | x = self.content_in_proj(x)
117 | # x in (B, T, D)
118 | mask = sequence_mask(ylens).unsqueeze(-1)
119 | if self.interpolate:
120 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
121 | else:
122 | x = x.transpose(1, 2).contiguous()
123 | mask = mask[:, :x.size(2), :]
124 | ylens = ylens.clamp(max=x.size(2)).long()
125 | if self.f0_condition:
126 | if f0 is None:
127 | x = x + self.f0_mask.unsqueeze(-1)
128 | else:
129 | #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
130 | quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
131 | quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
132 | f0_emb = self.f0_embedding(quantized_f0)
133 | f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
134 | x = x + f0_emb
135 | out = self.model(x).transpose(1, 2).contiguous()
136 | if hasattr(self, 'vq'):
137 | out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
138 | out_q = out_q.transpose(1, 2)
139 | return out_q * mask, ylens, codes, commitment_loss, codebook_loss
140 | olens = ylens
141 | return out * mask, olens, None, None, None
142 |
--------------------------------------------------------------------------------
/seedvc/modules/quantize.py:
--------------------------------------------------------------------------------
1 | from dac.nn.quantize import ResidualVectorQuantize
2 | from torch import nn
3 | from .wavenet import WN
4 | import torch
5 | import torchaudio
6 | import torchaudio.functional as audio_F
7 | import numpy as np
8 | from .alias_free_torch import *
9 | from torch.nn.utils import weight_norm
10 | from torch import nn, sin, pow
11 | from einops.layers.torch import Rearrange
12 | from ..dac.model.encodec import SConv1d
13 |
14 | def init_weights(m):
15 | if isinstance(m, nn.Conv1d):
16 | nn.init.trunc_normal_(m.weight, std=0.02)
17 | nn.init.constant_(m.bias, 0)
18 |
19 |
20 | def WNConv1d(*args, **kwargs):
21 | return weight_norm(nn.Conv1d(*args, **kwargs))
22 |
23 |
24 | def WNConvTranspose1d(*args, **kwargs):
25 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
26 |
27 | class SnakeBeta(nn.Module):
28 | """
29 | A modified Snake function which uses separate parameters for the magnitude of the periodic components
30 | Shape:
31 | - Input: (B, C, T)
32 | - Output: (B, C, T), same shape as the input
33 | Parameters:
34 | - alpha - trainable parameter that controls frequency
35 | - beta - trainable parameter that controls magnitude
36 | References:
37 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
38 | https://arxiv.org/abs/2006.08195
39 | Examples:
40 | >>> a1 = snakebeta(256)
41 | >>> x = torch.randn(256)
42 | >>> x = a1(x)
43 | """
44 |
45 | def __init__(
46 | self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
47 | ):
48 | """
49 | Initialization.
50 | INPUT:
51 | - in_features: shape of the input
52 | - alpha - trainable parameter that controls frequency
53 | - beta - trainable parameter that controls magnitude
54 | alpha is initialized to 1 by default, higher values = higher-frequency.
55 | beta is initialized to 1 by default, higher values = higher-magnitude.
56 | alpha will be trained along with the rest of your model.
57 | """
58 | super(SnakeBeta, self).__init__()
59 | self.in_features = in_features
60 |
61 | # initialize alpha
62 | self.alpha_logscale = alpha_logscale
63 | if self.alpha_logscale: # log scale alphas initialized to zeros
64 | self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
65 | self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
66 | else: # linear scale alphas initialized to ones
67 | self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
68 | self.beta = nn.Parameter(torch.ones(in_features) * alpha)
69 |
70 | self.alpha.requires_grad = alpha_trainable
71 | self.beta.requires_grad = alpha_trainable
72 |
73 | self.no_div_by_zero = 0.000000001
74 |
75 | def forward(self, x):
76 | """
77 | Forward pass of the function.
78 | Applies the function to the input elementwise.
79 | SnakeBeta := x + 1/b * sin^2 (xa)
80 | """
81 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
82 | beta = self.beta.unsqueeze(0).unsqueeze(-1)
83 | if self.alpha_logscale:
84 | alpha = torch.exp(alpha)
85 | beta = torch.exp(beta)
86 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
87 |
88 | return x
89 |
90 | class ResidualUnit(nn.Module):
91 | def __init__(self, dim: int = 16, dilation: int = 1):
92 | super().__init__()
93 | pad = ((7 - 1) * dilation) // 2
94 | self.block = nn.Sequential(
95 | Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
96 | WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
97 | Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
98 | WNConv1d(dim, dim, kernel_size=1),
99 | )
100 |
101 | def forward(self, x):
102 | return x + self.block(x)
103 |
104 | class CNNLSTM(nn.Module):
105 | def __init__(self, indim, outdim, head, global_pred=False):
106 | super().__init__()
107 | self.global_pred = global_pred
108 | self.model = nn.Sequential(
109 | ResidualUnit(indim, dilation=1),
110 | ResidualUnit(indim, dilation=2),
111 | ResidualUnit(indim, dilation=3),
112 | Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
113 | Rearrange("b c t -> b t c"),
114 | )
115 | self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
116 |
117 | def forward(self, x):
118 | # x: [B, C, T]
119 | x = self.model(x)
120 | if self.global_pred:
121 | x = torch.mean(x, dim=1, keepdim=False)
122 | outs = [head(x) for head in self.heads]
123 | return outs
124 |
125 | def sequence_mask(length, max_length=None):
126 | if max_length is None:
127 | max_length = length.max()
128 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
129 | return x.unsqueeze(0) < length.unsqueeze(1)
130 | class FAquantizer(nn.Module):
131 | def __init__(self, in_dim=1024,
132 | n_p_codebooks=1,
133 | n_c_codebooks=2,
134 | n_t_codebooks=2,
135 | n_r_codebooks=3,
136 | codebook_size=1024,
137 | codebook_dim=8,
138 | quantizer_dropout=0.5,
139 | causal=False,
140 | separate_prosody_encoder=False,
141 | timbre_norm=False,):
142 | super(FAquantizer, self).__init__()
143 | conv1d_type = SConv1d# if causal else nn.Conv1d
144 | self.prosody_quantizer = ResidualVectorQuantize(
145 | input_dim=in_dim,
146 | n_codebooks=n_p_codebooks,
147 | codebook_size=codebook_size,
148 | codebook_dim=codebook_dim,
149 | quantizer_dropout=quantizer_dropout,
150 | )
151 |
152 | self.content_quantizer = ResidualVectorQuantize(
153 | input_dim=in_dim,
154 | n_codebooks=n_c_codebooks,
155 | codebook_size=codebook_size,
156 | codebook_dim=codebook_dim,
157 | quantizer_dropout=quantizer_dropout,
158 | )
159 |
160 | self.residual_quantizer = ResidualVectorQuantize(
161 | input_dim=in_dim,
162 | n_codebooks=n_r_codebooks,
163 | codebook_size=codebook_size,
164 | codebook_dim=codebook_dim,
165 | quantizer_dropout=quantizer_dropout,
166 | )
167 |
168 | self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal)
169 | self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal)
170 | self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal)
171 |
172 | self.prob_random_mask_residual = 0.75
173 |
174 | SPECT_PARAMS = {
175 | "n_fft": 2048,
176 | "win_length": 1200,
177 | "hop_length": 300,
178 | }
179 | MEL_PARAMS = {
180 | "n_mels": 80,
181 | }
182 |
183 | self.to_mel = torchaudio.transforms.MelSpectrogram(
184 | n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
185 | )
186 | self.mel_mean, self.mel_std = -4, 4
187 | self.frame_rate = 24000 / 300
188 | self.hop_length = 300
189 |
190 | def preprocess(self, wave_tensor, n_bins=20):
191 | mel_tensor = self.to_mel(wave_tensor.squeeze(1))
192 | mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
193 | return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)]
194 |
195 | def forward(self, x, wave_segments):
196 | outs = 0
197 | prosody_feature = self.preprocess(wave_segments)
198 |
199 | f0_input = prosody_feature # (B, T, 20)
200 | f0_input = self.melspec_linear(f0_input)
201 | f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(
202 | f0_input.device).bool())
203 | f0_input = self.melspec_linear2(f0_input)
204 |
205 | common_min_size = min(f0_input.size(2), x.size(2))
206 | f0_input = f0_input[:, :, :common_min_size]
207 |
208 | x = x[:, :, :common_min_size]
209 |
210 | z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
211 | f0_input, 1
212 | )
213 | outs += z_p.detach()
214 |
215 | z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
216 | x, 2
217 | )
218 | outs += z_c.detach()
219 |
220 | residual_feature = x - z_p.detach() - z_c.detach()
221 |
222 | z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
223 | residual_feature, 3
224 | )
225 |
226 | quantized = [z_p, z_c, z_r]
227 | codes = [codes_p, codes_c, codes_r]
228 |
229 | return quantized, codes
--------------------------------------------------------------------------------
/seedvc/modules/wavenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from .encodec import SConv1d
7 |
8 | from . import commons
9 | LRELU_SLOPE = 0.1
10 |
11 | class LayerNorm(nn.Module):
12 | def __init__(self, channels, eps=1e-5):
13 | super().__init__()
14 | self.channels = channels
15 | self.eps = eps
16 |
17 | self.gamma = nn.Parameter(torch.ones(channels))
18 | self.beta = nn.Parameter(torch.zeros(channels))
19 |
20 | def forward(self, x):
21 | x = x.transpose(1, -1)
22 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23 | return x.transpose(1, -1)
24 |
25 |
26 | class ConvReluNorm(nn.Module):
27 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
28 | super().__init__()
29 | self.in_channels = in_channels
30 | self.hidden_channels = hidden_channels
31 | self.out_channels = out_channels
32 | self.kernel_size = kernel_size
33 | self.n_layers = n_layers
34 | self.p_dropout = p_dropout
35 | assert n_layers > 1, "Number of layers should be larger than 0."
36 |
37 | self.conv_layers = nn.ModuleList()
38 | self.norm_layers = nn.ModuleList()
39 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
40 | self.norm_layers.append(LayerNorm(hidden_channels))
41 | self.relu_drop = nn.Sequential(
42 | nn.ReLU(),
43 | nn.Dropout(p_dropout))
44 | for _ in range(n_layers - 1):
45 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
46 | self.norm_layers.append(LayerNorm(hidden_channels))
47 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
48 | self.proj.weight.data.zero_()
49 | self.proj.bias.data.zero_()
50 |
51 | def forward(self, x, x_mask):
52 | x_org = x
53 | for i in range(self.n_layers):
54 | x = self.conv_layers[i](x * x_mask)
55 | x = self.norm_layers[i](x)
56 | x = self.relu_drop(x)
57 | x = x_org + self.proj(x)
58 | return x * x_mask
59 |
60 |
61 | class DDSConv(nn.Module):
62 | """
63 | Dialted and Depth-Separable Convolution
64 | """
65 |
66 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
67 | super().__init__()
68 | self.channels = channels
69 | self.kernel_size = kernel_size
70 | self.n_layers = n_layers
71 | self.p_dropout = p_dropout
72 |
73 | self.drop = nn.Dropout(p_dropout)
74 | self.convs_sep = nn.ModuleList()
75 | self.convs_1x1 = nn.ModuleList()
76 | self.norms_1 = nn.ModuleList()
77 | self.norms_2 = nn.ModuleList()
78 | for i in range(n_layers):
79 | dilation = kernel_size ** i
80 | padding = (kernel_size * dilation - dilation) // 2
81 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
82 | groups=channels, dilation=dilation, padding=padding
83 | ))
84 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
85 | self.norms_1.append(LayerNorm(channels))
86 | self.norms_2.append(LayerNorm(channels))
87 |
88 | def forward(self, x, x_mask, g=None):
89 | if g is not None:
90 | x = x + g
91 | for i in range(self.n_layers):
92 | y = self.convs_sep[i](x * x_mask)
93 | y = self.norms_1[i](y)
94 | y = F.gelu(y)
95 | y = self.convs_1x1[i](y)
96 | y = self.norms_2[i](y)
97 | y = F.gelu(y)
98 | y = self.drop(y)
99 | x = x + y
100 | return x * x_mask
101 |
102 |
103 | class WN(torch.nn.Module):
104 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False):
105 | super(WN, self).__init__()
106 | conv1d_type = SConv1d
107 | assert (kernel_size % 2 == 1)
108 | self.hidden_channels = hidden_channels
109 | self.kernel_size = kernel_size,
110 | self.dilation_rate = dilation_rate
111 | self.n_layers = n_layers
112 | self.gin_channels = gin_channels
113 | self.p_dropout = p_dropout
114 |
115 | self.in_layers = torch.nn.ModuleList()
116 | self.res_skip_layers = torch.nn.ModuleList()
117 | self.drop = nn.Dropout(p_dropout)
118 |
119 | if gin_channels != 0:
120 | self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm')
121 |
122 | for i in range(n_layers):
123 | dilation = dilation_rate ** i
124 | padding = int((kernel_size * dilation - dilation) / 2)
125 | in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation,
126 | padding=padding, norm='weight_norm', causal=causal)
127 | self.in_layers.append(in_layer)
128 |
129 | # last one is not necessary
130 | if i < n_layers - 1:
131 | res_skip_channels = 2 * hidden_channels
132 | else:
133 | res_skip_channels = hidden_channels
134 |
135 | res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal)
136 | self.res_skip_layers.append(res_skip_layer)
137 |
138 | def forward(self, x, x_mask, g=None, **kwargs):
139 | output = torch.zeros_like(x)
140 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
141 |
142 | if g is not None:
143 | g = self.cond_layer(g)
144 |
145 | for i in range(self.n_layers):
146 | x_in = self.in_layers[i](x)
147 | if g is not None:
148 | cond_offset = i * 2 * self.hidden_channels
149 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
150 | else:
151 | g_l = torch.zeros_like(x_in)
152 |
153 | acts = commons.fused_add_tanh_sigmoid_multiply(
154 | x_in,
155 | g_l,
156 | n_channels_tensor)
157 | acts = self.drop(acts)
158 |
159 | res_skip_acts = self.res_skip_layers[i](acts)
160 | if i < self.n_layers - 1:
161 | res_acts = res_skip_acts[:, :self.hidden_channels, :]
162 | x = (x + res_acts) * x_mask
163 | output = output + res_skip_acts[:, self.hidden_channels:, :]
164 | else:
165 | output = output + res_skip_acts
166 | return output * x_mask
167 |
168 | def remove_weight_norm(self):
169 | if self.gin_channels != 0:
170 | torch.nn.utils.remove_weight_norm(self.cond_layer)
171 | for l in self.in_layers:
172 | torch.nn.utils.remove_weight_norm(l)
173 | for l in self.res_skip_layers:
174 | torch.nn.utils.remove_weight_norm(l)
--------------------------------------------------------------------------------