├── models ├── __init__.py ├── one_tier │ ├── wavent.py │ └── one_tier.py └── two_tier │ ├── two_tier_generate16k.py │ ├── two_tier_generate32k.py │ ├── two_tier16k.py │ └── two_tier32k.py ├── datasets ├── __init__.py ├── music │ ├── drum-preprocess.sh │ ├── prune_flacs.py │ ├── sum_flacs.py │ ├── SNAREdrum-preprocessERRORS.md │ ├── preprocess.sh │ ├── download_archive_preprocess.sh │ ├── _2npy.py │ ├── preprocess.py │ ├── _drum2npy.py │ ├── log_mp3s │ ├── new_experiment16k.py │ ├── new_experiment32k.py │ └── drum-preprocess.py └── dataset.py ├── LICENSE ├── .gitignore ├── clean_results.py ├── lib ├── generate.py └── __init__.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/music/drum-preprocess.sh: -------------------------------------------------------------------------------- 1 | SCRIPTPATH=$( cd "$(dirname "$0")" ; pwd -P ) 2 | echo "Preprocessing" 3 | python drum-preprocess.py "$SCRIPTPATH" 4 | echo "Done!" 5 | 6 | echo "Writing datasets" 7 | python _drum2npy.py 8 | echo "Done!" 9 | -------------------------------------------------------------------------------- /datasets/music/prune_flacs.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import glob 4 | 5 | DIR = "." 6 | fs = glob.glob(DIR+"/*.flac") 7 | for f in fs: 8 | size = float(subprocess.check_output('ffprobe -i "{}/{}" -show_entries format=duration -v quiet -of csv="p=0"'.format(DIR, f), shell=True)) 9 | if size != 3.762563: 10 | print f 11 | print size -------------------------------------------------------------------------------- /datasets/music/sum_flacs.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import glob 4 | 5 | DIR = "." 6 | fs = glob.glob(DIR+"/*.wav") 7 | t = 0 8 | print 'counting...' 9 | for f in fs: 10 | size = float(subprocess.check_output('ffprobe -i "{}/{}" -show_entries format=duration -v quiet -of csv="p=0"'.format(DIR, f), shell=True)) 11 | t = t + size 12 | print t, ' seconds' -------------------------------------------------------------------------------- /datasets/music/SNAREdrum-preprocessERRORS.md: -------------------------------------------------------------------------------- 1 | # SNARE drum-preprocess.py ERRORS 2 | ## sample-rnn 3 | ## 4/6/2017 4 | 5 | B 6 | = 7 | 8 | ./p295d.flac 4.28575 9 | 10 | ./p1290d.flac 3.980813 11 | 12 | ./p1290u.flac 3.980813 13 | 14 | ./p295.flac 4.28575 15 | 16 | ./p295u.flac 4.28575 17 | 18 | ./p1290.flac 3.980813 19 | 20 | BR 21 | = 22 | 23 | ./p295d.flac 4.28575 24 | 25 | ./p1290d.flac 3.980813 26 | 27 | ./p1290u.flac 3.980813 28 | 29 | ./p295.flac 4.28575 30 | 31 | ./p295u.flac 4.28575 32 | 33 | ./p1290.flac 3.980813 34 | 35 | FR 36 | == 37 | 38 | ./p295d.flac 4.28575 39 | 40 | ./p1290d.flac 3.980813 41 | 42 | ./p1290u.flac 3.980813 43 | 44 | ./p295.flac 4.28575 45 | 46 | ./p295u.flac 4.28575 47 | 48 | ./p1290.flac 3.980813 49 | -------------------------------------------------------------------------------- /datasets/music/preprocess.sh: -------------------------------------------------------------------------------- 1 | # Requires 2GB of free disk space at most. 2 | SCRIPTPATH=$( cd "$(dirname "$0")" ; pwd -P ) 3 | echo "Converting from OGG to 16Khz, 16bit mono-channel WAV" 4 | # Next line with & executes in a forked shell in the background. That's parallel and not recommended. 5 | # Remove if causing problem 6 | #for file in "$DL_PATH"*_64kb.mp3; do ffmpeg -i "$file" -ar 16000 -ac 1 "$DL_PATH""`basename "$file" _64kb.mp3`.wav" & done 7 | for file in "$SCRIPTPATH"*.ogg; do 8 | ffmpeg -i "$file" -ar 16000 -ac 1 "$SCRIPTPATH""`basename "$file" .ogg`.wav" 9 | done 10 | echo "Cleaning up" 11 | rm "$SCRIPTPATH"*.ogg 12 | 13 | echo "Preprocessing" 14 | python preprocess.py "$SCRIPTPATH" 15 | echo "Done!" 16 | 17 | echo "Writing datasets" 18 | python _2npy.py 19 | echo "Done!" 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Soroush Mehri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/music/download_archive_preprocess.sh: -------------------------------------------------------------------------------- 1 | # Requires 2GB of free disk space at most. 2 | SCRIPTPATH=$( cd "$(dirname "$0")" ; pwd -P ) 3 | DL_PATH="$SCRIPTPATH"/download/ 4 | mkdir -p "$DL_PATH" 5 | echo "Downloading files to "$DL_PATH"" 6 | # See: https://blog.archive.org/2012/04/26/downloading-in-bulk-using-wget/ 7 | wget -r -H -nc -nH --cut-dir=1 -A .ogg -R *_vbr.mp3 -e robots=off -P "$DL_PATH" -l1 -i ./itemlist.txt -B 'http://archive.org/download/' 8 | echo "Organizing files and folders" 9 | mv "$DL_PATH"*/*.ogg "$DL_PATH" 10 | rmdir "$DL_PATH"*/ 11 | echo "Converting from OGG to 16Khz, 16bit mono-channel WAV" 12 | # Next line with & executes in a forked shell in the background. That's parallel and not recommended. 13 | # Remove if causing problem 14 | #for file in "$DL_PATH"*_64kb.mp3; do ffmpeg -i "$file" -ar 16000 -ac 1 "$DL_PATH""`basename "$file" _64kb.mp3`.wav" & done 15 | for file in "$DL_PATH"*.ogg; do 16 | ffmpeg -i "$file" -ar 16000 -ac 1 "$DL_PATH""`basename "$file" .ogg`.wav" 17 | done 18 | echo "Cleaning up" 19 | rm "$DL_PATH"*.ogg 20 | 21 | echo "Preprocessing" 22 | python preprocess.py "$DL_PATH" 23 | echo "Done!" 24 | 25 | echo "Writing datasets" 26 | python _2npy.py 27 | echo "Done!" 28 | -------------------------------------------------------------------------------- /datasets/music/_2npy.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | np = numpy 3 | import scipy.io.wavfile 4 | import scikits.audiolab 5 | 6 | import random 7 | import time 8 | import os 9 | import glob 10 | 11 | __RAND_SEED = 123 12 | def __fixed_shuffle(inp_list): 13 | if isinstance(inp_list, list): 14 | random.seed(__RAND_SEED) 15 | random.shuffle(inp_list) 16 | return 17 | #import collections 18 | #if isinstance(inp_list, (collections.Sequence)): 19 | if isinstance(inp_list, numpy.ndarray): 20 | numpy.random.seed(__RAND_SEED) 21 | numpy.random.shuffle(inp_list) 22 | return 23 | # destructive operations; in place; no need to return 24 | raise ValueError("inp_list is neither a list nor a numpy.ndarray but a "+type(inp_list)) 25 | 26 | data_path = os.path.abspath('./download/parts') 27 | print data_path 28 | 29 | paths = sorted(glob.glob(data_path+"/*.flac")) 30 | __fixed_shuffle(paths) 31 | 32 | arr = [(scikits.audiolab.flacread(p)[0]).astype('float16') for p in paths] 33 | np_arr = np.array(arr) 34 | 35 | # BETHOVEEN MUSIC DATA SET SPLIT 36 | np.save('all_music.npy', np_arr) 37 | np.save('music_train.npy', np_arr[:-2*256]) 38 | np.save('music_valid.npy', np_arr[-2*256:-256]) 39 | np.save('music_test.npy', np_arr[-256:]) 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | #lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | *.wav 91 | datasets/music/rev-preprocess.sh 92 | -------------------------------------------------------------------------------- /clean_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | cwd = os.getcwd() 4 | results_dir = os.path.join(cwd, 'results_2t') 5 | def get_subdirectories(a_dir): 6 | return [name for name in os.listdir(a_dir) 7 | if os.path.isdir(os.path.join(a_dir, name))] 8 | experiments = get_subdirectories(results_dir) 9 | num_epochs = 5 10 | hit_list = ["params_e"+str(n)+"_" for n in xrange(num_epochs)] 11 | unused_files = [] 12 | for e in experiments: 13 | e_dir = os.path.join(results_dir, e) 14 | params = os.path.join(e_dir, "params") 15 | for root, dirs, files in os.walk(params): 16 | for file in files: 17 | for hit in hit_list: 18 | if file.startswith(hit): 19 | print file 20 | unused_files.append(os.path.join(root, file)) 21 | def prompt_delete(num_prompts): 22 | num_prompts -= 1 23 | if num_prompts >= 0: 24 | prompt = input("Do you want to delete these "+str(len(unused_files))+" files? ['Y'/'n']") 25 | if prompt == "Y" or prompt == "yes": 26 | print 'removing old epochs...' 27 | for uf in unused_files: 28 | os.remove(uf) 29 | elif prompt == "n" or prompt == "no": 30 | print "clean aborted: 0 files deleted" 31 | else: 32 | print "warning:", prompt, "is an unknown command" 33 | prompt_delete(num_prompts) 34 | else: 35 | print "0 files deleted: Good-bye" 36 | if len(unused_files) > 0: 37 | prompt_delete(3) 38 | else: 39 | print 'found 0 files to clean: Good-bye' 40 | -------------------------------------------------------------------------------- /datasets/music/preprocess.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import subprocess 3 | 4 | RAW_DATA_DIR=str(sys.argv[1]) 5 | OUTPUT_DIR=os.path.join(RAW_DATA_DIR, "parts") 6 | os.makedirs(OUTPUT_DIR) 7 | print RAW_DATA_DIR 8 | print OUTPUT_DIR 9 | 10 | # Step 1: write all filenames to a list 11 | with open(os.path.join(OUTPUT_DIR, 'preprocess_file_list.txt'), 'w') as f: 12 | for dirpath, dirnames, filenames in os.walk(RAW_DATA_DIR): 13 | for filename in filenames: 14 | if filename.endswith(".wav"): 15 | f.write("file '" + dirpath + '/'+ filename + "'\n") 16 | 17 | # Step 2: concatenate everything into one massive wav file 18 | os.system("ffmpeg -f concat -safe 0 -i {}/preprocess_file_list.txt {}/preprocess_all_audio.wav".format(OUTPUT_DIR, OUTPUT_DIR)) 19 | audio = "preprocess_all_audio.wav" 20 | # # get the length of the resulting file 21 | length = float(subprocess.check_output('ffprobe -i {}/{} -show_entries format=duration -v quiet -of csv="p=0"'.format(OUTPUT_DIR, audio), shell=True)) 22 | print length, "DURATION" 23 | # reverse the audio file 24 | if sys.argv[2] == True: 25 | os.system("sox preprocess_all_audio.wav reverse_preprocess_audio.wav reverse") 26 | audio = "reverse_preprocess_audio.wav" 27 | # # Step 3: split the big file into 8-second chunks 28 | for i in xrange((int(length)//8 - 1)/3): 29 | os.system('ffmpeg -ss {} -t 8 -i {}/{} -ac 1 -ab 16k -ar 16000 {}/p{}.flac'.format(i, OUTPUT_DIR, audio, OUTPUT_DIR, i)) 30 | 31 | # # Step 4: clean up temp files 32 | #os.system('rm {}/preprocess_all_audio.wav'.format(OUTPUT_DIR)) 33 | os.system('rm {}/preprocess_file_list.txt'.format(OUTPUT_DIR)) 34 | -------------------------------------------------------------------------------- /datasets/music/_drum2npy.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | np = numpy 3 | import scipy.io.wavfile 4 | import scikits.audiolab 5 | 6 | import random 7 | import time 8 | import os 9 | import glob 10 | 11 | __RAND_SEED = 123 12 | def __fixed_shuffle(inp_list): 13 | if isinstance(inp_list, list): 14 | random.seed(__RAND_SEED) 15 | random.shuffle(inp_list) 16 | return 17 | #import collections 18 | #if isinstance(inp_list, (collections.Sequence)): 19 | if isinstance(inp_list, numpy.ndarray): 20 | numpy.random.seed(__RAND_SEED) 21 | numpy.random.shuffle(inp_list) 22 | return 23 | # destructive operations; in place; no need to return 24 | raise ValueError("inp_list is neither a list nor a numpy.ndarray but a "+type(inp_list)) 25 | 26 | data_paths = [ 27 | (os.path.abspath('./fr-parts')), 28 | (os.path.abspath('./br-parts')), 29 | (os.path.abspath('./f-parts')), 30 | (os.path.abspath('./b-parts')) 31 | ] 32 | 33 | for dp in data_paths: 34 | paths = sorted(glob.glob(dp+"/*.flac")) 35 | __fixed_shuffle(paths) 36 | arr = [(scikits.audiolab.flacread(p)[0]).astype('float16') for p in paths] 37 | np_arr = np.array(arr) 38 | print np_arr.shape 39 | """ BETHOVEEN MUSIC SPLIT 40 | np.save('all_music.npy', np_arr) 41 | np.save('music_train.npy', np_arr[:-2*256]) 42 | np.save('music_valid.npy', np_arr[-2*256:-256]) 43 | np.save('music_test.npy', np_arr[-256:]) 44 | """ 45 | # 88/6/6 split 46 | length = len(np_arr) 47 | train_size = int(np.floor(length * .88)) # train 48 | test_size = int(np.floor(length * .06)) # test 49 | np.save(dp+'all_drums.npy', np_arr) 50 | np.save(dp+'drums_train.npy', np_arr[:train_size]) 51 | np.save(dp+'drums_valid.npy', np_arr[train_size:train_size + test_size]) 52 | np.save(dp+'drums_test.npy', np_arr[train_size + test_size:]) 53 | -------------------------------------------------------------------------------- /datasets/music/log_mp3s: -------------------------------------------------------------------------------- 1 | download$ for f in *; do ffmpeg -i $f 2>&1 | grep Duration; done 2 | Duration: 00:22:18.52, start: 0.000000, bitrate: 320 kb/s 3 | Duration: 00:15:13.07, start: 0.000000, bitrate: 320 kb/s 4 | Duration: 00:13:44.23, start: 0.000000, bitrate: 320 kb/s 5 | Duration: 00:21:17.55, start: 0.000000, bitrate: 320 kb/s 6 | Duration: 00:24:03.82, start: 0.000000, bitrate: 320 kb/s 7 | Duration: 00:23:00.14, start: 0.000000, bitrate: 320 kb/s 8 | Duration: 00:21:24.58, start: 0.000000, bitrate: 320 kb/s 9 | Duration: 00:07:09.15, start: 0.000000, bitrate: 320 kb/s 10 | Duration: 00:07:20.90, start: 0.000000, bitrate: 320 kb/s 11 | Duration: 00:09:58.42, start: 0.000000, bitrate: 320 kb/s 12 | Duration: 00:10:17.88, start: 0.000000, bitrate: 320 kb/s 13 | Duration: 00:22:07.47, start: 0.000000, bitrate: 320 kb/s 14 | Duration: 00:09:47.16, start: 0.000000, bitrate: 320 kb/s 15 | Duration: 00:08:31.91, start: 0.000000, bitrate: 320 kb/s 16 | Duration: 00:07:00.63, start: 0.000000, bitrate: 320 kb/s 17 | Duration: 00:12:31.47, start: 0.000000, bitrate: 320 kb/s 18 | Duration: 00:19:19.51, start: 0.000000, bitrate: 320 kb/s 19 | Duration: 00:40:38.57, start: 0.000000, bitrate: 320 kb/s 20 | Duration: 00:26:01.98, start: 0.000000, bitrate: 320 kb/s 21 | Duration: 00:13:57.26, start: 0.000000, bitrate: 320 kb/s 22 | Duration: 00:16:23.42, start: 0.000000, bitrate: 320 kb/s 23 | Duration: 00:24:17.95, start: 0.025057, bitrate: 137 kb/s 24 | Duration: 00:17:26.14, start: 0.000000, bitrate: 320 kb/s 25 | Duration: 00:23:03.66, start: 0.000000, bitrate: 320 kb/s 26 | Duration: 00:20:31.32, start: 0.000000, bitrate: 320 kb/s 27 | Duration: 00:18:35.52, start: 0.000000, bitrate: 320 kb/s 28 | Duration: 00:25:45.52, start: 0.000000, bitrate: 320 kb/s 29 | Duration: 00:27:36.38, start: 0.000000, bitrate: 320 kb/s 30 | Duration: 00:16:26.45, start: 0.000000, bitrate: 320 kb/s 31 | Duration: 00:11:07.99, start: 0.000000, bitrate: 320 kb/s 32 | Duration: 00:24:12.24, start: 0.000000, bitrate: 320 kb/s 33 | Duration: 00:18:32.30, start: 0.000000, bitrate: 320 kb/s 34 | 35 | 560 minutes total 36 | -------------------------------------------------------------------------------- /datasets/music/new_experiment16k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os, subprocess, scikits.audiolab, random, time, glob 3 | 4 | PWD = os.getcwd() 5 | print 'PWD is', PWD 6 | #store dataset name 7 | DATASET_NAME = str(sys.argv[1]) 8 | DOWNLOAD_DIR = str(sys.argv[2]) 9 | print 'dl_dir is set to', DOWNLOAD_DIR 10 | #create the 11 | print "creating directory for", DATASET_NAME 12 | DATASET_DIR = os.path.join(PWD, DATASET_NAME) 13 | os.makedirs(DATASET_DIR) 14 | #move samples from directory to use dataset name 15 | print "moving samples" 16 | types = {'wav', "mp3"} 17 | for t in types: 18 | os.system('mv {}/*.{} {}/'.format(DOWNLOAD_DIR, t, DATASET_DIR)) 19 | #run proprocess 20 | print "preprocessing" 21 | OUTPUT_DIR=os.path.join(DATASET_DIR, "parts") 22 | os.makedirs(OUTPUT_DIR) 23 | # Step 1: write all filenames to a list 24 | with open(os.path.join(DATASET_DIR, 'preprocess_file_list.txt'), 'w') as f: 25 | for dirpath, dirnames, filenames in os.walk(DATASET_DIR): 26 | for filename in filenames: 27 | if filename.endswith(".wav") or filename.endswith("mp3"): 28 | f.write("file '" + dirpath + '/'+ filename + "'\n") 29 | 30 | # Step 2: concatenate everything into one massive wav file 31 | print "concatenate all files" 32 | os.system('pwd') 33 | os.system("ffmpeg -f concat -safe 0 -i {}/preprocess_file_list.txt {}/preprocess_all_audio.wav".format(DATASET_DIR, OUTPUT_DIR)) 34 | audio = "preprocess_all_audio.wav" 35 | print "get length" 36 | # # get the length of the resulting file 37 | length = float(subprocess.check_output('ffprobe -i {}/{} -show_entries format=duration -v quiet -of csv="p=0"'.format(OUTPUT_DIR, audio), shell=True)) 38 | print length, "DURATION" 39 | print "print big file into chunks" 40 | # # Step 3: split the big file into 8-second chunks 41 | # overlapping 3 times per 8 seconds 42 | ''' 43 | for i in xrange(int((length//8)*3)-1): 44 | time = (i * 8 )/ 3 45 | os.system('ffmpeg -ss {} -t 8 -i {}/preprocess_all_audio.wav -ac 1 -ab 16k -ar 16000 {}/p{}.flac'.format(time, OUTPUT_DIR, OUTPUT_DIR, i)) 46 | ''' 47 | size = 8 48 | num = 3200 49 | for i in xrange(0, num): 50 | time = i * ((length-size)/float(num)) 51 | os.system('ffmpeg -ss {} -t 8 -i {}/preprocess_all_audio.wav -ac 1 -ab 16k -ar 16000 {}/p{}.flac'.format(time, OUTPUT_DIR, OUTPUT_DIR, i)) 52 | print "clean up" 53 | # # Step 4: clean up temp files 54 | os.system('rm {}/preprocess_all_audio.wav'.format(OUTPUT_DIR)) 55 | os.system('rm {}/preprocess_file_list.txt'.format(DATASET_DIR)) 56 | print 'save as .npy' 57 | __RAND_SEED = 123 58 | def __fixed_shuffle(inp_list): 59 | if isinstance(inp_list, list): 60 | random.seed(__RAND_SEED) 61 | random.shuffle(inp_list) 62 | return 63 | #import collections 64 | #if isinstance(inp_list, (collections.Sequence)): 65 | if isinstance(inp_list, numpy.ndarray): 66 | numpy.random.seed(__RAND_SEED) 67 | numpy.random.shuffle(inp_list) 68 | return 69 | # destructive operations; in place; no need to return 70 | raise ValueError("inp_list is neither a list nor a numpy.ndarray but a "+type(inp_list)) 71 | 72 | paths = sorted(glob.glob(OUTPUT_DIR+"/*.flac")) 73 | __fixed_shuffle(paths) 74 | 75 | arr = [(scikits.audiolab.flacread(p)[0]).astype('float16') for p in paths] 76 | np_arr = np.array(arr) 77 | # 88/6/6 split 78 | length = len(np_arr) 79 | train_size = int(np.floor(length * .88)) # train 80 | test_size = int(np.floor(length * .06)) # test 81 | 82 | np.save(os.path.join(DATASET_DIR,'all_music.npy'), np_arr) 83 | np.save(os.path.join(DATASET_DIR,'music_train.npy'), np_arr[:train_size]) 84 | np.save(os.path.join(DATASET_DIR,'music_valid.npy'), np_arr[train_size:train_size + test_size]) 85 | np.save(os.path.join(DATASET_DIR,'music_test.npy'), np_arr[train_size + test_size:]) 86 | 87 | #pass dataset name through two_tier.py || three_tier.py to datasets.py -------------------------------------------------------------------------------- /datasets/music/new_experiment32k.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys, os, subprocess, scikits.audiolab, random, time, glob 3 | 4 | PWD = os.getcwd() 5 | print 'PWD is', PWD 6 | #store dataset name 7 | DATASET_NAME = str(sys.argv[1]) 8 | DOWNLOAD_DIR = str(sys.argv[2]) 9 | print 'dl_dir is set to', DOWNLOAD_DIR 10 | #create the 11 | print "creating directory for", DATASET_NAME 12 | DATASET_DIR = os.path.join(PWD, DATASET_NAME) 13 | os.makedirs(DATASET_DIR) 14 | #move samples from directory to use dataset name 15 | print "moving samples" 16 | types = {'wav', "mp3"} 17 | for t in types: 18 | os.system('mv {}/*.{} {}/'.format(DOWNLOAD_DIR, t, DATASET_DIR)) 19 | #run proprocess 20 | print "preprocessing" 21 | OUTPUT_DIR=os.path.join(DATASET_DIR, "parts") 22 | os.makedirs(OUTPUT_DIR) 23 | # Step 1: write all filenames to a list 24 | with open(os.path.join(DATASET_DIR, 'preprocess_file_list.txt'), 'w') as f: 25 | for dirpath, dirnames, filenames in os.walk(DATASET_DIR): 26 | for filename in filenames: 27 | if filename.endswith(".wav") or filename.endswith("mp3"): 28 | f.write("file '" + dirpath + '/'+ filename + "'\n") 29 | 30 | # Step 2: concatenate everything into one massive wav file 31 | print "concatenate all files" 32 | os.system('pwd') 33 | os.system("ffmpeg -f concat -safe 0 -i {}/preprocess_file_list.txt {}/preprocess_all_audio.wav".format(DATASET_DIR, OUTPUT_DIR)) 34 | audio = "preprocess_all_audio.wav" 35 | print "get length" 36 | # # get the length of the resulting file 37 | length = float(subprocess.check_output('ffprobe -i {}/{} -show_entries format=duration -v quiet -of csv="p=0"'.format(OUTPUT_DIR, audio), shell=True)) 38 | print length, "DURATION" 39 | print "print big file into chunks" 40 | # # Step 3: split the big file into 8-second chunks 41 | # overlapping 3 times per 8 seconds 42 | ''' 43 | for i in xrange(int((length//8)*3)-1): 44 | time = (i * 8 )/ 3 45 | os.system('ffmpeg -ss {} -t 8 -i {}/preprocess_all_audio.wav -ac 1 -ab 16k -ar 16000 {}/p{}.flac'.format(time, OUTPUT_DIR, OUTPUT_DIR, i)) 46 | ''' 47 | size = 8 48 | num = 3200 49 | for i in xrange(0, num): 50 | time = i * ((length-size)/float(num)) 51 | os.system('ffmpeg -ss {} -t 8 -i {}/preprocess_all_audio.wav -ac 1 -ab 16k -ar 32000 {}/p{}.flac'.format(time, OUTPUT_DIR, OUTPUT_DIR, i)) 52 | print "clean up" 53 | # # Step 4: clean up temp files 54 | os.system('rm {}/preprocess_all_audio.wav'.format(OUTPUT_DIR)) 55 | os.system('rm {}/preprocess_file_list.txt'.format(DATASET_DIR)) 56 | print 'save as .npy' 57 | __RAND_SEED = 123 58 | def __fixed_shuffle(inp_list): 59 | if isinstance(inp_list, list): 60 | random.seed(__RAND_SEED) 61 | random.shuffle(inp_list) 62 | return 63 | #import collections 64 | #if isinstance(inp_list, (collections.Sequence)): 65 | if isinstance(inp_list, numpy.ndarray): 66 | numpy.random.seed(__RAND_SEED) 67 | numpy.random.shuffle(inp_list) 68 | return 69 | # destructive operations; in place; no need to return 70 | raise ValueError("inp_list is neither a list nor a numpy.ndarray but a "+type(inp_list)) 71 | 72 | paths = sorted(glob.glob(OUTPUT_DIR+"/*.flac")) 73 | __fixed_shuffle(paths) 74 | 75 | arr = [(scikits.audiolab.flacread(p)[0]).astype('float16') for p in paths] 76 | np_arr = np.array(arr) 77 | # 88/6/6 split 78 | length = len(np_arr) 79 | train_size = int(np.floor(length * .88)) # train 80 | test_size = int(np.floor(length * .06)) # test 81 | 82 | np.save(os.path.join(DATASET_DIR,'all_music.npy'), np_arr) 83 | np.save(os.path.join(DATASET_DIR,'music_train.npy'), np_arr[:train_size]) 84 | np.save(os.path.join(DATASET_DIR,'music_valid.npy'), np_arr[train_size:train_size + test_size]) 85 | np.save(os.path.join(DATASET_DIR,'music_test.npy'), np_arr[train_size + test_size:]) 86 | 87 | #pass dataset name through two_tier.py || three_tier.py to datasets.py -------------------------------------------------------------------------------- /lib/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | import scipy.io.wavfile 4 | import glob 5 | import sys 6 | import numpy 7 | import pickle 8 | import theano 9 | import theano.tensor as T 10 | 11 | tag = sys.argv[1] 12 | name = glob.glob("../results*/" + tag + "/args.pkl")[0] 13 | params = pickle.load(open(name, "r")) 14 | print params 15 | info = {} 16 | for p in xrange(1,len(params),2): 17 | if p+1 < len(params): 18 | info[params[p][2:]] = params[p+1] 19 | print info 20 | #exit() 21 | 22 | Q_TYPE = info["q_type"] 23 | Q_LEVELS = int(info["q_levels"]) 24 | N_RNN = int(info["n_rnn"]) 25 | DIM = int(info["dim"]) 26 | FRAME_SIZE = int(info["frame_size"]) 27 | 28 | 29 | #{'dim': '1024', 'q_type': 'linear', 'learn_h0': 'True', 'weight_norm': 'True', 'q_levels': '256', 'skip_conn': 'False', 'batch_size': '128', 'n_frames': '64', 'emb_size': '256', 'exp': 'KURT2x4', 'frame_size': '16', 'which_set': 'KURT', 'rnn_type': 'GRU', 'n_rnn': '4'} 30 | 31 | ###grab this stuff 32 | #args 33 | #Q_TYPE 34 | #Q_TEVELS 35 | #N_RNN 36 | #DIM 37 | #FRAME_SIZE 38 | 39 | BITRATE = 16000 40 | N_SEQS = 20 # Number of samples to generate every time monitoring. 41 | Q_ZERO = numpy.int32(Q_LEVELS//2) # Discrete value correponding to zero amplitude 42 | H0_MULT = 1 43 | 44 | RESULTS_DIR = 'results_2t' 45 | RESULTS_DIR = name.split("/")[1] 46 | print RESULTS_DIR 47 | 48 | FOLDER_PREFIX = os.path.join(RESULTS_DIR, tag) 49 | ### Create directories ### 50 | # FOLDER_PREFIX: root, contains: 51 | # log.txt, __note.txt, train_log.pkl, train_log.png [, model_settings.txt] 52 | # FOLDER_PREFIX/samples: keeps all checkpoint samples as wav 53 | SAMPLES_PATH = os.path.join(FOLDER_PREFIX, 'samples') 54 | 55 | print SAMPLES_PATH 56 | # Uniform [-0.5, 0.5) for half of initial state for generated samples 57 | # to study the behaviour of the model and also to introduce some diversity 58 | # to samples in a simple way. [it's disabled for now] 59 | sequences = T.imatrix('sequences') 60 | h0 = T.tensor3('h0') 61 | reset = T.iscalar('reset') 62 | mask = T.matrix('mask') 63 | fixed_rand_h0 = numpy.random.rand(N_SEQS//2, N_RNN, H0_MULT*DIM) 64 | fixed_rand_h0 -= 0.5 65 | fixed_rand_h0 = fixed_rand_h0.astype('float32') 66 | 67 | def generate_and_save_samples(): 68 | # Sampling at frame level 69 | frame_level_generate_fn = theano.function( 70 | [sequences, h0, reset], 71 | frame_level_rnn(sequences, h0, reset), 72 | on_unused_input='warn' 73 | ) 74 | def write_audio_file(name, data): 75 | data = data.astype('float32') 76 | data -= data.min() 77 | data /= data.max() 78 | data -= 0.5 79 | data *= 0.95 80 | scipy.io.wavfile.write( 81 | os.path.join(SAMPLES_PATH, name+'.wav'), 82 | BITRATE, 83 | data) 84 | 85 | total_time = time() 86 | # Generate N_SEQS' sample files, each 5 seconds long 87 | N_SECS = 5 88 | LENGTH = N_SECS*BITRATE 89 | 90 | samples = numpy.zeros((N_SEQS, LENGTH), dtype='int32') 91 | samples[:, :FRAME_SIZE] = Q_ZERO 92 | 93 | # First half zero, others fixed random at each checkpoint 94 | h0 = numpy.zeros( 95 | (N_SEQS-fixed_rand_h0.shape[0], N_RNN, H0_MULT*DIM), 96 | dtype='float32' 97 | ) 98 | h0 = numpy.concatenate((h0, fixed_rand_h0), axis=0) 99 | frame_level_outputs = None 100 | 101 | for t in xrange(FRAME_SIZE, LENGTH): 102 | 103 | if t % FRAME_SIZE == 0: 104 | frame_level_outputs, h0 = frame_level_generate_fn( 105 | samples[:, t-FRAME_SIZE:t], 106 | h0, 107 | #numpy.full((N_SEQS, ), (t == FRAME_SIZE), dtype='int32'), 108 | numpy.int32(t == FRAME_SIZE) 109 | ) 110 | 111 | samples[:, t] = sample_level_generate_fn( 112 | frame_level_outputs[:, t % FRAME_SIZE], 113 | samples[:, t-FRAME_SIZE:t], 114 | ) 115 | 116 | total_time = time() - total_time 117 | log = "{} samples of {} seconds length generated in {} seconds." 118 | log = log.format(N_SEQS, N_SECS, total_time) 119 | print log, 120 | 121 | for i in xrange(N_SEQS): 122 | samp = samples[i] 123 | if Q_TYPE == 'mu-law': 124 | from datasets.dataset import mu2linear 125 | samp = mu2linear(samp) 126 | elif Q_TYPE == 'a-law': 127 | raise NotImplementedError('a-law is not implemented') 128 | write_audio_file("sample_{}_{}".format(tag, i), samp) 129 | 130 | generate_and_save_samples() -------------------------------------------------------------------------------- /datasets/music/drum-preprocess.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import subprocess 3 | # requires sox, ffmpeg, and ffprobe command line tools 4 | 5 | RAW_DATA_DIR=str(sys.argv[1]) 6 | TEMP_DIR=os.path.join(RAW_DATA_DIR, "temp") 7 | FR_DIR=os.path.join(RAW_DATA_DIR, "fr-parts") 8 | BR_DIR=os.path.join(RAW_DATA_DIR, "br-parts") 9 | F_DIR=os.path.join(RAW_DATA_DIR, "f-parts") 10 | B_DIR=os.path.join(RAW_DATA_DIR, "b-parts") 11 | SAMPLE_RATE = 16000 12 | os.makedirs(TEMP_DIR) 13 | os.makedirs(FR_DIR) 14 | os.makedirs(BR_DIR) 15 | os.makedirs(F_DIR) 16 | os.makedirs(B_DIR) 17 | 18 | def createParts(): 19 | def renderFlacs(fr, br, f, b): 20 | os.system('ffmpeg -i {}/{}_temp.wav -ac 1 -ab 16k -ar {} {}/p{}.flac'.format(TEMP_DIR, fr, SAMPLE_RATE, FR_DIR, i))#convert part to flac 21 | os.system('ffmpeg -i {}/{}_temp.wav -ac 1 -ab 16k -ar {} {}/p{}.flac'.format(TEMP_DIR, br, SAMPLE_RATE, BR_DIR, i))#convert part to flac 22 | os.system('ffmpeg -i {}/{}_temp.wav -ac 1 -ab 16k -ar {} {}/p{}.flac'.format(TEMP_DIR, f, SAMPLE_RATE, F_DIR, i))#convert part to flac 23 | os.system('ffmpeg -i {}/{}_temp.wav -ac 1 -ab 16k -ar {} {}/p{}.flac'.format(TEMP_DIR, b, SAMPLE_RATE, B_DIR, i))#convert part to flac 24 | #pitch down 25 | os.system('ffmpeg -i {}/{}_down.wav -ac 1 -ab 16k -ar {} {}/p{}d.flac'.format(TEMP_DIR, fr, SAMPLE_RATE, FR_DIR, i))#convert part to flac 26 | os.system('ffmpeg -i {}/{}_down.wav -ac 1 -ab 16k -ar {} {}/p{}d.flac'.format(TEMP_DIR, br, SAMPLE_RATE, BR_DIR, i))#convert part to flac 27 | os.system('ffmpeg -i {}/{}_down.wav -ac 1 -ab 16k -ar {} {}/p{}d.flac'.format(TEMP_DIR, f, SAMPLE_RATE, F_DIR, i))#convert part to flac 28 | os.system('ffmpeg -i {}/{}_down.wav -ac 1 -ab 16k -ar {} {}/p{}d.flac'.format(TEMP_DIR, b, SAMPLE_RATE, B_DIR, i))#convert part to flac 29 | #pitch up 30 | os.system('ffmpeg -i {}/{}_up.wav -ac 1 -ab 16k -ar {} {}/p{}u.flac'.format(TEMP_DIR, fr, SAMPLE_RATE, FR_DIR, i))#convert part to flac 31 | os.system('ffmpeg -i {}/{}_up.wav -ac 1 -ab 16k -ar {} {}/p{}u.flac'.format(TEMP_DIR, br, SAMPLE_RATE, BR_DIR, i))#convert part to flac 32 | os.system('ffmpeg -i {}/{}_up.wav -ac 1 -ab 16k -ar {} {}/p{}u.flac'.format(TEMP_DIR, f, SAMPLE_RATE, F_DIR, i))#convert part to flac 33 | os.system('ffmpeg -i {}/{}_up.wav -ac 1 -ab 16k -ar {} {}/p{}u.flac'.format(TEMP_DIR, b, SAMPLE_RATE, B_DIR, i))#convert part to flac 34 | #initial preparation 35 | os.system('ffmpeg -i "{}" -ac 1 -ab 16k -ar {} {}/this_temp.wav'.format(full_name, SAMPLE_RATE, TEMP_DIR)) #resample this file as mono 16000smpls/s 36 | this_length = float(subprocess.check_output('ffprobe -i {}/this_temp.wav -show_entries format=duration -v quiet -of csv="p=0"'.format(TEMP_DIR), shell=True)) #check length of resampled audio 37 | print full_name, ':', this_length, 'DURATION' 38 | pad_length = longest_length - this_length 39 | os.system('sox {}/this_temp.wav {}/r_temp.wav reverse'.format(TEMP_DIR, TEMP_DIR)) # reverse file 40 | if pad_length > 0.: # every audiofile except the largest 41 | #create temp files 42 | os.system('ffmpeg -f lavfi -i anullsrc=channel_layout=mono:sample_rate={} -t {} {}/anullsrc_temp.wav'.format(SAMPLE_RATE, pad_length, TEMP_DIR)) #create anullsrc_temp.wav zero-pad 43 | os.system('sox {}/anullsrc_temp.wav {}/r_temp.wav {}/fr_temp.wav'.format(TEMP_DIR, TEMP_DIR, TEMP_DIR)) #FR 44 | os.system('sox {}/r_temp.wav {}/anullsrc_temp.wav {}/br_temp.wav'.format(TEMP_DIR, TEMP_DIR, TEMP_DIR)) #BR 45 | os.system('sox {}/anullsrc_temp.wav {}/this_temp.wav {}/f_temp.wav'.format(TEMP_DIR, TEMP_DIR, TEMP_DIR)) #F 46 | os.system('sox {}/this_temp.wav {}/anullsrc_temp.wav {}/b_temp.wav'.format(TEMP_DIR, TEMP_DIR, TEMP_DIR)) #B 47 | # extend the data set by copying and repitching each sample up+down 1 semitone 48 | os.system('sox {}/fr_temp.wav {}/fr_down.wav pitch -100'.format(TEMP_DIR, TEMP_DIR))#FR down 49 | os.system('sox {}/br_temp.wav {}/br_down.wav pitch -100'.format(TEMP_DIR, TEMP_DIR))#BR down 50 | os.system('sox {}/f_temp.wav {}/f_down.wav pitch -100'.format(TEMP_DIR, TEMP_DIR))#F down 51 | os.system('sox {}/b_temp.wav {}/b_down.wav pitch -100'.format(TEMP_DIR, TEMP_DIR))#B down 52 | os.system('sox {}/fr_temp.wav {}/fr_up.wav pitch 100'.format(TEMP_DIR, TEMP_DIR))#FR up 53 | os.system('sox {}/br_temp.wav {}/br_up.wav pitch 100'.format(TEMP_DIR, TEMP_DIR))#BR up 54 | os.system('sox {}/f_temp.wav {}/f_up.wav pitch 100'.format(TEMP_DIR, TEMP_DIR))#F up 55 | os.system('sox {}/b_temp.wav {}/b_up.wav pitch 100'.format(TEMP_DIR, TEMP_DIR))#D up 56 | #final export 57 | renderFlacs('fr', 'br', 'f', 'b') #render parts 58 | #clean up temp files 59 | os.system('rm {}/anullsrc_temp.wav'.format(TEMP_DIR)) 60 | os.system('rm {}/fr_down.wav'.format(TEMP_DIR)) 61 | os.system('rm {}/br_down.wav'.format(TEMP_DIR)) 62 | os.system('rm {}/f_down.wav'.format(TEMP_DIR)) 63 | os.system('rm {}/b_down.wav'.format(TEMP_DIR)) 64 | os.system('rm {}/fr_up.wav'.format(TEMP_DIR)) 65 | os.system('rm {}/br_up.wav'.format(TEMP_DIR)) 66 | os.system('rm {}/f_up.wav'.format(TEMP_DIR)) 67 | os.system('rm {}/b_up.wav'.format(TEMP_DIR)) 68 | else: #longest file 69 | # extend the data set by copying and repitching each sample up+down 1 semitone 70 | os.system('sox {}/this_temp.wav {}/r_up.wav pitch 100'.format(TEMP_DIR, TEMP_DIR))# up 71 | os.system('sox {}/this_temp.wav {}/r_down.wav pitch -100'.format(TEMP_DIR, TEMP_DIR))# down 72 | os.system('sox {}/r_temp.wav {}/this_up.wav pitch 100'.format(TEMP_DIR, TEMP_DIR))#r up 73 | os.system('sox {}/r_temp.wav {}/this_down.wav pitch -100'.format(TEMP_DIR, TEMP_DIR))#r down 74 | # final export 75 | renderFlacs('r', 'r', 'this', 'this') 76 | #clean up temp files 77 | os.system('rm {}/r_up.wav'.format(TEMP_DIR)) 78 | os.system('rm {}/r_down.wav'.format(TEMP_DIR)) 79 | os.system('rm {}/this_up.wav'.format(TEMP_DIR)) 80 | os.system('rm {}/this_down.wav'.format(TEMP_DIR)) 81 | os.system('rm {}/r_temp.wav'.format(TEMP_DIR)) 82 | os.system('rm {}/this_temp.wav'.format(TEMP_DIR)) 83 | 84 | # Step 1: Find the largest file size in the audio dataset 85 | objects = os.listdir(RAW_DATA_DIR) 86 | sofar = 0 87 | largest = "" 88 | for item in objects: 89 | if ".wav" in item: 90 | size = os.path.getsize(item) 91 | if size > sofar: 92 | sofar = size 93 | largest = item 94 | 95 | print "Largest file is ", sofar 96 | print largest 97 | os.system('ffmpeg -i "{}" -ac 1 -ab 16k -ar {} {}/longest_temp.wav'.format(largest, SAMPLE_RATE, TEMP_DIR)) #resample the largest file as mono 98 | longest_length = float(subprocess.check_output('ffprobe -i {}/longest_temp.wav -show_entries format=duration -v quiet -of csv="p=0"'.format(TEMP_DIR), shell=True)) 99 | #clean up longest temp wav 100 | os.system('rm {}/longest_temp.wav'.format(TEMP_DIR)) 101 | 102 | i = 0 103 | for dirpath, dirnames, filenames in os.walk(RAW_DATA_DIR): 104 | for filename in filenames: 105 | if filename.endswith(".wav"): 106 | full_name = dirpath + '/'+ filename # raw audio file 107 | createParts() 108 | i += 1 109 | #remove empty temp dir 110 | #os.system('rmdir {}'.format(TEMP_DIR)) 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SampleRNN (ZVK fork) 2 | 3 | Code accompanying the paper [SampleRNN: An Unconditional End-to-End Neural Audio Generation Model](https://openreview.net/forum?id=SkxKPDv5xl). Samples are available [here](https://soundcloud.com/samplernn/sets). 4 | 5 | ## Features of the ZVK fork: 6 | 7 | - auto-preprocessing (audio conversion, concatenation, chunking, and saving .npy files) 8 | - generate scripts for trained datasets 9 | - scripts for different sample rates are available (16k, 32k) 10 | - any processed datasets can be loaded into the two-tier network via arguments 11 | - sampling is picked from distribution (not max) 12 | - any number of RNN layers is now possible (until you run out of memory) 13 | 14 | ## Dependencies 15 | - cuDNN 5105 16 | - Python 2.7.12 17 | - Numpy 1.11.1 18 | - Theano 0.9.0rc3 or 1.0 19 | - Lasagne 0.2.dev1 20 | 21 | ## Datasets 22 | To preprocess audio for a 32k new experiment, place your audio here: 23 | ``` 24 | datasets/music/downloads/ 25 | ``` 26 | then run the new experiment python script located in the datasets/music directory: 27 | 28 | ``` 29 | cd datasets/music/ 30 | python new_experiment32k.py your_datasets_name downloads/ 31 | ``` 32 | 33 | ## Training 34 | To train a model on an existing dataset with accelerated GPU processing, you need to run following lines from the root of `sampleRNN_ICLR2017` folder which corresponds to the best found set of hyper-paramters. 35 | 36 | Mission control center: 37 | ``` 38 | $ pwd 39 | /root/zvk/sampleRNN_ICLR2017 40 | ``` 41 | ### SampleRNN (2-tier) 42 | ``` 43 | $ python models/two_tier/two_tier32k.py -h 44 | usage: two_tier.py [-h] [--exp EXP] --n_frames N_FRAMES --frame_size 45 | FRAME_SIZE --weight_norm WEIGHT_NORM --emb_size EMB_SIZE 46 | --skip_conn SKIP_CONN --dim DIM --n_rnn {1,2,3,4,5} 47 | --rnn_type {LSTM,GRU} --learn_h0 LEARN_H0 --q_levels 48 | Q_LEVELS --q_type {linear,a-law,mu-law} --which_set 49 | {ONOM,BLIZZ,MUSIC} --batch_size {64,128,256} [--debug] 50 | [--resume] 51 | 52 | two_tier.py No default value! Indicate every argument. 53 | 54 | optional arguments: 55 | -h, --help show this help message and exit 56 | --exp EXP Experiment name 57 | --n_frames N_FRAMES How many "frames" to include in each Truncated BPTT 58 | pass 59 | --frame_size FRAME_SIZE 60 | How many samples per frame 61 | --weight_norm WEIGHT_NORM 62 | Adding learnable weight normalization to all the 63 | linear layers (except for the embedding layer) 64 | --emb_size EMB_SIZE Size of embedding layer (0 to disable) 65 | --skip_conn SKIP_CONN 66 | Add skip connections to RNN 67 | --dim DIM Dimension of RNN and MLPs 68 | --n_rnn {1,2,3,4,5,6,7,8,9,10,11,12,n,...} 69 | Number of layers in the stacked RNN 70 | --rnn_type {LSTM,GRU} 71 | GRU or LSTM 72 | --learn_h0 LEARN_H0 Whether to learn the initial state of RNN 73 | --q_levels Q_LEVELS Number of bins for quantization of audio samples. 74 | Should be 256 for mu-law. 75 | --q_type {linear,a-law,mu-law} 76 | Quantization in linear-scale, a-law-companding, or mu- 77 | law compandig. With mu-/a-law quantization level shoud 78 | be set as 256 79 | --which_set WHICH_SET any preprocessed set in the datasets/music/ directory 80 | --batch_size {64,128,256} 81 | size of mini-batch 82 | --debug Debug mode 83 | --resume Resume the same model from the last checkpoint. Order 84 | of params are important. [for now] 85 | ``` 86 | To run: 87 | ``` 88 | $ THEANO_FLAGS=mode=FAST_RUN,device=gpu0,floatX=float32 python -u models/two_tier/two_tier32.py --exp BEST_2TIER --n_frames 64 --frame_size 16 --emb_size 256 --skip_conn False --dim 1024 --n_rnn 3 --rnn_type GRU --q_levels 256 --q_type linear --batch_size 128 --weight_norm True --learn_h0 True --which_set user_dataset_name 89 | ``` 90 | ### SampleRNN (3-tier) 91 | ``` 92 | $ python models/three_tier/three_tier.py -h 93 | usage: three_tier16k.py [-h] [--exp EXP] --seq_len SEQ_LEN --big_frame_size 94 | BIG_FRAME_SIZE --frame_size FRAME_SIZE --weight_norm 95 | WEIGHT_NORM --emb_size EMB_SIZE --skip_conn SKIP_CONN 96 | --dim DIM --n_rnn {1,2,3,4,5} --rnn_type {LSTM,GRU} 97 | --learn_h0 LEARN_H0 --q_levels Q_LEVELS --q_type 98 | {linear,a-law,mu-law} --which_set {ONOM,BLIZZ,MUSIC} 99 | --batch_size {64,128,256} [--debug] [--resume] 100 | 101 | three_tier.py No default value! Indicate every argument. 102 | 103 | optional arguments: 104 | -h, --help show this help message and exit 105 | --exp EXP Experiment name 106 | --seq_len SEQ_LEN How many samples to include in each Truncated BPTT 107 | pass 108 | --big_frame_size BIG_FRAME_SIZE 109 | How many samples per big frame in tier 3 110 | --frame_size FRAME_SIZE 111 | How many samples per frame in tier 2 112 | --weight_norm WEIGHT_NORM 113 | Adding learnable weight normalization to all the 114 | linear layers (except for the embedding layer) 115 | --emb_size EMB_SIZE Size of embedding layer (> 0) 116 | --skip_conn SKIP_CONN 117 | Add skip connections to RNN 118 | --dim DIM Dimension of RNN and MLPs 119 | --n_rnn {1,2,3,4,5} Number of layers in the stacked RNN 120 | --rnn_type {LSTM,GRU} 121 | GRU or LSTM 122 | --learn_h0 LEARN_H0 Whether to learn the initial state of RNN 123 | --q_levels Q_LEVELS Number of bins for quantization of audio samples. 124 | Should be 256 for mu-law. 125 | --q_type {linear,a-law,mu-law} 126 | Quantization in linear-scale, a-law-companding, or mu- 127 | law compandig. With mu-/a-law quantization level shoud 128 | be set as 256 129 | --which_set WHICH_SET 130 | any preprocessed set in the datasets/music/ directory 131 | --batch_size {64,128,256} 132 | size of mini-batch 133 | --debug Debug mode 134 | --resume Resume the same model from the last checkpoint. Order 135 | of params are important. [for now] 136 | ``` 137 | To run: 138 | ``` 139 | $ THEANO_FLAGS=mode=FAST_RUN,device=gpu0,floatX=float32 python -u models/two_tier/two_tier32k.py --exp BEST_2TIER --seq_len 512 --big_frame_size 8 --frame_size 2 --emb_size 256 --skip_conn False --dim 1024 --n_rnn 1 --rnn_type GRU --q_levels 256 --q_type linear --batch_size 128 --weight_norm True --learn_h0 True --which_set your_dataset_name 140 | ``` 141 | 142 | To generate 5 sequences (10 seconds each) from a trained model: 143 | ``` 144 | $ THEANO_FLAGS=mode=FAST_RUN,device=gpu0,floatX=float32 python -u models/two_tier/two_tier_generate32k.py --exp BEST_2TIER --seq_len 512 --big_frame_size 8 --frame_size 2 --emb_size 256 --skip_conn False --dim 1024 --n_rnn 1 --rnn_type GRU --q_levels 256 --q_type linear --batch_size 128 --weight_norm True --learn_h0 True --which_set your_dataset_name --n_secs 10 --n_seqs 5 145 | ``` 146 | 147 | ## Reference 148 | If you are using this code, please cite the paper. 149 | 150 | SampleRNN: An Unconditional End-to-End Neural Audio Generation Model. Soroush Mehri, Kundan Kumar, Ishaan Gulrajani, Rithesh Kumar, Shubham Jain, Jose Sotelo, Aaron Courville, Yoshua Bengio, 5th International Conference on Learning Representations (ICLR 2017), submitted and under review. 151 | -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN Audio Generation Model 3 | """ 4 | import numpy as np 5 | import random, time, os, glob 6 | 7 | def __getFile(dataset_name): 8 | return 'music/'+dataset_name+'/music_{}.npy' 9 | 10 | __base = [ 11 | ('Local', 'datasets/') 12 | ] 13 | 14 | __train = lambda s: s.format('train') 15 | __valid = lambda s: s.format('valid') 16 | __test = lambda s: s.format('test') 17 | 18 | def find_dataset(filename): 19 | for (k, v) in __base: 20 | tmp_path = os.path.join(v, filename) 21 | if os.path.exists(tmp_path): 22 | #print "Path on {}:".format(k) 23 | #print tmp_path 24 | return tmp_path 25 | #print "not found on {}".format(k) 26 | raise Exception('{} NOT FOUND!'.format(filename)) 27 | 28 | ### Basic utils ### 29 | def __round_to(x, y): 30 | """round x up to the nearest y""" 31 | return int(np.ceil(x / float(y))) * y 32 | 33 | def __normalize(data): 34 | """To range [0., 1.]""" 35 | data -= data.min(axis=1)[:, None] 36 | data /= data.max(axis=1)[:, None] 37 | return data 38 | 39 | def __linear_quantize(data, q_levels): 40 | """ 41 | floats in (0, 1) to ints in [0, q_levels-1] 42 | scales normalized across axis 1 43 | """ 44 | # Normalization is on mini-batch not whole file 45 | #eps = np.float64(1e-5) 46 | #data -= data.min(axis=1)[:, None] 47 | #data *= ((q_levels - eps) / data.max(axis=1)[:, None]) 48 | #data += eps/2 49 | #data = data.astype('int32') 50 | 51 | eps = np.float64(1e-5) 52 | data *= (q_levels - eps) 53 | data += eps/2 54 | data = data.astype('int32') 55 | return data 56 | 57 | def __a_law_quantize(data): 58 | """ 59 | :todo: 60 | """ 61 | raise NotImplementedError 62 | 63 | def linear2mu(x, mu=255): 64 | """ 65 | From Joao 66 | x should be normalized between -1 and 1 67 | Converts an array according to mu-law and discretizes it 68 | Note: 69 | mu2linear(linear2mu(x)) != x 70 | Because we are compressing to 8 bits here. 71 | They will sound pretty much the same, though. 72 | :usage: 73 | >>> bitrate, samples = scipy.io.wavfile.read('orig.wav') 74 | >>> norm = __normalize(samples)[None, :] # It takes 2D as inp 75 | >>> mu_encoded = linear2mu(2.*norm-1.) # From [0, 1] to [-1, 1] 76 | >>> print mu_encoded.min(), mu_encoded.max(), mu_encoded.dtype 77 | 0, 255, dtype('int16') 78 | >>> mu_decoded = mu2linear(mu_encoded) # Back to linear 79 | >>> print mu_decoded.min(), mu_decoded.max(), mu_decoded.dtype 80 | -1, 0.9574371, dtype('float32') 81 | """ 82 | x_mu = np.sign(x) * np.log(1 + mu*np.abs(x))/np.log(1 + mu) 83 | return ((x_mu + 1)/2 * mu).astype('int16') 84 | 85 | def mu2linear(x, mu=255): 86 | """ 87 | From Joao with modifications 88 | Converts an integer array from mu to linear 89 | For important notes and usage see: linear2mu 90 | """ 91 | mu = float(mu) 92 | x = x.astype('float32') 93 | y = 2. * (x - (mu+1.)/2.) / (mu+1.) 94 | return np.sign(y) * (1./mu) * ((1. + mu)**np.abs(y) - 1.) 95 | 96 | def __mu_law_quantize(data): 97 | return linear2mu(data) 98 | 99 | def __batch_quantize(data, q_levels, q_type): 100 | """ 101 | One of 'linear', 'a-law', 'mu-law' for q_type. 102 | """ 103 | data = data.astype('float64') 104 | data = __normalize(data) 105 | if q_type == 'linear': 106 | return __linear_quantize(data, q_levels) 107 | if q_type == 'a-law': 108 | return __a_law_quantize(data) 109 | if q_type == 'mu-law': 110 | # from [0, 1] to [-1, 1] 111 | data = 2.*data-1. 112 | # Automatically quantized to 256 bins. 113 | return __mu_law_quantize(data) 114 | raise NotImplementedError 115 | 116 | __RAND_SEED = 123 117 | def __fixed_shuffle(inp_list): 118 | if isinstance(inp_list, list): 119 | random.seed(__RAND_SEED) 120 | random.shuffle(inp_list) 121 | return 122 | #import collections 123 | #if isinstance(inp_list, (collections.Sequence)): 124 | if isinstance(inp_list, np.ndarray): 125 | np.random.seed(__RAND_SEED) 126 | np.random.shuffle(inp_list) 127 | return 128 | # destructive operations; in place; no need to return 129 | raise ValueError("inp_list is neither a list nor a np.ndarray but a "+type(inp_list)) 130 | 131 | def __make_random_batches(inp_list, batch_size): 132 | batches = [] 133 | for i in xrange(len(inp_list) / batch_size): 134 | batches.append(inp_list[i*batch_size:(i+1)*batch_size]) 135 | 136 | __fixed_shuffle(batches) 137 | return batches 138 | 139 | 140 | ### MUSIC DATASET LOADER ### 141 | def __music_feed_epoch(files, 142 | batch_size, 143 | seq_len, 144 | overlap, 145 | q_levels, 146 | q_zero, 147 | q_type, 148 | real_valued=False): 149 | """ 150 | Helper function to load music dataset. 151 | Generator that yields training inputs (subbatch, reset). `subbatch` contains 152 | quantized audio data; `reset` is a boolean indicating the start of a new 153 | sequence (i.e. you should reset h0 whenever `reset` is True). 154 | Feeds subsequences which overlap by a specified amount, so that the model 155 | can always have target for every input in a given subsequence. 156 | Assumes all flac files have the same length. 157 | returns: (subbatch, reset) 158 | subbatch.shape: (BATCH_SIZE, SEQ_LEN + OVERLAP) 159 | reset: True or False 160 | """ 161 | batches = __make_random_batches(files, batch_size) 162 | 163 | for bch in batches: 164 | # batch_seq_len = length of longest sequence in the batch, rounded up to 165 | # the nearest SEQ_LEN. 166 | batch_seq_len = len(bch[0]) # should be 8*16000 167 | batch_seq_len = __round_to(batch_seq_len, seq_len) 168 | 169 | batch = np.zeros( 170 | (batch_size, batch_seq_len), 171 | dtype='float64' 172 | ) 173 | 174 | mask = np.ones(batch.shape, dtype='float32') 175 | 176 | for i, data in enumerate(bch): 177 | #data, fs, enc = scikits.audiolab.flacread(path) 178 | # data is float16 from reading the npy file 179 | batch[i, :len(data)] = data 180 | # This shouldn't change anything. All the flac files for Music 181 | # are the same length and the mask should be 1 every where. 182 | # mask[i, len(data):] = np.float32(0) 183 | 184 | if not real_valued: 185 | batch = __batch_quantize(batch, q_levels, q_type) 186 | 187 | batch = np.concatenate([ 188 | np.full((batch_size, overlap), q_zero, dtype='int32'), 189 | batch 190 | ], axis=1) 191 | else: 192 | batch -= __music_train_mean_std[0] 193 | batch /= __music_train_mean_std[1] 194 | batch = np.concatenate([ 195 | np.full((batch_size, overlap), 0, dtype='float32'), 196 | batch 197 | ], axis=1).astype('float32') 198 | 199 | mask = np.concatenate([ 200 | np.full((batch_size, overlap), 1, dtype='float32'), 201 | mask 202 | ], axis=1) 203 | 204 | for i in xrange(batch_seq_len // seq_len): 205 | reset = np.int32(i==0) 206 | subbatch = batch[:, i*seq_len : (i+1)*seq_len + overlap] 207 | submask = mask[:, i*seq_len : (i+1)*seq_len + overlap] 208 | yield (subbatch, reset, submask) 209 | 210 | def music_train_feed_epoch(d_name, *args): 211 | """ 212 | :parameters: 213 | batch_size: int 214 | seq_len: 215 | overlap: 216 | q_levels: 217 | q_zero: 218 | q_type: One the following 'linear', 'a-law', or 'mu-law' 219 | 4,340 (9.65 hours) in total 220 | With batch_size = 128: 221 | 4,224 (9.39 hours) in total 222 | 3,712 (88%, 8.25 hours)for training set 223 | 256 (6%, .57 hours) for validation set 224 | 256 (6%, .57 hours) for test set 225 | Note: 226 | 32 of Beethoven's piano sonatas available on archive.org (Public Domain) 227 | :returns: 228 | A generator yielding (subbatch, reset, submask) 229 | """ 230 | # Just check if valid/test sets are also available. If not, raise. 231 | find_dataset(__valid(__getFile(d_name))) 232 | find_dataset(__test(__getFile(d_name))) 233 | # Load train set 234 | data_path = find_dataset(__train(__getFile(d_name))) 235 | files = np.load(data_path) 236 | generator = __music_feed_epoch(files, *args) 237 | return generator 238 | 239 | def music_valid_feed_epoch(d_name, *args): 240 | """ 241 | See: 242 | music_train_feed_epoch 243 | """ 244 | data_path = find_dataset(__valid(__getFile(d_name))) 245 | files = np.load(data_path) 246 | generator = __music_feed_epoch(files, *args) 247 | return generator 248 | 249 | def music_test_feed_epoch(d_name, *args): 250 | """ 251 | See: 252 | music_train_feed_epoch 253 | """ 254 | data_path = find_dataset(__test(__getFile(d_name))) 255 | files = np.load(data_path) 256 | generator = __music_feed_epoch(files, *args) 257 | return generator 258 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | import ops 2 | #import lasagne 3 | #from theano.compile.nanguardmode import NanGuardMode 4 | 5 | import math 6 | import time 7 | import locale 8 | 9 | import numpy 10 | import theano 11 | import theano.tensor as T 12 | import theano.gof 13 | 14 | import cPickle as pickle 15 | #import pickle 16 | import warnings 17 | import sys, os, errno, glob 18 | 19 | import matplotlib 20 | matplotlib.use('Agg') 21 | import matplotlib.pyplot as plt 22 | 23 | # TODO: Grouping is not working on cluster! :-? 24 | # Set a locale first or you won't get grouping at all 25 | locale.setlocale(locale.LC_ALL, '') 26 | # 'en_US.UTF-8' 27 | 28 | _params = {} 29 | def param(name, *args, **kwargs): 30 | """ 31 | A wrapper for `theano.shared` which enables parameter sharing in models. 32 | 33 | Creates and returns theano shared variables similarly to `theano.shared`, 34 | except if you try to create a param with the same name as a 35 | previously-created one, `param(...)` will just return the old one instead of 36 | making a new one. 37 | 38 | This constructor also adds a `param` attribute to the shared variables it 39 | creates, so that you can easily search a graph for all params. 40 | """ 41 | 42 | if name not in _params: 43 | kwargs['name'] = name 44 | param = theano.shared(*args, **kwargs) 45 | param.param = True 46 | _params[name] = param 47 | return _params[name] 48 | 49 | def delete_params(name): 50 | to_delete = [p_name for p_name in _params if name in p_name] 51 | for p_name in to_delete: 52 | del _params[p_name] 53 | 54 | def search(node, critereon): 55 | """ 56 | Traverse the Theano graph starting at `node` and return a list of all nodes 57 | which match the `critereon` function. When optimizing a cost function, you 58 | can use this to get a list of all of the trainable params in the graph, like 59 | so: 60 | 61 | `lib.search(cost, lambda x: hasattr(x, "param"))` 62 | or 63 | `lib.search(cost, lambda x: hasattr(x, "param") and x.param==True)` 64 | """ 65 | 66 | def _search(node, critereon, visited): 67 | if node in visited: 68 | return [] 69 | visited.add(node) 70 | 71 | results = [] 72 | if isinstance(node, T.Apply): 73 | for inp in node.inputs: 74 | results += _search(inp, critereon, visited) 75 | else: # Variable node 76 | if critereon(node): 77 | results.append(node) 78 | if node.owner is not None: 79 | results += _search(node.owner, critereon, visited) 80 | return results 81 | 82 | return _search(node, critereon, set()) 83 | 84 | def floatX(x): 85 | """ 86 | Convert `x` to the numpy type specified in `theano.config.floatX`. 87 | """ 88 | if theano.config.floatX == 'float16': 89 | return numpy.float16(x) 90 | elif theano.config.floatX == 'float32': 91 | return numpy.float32(x) 92 | else: # Theano's default float type is float64 93 | print "Warning: lib.floatX using float64" 94 | return numpy.float64(x) 95 | 96 | def save_params(path): 97 | param_vals = {} 98 | for name, param in _params.iteritems(): 99 | param_vals[name] = param.get_value() 100 | 101 | with open(path, 'wb') as f: 102 | pickle.dump(param_vals, f) 103 | 104 | def load_params(path): 105 | with open(path, 'rb') as f: 106 | param_vals = pickle.load(f) 107 | 108 | for name, val in param_vals.iteritems(): 109 | _params[name].set_value(val) 110 | 111 | def clear_all_params(): 112 | to_delete = [p_name for p_name in _params] 113 | for p_name in to_delete: 114 | del _params[p_name] 115 | 116 | def ensure_dir(dirname): 117 | """ 118 | Ensure that a named directory exists; if it does not, attempt to create it. 119 | """ 120 | try: 121 | os.makedirs(dirname) 122 | except OSError, e: 123 | if e.errno != errno.EEXIST: 124 | raise 125 | 126 | __model_setting_file_name = 'model_settings.txt' 127 | def print_model_settings(locals_var, path=None, sys_arg=False): 128 | """ 129 | Prints all variables in upper case in locals_var, 130 | except for T which usually stands for theano.tensor. 131 | If locals() passed as input to this method, will print 132 | all the variables in upper case defined so far, that is 133 | model settings. 134 | 135 | With `path` as an address to a directory it will _append_ it 136 | as a file named `model_settings.txt` as well. 137 | 138 | With `sys_arg` set to True, log information about Python, Numpy, 139 | and Theano and passed arguments to the script will be added too. 140 | args.pkl would be overwritten, specially in case of resuming a job. 141 | But again that wouldn't be much of a problem as all the passed args 142 | to the script except for '--resume' should be the same. 143 | 144 | With both `path` and `sys_arg` passed, dumps the theano.config. 145 | 146 | :usage: 147 | >>> import theano.tensor as T 148 | >>> import lib 149 | >>> BATCH_SIZE, DIM = 128, 512 150 | >>> DATA_PATH = '/Path/to/dataset' 151 | >>> lib.print_model_settings(locals(), path='./') 152 | """ 153 | log = "" 154 | if sys_arg: 155 | try: 156 | log += "Python:\n" 157 | log += "\tsys.version_info\t{}\n".format(str(sys.version_info)) 158 | log += "Numpy:\n" 159 | log += "\t.__version__\t{}\n".format(numpy.__version__) 160 | log += "Theano:\n" 161 | log += "\t.__version__\t{}\n".format(theano.__version__) 162 | log += "\n\nAll passed args:\n" 163 | log += str(sys.argv) 164 | log += "\n" 165 | except: 166 | print "Something went wrong during sys_arg logging. Continue anyway!" 167 | 168 | log += "\nModel settings:" 169 | all_vars = [(k,v) for (k,v) in locals_var.items() if (k.isupper() and k != 'T')] 170 | all_vars = sorted(all_vars, key=lambda x: x[0]) 171 | for var_name, var_value in all_vars: 172 | log += ("\n\t%-20s %s" % (var_name, var_value)) 173 | print log 174 | if path is not None: 175 | ensure_dir(path) 176 | # Don't override, just append if by mistake there is something in the file. 177 | with open(os.path.join(path, __model_setting_file_name), 'a+') as f: 178 | f.write(log) 179 | if sys_arg: 180 | with open(os.path.join(path, 'th_conf.txt'), 'a+') as f: 181 | f.write(str(theano.config)) 182 | with open(os.path.join(path, 'args.pkl'), 'wb') as f: 183 | pickle.dump(sys.argv, f) 184 | # To load: 185 | # >>> import cPickle as pickle 186 | # >>> args = pickle.load(open(os.path.join(path, 'args.pkl'), 'rb')) 187 | 188 | def get_params(cost, criterion=lambda x: hasattr(x, 'param') and x.param==True): 189 | """ 190 | Default criterion: 191 | lambda x: hasattr(x, 'param') and x.param==True 192 | This will return every parameter for cost from computation graph. 193 | 194 | To exclude a parameter, just set 'param' to False: 195 | >>> h0 = lib.param('h0',\ 196 | numpy.zeros((3, 2*512), dtype=theano.config.floatX)) 197 | >>> print h0.param # Default: True 198 | >>> h0.param = False 199 | 200 | In this case one still can get list of all params (False or True) by: 201 | >>> lib.get_params(cost, lambda x: hasattr(x, 'param') 202 | 203 | :returns: 204 | A list of params 205 | """ 206 | return search(cost, criterion) 207 | 208 | def print_params_info(params, path=None): 209 | """ 210 | Print information about the parameters in the given param set. 211 | 212 | With `path` as an address to a directory it will _append_ it 213 | as a file named `model_settings.txt` as well. 214 | 215 | :usage: 216 | >>> params = lib.get_params(cost) 217 | >>> lib.print_params_info(params, path='./') 218 | """ 219 | params = sorted(params, key=lambda p: p.name) 220 | values = [p.get_value(borrow=True) for p in params] 221 | shapes = [p.shape for p in values] 222 | total_param_count = 0 223 | multiply_all = lambda a, b: a*b 224 | log = "\nParams for cost:" 225 | for param, value, shape in zip(params, values, shapes): 226 | log += ("\n\t%-20s %s" % (shape, param.name)) 227 | total_param_count += reduce(multiply_all, shape) 228 | 229 | log += "\nTotal parameter count for this cost:\n\t{0}".format( 230 | locale.format("%d", total_param_count, grouping=True) 231 | ) 232 | print log 233 | 234 | if path is not None: 235 | ensure_dir(path) 236 | # Don't override, just append if by mistake there is something in the file. 237 | with open(os.path.join(path, __model_setting_file_name), 'a+') as f: 238 | f.write(log) 239 | 240 | __train_log_file_name = 'train_log.pkl' 241 | def save_training_info(values, path): 242 | """ 243 | Gets a set of values as dictionary and append them to a log file. 244 | stores in /train_log.pkl 245 | """ 246 | file_name = os.path.join(path, __train_log_file_name) 247 | try: 248 | with open(file_name, "rb") as f: 249 | log = pickle.load(f) 250 | except IOError: # first time 251 | log = {} 252 | for k in values.keys(): 253 | log[k] = [] 254 | for k, v in values.items(): 255 | log[k].append(v) 256 | with open(file_name, "wb") as f: 257 | pickle.dump(log, f) 258 | 259 | resume_key = 'last resume index' 260 | def resumable(path, 261 | iter_key='iter', 262 | epoch_key='epoch', 263 | add_resume_counter=True, 264 | other_keys=[]): 265 | """ 266 | :warning: 267 | This is a naive implementation of resuming a training session 268 | and does not save and reload the training loop. The serialization 269 | of training loop and everything is costly and error-prone. 270 | 271 | :todo: 272 | - Save and load a serializable training loop. (See warning above) 273 | - Heavily dependent on the "model" file and the names used there right 274 | now. It's really easy to miss anything. 275 | 276 | `path` should be pointing at the root directory where `train_log.pkl` 277 | (See __train_log_file_name) and `params/` reside. 278 | 279 | Always assuming all the values in the log dictionary (except `resume_key`), 280 | are lists with the same length. 281 | """ 282 | file_name = os.path.join(path, __train_log_file_name) 283 | # Raise error if does not exists. 284 | with open(file_name, "rb") as f: 285 | log = pickle.load(f) 286 | 287 | param_found = False 288 | res_path = os.path.join(path, 'params', 'params_e{}_i{}*.pkl') 289 | for reverse_idx in range(-1, -len(log[epoch_key])-1, -1): 290 | ep, it = log[epoch_key][reverse_idx], log[iter_key][reverse_idx] 291 | print "> Params file for epoch {} iter {}".format(ep, it), 292 | last_path = glob.glob(res_path.format(ep, it)) 293 | if len(last_path) == 1: 294 | res_path = last_path[0] 295 | param_found = True 296 | print "found." 297 | break 298 | elif len(last_path) == 0: 299 | print "[NOT FOUND]. FALLING BACK TO..." 300 | else: # > 1 301 | # choose one, warning, rare 302 | print "[multiple version found]:" 303 | for l_path in last_path: 304 | print l_path 305 | res_path = last_path[0] 306 | param_found = True 307 | print "Arbitrarily choosing first:\n\t{}".format(res_path) 308 | 309 | assert 'reverse_idx' in locals(), 'Empty train_log???\n{}'.format(log) 310 | # Finishing for loop with no success 311 | assert param_found, 'No matching params file with train_log' 312 | 313 | acceptable_len = reverse_idx+len(log[epoch_key])+1 314 | if acceptable_len != len(log[epoch_key]): 315 | # Backup of the old train_log 316 | with open(file_name+'.backup', 'wb') as f: 317 | pickle.dump(log, f) 318 | 319 | # Change the log file to match the last existing checkpoint. 320 | for k, v in log.items(): 321 | # Fix resume indices 322 | if k == resume_key: 323 | log[k] = [i for i in log[k] if i < acceptable_len] 324 | continue 325 | # Rest is useless with no param file. 326 | log[k] = v[:acceptable_len] 327 | 328 | epochs = log[epoch_key] 329 | iters = log[iter_key] 330 | 331 | if add_resume_counter: 332 | resume_val = len(epochs) 333 | if not resume_key in log.keys(): 334 | log[resume_key] = [resume_val] 335 | else: 336 | if log[resume_key] == [] or log[resume_key][-1] != resume_val: 337 | log[resume_key].append(resume_val) 338 | with open(file_name, "wb") as f: 339 | pickle.dump(log, f) 340 | 341 | last_epoch = epochs[-1] 342 | last_iter = iters[-1] 343 | 344 | # The if-else statement is more readable than `next`: 345 | #iters_to_consume = next((last_iter%(i-1) for (e, i) in\ 346 | # zip(epochs, iters) if e == 1), last_iter) 347 | if last_epoch == 0: 348 | iters_to_consume = last_iter 349 | else: 350 | for e, i in zip(epochs, iters): 351 | # first time. Epoch turns from 0 to 1. 352 | # At the end of each `epoch` there should be 353 | # a monitoring step so it will gives number 354 | # number of iterations per epoch 355 | if e == 1: 356 | iters_per_epoch = i - 1 357 | break 358 | iters_to_consume = last_iter % iters_per_epoch 359 | 360 | last_other_keys = [log[k][-1] for k in other_keys] 361 | return iters_to_consume, res_path, last_epoch, last_iter, last_other_keys 362 | 363 | def plot_traing_info(x, ylist, path): 364 | """ 365 | Loads log file and plot x and y values as provided by input. 366 | Saves as /train_log.png 367 | """ 368 | file_name = os.path.join(path, __train_log_file_name) 369 | try: 370 | with open(file_name, "rb") as f: 371 | log = pickle.load(f) 372 | except IOError: # first time 373 | warnings.warn("There is no {} file here!!!".format(file_name)) 374 | return 375 | plt.figure() 376 | x_vals = log[x] 377 | for y in ylist: 378 | y_vals = log[y] 379 | if len(y_vals) != len(x_vals): 380 | warning.warn("One of y's: {} does not have the same length as x:{}".format(y, x)) 381 | plt.plot(x_vals, y_vals, label=y) 382 | # assert len(y_vals) == len(x_vals), "not the same len" 383 | plt.xlabel(x) 384 | plt.legend() 385 | #plt.show() 386 | plt.savefig(file_name[:-3]+'png', bbox_inches='tight') 387 | plt.close('all') 388 | 389 | def create_logging_folders(path): 390 | """ 391 | Handle structure of folders and naming here instead of training file. 392 | 393 | :todo: 394 | - Implement! 395 | """ 396 | pass 397 | 398 | def tv(var): 399 | """ 400 | :todo: 401 | - add tv() function for theano variables so that instead of calling 402 | x.tag.test_value, you can get the same thing just by calling the method 403 | in a faster way... 404 | - also for x.tag.test_value.shape 405 | """ 406 | # Based on EAFP (easier to ask for forgiveness than permission) 407 | try: 408 | return var.tag.test_value 409 | except AttributeError: 410 | print "NONE, test_value has not been set." 411 | import ipdb; ipdb.set_trace() 412 | 413 | ## Rather than LBYL (look before you leap) 414 | #if hasattr(var, 'tag'): 415 | # if hasattr(var.tag, 'test_value'): 416 | # return var.tag.test_value 417 | # else: 418 | # print "NONE, test_value has not set." 419 | # import ipdb; ipdb.set_trace() 420 | #else: 421 | # print "NONE, tag has not set." 422 | # import ipdb; ipdb.set_trace() 423 | 424 | def tvs(var): 425 | """ 426 | :returns: 427 | var.tag.test_value.shape 428 | """ 429 | return tv(var).shape 430 | 431 | def _is_symbolic(v): 432 | r"""Return `True` if any of the arguments are symbolic. 433 | See: 434 | https://github.com/Theano/Theano/wiki/Cookbook 435 | """ 436 | symbolic = False 437 | v = list(v) 438 | for _container, _iter in [(v, xrange(len(v)))]: 439 | for _k in _iter: 440 | _v = _container[_k] 441 | if isinstance(_v, theano.gof.Variable): 442 | symbolic = True 443 | return symbolic 444 | 445 | def unique_list(inp_list): 446 | """ 447 | returns a list with unique values of inp_list. 448 | :usage: 449 | >>> inp_list = ['a', 'b', 'c'] 450 | >>> unique_inp_list = unique_list(inp_list*2) 451 | """ 452 | return list(set(inp_list)) 453 | -------------------------------------------------------------------------------- /models/one_tier/wavent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | WaveNets Audio Generation Model 4 | 5 | How-to-run example: 6 | 7 | sampleRNN$ 8 | THEANO_FLAGS=mode=FAST_RUN,device=gpu1,floatX=float32,lib.cnmem=.95 python models/one_tier/wavent.py --dim 64 --q_levels 256 --q_type linear --which_set MUSIC --batch_size 8 --wavenet_blocks 4 --dilation_layers_per_block 10 --sequence_len_to_train 1600 9 | """ 10 | import time 11 | from datetime import datetime 12 | print "Experiment started at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 13 | exp_start = time.time() 14 | 15 | import os, sys 16 | sys.path.insert(1, os.getcwd()) 17 | import argparse 18 | 19 | import numpy 20 | numpy.random.seed(123) 21 | np = numpy 22 | import random 23 | random.seed(123) 24 | 25 | import theano 26 | import theano.tensor as T 27 | import theano.ifelse 28 | import lasagne 29 | import scipy.io.wavfile 30 | 31 | import lib 32 | 33 | 34 | ### Parsing passed args/hyperparameters ### 35 | def get_args(): 36 | def t_or_f(arg): 37 | ua = str(arg).upper() 38 | if 'TRUE'.startswith(ua): 39 | return True 40 | elif 'FALSE'.startswith(ua): 41 | return False 42 | else: 43 | raise ValueError('Arg is neither `True` nor `False`') 44 | 45 | def check_non_negative(value): 46 | ivalue = int(value) 47 | if ivalue < 0: 48 | raise argparse.ArgumentTypeError("%s is not non-negative!" % value) 49 | return ivalue 50 | 51 | def check_positive(value): 52 | ivalue = int(value) 53 | if ivalue < 1: 54 | raise argparse.ArgumentTypeError("%s is not positive!" % value) 55 | return ivalue 56 | 57 | def check_unit_interval(value): 58 | fvalue = float(value) 59 | if fvalue < 0 or fvalue > 1: 60 | raise argparse.ArgumentTypeError("%s is not in [0, 1] interval!" % value) 61 | return fvalue 62 | 63 | # No default value here. Indicate every single arguement. 64 | parser = argparse.ArgumentParser( 65 | description='two_tier.py\nNo default value! Indicate every argument.') 66 | 67 | # Hyperparameter arguements: 68 | parser.add_argument('--exp', help='Experiment name', 69 | type=str, required=False, default='_') 70 | parser.add_argument('--dim', help='Dimension of RNN and MLPs',\ 71 | type=check_positive, required=True) 72 | parser.add_argument('--q_levels', help='Number of bins for quantization of audio samples. Should be 256 for mu-law.',\ 73 | type=check_positive, required=True) 74 | parser.add_argument('--q_type', help='Quantization in linear-scale, a-law-companding, or mu-law compandig. With mu-/a-law quantization level shoud be set as 256',\ 75 | choices=['linear', 'a-law', 'mu-law'], required=True) 76 | #parser.add_argument('--nll_coeff', help='Value of alpha in [0, 1] for cost=alpha*NLL+(1-alpha)*FFT_cost',\ 77 | # type=check_unit_interval, required=True) 78 | parser.add_argument('--which_set', help='ONOM, BLIZZ, or MUSIC', 79 | choices=['ONOM', 'BLIZZ', 'MUSIC', 'HUCK'], required=True) 80 | parser.add_argument('--batch_size', help='size of mini-batch', 81 | type=check_positive, choices=[8, 16, 32, 64, 128, 256], required=True) 82 | parser.add_argument('--wavenet_blocks', help='Number of wavnet blocks to use', 83 | type=check_positive, required=True) 84 | parser.add_argument('--dilation_layers_per_block', help='number of dilation layers per block', 85 | type=check_positive, required=True) 86 | 87 | parser.add_argument('--sequence_len_to_train', help='size of output map', 88 | type=check_positive, required=True) 89 | 90 | parser.add_argument('--debug', help='debug mode', required=False, default=False, action='store_true') 91 | 92 | parser.add_argument('--resume', help='Resume the same model from the last checkpoint. Order of params are important. [for now]',\ 93 | required=False, default=False, action='store_true') 94 | 95 | args = parser.parse_args() 96 | 97 | # Create tag for this experiment based on passed args 98 | tag = reduce(lambda a, b: a+b, sys.argv).replace('--resume', '').replace('/', '-').replace('--', '-').replace('True', 'T').replace('False', 'F') 99 | print "Created experiment tag for these args:" 100 | print tag 101 | 102 | return args, tag 103 | 104 | args, tag = get_args() 105 | 106 | # N_FRAMES = args.n_frames # How many 'frames' to include in each truncated BPTT pass 107 | OVERLAP = (2**args.dilation_layers_per_block - 1)*args.wavenet_blocks + 1# How many samples per frame 108 | #GLOBAL_NORM = args.global_norm 109 | DIM = args.dim # Model dimensionality. 110 | Q_LEVELS = args.q_levels # How many levels to use when discretizing samples. e.g. 256 = 8-bit scalar quantization 111 | Q_TYPE = args.q_type # log- or linear-scale 112 | #NLL_COEFF = args.nll_coeff 113 | WHICH_SET = args.which_set 114 | BATCH_SIZE = args.batch_size 115 | #DATA_PATH = args.data_path 116 | 117 | if Q_TYPE == 'mu-law' and Q_LEVELS != 256: 118 | raise ValueError('For mu-law Quantization levels should be exactly 256!') 119 | 120 | # Fixed hyperparams 121 | GRAD_CLIP = 1 # Elementwise grad clip threshold 122 | BITRATE = 16000 123 | 124 | # Other constants 125 | #TRAIN_MODE = 'iters' # To use PRINT_ITERS and STOP_ITERS 126 | TRAIN_MODE = 'time' # To use PRINT_TIME and STOP_TIME 127 | #TRAIN_MODE = 'time-iters' 128 | # To use PRINT_TIME for validation, 129 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 130 | #TRAIN_MODE = 'iters-time' 131 | # To use PRINT_ITERS for validation, 132 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 133 | PRINT_ITERS = 10000 # Print cost, generate samples, save model checkpoint every N iterations. 134 | STOP_ITERS = 100000 # Stop after this many iterations 135 | PRINT_TIME = 90*60 # Print cost, generate samples, save model checkpoint every N seconds. 136 | STOP_TIME = 60*60*60 # Stop after this many seconds of actual training (not including time req'd to generate samples etc.) 137 | N_SEQS = 10 # Number of samples to generate every time monitoring. 138 | FOLDER_PREFIX = os.path.join('results_wavenets', tag) 139 | SEQ_LEN = args.sequence_len_to_train # Total length (# of samples) of each truncated BPTT sequence 140 | Q_ZERO = numpy.int32(Q_LEVELS//2) # Discrete value correponding to zero amplitude 141 | 142 | LEARNING_RATE = lib.floatX(numpy.float32(0.0001)) 143 | RESUME = args.resume 144 | 145 | epoch_str = 'epoch' 146 | iter_str = 'iter' 147 | lowest_valid_str = 'lowest valid cost' 148 | corresp_test_str = 'correponding test cost' 149 | train_nll_str, valid_nll_str, test_nll_str = \ 150 | 'train NLL (bits)', 'valid NLL (bits)', 'test NLL (bits)' 151 | 152 | if args.debug: 153 | import warnings 154 | warnings.warn('----------RUNNING IN DEBUG MODE----------') 155 | TRAIN_MODE = 'time-iters' 156 | PRINT_TIME = 100 157 | STOP_TIME = 300 158 | STOP_ITERS = 1000 159 | 160 | ### Create directories ### 161 | # FOLDER_PREFIX: root, contains: 162 | # log.txt, __note.txt, train_log.pkl, train_log.png [, model_settings.txt] 163 | # FOLDER_PREFIX/params: saves all checkpoint params as pkl 164 | # FOLDER_PREFIX/samples: keeps all checkpoint samples as wav 165 | # FOLDER_PREFIX/best: keeps the best parameters, samples, ... 166 | 167 | if not os.path.exists(FOLDER_PREFIX): 168 | os.makedirs(FOLDER_PREFIX) 169 | 170 | PARAMS_PATH = os.path.join(FOLDER_PREFIX, 'params') 171 | 172 | if not os.path.exists(PARAMS_PATH): 173 | os.makedirs(PARAMS_PATH) 174 | 175 | SAMPLES_PATH = os.path.join(FOLDER_PREFIX, 'samples') 176 | 177 | if not os.path.exists(SAMPLES_PATH): 178 | os.makedirs(SAMPLES_PATH) 179 | 180 | BEST_PATH = os.path.join(FOLDER_PREFIX, 'best') 181 | 182 | if not os.path.exists(BEST_PATH): 183 | os.makedirs(BEST_PATH) 184 | 185 | lib.print_model_settings(locals(), path=FOLDER_PREFIX, sys_arg=True) 186 | 187 | ### Creating computation graph ### 188 | 189 | def create_wavenet_block(inp, num_dilation_layer, input_dim, output_dim, name =None): 190 | assert name is not None 191 | layer_out = inp 192 | skip_contrib = [] 193 | skip_weights = lib.param(name+".parametrized_weights", lib.floatX(numpy.ones((num_dilation_layer,)))) 194 | for i in range(num_dilation_layer): 195 | layer_out, skip_c = lib.ops.dil_conv_1D( 196 | layer_out, 197 | output_dim, 198 | input_dim if i == 0 else output_dim, 199 | 2, 200 | dilation = 2**i, 201 | non_linearity = 'gated', 202 | name = name+".dilation_{}".format(i+1) 203 | ) 204 | skip_c = skip_c*skip_weights[i] 205 | 206 | skip_contrib.append(skip_c) 207 | 208 | skip_out = skip_contrib[-1] 209 | 210 | j = 0 211 | for i in range(num_dilation_layer-1): 212 | j += 2**(num_dilation_layer-i-1) 213 | skip_out = skip_out + skip_contrib[num_dilation_layer-2 - i][:,j:] 214 | 215 | return layer_out, skip_out 216 | 217 | def create_model(inp): 218 | out = (inp.astype(theano.config.floatX)/lib.floatX(Q_LEVELS-1) - lib.floatX(0.5)) 219 | l_out = out.dimshuffle(0,1,'x') 220 | 221 | skips = [] 222 | for i in range(args.wavenet_blocks): 223 | l_out, skip_out = create_wavenet_block(l_out, args.dilation_layers_per_block, 1 if i == 0 else args.dim, args.dim, name = "block_{}".format(i+1)) 224 | skips.append(skip_out) 225 | 226 | out = skips[-1] 227 | 228 | for i in range(args.wavenet_blocks - 1): 229 | out = out + skips[args.wavenet_blocks - 2 - i][:,(2**args.dilation_layers_per_block - 1)*(i+1):] 230 | 231 | for i in range(3): 232 | out = lib.ops.conv1d("out_{}".format(i+1), out, args.dim, args.dim, 1, non_linearity='relu') 233 | 234 | out = lib.ops.conv1d("final", out, args.dim, args.q_levels, 1, non_linearity='identity') 235 | 236 | return out 237 | 238 | sequences = T.imatrix('sequences') 239 | h0 = T.tensor3('h0') 240 | reset = T.iscalar('reset') 241 | mask = T.matrix('mask') 242 | 243 | if args.debug: 244 | # Solely for debugging purposes. 245 | # Maybe I should set the compute_test_value=warn from here. 246 | sequences.tag.test_value = numpy.zeros((BATCH_SIZE, SEQ_LEN), dtype='int32') 247 | 248 | input_sequences = sequences[:, :-1] 249 | target_sequences = sequences[:, (2**args.dilation_layers_per_block - 1)*args.wavenet_blocks + 1:] 250 | 251 | target_mask = mask[:, (2**args.dilation_layers_per_block - 1)*args.wavenet_blocks + 1:] 252 | 253 | output = create_model(input_sequences) 254 | 255 | cost = T.nnet.categorical_crossentropy( 256 | T.nnet.softmax(output.reshape((-1, Q_LEVELS))), 257 | target_sequences.flatten() 258 | ) 259 | 260 | cost = cost.reshape(target_sequences.shape) 261 | cost = cost * target_mask 262 | # Don't use these lines; could end up with NaN 263 | # Specially at the end of audio files where mask is 264 | # all zero for some of the shorter files in mini-batch. 265 | #cost = cost.sum(axis=1) / target_mask.sum(axis=1) 266 | #cost = cost.mean(axis=0) 267 | 268 | # Use this one instead. 269 | cost = cost.sum() 270 | cost = cost / target_mask.sum() 271 | 272 | # By default we report cross-entropy cost in bits. 273 | # Switch to nats by commenting out this line: 274 | # log_2(e) = 1.44269504089 275 | cost = cost * lib.floatX(numpy.log2(numpy.e)) 276 | 277 | ### Getting the params, grads, updates, and Theano functions ### 278 | params = lib.get_params(cost, lambda x: hasattr(x, 'param') and x.param==True) 279 | lib.print_params_info(params, path=FOLDER_PREFIX) 280 | 281 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 282 | grads = [T.clip(g, lib.floatX(-GRAD_CLIP), lib.floatX(GRAD_CLIP)) for g in grads] 283 | 284 | updates = lasagne.updates.adam(grads, params, learning_rate=LEARNING_RATE) 285 | 286 | # Training function 287 | train_fn = theano.function( 288 | [sequences, mask], 289 | cost, 290 | updates=updates, 291 | on_unused_input='warn' 292 | ) 293 | 294 | # Validation and Test function 295 | test_fn = theano.function( 296 | [sequences, mask], 297 | cost, 298 | on_unused_input='warn' 299 | ) 300 | 301 | # Sampling at frame level 302 | generate_fn = theano.function( 303 | [sequences], 304 | lib.ops.softmax_and_sample(output), 305 | on_unused_input='warn' 306 | ) 307 | 308 | 309 | def generate_and_save_samples(tag): 310 | def write_audio_file(name, data): 311 | data = data.astype('float32') 312 | data -= data.min() 313 | data /= data.max() 314 | data -= 0.5 315 | data *= 0.95 316 | scipy.io.wavfile.write( 317 | os.path.join(SAMPLES_PATH, name+'.wav'), 318 | BITRATE, 319 | data) 320 | 321 | total_time = time.time() 322 | # Generate N_SEQS' sample files, each 5 seconds long 323 | N_SECS = 5 324 | LENGTH = N_SECS*BITRATE 325 | 326 | if args.debug: 327 | LENGTH = 1024 328 | 329 | num_prev_samples_to_use = (2**args.dilation_layers_per_block - 1)*args.wavenet_blocks + 1 330 | 331 | samples = numpy.zeros((N_SEQS, LENGTH + num_prev_samples_to_use), dtype='int32') 332 | samples[:, :num_prev_samples_to_use] = Q_ZERO 333 | 334 | for t in range(LENGTH): 335 | samples[:,num_prev_samples_to_use+t:num_prev_samples_to_use+t+1] = generate_fn(samples[:, t:t + num_prev_samples_to_use+1]) 336 | if (t > 2*BITRATE) and( t < 3*BITRATE): 337 | samples[:,num_prev_samples_to_use+t:num_prev_samples_to_use+t+1] = Q_ZERO 338 | 339 | total_time = time.time() - total_time 340 | log = "{} samples of {} seconds length generated in {} seconds." 341 | log = log.format(N_SEQS, N_SECS, total_time) 342 | print log, 343 | 344 | for i in xrange(N_SEQS): 345 | samp = samples[i, num_prev_samples_to_use: ] 346 | if Q_TYPE == 'mu-law': 347 | from datasets.dataset import mu2linear 348 | samp = mu2linear(samp) 349 | elif Q_TYPE == 'a-law': 350 | raise NotImplementedError('a-law is not implemented') 351 | write_audio_file("sample_{}_{}".format(tag, i), samp) 352 | 353 | ### Import the data_feeder ### 354 | # Handling WHICH_SET 355 | if WHICH_SET == 'ONOM': 356 | from datasets.dataset import onom_train_feed_epoch as train_feeder 357 | from datasets.dataset import onom_valid_feed_epoch as valid_feeder 358 | from datasets.dataset import onom_test_feed_epoch as test_feeder 359 | elif WHICH_SET == 'BLIZZ': 360 | from datasets.dataset import blizz_train_feed_epoch as train_feeder 361 | from datasets.dataset import blizz_valid_feed_epoch as valid_feeder 362 | from datasets.dataset import blizz_test_feed_epoch as test_feeder 363 | elif WHICH_SET == 'MUSIC': 364 | from datasets.dataset import music_train_feed_epoch as train_feeder 365 | from datasets.dataset import music_valid_feed_epoch as valid_feeder 366 | from datasets.dataset import music_test_feed_epoch as test_feeder 367 | elif WHICH_SET == 'HUCK': 368 | from datasets.dataset import huck_train_feed_epoch as train_feeder 369 | from datasets.dataset import huck_valid_feed_epoch as valid_feeder 370 | from datasets.dataset import huck_test_feed_epoch as test_feeder 371 | 372 | 373 | def monitor(data_feeder): 374 | """ 375 | Cost and time of test_fn on a given dataset section. 376 | Pass only one of `valid_feeder` or `test_feeder`. 377 | Don't pass `train_feed`. 378 | 379 | :returns: 380 | Mean cost over the input dataset (data_feeder) 381 | Total time spent 382 | """ 383 | _total_time = 0. 384 | _costs = [] 385 | _data_feeder = data_feeder(BATCH_SIZE, 386 | SEQ_LEN, 387 | OVERLAP, 388 | Q_LEVELS, 389 | Q_ZERO, 390 | Q_TYPE) 391 | 392 | for _seqs, _reset, _mask in _data_feeder: 393 | _start_time = time.time() 394 | _cost = test_fn(_seqs, _mask) 395 | _total_time += time.time() - _start_time 396 | 397 | _costs.append(_cost) 398 | 399 | return numpy.mean(_costs), _total_time 400 | 401 | 402 | print "Wall clock time spent before training started: {:.2f}h"\ 403 | .format((time.time()-exp_start)/3600.) 404 | print "Training!" 405 | total_iters = 0 406 | total_time = 0. 407 | last_print_time = 0. 408 | last_print_iters = 0 409 | costs = [] 410 | lowest_valid_cost = numpy.finfo(numpy.float32).max 411 | corresponding_test_cost = numpy.finfo(numpy.float32).max 412 | new_lowest_cost = False 413 | end_of_batch = False 414 | epoch = 0 # Important for mostly other datasets rather than Blizz 415 | 416 | # Initial load train dataset 417 | tr_feeder = train_feeder(BATCH_SIZE, 418 | SEQ_LEN, 419 | OVERLAP, 420 | Q_LEVELS, 421 | Q_ZERO, 422 | Q_TYPE) 423 | 424 | 425 | 426 | if RESUME: 427 | # Check if checkpoint from previous run is not corrupted. 428 | # Then overwrite some of the variables above. 429 | iters_to_consume, res_path, epoch, total_iters,\ 430 | [lowest_valid_cost, corresponding_test_cost, test_cost] = \ 431 | lib.resumable(path=FOLDER_PREFIX, 432 | iter_key=iter_str, 433 | epoch_key=epoch_str, 434 | add_resume_counter=True, 435 | other_keys=[lowest_valid_str, 436 | corresp_test_str, 437 | test_nll_str]) 438 | # At this point we saved the pkl file. 439 | last_print_iters = total_iters 440 | print "### RESUMING JOB FROM EPOCH {}, ITER {}".format(epoch, total_iters) 441 | # Consumes this much iters to get to the last point in training data. 442 | consume_time = time.time() 443 | for i in xrange(iters_to_consume): 444 | tr_feeder.next() 445 | consume_time = time.time() - consume_time 446 | print "Train data ready in {:.2f}secs after consuming {} minibatches.".\ 447 | format(consume_time, iters_to_consume) 448 | 449 | lib.load_params(res_path) 450 | print "Parameters from last available checkpoint loaded from path {}".format(res_path) 451 | 452 | test_time = 0.0 453 | 454 | while True: 455 | # THIS IS ONE ITERATION 456 | if total_iters % 500 == 0: 457 | print total_iters, 458 | 459 | total_iters += 1 460 | 461 | try: 462 | # Take as many mini-batches as possible from train set 463 | mini_batch = tr_feeder.next() 464 | except StopIteration: 465 | # Mini-batches are finished. Load it again. 466 | # Basically, one epoch. 467 | tr_feeder = train_feeder(BATCH_SIZE, 468 | SEQ_LEN, 469 | OVERLAP, 470 | Q_LEVELS, 471 | Q_ZERO, 472 | Q_TYPE) 473 | 474 | # and start taking new mini-batches again. 475 | mini_batch = tr_feeder.next() 476 | epoch += 1 477 | end_of_batch = True 478 | print "[Another epoch]", 479 | 480 | seqs, reset, mask = mini_batch 481 | 482 | 483 | ##Remove this 484 | # print seqs.shape 485 | # targ = generate_fn(seqs) 486 | # print targ.shape 487 | ##### 488 | 489 | start_time = time.time() 490 | cost = train_fn(seqs, mask) 491 | total_time += time.time() - start_time 492 | #print "This cost:", cost, "This h0.mean()", h0.mean() 493 | 494 | costs.append(cost) 495 | 496 | if (TRAIN_MODE=='iters' and total_iters-last_print_iters == PRINT_ITERS) or \ 497 | (TRAIN_MODE=='time' and total_time-last_print_time >= PRINT_TIME) or \ 498 | (TRAIN_MODE=='time-iters' and total_time-last_print_time >= PRINT_TIME) or \ 499 | (TRAIN_MODE=='iters-time' and total_iters-last_print_iters >= PRINT_ITERS) or \ 500 | end_of_batch: 501 | print "\nValidation!", 502 | valid_cost, valid_time = monitor(valid_feeder) 503 | print "Done!" 504 | 505 | # Only when the validation cost is improved get the cost for test set. 506 | if valid_cost < lowest_valid_cost: 507 | lowest_valid_cost = valid_cost 508 | print "\n>>> Best validation cost of {} reached. Testing!"\ 509 | .format(valid_cost), 510 | test_cost, test_time = monitor(test_feeder) 511 | print "Done!" 512 | # Report last one which is the lowest on validation set: 513 | print ">>> test cost:{}\ttotal time:{}".format(test_cost, test_time) 514 | corresponding_test_cost = test_cost 515 | new_lowest_cost = True 516 | 517 | # Stdout the training progress 518 | print_info = "epoch:{}\ttotal iters:{}\twall clock time:{:.2f}h\n" 519 | print_info += ">>> Lowest valid cost:{}\t Corresponding test cost:{}\n" 520 | print_info += "\ttrain cost:{:.4f}\ttotal time:{:.2f}h\tper iter:{:.3f}s\n" 521 | print_info += "\tvalid cost:{:.4f}\ttotal time:{:.2f}h\n" 522 | print_info += "\ttest cost:{:.4f}\ttotal time:{:.2f}h" 523 | print_info = print_info.format(epoch, 524 | total_iters, 525 | (time.time()-exp_start)/3600, 526 | lowest_valid_cost, 527 | corresponding_test_cost, 528 | numpy.mean(costs), 529 | total_time/3600, 530 | total_time/total_iters, 531 | valid_cost, 532 | valid_time/3600, 533 | test_cost, 534 | test_time/3600) 535 | print print_info 536 | 537 | # Save and graph training progress 538 | x_axis_str = 'iter' 539 | train_nll_str, valid_nll_str, test_nll_str = \ 540 | 'train NLL (bits)', 'valid NLL (bits)', 'test NLL (bits)' 541 | training_info = {'epoch' : epoch, 542 | x_axis_str : total_iters, 543 | train_nll_str : numpy.mean(costs), 544 | valid_nll_str : valid_cost, 545 | test_nll_str : test_cost, 546 | 'lowest valid cost' : lowest_valid_cost, 547 | 'correponding test cost' : corresponding_test_cost, 548 | 'train time' : total_time, 549 | 'valid time' : valid_time, 550 | 'test time' : test_time, 551 | 'wall clock time' : time.time()-exp_start} 552 | lib.save_training_info(training_info, FOLDER_PREFIX) 553 | print "Train info saved!", 554 | 555 | y_axis_strs = [train_nll_str, valid_nll_str, test_nll_str] 556 | lib.plot_traing_info(x_axis_str, y_axis_strs, FOLDER_PREFIX) 557 | print "Plotted!" 558 | 559 | # Generate and save samples 560 | print "Sampling!", 561 | tag = "e{}_i{}_t{:.2f}_tr{:.4f}_v{:.4f}" 562 | tag = tag.format(epoch, 563 | total_iters, 564 | total_time/3600, 565 | numpy.mean(cost), 566 | valid_cost) 567 | tag += ("_best" if new_lowest_cost else "") 568 | # Generate samples 569 | generate_and_save_samples(tag) 570 | print "Done!" 571 | 572 | # Save params of model 573 | lib.save_params( 574 | os.path.join(PARAMS_PATH, 'params_{}.pkl'.format(tag)) 575 | ) 576 | print "Params saved!" 577 | 578 | if total_iters-last_print_iters == PRINT_ITERS \ 579 | or total_time-last_print_time >= PRINT_TIME: 580 | # If we are here b/c of onom_end_of_batch, we shouldn't mess 581 | # with costs and last_print_iters 582 | costs = [] 583 | last_print_time += PRINT_TIME 584 | last_print_iters += PRINT_ITERS 585 | 586 | end_of_batch = False 587 | new_lowest_cost = False 588 | 589 | print "Validation Done!\nBack to Training..." 590 | 591 | if (TRAIN_MODE=='iters' and total_iters == STOP_ITERS) or \ 592 | (TRAIN_MODE=='time' and total_time >= STOP_TIME) or \ 593 | ((TRAIN_MODE=='time-iters' or TRAIN_MODE=='iters-time') and \ 594 | (total_iters == STOP_ITERS or total_time >= STOP_TIME)): 595 | 596 | print "Done! Total iters:", total_iters, "Total time: ", total_time 597 | print "Experiment ended at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 598 | print "Wall clock time spent: {:.2f}h"\ 599 | .format((time.time()-exp_start)/3600) 600 | 601 | sys.exit() 602 | -------------------------------------------------------------------------------- /models/two_tier/two_tier_generate16k.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN Audio Generation Model 3 | 4 | Two-tier model, Quantized input 5 | For more info: 6 | $ python two_tier.py -h 7 | 8 | How-to-run example: 9 | sampleRNN$ pwd 10 | /u/mehris/sampleRNN 11 | 12 | sampleRNN$ \ 13 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python -u \ 14 | models/two_tier/two_tier.py --exp AXIS1 --n_frames 12 --frame_size 10 \ 15 | --weight_norm True --emb_size 64 --skip_conn False --dim 32 --n_rnn 2 \ 16 | --rnn_type LSTM --learn_h0 False --q_levels 16 --q_type linear \ 17 | --batch_size 128 --which_set MUSIC 18 | 19 | To resume add ` --resume` to the END of the EXACTLY above line. You can run the 20 | resume code as many time as possible, depending on the TRAIN_MODE. 21 | (folder name, file name, flags, their order, and the values are important) 22 | """ 23 | from time import time 24 | from datetime import datetime 25 | print "Experiment started at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 26 | exp_start = time() 27 | 28 | import os, sys, glob 29 | sys.path.insert(1, os.getcwd()) 30 | import argparse 31 | import datetime 32 | import numpy 33 | numpy.random.seed(123) 34 | np = numpy 35 | import random 36 | random.seed(123) 37 | import re 38 | 39 | 40 | import theano 41 | import theano.tensor as T 42 | import theano.ifelse 43 | import lasagne 44 | import scipy.io.wavfile 45 | 46 | import lib 47 | 48 | LEARNING_RATE = 0.001 49 | 50 | ### Parsing passed args/hyperparameters ### 51 | def get_args(): 52 | def t_or_f(arg): 53 | ua = str(arg).upper() 54 | if 'TRUE'.startswith(ua): 55 | return True 56 | elif 'FALSE'.startswith(ua): 57 | return False 58 | else: 59 | raise ValueError('Arg is neither `True` nor `False`') 60 | 61 | def check_non_negative(value): 62 | ivalue = int(value) 63 | if ivalue < 0: 64 | raise argparse.ArgumentTypeError("%s is not non-negative!" % value) 65 | return ivalue 66 | 67 | def check_positive(value): 68 | ivalue = int(value) 69 | if ivalue < 1: 70 | raise argparse.ArgumentTypeError("%s is not positive!" % value) 71 | return ivalue 72 | 73 | def check_unit_interval(value): 74 | fvalue = float(value) 75 | if fvalue < 0 or fvalue > 1: 76 | raise argparse.ArgumentTypeError("%s is not in [0, 1] interval!" % value) 77 | return fvalue 78 | 79 | # No default value here. Indicate every single arguement. 80 | parser = argparse.ArgumentParser( 81 | description='two_tier.py\nNo default value! Indicate every argument.') 82 | 83 | # Hyperparameter arguements: 84 | parser.add_argument('--exp', help='Experiment name', 85 | type=str, required=False, default='_') 86 | parser.add_argument('--n_frames', help='How many "frames" to include in each\ 87 | Truncated BPTT pass', type=check_positive, required=True) 88 | parser.add_argument('--frame_size', help='How many samples per frame',\ 89 | type=check_positive, required=True) 90 | parser.add_argument('--weight_norm', help='Adding learnable weight normalization\ 91 | to all the linear layers (except for the embedding layer)',\ 92 | type=t_or_f, required=True) 93 | parser.add_argument('--emb_size', help='Size of embedding layer (0 to disable)', type=check_non_negative, required=True) 94 | parser.add_argument('--skip_conn', help='Add skip connections to RNN', type=t_or_f, required=True) 95 | parser.add_argument('--dim', help='Dimension of RNN and MLPs',\ 96 | type=check_positive, required=True) 97 | parser.add_argument('--n_rnn', help='Number of layers in the stacked RNN', 98 | type=check_positive, choices=xrange(1,40), required=True) 99 | parser.add_argument('--rnn_type', help='GRU or LSTM', choices=['LSTM', 'GRU'],\ 100 | required=True) 101 | parser.add_argument('--learn_h0', help='Whether to learn the initial state of RNN',\ 102 | type=t_or_f, required=True) 103 | parser.add_argument('--q_levels', help='Number of bins for quantization of audio samples. Should be 256 for mu-law.',\ 104 | type=check_positive, required=True) 105 | parser.add_argument('--q_type', help='Quantization in linear-scale, a-law-companding, or mu-law compandig. With mu-/a-law quantization level shoud be set as 256',\ 106 | choices=['linear', 'a-law', 'mu-law'], required=True) 107 | parser.add_argument('--which_set', help='the directory name of the dataset' , 108 | type=str, required=True) 109 | parser.add_argument('--batch_size', help='size of mini-batch', 110 | type=check_positive, choices=xrange(0, 129), required=True) 111 | 112 | parser.add_argument('--debug', help='Debug mode', required=False, default=False, action='store_true') 113 | # NEW 114 | parser.add_argument('--resume', help='Resume the same model from the last checkpoint. Order of params are important. [for now]',\ 115 | required=False, default=False, action='store_true') 116 | 117 | parser.add_argument('--n_secs', help='Seconds to generate',\ 118 | type=check_positive, required=True) 119 | parser.add_argument('--n_seqs', help='Number wavs to generate',\ 120 | type=check_positive, required=True) 121 | 122 | 123 | args = parser.parse_args() 124 | 125 | # NEW 126 | # Create tag for this experiment based on passed args 127 | tag = reduce(lambda a, b: a+b, sys.argv).replace('--resume', '').replace('/', '-').replace('--', '-').replace('True', 'T').replace('False', 'F') 128 | tag = re.sub(r'-n_secs[0-9]+', "", tag) 129 | tag = re.sub(r'-n_seqs[0-9]+', "", tag) 130 | tag = re.sub(r'_generate', "", tag) 131 | tag += '-lr'+str(LEARNING_RATE) 132 | print "Created experiment tag for these args:" 133 | print tag 134 | 135 | return args, tag 136 | 137 | args, tag = get_args() 138 | 139 | 140 | print "sup" 141 | 142 | N_FRAMES = args.n_frames # How many 'frames' to include in each truncated BPTT pass 143 | OVERLAP = FRAME_SIZE = args.frame_size # How many samples per frame 144 | WEIGHT_NORM = args.weight_norm 145 | EMB_SIZE = args.emb_size 146 | SKIP_CONN = args.skip_conn 147 | DIM = args.dim # Model dimensionality. 148 | N_RNN = args.n_rnn # How many RNNs to stack 149 | RNN_TYPE = args.rnn_type 150 | H0_MULT = 2 if RNN_TYPE == 'LSTM' else 1 151 | LEARN_H0 = args.learn_h0 152 | Q_LEVELS = args.q_levels # How many levels to use when discretizing samples. e.g. 256 = 8-bit scalar quantization 153 | Q_TYPE = args.q_type # log- or linear-scale 154 | WHICH_SET = args.which_set 155 | BATCH_SIZE = args.batch_size 156 | RESUME = args.resume 157 | N_SECS = args.n_secs 158 | N_SEQS = args.n_seqs 159 | 160 | 161 | print "hi" 162 | 163 | if Q_TYPE == 'mu-law' and Q_LEVELS != 256: 164 | raise ValueError('For mu-law Quantization levels should be exactly 256!') 165 | 166 | # Fixed hyperparams 167 | GRAD_CLIP = 1 # Elementwise grad clip threshold 168 | BITRATE = 16000 169 | 170 | # Other constants 171 | #TRAIN_MODE = 'iters' # To use PRINT_ITERS and STOP_ITERS 172 | TRAIN_MODE = 'time' # To use PRINT_TIME and STOP_TIME 173 | #TRAIN_MODE = 'time-iters' 174 | # To use PRINT_TIME for validation, 175 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 176 | #TRAIN_MODE = 'iters-time' 177 | # To use PRINT_ITERS for validation, 178 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 179 | PRINT_ITERS = 10000 # Print cost, generate samples, save model checkpoint every N iterations. 180 | STOP_ITERS = 100000 # Stop after this many iterations 181 | # TODO: 182 | PRINT_TIME = 90*60 # Print cost, generate samples, save model checkpoint every N seconds. 183 | STOP_TIME = 60*60*24*3 # Stop after this many seconds of actual training (not including time req'd to generate samples etc.) 184 | # TODO: 185 | RESULTS_DIR = 'results_2t' 186 | FOLDER_PREFIX = os.path.join(RESULTS_DIR, tag) 187 | SEQ_LEN = N_FRAMES * FRAME_SIZE # Total length (# of samples) of each truncated BPTT sequence 188 | Q_ZERO = numpy.int32(Q_LEVELS//2) # Discrete value correponding to zero amplitude 189 | 190 | 191 | print "SEQ_LEN", SEQ_LEN, N_FRAMES, FRAME_SIZE 192 | 193 | 194 | epoch_str = 'epoch' 195 | iter_str = 'iter' 196 | lowest_valid_str = 'lowest valid cost' 197 | corresp_test_str = 'correponding test cost' 198 | train_nll_str, valid_nll_str, test_nll_str = \ 199 | 'train NLL (bits)', 'valid NLL (bits)', 'test NLL (bits)' 200 | 201 | if args.debug: 202 | import warnings 203 | warnings.warn('----------RUNNING IN DEBUG MODE----------') 204 | TRAIN_MODE = 'time' 205 | PRINT_TIME = 100 206 | STOP_TIME = 3000 207 | STOP_ITERS = 1000 208 | 209 | ### Create directories ### 210 | # FOLDER_PREFIX: root, contains: 211 | # log.txt, __note.txt, train_log.pkl, train_log.png [, model_settings.txt] 212 | # FOLDER_PREFIX/params: saves all checkpoint params as pkl 213 | # FOLDER_PREFIX/samples: keeps all checkpoint samples as wav 214 | # FOLDER_PREFIX/best: keeps the best parameters, samples, ... 215 | if not os.path.exists(FOLDER_PREFIX): 216 | os.makedirs(FOLDER_PREFIX) 217 | PARAMS_PATH = os.path.join(FOLDER_PREFIX, 'params') 218 | if not os.path.exists(PARAMS_PATH): 219 | os.makedirs(PARAMS_PATH) 220 | SAMPLES_PATH = os.path.join(FOLDER_PREFIX, 'samples') 221 | if not os.path.exists(SAMPLES_PATH): 222 | os.makedirs(SAMPLES_PATH) 223 | BEST_PATH = os.path.join(FOLDER_PREFIX, 'best') 224 | if not os.path.exists(BEST_PATH): 225 | os.makedirs(BEST_PATH) 226 | 227 | lib.print_model_settings(locals(), path=FOLDER_PREFIX, sys_arg=True) 228 | 229 | ### Import the data_feeder ### 230 | # Handling WHICH_SET 231 | from datasets.dataset import music_train_feed_epoch as train_feeder 232 | from datasets.dataset import music_valid_feed_epoch as valid_feeder 233 | from datasets.dataset import music_test_feed_epoch as test_feeder 234 | 235 | def load_data(data_feeder): 236 | """ 237 | Helper function to deal with interface of different datasets. 238 | `data_feeder` should be `train_feeder`, `valid_feeder`, or `test_feeder`. 239 | """ 240 | return data_feeder(WHICH_SET, BATCH_SIZE, 241 | SEQ_LEN, 242 | OVERLAP, 243 | Q_LEVELS, 244 | Q_ZERO, 245 | Q_TYPE) 246 | 247 | ### Creating computation graph ### 248 | def frame_level_rnn(input_sequences, h0, reset): 249 | """ 250 | input_sequences.shape: (batch size, n frames * FRAME_SIZE) 251 | h0.shape: (batch size, N_RNN, DIM) 252 | reset.shape: () 253 | 254 | output.shape: (batch size, n frames * FRAME_SIZE, DIM) 255 | """ 256 | frames = input_sequences.reshape(( 257 | input_sequences.shape[0], 258 | input_sequences.shape[1] // FRAME_SIZE, 259 | FRAME_SIZE 260 | )) 261 | 262 | # Rescale frames from ints in [0, Q_LEVELS) to floats in [-2, 2] 263 | # (a reasonable range to pass as inputs to the RNN) 264 | frames = (frames.astype('float32') / lib.floatX(Q_LEVELS/2)) - lib.floatX(1) 265 | frames *= lib.floatX(2) 266 | # (128, 64, 4) 267 | 268 | # Initial state of RNNs 269 | learned_h0 = lib.param( 270 | 'FrameLevel.h0', 271 | numpy.zeros((N_RNN, H0_MULT*DIM), dtype=theano.config.floatX) 272 | ) 273 | # Handling LEARN_H0 274 | learned_h0.param = LEARN_H0 275 | learned_h0 = T.alloc(learned_h0, h0.shape[0], N_RNN, H0_MULT*DIM) 276 | learned_h0 = T.unbroadcast(learned_h0, 0, 1, 2) 277 | h0 = theano.ifelse.ifelse(reset, learned_h0, h0) 278 | 279 | # Handling RNN_TYPE 280 | # Handling SKIP_CONN 281 | if RNN_TYPE == 'GRU': 282 | rnns_out, last_hidden = lib.ops.stackedGRU('FrameLevel.GRU', 283 | N_RNN, 284 | FRAME_SIZE, 285 | DIM, 286 | frames, 287 | h0=h0, 288 | weightnorm=WEIGHT_NORM, 289 | skip_conn=SKIP_CONN) 290 | elif RNN_TYPE == 'LSTM': 291 | rnns_out, last_hidden = lib.ops.stackedLSTM('FrameLevel.LSTM', 292 | N_RNN, 293 | FRAME_SIZE, 294 | DIM, 295 | frames, 296 | h0=h0, 297 | weightnorm=WEIGHT_NORM, 298 | skip_conn=SKIP_CONN) 299 | 300 | # rnns_out (bs, seqlen, dim) (128, 64, 512) 301 | output = lib.ops.Linear( 302 | 'FrameLevel.Output', 303 | DIM, 304 | FRAME_SIZE * DIM, 305 | rnns_out, 306 | initialization='he', 307 | weightnorm=WEIGHT_NORM 308 | ) 309 | # output: (2, 9, 4*dim) 310 | output = output.reshape((output.shape[0], output.shape[1] * FRAME_SIZE, DIM)) 311 | # output: (2, 9*4, dim) 312 | 313 | return (output, last_hidden) 314 | 315 | def sample_level_predictor(frame_level_outputs, prev_samples): 316 | """ 317 | batch size = BATCH_SIZE * SEQ_LEN 318 | SEQ_LEN = N_FRAMES * FRAME_SIZE 319 | 320 | frame_level_outputs.shape: (batch size, DIM) 321 | prev_samples.shape: (batch size, FRAME_SIZE) int32 322 | 323 | output.shape: (batch size, Q_LEVELS) 324 | """ 325 | # Handling EMB_SIZE 326 | if EMB_SIZE == 0: 327 | prev_samples = lib.ops.T_one_hot(prev_samples, Q_LEVELS) 328 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, Q_LEVELS) 329 | last_out_shape = Q_LEVELS 330 | elif EMB_SIZE > 0: 331 | prev_samples = lib.ops.Embedding( 332 | 'SampleLevel.Embedding', 333 | Q_LEVELS, 334 | EMB_SIZE, 335 | prev_samples) 336 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, EMB_SIZE), f32 337 | last_out_shape = EMB_SIZE 338 | else: 339 | raise ValueError('EMB_SIZE cannot be negative.') 340 | 341 | prev_samples = prev_samples.reshape((-1, FRAME_SIZE * last_out_shape)) 342 | 343 | out = lib.ops.Linear( 344 | 'SampleLevel.L1_PrevSamples', 345 | FRAME_SIZE * last_out_shape, 346 | DIM, 347 | prev_samples, 348 | biases=False, 349 | initialization='he', 350 | weightnorm=WEIGHT_NORM) 351 | # shape: (BATCH_SIZE*N_FRAMES*FRAME_SIZE, DIM) 352 | 353 | out += frame_level_outputs 354 | # ^ (2*(9*4), dim) 355 | 356 | # L2 357 | out = lib.ops.Linear('SampleLevel.L2', 358 | DIM, 359 | DIM, 360 | out, 361 | initialization='he', 362 | weightnorm=WEIGHT_NORM) 363 | out = T.nnet.relu(out) 364 | 365 | # L3 366 | out = lib.ops.Linear('SampleLevel.L3', 367 | DIM, 368 | DIM, 369 | out, 370 | initialization='he', 371 | weightnorm=WEIGHT_NORM) 372 | out = T.nnet.relu(out) 373 | 374 | # Output 375 | # We apply the softmax later 376 | out = lib.ops.Linear('SampleLevel.Output', 377 | DIM, 378 | Q_LEVELS, 379 | out, 380 | weightnorm=WEIGHT_NORM) 381 | return out 382 | 383 | sequences = T.imatrix('sequences') 384 | h0 = T.tensor3('h0') 385 | reset = T.iscalar('reset') 386 | mask = T.matrix('mask') 387 | 388 | if args.debug: 389 | # Solely for debugging purposes. 390 | # Maybe I should set the compute_test_value=warn from here. 391 | sequences.tag.test_value = numpy.zeros((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='int32') 392 | h0.tag.test_value = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 393 | reset.tag.test_value = numpy.array(1, dtype='int32') 394 | mask.tag.test_value = numpy.ones((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='float32') 395 | 396 | input_sequences = sequences[:, :-FRAME_SIZE] 397 | target_sequences = sequences[:, FRAME_SIZE:] 398 | 399 | target_mask = mask[:, FRAME_SIZE:] 400 | 401 | frame_level_outputs, new_h0 =\ 402 | frame_level_rnn(input_sequences, h0, reset) 403 | 404 | prev_samples = sequences[:, :-1] 405 | prev_samples = prev_samples.reshape((1, BATCH_SIZE, 1, -1)) 406 | prev_samples = T.nnet.neighbours.images2neibs(prev_samples, (1, FRAME_SIZE), neib_step=(1, 1), mode='valid') 407 | prev_samples = prev_samples.reshape((BATCH_SIZE * SEQ_LEN, FRAME_SIZE)) 408 | # (batch_size*n_frames*frame_size, frame_size) 409 | 410 | sample_level_outputs = sample_level_predictor( 411 | frame_level_outputs.reshape((BATCH_SIZE * SEQ_LEN, DIM)), 412 | prev_samples, 413 | ) 414 | 415 | cost = T.nnet.categorical_crossentropy( 416 | T.nnet.softmax(sample_level_outputs), 417 | target_sequences.flatten() 418 | ) 419 | cost = cost.reshape(target_sequences.shape) 420 | cost = cost * target_mask 421 | # Don't use these lines; could end up with NaN 422 | # Specially at the end of audio files where mask is 423 | # all zero for some of the shorter files in mini-batch. 424 | #cost = cost.sum(axis=1) / target_mask.sum(axis=1) 425 | #cost = cost.mean(axis=0) 426 | 427 | # Use this one instead. 428 | cost = cost.sum() 429 | cost = cost / target_mask.sum() 430 | 431 | # By default we report cross-entropy cost in bits. 432 | # Switch to nats by commenting out this line: 433 | # log_2(e) = 1.44269504089 434 | cost = cost * lib.floatX(numpy.log2(numpy.e)) 435 | 436 | ### Getting the params, grads, updates, and Theano functions ### 437 | params = lib.get_params(cost, lambda x: hasattr(x, 'param') and x.param==True) 438 | lib.print_params_info(params, path=FOLDER_PREFIX) 439 | 440 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 441 | grads = [T.clip(g, lib.floatX(-GRAD_CLIP), lib.floatX(GRAD_CLIP)) for g in grads] 442 | 443 | updates = lasagne.updates.adam(grads, params, learning_rate=LEARNING_RATE) 444 | 445 | # Training function 446 | train_fn = theano.function( 447 | [sequences, h0, reset, mask], 448 | [cost, new_h0], 449 | updates=updates, 450 | on_unused_input='warn' 451 | ) 452 | 453 | # Validation and Test function, hence no updates 454 | test_fn = theano.function( 455 | [sequences, h0, reset, mask], 456 | [cost, new_h0], 457 | on_unused_input='warn' 458 | ) 459 | 460 | # Sampling at frame level 461 | frame_level_generate_fn = theano.function( 462 | [sequences, h0, reset], 463 | frame_level_rnn(sequences, h0, reset), 464 | on_unused_input='warn' 465 | ) 466 | 467 | # Sampling at audio sample level 468 | frame_level_outputs = T.matrix('frame_level_outputs') 469 | prev_samples = T.imatrix('prev_samples') 470 | sample_level_generate_fn = theano.function( 471 | [frame_level_outputs, prev_samples], 472 | lib.ops.softmax_and_sample( 473 | sample_level_predictor( 474 | frame_level_outputs, 475 | prev_samples, 476 | ) 477 | ), 478 | on_unused_input='warn' 479 | ) 480 | 481 | # Uniform [-0.5, 0.5) for half of initial state for generated samples 482 | # to study the behaviour of the model and also to introduce some diversity 483 | # to samples in a simple way. [it's disabled for now] 484 | fixed_rand_h0 = numpy.random.rand(N_SEQS//2, N_RNN, H0_MULT*DIM) 485 | fixed_rand_h0 -= 0.5 486 | fixed_rand_h0 = fixed_rand_h0.astype('float32') 487 | 488 | def generate_and_save_samples(tag, N_SECS=5): 489 | def write_audio_file(name, data): 490 | data = data.astype('float32') 491 | data -= data.min() 492 | data /= data.max() 493 | data -= 0.5 494 | data *= 0.95 495 | scipy.io.wavfile.write( 496 | os.path.join(SAMPLES_PATH, name+'.wav'), 497 | BITRATE, 498 | data) 499 | 500 | total_time = time() 501 | # Generate N_SEQS' sample files, each 5 seconds long 502 | LENGTH = N_SECS*BITRATE if not args.debug else 100 503 | 504 | samples = numpy.zeros((N_SEQS, LENGTH), dtype='int32') 505 | samples[:, :FRAME_SIZE] = Q_ZERO 506 | 507 | # First half zero, others fixed random at each checkpoint 508 | h0 = numpy.zeros( 509 | (N_SEQS-fixed_rand_h0.shape[0], N_RNN, H0_MULT*DIM), 510 | dtype='float32' 511 | ) 512 | h0 = numpy.concatenate((h0, fixed_rand_h0), axis=0) 513 | frame_level_outputs = None 514 | 515 | for t in xrange(FRAME_SIZE, LENGTH): 516 | 517 | if t % FRAME_SIZE == 0: 518 | frame_level_outputs, h0 = frame_level_generate_fn( 519 | samples[:, t-FRAME_SIZE:t], 520 | h0, 521 | #numpy.full((N_SEQS, ), (t == FRAME_SIZE), dtype='int32'), 522 | numpy.int32(t == FRAME_SIZE) 523 | ) 524 | 525 | samples[:, t] = sample_level_generate_fn( 526 | frame_level_outputs[:, t % FRAME_SIZE], 527 | samples[:, t-FRAME_SIZE:t], 528 | ) 529 | 530 | total_time = time() - total_time 531 | log = "{} samples of {} seconds length generated in {} seconds." 532 | log = log.format(N_SEQS, N_SECS, total_time) 533 | print log 534 | 535 | for i in xrange(N_SEQS): 536 | samp = samples[i] 537 | if Q_TYPE == 'mu-law': 538 | from datasets.dataset import mu2linear 539 | samp = mu2linear(samp) 540 | elif Q_TYPE == 'a-law': 541 | raise NotImplementedError('a-law is not implemented') 542 | 543 | now = datetime.datetime.now() 544 | now_time = "{}:{}:{}".format(now.hour, now.minute, now.second) 545 | 546 | file_name = "sample_{}_{}_{}_{}".format(tag, N_SECS, now_time, i) 547 | print "writing...", file_name 548 | write_audio_file(file_name, samp) 549 | 550 | 551 | 552 | def monitor(data_feeder): 553 | """ 554 | Cost and time of test_fn on a given dataset section. 555 | Pass only one of `valid_feeder` or `test_feeder`. 556 | Don't pass `train_feed`. 557 | 558 | :returns: 559 | Mean cost over the input dataset (data_feeder) 560 | Total time spent 561 | """ 562 | _total_time = time() 563 | _h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 564 | _costs = [] 565 | _data_feeder = load_data(data_feeder) 566 | for _seqs, _reset, _mask in _data_feeder: 567 | _cost, _h0 = test_fn(_seqs, _h0, _reset, _mask) 568 | _costs.append(_cost) 569 | 570 | return numpy.mean(_costs), time() - _total_time 571 | 572 | print "Wall clock time spent before training started: {:.2f}h"\ 573 | .format((time()-exp_start)/3600.) 574 | print "Training!" 575 | total_iters = 0 576 | total_time = 0. 577 | last_print_time = 0. 578 | last_print_iters = 0 579 | costs = [] 580 | lowest_valid_cost = numpy.finfo(numpy.float32).max 581 | corresponding_test_cost = numpy.finfo(numpy.float32).max 582 | new_lowest_cost = False 583 | end_of_batch = False 584 | epoch = 0 585 | 586 | h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 587 | 588 | # Initial load train dataset 589 | tr_feeder = load_data(train_feeder) 590 | 591 | ### Handling the resume option: 592 | if True: #if Resume: 593 | # Check if checkpoint from previous run is not corrupted. 594 | # Then overwrite some of the variables above. 595 | iters_to_consume, res_path, epoch, total_iters,\ 596 | [lowest_valid_cost, corresponding_test_cost, test_cost] = \ 597 | lib.resumable(path=FOLDER_PREFIX, 598 | iter_key=iter_str, 599 | epoch_key=epoch_str, 600 | add_resume_counter=True, 601 | other_keys=[lowest_valid_str, 602 | corresp_test_str, 603 | test_nll_str]) 604 | # At this point we saved the pkl file. 605 | last_print_iters = total_iters 606 | print "### RESUMING JOB FROM EPOCH {}, ITER {}".format(epoch, total_iters) 607 | # Consumes this much iters to get to the last point in training data. 608 | consume_time = time() 609 | for i in xrange(iters_to_consume): 610 | tr_feeder.next() 611 | consume_time = time() - consume_time 612 | print "Train data ready in {:.2f}secs after consuming {} minibatches.".\ 613 | format(consume_time, iters_to_consume) 614 | 615 | lib.load_params(res_path) 616 | print "Parameters from last available checkpoint loaded." 617 | 618 | 619 | 620 | # 2. Stdout the training progress 621 | print_info = "epoch:{}\ttotal iters:{}\twall clock time:{:.2f}h\n" 622 | print_info = print_info.format(epoch, 623 | total_iters, 624 | (time()-exp_start)/3600) 625 | print print_info 626 | 627 | tag = "e{}_i{}" 628 | tag = tag.format(epoch, 629 | total_iters) 630 | 631 | # 5. Generate and save samples (time consuming) 632 | # If not successful, we still have the params to sample afterward 633 | print "Sampling!", 634 | # Generate samples 635 | generate_and_save_samples(tag, N_SECS) 636 | print "Done!" 637 | 638 | print "Wall clock time spent: {:.2f}h"\ 639 | .format((time()-exp_start)/3600) 640 | 641 | sys.exit() -------------------------------------------------------------------------------- /models/two_tier/two_tier_generate32k.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN Audio Generation Model 3 | 4 | Two-tier model, Quantized input 5 | For more info: 6 | $ python two_tier.py -h 7 | 8 | How-to-run example: 9 | sampleRNN$ pwd 10 | /u/mehris/sampleRNN 11 | 12 | sampleRNN$ \ 13 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python -u \ 14 | models/two_tier/two_tier.py --exp AXIS1 --n_frames 12 --frame_size 10 \ 15 | --weight_norm True --emb_size 64 --skip_conn False --dim 32 --n_rnn 2 \ 16 | --rnn_type LSTM --learn_h0 False --q_levels 16 --q_type linear \ 17 | --batch_size 128 --which_set MUSIC 18 | 19 | To resume add ` --resume` to the END of the EXACTLY above line. You can run the 20 | resume code as many time as possible, depending on the TRAIN_MODE. 21 | (folder name, file name, flags, their order, and the values are important) 22 | """ 23 | from time import time 24 | from datetime import datetime 25 | print "Experiment started at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 26 | exp_start = time() 27 | 28 | import os, sys, glob 29 | sys.path.insert(1, os.getcwd()) 30 | import argparse 31 | import datetime 32 | import numpy 33 | numpy.random.seed(123) 34 | np = numpy 35 | import random 36 | random.seed(123) 37 | import re 38 | 39 | 40 | import theano 41 | import theano.tensor as T 42 | import theano.ifelse 43 | import lasagne 44 | import scipy.io.wavfile 45 | 46 | import lib 47 | 48 | LEARNING_RATE = 0.001 49 | 50 | ### Parsing passed args/hyperparameters ### 51 | def get_args(): 52 | def t_or_f(arg): 53 | ua = str(arg).upper() 54 | if 'TRUE'.startswith(ua): 55 | return True 56 | elif 'FALSE'.startswith(ua): 57 | return False 58 | else: 59 | raise ValueError('Arg is neither `True` nor `False`') 60 | 61 | def check_non_negative(value): 62 | ivalue = int(value) 63 | if ivalue < 0: 64 | raise argparse.ArgumentTypeError("%s is not non-negative!" % value) 65 | return ivalue 66 | 67 | def check_positive(value): 68 | ivalue = int(value) 69 | if ivalue < 1: 70 | raise argparse.ArgumentTypeError("%s is not positive!" % value) 71 | return ivalue 72 | 73 | def check_unit_interval(value): 74 | fvalue = float(value) 75 | if fvalue < 0 or fvalue > 1: 76 | raise argparse.ArgumentTypeError("%s is not in [0, 1] interval!" % value) 77 | return fvalue 78 | 79 | # No default value here. Indicate every single arguement. 80 | parser = argparse.ArgumentParser( 81 | description='two_tier.py\nNo default value! Indicate every argument.') 82 | 83 | # Hyperparameter arguements: 84 | parser.add_argument('--exp', help='Experiment name', 85 | type=str, required=False, default='_') 86 | parser.add_argument('--n_frames', help='How many "frames" to include in each\ 87 | Truncated BPTT pass', type=check_positive, required=True) 88 | parser.add_argument('--frame_size', help='How many samples per frame',\ 89 | type=check_positive, required=True) 90 | parser.add_argument('--weight_norm', help='Adding learnable weight normalization\ 91 | to all the linear layers (except for the embedding layer)',\ 92 | type=t_or_f, required=True) 93 | parser.add_argument('--emb_size', help='Size of embedding layer (0 to disable)', type=check_non_negative, required=True) 94 | parser.add_argument('--skip_conn', help='Add skip connections to RNN', type=t_or_f, required=True) 95 | parser.add_argument('--dim', help='Dimension of RNN and MLPs',\ 96 | type=check_positive, required=True) 97 | parser.add_argument('--n_rnn', help='Number of layers in the stacked RNN', 98 | type=check_positive, choices=xrange(1,40), required=True) 99 | parser.add_argument('--rnn_type', help='GRU or LSTM', choices=['LSTM', 'GRU'],\ 100 | required=True) 101 | parser.add_argument('--learn_h0', help='Whether to learn the initial state of RNN',\ 102 | type=t_or_f, required=True) 103 | parser.add_argument('--q_levels', help='Number of bins for quantization of audio samples. Should be 256 for mu-law.',\ 104 | type=check_positive, required=True) 105 | parser.add_argument('--q_type', help='Quantization in linear-scale, a-law-companding, or mu-law compandig. With mu-/a-law quantization level shoud be set as 256',\ 106 | choices=['linear', 'a-law', 'mu-law'], required=True) 107 | parser.add_argument('--which_set', help='the directory name of the dataset' , 108 | type=str, required=True) 109 | parser.add_argument('--batch_size', help='size of mini-batch', 110 | type=check_positive, choices=xrange(0, 129), required=True) 111 | 112 | parser.add_argument('--debug', help='Debug mode', required=False, default=False, action='store_true') 113 | # NEW 114 | parser.add_argument('--resume', help='Resume the same model from the last checkpoint. Order of params are important. [for now]',\ 115 | required=False, default=False, action='store_true') 116 | 117 | parser.add_argument('--n_secs', help='Seconds to generate',\ 118 | type=check_positive, required=True) 119 | parser.add_argument('--n_seqs', help='Number wavs to generate',\ 120 | type=check_positive, required=True) 121 | 122 | 123 | args = parser.parse_args() 124 | 125 | # NEW 126 | # Create tag for this experiment based on passed args 127 | tag = reduce(lambda a, b: a+b, sys.argv).replace('--resume', '').replace('/', '-').replace('--', '-').replace('True', 'T').replace('False', 'F') 128 | tag = re.sub(r'-n_secs[0-9]+', "", tag) 129 | tag = re.sub(r'-n_seqs[0-9]+', "", tag) 130 | tag = re.sub(r'_generate', "", tag) 131 | tag += '-lr'+str(LEARNING_RATE) 132 | print "Created experiment tag for these args:" 133 | print tag 134 | 135 | return args, tag 136 | 137 | args, tag = get_args() 138 | 139 | 140 | print "sup" 141 | 142 | N_FRAMES = args.n_frames # How many 'frames' to include in each truncated BPTT pass 143 | OVERLAP = FRAME_SIZE = args.frame_size # How many samples per frame 144 | WEIGHT_NORM = args.weight_norm 145 | EMB_SIZE = args.emb_size 146 | SKIP_CONN = args.skip_conn 147 | DIM = args.dim # Model dimensionality. 148 | N_RNN = args.n_rnn # How many RNNs to stack 149 | RNN_TYPE = args.rnn_type 150 | H0_MULT = 2 if RNN_TYPE == 'LSTM' else 1 151 | LEARN_H0 = args.learn_h0 152 | Q_LEVELS = args.q_levels # How many levels to use when discretizing samples. e.g. 256 = 8-bit scalar quantization 153 | Q_TYPE = args.q_type # log- or linear-scale 154 | WHICH_SET = args.which_set 155 | BATCH_SIZE = args.batch_size 156 | RESUME = args.resume 157 | N_SECS = args.n_secs 158 | N_SEQS = args.n_seqs 159 | 160 | 161 | print "hi" 162 | 163 | if Q_TYPE == 'mu-law' and Q_LEVELS != 256: 164 | raise ValueError('For mu-law Quantization levels should be exactly 256!') 165 | 166 | # Fixed hyperparams 167 | GRAD_CLIP = 1 # Elementwise grad clip threshold 168 | BITRATE = 32000 169 | 170 | # Other constants 171 | #TRAIN_MODE = 'iters' # To use PRINT_ITERS and STOP_ITERS 172 | TRAIN_MODE = 'time' # To use PRINT_TIME and STOP_TIME 173 | #TRAIN_MODE = 'time-iters' 174 | # To use PRINT_TIME for validation, 175 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 176 | #TRAIN_MODE = 'iters-time' 177 | # To use PRINT_ITERS for validation, 178 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 179 | PRINT_ITERS = 10000 # Print cost, generate samples, save model checkpoint every N iterations. 180 | STOP_ITERS = 100000 # Stop after this many iterations 181 | # TODO: 182 | PRINT_TIME = 90*60 # Print cost, generate samples, save model checkpoint every N seconds. 183 | STOP_TIME = 60*60*24*3 # Stop after this many seconds of actual training (not including time req'd to generate samples etc.) 184 | # TODO: 185 | RESULTS_DIR = 'results_2t' 186 | FOLDER_PREFIX = os.path.join(RESULTS_DIR, tag) 187 | SEQ_LEN = N_FRAMES * FRAME_SIZE # Total length (# of samples) of each truncated BPTT sequence 188 | Q_ZERO = numpy.int32(Q_LEVELS//2) # Discrete value correponding to zero amplitude 189 | 190 | 191 | print "SEQ_LEN", SEQ_LEN, N_FRAMES, FRAME_SIZE 192 | 193 | 194 | epoch_str = 'epoch' 195 | iter_str = 'iter' 196 | lowest_valid_str = 'lowest valid cost' 197 | corresp_test_str = 'correponding test cost' 198 | train_nll_str, valid_nll_str, test_nll_str = \ 199 | 'train NLL (bits)', 'valid NLL (bits)', 'test NLL (bits)' 200 | 201 | if args.debug: 202 | import warnings 203 | warnings.warn('----------RUNNING IN DEBUG MODE----------') 204 | TRAIN_MODE = 'time' 205 | PRINT_TIME = 100 206 | STOP_TIME = 3000 207 | STOP_ITERS = 1000 208 | 209 | ### Create directories ### 210 | # FOLDER_PREFIX: root, contains: 211 | # log.txt, __note.txt, train_log.pkl, train_log.png [, model_settings.txt] 212 | # FOLDER_PREFIX/params: saves all checkpoint params as pkl 213 | # FOLDER_PREFIX/samples: keeps all checkpoint samples as wav 214 | # FOLDER_PREFIX/best: keeps the best parameters, samples, ... 215 | if not os.path.exists(FOLDER_PREFIX): 216 | os.makedirs(FOLDER_PREFIX) 217 | PARAMS_PATH = os.path.join(FOLDER_PREFIX, 'params') 218 | if not os.path.exists(PARAMS_PATH): 219 | os.makedirs(PARAMS_PATH) 220 | SAMPLES_PATH = os.path.join(FOLDER_PREFIX, 'samples') 221 | if not os.path.exists(SAMPLES_PATH): 222 | os.makedirs(SAMPLES_PATH) 223 | BEST_PATH = os.path.join(FOLDER_PREFIX, 'best') 224 | if not os.path.exists(BEST_PATH): 225 | os.makedirs(BEST_PATH) 226 | 227 | lib.print_model_settings(locals(), path=FOLDER_PREFIX, sys_arg=True) 228 | 229 | ### Import the data_feeder ### 230 | # Handling WHICH_SET 231 | from datasets.dataset import music_train_feed_epoch as train_feeder 232 | from datasets.dataset import music_valid_feed_epoch as valid_feeder 233 | from datasets.dataset import music_test_feed_epoch as test_feeder 234 | 235 | def load_data(data_feeder): 236 | """ 237 | Helper function to deal with interface of different datasets. 238 | `data_feeder` should be `train_feeder`, `valid_feeder`, or `test_feeder`. 239 | """ 240 | return data_feeder(WHICH_SET, BATCH_SIZE, 241 | SEQ_LEN, 242 | OVERLAP, 243 | Q_LEVELS, 244 | Q_ZERO, 245 | Q_TYPE) 246 | 247 | ### Creating computation graph ### 248 | def frame_level_rnn(input_sequences, h0, reset): 249 | """ 250 | input_sequences.shape: (batch size, n frames * FRAME_SIZE) 251 | h0.shape: (batch size, N_RNN, DIM) 252 | reset.shape: () 253 | 254 | output.shape: (batch size, n frames * FRAME_SIZE, DIM) 255 | """ 256 | frames = input_sequences.reshape(( 257 | input_sequences.shape[0], 258 | input_sequences.shape[1] // FRAME_SIZE, 259 | FRAME_SIZE 260 | )) 261 | 262 | # Rescale frames from ints in [0, Q_LEVELS) to floats in [-2, 2] 263 | # (a reasonable range to pass as inputs to the RNN) 264 | frames = (frames.astype('float32') / lib.floatX(Q_LEVELS/2)) - lib.floatX(1) 265 | frames *= lib.floatX(2) 266 | # (128, 64, 4) 267 | 268 | # Initial state of RNNs 269 | learned_h0 = lib.param( 270 | 'FrameLevel.h0', 271 | numpy.zeros((N_RNN, H0_MULT*DIM), dtype=theano.config.floatX) 272 | ) 273 | # Handling LEARN_H0 274 | learned_h0.param = LEARN_H0 275 | learned_h0 = T.alloc(learned_h0, h0.shape[0], N_RNN, H0_MULT*DIM) 276 | learned_h0 = T.unbroadcast(learned_h0, 0, 1, 2) 277 | h0 = theano.ifelse.ifelse(reset, learned_h0, h0) 278 | 279 | # Handling RNN_TYPE 280 | # Handling SKIP_CONN 281 | if RNN_TYPE == 'GRU': 282 | rnns_out, last_hidden = lib.ops.stackedGRU('FrameLevel.GRU', 283 | N_RNN, 284 | FRAME_SIZE, 285 | DIM, 286 | frames, 287 | h0=h0, 288 | weightnorm=WEIGHT_NORM, 289 | skip_conn=SKIP_CONN) 290 | elif RNN_TYPE == 'LSTM': 291 | rnns_out, last_hidden = lib.ops.stackedLSTM('FrameLevel.LSTM', 292 | N_RNN, 293 | FRAME_SIZE, 294 | DIM, 295 | frames, 296 | h0=h0, 297 | weightnorm=WEIGHT_NORM, 298 | skip_conn=SKIP_CONN) 299 | 300 | # rnns_out (bs, seqlen, dim) (128, 64, 512) 301 | output = lib.ops.Linear( 302 | 'FrameLevel.Output', 303 | DIM, 304 | FRAME_SIZE * DIM, 305 | rnns_out, 306 | initialization='he', 307 | weightnorm=WEIGHT_NORM 308 | ) 309 | # output: (2, 9, 4*dim) 310 | output = output.reshape((output.shape[0], output.shape[1] * FRAME_SIZE, DIM)) 311 | # output: (2, 9*4, dim) 312 | 313 | return (output, last_hidden) 314 | 315 | def sample_level_predictor(frame_level_outputs, prev_samples): 316 | """ 317 | batch size = BATCH_SIZE * SEQ_LEN 318 | SEQ_LEN = N_FRAMES * FRAME_SIZE 319 | 320 | frame_level_outputs.shape: (batch size, DIM) 321 | prev_samples.shape: (batch size, FRAME_SIZE) int32 322 | 323 | output.shape: (batch size, Q_LEVELS) 324 | """ 325 | # Handling EMB_SIZE 326 | if EMB_SIZE == 0: 327 | prev_samples = lib.ops.T_one_hot(prev_samples, Q_LEVELS) 328 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, Q_LEVELS) 329 | last_out_shape = Q_LEVELS 330 | elif EMB_SIZE > 0: 331 | prev_samples = lib.ops.Embedding( 332 | 'SampleLevel.Embedding', 333 | Q_LEVELS, 334 | EMB_SIZE, 335 | prev_samples) 336 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, EMB_SIZE), f32 337 | last_out_shape = EMB_SIZE 338 | else: 339 | raise ValueError('EMB_SIZE cannot be negative.') 340 | 341 | prev_samples = prev_samples.reshape((-1, FRAME_SIZE * last_out_shape)) 342 | 343 | out = lib.ops.Linear( 344 | 'SampleLevel.L1_PrevSamples', 345 | FRAME_SIZE * last_out_shape, 346 | DIM, 347 | prev_samples, 348 | biases=False, 349 | initialization='he', 350 | weightnorm=WEIGHT_NORM) 351 | # shape: (BATCH_SIZE*N_FRAMES*FRAME_SIZE, DIM) 352 | 353 | out += frame_level_outputs 354 | # ^ (2*(9*4), dim) 355 | 356 | # L2 357 | out = lib.ops.Linear('SampleLevel.L2', 358 | DIM, 359 | DIM, 360 | out, 361 | initialization='he', 362 | weightnorm=WEIGHT_NORM) 363 | out = T.nnet.relu(out) 364 | 365 | # L3 366 | out = lib.ops.Linear('SampleLevel.L3', 367 | DIM, 368 | DIM, 369 | out, 370 | initialization='he', 371 | weightnorm=WEIGHT_NORM) 372 | out = T.nnet.relu(out) 373 | 374 | # Output 375 | # We apply the softmax later 376 | out = lib.ops.Linear('SampleLevel.Output', 377 | DIM, 378 | Q_LEVELS, 379 | out, 380 | weightnorm=WEIGHT_NORM) 381 | return out 382 | 383 | sequences = T.imatrix('sequences') 384 | h0 = T.tensor3('h0') 385 | reset = T.iscalar('reset') 386 | mask = T.matrix('mask') 387 | 388 | if args.debug: 389 | # Solely for debugging purposes. 390 | # Maybe I should set the compute_test_value=warn from here. 391 | sequences.tag.test_value = numpy.zeros((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='int32') 392 | h0.tag.test_value = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 393 | reset.tag.test_value = numpy.array(1, dtype='int32') 394 | mask.tag.test_value = numpy.ones((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='float32') 395 | 396 | input_sequences = sequences[:, :-FRAME_SIZE] 397 | target_sequences = sequences[:, FRAME_SIZE:] 398 | 399 | target_mask = mask[:, FRAME_SIZE:] 400 | 401 | frame_level_outputs, new_h0 =\ 402 | frame_level_rnn(input_sequences, h0, reset) 403 | 404 | prev_samples = sequences[:, :-1] 405 | prev_samples = prev_samples.reshape((1, BATCH_SIZE, 1, -1)) 406 | prev_samples = T.nnet.neighbours.images2neibs(prev_samples, (1, FRAME_SIZE), neib_step=(1, 1), mode='valid') 407 | prev_samples = prev_samples.reshape((BATCH_SIZE * SEQ_LEN, FRAME_SIZE)) 408 | # (batch_size*n_frames*frame_size, frame_size) 409 | 410 | sample_level_outputs = sample_level_predictor( 411 | frame_level_outputs.reshape((BATCH_SIZE * SEQ_LEN, DIM)), 412 | prev_samples, 413 | ) 414 | 415 | cost = T.nnet.categorical_crossentropy( 416 | T.nnet.softmax(sample_level_outputs), 417 | target_sequences.flatten() 418 | ) 419 | cost = cost.reshape(target_sequences.shape) 420 | cost = cost * target_mask 421 | # Don't use these lines; could end up with NaN 422 | # Specially at the end of audio files where mask is 423 | # all zero for some of the shorter files in mini-batch. 424 | #cost = cost.sum(axis=1) / target_mask.sum(axis=1) 425 | #cost = cost.mean(axis=0) 426 | 427 | # Use this one instead. 428 | cost = cost.sum() 429 | cost = cost / target_mask.sum() 430 | 431 | # By default we report cross-entropy cost in bits. 432 | # Switch to nats by commenting out this line: 433 | # log_2(e) = 1.44269504089 434 | cost = cost * lib.floatX(numpy.log2(numpy.e)) 435 | 436 | ### Getting the params, grads, updates, and Theano functions ### 437 | params = lib.get_params(cost, lambda x: hasattr(x, 'param') and x.param==True) 438 | lib.print_params_info(params, path=FOLDER_PREFIX) 439 | 440 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 441 | grads = [T.clip(g, lib.floatX(-GRAD_CLIP), lib.floatX(GRAD_CLIP)) for g in grads] 442 | 443 | updates = lasagne.updates.adam(grads, params, learning_rate=LEARNING_RATE) 444 | 445 | # Training function 446 | train_fn = theano.function( 447 | [sequences, h0, reset, mask], 448 | [cost, new_h0], 449 | updates=updates, 450 | on_unused_input='warn' 451 | ) 452 | 453 | # Validation and Test function, hence no updates 454 | test_fn = theano.function( 455 | [sequences, h0, reset, mask], 456 | [cost, new_h0], 457 | on_unused_input='warn' 458 | ) 459 | 460 | # Sampling at frame level 461 | frame_level_generate_fn = theano.function( 462 | [sequences, h0, reset], 463 | frame_level_rnn(sequences, h0, reset), 464 | on_unused_input='warn' 465 | ) 466 | 467 | # Sampling at audio sample level 468 | frame_level_outputs = T.matrix('frame_level_outputs') 469 | prev_samples = T.imatrix('prev_samples') 470 | sample_level_generate_fn = theano.function( 471 | [frame_level_outputs, prev_samples], 472 | lib.ops.softmax_and_sample( 473 | sample_level_predictor( 474 | frame_level_outputs, 475 | prev_samples, 476 | ) 477 | ), 478 | on_unused_input='warn' 479 | ) 480 | 481 | # Uniform [-0.5, 0.5) for half of initial state for generated samples 482 | # to study the behaviour of the model and also to introduce some diversity 483 | # to samples in a simple way. [it's disabled for now] 484 | fixed_rand_h0 = numpy.random.rand(N_SEQS//2, N_RNN, H0_MULT*DIM) 485 | fixed_rand_h0 -= 0.5 486 | fixed_rand_h0 = fixed_rand_h0.astype('float32') 487 | 488 | def generate_and_save_samples(tag, N_SECS=5): 489 | def write_audio_file(name, data): 490 | data = data.astype('float32') 491 | data -= data.min() 492 | data /= data.max() 493 | data -= 0.5 494 | data *= 0.95 495 | scipy.io.wavfile.write( 496 | os.path.join(SAMPLES_PATH, name+'.wav'), 497 | BITRATE, 498 | data) 499 | 500 | total_time = time() 501 | # Generate N_SEQS' sample files, each 5 seconds long 502 | LENGTH = N_SECS*BITRATE if not args.debug else 100 503 | 504 | samples = numpy.zeros((N_SEQS, LENGTH), dtype='int32') 505 | samples[:, :FRAME_SIZE] = Q_ZERO 506 | 507 | # First half zero, others fixed random at each checkpoint 508 | h0 = numpy.zeros( 509 | (N_SEQS-fixed_rand_h0.shape[0], N_RNN, H0_MULT*DIM), 510 | dtype='float32' 511 | ) 512 | h0 = numpy.concatenate((h0, fixed_rand_h0), axis=0) 513 | frame_level_outputs = None 514 | 515 | for t in xrange(FRAME_SIZE, LENGTH): 516 | 517 | if t % FRAME_SIZE == 0: 518 | frame_level_outputs, h0 = frame_level_generate_fn( 519 | samples[:, t-FRAME_SIZE:t], 520 | h0, 521 | #numpy.full((N_SEQS, ), (t == FRAME_SIZE), dtype='int32'), 522 | numpy.int32(t == FRAME_SIZE) 523 | ) 524 | 525 | samples[:, t] = sample_level_generate_fn( 526 | frame_level_outputs[:, t % FRAME_SIZE], 527 | samples[:, t-FRAME_SIZE:t], 528 | ) 529 | 530 | total_time = time() - total_time 531 | log = "{} samples of {} seconds length generated in {} seconds." 532 | log = log.format(N_SEQS, N_SECS, total_time) 533 | print log 534 | 535 | for i in xrange(N_SEQS): 536 | samp = samples[i] 537 | if Q_TYPE == 'mu-law': 538 | from datasets.dataset import mu2linear 539 | samp = mu2linear(samp) 540 | elif Q_TYPE == 'a-law': 541 | raise NotImplementedError('a-law is not implemented') 542 | 543 | now = datetime.datetime.now() 544 | now_time = "{}:{}:{}".format(now.hour, now.minute, now.second) 545 | 546 | file_name = "sample_{}_{}_{}_{}".format(tag, N_SECS, now_time, i) 547 | print "writing...", file_name 548 | write_audio_file(file_name, samp) 549 | 550 | 551 | 552 | def monitor(data_feeder): 553 | """ 554 | Cost and time of test_fn on a given dataset section. 555 | Pass only one of `valid_feeder` or `test_feeder`. 556 | Don't pass `train_feed`. 557 | 558 | :returns: 559 | Mean cost over the input dataset (data_feeder) 560 | Total time spent 561 | """ 562 | _total_time = time() 563 | _h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 564 | _costs = [] 565 | _data_feeder = load_data(data_feeder) 566 | for _seqs, _reset, _mask in _data_feeder: 567 | _cost, _h0 = test_fn(_seqs, _h0, _reset, _mask) 568 | _costs.append(_cost) 569 | 570 | return numpy.mean(_costs), time() - _total_time 571 | 572 | print "Wall clock time spent before training started: {:.2f}h"\ 573 | .format((time()-exp_start)/3600.) 574 | print "Training!" 575 | total_iters = 0 576 | total_time = 0. 577 | last_print_time = 0. 578 | last_print_iters = 0 579 | costs = [] 580 | lowest_valid_cost = numpy.finfo(numpy.float32).max 581 | corresponding_test_cost = numpy.finfo(numpy.float32).max 582 | new_lowest_cost = False 583 | end_of_batch = False 584 | epoch = 0 585 | 586 | h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 587 | 588 | # Initial load train dataset 589 | tr_feeder = load_data(train_feeder) 590 | 591 | ### Handling the resume option: 592 | if True: #if Resume: 593 | # Check if checkpoint from previous run is not corrupted. 594 | # Then overwrite some of the variables above. 595 | iters_to_consume, res_path, epoch, total_iters,\ 596 | [lowest_valid_cost, corresponding_test_cost, test_cost] = \ 597 | lib.resumable(path=FOLDER_PREFIX, 598 | iter_key=iter_str, 599 | epoch_key=epoch_str, 600 | add_resume_counter=True, 601 | other_keys=[lowest_valid_str, 602 | corresp_test_str, 603 | test_nll_str]) 604 | # At this point we saved the pkl file. 605 | last_print_iters = total_iters 606 | print "### RESUMING JOB FROM EPOCH {}, ITER {}".format(epoch, total_iters) 607 | # Consumes this much iters to get to the last point in training data. 608 | consume_time = time() 609 | for i in xrange(iters_to_consume): 610 | tr_feeder.next() 611 | consume_time = time() - consume_time 612 | print "Train data ready in {:.2f}secs after consuming {} minibatches.".\ 613 | format(consume_time, iters_to_consume) 614 | 615 | lib.load_params(res_path) 616 | print "Parameters from last available checkpoint loaded." 617 | 618 | 619 | 620 | # 2. Stdout the training progress 621 | print_info = "epoch:{}\ttotal iters:{}\twall clock time:{:.2f}h\n" 622 | print_info = print_info.format(epoch, 623 | total_iters, 624 | (time()-exp_start)/3600) 625 | print print_info 626 | 627 | tag = "e{}_i{}" 628 | tag = tag.format(epoch, 629 | total_iters) 630 | 631 | # 5. Generate and save samples (time consuming) 632 | # If not successful, we still have the params to sample afterward 633 | print "Sampling!", 634 | # Generate samples 635 | generate_and_save_samples(tag, N_SECS) 636 | print "Done!" 637 | 638 | print "Wall clock time spent: {:.2f}h"\ 639 | .format((time()-exp_start)/3600) 640 | 641 | sys.exit() -------------------------------------------------------------------------------- /models/one_tier/one_tier.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN Audio Generation Model 3 | 4 | one-tier model, Quantized input 5 | For more info: 6 | $ python one_tier.py -h 7 | 8 | How-to-run example: 9 | sampleRNN$ pwd 10 | /u/mehris/sampleRNN 11 | 12 | sampleRNN$ \ 13 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python -u \ 14 | models/one_tier/one_tier.py --exp AXIS1 --seq_len 10 --weight_norm True \ 15 | --emb_size 64 --skip_conn False --dim 32 --n_rnn 2 --rnn_type LSTM --learn_h0 \ 16 | False --q_levels 16 --q_type linear --batch_size 128 --which_set MUSIC 17 | 18 | To resume add ` --resume` to the END of the EXACTLY above line. You can run the 19 | resume code as many time as possible, depending on the TRAIN_MODE. 20 | (folder name, file name, flags, their order, and the values are important) 21 | """ 22 | from time import time 23 | from datetime import datetime 24 | print "Experiment started at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 25 | exp_start = time() 26 | 27 | import os, sys, glob 28 | sys.path.insert(1, os.getcwd()) 29 | import argparse 30 | 31 | import numpy 32 | numpy.random.seed(123) 33 | np = numpy 34 | import random 35 | random.seed(123) 36 | 37 | import theano 38 | import theano.tensor as T 39 | import theano.ifelse 40 | import lasagne 41 | import scipy.io.wavfile 42 | 43 | import lib 44 | 45 | LEARNING_RATE = 0.001 46 | 47 | ### Parsing passed args/hyperparameters ### 48 | def get_args(): 49 | def t_or_f(arg): 50 | ua = str(arg).upper() 51 | if 'TRUE'.startswith(ua): 52 | return True 53 | elif 'FALSE'.startswith(ua): 54 | return False 55 | else: 56 | raise ValueError('Arg is neither `True` nor `False`') 57 | 58 | def check_non_negative(value): 59 | ivalue = int(value) 60 | if ivalue < 0: 61 | raise argparse.ArgumentTypeError("%s is not non-negative!" % value) 62 | return ivalue 63 | 64 | def check_positive(value): 65 | ivalue = int(value) 66 | if ivalue < 1: 67 | raise argparse.ArgumentTypeError("%s is not positive!" % value) 68 | return ivalue 69 | 70 | def check_unit_interval(value): 71 | fvalue = float(value) 72 | if fvalue < 0 or fvalue > 1: 73 | raise argparse.ArgumentTypeError("%s is not in [0, 1] interval!" % value) 74 | return fvalue 75 | 76 | # No default value here. Indicate every single arguement. 77 | parser = argparse.ArgumentParser( 78 | description='one_tier.py\nNo default value! Indicate every argument.') 79 | 80 | # Hyperparameter arguements: 81 | parser.add_argument('--exp', help='Experiment name', 82 | type=str, required=False, default='_') 83 | parser.add_argument('--seq_len', help='How many audio samples to include\ 84 | in each truncated BPTT pass', type=check_positive, required=True) 85 | parser.add_argument('--weight_norm', help='Adding learnable weight normalization\ 86 | to all the linear layers (except for the embedding layer)',\ 87 | type=t_or_f, required=True) 88 | parser.add_argument('--emb_size', help='Size of embedding layer (> 0)', 89 | type=check_positive, required=True) # different than two_tier 90 | parser.add_argument('--skip_conn', help='Add skip connections to RNN', type=t_or_f, required=True) 91 | parser.add_argument('--dim', help='Dimension of RNN and MLPs',\ 92 | type=check_positive, required=True) 93 | parser.add_argument('--n_rnn', help='Number of layers in the stacked RNN', 94 | type=check_positive, choices=xrange(1,6), required=True) 95 | parser.add_argument('--rnn_type', help='GRU or LSTM', choices=['LSTM', 'GRU'],\ 96 | required=True) 97 | parser.add_argument('--learn_h0', help='Whether to learn the initial state of RNN',\ 98 | type=t_or_f, required=True) 99 | parser.add_argument('--q_levels', help='Number of bins for quantization of audio samples. Should be 256 for mu-law.',\ 100 | type=check_positive, required=True) 101 | parser.add_argument('--q_type', help='Quantization in linear-scale, a-law-companding, or mu-law compandig. With mu-/a-law quantization level shoud be set as 256',\ 102 | choices=['linear', 'a-law', 'mu-law'], required=True) 103 | parser.add_argument('--which_set', help='ONOM, BLIZZ, MUSIC, HENDRIX, GLASS, or COBAIN', 104 | choices=['ONOM', 'BLIZZ', 'MUSIC', 'HENDRIX', 'COBAIN', 'GLASS'], required=True) 105 | parser.add_argument('--batch_size', help='size of mini-batch', 106 | type=check_positive, choices=xrange(1,10000), required=True) 107 | 108 | parser.add_argument('--debug', help='Debug mode', required=False, default=False, action='store_true') 109 | # NEW 110 | parser.add_argument('--resume', help='Resume the same model from the last checkpoint. Order of params are important. [for now]',\ 111 | required=False, default=False, action='store_true') 112 | 113 | args = parser.parse_args() 114 | 115 | # Create tag for this experiment based on passed args 116 | tag = reduce(lambda a, b: a+b, sys.argv).replace('--resume', '').replace('/', '-').replace('--', '-').replace('True', 'T').replace('False', 'F') 117 | tag += '-lr'+str(LEARNING_RATE) 118 | print "Created experiment tag for these args:" 119 | print tag 120 | 121 | return args, tag 122 | 123 | args, tag = get_args() 124 | 125 | SEQ_LEN = args.seq_len # How many audio samples to include in each truncated BPTT pass 126 | OVERLAP = 1 127 | WEIGHT_NORM = args.weight_norm 128 | EMB_SIZE = args.emb_size 129 | SKIP_CONN = args.skip_conn 130 | DIM = args.dim # Model dimensionality. 131 | N_RNN = args.n_rnn # How many RNNs to stack in the frame-level model 132 | RNN_TYPE = args.rnn_type 133 | H0_MULT = 2 if RNN_TYPE == 'LSTM' else 1 134 | LEARN_H0 = args.learn_h0 135 | Q_LEVELS = args.q_levels # How many levels to use when discretizing samples. e.g. 256 = 8-bit scalar quantization 136 | Q_TYPE = args.q_type # log- or linear-scale 137 | WHICH_SET = args.which_set 138 | BATCH_SIZE = args.batch_size 139 | RESUME = args.resume 140 | 141 | if Q_TYPE == 'mu-law' and Q_LEVELS != 256: 142 | raise ValueError('For mu-law Quantization levels should be exactly 256!') 143 | 144 | # Fixed hyperparams 145 | GRAD_CLIP = 1 # Elementwise grad clip threshold 146 | BITRATE = 16000 147 | 148 | # Other constants 149 | #TRAIN_MODE = 'iters' # To use PRINT_ITERS and STOP_ITERS 150 | TRAIN_MODE = 'time' # To use PRINT_TIME and STOP_TIME 151 | #TRAIN_MODE = 'time-iters' 152 | # To use PRINT_TIME for validation, 153 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 154 | #TRAIN_MODE = 'iters-time' 155 | # To use PRINT_ITERS for validation, 156 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 157 | PRINT_ITERS = 10000 # Print cost, generate samples, save model checkpoint every N iterations. 158 | STOP_ITERS = 100000 # Stop after this many iterations 159 | PRINT_TIME = 90*60 # Print cost, generate samples, save model checkpoint every N seconds. 160 | STOP_TIME = 60*60*24*3 # Stop after this many seconds of actual training (not including time req'd to generate samples etc.) 161 | N_SEQS = 20 # Number of samples to generate every time monitoring. 162 | RESULTS_DIR = 'results_1t' 163 | FOLDER_PREFIX = os.path.join(RESULTS_DIR, tag) 164 | Q_ZERO = numpy.int32(Q_LEVELS//2) # Discrete value correponding to zero amplitude 165 | 166 | epoch_str = 'epoch' 167 | iter_str = 'iter' 168 | lowest_valid_str = 'lowest valid cost' 169 | corresp_test_str = 'correponding test cost' 170 | train_nll_str, valid_nll_str, test_nll_str = \ 171 | 'train NLL (bits)', 'valid NLL (bits)', 'test NLL (bits)' 172 | 173 | if args.debug: 174 | import warnings 175 | warnings.warn('----------RUNNING IN DEBUG MODE----------') 176 | TRAIN_MODE = 'time' 177 | PRINT_TIME = 100 178 | STOP_TIME = 3000 179 | STOP_ITERS = 1000 180 | 181 | ### Create directories ### 182 | # FOLDER_PREFIX: root, contains: 183 | # log.txt, __note.txt, train_log.pkl, train_log.png [, model_settings.txt] 184 | # FOLDER_PREFIX/params: saves all checkpoint params as pkl 185 | # FOLDER_PREFIX/samples: keeps all checkpoint samples as wav 186 | # FOLDER_PREFIX/best: keeps the best parameters, samples, ... 187 | if not os.path.exists(FOLDER_PREFIX): 188 | os.makedirs(FOLDER_PREFIX) 189 | PARAMS_PATH = os.path.join(FOLDER_PREFIX, 'params') 190 | if not os.path.exists(PARAMS_PATH): 191 | os.makedirs(PARAMS_PATH) 192 | SAMPLES_PATH = os.path.join(FOLDER_PREFIX, 'samples') 193 | if not os.path.exists(SAMPLES_PATH): 194 | os.makedirs(SAMPLES_PATH) 195 | BEST_PATH = os.path.join(FOLDER_PREFIX, 'best') 196 | if not os.path.exists(BEST_PATH): 197 | os.makedirs(BEST_PATH) 198 | 199 | lib.print_model_settings(locals(), path=FOLDER_PREFIX, sys_arg=True) 200 | 201 | ### Import the data_feeder ### 202 | # Handling WHICH_SET 203 | if WHICH_SET == 'ONOM': 204 | from datasets.dataset import onom_train_feed_epoch as train_feeder 205 | from datasets.dataset import onom_valid_feed_epoch as valid_feeder 206 | from datasets.dataset import onom_test_feed_epoch as test_feeder 207 | elif WHICH_SET == 'BLIZZ': 208 | from datasets.dataset import blizz_train_feed_epoch as train_feeder 209 | from datasets.dataset import blizz_valid_feed_epoch as valid_feeder 210 | from datasets.dataset import blizz_test_feed_epoch as test_feeder 211 | elif WHICH_SET == 'MUSIC': 212 | from datasets.dataset import music_train_feed_epoch as train_feeder 213 | from datasets.dataset import music_valid_feed_epoch as valid_feeder 214 | from datasets.dataset import music_test_feed_epoch as test_feeder 215 | elif WHICH_SET == 'HENDRIX': 216 | from datasets.dataset import hendrix_train_feed_epoch as train_feeder 217 | from datasets.dataset import hendrix_valid_feed_epoch as valid_feeder 218 | from datasets.dataset import hendrix_test_feed_epoch as test_feeder 219 | elif WHICH_SET == 'COBAIN': 220 | from datasets.dataset import cobain_train_feed_epoch as train_feeder 221 | from datasets.dataset import cobain_valid_feed_epoch as valid_feeder 222 | from datasets.dataset import cobain_test_feed_epoch as test_feeder 223 | elif WHICH_SET == 'GLASS': 224 | from datasets.dataset import glass_train_feed_epoch as train_feeder 225 | from datasets.dataset import glass_valid_feed_epoch as valid_feeder 226 | from datasets.dataset import glass_test_feed_epoch as test_feeder 227 | 228 | def load_data(data_feeder): 229 | """ 230 | Helper function to deal with interface of different datasets. 231 | `data_feeder` should be `train_feeder`, `valid_feeder`, or `test_feeder`. 232 | """ 233 | return data_feeder(BATCH_SIZE, 234 | SEQ_LEN, 235 | OVERLAP, 236 | Q_LEVELS, 237 | Q_ZERO, 238 | Q_TYPE) 239 | 240 | ### Creating computation graph ### 241 | def sample_level_rnn(input_sequences, h0, reset): 242 | """ 243 | input_sequences.shape: (batch size, seq len) 244 | h0.shape: (batch size, N_RNN, DIM) 245 | reset.shape: () 246 | output.shape: (batch size, seq len, DIM) 247 | """ 248 | 249 | # Embedded inputs 250 | # Handling EMB_SIZE 251 | ################# 252 | FRAME_SIZE = EMB_SIZE 253 | frames = lib.ops.Embedding( 254 | 'SampleLevel.Embedding', 255 | Q_LEVELS, 256 | EMB_SIZE, 257 | input_sequences) 258 | 259 | # Real-valued inputs 260 | #################### 261 | # # 'frames' of size 1 262 | # FRAME_SIZE = 1 263 | # frames = input_sequences.reshape(( 264 | # input_sequences.shape[0], 265 | # input_sequences.shape[1], 266 | # 1 267 | # )) 268 | # # Rescale frames from ints in [0, Q_LEVELS) to floats in [-2, 2] 269 | # # (a reasonable range to pass as inputs to the RNN) 270 | # frames = (frames.astype('float32') / lib.floatX(Q_LEVELS/2)) - lib.floatX(1) 271 | # frames *= lib.floatX(2) 272 | 273 | # Initial state of RNNs 274 | learned_h0 = lib.param( 275 | 'SampleLevel.h0', 276 | numpy.zeros((N_RNN, H0_MULT*DIM), dtype=theano.config.floatX) 277 | ) 278 | # Handling LEARN_H0 279 | learned_h0.param = LEARN_H0 280 | learned_h0 = T.alloc(learned_h0, h0.shape[0], N_RNN, H0_MULT*DIM) 281 | learned_h0 = T.unbroadcast(learned_h0, 0, 1, 2) 282 | h0 = theano.ifelse.ifelse(reset, learned_h0, h0) 283 | 284 | # Handling RNN_TYPE 285 | # Handling SKIP_CONN 286 | if RNN_TYPE == 'GRU': 287 | rnns_out, last_hidden = lib.ops.stackedGRU('SampleLevel.GRU', 288 | N_RNN, 289 | FRAME_SIZE, 290 | DIM, 291 | frames, 292 | h0=h0, 293 | weightnorm=WEIGHT_NORM, 294 | skip_conn=SKIP_CONN) 295 | elif RNN_TYPE == 'LSTM': 296 | rnns_out, last_hidden = lib.ops.stackedLSTM('SampleLevel.LSTM', 297 | N_RNN, 298 | FRAME_SIZE, 299 | DIM, 300 | frames, 301 | h0=h0, 302 | weightnorm=WEIGHT_NORM, 303 | skip_conn=SKIP_CONN) 304 | 305 | out = lib.ops.Linear( 306 | 'SampleLevel.L1', 307 | DIM, 308 | DIM, 309 | rnns_out, 310 | initialization='he', 311 | weightnorm=WEIGHT_NORM 312 | ) 313 | out = T.nnet.relu(out) 314 | 315 | out = lib.ops.Linear( 316 | 'SampleLevel.L2', 317 | DIM, 318 | DIM, 319 | out, 320 | initialization='he', 321 | weightnorm=WEIGHT_NORM 322 | ) 323 | out = T.nnet.relu(out) 324 | 325 | out = lib.ops.Linear( 326 | 'SampleLevel.L3', 327 | DIM, 328 | DIM, 329 | out, 330 | initialization='he', 331 | weightnorm=WEIGHT_NORM 332 | ) 333 | out = T.nnet.relu(out) 334 | 335 | # We apply the softmax later 336 | out = lib.ops.Linear( 337 | 'SampleLevel.Output', 338 | DIM, 339 | Q_LEVELS, 340 | out, 341 | initialization='he', 342 | weightnorm=WEIGHT_NORM 343 | ) 344 | 345 | return (out, last_hidden) 346 | 347 | sequences = T.imatrix('sequences') 348 | h0 = T.tensor3('h0') 349 | reset = T.iscalar('reset') 350 | mask = T.matrix('mask') 351 | 352 | if args.debug: 353 | # Solely for debugging purposes. 354 | # Maybe I should set the compute_test_value=warn from here. 355 | sequences.tag.test_value = numpy.zeros((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='int32') 356 | h0.tag.test_value = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 357 | reset.tag.test_value = numpy.array(1, dtype='int32') 358 | mask.tag.test_value = numpy.ones((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='float32') 359 | 360 | input_sequences = sequences[:, :-1] 361 | target_sequences = sequences[:, 1:] 362 | 363 | target_mask = mask[:, 1:] 364 | 365 | sample_level_outputs, new_h0 = sample_level_rnn(input_sequences, h0, reset) 366 | 367 | cost = T.nnet.categorical_crossentropy( 368 | T.nnet.softmax(sample_level_outputs.reshape((-1, Q_LEVELS))), 369 | target_sequences.flatten() 370 | ) 371 | cost = cost.reshape(target_sequences.shape) 372 | cost = cost * target_mask 373 | # Don't use these lines; could end up with NaN 374 | # Specially at the end of audio files where mask is 375 | # all zero for some of the shorter files in mini-batch. 376 | #cost = cost.sum(axis=1) / target_mask.sum(axis=1) 377 | #cost = cost.mean(axis=0) 378 | 379 | # Use this one instead. 380 | cost = cost.sum() 381 | cost = cost / target_mask.sum() 382 | 383 | # By default we report cross-entropy cost in bits. 384 | # Switch to nats by commenting out this line: 385 | # log_2(e) = 1.44269504089 386 | cost = cost * lib.floatX(numpy.log2(numpy.e)) 387 | 388 | ### Getting the params, grads, updates, and Theano functions ### 389 | params = lib.get_params(cost, lambda x: hasattr(x, 'param') and x.param==True) 390 | lib.print_params_info(params, path=FOLDER_PREFIX) 391 | 392 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 393 | grads = [T.clip(g, lib.floatX(-GRAD_CLIP), lib.floatX(GRAD_CLIP)) for g in grads] 394 | 395 | updates = lasagne.updates.adam(grads, params, learning_rate=LEARNING_RATE) 396 | 397 | # Training function 398 | train_fn = theano.function( 399 | [sequences, h0, reset, mask], 400 | [cost, new_h0], 401 | updates=updates, 402 | on_unused_input='warn' 403 | ) 404 | 405 | # Validation and Test function, hence no updates 406 | test_fn = theano.function( 407 | [sequences, h0, reset, mask], 408 | [cost, new_h0], 409 | on_unused_input='warn' 410 | ) 411 | 412 | # Sampling at audio sample level 413 | generate_outputs, generate_new_h0 = sample_level_rnn(sequences, h0, reset) 414 | generate_fn = theano.function( 415 | [sequences, h0, reset], 416 | [lib.ops.softmax_and_sample(generate_outputs), generate_new_h0], 417 | on_unused_input='warn' 418 | ) 419 | 420 | # Uniform [-0.5, 0.5) for half of initial state for generated samples 421 | # to study the behaviour of the model and also to introduce some diversity 422 | # to samples in a simple way. [it's disabled] 423 | fixed_rand_h0 = numpy.random.rand(N_SEQS//2, N_RNN, H0_MULT*DIM) 424 | fixed_rand_h0 -= 0.5 425 | fixed_rand_h0 = fixed_rand_h0.astype('float32') 426 | 427 | def generate_and_save_samples(tag): 428 | def write_audio_file(name, data): 429 | data = data.astype('float32') 430 | data -= data.min() 431 | data /= data.max() 432 | data -= 0.5 433 | data *= 0.95 434 | scipy.io.wavfile.write( 435 | os.path.join(SAMPLES_PATH, name+'.wav'), 436 | BITRATE, 437 | data) 438 | 439 | total_time = time() 440 | # Generate N_SEQS' sample files, each 5 seconds long 441 | N_SECS = 5 442 | LENGTH = N_SECS*BITRATE if not args.debug else 100 443 | 444 | samples = numpy.zeros((N_SEQS, LENGTH), dtype='int32') 445 | samples[:, 0] = Q_ZERO 446 | 447 | # First half zero, others fixed random at each checkpoint 448 | h0 = numpy.zeros( 449 | (N_SEQS-fixed_rand_h0.shape[0], N_RNN, H0_MULT*DIM), 450 | dtype='float32' 451 | ) 452 | h0 = numpy.concatenate((h0, fixed_rand_h0), axis=0) 453 | 454 | for t in xrange(1, LENGTH): 455 | samples[:, t:t+1], h0 = generate_fn( 456 | samples[:, t-1:t], 457 | h0, 458 | numpy.int32(t == 1) 459 | ) 460 | 461 | total_time = time() - total_time 462 | log = "{} samples of {} seconds length generated in {} seconds." 463 | log = log.format(N_SEQS, N_SECS, total_time) 464 | print log, 465 | 466 | for i in xrange(N_SEQS): 467 | samp = samples[i] 468 | if Q_TYPE == 'mu-law': 469 | from datasets.dataset import mu2linear 470 | samp = mu2linear(samp) 471 | elif Q_TYPE == 'a-law': 472 | raise NotImplementedError('a-law is not implemented') 473 | write_audio_file("sample_{}_{}".format(tag, i), samp) 474 | 475 | def monitor(data_feeder): 476 | """ 477 | Cost and time of test_fn on a given dataset section. 478 | Pass only one of `valid_feeder` or `test_feeder`. 479 | Don't pass `train_feed`. 480 | 481 | :returns: 482 | Mean cost over the input dataset (data_feeder) 483 | Total time spent 484 | """ 485 | _total_time = time() 486 | _h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 487 | _costs = [] 488 | _data_feeder = load_data(data_feeder) 489 | for _seqs, _reset, _mask in _data_feeder: 490 | _cost, _h0 = test_fn(_seqs, _h0, _reset, _mask) 491 | _costs.append(_cost) 492 | 493 | return numpy.mean(_costs), time() - _total_time 494 | 495 | print "Wall clock time spent before training started: {:.2f}h"\ 496 | .format((time()-exp_start)/3600.) 497 | print "Training!" 498 | total_iters = 0 499 | total_time = 0. 500 | last_print_time = 0. 501 | last_print_iters = 0 502 | costs = [] 503 | lowest_valid_cost = numpy.finfo(numpy.float32).max 504 | corresponding_test_cost = numpy.finfo(numpy.float32).max 505 | new_lowest_cost = False 506 | end_of_batch = False 507 | epoch = 0 508 | 509 | h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 510 | 511 | # Initial load train dataset 512 | tr_feeder = load_data(train_feeder) 513 | 514 | ### Handling the resume option: 515 | if RESUME: 516 | # Check if checkpoint from previous run is not corrupted. 517 | # Then overwrite some of the variables above. 518 | iters_to_consume, res_path, epoch, total_iters,\ 519 | [lowest_valid_cost, corresponding_test_cost, test_cost] = \ 520 | lib.resumable(path=FOLDER_PREFIX, 521 | iter_key=iter_str, 522 | epoch_key=epoch_str, 523 | add_resume_counter=True, 524 | other_keys=[lowest_valid_str, 525 | corresp_test_str, 526 | test_nll_str]) 527 | # At this point we saved the pkl file. 528 | last_print_iters = total_iters 529 | print "### RESUMING JOB FROM EPOCH {}, ITER {}".format(epoch, total_iters) 530 | # Consumes this much iters to get to the last point in training data. 531 | consume_time = time() 532 | for i in xrange(iters_to_consume): 533 | tr_feeder.next() 534 | consume_time = time() - consume_time 535 | print "Train data ready in {:.2f}secs after consuming {} minibatches.".\ 536 | format(consume_time, iters_to_consume) 537 | 538 | lib.load_params(res_path) 539 | print "Parameters from last available checkpoint loaded." 540 | 541 | while True: 542 | # THIS IS ONE ITERATION 543 | if total_iters % 500 == 0: 544 | print total_iters, 545 | 546 | total_iters += 1 547 | 548 | try: 549 | # Take as many mini-batches as possible from train set 550 | mini_batch = tr_feeder.next() 551 | except StopIteration: 552 | # Mini-batches are finished. Load it again. 553 | # Basically, one epoch. 554 | tr_feeder = load_data(train_feeder) 555 | 556 | # and start taking new mini-batches again. 557 | mini_batch = tr_feeder.next() 558 | epoch += 1 559 | end_of_batch = True 560 | print "[Another epoch]", 561 | 562 | seqs, reset, mask = mini_batch 563 | 564 | start_time = time() 565 | cost, h0 = train_fn(seqs, h0, reset, mask) 566 | total_time += time() - start_time 567 | #print "This cost:", cost, "This h0.mean()", h0.mean() 568 | 569 | costs.append(cost) 570 | 571 | # Monitoring step 572 | if (TRAIN_MODE=='iters' and total_iters-last_print_iters == PRINT_ITERS) or \ 573 | (TRAIN_MODE=='time' and total_time-last_print_time >= PRINT_TIME) or \ 574 | (TRAIN_MODE=='time-iters' and total_time-last_print_time >= PRINT_TIME) or \ 575 | (TRAIN_MODE=='iters-time' and total_iters-last_print_iters >= PRINT_ITERS) or \ 576 | end_of_batch: 577 | # 0. Validation 578 | print "\nValidation!", 579 | valid_cost, valid_time = monitor(valid_feeder) 580 | print "Done!" 581 | 582 | # 1. Test 583 | test_time = 0. 584 | # Only when the validation cost is improved get the cost for test set. 585 | if valid_cost < lowest_valid_cost: 586 | lowest_valid_cost = valid_cost 587 | print "\n>>> Best validation cost of {} reached. Testing!"\ 588 | .format(valid_cost), 589 | test_cost, test_time = monitor(test_feeder) 590 | print "Done!" 591 | # Report last one which is the lowest on validation set: 592 | print ">>> test cost:{}\ttotal time:{}".format(test_cost, test_time) 593 | corresponding_test_cost = test_cost 594 | new_lowest_cost = True 595 | 596 | # 2. Stdout the training progress 597 | print_info = "epoch:{}\ttotal iters:{}\twall clock time:{:.2f}h\n" 598 | print_info += ">>> Lowest valid cost:{}\t Corresponding test cost:{}\n" 599 | print_info += "\ttrain cost:{:.4f}\ttotal time:{:.2f}h\tper iter:{:.3f}s\n" 600 | print_info += "\tvalid cost:{:.4f}\ttotal time:{:.2f}h\n" 601 | print_info += "\ttest cost:{:.4f}\ttotal time:{:.2f}h" 602 | print_info = print_info.format(epoch, 603 | total_iters, 604 | (time()-exp_start)/3600, 605 | lowest_valid_cost, 606 | corresponding_test_cost, 607 | numpy.mean(costs), 608 | total_time/3600, 609 | total_time/total_iters, 610 | valid_cost, 611 | valid_time/3600, 612 | test_cost, 613 | test_time/3600) 614 | print print_info 615 | 616 | tag = "e{}_i{}_t{:.2f}_tr{:.4f}_v{:.4f}" 617 | tag = tag.format(epoch, 618 | total_iters, 619 | total_time/3600, 620 | numpy.mean(cost), 621 | valid_cost) 622 | tag += ("_best" if new_lowest_cost else "") 623 | 624 | # 3. Save params of model (IO bound, time consuming) 625 | # If saving params is not successful, there shouldn't be any trace of 626 | # successful monitoring step in train_log as well. 627 | print "Saving params!", 628 | lib.save_params( 629 | os.path.join(PARAMS_PATH, 'params_{}.pkl'.format(tag)) 630 | ) 631 | print "Done!" 632 | 633 | # 4. Save and graph training progress (fast) 634 | training_info = {epoch_str : epoch, 635 | iter_str : total_iters, 636 | train_nll_str : numpy.mean(costs), 637 | valid_nll_str : valid_cost, 638 | test_nll_str : test_cost, 639 | lowest_valid_str : lowest_valid_cost, 640 | corresp_test_str : corresponding_test_cost, 641 | 'train time' : total_time, 642 | 'valid time' : valid_time, 643 | 'test time' : test_time, 644 | 'wall clock time' : time()-exp_start} 645 | lib.save_training_info(training_info, FOLDER_PREFIX) 646 | print "Train info saved!", 647 | 648 | y_axis_strs = [train_nll_str, valid_nll_str, test_nll_str] 649 | lib.plot_traing_info(iter_str, y_axis_strs, FOLDER_PREFIX) 650 | print "And plotted!" 651 | 652 | # 5. Generate and save samples (time consuming) 653 | # If not successful, we still have the params to sample afterward 654 | print "Sampling!", 655 | # Generate samples 656 | generate_and_save_samples(tag) 657 | print "Done!" 658 | 659 | if total_iters-last_print_iters == PRINT_ITERS \ 660 | or total_time-last_print_time >= PRINT_TIME: 661 | # If we are here b/c of onom_end_of_batch, we shouldn't mess 662 | # with costs and last_print_iters 663 | costs = [] 664 | last_print_time += PRINT_TIME 665 | last_print_iters += PRINT_ITERS 666 | 667 | end_of_batch = False 668 | new_lowest_cost = False 669 | 670 | print "Validation Done!\nBack to Training..." 671 | 672 | if (TRAIN_MODE=='iters' and total_iters == STOP_ITERS) or \ 673 | (TRAIN_MODE=='time' and total_time >= STOP_TIME) or \ 674 | ((TRAIN_MODE=='time-iters' or TRAIN_MODE=='iters-time') and \ 675 | (total_iters == STOP_ITERS or total_time >= STOP_TIME)): 676 | 677 | print "Done! Total iters:", total_iters, "Total time: ", total_time 678 | print "Experiment ended at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 679 | print "Wall clock time spent: {:.2f}h"\ 680 | .format((time()-exp_start)/3600) 681 | 682 | sys.exit() 683 | -------------------------------------------------------------------------------- /models/two_tier/two_tier16k.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN Audio Generation Model 3 | 4 | Two-tier model, Quantized input 5 | For more info: 6 | $ python two_tier.py -h 7 | 8 | How-to-run example: 9 | sampleRNN$ pwd 10 | /u/mehris/sampleRNN 11 | 12 | sampleRNN$ \ 13 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python -u \ 14 | models/two_tier/two_tier.py --exp AXIS1 --n_frames 12 --frame_size 10 \ 15 | --weight_norm True --emb_size 64 --skip_conn False --dim 32 --n_rnn 2 \ 16 | --rnn_type LSTM --learn_h0 False --q_levels 16 --q_type linear \ 17 | --batch_size 128 --which_set MUSIC 18 | 19 | To resume add ` --resume` to the END of the EXACTLY above line. You can run the 20 | resume code as many time as possible, depending on the TRAIN_MODE. 21 | (folder name, file name, flags, their order, and the values are important) 22 | """ 23 | from time import time 24 | from datetime import datetime 25 | print "Experiment started at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 26 | exp_start = time() 27 | 28 | import os, sys, glob 29 | sys.path.insert(1, os.getcwd()) 30 | import argparse 31 | 32 | import numpy 33 | numpy.random.seed(123) 34 | np = numpy 35 | import random 36 | random.seed(123) 37 | 38 | import theano 39 | import theano.tensor as T 40 | import theano.ifelse 41 | import lasagne 42 | import scipy.io.wavfile 43 | 44 | import lib 45 | 46 | LEARNING_RATE = 0.001 47 | 48 | ### Parsing passed args/hyperparameters ### 49 | def get_args(): 50 | def t_or_f(arg): 51 | ua = str(arg).upper() 52 | if 'TRUE'.startswith(ua): 53 | return True 54 | elif 'FALSE'.startswith(ua): 55 | return False 56 | else: 57 | raise ValueError('Arg is neither `True` nor `False`') 58 | 59 | def check_non_negative(value): 60 | ivalue = int(value) 61 | if ivalue < 0: 62 | raise argparse.ArgumentTypeError("%s is not non-negative!" % value) 63 | return ivalue 64 | 65 | def check_positive(value): 66 | ivalue = int(value) 67 | if ivalue < 1: 68 | raise argparse.ArgumentTypeError("%s is not positive!" % value) 69 | return ivalue 70 | 71 | def check_unit_interval(value): 72 | fvalue = float(value) 73 | if fvalue < 0 or fvalue > 1: 74 | raise argparse.ArgumentTypeError("%s is not in [0, 1] interval!" % value) 75 | return fvalue 76 | 77 | # No default value here. Indicate every single arguement. 78 | parser = argparse.ArgumentParser( 79 | description='two_tier.py\nNo default value! Indicate every argument.') 80 | 81 | # Hyperparameter arguements: 82 | parser.add_argument('--exp', help='Experiment name', 83 | type=str, required=False, default='_') 84 | parser.add_argument('--n_frames', help='How many "frames" to include in each\ 85 | Truncated BPTT pass', type=check_positive, required=True) 86 | parser.add_argument('--frame_size', help='How many samples per frame',\ 87 | type=check_positive, required=True) 88 | parser.add_argument('--weight_norm', help='Adding learnable weight normalization\ 89 | to all the linear layers (except for the embedding layer)',\ 90 | type=t_or_f, required=True) 91 | parser.add_argument('--emb_size', help='Size of embedding layer (0 to disable)', type=check_non_negative, required=True) 92 | parser.add_argument('--skip_conn', help='Add skip connections to RNN', type=t_or_f, required=True) 93 | parser.add_argument('--dim', help='Dimension of RNN and MLPs',\ 94 | type=check_positive, required=True) 95 | parser.add_argument('--n_rnn', help='Number of layers in the stacked RNN', 96 | type=check_positive, choices=xrange(1,40), required=True) 97 | parser.add_argument('--rnn_type', help='GRU or LSTM', choices=['LSTM', 'GRU'],\ 98 | required=True) 99 | parser.add_argument('--learn_h0', help='Whether to learn the initial state of RNN',\ 100 | type=t_or_f, required=True) 101 | parser.add_argument('--q_levels', help='Number of bins for quantization of audio samples. Should be 256 for mu-law.',\ 102 | type=check_positive, required=True) 103 | parser.add_argument('--q_type', help='Quantization in linear-scale, a-law-companding, or mu-law compandig. With mu-/a-law quantization level shoud be set as 256',\ 104 | choices=['linear', 'a-law', 'mu-law'], required=True) 105 | parser.add_argument('--which_set', help='the directory name of the dataset' , 106 | type=str, required=True) 107 | parser.add_argument('--batch_size', help='size of mini-batch', 108 | type=check_positive, choices=xrange(1,10000), required=True) 109 | 110 | parser.add_argument('--debug', help='Debug mode', required=False, default=False, action='store_true') 111 | # NEW 112 | parser.add_argument('--resume', help='Resume the same model from the last checkpoint. Order of params are important. [for now]',\ 113 | required=False, default=False, action='store_true') 114 | 115 | args = parser.parse_args() 116 | 117 | # NEW 118 | # Create tag for this experiment based on passed args 119 | tag = reduce(lambda a, b: a+b, sys.argv).replace('--resume', '').replace('/', '-').replace('--', '-').replace('True', 'T').replace('False', 'F') 120 | tag += '-lr'+str(LEARNING_RATE) 121 | print "Created experiment tag for these args:" 122 | print tag 123 | 124 | return args, tag 125 | 126 | args, tag = get_args() 127 | 128 | N_FRAMES = args.n_frames # How many 'frames' to include in each truncated BPTT pass 129 | OVERLAP = FRAME_SIZE = args.frame_size # How many samples per frame 130 | WEIGHT_NORM = args.weight_norm 131 | EMB_SIZE = args.emb_size 132 | SKIP_CONN = args.skip_conn 133 | DIM = args.dim # Model dimensionality. 134 | N_RNN = args.n_rnn # How many RNNs to stack 135 | RNN_TYPE = args.rnn_type 136 | H0_MULT = 2 if RNN_TYPE == 'LSTM' else 1 137 | LEARN_H0 = args.learn_h0 138 | Q_LEVELS = args.q_levels # How many levels to use when discretizing samples. e.g. 256 = 8-bit scalar quantization 139 | Q_TYPE = args.q_type # log- or linear-scale 140 | WHICH_SET = args.which_set 141 | BATCH_SIZE = args.batch_size 142 | RESUME = args.resume 143 | 144 | if Q_TYPE == 'mu-law' and Q_LEVELS != 256: 145 | raise ValueError('For mu-law Quantization levels should be exactly 256!') 146 | 147 | # Fixed hyperparams 148 | GRAD_CLIP = 1 # Elementwise grad clip threshold 149 | BITRATE = 16000 150 | 151 | # Other constants 152 | #TRAIN_MODE = 'iters' # To use PRINT_ITERS and STOP_ITERS 153 | TRAIN_MODE = 'time' # To use PRINT_TIME and STOP_TIME 154 | #TRAIN_MODE = 'time-iters' 155 | # To use PRINT_TIME for validation, 156 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 157 | #TRAIN_MODE = 'iters-time' 158 | # To use PRINT_ITERS for validation, 159 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 160 | PRINT_ITERS = 10000 # Print cost, generate samples, save model checkpoint every N iterations. 161 | STOP_ITERS = 100000 # Stop after this many iterations 162 | # TODO: 163 | PRINT_TIME = 90*60 # Print cost, generate samples, save model checkpoint every N seconds. 164 | STOP_TIME = 60*60*24*3 # Stop after this many seconds of actual training (not including time req'd to generate samples etc.) 165 | N_SEQS = 10 # Number of samples to generate every time monitoring. 166 | # TODO: 167 | RESULTS_DIR = 'results_2t' 168 | FOLDER_PREFIX = os.path.join(RESULTS_DIR, tag) 169 | SEQ_LEN = N_FRAMES * FRAME_SIZE # Total length (# of samples) of each truncated BPTT sequence 170 | Q_ZERO = numpy.int32(Q_LEVELS//2) # Discrete value correponding to zero amplitude 171 | 172 | 173 | epoch_str = 'epoch' 174 | iter_str = 'iter' 175 | lowest_valid_str = 'lowest valid cost' 176 | corresp_test_str = 'correponding test cost' 177 | train_nll_str, valid_nll_str, test_nll_str = \ 178 | 'train NLL (bits)', 'valid NLL (bits)', 'test NLL (bits)' 179 | 180 | if args.debug: 181 | import warnings 182 | warnings.warn('----------RUNNING IN DEBUG MODE----------') 183 | TRAIN_MODE = 'time' 184 | PRINT_TIME = 100 185 | STOP_TIME = 3000 186 | STOP_ITERS = 1000 187 | 188 | ### Create directories ### 189 | # FOLDER_PREFIX: root, contains: 190 | # log.txt, __note.txt, train_log.pkl, train_log.png [, model_settings.txt] 191 | # FOLDER_PREFIX/params: saves all checkpoint params as pkl 192 | # FOLDER_PREFIX/samples: keeps all checkpoint samples as wav 193 | # FOLDER_PREFIX/best: keeps the best parameters, samples, ... 194 | if not os.path.exists(FOLDER_PREFIX): 195 | os.makedirs(FOLDER_PREFIX) 196 | PARAMS_PATH = os.path.join(FOLDER_PREFIX, 'params') 197 | if not os.path.exists(PARAMS_PATH): 198 | os.makedirs(PARAMS_PATH) 199 | SAMPLES_PATH = os.path.join(FOLDER_PREFIX, 'samples') 200 | if not os.path.exists(SAMPLES_PATH): 201 | os.makedirs(SAMPLES_PATH) 202 | BEST_PATH = os.path.join(FOLDER_PREFIX, 'best') 203 | if not os.path.exists(BEST_PATH): 204 | os.makedirs(BEST_PATH) 205 | 206 | lib.print_model_settings(locals(), path=FOLDER_PREFIX, sys_arg=True) 207 | 208 | ### Import the data_feeder ### 209 | # Handling WHICH_SET 210 | from datasets.dataset import music_train_feed_epoch as train_feeder 211 | from datasets.dataset import music_valid_feed_epoch as valid_feeder 212 | from datasets.dataset import music_test_feed_epoch as test_feeder 213 | 214 | def load_data(data_feeder): 215 | """ 216 | Helper function to deal with interface of different datasets. 217 | `data_feeder` should be `train_feeder`, `valid_feeder`, or `test_feeder`. 218 | """ 219 | return data_feeder(WHICH_SET, BATCH_SIZE, 220 | SEQ_LEN, 221 | OVERLAP, 222 | Q_LEVELS, 223 | Q_ZERO, 224 | Q_TYPE) 225 | 226 | ### Creating computation graph ### 227 | def frame_level_rnn(input_sequences, h0, reset): 228 | """ 229 | input_sequences.shape: (batch size, n frames * FRAME_SIZE) 230 | h0.shape: (batch size, N_RNN, DIM) 231 | reset.shape: () 232 | 233 | output.shape: (batch size, n frames * FRAME_SIZE, DIM) 234 | """ 235 | frames = input_sequences.reshape(( 236 | input_sequences.shape[0], 237 | input_sequences.shape[1] // FRAME_SIZE, 238 | FRAME_SIZE 239 | )) 240 | 241 | # Rescale frames from ints in [0, Q_LEVELS) to floats in [-2, 2] 242 | # (a reasonable range to pass as inputs to the RNN) 243 | frames = (frames.astype('float32') / lib.floatX(Q_LEVELS/2)) - lib.floatX(1) 244 | frames *= lib.floatX(2) 245 | # (128, 64, 4) 246 | 247 | # Initial state of RNNs 248 | learned_h0 = lib.param( 249 | 'FrameLevel.h0', 250 | numpy.zeros((N_RNN, H0_MULT*DIM), dtype=theano.config.floatX) 251 | ) 252 | # Handling LEARN_H0 253 | learned_h0.param = LEARN_H0 254 | learned_h0 = T.alloc(learned_h0, h0.shape[0], N_RNN, H0_MULT*DIM) 255 | learned_h0 = T.unbroadcast(learned_h0, 0, 1, 2) 256 | h0 = theano.ifelse.ifelse(reset, learned_h0, h0) 257 | 258 | # Handling RNN_TYPE 259 | # Handling SKIP_CONN 260 | if RNN_TYPE == 'GRU': 261 | rnns_out, last_hidden = lib.ops.stackedGRU('FrameLevel.GRU', 262 | N_RNN, 263 | FRAME_SIZE, 264 | DIM, 265 | frames, 266 | h0=h0, 267 | weightnorm=WEIGHT_NORM, 268 | skip_conn=SKIP_CONN) 269 | elif RNN_TYPE == 'LSTM': 270 | rnns_out, last_hidden = lib.ops.stackedLSTM('FrameLevel.LSTM', 271 | N_RNN, 272 | FRAME_SIZE, 273 | DIM, 274 | frames, 275 | h0=h0, 276 | weightnorm=WEIGHT_NORM, 277 | skip_conn=SKIP_CONN) 278 | 279 | # rnns_out (bs, seqlen, dim) (128, 64, 512) 280 | output = lib.ops.Linear( 281 | 'FrameLevel.Output', 282 | DIM, 283 | FRAME_SIZE * DIM, 284 | rnns_out, 285 | initialization='he', 286 | weightnorm=WEIGHT_NORM 287 | ) 288 | # output: (2, 9, 4*dim) 289 | output = output.reshape((output.shape[0], output.shape[1] * FRAME_SIZE, DIM)) 290 | # output: (2, 9*4, dim) 291 | 292 | return (output, last_hidden) 293 | 294 | def sample_level_predictor(frame_level_outputs, prev_samples): 295 | """ 296 | batch size = BATCH_SIZE * SEQ_LEN 297 | SEQ_LEN = N_FRAMES * FRAME_SIZE 298 | 299 | frame_level_outputs.shape: (batch size, DIM) 300 | prev_samples.shape: (batch size, FRAME_SIZE) int32 301 | 302 | output.shape: (batch size, Q_LEVELS) 303 | """ 304 | # Handling EMB_SIZE 305 | if EMB_SIZE == 0: 306 | prev_samples = lib.ops.T_one_hot(prev_samples, Q_LEVELS) 307 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, Q_LEVELS) 308 | last_out_shape = Q_LEVELS 309 | elif EMB_SIZE > 0: 310 | prev_samples = lib.ops.Embedding( 311 | 'SampleLevel.Embedding', 312 | Q_LEVELS, 313 | EMB_SIZE, 314 | prev_samples) 315 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, EMB_SIZE), f32 316 | last_out_shape = EMB_SIZE 317 | else: 318 | raise ValueError('EMB_SIZE cannot be negative.') 319 | 320 | prev_samples = prev_samples.reshape((-1, FRAME_SIZE * last_out_shape)) 321 | 322 | out = lib.ops.Linear( 323 | 'SampleLevel.L1_PrevSamples', 324 | FRAME_SIZE * last_out_shape, 325 | DIM, 326 | prev_samples, 327 | biases=False, 328 | initialization='he', 329 | weightnorm=WEIGHT_NORM) 330 | # shape: (BATCH_SIZE*N_FRAMES*FRAME_SIZE, DIM) 331 | 332 | out += frame_level_outputs 333 | # ^ (2*(9*4), dim) 334 | 335 | # L2 336 | out = lib.ops.Linear('SampleLevel.L2', 337 | DIM, 338 | DIM, 339 | out, 340 | initialization='he', 341 | weightnorm=WEIGHT_NORM) 342 | out = T.nnet.relu(out) 343 | 344 | # L3 345 | out = lib.ops.Linear('SampleLevel.L3', 346 | DIM, 347 | DIM, 348 | out, 349 | initialization='he', 350 | weightnorm=WEIGHT_NORM) 351 | out = T.nnet.relu(out) 352 | 353 | # Output 354 | # We apply the softmax later 355 | out = lib.ops.Linear('SampleLevel.Output', 356 | DIM, 357 | Q_LEVELS, 358 | out, 359 | weightnorm=WEIGHT_NORM) 360 | return out 361 | 362 | sequences = T.imatrix('sequences') 363 | h0 = T.tensor3('h0') 364 | reset = T.iscalar('reset') 365 | mask = T.matrix('mask') 366 | 367 | if args.debug: 368 | # Solely for debugging purposes. 369 | # Maybe I should set the compute_test_value=warn from here. 370 | sequences.tag.test_value = numpy.zeros((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='int32') 371 | h0.tag.test_value = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 372 | reset.tag.test_value = numpy.array(1, dtype='int32') 373 | mask.tag.test_value = numpy.ones((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='float32') 374 | 375 | input_sequences = sequences[:, :-FRAME_SIZE] 376 | target_sequences = sequences[:, FRAME_SIZE:] 377 | 378 | target_mask = mask[:, FRAME_SIZE:] 379 | 380 | frame_level_outputs, new_h0 =\ 381 | frame_level_rnn(input_sequences, h0, reset) 382 | 383 | prev_samples = sequences[:, :-1] 384 | prev_samples = prev_samples.reshape((1, BATCH_SIZE, 1, -1)) 385 | prev_samples = T.nnet.neighbours.images2neibs(prev_samples, (1, FRAME_SIZE), neib_step=(1, 1), mode='valid') 386 | prev_samples = prev_samples.reshape((BATCH_SIZE * SEQ_LEN, FRAME_SIZE)) 387 | # (batch_size*n_frames*frame_size, frame_size) 388 | 389 | sample_level_outputs = sample_level_predictor( 390 | frame_level_outputs.reshape((BATCH_SIZE * SEQ_LEN, DIM)), 391 | prev_samples, 392 | ) 393 | 394 | cost = T.nnet.categorical_crossentropy( 395 | T.nnet.softmax(sample_level_outputs), 396 | target_sequences.flatten() 397 | ) 398 | cost = cost.reshape(target_sequences.shape) 399 | cost = cost * target_mask 400 | # Don't use these lines; could end up with NaN 401 | # Specially at the end of audio files where mask is 402 | # all zero for some of the shorter files in mini-batch. 403 | #cost = cost.sum(axis=1) / target_mask.sum(axis=1) 404 | #cost = cost.mean(axis=0) 405 | 406 | # Use this one instead. 407 | cost = cost.sum() 408 | cost = cost / target_mask.sum() 409 | 410 | # By default we report cross-entropy cost in bits. 411 | # Switch to nats by commenting out this line: 412 | # log_2(e) = 1.44269504089 413 | cost = cost * lib.floatX(numpy.log2(numpy.e)) 414 | 415 | ### Getting the params, grads, updates, and Theano functions ### 416 | params = lib.get_params(cost, lambda x: hasattr(x, 'param') and x.param==True) 417 | lib.print_params_info(params, path=FOLDER_PREFIX) 418 | 419 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 420 | grads = [T.clip(g, lib.floatX(-GRAD_CLIP), lib.floatX(GRAD_CLIP)) for g in grads] 421 | 422 | updates = lasagne.updates.adam(grads, params, learning_rate=LEARNING_RATE) 423 | 424 | # Training function 425 | train_fn = theano.function( 426 | [sequences, h0, reset, mask], 427 | [cost, new_h0], 428 | updates=updates, 429 | on_unused_input='warn' 430 | ) 431 | 432 | # Validation and Test function, hence no updates 433 | test_fn = theano.function( 434 | [sequences, h0, reset, mask], 435 | [cost, new_h0], 436 | on_unused_input='warn' 437 | ) 438 | 439 | # Sampling at frame level 440 | frame_level_generate_fn = theano.function( 441 | [sequences, h0, reset], 442 | frame_level_rnn(sequences, h0, reset), 443 | on_unused_input='warn' 444 | ) 445 | 446 | # Sampling at audio sample level 447 | frame_level_outputs = T.matrix('frame_level_outputs') 448 | prev_samples = T.imatrix('prev_samples') 449 | sample_level_generate_fn = theano.function( 450 | [frame_level_outputs, prev_samples], 451 | lib.ops.softmax_and_sample( 452 | sample_level_predictor( 453 | frame_level_outputs, 454 | prev_samples, 455 | ) 456 | ), 457 | on_unused_input='warn' 458 | ) 459 | 460 | # Uniform [-0.5, 0.5) for half of initial state for generated samples 461 | # to study the behaviour of the model and also to introduce some diversity 462 | # to samples in a simple way. [it's disabled for now] 463 | fixed_rand_h0 = numpy.random.rand(N_SEQS//2, N_RNN, H0_MULT*DIM) 464 | fixed_rand_h0 -= 0.5 465 | fixed_rand_h0 = fixed_rand_h0.astype('float32') 466 | 467 | def generate_and_save_samples(tag): 468 | def write_audio_file(name, data): 469 | data = data.astype('float32') 470 | data -= data.min() 471 | data /= data.max() 472 | data -= 0.5 473 | data *= 0.95 474 | scipy.io.wavfile.write( 475 | os.path.join(SAMPLES_PATH, name+'.wav'), 476 | BITRATE, 477 | data) 478 | 479 | total_time = time() 480 | # Generate N_SEQS' sample files, each 5 seconds long 481 | N_SECS = 15 482 | LENGTH = N_SECS*BITRATE if not args.debug else 100 483 | 484 | samples = numpy.zeros((N_SEQS, LENGTH), dtype='int32') 485 | samples[:, :FRAME_SIZE] = Q_ZERO 486 | 487 | # First half zero, others fixed random at each checkpoint 488 | h0 = numpy.zeros( 489 | (N_SEQS-fixed_rand_h0.shape[0], N_RNN, H0_MULT*DIM), 490 | dtype='float32' 491 | ) 492 | h0 = numpy.concatenate((h0, fixed_rand_h0), axis=0) 493 | frame_level_outputs = None 494 | 495 | for t in xrange(FRAME_SIZE, LENGTH): 496 | 497 | if t % FRAME_SIZE == 0: 498 | frame_level_outputs, h0 = frame_level_generate_fn( 499 | samples[:, t-FRAME_SIZE:t], 500 | h0, 501 | #numpy.full((N_SEQS, ), (t == FRAME_SIZE), dtype='int32'), 502 | numpy.int32(t == FRAME_SIZE) 503 | ) 504 | 505 | samples[:, t] = sample_level_generate_fn( 506 | frame_level_outputs[:, t % FRAME_SIZE], 507 | samples[:, t-FRAME_SIZE:t], 508 | ) 509 | 510 | total_time = time() - total_time 511 | log = "{} samples of {} seconds length generated in {} seconds." 512 | log = log.format(N_SEQS, N_SECS, total_time) 513 | print log, 514 | 515 | for i in xrange(N_SEQS): 516 | samp = samples[i] 517 | if Q_TYPE == 'mu-law': 518 | from datasets.dataset import mu2linear 519 | samp = mu2linear(samp) 520 | elif Q_TYPE == 'a-law': 521 | raise NotImplementedError('a-law is not implemented') 522 | write_audio_file("sample_{}_{}".format(tag, i), samp) 523 | 524 | def monitor(data_feeder): 525 | """ 526 | Cost and time of test_fn on a given dataset section. 527 | Pass only one of `valid_feeder` or `test_feeder`. 528 | Don't pass `train_feed`. 529 | 530 | :returns: 531 | Mean cost over the input dataset (data_feeder) 532 | Total time spent 533 | """ 534 | _total_time = time() 535 | _h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 536 | _costs = [] 537 | _data_feeder = load_data(data_feeder) 538 | for _seqs, _reset, _mask in _data_feeder: 539 | _cost, _h0 = test_fn(_seqs, _h0, _reset, _mask) 540 | _costs.append(_cost) 541 | 542 | return numpy.mean(_costs), time() - _total_time 543 | 544 | print "Wall clock time spent before training started: {:.2f}h"\ 545 | .format((time()-exp_start)/3600.) 546 | print "Training!" 547 | total_iters = 0 548 | total_time = 0. 549 | last_print_time = 0. 550 | last_print_iters = 0 551 | costs = [] 552 | lowest_valid_cost = numpy.finfo(numpy.float32).max 553 | corresponding_test_cost = numpy.finfo(numpy.float32).max 554 | new_lowest_cost = False 555 | end_of_batch = False 556 | epoch = 0 557 | 558 | h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 559 | 560 | # Initial load train dataset 561 | tr_feeder = load_data(train_feeder) 562 | 563 | ### Handling the resume option: 564 | if RESUME: 565 | # Check if checkpoint from previous run is not corrupted. 566 | # Then overwrite some of the variables above. 567 | iters_to_consume, res_path, epoch, total_iters,\ 568 | [lowest_valid_cost, corresponding_test_cost, test_cost] = \ 569 | lib.resumable(path=FOLDER_PREFIX, 570 | iter_key=iter_str, 571 | epoch_key=epoch_str, 572 | add_resume_counter=True, 573 | other_keys=[lowest_valid_str, 574 | corresp_test_str, 575 | test_nll_str]) 576 | # At this point we saved the pkl file. 577 | last_print_iters = total_iters 578 | print "### RESUMING JOB FROM EPOCH {}, ITER {}".format(epoch, total_iters) 579 | # Consumes this much iters to get to the last point in training data. 580 | consume_time = time() 581 | for i in xrange(iters_to_consume): 582 | tr_feeder.next() 583 | consume_time = time() - consume_time 584 | print "Train data ready in {:.2f}secs after consuming {} minibatches.".\ 585 | format(consume_time, iters_to_consume) 586 | 587 | lib.load_params(res_path) 588 | print "Parameters from last available checkpoint loaded." 589 | 590 | while True: 591 | # THIS IS ONE ITERATION 592 | if total_iters % 500 == 0: 593 | print total_iters, 594 | 595 | total_iters += 1 596 | 597 | try: 598 | # Take as many mini-batches as possible from train set 599 | mini_batch = tr_feeder.next() 600 | except StopIteration: 601 | # Mini-batches are finished. Load it again. 602 | # Basically, one epoch. 603 | tr_feeder = load_data(train_feeder) 604 | 605 | # and start taking new mini-batches again. 606 | mini_batch = tr_feeder.next() 607 | epoch += 1 608 | end_of_batch = True 609 | print "[Another epoch]", 610 | 611 | seqs, reset, mask = mini_batch 612 | 613 | start_time = time() 614 | cost, h0 = train_fn(seqs, h0, reset, mask) 615 | total_time += time() - start_time 616 | #print "This cost:", cost, "This h0.mean()", h0.mean() 617 | 618 | costs.append(cost) 619 | 620 | # Monitoring step 621 | if (TRAIN_MODE=='iters' and total_iters-last_print_iters == PRINT_ITERS) or \ 622 | (TRAIN_MODE=='time' and total_time-last_print_time >= PRINT_TIME) or \ 623 | (TRAIN_MODE=='time-iters' and total_time-last_print_time >= PRINT_TIME) or \ 624 | (TRAIN_MODE=='iters-time' and total_iters-last_print_iters >= PRINT_ITERS) or \ 625 | end_of_batch: 626 | # 0. Validation 627 | print "\nValidation!", 628 | valid_cost, valid_time = monitor(valid_feeder) 629 | print "Done!" 630 | 631 | # 1. Test 632 | test_time = 0. 633 | # Only when the validation cost is improved get the cost for test set. 634 | if valid_cost < lowest_valid_cost: 635 | lowest_valid_cost = valid_cost 636 | print "\n>>> Best validation cost of {} reached. Testing!"\ 637 | .format(valid_cost), 638 | test_cost, test_time = monitor(test_feeder) 639 | print "Done!" 640 | # Report last one which is the lowest on validation set: 641 | print ">>> test cost:{}\ttotal time:{}".format(test_cost, test_time) 642 | corresponding_test_cost = test_cost 643 | new_lowest_cost = True 644 | 645 | # 2. Stdout the training progress 646 | print_info = "epoch:{}\ttotal iters:{}\twall clock time:{:.2f}h\n" 647 | print_info += ">>> Lowest valid cost:{}\t Corresponding test cost:{}\n" 648 | print_info += "\ttrain cost:{:.4f}\ttotal time:{:.2f}h\tper iter:{:.3f}s\n" 649 | print_info += "\tvalid cost:{:.4f}\ttotal time:{:.2f}h\n" 650 | print_info += "\ttest cost:{:.4f}\ttotal time:{:.2f}h" 651 | print_info = print_info.format(epoch, 652 | total_iters, 653 | (time()-exp_start)/3600, 654 | lowest_valid_cost, 655 | corresponding_test_cost, 656 | numpy.mean(costs), 657 | total_time/3600, 658 | total_time/total_iters, 659 | valid_cost, 660 | valid_time/3600, 661 | test_cost, 662 | test_time/3600) 663 | print print_info 664 | 665 | tag = "e{}_i{}_t{:.2f}_tr{:.4f}_v{:.4f}" 666 | tag = tag.format(epoch, 667 | total_iters, 668 | total_time/3600, 669 | numpy.mean(cost), 670 | valid_cost) 671 | tag += ("_best" if new_lowest_cost else "") 672 | 673 | # 3. Save params of model (IO bound, time consuming) 674 | # If saving params is not successful, there shouldn't be any trace of 675 | # successful monitoring step in train_log as well. 676 | print "Saving params!", 677 | lib.save_params( 678 | os.path.join(PARAMS_PATH, 'params_{}.pkl'.format(tag)) 679 | ) 680 | print "Done!" 681 | 682 | # 4. Save and graph training progress (fast) 683 | training_info = {epoch_str : epoch, 684 | iter_str : total_iters, 685 | train_nll_str : numpy.mean(costs), 686 | valid_nll_str : valid_cost, 687 | test_nll_str : test_cost, 688 | lowest_valid_str : lowest_valid_cost, 689 | corresp_test_str : corresponding_test_cost, 690 | 'train time' : total_time, 691 | 'valid time' : valid_time, 692 | 'test time' : test_time, 693 | 'wall clock time' : time()-exp_start} 694 | lib.save_training_info(training_info, FOLDER_PREFIX) 695 | print "Train info saved!", 696 | 697 | y_axis_strs = [train_nll_str, valid_nll_str, test_nll_str] 698 | lib.plot_traing_info(iter_str, y_axis_strs, FOLDER_PREFIX) 699 | print "And plotted!" 700 | 701 | # 5. Generate and save samples (time consuming) 702 | # If not successful, we still have the params to sample afterward 703 | print "Sampling!", 704 | # Generate samples 705 | generate_and_save_samples(tag) 706 | print "Done!" 707 | 708 | if total_iters-last_print_iters == PRINT_ITERS \ 709 | or total_time-last_print_time >= PRINT_TIME: 710 | # If we are here b/c of onom_end_of_batch, we shouldn't mess 711 | # with costs and last_print_iters 712 | costs = [] 713 | last_print_time += PRINT_TIME 714 | last_print_iters += PRINT_ITERS 715 | 716 | end_of_batch = False 717 | new_lowest_cost = False 718 | 719 | print "Validation Done!\nBack to Training..." 720 | 721 | if (TRAIN_MODE=='iters' and total_iters == STOP_ITERS) or \ 722 | (TRAIN_MODE=='time' and total_time >= STOP_TIME) or \ 723 | ((TRAIN_MODE=='time-iters' or TRAIN_MODE=='iters-time') and \ 724 | (total_iters == STOP_ITERS or total_time >= STOP_TIME)): 725 | 726 | print "Done! Total iters:", total_iters, "Total time: ", total_time 727 | print "Experiment ended at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 728 | print "Wall clock time spent: {:.2f}h"\ 729 | .format((time()-exp_start)/3600) 730 | 731 | sys.exit() 732 | -------------------------------------------------------------------------------- /models/two_tier/two_tier32k.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN Audio Generation Model 3 | 4 | Two-tier model, Quantized input 5 | For more info: 6 | $ python two_tier.py -h 7 | 8 | How-to-run example: 9 | sampleRNN$ pwd 10 | /u/mehris/sampleRNN 11 | 12 | sampleRNN$ \ 13 | THEANO_FLAGS=mode=FAST_RUN,device=gpu,floatX=float32 python -u \ 14 | models/two_tier/two_tier.py --exp AXIS1 --n_frames 12 --frame_size 10 \ 15 | --weight_norm True --emb_size 64 --skip_conn False --dim 32 --n_rnn 2 \ 16 | --rnn_type LSTM --learn_h0 False --q_levels 16 --q_type linear \ 17 | --batch_size 128 --which_set MUSIC 18 | 19 | To resume add ` --resume` to the END of the EXACTLY above line. You can run the 20 | resume code as many time as possible, depending on the TRAIN_MODE. 21 | (folder name, file name, flags, their order, and the values are important) 22 | """ 23 | from time import time 24 | from datetime import datetime 25 | print "Experiment started at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 26 | exp_start = time() 27 | 28 | import os, sys, glob 29 | sys.path.insert(1, os.getcwd()) 30 | import argparse 31 | 32 | import numpy 33 | numpy.random.seed(123) 34 | np = numpy 35 | import random 36 | random.seed(123) 37 | 38 | import theano 39 | import theano.tensor as T 40 | import theano.ifelse 41 | import lasagne 42 | import scipy.io.wavfile 43 | 44 | import lib 45 | 46 | LEARNING_RATE = 0.001 47 | 48 | ### Parsing passed args/hyperparameters ### 49 | def get_args(): 50 | def t_or_f(arg): 51 | ua = str(arg).upper() 52 | if 'TRUE'.startswith(ua): 53 | return True 54 | elif 'FALSE'.startswith(ua): 55 | return False 56 | else: 57 | raise ValueError('Arg is neither `True` nor `False`') 58 | 59 | def check_non_negative(value): 60 | ivalue = int(value) 61 | if ivalue < 0: 62 | raise argparse.ArgumentTypeError("%s is not non-negative!" % value) 63 | return ivalue 64 | 65 | def check_positive(value): 66 | ivalue = int(value) 67 | if ivalue < 1: 68 | raise argparse.ArgumentTypeError("%s is not positive!" % value) 69 | return ivalue 70 | 71 | def check_unit_interval(value): 72 | fvalue = float(value) 73 | if fvalue < 0 or fvalue > 1: 74 | raise argparse.ArgumentTypeError("%s is not in [0, 1] interval!" % value) 75 | return fvalue 76 | 77 | # No default value here. Indicate every single arguement. 78 | parser = argparse.ArgumentParser( 79 | description='two_tier.py\nNo default value! Indicate every argument.') 80 | 81 | # Hyperparameter arguements: 82 | parser.add_argument('--exp', help='Experiment name', 83 | type=str, required=False, default='_') 84 | parser.add_argument('--n_frames', help='How many "frames" to include in each\ 85 | Truncated BPTT pass', type=check_positive, required=True) 86 | parser.add_argument('--frame_size', help='How many samples per frame',\ 87 | type=check_positive, required=True) 88 | parser.add_argument('--weight_norm', help='Adding learnable weight normalization\ 89 | to all the linear layers (except for the embedding layer)',\ 90 | type=t_or_f, required=True) 91 | parser.add_argument('--emb_size', help='Size of embedding layer (0 to disable)', type=check_non_negative, required=True) 92 | parser.add_argument('--skip_conn', help='Add skip connections to RNN', type=t_or_f, required=True) 93 | parser.add_argument('--dim', help='Dimension of RNN and MLPs',\ 94 | type=check_positive, required=True) 95 | parser.add_argument('--n_rnn', help='Number of layers in the stacked RNN', 96 | type=check_positive, choices=xrange(1,12), required=True) 97 | parser.add_argument('--rnn_type', help='GRU or LSTM', choices=['LSTM', 'GRU'],\ 98 | required=True) 99 | parser.add_argument('--learn_h0', help='Whether to learn the initial state of RNN',\ 100 | type=t_or_f, required=True) 101 | parser.add_argument('--q_levels', help='Number of bins for quantization of audio samples. Should be 256 for mu-law.',\ 102 | type=check_positive, required=True) 103 | parser.add_argument('--q_type', help='Quantization in linear-scale, a-law-companding, or mu-law compandig. With mu-/a-law quantization level shoud be set as 256',\ 104 | choices=['linear', 'a-law', 'mu-law'], required=True) 105 | parser.add_argument('--which_set', help='the directory name of the dataset' , 106 | type=str, required=True) 107 | parser.add_argument('--batch_size', help='size of mini-batch', 108 | type=check_positive, choices=xrange(1,10000), required=True) 109 | 110 | parser.add_argument('--debug', help='Debug mode', required=False, default=False, action='store_true') 111 | # NEW 112 | parser.add_argument('--resume', help='Resume the same model from the last checkpoint. Order of params are important. [for now]',\ 113 | required=False, default=False, action='store_true') 114 | 115 | args = parser.parse_args() 116 | 117 | # NEW 118 | # Create tag for this experiment based on passed args 119 | tag = reduce(lambda a, b: a+b, sys.argv).replace('--resume', '').replace('/', '-').replace('--', '-').replace('True', 'T').replace('False', 'F') 120 | tag += '-lr'+str(LEARNING_RATE) 121 | print "Created experiment tag for these args:" 122 | print tag 123 | 124 | return args, tag 125 | 126 | args, tag = get_args() 127 | 128 | N_FRAMES = args.n_frames # How many 'frames' to include in each truncated BPTT pass 129 | OVERLAP = FRAME_SIZE = args.frame_size # How many samples per frame 130 | WEIGHT_NORM = args.weight_norm 131 | EMB_SIZE = args.emb_size 132 | SKIP_CONN = args.skip_conn 133 | DIM = args.dim # Model dimensionality. 134 | N_RNN = args.n_rnn # How many RNNs to stack 135 | RNN_TYPE = args.rnn_type 136 | H0_MULT = 2 if RNN_TYPE == 'LSTM' else 1 137 | LEARN_H0 = args.learn_h0 138 | Q_LEVELS = args.q_levels # How many levels to use when discretizing samples. e.g. 256 = 8-bit scalar quantization 139 | Q_TYPE = args.q_type # log- or linear-scale 140 | WHICH_SET = args.which_set 141 | BATCH_SIZE = args.batch_size 142 | RESUME = args.resume 143 | 144 | if Q_TYPE == 'mu-law' and Q_LEVELS != 256: 145 | raise ValueError('For mu-law Quantization levels should be exactly 256!') 146 | 147 | # Fixed hyperparams 148 | GRAD_CLIP = 1 # Elementwise grad clip threshold 149 | BITRATE = 32000 150 | 151 | # Other constants 152 | #TRAIN_MODE = 'iters' # To use PRINT_ITERS and STOP_ITERS 153 | TRAIN_MODE = 'time' # To use PRINT_TIME and STOP_TIME 154 | #TRAIN_MODE = 'time-iters' 155 | # To use PRINT_TIME for validation, 156 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 157 | #TRAIN_MODE = 'iters-time' 158 | # To use PRINT_ITERS for validation, 159 | # and (STOP_ITERS, STOP_TIME), whichever happened first, for stopping exp. 160 | PRINT_ITERS = 10000 # Print cost, generate samples, save model checkpoint every N iterations. 161 | STOP_ITERS = 100000 # Stop after this many iterations 162 | # TODO: 163 | PRINT_TIME = 60*60*24*2 # Print cost, generate samples, save model checkpoint every N seconds. 164 | STOP_TIME = 60*60*24*4 # Stop after this many seconds of actual training (not including time req'd to generate samples etc.) 165 | N_SEQS = 5 # Number of samples to generate every time monitoring. 166 | # TODO: 167 | RESULTS_DIR = 'results_2t' 168 | FOLDER_PREFIX = os.path.join(RESULTS_DIR, tag) 169 | SEQ_LEN = N_FRAMES * FRAME_SIZE # Total length (# of samples) of each truncated BPTT sequence 170 | Q_ZERO = numpy.int32(Q_LEVELS//2) # Discrete value correponding to zero amplitude 171 | 172 | 173 | epoch_str = 'epoch' 174 | iter_str = 'iter' 175 | lowest_valid_str = 'lowest valid cost' 176 | corresp_test_str = 'correponding test cost' 177 | train_nll_str, valid_nll_str, test_nll_str = \ 178 | 'train NLL (bits)', 'valid NLL (bits)', 'test NLL (bits)' 179 | 180 | if args.debug: 181 | import warnings 182 | warnings.warn('----------RUNNING IN DEBUG MODE----------') 183 | TRAIN_MODE = 'time' 184 | PRINT_TIME = 100 185 | STOP_TIME = 3000 186 | STOP_ITERS = 1000 187 | 188 | ### Create directories ### 189 | # FOLDER_PREFIX: root, contains: 190 | # log.txt, __note.txt, train_log.pkl, train_log.png [, model_settings.txt] 191 | # FOLDER_PREFIX/params: saves all checkpoint params as pkl 192 | # FOLDER_PREFIX/samples: keeps all checkpoint samples as wav 193 | # FOLDER_PREFIX/best: keeps the best parameters, samples, ... 194 | if not os.path.exists(FOLDER_PREFIX): 195 | os.makedirs(FOLDER_PREFIX) 196 | PARAMS_PATH = os.path.join(FOLDER_PREFIX, 'params') 197 | if not os.path.exists(PARAMS_PATH): 198 | os.makedirs(PARAMS_PATH) 199 | SAMPLES_PATH = os.path.join(FOLDER_PREFIX, 'samples') 200 | if not os.path.exists(SAMPLES_PATH): 201 | os.makedirs(SAMPLES_PATH) 202 | BEST_PATH = os.path.join(FOLDER_PREFIX, 'best') 203 | if not os.path.exists(BEST_PATH): 204 | os.makedirs(BEST_PATH) 205 | 206 | lib.print_model_settings(locals(), path=FOLDER_PREFIX, sys_arg=True) 207 | 208 | ### Import the data_feeder ### 209 | # Handling WHICH_SET 210 | from datasets.dataset import music_train_feed_epoch as train_feeder 211 | from datasets.dataset import music_valid_feed_epoch as valid_feeder 212 | from datasets.dataset import music_test_feed_epoch as test_feeder 213 | 214 | def load_data(data_feeder): 215 | """ 216 | Helper function to deal with interface of different datasets. 217 | `data_feeder` should be `train_feeder`, `valid_feeder`, or `test_feeder`. 218 | """ 219 | return data_feeder(WHICH_SET, BATCH_SIZE, 220 | SEQ_LEN, 221 | OVERLAP, 222 | Q_LEVELS, 223 | Q_ZERO, 224 | Q_TYPE) 225 | 226 | ### Creating computation graph ### 227 | def frame_level_rnn(input_sequences, h0, reset): 228 | """ 229 | input_sequences.shape: (batch size, n frames * FRAME_SIZE) 230 | h0.shape: (batch size, N_RNN, DIM) 231 | reset.shape: () 232 | 233 | output.shape: (batch size, n frames * FRAME_SIZE, DIM) 234 | """ 235 | frames = input_sequences.reshape(( 236 | input_sequences.shape[0], 237 | input_sequences.shape[1] // FRAME_SIZE, 238 | FRAME_SIZE 239 | )) 240 | 241 | # Rescale frames from ints in [0, Q_LEVELS) to floats in [-2, 2] 242 | # (a reasonable range to pass as inputs to the RNN) 243 | frames = (frames.astype('float32') / lib.floatX(Q_LEVELS/2)) - lib.floatX(1) 244 | frames *= lib.floatX(2) 245 | # (128, 64, 4) 246 | 247 | # Initial state of RNNs 248 | learned_h0 = lib.param( 249 | 'FrameLevel.h0', 250 | numpy.zeros((N_RNN, H0_MULT*DIM), dtype=theano.config.floatX) 251 | ) 252 | # Handling LEARN_H0 253 | learned_h0.param = LEARN_H0 254 | learned_h0 = T.alloc(learned_h0, h0.shape[0], N_RNN, H0_MULT*DIM) 255 | learned_h0 = T.unbroadcast(learned_h0, 0, 1, 2) 256 | h0 = theano.ifelse.ifelse(reset, learned_h0, h0) 257 | 258 | # Handling RNN_TYPE 259 | # Handling SKIP_CONN 260 | if RNN_TYPE == 'GRU': 261 | rnns_out, last_hidden = lib.ops.stackedGRU('FrameLevel.GRU', 262 | N_RNN, 263 | FRAME_SIZE, 264 | DIM, 265 | frames, 266 | h0=h0, 267 | weightnorm=WEIGHT_NORM, 268 | skip_conn=SKIP_CONN) 269 | elif RNN_TYPE == 'LSTM': 270 | rnns_out, last_hidden = lib.ops.stackedLSTM('FrameLevel.LSTM', 271 | N_RNN, 272 | FRAME_SIZE, 273 | DIM, 274 | frames, 275 | h0=h0, 276 | weightnorm=WEIGHT_NORM, 277 | skip_conn=SKIP_CONN) 278 | 279 | # rnns_out (bs, seqlen, dim) (128, 64, 512) 280 | output = lib.ops.Linear( 281 | 'FrameLevel.Output', 282 | DIM, 283 | FRAME_SIZE * DIM, 284 | rnns_out, 285 | initialization='he', 286 | weightnorm=WEIGHT_NORM 287 | ) 288 | # output: (2, 9, 4*dim) 289 | output = output.reshape((output.shape[0], output.shape[1] * FRAME_SIZE, DIM)) 290 | # output: (2, 9*4, dim) 291 | 292 | return (output, last_hidden) 293 | 294 | def sample_level_predictor(frame_level_outputs, prev_samples): 295 | """ 296 | batch size = BATCH_SIZE * SEQ_LEN 297 | SEQ_LEN = N_FRAMES * FRAME_SIZE 298 | 299 | frame_level_outputs.shape: (batch size, DIM) 300 | prev_samples.shape: (batch size, FRAME_SIZE) int32 301 | 302 | output.shape: (batch size, Q_LEVELS) 303 | """ 304 | # Handling EMB_SIZE 305 | if EMB_SIZE == 0: 306 | prev_samples = lib.ops.T_one_hot(prev_samples, Q_LEVELS) 307 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, Q_LEVELS) 308 | last_out_shape = Q_LEVELS 309 | elif EMB_SIZE > 0: 310 | prev_samples = lib.ops.Embedding( 311 | 'SampleLevel.Embedding', 312 | Q_LEVELS, 313 | EMB_SIZE, 314 | prev_samples) 315 | # (BATCH_SIZE*N_FRAMES*FRAME_SIZE, FRAME_SIZE, EMB_SIZE), f32 316 | last_out_shape = EMB_SIZE 317 | else: 318 | raise ValueError('EMB_SIZE cannot be negative.') 319 | 320 | prev_samples = prev_samples.reshape((-1, FRAME_SIZE * last_out_shape)) 321 | 322 | out = lib.ops.Linear( 323 | 'SampleLevel.L1_PrevSamples', 324 | FRAME_SIZE * last_out_shape, 325 | DIM, 326 | prev_samples, 327 | biases=False, 328 | initialization='he', 329 | weightnorm=WEIGHT_NORM) 330 | # shape: (BATCH_SIZE*N_FRAMES*FRAME_SIZE, DIM) 331 | 332 | out += frame_level_outputs 333 | # ^ (2*(9*4), dim) 334 | 335 | # L2 336 | out = lib.ops.Linear('SampleLevel.L2', 337 | DIM, 338 | DIM, 339 | out, 340 | initialization='he', 341 | weightnorm=WEIGHT_NORM) 342 | out = T.nnet.relu(out) 343 | 344 | # L3 345 | out = lib.ops.Linear('SampleLevel.L3', 346 | DIM, 347 | DIM, 348 | out, 349 | initialization='he', 350 | weightnorm=WEIGHT_NORM) 351 | out = T.nnet.relu(out) 352 | 353 | # Output 354 | # We apply the softmax later 355 | out = lib.ops.Linear('SampleLevel.Output', 356 | DIM, 357 | Q_LEVELS, 358 | out, 359 | weightnorm=WEIGHT_NORM) 360 | return out 361 | 362 | sequences = T.imatrix('sequences') 363 | h0 = T.tensor3('h0') 364 | reset = T.iscalar('reset') 365 | mask = T.matrix('mask') 366 | 367 | if args.debug: 368 | # Solely for debugging purposes. 369 | # Maybe I should set the compute_test_value=warn from here. 370 | sequences.tag.test_value = numpy.zeros((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='int32') 371 | h0.tag.test_value = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 372 | reset.tag.test_value = numpy.array(1, dtype='int32') 373 | mask.tag.test_value = numpy.ones((BATCH_SIZE, SEQ_LEN+OVERLAP), dtype='float32') 374 | 375 | input_sequences = sequences[:, :-FRAME_SIZE] 376 | target_sequences = sequences[:, FRAME_SIZE:] 377 | 378 | target_mask = mask[:, FRAME_SIZE:] 379 | 380 | frame_level_outputs, new_h0 =\ 381 | frame_level_rnn(input_sequences, h0, reset) 382 | 383 | prev_samples = sequences[:, :-1] 384 | prev_samples = prev_samples.reshape((1, BATCH_SIZE, 1, -1)) 385 | prev_samples = T.nnet.neighbours.images2neibs(prev_samples, (1, FRAME_SIZE), neib_step=(1, 1), mode='valid') 386 | prev_samples = prev_samples.reshape((BATCH_SIZE * SEQ_LEN, FRAME_SIZE)) 387 | # (batch_size*n_frames*frame_size, frame_size) 388 | 389 | sample_level_outputs = sample_level_predictor( 390 | frame_level_outputs.reshape((BATCH_SIZE * SEQ_LEN, DIM)), 391 | prev_samples, 392 | ) 393 | 394 | cost = T.nnet.categorical_crossentropy( 395 | T.nnet.softmax(sample_level_outputs), 396 | target_sequences.flatten() 397 | ) 398 | cost = cost.reshape(target_sequences.shape) 399 | cost = cost * target_mask 400 | # Don't use these lines; could end up with NaN 401 | # Specially at the end of audio files where mask is 402 | # all zero for some of the shorter files in mini-batch. 403 | #cost = cost.sum(axis=1) / target_mask.sum(axis=1) 404 | #cost = cost.mean(axis=0) 405 | 406 | # Use this one instead. 407 | cost = cost.sum() 408 | cost = cost / target_mask.sum() 409 | 410 | # By default we report cross-entropy cost in bits. 411 | # Switch to nats by commenting out this line: 412 | # log_2(e) = 1.44269504089 413 | cost = cost * lib.floatX(numpy.log2(numpy.e)) 414 | 415 | ### Getting the params, grads, updates, and Theano functions ### 416 | params = lib.get_params(cost, lambda x: hasattr(x, 'param') and x.param==True) 417 | lib.print_params_info(params, path=FOLDER_PREFIX) 418 | 419 | grads = T.grad(cost, wrt=params, disconnected_inputs='warn') 420 | grads = [T.clip(g, lib.floatX(-GRAD_CLIP), lib.floatX(GRAD_CLIP)) for g in grads] 421 | 422 | updates = lasagne.updates.adam(grads, params, learning_rate=LEARNING_RATE) 423 | 424 | # Training function 425 | train_fn = theano.function( 426 | [sequences, h0, reset, mask], 427 | [cost, new_h0], 428 | updates=updates, 429 | on_unused_input='warn' 430 | ) 431 | 432 | # Validation and Test function, hence no updates 433 | test_fn = theano.function( 434 | [sequences, h0, reset, mask], 435 | [cost, new_h0], 436 | on_unused_input='warn' 437 | ) 438 | 439 | # Sampling at frame level 440 | frame_level_generate_fn = theano.function( 441 | [sequences, h0, reset], 442 | frame_level_rnn(sequences, h0, reset), 443 | on_unused_input='warn' 444 | ) 445 | 446 | # Sampling at audio sample level 447 | frame_level_outputs = T.matrix('frame_level_outputs') 448 | prev_samples = T.imatrix('prev_samples') 449 | sample_level_generate_fn = theano.function( 450 | [frame_level_outputs, prev_samples], 451 | lib.ops.softmax_and_sample( 452 | sample_level_predictor( 453 | frame_level_outputs, 454 | prev_samples, 455 | ) 456 | ), 457 | on_unused_input='warn' 458 | ) 459 | 460 | # Uniform [-0.5, 0.5) for half of initial state for generated samples 461 | # to study the behaviour of the model and also to introduce some diversity 462 | # to samples in a simple way. [it's disabled for now] 463 | fixed_rand_h0 = numpy.random.rand(N_SEQS//2, N_RNN, H0_MULT*DIM) 464 | fixed_rand_h0 -= 0.5 465 | fixed_rand_h0 = fixed_rand_h0.astype('float32') 466 | 467 | def generate_and_save_samples(tag): 468 | def write_audio_file(name, data): 469 | data = data.astype('float32') 470 | data -= data.min() 471 | data /= data.max() 472 | data -= 0.5 473 | data *= 0.95 474 | scipy.io.wavfile.write( 475 | os.path.join(SAMPLES_PATH, name+'.wav'), 476 | BITRATE, 477 | data) 478 | 479 | total_time = time() 480 | # Generate N_SEQS' sample files, each 5 seconds long 481 | N_SECS = 30 482 | LENGTH = N_SECS*BITRATE if not args.debug else 100 483 | 484 | samples = numpy.zeros((N_SEQS, LENGTH), dtype='int32') 485 | samples[:, :FRAME_SIZE] = Q_ZERO 486 | 487 | # First half zero, others fixed random at each checkpoint 488 | h0 = numpy.zeros( 489 | (N_SEQS-fixed_rand_h0.shape[0], N_RNN, H0_MULT*DIM), 490 | dtype='float32' 491 | ) 492 | h0 = numpy.concatenate((h0, fixed_rand_h0), axis=0) 493 | frame_level_outputs = None 494 | 495 | for t in xrange(FRAME_SIZE, LENGTH): 496 | 497 | if t % FRAME_SIZE == 0: 498 | frame_level_outputs, h0 = frame_level_generate_fn( 499 | samples[:, t-FRAME_SIZE:t], 500 | h0, 501 | #numpy.full((N_SEQS, ), (t == FRAME_SIZE), dtype='int32'), 502 | numpy.int32(t == FRAME_SIZE) 503 | ) 504 | 505 | samples[:, t] = sample_level_generate_fn( 506 | frame_level_outputs[:, t % FRAME_SIZE], 507 | samples[:, t-FRAME_SIZE:t], 508 | ) 509 | 510 | total_time = time() - total_time 511 | log = "{} samples of {} seconds length generated in {} seconds." 512 | log = log.format(N_SEQS, N_SECS, total_time) 513 | print log, 514 | 515 | for i in xrange(N_SEQS): 516 | samp = samples[i] 517 | if Q_TYPE == 'mu-law': 518 | from datasets.dataset import mu2linear 519 | samp = mu2linear(samp) 520 | elif Q_TYPE == 'a-law': 521 | raise NotImplementedError('a-law is not implemented') 522 | write_audio_file("sample_{}_{}".format(tag, i), samp) 523 | 524 | def monitor(data_feeder): 525 | """ 526 | Cost and time of test_fn on a given dataset section. 527 | Pass only one of `valid_feeder` or `test_feeder`. 528 | Don't pass `train_feed`. 529 | 530 | :returns: 531 | Mean cost over the input dataset (data_feeder) 532 | Total time spent 533 | """ 534 | _total_time = time() 535 | _h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 536 | _costs = [] 537 | _data_feeder = load_data(data_feeder) 538 | for _seqs, _reset, _mask in _data_feeder: 539 | _cost, _h0 = test_fn(_seqs, _h0, _reset, _mask) 540 | _costs.append(_cost) 541 | 542 | return numpy.mean(_costs), time() - _total_time 543 | 544 | print "Wall clock time spent before training started: {:.2f}h"\ 545 | .format((time()-exp_start)/3600.) 546 | print "Training!" 547 | total_iters = 0 548 | total_time = 0. 549 | last_print_time = 0. 550 | last_print_iters = 0 551 | costs = [] 552 | lowest_valid_cost = numpy.finfo(numpy.float32).max 553 | corresponding_test_cost = numpy.finfo(numpy.float32).max 554 | new_lowest_cost = False 555 | end_of_batch = False 556 | epoch = 0 557 | 558 | h0 = numpy.zeros((BATCH_SIZE, N_RNN, H0_MULT*DIM), dtype='float32') 559 | 560 | # Initial load train dataset 561 | tr_feeder = load_data(train_feeder) 562 | 563 | ### Handling the resume option: 564 | if RESUME: 565 | # Check if checkpoint from previous run is not corrupted. 566 | # Then overwrite some of the variables above. 567 | iters_to_consume, res_path, epoch, total_iters,\ 568 | [lowest_valid_cost, corresponding_test_cost, test_cost] = \ 569 | lib.resumable(path=FOLDER_PREFIX, 570 | iter_key=iter_str, 571 | epoch_key=epoch_str, 572 | add_resume_counter=True, 573 | other_keys=[lowest_valid_str, 574 | corresp_test_str, 575 | test_nll_str]) 576 | # At this point we saved the pkl file. 577 | last_print_iters = total_iters 578 | print "### RESUMING JOB FROM EPOCH {}, ITER {}".format(epoch, total_iters) 579 | # Consumes this much iters to get to the last point in training data. 580 | consume_time = time() 581 | for i in xrange(iters_to_consume): 582 | tr_feeder.next() 583 | consume_time = time() - consume_time 584 | print "Train data ready in {:.2f}secs after consuming {} minibatches.".\ 585 | format(consume_time, iters_to_consume) 586 | 587 | lib.load_params(res_path) 588 | print "Parameters from last available checkpoint loaded." 589 | 590 | while True: 591 | # THIS IS ONE ITERATION 592 | if total_iters % 500 == 0: 593 | print total_iters, 594 | 595 | total_iters += 1 596 | 597 | try: 598 | # Take as many mini-batches as possible from train set 599 | mini_batch = tr_feeder.next() 600 | except StopIteration: 601 | # Mini-batches are finished. Load it again. 602 | # Basically, one epoch. 603 | tr_feeder = load_data(train_feeder) 604 | 605 | # and start taking new mini-batches again. 606 | mini_batch = tr_feeder.next() 607 | epoch += 1 608 | end_of_batch = True 609 | print "[Another epoch]", 610 | 611 | seqs, reset, mask = mini_batch 612 | 613 | start_time = time() 614 | cost, h0 = train_fn(seqs, h0, reset, mask) 615 | total_time += time() - start_time 616 | #print "This cost:", cost, "This h0.mean()", h0.mean() 617 | 618 | costs.append(cost) 619 | 620 | # Monitoring step 621 | if (TRAIN_MODE=='iters' and total_iters-last_print_iters == PRINT_ITERS) or \ 622 | (TRAIN_MODE=='time' and total_time-last_print_time >= PRINT_TIME) or \ 623 | (TRAIN_MODE=='time-iters' and total_time-last_print_time >= PRINT_TIME) or \ 624 | (TRAIN_MODE=='iters-time' and total_iters-last_print_iters >= PRINT_ITERS) or \ 625 | end_of_batch: 626 | # 0. Validation 627 | print "\nValidation!", 628 | valid_cost, valid_time = monitor(valid_feeder) 629 | print "Done!" 630 | 631 | # 1. Test 632 | test_time = 0. 633 | # Only when the validation cost is improved get the cost for test set. 634 | if valid_cost < lowest_valid_cost: 635 | lowest_valid_cost = valid_cost 636 | print "\n>>> Best validation cost of {} reached. Testing!"\ 637 | .format(valid_cost), 638 | test_cost, test_time = monitor(test_feeder) 639 | print "Done!" 640 | # Report last one which is the lowest on validation set: 641 | print ">>> test cost:{}\ttotal time:{}".format(test_cost, test_time) 642 | corresponding_test_cost = test_cost 643 | new_lowest_cost = True 644 | 645 | # 2. Stdout the training progress 646 | print_info = "epoch:{}\ttotal iters:{}\twall clock time:{:.2f}h\n" 647 | print_info += ">>> Lowest valid cost:{}\t Corresponding test cost:{}\n" 648 | print_info += "\ttrain cost:{:.4f}\ttotal time:{:.2f}h\tper iter:{:.3f}s\n" 649 | print_info += "\tvalid cost:{:.4f}\ttotal time:{:.2f}h\n" 650 | print_info += "\ttest cost:{:.4f}\ttotal time:{:.2f}h" 651 | print_info = print_info.format(epoch, 652 | total_iters, 653 | (time()-exp_start)/3600, 654 | lowest_valid_cost, 655 | corresponding_test_cost, 656 | numpy.mean(costs), 657 | total_time/3600, 658 | total_time/total_iters, 659 | valid_cost, 660 | valid_time/3600, 661 | test_cost, 662 | test_time/3600) 663 | print print_info 664 | 665 | tag = "e{}_i{}_t{:.2f}_tr{:.4f}_v{:.4f}" 666 | tag = tag.format(epoch, 667 | total_iters, 668 | total_time/3600, 669 | numpy.mean(cost), 670 | valid_cost) 671 | tag += ("_best" if new_lowest_cost else "") 672 | 673 | # 3. Save params of model (IO bound, time consuming) 674 | # If saving params is not successful, there shouldn't be any trace of 675 | # successful monitoring step in train_log as well. 676 | print "Saving params!", 677 | lib.save_params( 678 | os.path.join(PARAMS_PATH, 'params_{}.pkl'.format(tag)) 679 | ) 680 | print "Done!" 681 | 682 | # 4. Save and graph training progress (fast) 683 | training_info = {epoch_str : epoch, 684 | iter_str : total_iters, 685 | train_nll_str : numpy.mean(costs), 686 | valid_nll_str : valid_cost, 687 | test_nll_str : test_cost, 688 | lowest_valid_str : lowest_valid_cost, 689 | corresp_test_str : corresponding_test_cost, 690 | 'train time' : total_time, 691 | 'valid time' : valid_time, 692 | 'test time' : test_time, 693 | 'wall clock time' : time()-exp_start} 694 | lib.save_training_info(training_info, FOLDER_PREFIX) 695 | print "Train info saved!", 696 | 697 | y_axis_strs = [train_nll_str, valid_nll_str, test_nll_str] 698 | lib.plot_traing_info(iter_str, y_axis_strs, FOLDER_PREFIX) 699 | print "And plotted!" 700 | 701 | # 5. Generate and save samples (time consuming) 702 | # If not successful, we still have the params to sample afterward 703 | print "Sampling!", 704 | # Generate samples 705 | generate_and_save_samples(tag) 706 | print "Done!" 707 | 708 | if total_iters-last_print_iters == PRINT_ITERS \ 709 | or total_time-last_print_time >= PRINT_TIME: 710 | # If we are here b/c of onom_end_of_batch, we shouldn't mess 711 | # with costs and last_print_iters 712 | costs = [] 713 | last_print_time += PRINT_TIME 714 | last_print_iters += PRINT_ITERS 715 | 716 | end_of_batch = False 717 | new_lowest_cost = False 718 | 719 | print "Validation Done!\nBack to Training..." 720 | 721 | if (TRAIN_MODE=='iters' and total_iters == STOP_ITERS) or \ 722 | (TRAIN_MODE=='time' and total_time >= STOP_TIME) or \ 723 | ((TRAIN_MODE=='time-iters' or TRAIN_MODE=='iters-time') and \ 724 | (total_iters == STOP_ITERS or total_time >= STOP_TIME)): 725 | 726 | print "Done! Total iters:", total_iters, "Total time: ", total_time 727 | print "Experiment ended at:", datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M') 728 | print "Wall clock time spent: {:.2f}h"\ 729 | .format((time()-exp_start)/3600) 730 | 731 | sys.exit() 732 | --------------------------------------------------------------------------------