├── .DS_Store
├── .gitignore
├── .idea
├── deployment.xml
├── misc.xml
├── modules.xml
├── torch_npss.iml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── README_CN.md
├── data
├── __init__.py
├── cut_raw.py
├── data_util.py
├── dataset.py
├── gen_wav
│ └── 29-test.wav
├── preprocess.py
├── raw
│ ├── nitech_jp_song070_f001_029.lab
│ └── nitech_jp_song070_f001_029.raw
└── timbre_model
│ ├── all_phonetic.npy
│ ├── min_max_record.npy
│ └── test
│ ├── ap
│ └── nitech_jp_song070_f001_029_ap.npy
│ ├── condition
│ └── nitech_jp_song070_f001_029_condi.npy
│ ├── sp
│ └── nitech_jp_song070_f001_029_sp.npy
│ └── vuv
│ └── nitech_jp_song070_f001_029_vuv.npy
├── hparams.py
├── inference.ipynb
├── inference.py
├── learn.ipynb
├── model
├── timbre_training.py
├── util.py
└── wavenet_model.py
├── model_logging.py
├── playground.py
├── requirements.txt
├── snapshots
├── aperiodic
│ └── aper_1649_2019-09-06_06-20-24
├── harmonic
│ └── harm_1649_2019-09-06_07-03-37
└── vuv
│ └── vuv_1649_2019-09-06_06-04-09
├── temp.py
├── train_aperoidic.py
├── train_harmonoc.py
├── train_script.py
└── train_vuv.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | #custom
2 | *.wav
3 | *.raw
4 | *.lab
5 | *.npy
6 | /snapshots
7 | .idea/
8 | *.xml
9 | .DS_Store
10 |
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .coverage
53 | .coverage.*
54 | .cache
55 | nosetests.xml
56 | coverage.xml
57 | *.cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # celery beat schedule file
90 | celerybeat-schedule
91 |
92 | # SageMath parsed files
93 | *.sage.py
94 |
95 | # Environments
96 | .env
97 | .venv
98 | env/
99 | venv/
100 | ENV/
101 | env.bak/
102 | venv.bak/
103 |
104 | # Spyder project settings
105 | .spyderproject
106 | .spyproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 | # mkdocs documentation
112 | /site
113 |
114 | # mypy
115 | .mypy_cache/
116 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/torch_npss.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 | true
205 | DEFINITION_ORDER
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 | 1554295028383
326 |
327 |
328 | 1554295028383
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 sean zhao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # torch_npss
2 |
3 | [中文(chinese)](README_CN.md)
4 |
5 | ### implementation of A Neural Parametric Singing Synthesizer: https://arxiv.org/abs/1704.03809
6 | * pretrained models are provided in snapshots/
7 | * generated samples are in data/gen_wav/
8 |
9 | dataset:https://drive.google.com/file/d/137dTlTiN7jSadV76sRDRwNJ_ysTz1psA/view?usp=sharing
10 |
11 | [sample](https://soundcloud.com/sean-zhao-236492288/29-test)
12 |
13 |
16 |
17 | ### try it out!
18 | ```
19 | note: test labels are in data/timbre_model/test
20 | pip install -r requirements.txt
21 | python inference.py
22 | ```
23 |
24 | ### try with your own data
25 | ```
26 | put your own raw and label data in data/raw/
27 | change custom_test in data/preprocess.py to True
28 | run data/preprocess.py
29 | run generate_test('your_file_name') in inference.py
30 | ```
31 |
32 | ### train your own model
33 | - put your audio and label in data/raw
34 | - run data/preprocess.py
35 | - adjust condition_channel in hparam.py according to your data
36 | - run train_harmonoc.py train_aperoidic.py train_vuv.py
37 |
38 | * if you have any questions feel free to leave an issue
39 |
40 | ## A complete version of implementation which including F0 model and vuv model see here: https://github.com/seaniezhao/cnnpss
41 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 | # torch_npss
2 | [English](README.md)
3 |
4 | ### 一、功能简介
5 |
6 | 这项目是 [Merlijn Blaauw, Jordi Bonada 的《A Neural Parametric Singing Synthesizer》](https://arxiv.org/abs/1704.03809/) 的pytroch部分实现。它可以根据某些条件合成歌声。一句话简述,这是一个基于深度学习的"AI歌手"。
7 |
8 | ### 二、试听小例子
9 | [试听](https://soundcloud.com/sean-zhao-236492288/29-test)
10 |
11 |
14 |
15 |
16 | ### 三、依赖安装
17 | ```
18 | pip install -r requirements.txt
19 | ```
20 |
21 | ### 四、训练数据及测试数据准备
22 | 原文所用数据集:https://drive.google.com/file/d/137dTlTiN7jSadV76sRDRwNJ_ysTz1psA/view?usp=sharing
23 |
24 | 将音频文件和标注文件放到data/raw目录下,然后执行
25 |
26 | ```
27 | python data/preprocess.py
28 | ```
29 | 注意根据处理获得的数据调整hparams.py中的condition_channel
30 | ###### 如果只想用自己的数据测试:
31 | - 1. 将自己的数据放到data/raw目录下
32 | - 2. 将 data/preprocess.py 中的 custom_test 改为True
33 | - 3. 运行 data/preprocess.py
34 | - 4. inference.py 中的文件名改成自己的文件名
35 | - 5. python inference.py
36 |
37 | ### 五、模型的训练
38 | ```
39 | python train_harmonoc.py
40 | python train_aperoidic.py
41 | python train_vuv.py
42 | ```
43 |
44 | ### 六、生成方式
45 | ```
46 | 注:需要生成的标签已经放到了data/timbre_model/test,可以自己生成数据放到test中相应文件夹下
47 | pip install -r requirements.txt
48 | python inference.py
49 | ```
50 |
51 | ### 七、有任何使用上的问题,或者交流合作加微信:seanweichat
52 | ### 最后: 如果喜欢本项目,请给个star谢谢
53 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/__init__.py
--------------------------------------------------------------------------------
/data/cut_raw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import fnmatch
3 | import soundfile as sf
4 |
5 | dist_folder = './cut_raw'
6 |
7 | def find_cut_point(txt_path):
8 | file = open(txt_path, 'r')
9 |
10 | min_interval = 1500000
11 | phn_timing = []
12 | try:
13 | text_lines = file.readlines()
14 | len_line = len(text_lines)
15 | for i, line in enumerate(text_lines):
16 | line = line.replace('\n', '')
17 | [start, end, phn] = line.split(' ')
18 | start = int(start)
19 | end = int(end)
20 | interval = end - start
21 | if phn == 'None' and interval > min_interval:
22 | if start == 0:
23 | # first pau
24 | phn_timing.append([start, end-min_interval, 'cut'])
25 | phn_timing.append([end - min_interval, end, phn])
26 | elif i == len_line-1:
27 | # last pau
28 | phn_timing.append([start, start + min_interval, phn])
29 | phn_timing.append([start + min_interval, end, 'cut'])
30 | else:
31 | phn_timing.append([start, start + min_interval, phn])
32 | phn_timing.append([start + min_interval, end - min_interval, 'cut'])
33 | phn_timing.append([end - min_interval, end, phn])
34 |
35 | else:
36 | if phn == 'sil' or phn == 'pau':
37 | phn = 'cut'
38 | phn_timing.append([start, end, phn])
39 |
40 | finally:
41 | file.close()
42 |
43 |
44 | return phn_timing
45 |
46 | def cut_txt(phn_timing, fname, dist_folder):
47 |
48 | txts = []
49 | time_pairs = []
50 | time_pair = [None, None]
51 |
52 | for i, line in enumerate(phn_timing):
53 | [start, end, phn] = line
54 | if phn == 'cut':
55 | if None not in time_pair:
56 | pass
57 | else:
58 | if time_pair[0] == None:
59 | pass
60 | else:
61 | time_pair[1] = start
62 | if time_pair[0] != time_pair[1]:
63 | time_pairs.append(time_pair)
64 | time_pair = [None, None]
65 | time_pair[0] = end
66 | txts.append([])
67 | else:
68 | txts[-1].append(line)
69 |
70 | clean_txt = []
71 | for i in txts:
72 | if len(i)>0:
73 | clean_txt.append(i)
74 |
75 | #处理多余的pau
76 | clean_txt_trim=[]
77 | time_pair_trim=[]
78 | assert len(clean_txt) == len(time_pairs)
79 | for i in range(len(clean_txt)):
80 | if len(clean_txt[i]) == 1 and clean_txt[i][0][-1] == 'pau':
81 | continue
82 | clean_txt_trim.append(clean_txt[i])
83 | time_pair_trim.append(time_pairs[i])
84 |
85 |
86 | for i, item in enumerate(clean_txt_trim):
87 | new_path = os.path.join(dist_folder, file_name+'_'+str(i)+'.lab')
88 | file = open(new_path, 'w')
89 | start_num = 0
90 | for i, line in enumerate(item):
91 | [start, end, phn] = line
92 | if i ==0:
93 | start_num = start
94 | start -= start_num
95 | end -= start_num
96 | line = str(start)+' '+str(end)+' '+phn
97 | file.write(line)
98 | file.write("\n")
99 | file.close()
100 |
101 |
102 | return time_pair_trim
103 |
104 |
105 | def cut_wav(wav_path, time_pair, fname, dist_folder):
106 |
107 | [start, stop] = time_pair
108 | start = int(int(start)*48000/10000000)
109 | stop = int(int(stop)*48000/10000000)
110 |
111 | y, osr = sf.read(wav_path, subtype='PCM_16', channels=1, samplerate=48000,
112 | endian='LITTLE', start=start, stop=stop)
113 |
114 | new_path = os.path.join(dist_folder,fname+'.raw')
115 | sf.write(new_path, y, subtype='PCM_16', samplerate=48000, endian='LITTLE')
116 |
117 | if __name__ == '__main__':
118 |
119 | if not os.path.exists(dist_folder):
120 | os.mkdir(dist_folder)
121 |
122 | raw_folder = './raw'
123 |
124 | supportedExtensions = '*.raw'
125 | for dirpath, dirs, files in os.walk(raw_folder):
126 | for file in fnmatch.filter(files, supportedExtensions):
127 |
128 | file_name = file.replace('.raw','')
129 | raw_path = os.path.join(dirpath, file)
130 | txt_path = raw_path.replace('.raw', '.lab')
131 |
132 | phn_timing = find_cut_point(txt_path)
133 | time_pairs = cut_txt(phn_timing, file_name, dist_folder)
134 |
135 | for i, tp in enumerate(time_pairs):
136 | fname = file_name+'_'+str(i)
137 | cut_wav(raw_path, tp, fname, dist_folder)
138 |
--------------------------------------------------------------------------------
/data/data_util.py:
--------------------------------------------------------------------------------
1 | import scipy.fftpack as fftpack
2 | import librosa
3 | import pyworld as pw
4 | import numpy as np
5 | import os
6 | import soundfile as sf
7 | import fnmatch
8 | import matplotlib.pyplot as plt
9 | import pysptk
10 | from librosa.display import specshow
11 | import copy
12 |
13 |
14 | gamma = 0
15 | mcepInput = 3 # 0 for dB, 3 for magnitude
16 | alpha = 0.45
17 | en_floor = 10 ** (-80 / 20)
18 |
19 |
20 | def code_harmonic(sp, order):
21 |
22 | #get mcep
23 | mceps = np.apply_along_axis(pysptk.mcep, 1, sp, order - 1, alpha, itype=mcepInput, threshold=en_floor)
24 |
25 | #do fft and take real
26 | scale_mceps = copy.copy(mceps)
27 | scale_mceps[:, 0] *= 2
28 | scale_mceps[:, -1] *= 2
29 | mirror = np.hstack([scale_mceps[:, :-1], scale_mceps[:, -1:0:-1]])
30 | mfsc = np.fft.rfft(mirror).real
31 |
32 | return mfsc
33 |
34 |
35 | def decode_harmonic(mfsc, fftlen):
36 | # get mcep back
37 | mceps_mirror = np.fft.irfft(mfsc)
38 | mceps_back = mceps_mirror[:, :60]
39 | mceps_back[:, 0] /= 2
40 | mceps_back[:, -1] /= 2
41 |
42 | #get sp
43 | spSm = np.exp(np.apply_along_axis(pysptk.mgc2sp, 1, mceps_back, alpha, gamma, fftlen=fftlen).real)
44 |
45 | return spSm
46 |
47 |
48 | if __name__ == '__main__':
49 | y, osr = sf.read('cut_raw/nitech_jp_song070_f001_040_1.raw', subtype='PCM_16', channels=1, samplerate=48000,
50 | endian='LITTLE') # , start=56640, stop=262560)
51 |
52 | sr = 32000
53 | if osr != sr:
54 | y = librosa.resample(y, osr, sr)
55 |
56 | # D = np.abs(librosa.stft(y, hop_length=160)) ** 2
57 | # #D_db = librosa.power_to_db(D, ref=np.max)
58 | # S = librosa.feature.melspectrogram(S=D)
59 | # ptd_S = librosa.power_to_db(S)
60 | # mfcc = librosa.feature.mfcc(S=ptd_S, n_mfcc=60)
61 | #
62 | #
63 | # 使用DIO算法计算音频的基频F0
64 | _f0, t = pw.dio(y, sr, f0_floor=50.0, f0_ceil=800.0, channels_in_octave=2, frame_period=pw.default_frame_period)
65 | print(_f0.shape)
66 |
67 | # 使用CheapTrick算法计算音频的频谱包络
68 | sp = pw.cheaptrick(y, _f0, t, sr)
69 |
70 | _ap = pw.d4c(y, _f0, t, sr)
71 | # ptd_S = librosa.power_to_db(np.transpose(_sp))
72 | # tran_ptd_S = (ptd_S - 0.45)/(1 - 0.45*ptd_S)
73 | # mfcc = librosa.feature.mfcc(S=tran_ptd_S, n_mfcc=60)
74 | #
75 | # _sp_min = np.min(mfcc)
76 | # _sp_max = np.max(mfcc)
77 | # mfcc = (mfcc - _sp_min)/(_sp_max - _sp_min)
78 | #
79 | # code_sp = pw.code_spectral_envelope(_sp, sr, 60)
80 | # t_code_sp = np.transpose(code_sp)
81 | #
82 | # _sp_min = np.min(t_code_sp)
83 | # _sp_max = np.max(t_code_sp)
84 | # t_code_sp = (t_code_sp - _sp_min) / (_sp_max - _sp_min)
85 | #
86 | # plt.imshow(mfcc, aspect='auto', origin='bottom', interpolation='none')
87 | # plt.show()
88 | # plt.imshow(t_code_sp, aspect='auto', origin='bottom', interpolation='none')
89 | # plt.show()
90 | #
91 | # decode_sp = pw.decode_spectral_envelope(code_sp, 32000, 2048)
92 | # x = code_harmonic(_sp)
93 | order = 60
94 | gamma = 0
95 | mcepInput = 3 # 0 for dB, 3 for magnitude
96 | alpha = 0.35
97 | fftlen = (sp.shape[1] - 1) * 2
98 | en_floor = 10 ** (-80 / 20)
99 |
100 | # Reduction and Interpolation
101 | mceps = np.apply_along_axis(pysptk.mcep, 1, sp, order - 1, alpha, itype=mcepInput, threshold=en_floor)
102 |
103 | # scale_mceps = copy.copy(mceps)
104 | # scale_mceps[:, 0] *=2
105 | # scale_mceps[:, -1] *=2
106 | # mirror = np.hstack([scale_mceps[:, :-1], scale_mceps[:, -1:0:-1]])
107 | # mfsc = np.fft.rfft2(mirror).real
108 | mfsc = fftpack.dct(mceps, norm='ortho')
109 |
110 | specshow(mfsc.T, sr=sr, hop_length=80, x_axis='time')
111 | plt.colorbar()
112 | plt.title('mfsc')
113 | plt.tight_layout()
114 | plt.show()
115 |
116 | # itest = np.fft.ifft2(mfsc).real
117 | itest = fftpack.idct(mfsc, norm='ortho')
118 |
119 | specshow(itest.T, sr=sr, hop_length=80, x_axis='time')
120 | plt.colorbar()
121 | plt.title('itest')
122 | plt.tight_layout()
123 | plt.show()
124 |
125 | spSm = np.exp(np.apply_along_axis(pysptk.mgc2sp, 1, itest, alpha, gamma, fftlen=fftlen).real)
126 |
127 | specshow(10 * np.log10(sp.T), sr=sr, hop_length=80, x_axis='time')
128 | plt.colorbar()
129 | plt.title('Original envelope spectrogram')
130 | plt.tight_layout()
131 | plt.show()
132 |
133 | specshow(10 * np.log10(spSm.T), sr=sr, hop_length=80, x_axis='time')
134 | plt.colorbar()
135 | plt.title('Smooth envelope spectrogram')
136 | plt.tight_layout()
137 | plt.show()
138 |
139 | synthesized = pw.synthesize(_f0, spSm, _ap, 32000, pw.default_frame_period)
140 | # 1.输出原始语音
141 | sf.write('gen_wav/dct.wav', synthesized, 32000)
142 |
143 | # mgc = pysptk.mcep(np.sqrt(fft), 59, 0.35, itype=3)
144 | # mfsc = np.exp(pysptk.mgc2sp(mgc, 0.35, fftlen=2048).real)
145 | # pysptk.mgc2sp
146 | # pass
147 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import math
4 | import threading
5 | import torch
6 | import torch.utils.data
7 | import numpy as np
8 | #from torch.distributions.normal import Normal
9 | import librosa as lr
10 | import bisect
11 |
12 |
13 | class NpssDataset(torch.utils.data.Dataset):
14 | def __init__(self,
15 | dataset_file,
16 | condition_file,
17 | receptive_field,
18 | target_length,
19 | train=True):
20 |
21 | # |----receptive_field----|
22 | # example: | | | | | | | | | | | | | | | | | | | | |
23 | # target: | | | | | | | | |
24 | self.dataset_file = dataset_file
25 | self._receptive_field = receptive_field
26 | self.target_length = target_length
27 | self.item_length = self._receptive_field+self.target_length
28 |
29 | self.data = np.load(self.dataset_file)
30 |
31 |
32 | self.conditon = np.load(condition_file).astype(np.float)
33 |
34 |
35 | self._length = 0
36 | self.calculate_length()
37 | self.train = train
38 |
39 |
40 |
41 | def calculate_length(self):
42 |
43 | available_length = self.data.shape[0] - self._receptive_field
44 | self._length = math.floor(available_length / self.target_length)
45 |
46 |
47 | def __getitem__(self, idx):
48 |
49 | sample_index = idx * self.target_length
50 |
51 | sample = self.data[sample_index:sample_index+self.item_length, :]
52 |
53 | item_condition = np.transpose(self.conditon[sample_index+self.item_length-1:sample_index+self.item_length, :])
54 |
55 | example = torch.from_numpy(sample)
56 |
57 | item = example[:self._receptive_field].transpose(0, 1)
58 | target = example[-self.target_length:].transpose(0, 1)
59 | return (item, item_condition), target
60 |
61 | def __len__(self):
62 |
63 | return self._length
64 |
65 |
66 | class TimbreDataset(torch.utils.data.Dataset):
67 | # type 0:harmonic, 1:aperiodic, 2:vuv
68 | def __init__(self,
69 | data_folder,
70 | receptive_field,
71 | type = 0,
72 | target_length=210,
73 | train=True):
74 |
75 | # |----receptive_field----|
76 | # example: | | | | | | | | | | | | | | | | | | | | |
77 | # target: | | | | | | | | |
78 | self.type = type
79 | self._receptive_field = receptive_field
80 | self.target_length = target_length
81 | # 错开一位,其它的在模型中pad
82 | self.item_length = 1+self.target_length
83 |
84 | if train:
85 | data_folder = os.path.join(data_folder, 'train')
86 | else:
87 | data_folder = os.path.join(data_folder, 'test')
88 |
89 | sp_folder = os.path.join(data_folder, 'sp')
90 | ap_folder = os.path.join(data_folder, 'ap')
91 | condi_folder = os.path.join(data_folder, 'condition')
92 | vuv_folder = os.path.join(data_folder, 'vuv')
93 |
94 | # store every data length
95 | self.data_lengths = []
96 | self.dataset_files = []
97 | dirlist = os.listdir(sp_folder)
98 | for item in dirlist:
99 | name = item.replace('_sp.npy','')
100 |
101 | sp = np.load(os.path.join(sp_folder, item))
102 | ap = np.load(os.path.join(ap_folder, name+'_ap.npy'))
103 | vuv = np.load(os.path.join(vuv_folder, name+'_vuv.npy')).astype(np.uint8)
104 | condition = np.load(os.path.join(condi_folder, name+'_condi.npy')).astype(np.float)
105 |
106 | assert len(sp) == len(ap) == len(vuv) == len(condition)
107 |
108 | self.data_lengths.append(math.ceil(len(sp)/target_length))
109 |
110 | # pad zeros(_receptive_field, 60) ahead for each data
111 | sp = np.pad(sp, ((1, 0), (0, 0)), 'constant', constant_values=0)
112 | ap = np.pad(ap, ((1, 0), (0, 0)), 'constant', constant_values=0)
113 | vuv = np.pad(vuv, (1, 0), 'constant', constant_values=0)
114 |
115 | self.dataset_files.append((sp, ap, vuv, condition))
116 | # for test
117 | # break
118 |
119 | self._length = 0
120 | self.calculate_length()
121 | self.train = train
122 |
123 |
124 |
125 | def calculate_length(self):
126 |
127 | self._length = 0
128 | for _len in self.data_lengths:
129 | self._length += _len
130 |
131 |
132 | def __getitem__(self, idx):
133 |
134 | # find witch file it require
135 | current_files = None
136 | current_files_idx = 0
137 | total_len = 0
138 | for fid, _len in enumerate(self.data_lengths):
139 | current_files_idx = idx - total_len
140 | total_len += _len
141 | if idx < total_len:
142 | current_files = self.dataset_files[fid]
143 | break
144 |
145 | sp, ap, vuv, condition = current_files
146 | target_index = current_files_idx*self.target_length
147 | short_sample = self.target_length - (len(sp) - 1 - target_index)
148 | if short_sample > 0:
149 | target_index -= short_sample
150 | item_condition = torch.Tensor(condition[target_index:target_index+self.target_length, :]).transpose(0, 1)
151 |
152 | # notice we pad 1 before so
153 | sp_sample = torch.Tensor(sp[target_index:target_index+self.item_length, :]).transpose(0, 1)
154 | sp_item = sp_sample[:, :self.target_length]
155 | sp_target = sp_sample[:, -self.target_length:]
156 |
157 | ap_sample = torch.Tensor(ap[target_index:target_index + self.item_length, :]).transpose(0, 1)
158 | ap_item = ap_sample[:, :self.target_length]
159 | ap_item = torch.cat((ap_item, sp_item), 0)
160 | ap_target = ap_sample[:, -self.target_length:]
161 |
162 | vuv_sample = torch.Tensor(vuv[target_index:target_index + self.item_length])
163 | vuv_item = vuv_sample[:self.target_length]
164 | # notice here ap_item == (ap_item, sp_item) so we dont cat sp item any more
165 | vuv_item = torch.cat((vuv_item.unsqueeze(0), ap_item), 0)
166 | vuv_target = vuv_sample[-self.target_length:]
167 |
168 | if self.type == 0:
169 | return (sp_item, item_condition), sp_target
170 | elif self.type == 1:
171 | return (ap_item, item_condition), ap_target
172 | else:
173 | return (vuv_item, item_condition), vuv_target
174 |
175 |
176 | def __len__(self):
177 |
178 | return self._length
--------------------------------------------------------------------------------
/data/gen_wav/29-test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/gen_wav/29-test.wav
--------------------------------------------------------------------------------
/data/preprocess.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import pyworld as pw
3 | import numpy as np
4 | import os
5 | import soundfile as sf
6 | import fnmatch
7 | # 直接运行些脚本的时候,不用每次修改data引入的路径
8 | if __name__ == "__main__":
9 | import sys
10 | data_libary_dir = os.path.join(os.path.dirname(__file__), "../")
11 | sys.path.append(data_libary_dir)
12 |
13 | from data.data_util import code_harmonic, decode_harmonic
14 | import sys
15 |
16 | sp_min, sp_max = sys.maxsize, (-sys.maxsize - 1)
17 | ap_min, ap_max = sys.maxsize, (-sys.maxsize - 1)
18 |
19 |
20 | custom_test = True
21 |
22 | rt_folder = os.path.join(os.path.dirname(__file__), 'timbre_model')
23 | tr_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/train')
24 |
25 | sp_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/train/sp/')
26 | ap_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/train/ap/')
27 | vuv_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/train/vuv/')
28 | condition_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/train/condition/')
29 |
30 | te_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/test')
31 |
32 | te_sp_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/test/sp/')
33 | te_ap_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/test/ap/')
34 | te_vuv_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/test/vuv/')
35 | te_condition_folder = os.path.join(os.path.dirname(__file__), 'timbre_model/test/condition/')
36 |
37 | f0_bin = 256
38 | f0_max = 1100.0
39 | f0_min = 50.0
40 | # transfer wav data to three features and store as npy format
41 | def process_wav(wav_path):
42 | y, osr = sf.read(wav_path, subtype='PCM_16', channels=1, samplerate=48000,
43 | endian='LITTLE') #, start=56640, stop=262560)
44 |
45 | sr = 32000
46 | if osr != sr:
47 | y = librosa.resample(y, osr, sr)
48 |
49 | #使用harvest算法计算音频的基频F0
50 | _f0, t = pw.harvest(y, sr, f0_floor=f0_min, f0_ceil=f0_max, frame_period=pw.default_frame_period)
51 | _f0 = pw.stonemask(y, _f0, t, sr)
52 | print(_f0.shape)
53 |
54 | #使用CheapTrick算法计算音频的频谱包络
55 | _sp = pw.cheaptrick(y, _f0, t, sr)
56 |
57 | code_sp = code_harmonic(_sp, 60)
58 | print(_sp.shape, code_sp.shape)
59 | #计算aperiodic参数
60 | _ap = pw.d4c(y, _f0, t, sr)
61 |
62 | code_ap = pw.code_aperiodicity(_ap, sr)
63 | print(_ap.shape, code_ap.shape)
64 |
65 | return _f0, _sp, code_sp, _ap, code_ap
66 |
67 |
68 | def process_phon_label(label_path):
69 | file = open(label_path, 'r')
70 |
71 | time_phon_list = []
72 | phon_list = []
73 | try:
74 | text_lines = file.readlines()
75 | print(type(text_lines), text_lines)
76 | for line in text_lines:
77 | line = line.replace('\n', '')
78 | l_c = line.split(' ')
79 | phn = l_c[2]
80 | tup = (float(l_c[0])*200/10000000, float(l_c[1])*200/10000000, phn)
81 | print(tup)
82 | time_phon_list.append(tup)
83 | if phn not in phon_list:
84 | phon_list.append(phn)
85 | finally:
86 | file.close()
87 |
88 | return time_phon_list, phon_list
89 |
90 |
91 | def process_timbre_model_condition(time_phon_list, all_phon, f0):
92 |
93 | # process f0
94 | # mappling to mel scale
95 |
96 | f0_mel = 1127 * np.log(1 + f0 / 700)
97 | f0_mel_min = 1127 * np.log(1 + f0_min / 700)
98 | f0_mel_max = 1127 * np.log(1 + f0_max / 700)
99 | # f0_mel[f0_mel == 0] = 0
100 | # 大于0的分为255个箱
101 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
102 |
103 | f0_mel[f0_mel < 0] = 1
104 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
105 | f0_coarse = np.rint(f0_mel).astype(np.int)
106 | print('Max f0', np.max(f0_coarse), ' ||Min f0', np.min(f0_coarse))
107 | assert (np.max(f0_coarse) <= 256 and np.min(f0_coarse) >= 0)
108 |
109 | label_list = []
110 | oh_list = []
111 | for i in range(len(f0)):
112 | pre_phn, cur_phn, next_phn, pos_in_phon = (0, 0, 0, 0)
113 | for j in range(len(time_phon_list)):
114 | if time_phon_list[j][0] <= i <= time_phon_list[j][1]:
115 | cur_phn = all_phon.index(time_phon_list[j][2])
116 | if j == 0:
117 | pre_phn = all_phon.index('none')
118 | else:
119 | pre_phn = all_phon.index(time_phon_list[j - 1][2])
120 |
121 | if j == len(time_phon_list) - 1:
122 | next_phn = all_phon.index('none')
123 | else:
124 | next_phn = all_phon.index(time_phon_list[j + 1][2])
125 |
126 | begin = time_phon_list[j][0]
127 | end = time_phon_list[j][1]
128 | width = end - begin
129 |
130 | # 正常语速1分钟200个字
131 | if width < 150:
132 | fpos = width / 3
133 | spos = 2 * width / 3
134 | else:
135 | fpos = 50
136 | spos = width-50
137 |
138 | if i - begin < fpos:
139 | pos_in_phon = 0
140 | elif fpos <= i - begin < spos:
141 | pos_in_phon = 1
142 | else:
143 | pos_in_phon = 2
144 |
145 | label_list.append([pre_phn, cur_phn, next_phn, pos_in_phon, f0_coarse[i]])
146 |
147 | # onehot
148 | pre_phn_oh = np.zeros(len(all_phon))
149 | cur_phn_oh = np.zeros(len(all_phon))
150 | next_phn_oh = np.zeros(len(all_phon))
151 | pos_in_phon_oh = np.zeros(3)
152 | f0_coarse_oh = np.zeros(f0_bin)
153 |
154 | pre_phn_oh[pre_phn] = 1
155 | cur_phn_oh[cur_phn] = 1
156 | next_phn_oh[next_phn] = 1
157 | pos_in_phon_oh[pos_in_phon] = 1
158 | f0_coarse_oh[f0_coarse[i]] = 1
159 |
160 | oh_list.append(
161 | np.concatenate((pre_phn_oh, cur_phn_oh, next_phn_oh, pos_in_phon_oh, f0_coarse_oh)).astype(np.int8))
162 | if i == len(f0) - 1:
163 | print(len(oh_list[-1]), np.sum(oh_list[-1]))
164 |
165 | return oh_list
166 |
167 | if __name__ == '__main__':
168 | if not os.path.exists(rt_folder):
169 | os.mkdir(rt_folder)
170 | # train folders
171 | if not os.path.exists(tr_folder):
172 | os.mkdir(tr_folder)
173 |
174 | if not os.path.exists(sp_folder):
175 | os.mkdir(sp_folder)
176 | if not os.path.exists(ap_folder):
177 | os.mkdir(ap_folder)
178 | if not os.path.exists(vuv_folder):
179 | os.mkdir(vuv_folder)
180 | if not os.path.exists(condition_folder):
181 | os.mkdir(condition_folder)
182 | # test folders
183 | if not os.path.exists(te_folder):
184 | os.mkdir(te_folder)
185 |
186 | if not os.path.exists(te_sp_folder):
187 | os.mkdir(te_sp_folder)
188 | if not os.path.exists(te_ap_folder):
189 | os.mkdir(te_ap_folder)
190 | if not os.path.exists(te_vuv_folder):
191 | os.mkdir(te_vuv_folder)
192 | if not os.path.exists(te_condition_folder):
193 | os.mkdir(te_condition_folder)
194 |
195 | test_names = ['nitech_jp_song070_f001_015', 'nitech_jp_song070_f001_029', 'nitech_jp_song070_f001_040']
196 |
197 | raw_folder = os.path.join(os.path.dirname(__file__), './raw')
198 |
199 | all_phon = ['none']
200 | data_to_save = []
201 |
202 | supportedExtensions = '*.raw'
203 | for dirpath, dirs, files in os.walk(raw_folder):
204 | for file in fnmatch.filter(files, supportedExtensions):
205 | file_name = file.replace('.raw','')
206 | raw_path = os.path.join(dirpath, file)
207 | txt_path = raw_path.replace('.raw', '.lab')
208 | f0, _sp, code_sp, _ap, code_ap = process_wav(raw_path)
209 | v_uv = f0 > 0
210 |
211 | time_phon_list, phon_list = process_phon_label(txt_path)
212 | for item in phon_list:
213 | if item not in all_phon:
214 | all_phon.append(item)
215 |
216 | data_to_save.append((file_name, time_phon_list, f0, code_sp, code_ap, v_uv))
217 |
218 | _sp_min = np.min(code_sp)
219 | _sp_max = np.max(code_sp)
220 | if _sp_min < sp_min:
221 | sp_min = _sp_min
222 | if _sp_max > sp_max:
223 | sp_max = _sp_max
224 |
225 | _ap_min = np.min(code_ap)
226 | _ap_max = np.max(code_ap)
227 | if _ap_min < ap_min:
228 | ap_min = _ap_min
229 | if _ap_max > ap_max:
230 | ap_max = _ap_max
231 |
232 | if custom_test:
233 | all_phon = list(np.load(os.path.join(os.path.dirname(__file__), 'timbre_model/all_phonetic.npy')))
234 | else:
235 | np.save(os.path.join(os.path.dirname(__file__), 'timbre_model/min_max_record.npy'), [sp_min, sp_max, ap_min, ap_max])
236 | np.save(os.path.join(os.path.dirname(__file__), 'timbre_model/all_phonetic.npy'), all_phon)
237 |
238 |
239 | for file_name, time_phon_list, f0, code_sp, code_ap, v_uv in data_to_save:
240 | oh_list = process_timbre_model_condition(time_phon_list, all_phon, f0)
241 |
242 | code_sp = (code_sp - sp_min) / (sp_max - sp_min) - 0.5
243 | code_ap = (code_ap - ap_min) / (ap_max - ap_min) - 0.5
244 |
245 | test = False or custom_test
246 | for n in test_names:
247 | if n in file_name:
248 | test = True
249 | if test:
250 | np.save(te_condition_folder + file_name + '_condi.npy', oh_list)
251 | np.save(te_sp_folder + file_name + '_sp.npy', code_sp)
252 | np.save(te_ap_folder + file_name + '_ap.npy', code_ap)
253 | np.save(te_vuv_folder + file_name + '_vuv.npy', v_uv)
254 | else:
255 | np.save(condition_folder + file_name + '_condi.npy', oh_list)
256 | np.save(sp_folder + file_name + '_sp.npy', code_sp)
257 | np.save(ap_folder + file_name + '_ap.npy', code_ap)
258 | np.save(vuv_folder + file_name + '_vuv.npy', v_uv)
259 |
260 | # np.save('prepared_data/f0.npy', f0)
261 |
262 |
263 |
--------------------------------------------------------------------------------
/data/raw/nitech_jp_song070_f001_029.lab:
--------------------------------------------------------------------------------
1 | 0 24100000 pau
2 | 24100000 25000000 w
3 | 25000000 34300000 a
4 | 34300000 35650000 a
5 | 35650000 36700000 r
6 | 36700000 41700000 e
7 | 41700000 42650000 w
8 | 42650000 47600000 a
9 | 47600000 56450000 u
10 | 56450000 57100000 m
11 | 57100000 60000000 i
12 | 60000000 60450000 n
13 | 60450000 65550000 o
14 | 65550000 66350000 k
15 | 66350000 71100000 o
16 | 71100000 72400000 sh
17 | 72400000 80150000 i
18 | 80150000 80450000 r
19 | 80450000 83350000 a
20 | 83350000 84000000 n
21 | 84000000 88700000 a
22 | 88700000 89450000 m
23 | 89450000 95550000 i
24 | 95550000 96200000 n
25 | 96200000 112750000 o
26 | 112750000 119000000 pau
27 | 119000000 120300000 s
28 | 120300000 130500000 a
29 | 130500000 131150000 a
30 | 131150000 132350000 w
31 | 132350000 137800000 a
32 | 137800000 138350000 g
33 | 138350000 144100000 u
34 | 144100000 152500000 i
35 | 152500000 153400000 s
36 | 153400000 155900000 o
37 | 155900000 156450000 b
38 | 156450000 161600000 e
39 | 161600000 162450000 n
40 | 162450000 165650000 o
41 | 165650000 167750000 br
42 | 167750000 168250000 m
43 | 168250000 173300000 a
44 | 173300000 174550000 ts
45 | 174550000 179550000 u
46 | 179550000 180150000 b
47 | 180150000 188200000 a
48 | 188200000 189250000 r
49 | 189250000 192000000 a
50 | 192000000 192900000 n
51 | 192900000 208500000 i
52 | 208500000 214900000 pau
53 | 214900000 216050000 k
54 | 216050000 227650000 e
55 | 227650000 228300000 m
56 | 228300000 233900000 u
57 | 233900000 239450000 i
58 | 239450000 240150000 t
59 | 240150000 248500000 a
60 | 248500000 249200000 n
61 | 249200000 251850000 a
62 | 251850000 252450000 b
63 | 252450000 257550000 i
64 | 257550000 258400000 k
65 | 258400000 261200000 u
66 | 261200000 263150000 br
67 | 263150000 263900000 t
68 | 263900000 269100000 o
69 | 269100000 269800000 m
70 | 269800000 276050000 a
71 | 276050000 276750000 y
72 | 276750000 284250000 a
73 | 284250000 284950000 k
74 | 284950000 287900000 o
75 | 287900000 288700000 s
76 | 288700000 305350000 o
77 | 305350000 311200000 pau
78 | 311200000 312150000 w
79 | 312150000 319950000 a
80 | 319950000 320550000 g
81 | 320550000 323450000 a
82 | 323450000 324250000 n
83 | 324250000 329100000 a
84 | 329100000 330100000 ts
85 | 330100000 335750000 u
86 | 335750000 336600000 k
87 | 336600000 347200000 a
88 | 347200000 348300000 sh
89 | 348300000 353500000 i
90 | 353500000 354250000 k
91 | 354250000 357450000 i
92 | 357450000 359450000 br
93 | 359450000 360700000 s
94 | 360700000 365550000 u
95 | 365550000 365900000 m
96 | 365900000 371500000 i
97 | 371500000 372300000 k
98 | 372300000 380300000 a
99 | 380300000 381250000 n
100 | 381250000 384300000 a
101 | 384300000 384950000 r
102 | 384950000 404400000 e
103 | 404400000 432000000 pau
104 |
--------------------------------------------------------------------------------
/data/raw/nitech_jp_song070_f001_029.raw:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/raw/nitech_jp_song070_f001_029.raw
--------------------------------------------------------------------------------
/data/timbre_model/all_phonetic.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/timbre_model/all_phonetic.npy
--------------------------------------------------------------------------------
/data/timbre_model/min_max_record.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/timbre_model/min_max_record.npy
--------------------------------------------------------------------------------
/data/timbre_model/test/ap/nitech_jp_song070_f001_029_ap.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/timbre_model/test/ap/nitech_jp_song070_f001_029_ap.npy
--------------------------------------------------------------------------------
/data/timbre_model/test/condition/nitech_jp_song070_f001_029_condi.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/timbre_model/test/condition/nitech_jp_song070_f001_029_condi.npy
--------------------------------------------------------------------------------
/data/timbre_model/test/sp/nitech_jp_song070_f001_029_sp.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/timbre_model/test/sp/nitech_jp_song070_f001_029_sp.npy
--------------------------------------------------------------------------------
/data/timbre_model/test/vuv/nitech_jp_song070_f001_029_vuv.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/data/timbre_model/test/vuv/nitech_jp_song070_f001_029_vuv.npy
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def create_harmonic_hparams(hparams_string=None, verbose=False):
5 | """Create model hyperparameters. Parse nondefault from given string."""
6 |
7 | hparams = tf.contrib.training.HParams(
8 | type=0,
9 | layers=3,
10 | blocks=2,
11 | dilation_channels=130,
12 | residual_channels=130,
13 | skip_channels=240,
14 | input_channel=60,
15 | condition_channel=364,
16 | output_channel=240,
17 | sample_channel=60,
18 | initial_kernel=10,
19 | kernel_size=2,
20 | bias=True
21 | )
22 |
23 | if hparams_string:
24 | tf.logging.info('Parsing harmonic hparams: %s', hparams_string)
25 | hparams.parse(hparams_string)
26 |
27 | if verbose:
28 | tf.logging.info('Final harmonic hparams: %s', hparams.values())
29 |
30 | return hparams
31 |
32 |
33 | def create_aperiodic_hparams(hparams_string=None, verbose=False):
34 | """Create model hyperparameters. Parse nondefault from given string."""
35 |
36 | hparams = tf.contrib.training.HParams(
37 | type=1,
38 | layers=3,
39 | blocks=2,
40 | dilation_channels=20,
41 | residual_channels=20,
42 | skip_channels=16,
43 | input_channel=64,
44 | condition_channel=364,
45 | output_channel=16,
46 | sample_channel=4,
47 | initial_kernel=10,
48 | kernel_size=2,
49 | bias=True
50 | )
51 |
52 | if hparams_string:
53 | tf.logging.info('Parsing aperiodic hparams: %s', hparams_string)
54 | hparams.parse(hparams_string)
55 |
56 | if verbose:
57 | tf.logging.info('Final aperiodic hparams: %s', hparams.values())
58 |
59 | return hparams
60 |
61 |
62 | def create_vuv_hparams(hparams_string=None, verbose=False):
63 | """Create model hyperparameters. Parse nondefault from given string."""
64 |
65 | hparams = tf.contrib.training.HParams(
66 | type=2,
67 | layers=3,
68 | blocks=2,
69 | dilation_channels=20,
70 | residual_channels=20,
71 | skip_channels=4,
72 | input_channel=65,
73 | condition_channel=364,
74 | output_channel=1,
75 | sample_channel=1,
76 | initial_kernel=10,
77 | kernel_size=2,
78 | bias=True
79 | )
80 |
81 | if hparams_string:
82 | tf.logging.info('Parsing vuv hparams: %s', hparams_string)
83 | hparams.parse(hparams_string)
84 |
85 | if verbose:
86 | tf.logging.info('Final vuv hparams: %s', hparams.values())
87 |
88 | return hparams
89 |
90 |
91 | def create_f0_hparams(hparams_string=None, verbose=False):
92 | """Create model hyperparameters. Parse nondefault from given string."""
93 |
94 | hparams = tf.contrib.training.HParams(
95 | type=3,
96 | layers=3,
97 | blocks=2,
98 | dilation_channels=130,
99 | residual_channels=130,
100 | skip_channels=240,
101 | input_channel=60,
102 | condition_channel=1126,
103 | cgm_factor=4,
104 | initial_kernel=10,
105 | kernel_size=2,
106 | bias=True
107 | )
108 |
109 | if hparams_string:
110 | tf.logging.info('Parsing f0 hparams: %s', hparams_string)
111 | hparams.parse(hparams_string)
112 |
113 | if verbose:
114 | tf.logging.info('f0 hparams: %s', hparams.values())
115 |
116 | return hparams
117 |
118 |
119 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from model.wavenet_model import *
2 | from data.dataset import NpssDataset
3 | import hparams
4 | import pyworld as pw
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import soundfile as sf
8 | from data.preprocess import process_wav
9 | from data.data_util import decode_harmonic
10 | import librosa
11 |
12 | fft_size = 2048
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | def load_latest_model_from(mtype, location):
16 |
17 | files = [location + "/" + f for f in os.listdir(location)]
18 | newest_file = max(files, key=os.path.getctime)
19 | # debug
20 | # if mtype == 0:
21 | # newest_file = '/home/sean/pythonProj/torch_npss/snapshots/harmonic/harm_server1649'
22 | # else:
23 | # newest_file = '/home/sean/pythonProj/torch_npss/snapshots/aperiodic/ap_server1649'
24 |
25 |
26 | print("load model " + newest_file)
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 |
29 | if mtype == 0:
30 | hparam = hparams.create_harmonic_hparams()
31 | elif mtype == 1:
32 | hparam = hparams.create_aperiodic_hparams()
33 | else:
34 | hparam = hparams.create_vuv_hparams()
35 |
36 | model = WaveNetModel(hparam, device).to(device)
37 | states = torch.load(newest_file, map_location=device)
38 | model.load_state_dict(states['state_dict'])
39 |
40 | return model
41 |
42 |
43 | def load_timbre(path, m_type, mx, mn):
44 | load_t = np.load(path).astype(np.double)
45 |
46 | load_t = (load_t + 0.5) * (mx - mn) + mn
47 | decode_sp = decode_harmonic(load_t, fft_size)
48 | if m_type == 1:
49 | decode_sp = pw.decode_aperiodicity(load_t, 32000, fft_size)
50 |
51 | return decode_sp
52 |
53 |
54 | # type 0:harmonic, 1:aperiodic,
55 | def generate_timbre(m_type, mx, mn, condition, cat_input=None):
56 | model_path = 'snapshots/harmonic'
57 | if m_type == 1:
58 | model_path = 'snapshots/aperiodic'
59 | model = load_latest_model_from(m_type, model_path)
60 | raw_gen = model.generate(condition, cat_input)
61 | sample = (raw_gen.transpose(0, 1).cpu().numpy().astype(np.double)+0.5) * (mx - mn) + mn
62 |
63 | decode_sp = None
64 | if m_type == 0:
65 | decode_sp = decode_harmonic(sample, fft_size)
66 | elif m_type == 1:
67 | decode_sp = pw.decode_aperiodicity(np.ascontiguousarray(sample), 32000, fft_size)
68 |
69 | return decode_sp, raw_gen
70 |
71 | def generate_vuv(condition, cat_input):
72 | model_path = 'snapshots/vuv'
73 | model = load_latest_model_from(2, model_path)
74 | gen = model.generate(condition, cat_input).squeeze()
75 |
76 | return gen.cpu().numpy().astype(np.uint8)
77 |
78 |
79 | def get_ap_cat():
80 |
81 | wav_path = 'data/timbre_model/test/sp/nitech_jp_song070_f001_015_sp.npy'
82 |
83 | code_sp = np.load(wav_path).astype(np.double)
84 | return torch.Tensor(code_sp).transpose(0, 1)
85 |
86 | def get_vuv_cat():
87 | wav_path = 'data/timbre_model/test/sp/20_sp.npy'
88 |
89 | code_sp = np.load(wav_path).astype(np.double)
90 | sp_cat = torch.Tensor(code_sp).transpose(0, 1)
91 |
92 | wav_path = 'data/timbre_model/test/ap/20_ap.npy'
93 |
94 | code_sp = np.load(wav_path).astype(np.double)
95 | ap_cat = torch.Tensor(code_sp).transpose(0, 1)
96 |
97 | cat = torch.cat((ap_cat, sp_cat), 0)
98 | return cat
99 |
100 |
101 |
102 |
103 | def get_condition(filename):
104 |
105 | c_path = 'data/timbre_model/test/condition/'+filename+'_condi.npy'
106 | conditon = np.load(c_path).astype(np.float)
107 | return torch.Tensor(conditon).transpose(0, 1)
108 |
109 |
110 | def generate_test(filename):
111 |
112 | [sp_min, sp_max, ap_min, ap_max] = np.load('data/timbre_model/min_max_record.npy')
113 | condi = get_condition(filename)
114 | # cat_input = get_ap_cat()
115 | # fist_input = get_first_input()
116 |
117 | sp, raw_sp = generate_timbre(0, sp_max, sp_min, condi, None)
118 |
119 | plt.imshow(np.log(np.transpose(sp)), aspect='auto', origin='bottom', interpolation='none')
120 | plt.show()
121 |
122 | sp1 = load_timbre('data/timbre_model/test/sp/'+filename+'_sp.npy', 0, sp_max, sp_min)
123 |
124 | plt.imshow(np.log(np.transpose(sp1)), aspect='auto', origin='bottom', interpolation='none')
125 | plt.show()
126 | ####################################################################################################
127 | ap, raw_ap = generate_timbre(1, ap_max, ap_min, condi, raw_sp)
128 |
129 | plt.imshow(np.log(np.transpose(ap)), aspect='auto', origin='bottom', interpolation='none')
130 | plt.show()
131 |
132 | ap1 = load_timbre('data/timbre_model/test/ap/'+filename+'_ap.npy', 1, ap_max, ap_min)
133 |
134 | plt.imshow(np.log(np.transpose(ap1)), aspect='auto', origin='bottom', interpolation='none')
135 | plt.show()
136 |
137 | #########################################################################################################
138 | # vuv_cat = get_vuv_cat()
139 | # gen_cat = torch.cat((raw_ap, raw_sp), 0)
140 |
141 | # vuv = generate_vuv(condi, vuv_cat)
142 | # plt.plot(vuv)
143 | # plt.show()
144 | #
145 | # vuv1 = np.load('data/timbre_model/test/vuv/nitech_jp_song070_f001_029_vuv.npy')
146 | # plt.plot(vuv1)
147 | # plt.show()
148 |
149 | path = 'data/raw/'+filename+'.raw'
150 | _f0, _sp, code_sp, _ap, code_ap = process_wav(path)
151 | # 合成原始语音
152 | synthesized = pw.synthesize(_f0, sp, ap, 32000, pw.default_frame_period)
153 | # 1.输出原始语音
154 | sf.write('./data/gen_wav/'+filename+''
155 | '.wav', synthesized, 32000)
156 |
157 |
158 |
159 | if __name__ == '__main__':
160 | generate_test('nitech_jp_song070_f001_029')
161 |
162 |
--------------------------------------------------------------------------------
/model/timbre_training.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import torch.utils.data
4 | import time
5 | import os
6 | import torch.nn as nn
7 | from data.dataset import TimbreDataset
8 | from datetime import datetime
9 | from torch.autograd import Variable
10 | from model.util import *
11 | import pyworld as pw
12 | import matplotlib.pyplot as plt
13 |
14 |
15 |
16 | class ModelTrainer:
17 | def __init__(self,
18 | model,
19 | data_folder,
20 | device,
21 | snapshot_path=None,
22 | snapshot_name='snapshot',
23 | snapshot_interval=1000,
24 | lr=0.0005,
25 | weight_decay=0):
26 |
27 | self.model = model
28 |
29 | self.trainset = TimbreDataset(data_folder=data_folder, receptive_field=model.receptive_field, type=model.model_type, train=True)
30 | print('the dataset has ' + str(len(self.trainset)) + ' items')
31 |
32 | self.testset = TimbreDataset(data_folder=data_folder, receptive_field=model.receptive_field, type=model.model_type, train=False)
33 |
34 | self.dataloader = None
35 | self.lr = lr
36 | self.weight_decay = weight_decay
37 | self.device = device
38 |
39 |
40 | self.snapshot_path = snapshot_path
41 | self.snapshot_name = snapshot_name
42 | self.snapshot_interval = snapshot_interval
43 |
44 | self.optimizer = optim.Adam(params=self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
45 |
46 | self.clip = None
47 |
48 | self.device_count = torch.cuda.device_count()
49 |
50 | self.start_epoch = 0
51 | self.epoch = 0
52 |
53 | self.model_type = self.model.model_type
54 |
55 | def adjust_learning_rate(self):
56 |
57 | real_epoch = self.start_epoch + self.epoch
58 | lr = self.lr / (1 + 0.00001 * real_epoch)
59 |
60 | print('lr '+str(lr)+' epoch '+str(real_epoch))
61 | for param_group in self.optimizer.param_groups:
62 | param_group['lr'] = lr
63 |
64 | def train(self, batch_size=32, epochs=10):
65 |
66 | self.model.train()
67 | if self.device_count > 1:
68 | self.model = nn.DataParallel(self.model)
69 | print('multiple device using :', self.device_count)
70 |
71 | self.dataloader = torch.utils.data.DataLoader(self.trainset,
72 | batch_size=batch_size,
73 | shuffle=True,
74 | num_workers=8,
75 | pin_memory=False)
76 |
77 | self.testdataloader = torch.utils.data.DataLoader(self.testset,
78 | batch_size=batch_size,
79 | shuffle=False,
80 | num_workers=8,
81 | pin_memory=False)
82 |
83 | step = 0
84 | for current_epoch in range(epochs):
85 | print("epoch", current_epoch)
86 | self.epoch = current_epoch
87 |
88 | self.adjust_learning_rate()
89 | tic = time.time()
90 |
91 | total_loss = 0
92 | epoch_step = 0
93 | for (x, target) in iter(self.dataloader):
94 | x, condi = x
95 | x = x.to(self.device)
96 | condi = condi.to(self.device)
97 |
98 | target = target.to(self.device)
99 |
100 | output = self.model(x, condi)
101 | if self.model_type == 2:
102 | loss = torch.mean((output.squeeze()-target.squeeze())**2)
103 | else:
104 | loss = CGM_loss(output, target)
105 |
106 | self.optimizer.zero_grad()
107 | loss.backward()
108 | loss = loss.item()
109 | total_loss += loss
110 | epoch_step += 1
111 | #print('loss: ', loss)
112 |
113 | if self.clip is not None:
114 | torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clip)
115 | self.optimizer.step()
116 | step += 1
117 |
118 | # time step duration:
119 | if step == 100:
120 | toc = time.time()
121 | print("one training step does take approximately " + str((toc - tic) * 0.01) + " seconds)")
122 |
123 | #self.save_model()
124 | test_loss = self.validate()
125 | toc = time.time()
126 | print("one epoch does take approximately " + str((toc - tic)) + " seconds), average loss: " + str(total_loss/epoch_step)+" test loss: "+str(test_loss))
127 |
128 | self.save_model()
129 |
130 |
131 | def load_checkpoint(self, filename):
132 |
133 | if os.path.isfile(filename):
134 | print("=> loading checkpoint '{}'".format(filename))
135 | checkpoint = torch.load(filename)
136 | self.start_epoch = checkpoint['epoch']
137 | self.model.load_state_dict(checkpoint['state_dict'])
138 | self.optimizer.load_state_dict(checkpoint['optimizer'])
139 |
140 | print("=> loaded checkpoint '{}' (epoch {})"
141 | .format(filename, checkpoint['epoch']))
142 | else:
143 | print("=> no checkpoint found at '{}'".format(filename))
144 |
145 | return self.start_epoch
146 |
147 |
148 | def validate(self):
149 |
150 | self.model.eval()
151 |
152 | total_loss = 0
153 | epoch_step = 0
154 | for (x, target) in iter(self.testdataloader):
155 | x, condi = x
156 | x = x.to(self.device)
157 | condi = condi.to(self.device)
158 |
159 | target = target.to(self.device)
160 |
161 | output = self.model(x, condi)
162 | if self.model_type == 2:
163 | loss = torch.mean((output.squeeze() - target.squeeze()) ** 2)
164 | else:
165 | loss = CGM_loss(output, target)
166 |
167 | loss = loss.item()
168 |
169 | total_loss += loss
170 | epoch_step += 1
171 |
172 | self.model.train()
173 | avg_loss = total_loss/epoch_step
174 | return avg_loss
175 |
176 | def save_model(self):
177 | if self.snapshot_path is None:
178 | return
179 | time_string = time.strftime("%Y-%m-%d_%H-%M-%S", time.gmtime())
180 | if not os.path.exists(self.snapshot_path):
181 | os.mkdir(self.snapshot_path)
182 | to_save = self.model
183 | if self.device_count > 1:
184 | to_save = self.model.module
185 |
186 | str_epoch = str(self.start_epoch + self.epoch)
187 | filename = self.snapshot_path + '/' + self.snapshot_name + '_' + str_epoch + '_' + time_string
188 | state = {'epoch': self.epoch + 1, 'state_dict': to_save.state_dict(),
189 | 'optimizer': self.optimizer.state_dict()}
190 | torch.save(state, filename)
191 |
192 | print('model saved')
193 |
194 |
195 | # test
196 | def get_first_input(self):
197 | wav_path = './data/prepared_data/sp.npy'
198 |
199 | code_sp = np.load(wav_path).astype(np.double)
200 | return torch.Tensor(code_sp)
201 |
202 | def get_condition(self):
203 | c_path = './data/prepared_data/condition.npy'
204 | conditon = np.load(c_path).astype(np.float)
205 | return torch.Tensor(conditon).transpose(0, 1)
206 |
207 | def generate_audio(self):
208 |
209 | first_input = self.get_first_input()
210 | condi = self.get_condition()
211 | gen = self.model.generate(condi, first_input.transpose(0, 1))
212 |
213 | x = torch.sum((first_input - gen.cpu())**2)
214 | print("MSE !!!!!!!!!!!!!!", x)
215 | return gen
216 |
--------------------------------------------------------------------------------
/model/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 | from torch.distributions.normal import Normal
5 |
6 |
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 | # 计算 mu sigma 和 w
9 | def cal_para(out, temperature):
10 |
11 | sqrt = math.sqrt
12 |
13 | cgm_factor = 4
14 | r_u = 1.6
15 | r_s = 1.1
16 | r_w = 1 / 1.75
17 |
18 |
19 | out = out.permute(0, 2, 1).contiguous()
20 | out = out.view(out.shape[0], out.shape[1], -1, cgm_factor)
21 |
22 | a0 = out[:, :, :, 0]
23 | a1 = out[:, :, :, 1]
24 | a2 = out[:, :, :, 2]
25 | a3 = out[:, :, :, 3]
26 |
27 | xi = 2 * torch.sigmoid(a0) - 1
28 | omega = torch.exp(4 * torch.sigmoid(a1)) * 2 / 255
29 | alpha = 2 * torch.sigmoid(a2) - 1
30 | beta = 2 * torch.sigmoid(a3)
31 |
32 | # cal temperature
33 | use_t = False
34 | if temperature != 0:
35 | use_t = True
36 | tempers = []
37 | for i in range(xi.shape[-1]):
38 | if i<=3:
39 | temper = 0.05
40 | elif i>=8:
41 | temper = 0.5
42 | else:
43 | temper = 0.05 + (i-3)*0.09
44 | tempers.append(temper)
45 |
46 | #tempers = tempers[::-1]
47 | tempers = torch.Tensor(tempers)
48 | tempers = tempers.expand(xi.shape).to(device)
49 | # if temperature != 0.01 mean it is for harmonic so it will be piecewise linear
50 | if temperature != 0.01:
51 | temperature = tempers
52 | sqrt = torch.sqrt
53 | # end cal temperature
54 |
55 | sigmas = []
56 | for k in range(cgm_factor):
57 | sigma = omega * torch.exp(k * (torch.abs(alpha) * r_s - 1))
58 | sigmas.append(sigma)
59 |
60 | mus = []
61 | for k in range(cgm_factor):
62 | temp_sum = 0
63 | for i in range(k):
64 | temp_sum += sigmas[i] * r_u * alpha
65 | mu = xi + temp_sum
66 | mus.append(mu)
67 |
68 | ws = []
69 | temp_sum = 0
70 | for i in range(cgm_factor):
71 | temp_sum += alpha.pow(2 * i) * beta.pow(i) * (r_w ** i)
72 | for k in range(cgm_factor):
73 | w = (alpha.pow(2 * k) * beta.pow(k) * (r_w ** k)) / temp_sum
74 | ws.append(w)
75 |
76 | if use_t:
77 | _mus = 0
78 | for k in range(cgm_factor):
79 | _mus += ws[k]*mus[k]
80 |
81 | for k in range(cgm_factor):
82 | mus[k] = mus[k] + (_mus - mus[k])*(1 - temperature)
83 | sigmas[k] *= sqrt(temperature)
84 |
85 |
86 | return sigmas, mus, ws
87 |
88 | # x dim = (batch, output_channel, length)
89 | # l dim = (batch, output_channel * cgm_factor, length)
90 |
91 |
92 | def CGM_loss(out, y):
93 | y = y.permute(0, 2, 1)
94 |
95 | sigmas, mus, ws = cal_para(out, 0)
96 |
97 | #print(torch.mean(sigmas[0]))
98 | # 验证w之和是1
99 | sum = 0
100 | for k in range(4):
101 | tw = ws[k].view(-1)
102 | sum += tw
103 |
104 |
105 | # alternative: torch.distributions.normal.Normal
106 | probs = 0
107 | for k in range(4):
108 | dist = Normal(mus[k], sigmas[k])
109 | log_prob = dist.log_prob(y)
110 |
111 | x = dist.sample()
112 | # prob = log_prob * log_prob
113 | probs += ws[k] * log_prob
114 |
115 | return -torch.mean(probs)
116 |
117 |
118 | def sample_from_CGM(out, temperature=0.01):
119 | #temperature = 0.01
120 | out = out.unsqueeze(1)
121 | out = out.unsqueeze(0)
122 | sigmas, mus, ws = cal_para(out, temperature)
123 |
124 | value = 0
125 | rand = torch.rand(ws[0].shape).to(device)
126 |
127 | for k in range(4):
128 | mask_btm = torch.zeros(ws[k].shape).to(device)
129 | for i in range(k):
130 | mask_btm += ws[i]
131 | mask = (rand < (ws[k] + mask_btm)) * (rand >= mask_btm)
132 | mask = mask.float()
133 | gaussian_dist = Normal(loc=mus[k], scale=sigmas[k])
134 | x = gaussian_dist.sample()
135 | value += mask * x
136 |
137 | # value shape (batch, length, channel'60')
138 | return value
139 |
140 |
--------------------------------------------------------------------------------
/model/wavenet_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import time
4 | from data.dataset import *
5 | from model.util import *
6 | import torch.nn as nn
7 | from torch.distributions.normal import Normal
8 | from torch.distributions.multivariate_normal import MultivariateNormal
9 |
10 | class WaveNetModel(nn.Module):
11 |
12 | def __init__(self, hparams, device):
13 |
14 | super(WaveNetModel, self).__init__()
15 | self.noise_lambda = 0.2
16 | self.model_type = hparams.type
17 | self.layers = hparams.layers
18 | self.blocks = hparams.blocks
19 | self.dilation_channels = hparams.dilation_channels
20 | self.residual_channels = hparams.residual_channels
21 | self.skip_channels = hparams.skip_channels
22 | self.input_channel = hparams.input_channel
23 | self.initial_kernel = hparams.initial_kernel
24 | self.kernel_size = hparams.kernel_size
25 | self.output_channel = hparams.output_channel
26 | # if use CGM sample_channel * cgm_factor = output_channel
27 | self.sample_channel = hparams.sample_channel
28 | self.condition_channel = hparams.condition_channel
29 | self.bias = hparams.bias
30 |
31 | self.device = device
32 | # build model
33 | receptive_field = 1
34 | init_dilation = 1
35 | #
36 | # self.dilations = []
37 | # self.dilated_queues = []
38 |
39 | self.dilated_pads = nn.ModuleList()
40 | self.dilated_convs = nn.ModuleList()
41 |
42 | self.residual_convs = nn.ModuleList()
43 | self.skip_convs = nn.ModuleList()
44 |
45 | self.condi_convs = nn.ModuleList()
46 |
47 | # 1x1 convolution to create channels
48 | self.start_pad = nn.ConstantPad1d((9, 0), 0)
49 | self.start_conv = nn.Conv1d(in_channels=self.input_channel,
50 | out_channels=self.residual_channels,
51 | kernel_size=self.initial_kernel,
52 | bias=self.bias)
53 | nn.init.xavier_uniform_(
54 | self.start_conv.weight, gain=nn.init.calculate_gain('linear'))
55 |
56 | for b in range(self.blocks):
57 | additional_scope = self.kernel_size - 1
58 | new_dilation = 1
59 | actual_layer = self.layers
60 | if b == self.blocks-1:
61 | actual_layer = self.layers - 1
62 | for i in range(actual_layer):
63 |
64 | self.condi_convs.append(nn.Conv1d(in_channels=self.condition_channel,
65 | out_channels=self.dilation_channels*2,
66 | kernel_size=1,
67 | bias=self.bias))
68 | nn.init.xavier_uniform_(
69 | self.condi_convs[i].weight, gain=nn.init.calculate_gain('linear'))
70 |
71 |
72 | self.dilated_pads.append(nn.ConstantPad1d((new_dilation, 0), 0))
73 | # dilated convolutions
74 | self.dilated_convs.append(nn.Conv1d(in_channels=self.residual_channels,
75 | out_channels=self.dilation_channels*2,
76 | kernel_size=self.kernel_size,
77 | bias=self.bias,
78 | dilation=new_dilation))
79 | nn.init.xavier_uniform_(
80 | self.dilated_convs[i].weight, gain=nn.init.calculate_gain('linear'))
81 |
82 |
83 | # 1x1 convolution for residual connection
84 | self.residual_convs.append(nn.Conv1d(in_channels=self.dilation_channels,
85 | out_channels=self.residual_channels,
86 | kernel_size=1,
87 | bias=self.bias))
88 | nn.init.xavier_uniform_(
89 | self.residual_convs[i].weight, gain=nn.init.calculate_gain('linear'))
90 |
91 | # 1x1 convolution for skip connection
92 | self.skip_convs.append(nn.Conv1d(in_channels=self.dilation_channels,
93 | out_channels=self.skip_channels,
94 | kernel_size=1,
95 | bias=self.bias))
96 | nn.init.xavier_uniform_(
97 | self.skip_convs[i].weight, gain=nn.init.calculate_gain('linear'))
98 |
99 |
100 | receptive_field += additional_scope
101 | additional_scope *= 2
102 | init_dilation = new_dilation
103 | new_dilation *= 2
104 |
105 | self.end_conv = nn.Conv1d(in_channels=self.skip_channels,
106 | out_channels=self.output_channel,
107 | kernel_size=1,
108 | bias=self.bias)
109 | nn.init.xavier_uniform_(
110 | self.end_conv.weight, gain=nn.init.calculate_gain('linear'))
111 |
112 | # condition end conv
113 | self.cond_end_conv = nn.Conv1d(in_channels=self.condition_channel,
114 | out_channels=self.skip_channels,
115 | kernel_size=1,
116 | bias=self.bias)
117 | nn.init.xavier_uniform_(
118 | self.cond_end_conv.weight, gain=nn.init.calculate_gain('linear'))
119 |
120 |
121 | self.receptive_field = receptive_field + self.initial_kernel - 1
122 |
123 | def wavenet(self, input, condition):
124 | # input shape: (B,N,L) N is channel
125 | # condition shape: (B,cN,L) cN is condition channel
126 |
127 | input = self.start_pad(input)
128 | x = self.start_conv(input)
129 | skip = 0
130 |
131 | # WaveNet layers
132 | for i in range(self.blocks * self.layers - 1):
133 |
134 | # |----------------------------------------------------| *residual*
135 | # | |
136 | # | |---- tanh --------| |
137 | # -> dilate_conv ->| * ----|-- 1x1 -- + --> *input*
138 | # |---- sigm --------| |
139 | # 1x1
140 | # |
141 | # ----------------------------------------> + -------------> *skip*
142 |
143 | residual = x
144 |
145 | x = self.dilated_pads[i](x)
146 | dilated = self.dilated_convs[i](x)
147 | # here plus condition
148 |
149 | condi = self.condi_convs[i](condition)
150 | dilated = dilated + condi
151 |
152 | filter, gate = torch.chunk(dilated, 2, dim=1)
153 |
154 | # dilated convolution
155 | filter = torch.tanh(filter)
156 | gate = torch.sigmoid(gate)
157 | x = filter * gate
158 |
159 | # parametrized skip connection
160 | s = x
161 | s = self.skip_convs[i](s)
162 | try:
163 | skip = skip[:, :, -s.size(2):]
164 | except:
165 | skip = 0
166 | skip = s + skip
167 |
168 | x = self.residual_convs[i](x)
169 | x = x + residual[:, :, -x.size(2):]
170 |
171 | # plus condition
172 | condi = self.cond_end_conv(condition)
173 | skip = skip + condi
174 |
175 | x = torch.tanh(skip)
176 | x = self.end_conv(x)
177 |
178 | if self.model_type == 2:
179 | x = torch.sigmoid(x)
180 |
181 | return x
182 |
183 | def forward(self, input, condition):
184 | # input noise
185 | # input shape : (B, N, L) for harmonic N = 60 = self.sample_channel
186 | # if self.sample_channel == 1:
187 | # dist = Normal(input, self.noise_lambda)
188 | # x = self.wavenet(dist.sample(), condition)
189 | # else:
190 | # input = input.permute(0, 2, 1)
191 | # sigmas = self.noise_lambda * torch.eye(self.sample_channel)
192 | # sigmas = sigmas.repeat(input.shape[0], input.shape[1], 1, 1).to(self.device)
193 | # dist = MultivariateNormal(input, sigmas)
194 | # r_input = dist.sample().permute(0, 2, 1)
195 | # x = self.wavenet(r_input, condition)
196 |
197 | dist = Normal(input, self.noise_lambda)
198 | x = self.wavenet(dist.sample(), condition)
199 | return x
200 |
201 | def parameter_count(self):
202 | par = list(self.parameters())
203 | s = sum([np.prod(list(d.size())) for d in par])
204 | return s
205 |
206 | def generate(self, conditions, cat_input=None):
207 | # conditions shape: (condition_channel, len)
208 | self.eval()
209 |
210 | if cat_input is not None:
211 | cat_input = cat_input.to(self.device)
212 |
213 | conditions = conditions.to(self.device)
214 | num_samples = conditions.shape[1]
215 | generated = torch.zeros(self.sample_channel, num_samples).to(self.device)
216 |
217 | model_input = torch.zeros(1, self.input_channel, 1).to(self.device)
218 |
219 | from tqdm import tqdm
220 | for i in tqdm(range(num_samples)):
221 | if i < self.receptive_field:
222 | condi = conditions[:, :i + 1]
223 | else:
224 | condi = conditions[:, i - self.receptive_field + 1:i + 1]
225 | condi = condi.unsqueeze(0)
226 |
227 | x = self.wavenet(model_input, condi)
228 | x = x[:, :, -1].squeeze()
229 | if self.model_type == 2:
230 | x_sample = 0
231 | if x > 0.5:
232 | x_sample = 1
233 | x_sample = torch.Tensor([x_sample]).to(self.device).unsqueeze(0)
234 | else:
235 | t = 0.01
236 | if self.model_type == 0:
237 | t = 0.05
238 | x_sample = sample_from_CGM(x.detach(), t)
239 |
240 | generated[:, i] = x_sample.squeeze(0)
241 |
242 | # set new input
243 | if i < self.receptive_field - 1:
244 | model_input = generated[:, :i + 1]
245 | if cat_input is not None:
246 | to_cat = cat_input[:, :i + 1]
247 | model_input = torch.cat((to_cat, model_input), 0)
248 |
249 | model_input = torch.Tensor(np.pad(model_input.cpu(), ((0, 0), (1, 0)), 'constant', constant_values=0)).to(
250 | self.device)
251 | else:
252 | model_input = generated[:, i - self.receptive_field + 1:i + 1]
253 | if cat_input is not None:
254 | to_cat = cat_input[:, i - self.receptive_field + 1:i + 1]
255 | model_input = torch.cat((to_cat, model_input), 0)
256 |
257 | model_input = model_input.unsqueeze(0)
258 |
259 | self.train()
260 | return generated
261 |
262 |
--------------------------------------------------------------------------------
/model_logging.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import scipy.misc
4 | import threading
5 |
6 | try:
7 | from StringIO import StringIO # Python 2.7
8 | except ImportError:
9 | from io import BytesIO # Python 3.x
10 |
11 |
12 | class Logger:
13 | def __init__(self,
14 | log_interval=50,
15 | validation_interval=200,
16 | generate_interval=500,
17 | trainer=None,
18 | generate_function=None):
19 | self.trainer = trainer
20 | self.log_interval = log_interval
21 | self.validation_interval = validation_interval
22 | self.generate_interval = generate_interval
23 | self.accumulated_loss = 0
24 | self.generate_function = generate_function
25 | if self.generate_function is not None:
26 | self.generate_thread = threading.Thread(target=self.generate_function)
27 | self.generate_function.daemon = True
28 |
29 | def log(self, current_step, current_loss):
30 | self.accumulated_loss += current_loss
31 | if current_step % self.log_interval == 0:
32 | self.log_loss(current_step)
33 | self.accumulated_loss = 0
34 | if current_step % self.validation_interval == 0:
35 | self.validate(current_step)
36 | if current_step % self.generate_interval == 0:
37 | self.generate(current_step)
38 |
39 | def log_loss(self, current_step):
40 | avg_loss = self.accumulated_loss / self.log_interval
41 | print("loss at step " + str(current_step) + ": " + str(avg_loss))
42 |
43 | def validate(self, current_step):
44 | avg_loss, avg_accuracy = self.trainer.validate()
45 | print("validation loss: " + str(avg_loss))
46 | print("validation accuracy: " + str(avg_accuracy * 100) + "%")
47 |
48 | def generate(self, current_step):
49 | if self.generate_function is None:
50 | return
51 |
52 | if self.generate_thread.is_alive():
53 | print("Last generate is still running, skipping this one")
54 | else:
55 | self.generate_thread = threading.Thread(target=self.generate_function,
56 | args=[current_step])
57 | self.generate_thread.daemon = True
58 | self.generate_thread.start()
59 |
60 |
61 | # Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
62 | class TensorboardLogger(Logger):
63 | def __init__(self,
64 | log_interval=50,
65 | validation_interval=200,
66 | generate_interval=500,
67 | trainer=None,
68 | generate_function=None,
69 | log_dir='logs'):
70 | super().__init__(log_interval, validation_interval, generate_interval, trainer, generate_function)
71 | self.writer = tf.summary.FileWriter(log_dir)
72 |
73 | def log_loss(self, current_step):
74 | # loss
75 | avg_loss = self.accumulated_loss / self.log_interval
76 | self.scalar_summary('loss', avg_loss, current_step)
77 |
78 | # parameter histograms
79 | for tag, value, in self.trainer.model.named_parameters():
80 | tag = tag.replace('.', '/')
81 | self.histo_summary(tag, value.data.cpu().numpy(), current_step)
82 | if value.grad is not None:
83 | self.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), current_step)
84 |
85 | def validate(self, current_step):
86 | avg_loss, avg_accuracy = self.trainer.validate()
87 | self.scalar_summary('validation loss', avg_loss, current_step)
88 | self.scalar_summary('validation accuracy', avg_accuracy, current_step)
89 |
90 | def log_audio(self, step):
91 | samples = self.generate_function()
92 | tf_samples = tf.convert_to_tensor(samples)
93 | self.audio_summary('audio sample', tf_samples, step, sr=16000)
94 |
95 | def scalar_summary(self, tag, value, step):
96 | """Log a scalar variable."""
97 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
98 | self.writer.add_summary(summary, step)
99 |
100 | def image_summary(self, tag, images, step):
101 | """Log a list of images."""
102 |
103 | img_summaries = []
104 | for i, img in enumerate(images):
105 | # Write the image to a string
106 | try:
107 | s = StringIO()
108 | except:
109 | s = BytesIO()
110 | scipy.misc.toimage(img).save(s, format="png")
111 |
112 | # Create an Image object
113 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
114 | height=img.shape[0],
115 | width=img.shape[1])
116 | # Create a Summary value
117 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
118 |
119 | # Create and write Summary
120 | summary = tf.Summary(value=img_summaries)
121 | self.writer.add_summary(summary, step)
122 |
123 | def audio_summary(self, tag, sample, step, sr=16000):
124 | with tf.Session() as sess:
125 | audio_summary = tf.summary.audio(tag, sample, sample_rate=sr, max_outputs=4)
126 | summary = sess.run(audio_summary)
127 | self.writer.add_summary(summary, step)
128 | self.writer.flush()
129 |
130 |
131 | def histo_summary(self, tag, values, step, bins=200):
132 | """Log a histogram of the tensor of values."""
133 |
134 | # Create a histogram using numpy
135 | counts, bin_edges = np.histogram(values, bins=bins)
136 |
137 | # Fill the fields of the histogram proto
138 | hist = tf.HistogramProto()
139 | hist.min = float(np.min(values))
140 | hist.max = float(np.max(values))
141 | hist.num = int(np.prod(values.shape))
142 | hist.sum = float(np.sum(values))
143 | hist.sum_squares = float(np.sum(values ** 2))
144 |
145 | # Drop the start of the first bin
146 | bin_edges = bin_edges[1:]
147 |
148 | # Add bin edges and counts
149 | for edge in bin_edges:
150 | hist.bucket_limit.append(edge)
151 | for c in counts:
152 | hist.bucket.append(c)
153 |
154 | # Create and write Summary
155 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
156 | self.writer.add_summary(summary, step)
157 | self.writer.flush()
158 |
159 | def tensor_summary(self, tag, tensor, step):
160 | tf_tensor = tf.Variable(tensor).to_proto()
161 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, tensor=tf_tensor)])
162 | #summary = tf.summary.tensor_summary(name=tag, tensor=tensor)
163 | self.writer.add_summary(summary, step)
164 |
165 |
--------------------------------------------------------------------------------
/playground.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import pyworld as pw
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import soundfile as sf
6 | import os
7 | #np.set_printoptions(threshold=np.nan)
8 |
9 |
10 | def process_wav(wav_path):
11 | y, osr = sf.read(wav_path, subtype='PCM_16', channels=1, samplerate=48000,
12 | endian='LITTLE') # , start=56640, stop=262560)
13 |
14 | sr = 32000
15 | y = librosa.resample(y, osr, sr)
16 |
17 | # 使用DIO算法计算音频的基频F0
18 | _f0, t = pw.dio(y, sr, f0_floor=50.0, f0_ceil=800.0, channels_in_octave=2, frame_period=pw.default_frame_period)
19 | print(_f0.shape)
20 |
21 | # 使用CheapTrick算法计算音频的频谱包络
22 | _sp = pw.cheaptrick(y, _f0, t, sr)
23 |
24 | code_sp = pw.code_spectral_envelope(_sp, sr, 60)
25 | print(_sp.shape, code_sp.shape)
26 | # 计算aperiodic参数
27 | _ap = pw.d4c(y, _f0, t, sr)
28 |
29 | code_ap = pw.code_aperiodicity(_ap, sr)
30 | print(_ap.shape, code_ap.shape)
31 |
32 | np.save('data/prepared_data/f0', _f0)
33 | np.save('data/prepared_data/ap', code_ap)
34 |
35 | # 合成原始语音
36 | synthesized = pw.synthesize(_f0-200, _sp, _ap, 32000, pw.default_frame_period)
37 | # 1.输出原始语音
38 | sf.write('./data/gen_wav/test-200.wav', synthesized, 32000)
39 |
40 | #process_wav('/home/sean/pythonProj/torch_npss/data/raw/nitech_jp_song070_f001_055.raw')
41 |
42 |
43 | def get_feature(wav_path):
44 | y, osr = sf.read(wav_path) # , start=56640, stop=262560)
45 |
46 | sr = 32000
47 | y = librosa.resample(y, osr, sr)
48 |
49 | # 使用DIO算法计算音频的基频F0
50 | _f0, t = pw.dio(y, sr, f0_floor=50.0, f0_ceil=800.0, channels_in_octave=2, frame_period=pw.default_frame_period)
51 | print(_f0.shape)
52 |
53 | # 使用CheapTrick算法计算音频的频谱包络
54 | _sp = pw.cheaptrick(y, _f0, t, sr)
55 |
56 | # 计算aperiodic参数
57 | _ap = pw.d4c(y, _f0, t, sr)
58 |
59 | return _f0, _sp, _ap
60 |
61 | a = '/home/sean/Desktop/f0_tets/counddown_ori.wav'
62 | b = '/home/sean/Desktop/f0_tets/counddown_joe.wav'
63 |
64 | af0, asp, aap = get_feature(a)
65 | bf0, bsp, bap = get_feature(b)
66 |
67 | plt.plot(af0)
68 | plt.show()
69 | plt.plot(bf0)
70 | plt.show()
71 |
72 | bf0[18:1019] = (bf0[18:1019] > 0)*af0
73 | #
74 | # for i,f0 in enumerate(bf0):
75 | # if i>1 and f0 == 0:
76 | # bf0[i] = bf0[i-1]
77 |
78 |
79 |
80 | # 合成原始语音
81 | synthesized = pw.synthesize(bf0[18:1019]/2.5, bsp[18:1019], bap[18:1019], 32000, pw.default_frame_period)
82 | # 1.输出原始语音
83 | sf.write('./data/gen_wav/countdown.wav', synthesized, 32000)
84 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | SoundFile
2 | torch
3 | librosa
4 | matplotlib
5 | pysptk
6 | tensorflow
7 | tqdm
8 | pyworld
9 |
--------------------------------------------------------------------------------
/snapshots/aperiodic/aper_1649_2019-09-06_06-20-24:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/snapshots/aperiodic/aper_1649_2019-09-06_06-20-24
--------------------------------------------------------------------------------
/snapshots/harmonic/harm_1649_2019-09-06_07-03-37:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/snapshots/harmonic/harm_1649_2019-09-06_07-03-37
--------------------------------------------------------------------------------
/snapshots/vuv/vuv_1649_2019-09-06_06-04-09:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/seaniezhao/torch_npss/c49dcb97e6fc11b5ac026799fcbed14aa0ed34aa/snapshots/vuv/vuv_1649_2019-09-06_06-04-09
--------------------------------------------------------------------------------
/temp.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | file = open('harmonic0_0005.log', 'r')
6 |
7 | text_content = file.read()
8 | print(text_content)
9 |
10 | #train = '(?<=average loss:)\s*\d*\.\d*'
11 | pttern = re.compile(r'(?<=average loss: )\-*\s*\d*\.\d*')
12 | train_loss = np.array(re.findall(pttern, text_content)).astype(np.float32)
13 |
14 | t_pttern = re.compile(r'(?<=test loss: )\-*\s*\d*\.\d*')
15 | test_loss = np.array(re.findall(t_pttern, text_content)).astype(np.float32)
16 | test_loss[test_loss>1] /=100000
17 |
18 |
19 | lst_iter = [i for i in range(1650)]
20 |
21 | title = 'weight_decay_loss'
22 | plt.plot(train_loss, '-b', label='train')
23 | plt.plot(test_loss, '-r', label='test')
24 |
25 | plt.xlabel("n epoch")
26 | plt.title(title)
27 |
28 | # save image
29 | plt.savefig(title+".png") # should before show method
30 |
31 | # show
32 | plt.show()
--------------------------------------------------------------------------------
/train_aperoidic.py:
--------------------------------------------------------------------------------
1 | import hparams
2 | from model.wavenet_model import *
3 | from data.dataset import TimbreDataset
4 | from model.timbre_training import *
5 |
6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7 | model = WaveNetModel(hparams.create_aperiodic_hparams(), device).to(device)
8 | print('model: ', model)
9 | print('receptive field: ', model.receptive_field)
10 | print('parameter count: ', model.parameter_count())
11 |
12 | trainer = ModelTrainer(model=model,
13 | data_folder='data/timbre_model',
14 | lr=0.0005,
15 | weight_decay=0.0,
16 | snapshot_path='./snapshots/aperiodic',
17 | snapshot_name='aper',
18 | snapshot_interval=50000,
19 | device=device)
20 |
21 | #epoch = trainer.load_checkpoint('/Users/zhaowenxiao/pythonProj/torch_npss/snapshots/aperiodic/chaconne_model_1021_2019-03-30_09-32-23')
22 | print('start training...')
23 | trainer.train(batch_size=32,
24 | epochs=1650)
25 |
--------------------------------------------------------------------------------
/train_harmonoc.py:
--------------------------------------------------------------------------------
1 | import hparams
2 | from model.wavenet_model import *
3 | from model.timbre_training import *
4 |
5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
6 |
7 | model = WaveNetModel(hparams.create_harmonic_hparams(), device).to(device)
8 | print('model: ', model)
9 | print('receptive field: ', model.receptive_field)
10 | print('parameter count: ', model.parameter_count())
11 | trainer = ModelTrainer(model=model,
12 | data_folder='data/timbre_model',
13 | lr=0.0005,
14 | weight_decay=0.0001,
15 | snapshot_path='./snapshots/harmonic',
16 | snapshot_name='harm',
17 | snapshot_interval=2000,
18 | device=device)
19 |
20 |
21 | def exit_handler():
22 | trainer.save_model()
23 | print("exit from keyboard")
24 |
25 |
26 | #atexit.register(exit_handler)
27 |
28 | #epoch = trainer.load_checkpoint('/home/sean/pythonProj/torch_npss/snapshots/harmonic/best_harmonic_model_1649_2019-03-31_17-43-00')
29 |
30 | print('start training...')
31 | trainer.train(batch_size=32,
32 | epochs=1650)
33 |
--------------------------------------------------------------------------------
/train_script.py:
--------------------------------------------------------------------------------
1 | import hparams
2 | from model.wavenet_model import *
3 | from data.dataset import TimbreDataset
4 | from model.timbre_training import *
5 | import atexit
6 |
7 | import os
8 | from model_logging import *
9 | from scipy.io import wavfile
10 | #os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11 |
12 |
13 |
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 |
16 |
17 | model = WaveNetModel(hparams.create_harmonic_hparams(), device).to(device)
18 | print('model: ', model)
19 | print('receptive field: ', model.receptive_field)
20 | print('parameter count: ', model.parameter_count())
21 |
22 | trainer = ModelTrainer(model=model,
23 | data_folder='data/timbre_model',
24 | lr=0.0001,
25 | weight_decay=0.0,
26 | snapshot_path='./snapshots/harmonic',
27 | snapshot_name='chaconne_model',
28 | snapshot_interval=2000,
29 | device=device,
30 | temperature=0.05)
31 |
32 |
33 | def exit_handler():
34 | trainer.save_model()
35 | print("exit from keyboard")
36 |
37 |
38 | #atexit.register(exit_handler)
39 |
40 | #epoch = trainer.load_checkpoint('/home/sean/pythonProj/torch_npss/snapshots/harmonic/chaconne_model_930_2019-03-26_06-18-49')
41 |
42 | print('start training...')
43 | trainer.train(batch_size=128,
44 | epochs=1650)
45 |
46 |
47 |
48 |
49 | # model = WaveNetModel(hparams.create_aperiodic_hparams(), device).to(device)
50 | #
51 | # print('model: ', model)
52 | # print('receptive field: ', model.receptive_field)
53 | # print('parameter count: ', model.parameter_count())
54 | #
55 | # data = TimbreDataset(data_folder='data/timbre_model', receptive_field=model.receptive_field, type=1)
56 | #
57 | # print('the dataset has ' + str(len(data)) + ' items')
58 | #
59 | #
60 | #
61 | # trainer = TimbreTrainer(model=model,
62 | # dataset=data,
63 | # lr=0.0005,
64 | # weight_decay=0.0,
65 | # snapshot_path='./snapshots/aperiodic',
66 | # snapshot_name='chaconne_model',
67 | # snapshot_interval=50000,
68 | # device=device)
69 | #
70 | # print('start training...')
71 | # trainer.train(batch_size=32,
72 | # epochs=420)
73 |
74 |
--------------------------------------------------------------------------------
/train_vuv.py:
--------------------------------------------------------------------------------
1 | import hparams
2 | from model.wavenet_model import *
3 | from data.dataset import TimbreDataset
4 | from model.timbre_training import *
5 |
6 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7 |
8 | model = WaveNetModel(hparams.create_vuv_hparams(), device).to(device)
9 | print('model: ', model)
10 | print('receptive field: ', model.receptive_field)
11 | print('parameter count: ', model.parameter_count())
12 |
13 | trainer = ModelTrainer(model=model,
14 | data_folder='data/timbre_model',
15 | lr=0.0005,
16 | weight_decay=0.0,
17 | snapshot_path='./snapshots/vuv',
18 | snapshot_name='vuv',
19 | snapshot_interval=50000,
20 | device=device)
21 |
22 |
23 | print('start training...')
24 | trainer.train(batch_size=32,
25 | epochs=1650)
26 |
--------------------------------------------------------------------------------