├── .gitignore ├── models.pyc ├── reader.pyc ├── final_paper.pdf ├── images └── example_abc.png ├── hyperparameters.txt ├── hparams_seq2seq.txt ├── test_train_tester.py ├── reader.py ├── song_generator.py ├── utils_baseline.py ├── README.md ├── utils_models.py ├── utils_hyperparam.py ├── utils_runtime.py ├── midi_manipulator.py ├── midi_crawler.py ├── utils.py ├── run_gan.py ├── utils_preprocess.py ├── run.py └── models.py /.gitignore: -------------------------------------------------------------------------------- 1 | sample_data/ 2 | -------------------------------------------------------------------------------- /models.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinoue93/CS224N_proj/HEAD/models.pyc -------------------------------------------------------------------------------- /reader.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinoue93/CS224N_proj/HEAD/reader.pyc -------------------------------------------------------------------------------- /final_paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinoue93/CS224N_proj/HEAD/final_paper.pdf -------------------------------------------------------------------------------- /images/example_abc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yinoue93/CS224N_proj/HEAD/images/example_abc.png -------------------------------------------------------------------------------- /hyperparameters.txt: -------------------------------------------------------------------------------- 1 | char 2 | 2 3 | meta_embed, 60, 160, 50 4 | embedding_dims, 30, 150, 60 5 | keep_prob, 0.8, 0.8, 0 -------------------------------------------------------------------------------- /hparams_seq2seq.txt: -------------------------------------------------------------------------------- 1 | seq2seq 2 | 1 3 | meta_embed, 60, 160, 50 4 | embedding_dims, 10, 70, 30 5 | keep_prob, 0.8, 0.8, 0 6 | attention_option, ['bahdanau', 'luong'] 7 | bidirectional, [False, True] 8 | -------------------------------------------------------------------------------- /test_train_tester.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import pickle 4 | import datetime 5 | 6 | from utils_hyperparam import OUTPUT_FILE 7 | 8 | TMP_HYPER_PICKLE = 'tmp_hyperparam.p' 9 | 10 | #-----------CHANGE THESE PARAMETERS-------------------------- 11 | TRAIN = '/data/full_dataset/char_rnn_dataset/nn_input_train_stride_25_window_10_nnType_char_rnn_shuffled' 12 | CKPT_DIR = '/data/another/char_10/' 13 | MODEL_TYPE = 'char' 14 | #------------------------------------------------------------ 15 | 16 | TEST = TRAIN.replace('train', 'test') 17 | DEV = TRAIN.replace('train', 'dev') 18 | 19 | def runTests(ckptList, dataset): 20 | for ckptPath in ckptList: 21 | cmd = 'python run.py -p dev -ckpt %s -m %s -c %s -data %s' \ 22 | %(ckptPath, MODEL_TYPE, TMP_HYPER_PICKLE, dataset) 23 | 24 | os.system(cmd) 25 | 26 | def getTestTrainAccuracies(): 27 | if os.path.exists(OUTPUT_FILE): 28 | os.remove(OUTPUT_FILE) 29 | 30 | # first scrape the model names 31 | ckptSet = set() 32 | for filename in os.listdir(CKPT_DIR): 33 | modelName = re.findall('model.ckpt-[0-9]+', filename) 34 | if len(modelName)==0: 35 | continue 36 | ckptSet.add(modelName[0]) 37 | 38 | ckptList = [] 39 | for i in range(len(ckptSet)): 40 | for j,cName in enumerate(ckptSet): 41 | if str(i) in cName: 42 | break 43 | 44 | ckptList.append(cName) 45 | 46 | ckptList = [os.path.join(CKPT_DIR, ckptName) for ckptName in ckptList] 47 | 48 | # dump a fake pickle file to trick run.py to think that we are doing 49 | # hyperparameter tuning 50 | emptyDict = {} 51 | pickle.dump(emptyDict, open(TMP_HYPER_PICKLE, 'wb')) 52 | 53 | with open(OUTPUT_FILE, 'a') as f: 54 | f.write('Train Dataset:\n') 55 | runTests(ckptList, TRAIN) 56 | 57 | with open(OUTPUT_FILE, 'a') as f: 58 | f.write('\nTest Dataset:\n') 59 | runTests(ckptList, TEST) 60 | 61 | with open(OUTPUT_FILE, 'a') as f: 62 | f.write('\nDev Dataset:\n') 63 | runTests(ckptList, DEV) 64 | 65 | # rename the result file with a timestamp 66 | now = datetime.datetime.now() 67 | resultName = '%s_%s_%s.txt' %(TRAIN[(TRAIN.rfind('/')+1):], MODEL_TYPE, 68 | now.strftime("%B_%d_%H_%M_%S")) 69 | 70 | with open(OUTPUT_FILE, 'r') as f, open(resultName, 'w') as g: 71 | txt = f.read() 72 | g.write(txt.replace('Dev set accuracy: ','')) 73 | 74 | os.remove(OUTPUT_FILE) 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | getTestTrainAccuracies() -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import re 5 | import json 6 | import pickle 7 | 8 | # Metadata + ~50 characters, then sliding window of (t+1) 9 | # Feed dict should pass in an intial state (previous final state) 10 | # Train on entire song & Batching for different songs? 11 | # Train on individual window examples 12 | 13 | 14 | def abc_filenames(datapath): 15 | return [os.path.join(datapath, f) for f in os.listdir(datapath) if os.path.isfile(os.path.join(datapath, f))] 16 | 17 | 18 | def abc_batch(iterable, n=1): 19 | l = len(iterable) 20 | batches = [] 21 | for ndx in range(0, l, n): 22 | if min(ndx + n, l) - ndx == n: 23 | batches.append(iterable[ndx:(ndx + n)]) 24 | return batches 25 | 26 | 27 | def read_abc_pickle(train_file): 28 | with open(train_file, 'r') as fd: 29 | return pickle.load(fd) 30 | 31 | 32 | def compute_save_vocabulary(datapath): 33 | # Iterate through whole dataset directory 34 | filenames = abc_filenames(datapath) 35 | unique_characters = set([]) 36 | for filename in filenames: 37 | characters = read_abc(filename) 38 | unique_characters.update(characters) 39 | 40 | vocabulary = dict(zip(unique_characters, range(len(unique_characters)))) 41 | with open('vocabulary.json', 'w') as v: 42 | json.dump(vocabulary, v) 43 | 44 | 45 | def load_vocabulary(): 46 | with open('vocabulary.json', 'r') as v: 47 | return json.load(v) 48 | 49 | 50 | def get_abs_files(datapath): 51 | filenames = abc_filenames(datapath) 52 | abc_songs = [] # Encoded as indicies 53 | for filename in filenames: 54 | characters = read_abc(filename) 55 | abc_songs.append(characters) 56 | return abc_songs 57 | 58 | 59 | def abc_to_index(filename, vocabulary): 60 | characters = read_abc(filename) 61 | character_indicies = [vocabulary[char] for char in characters] 62 | return character_indicies 63 | 64 | 65 | def read_abc(filename, exclude_title=True): 66 | with open(filename, 'r') as f: 67 | data = [line for line in f] 68 | if exclude_title: 69 | data = data[1:] 70 | # 4 metadata 'symbols' 71 | metadata = [re.split(":|\r\n", meta)[1].lower() for meta in data[:-1]] 72 | return metadata + list(re.split("\r\r\n",data[-1])[0]) 73 | 74 | 75 | def abc_producer(char_ids, batch_size): 76 | pass 77 | 78 | 79 | def main(_): 80 | datapath = "sample_data" 81 | compute_save_vocabulary(datapath) 82 | vocabulary = load_vocabulary() 83 | print vocabulary 84 | filename = "sample_data/Zycanthos jig_0.abc" 85 | abc_indecies = abc_to_index(filename, vocabulary) 86 | 87 | 88 | if __name__ == "__main__": 89 | tf.app.run() 90 | -------------------------------------------------------------------------------- /song_generator.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import re 4 | import pickle 5 | 6 | from argparse import ArgumentParser 7 | 8 | from run import run_model 9 | 10 | class ArgumentParserWannabe(object): 11 | pass 12 | 13 | def generateSong(args): 14 | args_fake = ArgumentParserWannabe() 15 | args_fake.train = 'sample' 16 | args_fake.data_dir = '' 17 | args_fake.num_epochs = 1 18 | args_fake.ckpt_dir = '' 19 | args_fake.set_config = 'song_generator.p' 20 | args_fake.override = False 21 | args_fake.ran_from_script = True 22 | args_fake.warm_len = args.warm_len 23 | 24 | if args.temperature==0 and (args.model=='seq2seq' or args.model=='duet'): 25 | args_fake.temperature = None 26 | else: 27 | args_fake.temperature = args.temperature 28 | 29 | sys.stdout = open(os.devnull, "w") 30 | 31 | if len(args.real_song) != 0: 32 | args_fake.warmupData = '/data/full_dataset/handmade/' + args.real_song 33 | 34 | ckpt_modifier = '' if args.ckpt_num==-1 else ('model.ckpt-'+str(args.ckpt_num)) 35 | 36 | if args.model=='seq2seq': 37 | args_fake.model = 'seq2seq' 38 | args_fake.ckpt_dir = '/data/another/seq2seq_25_2/'+ckpt_modifier 39 | 40 | paramDict = {'meta_embed':160, 'embedding_dims':100, 'keep_prob':0.8, 41 | 'attention_option':'bahnadau', 'bidirectional':False} 42 | with open(args_fake.set_config,'wb') as f: 43 | pickle.dump(paramDict, f) 44 | 45 | generated = run_model(args_fake) 46 | 47 | elif args.model=='char': 48 | args_fake.model = 'char' 49 | args_fake.ckpt_dir = '/data/another/char_50_2/'+ckpt_modifier 50 | 51 | paramDict = {'meta_embed':160, 'embedding_dims':20, 'keep_prob':0.8} 52 | with open(args_fake.set_config,'wb') as f: 53 | pickle.dump(paramDict, f) 54 | 55 | generated = run_model(args_fake) 56 | 57 | elif args.model=='cbow': 58 | args_fake.model = 'cbow' 59 | args_fake.ckpt_dir = '/data/another/cbow_ckpt/model.ckpt-8' 60 | 61 | paramDict = {'meta_embed':100, 'embedding_dims':60, 'keep_prob':0.8} 62 | with open(args_fake.set_config,'wb') as f: 63 | pickle.dump(paramDict, f) 64 | 65 | generated = run_model(args_fake) 66 | 67 | elif args.model=='duet': 68 | args_fake.model = 'seq2seq' 69 | args_fake.ckpt_dir = '/data/another/seq2seq_duet/'+ckpt_modifier 70 | args_fake.meta_map = 'full_dataset/duet_processed/vocab_map_meta.p' 71 | args_fake.music_map = 'full_dataset/duet_processed/vocab_map_music.p' 72 | args_fake.warmupData = '/data/full_dataset/duet_processed/checked' 73 | 74 | paramDict = {'meta_embed':160, 'embedding_dims':100, 'keep_prob':0.8, 75 | 'attention_option':'bahnadau', 'bidirectional':False} 76 | with open(args_fake.set_config,'wb') as f: 77 | pickle.dump(paramDict, f) 78 | 79 | generated = run_model(args_fake).replace('%','\n') 80 | 81 | generated = generated.replace('','').replace('','') 82 | 83 | long_num = re.findall('[0-9][0-9]+', generated) 84 | for longint in long_num: 85 | generated = generated.replace(longint, longint[0]) 86 | 87 | sys.stdout = sys.__stdout__ 88 | print '-'*50 89 | print generated 90 | 91 | def parseCommandLineSong(): 92 | desc = u'{0} [Args] [Options]\nDetailed options -h or --help'.format(__file__) 93 | parser = ArgumentParser(description=desc) 94 | 95 | print("Parsing Command Line Arguments...") 96 | requiredModel = parser.add_argument_group('Required Model arguments') 97 | requiredModel.add_argument('-m', choices = ["seq2seq", "char", "cbow", "duet"], type = str, 98 | dest = 'model', required = True, help = 'Type of model to run') 99 | 100 | parser.add_argument('-r', dest='real_song', default='', 101 | type=str, help='Sample from a real song') 102 | parser.add_argument('-t', dest='temperature', default=1.0, 103 | type=float, help='Temperature') 104 | parser.add_argument('-w', dest='warm_len', default=10, 105 | type=int, help='Warm start length') 106 | parser.add_argument('-n', dest='ckpt_num', default=-1, 107 | type=int, help='Checkpoint Number') 108 | 109 | args = parser.parse_args() 110 | return args 111 | 112 | if __name__ == "__main__": 113 | args = parseCommandLineSong() 114 | 115 | generateSong(args) -------------------------------------------------------------------------------- /utils_baseline.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | from utils import makedir 3 | 4 | import pickle 5 | import os 6 | import re 7 | import random 8 | 9 | import numpy as np 10 | 11 | def generateVocab(foldername, filename): 12 | 13 | with open(filename, 'r') as f: 14 | inputStr = f.readline() 15 | 16 | charList = Counter(inputStr).keys() 17 | 18 | dict2Store = {} 19 | for i,letter in enumerate(charList): 20 | dict2Store[letter] = i 21 | 22 | # write out to a file 23 | pickle.dump(dict2Store, open(os.path.join(foldername, 'vocab_map_baseline.p'), 'wb')) 24 | 25 | def encode(foldername, filename): 26 | outname = os.path.join(foldername, 'encoded.p') 27 | encodedList = [] 28 | encodeMap = pickle.load(open(os.path.join(foldername, 'vocab_map_baseline.p'), 'rb')) 29 | 30 | with open(filename, 'r') as f: 31 | inputStr = f.readline() 32 | 33 | for inStr in inputStr: 34 | encodedList.append(encodeMap[inStr]) 35 | 36 | with open(outname, 'wb') as f: 37 | pickle.dump(encodedList, f) 38 | 39 | def datasetNNInput(foldername, inputSz): 40 | encodedName = os.path.join(foldername, 'encoded.p') 41 | encodedList = pickle.load(open(encodedName, 'rb')) 42 | 43 | makedir(os.path.join(foldername, 'inputs')) 44 | 45 | iterNum = int((len(encodedList)-1)/inputSz) - 1 46 | empty_meta = np.asarray([0]*7) 47 | for i in range(iterNum): 48 | if i%100==0: 49 | print '%d/%d' %(i,iterNum) 50 | 51 | startIndx = i*inputSz 52 | endIndx = (i+1)*inputSz 53 | inData = np.asarray(encodedList[startIndx:endIndx]) 54 | labelData = np.asarray(encodedList[startIndx+1:endIndx+1]) 55 | 56 | data = [empty_meta, inData, labelData] 57 | outname = os.path.join(foldername, 'inputs/input_%d.p' % i) 58 | 59 | with open(outname, 'wb') as f: 60 | pickle.dump(data, f) 61 | 62 | def datasetSplit(folderName, setRatio): 63 | """ 64 | Split the dataset into training, testing, and dev sets. 65 | Usage: testTrainSplit('the_session_cleaned', (0.8,0.1,0.1)) 66 | """ 67 | if sum(setRatio)!=1: 68 | print '[ERROR] datasetSplit(): %f+%f+%f does not equal 1...' \ 69 | %(setRatio[0],setRatio[1],setRatio[2]) 70 | exit(0) 71 | 72 | inputsFname = os.path.join(folderName, 'inputs') 73 | filelist = os.listdir(inputsFname) 74 | 75 | random.shuffle(filelist) 76 | 77 | train_test_split_indx = int(len(filelist)*setRatio[0]) 78 | test_dev_split_indx = int(len(filelist)*(setRatio[0]+setRatio[1])) 79 | trainFiles = filelist[:train_test_split_indx] 80 | testFiles = filelist[train_test_split_indx:test_dev_split_indx] 81 | devFiles = filelist[test_dev_split_indx:] 82 | 83 | trainFilename = os.path.join(folderName, 'train') 84 | testFilename = os.path.join(folderName, 'test') 85 | devFilename = os.path.join(folderName, 'dev') 86 | makedir(trainFilename) 87 | makedir(testFilename) 88 | makedir(devFilename) 89 | 90 | inputfileList = [testFiles, trainFiles, devFiles] 91 | dirNames = [testFilename, trainFilename, devFilename] 92 | 93 | for itr in range(len(inputfileList)): 94 | nextSkip = len(inputfileList[itr]) / 8.0 + 1 95 | 96 | list2Save = [] 97 | count = 0 98 | for i,fname in enumerate(inputfileList[itr]): 99 | fromfname = os.path.join(inputsFname, fname) 100 | with open(fromfname, 'rb') as f: 101 | list2Save.append(pickle.load(f)) 102 | 103 | if i>nextSkip: 104 | outDirname = os.path.join(dirNames[itr], '%d.p' % count) 105 | with open(outDirname, 'wb') as f: 106 | pickle.dump(list2Save, f) 107 | 108 | list2Save = [] 109 | count += 1 110 | nextSkip += len(inputfileList[itr]) / 8.0 111 | 112 | outDirname = os.path.join(dirNames[itr], '%d.p' % 7) 113 | with open(outDirname, 'wb') as f: 114 | pickle.dump(list2Save, f) 115 | 116 | 117 | if __name__ == "__main__": 118 | # preprocessing pipeline 119 | #----------------------------------- 120 | filename = 'war_peace_cleaned.txt' 121 | processedDir = '/data/full_dataset/baseline' 122 | 123 | # print '-'*20 + 'GENERATING VOCAB' + '-'*20 124 | # generateVocab(processedDir, filename) 125 | 126 | # print '-'*20 + 'ENCODING' + '-'*20 127 | # encode(processedDir, filename) 128 | 129 | # print '-'*20 + 'FORMING NNINPUTS' + '-'*20 130 | # datasetNNInput(processedDir, 25) 131 | 132 | print '-'*20 + 'SPLITTING' + '-'*20 133 | datasetSplit(processedDir, (0.8,0.1,0.1)) 134 | #----------------------------------- -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CS224N Final Project 2 | 3 | ## Purpose 4 | 5 | Music artists have composed pieces that are both creative and precise. For example, classical music is well-known for its meticulous structure and emotional effect. Recurrent Neural Networks (RNNs) are powerful models that have achieved excellent performance on difficult learning tasks having temporal dependencies. We propose generative RNN models that create sheet music with well-formed structure and stylistic conventions without predefining music composition rules to the models. 6 | 7 | ## Related Deliverables 8 | 9 | [`Paper`](final_paper.pdf "Paper") 10 | 11 | [`Youtube Survey Videos`](https://www.youtube.com/watch?v=g8DTUFajung&list=PLSynD-DZWHaXvow0cawxhi7InYGQAfDzZ) 12 | 13 | [`Song Samples`](https://yinoue93.github.io/CS224N.html) 14 | 15 | 16 | ## Files 17 | * `midi_crawler.py` - crawls the Internet for .mid files. 18 | * Flags: **-u**: url, **-f**: output folder name, **-d**: crawl depth, **-r**: crawl regEx rules 19 | * `utils_preprocess.py` - utility script for midi preprocessing. 20 | 21 | ## Useful Websites 22 | 23 | 24 | 25 | 26 | ## Example .abc format 27 | 28 | ![exABC](images/example_abc.png?raw=true "Example .abc Music") 29 | 30 | ``` 31 | X:1 32 | T:NeilyCleere's 33 | R:polka 34 | M:2/4 35 | L:1/8 36 | K:Dmaj 37 | Q:100 38 | FG|A2A>B|=c/B/AFG|AB=c/B/A|G2FG|A>^GA>B|=c/B/Af2|edAF|G2:||:fg|a>gfa|gefg|a>gfa|g2fg|a>gfa|gef2|edAF|G2:| 39 | ``` 40 | 41 | ## Data Encoding Structure 42 | The numpy array representing each sample is composed of two parts: the metadata and the song. 43 | 44 | The first **7** integers in the numpy array are the metadata. They are, in order: **song type (R)**, **time signature (M)**, **note unit size (L)**, **number of flats (K)**, **song mode (K)**, **length**, **complexity**. 45 | 46 | Length is calculated by counting the distinct number of times the character '|' appears in a file, and complexity is calculated by (*number of notes in a song*) x 100/(*len* x *number of beats in a measure*). In other words, the complexity measure is trying to estimate how *busy* a song is. 47 | 48 | | | Description | .abc Tag | Dimensions | Examples (Top 3) | 49 | |-----------------|------------------------------------------------------------------------------|----------|------------|---------------------| 50 | | Song Type | Song Genre | R | 16 | Reel, Jig, Hornpipe | 51 | | Time Signature | Specifies how many beats are in each bar and which note value gets one beat | M | 15 | 4/4, 6/8, 3/4 | 52 | | Note Unit Size | Specifies which note value gets one beat in the text file | L | 3 | 1/8, 1/4, 1/16 | 53 | | Number of Flats | Positive for songs with flats, 0 for neutral, negative for songs with sharps | K | 12 | -1, -2, -3 | 54 | | Song Mode | 0=Major, 1=Minor, 2=Mixolydian, 3=Dorian, 4=Phrygian, 5=Lydian, 6=Locrian | K | 6 | 0, 1, 3 | 55 | | Song Length | Number of measures in a song | | | | 56 | | Song Complexity | Busy-ness of a song. | | | | 57 | 58 | The song portion of the numpy array is **82** dimensions (i.e. **80** music characters and **2** BEGIN/END special characters). 59 | 60 | ## Metadata and Music Encoding Map 61 | ``` 62 | >>> pickle.load(open('vocab_map_meta.p')) 63 | {'R': {'jig': 0, 'waltz': 1, 'three-two': 2, 'songair': 3, 'slowair': 4, 'strathspey': 5, 64 | 'polka': 6, 'air': 7, 'barndance': 8, 'slide': 9, 'slipjig': 10, 'hornpipe': 11, 65 | 'mazurka': 12, 'reel': 13, 'highlandfling': 14, 'quickstep': 15}, 66 | 'M': {'7/8': 1, '11/8': 2, '5/4': 0, '6/8': 3, '5/8': 4, '4/4': 5, '6/4': 6, '13/8': 7, 67 | '3/2': 8, '3/4': 9, '9/8': 10, '12/8': 11, '2/2': 12, '9/4': 13, '2/4': 14}, 68 | 'L': {'1/4': 0, '1/16': 1, '1/8': 2}, 69 | 'K_key': {'-5': 0, '-4': 1, '1': 2, '0': 3, '3': 4, '-6': 5, '-1': 6, '4': 7, '-3': 8, 70 | '-2': 9, '2': 10, '5': 11}, 71 | 'K_mode': {'1': 0, '0': 1, '3': 2, '2': 3, '5': 4, '4': 5}} 72 | ``` 73 | 74 | ``` 75 | >>> pickle.load(open('vocab_map_music.p')) 76 | {'!': 0, ' ': 1, '#': 2, "'": 3, '&': 4, ')': 5, '(': 6, '+': 7, '*': 8, '-': 9, ',': 10, 77 | '/': 11, '.': 12, '1': 13, '0': 14, '3': 15, '2': 16, '5': 17, '4': 18, '7': 19, '6': 20, 78 | '9': 21, '8': 22, ':': 23, '=': 24, '<': 25, '>': 26, 'A': 27, 'C': 28, 'B': 29, 'E': 30, 79 | 'D': 31, 'G': 32, 'F': 33, 'H': 34, 'K': 35, 'J': 36, 'M': 37, 'L': 38, 'O': 39, 'Q': 40, 80 | 'P': 41, 'S': 42, 'R': 43, 'U': 44, 'T': 45, 'V': 46, '[': 47, ']': 48, '\\': 49, '_': 50, 81 | '^': 51, 'a': 52, 'c': 53, 'b': 54, 'e': 55, 'd': 56, 'g': 57, 'f': 58, 'i': 59, 'h': 60, 82 | 'j': 61, 'm': 62, 'l': 63, 'o': 64, 'n': 65, 'p': 66, 's': 67, 'r': 68, 'u': 69, 't': 70, 83 | 'w': 71, 'v': 72, 'y': 73, 'x': 74, '{': 75, 'z': 76, '}': 77, '|': 78, '~': 79} 84 | ``` 85 | -------------------------------------------------------------------------------- /utils_models.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 2 | from tensorflow.python.framework import dtypes 3 | from tensorflow.python.framework import ops 4 | from tensorflow.python.ops import array_ops 5 | from tensorflow.python.ops import control_flow_ops 6 | from tensorflow.python.ops import math_ops 7 | from tensorflow.python.ops import random_ops 8 | from tensorflow.python.util import nest 9 | from tensorflow.python.ops import nn_ops 10 | 11 | 12 | def attention_decoder_fn_sampled_inference(output_fn, 13 | encoder_state, 14 | attention_keys, 15 | attention_values, 16 | attention_score_fn, 17 | attention_construct_fn, 18 | embeddings, 19 | start_of_sequence_id, 20 | end_of_sequence_id, 21 | maximum_length, 22 | num_decoder_symbols, 23 | dtype=dtypes.int32, 24 | temperature=None, 25 | name=None): 26 | 27 | with ops.name_scope(name, "attention_decoder_fn_inference", [ 28 | output_fn, encoder_state, attention_keys, attention_values, 29 | attention_score_fn, attention_construct_fn, embeddings, 30 | start_of_sequence_id, end_of_sequence_id, maximum_length, 31 | num_decoder_symbols, dtype 32 | ]): 33 | start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype) 34 | end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype) 35 | maximum_length = ops.convert_to_tensor(maximum_length, dtype) 36 | num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype) 37 | encoder_info = nest.flatten(encoder_state)[0] 38 | batch_size = encoder_info.get_shape()[0].value 39 | if output_fn is None: 40 | output_fn = lambda x: x 41 | if batch_size is None: 42 | batch_size = array_ops.shape(encoder_info)[0] 43 | 44 | def decoder_fn(time, cell_state, cell_input, cell_output, context_state): 45 | with ops.name_scope( 46 | name, "attention_decoder_fn_inference", 47 | [time, cell_state, cell_input, cell_output, context_state] 48 | ): 49 | if cell_input is not None: 50 | raise ValueError("Expected cell_input to be None, but saw: %s" % 51 | cell_input) 52 | if cell_output is None: 53 | # invariant that this is time == 0 54 | next_input_id = array_ops.ones( 55 | [batch_size,], dtype=dtype) * (start_of_sequence_id) 56 | done = array_ops.zeros([batch_size,], dtype=dtypes.bool) 57 | cell_state = encoder_state 58 | cell_output = array_ops.zeros( 59 | [num_decoder_symbols], dtype=dtypes.float32) 60 | cell_input = array_ops.gather(embeddings, next_input_id) 61 | 62 | # init attention 63 | attention = _init_attention(encoder_state) 64 | else: 65 | # construct attention 66 | attention = attention_construct_fn(cell_output, attention_keys, 67 | attention_values) 68 | cell_output = attention 69 | 70 | # sampled decoder 71 | cell_output = output_fn(cell_output) # logits 72 | if temperature: 73 | temperature_cell_output = math_ops.divide(cell_output, temperature) 74 | temperature_cell_output = nn_ops.softmax(temperature_cell_output) 75 | sampled_cell_output = random_ops.multinomial(cell_output, 1) 76 | sampled_cell_output = array_ops.reshape(sampled_cell_output, [-1]) 77 | else: 78 | sampled_cell_output = math_ops.argmax(cell_output, 1) 79 | next_input_id = math_ops.cast(sampled_cell_output, dtype=dtype) 80 | done = math_ops.equal(next_input_id, end_of_sequence_id) 81 | cell_input = array_ops.gather(embeddings, next_input_id) 82 | 83 | # combine cell_input and attention 84 | next_input = array_ops.concat([cell_input, attention], 1) 85 | 86 | # if time > maxlen, return all true vector 87 | done = control_flow_ops.cond( 88 | math_ops.greater(time, maximum_length), 89 | lambda: array_ops.ones([batch_size,], dtype=dtypes.bool), 90 | lambda: done) 91 | return (done, cell_state, next_input, cell_output, context_state) 92 | 93 | return decoder_fn 94 | 95 | 96 | 97 | 98 | def _init_attention(encoder_state): 99 | # Multi- vs single-layer 100 | if isinstance(encoder_state, tuple): 101 | top_state = encoder_state[-1] 102 | else: 103 | top_state = encoder_state 104 | 105 | # LSTM vs GRU 106 | if isinstance(top_state, core_rnn_cell_impl.LSTMStateTuple): 107 | attn = array_ops.zeros_like(top_state.h) 108 | else: 109 | attn = array_ops.zeros_like(top_state) 110 | 111 | return attn 112 | -------------------------------------------------------------------------------- /utils_hyperparam.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import pickle 3 | import os 4 | import datetime 5 | import re 6 | import ast 7 | import tensorflow as tf 8 | 9 | from argparse import ArgumentParser 10 | 11 | import numpy as np 12 | 13 | TMP_HYPER_PICKLE = 'tmp_hyperparam.p' 14 | OUTPUT_FILE = 'grid_search_result.txt' 15 | 16 | 17 | tf_ver = tf.__version__ 18 | SHERLOCK = (str(tf_ver) == '0.12.1') 19 | 20 | if SHERLOCK: 21 | DIR_MODIFIER = '/scratch/users/nipuna1' 22 | else: 23 | DIR_MODIFIER = '/data' 24 | 25 | DEV_CKPT_DIR = DIR_MODIFIER + '/dev_ckpt' 26 | 27 | def parseHyperTxt(paramTxtF): 28 | nameList = [] 29 | paramList = [] 30 | 31 | with open(paramTxtF, 'r') as paramF: 32 | count = 0 33 | for line in paramF: 34 | if count==0: 35 | modelType = line.replace('\n','').replace('\r','') 36 | count += 1 37 | continue 38 | elif count==1: 39 | num_epochs = int(line.replace('\n','').replace('\r','')) 40 | count += 1 41 | continue 42 | 43 | if '[' in line: 44 | name = [s.strip() for s in line.split(',')][0] 45 | listStr = line[line.find('['):(line.rfind(']')+1)] 46 | params = ast.literal_eval(listStr) 47 | else: 48 | name,start,end,step = [s.strip() for s in line.split(',')] 49 | 50 | startNum = float(start) 51 | endNum = float(end) 52 | stepNum = float(step) 53 | 54 | if startNum == endNum: 55 | params = [startNum] 56 | else: 57 | params = list(np.arange(startNum,endNum,stepNum)) 58 | params = [round(a,5) for a in params] 59 | # python list is not inclusive 60 | if params[-1]!=endNum: 61 | params.append(endNum) 62 | 63 | params = [(int(par) if par.is_integer() else par) for par in params] 64 | 65 | nameList.append(name) 66 | paramList.append(params) 67 | 68 | return modelType,num_epochs,nameList,paramList 69 | 70 | def runHyperparam(paramTxtF, dataset): 71 | """ 72 | Runs the gridsearch for hyperparameter tuning search. 73 | The grids are as defined in @paramTxtF 74 | """ 75 | 76 | if os.path.exists(OUTPUT_FILE): 77 | os.remove(OUTPUT_FILE) 78 | 79 | # parse the paramTxtF 80 | modelType, num_epochs, nameList, paramList = parseHyperTxt(paramTxtF) 81 | 82 | # create all combinations of params 83 | param_all_combos = list(itertools.product(*paramList)) 84 | 85 | print '[INFO] There are %d combinations of hyperparameters...' %len(param_all_combos) 86 | 87 | dataset_train = dataset 88 | dataset_dev = dataset.replace('train','dev') 89 | 90 | count = 0 91 | for param in param_all_combos: 92 | # create the param list and pickle it to TMP_HYPER_PICKLE 93 | paramDict = {} 94 | paramStrList = [] 95 | for name,par in zip(nameList,param): 96 | paramDict[name] = par 97 | 98 | paramStrList.append('{}: {}'.format(name, par)) 99 | 100 | paramStr = ','.join(paramStrList) + '\n' 101 | 102 | pickle.dump(paramDict, open(TMP_HYPER_PICKLE, 'wb')) 103 | 104 | with open(OUTPUT_FILE, 'a') as f: 105 | f.write(paramStr) 106 | 107 | # run the model 108 | print '='*80 109 | print 'Testing model with param: %s' % str(paramDict) 110 | print '='*80 111 | 112 | # train the model using the new hyperparameters 113 | print '-'*30 + 'TRAINING' + '-'*30 114 | if dataset=='': 115 | cmd = 'python run.py -p train -o -ckpt %s -m %s -e %d -c %s' \ 116 | %(DEV_CKPT_DIR, modelType, num_epochs, TMP_HYPER_PICKLE) 117 | else: 118 | cmd = 'python run.py -p train -o -ckpt %s -m %s -e %d -c %s -data %s' \ 119 | %(DEV_CKPT_DIR, modelType, num_epochs, TMP_HYPER_PICKLE, dataset_train) 120 | os.system(cmd) 121 | 122 | # test the model on the dev set 123 | print '-'*30 + 'TESTING DEV' + '-'*30 124 | if dataset=='': 125 | cmd = 'python run.py -p dev -ckpt %s -m %s -c %s' \ 126 | %(DEV_CKPT_DIR, modelType, TMP_HYPER_PICKLE) 127 | else: 128 | cmd = 'python run.py -p dev -ckpt %s -m %s -c %s -data %s' \ 129 | %(DEV_CKPT_DIR, modelType, TMP_HYPER_PICKLE, dataset_dev) 130 | os.system(cmd) 131 | 132 | os.remove(TMP_HYPER_PICKLE) 133 | 134 | # rename the result file with a timestamp 135 | now = datetime.datetime.now() 136 | resultName = '%s_%s.txt' %(OUTPUT_FILE.replace('.txt',''), 137 | now.strftime("%B_%d_%H_%M_%S")) 138 | os.rename(OUTPUT_FILE, resultName) 139 | 140 | 141 | def setHyperparam(config, hyperparam_path): 142 | """ 143 | To be called by a Config class. 144 | Reads TMP_HYPER_PICKLE, and sets the parameters of the @model accordingly. 145 | """ 146 | 147 | paramDict = pickle.load(open(hyperparam_path, 'rb')) 148 | 149 | for key,val in paramDict.iteritems(): 150 | setattr(config, key, val) 151 | 152 | setattr(config, 'dev_filename', OUTPUT_FILE) 153 | 154 | 155 | def resultParser(resultFname, top_N=3): 156 | """ 157 | Usage: python utils_hyperparam.py -m results -f grid_search_result_tmp.txt 158 | 159 | """ 160 | 161 | top_N_acc = np.array([0]*top_N, dtype=np.float32) 162 | top_N_str = ['']*top_N 163 | prev_str = '' 164 | config_map = None 165 | accuracy_list = [] 166 | with open(resultFname, 'r') as f: 167 | for line in f: 168 | line = line.replace('\n', '') 169 | if 'Dev set accuracy' in line: 170 | accuracy = float(re.search('0.[0-9]+', line).group(0)) 171 | accuracy_list.append(accuracy) 172 | 173 | if accuracy>top_N_acc[0]: 174 | top_N_acc[0] = accuracy 175 | top_N_str[0] = prev_str 176 | 177 | indx = np.argsort(top_N_acc) 178 | top_N_acc = top_N_acc[indx] 179 | top_N_str = [top_N_str[ind] for ind in list(indx)] 180 | else: 181 | prev_str = line 182 | 183 | tokenList = re.findall('[a-zA-Z_]+: [0-9]+\.*[0-9]*', line) 184 | configList = [re.match('[a-zA-Z_]+', token).group(0) for token in tokenList] 185 | valList = [float(re.findall('[0-9]+\.*[0-9]*', token)[0]) 186 | for token in tokenList] 187 | 188 | tokenList = re.findall('[a-zA-Z_]+: [a-zA-Z]+', line) 189 | configList += [token.split(' ')[0] for token in tokenList] 190 | valList += [token.split(' ')[1] for token in tokenList] 191 | 192 | if config_map==None: 193 | config_map = [] 194 | for configStr in configList: 195 | config_map.append([]) 196 | 197 | for i,val in enumerate(valList): 198 | config_map[i].append(val) 199 | 200 | 201 | print '-'*25 + 'Config. with top accuracy:' + '-'*25 202 | print np.flipud(top_N_acc) 203 | print '\n'.join(list(reversed(top_N_str))) 204 | print '-'*50 205 | 206 | print '='*25 + '"Flattened" Accuracies:' + '='*25 207 | accuracy_list = np.asarray(accuracy_list) 208 | for config_idx,configStr in enumerate(configList): 209 | print '*'*50 210 | print configStr 211 | 212 | config_list = np.asarray(config_map[config_idx]) 213 | for level in sorted(list(set(config_list))): 214 | indices = (config_list==level) 215 | print str(level)+':%0.5f' %(np.median(accuracy_list[indices])) 216 | 217 | print '='*50 218 | 219 | 220 | if __name__ == "__main__": 221 | desc = u'{0} [Args] [Options]\nDetailed options -h or --help'.format(__file__) 222 | parser = ArgumentParser(description=desc) 223 | 224 | parser.add_argument('-m', choices = ['results','tune'], type = str, 225 | dest = 'mode', required = True, help = 'Specify which mode to run') 226 | parser.add_argument('-f', type = str, default='', 227 | dest = 'filename', help = 'Filename to read results from') 228 | parser.add_argument('-data', type = str, default='', 229 | dest = 'dataset', help = 'Dataset to run run.py') 230 | parser.add_argument('-n', type = int, default=3, 231 | dest = 'top_N', help = 'Top N accuracies') 232 | parser.add_argument('-ckpt', type = str, default='', dest = 'ckpt_dir', 233 | help = 'Checkpoint to run the train/test set accuracy test') 234 | 235 | args = parser.parse_args() 236 | 237 | if args.mode == 'tune': 238 | hyperparamTxt = 'hparams_seq2seq.txt' 239 | runHyperparam(hyperparamTxt, args.dataset) 240 | elif args.mode == 'results': 241 | resultParser(args.filename, args.top_N) -------------------------------------------------------------------------------- /utils_runtime.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import random 4 | import numpy as np 5 | from argparse import ArgumentParser 6 | import reader 7 | 8 | import tensorflow as tf 9 | from utils_preprocess import * 10 | 11 | tf_ver = tf.__version__ 12 | SHERLOCK = (str(tf_ver) == '0.12.1') 13 | 14 | # for Sherlock 15 | if SHERLOCK: 16 | DIR_MODIFIER = '/scratch/users/nipuna1' 17 | # for Azure 18 | else: 19 | DIR_MODIFIER = '/data' 20 | 21 | def genWarmStartDataset(data_len, meta_map, music_map, 22 | dataFolder=os.path.join(DIR_MODIFIER, 'full_dataset/warmup_dataset/checked')): 23 | """ 24 | Generates metadata and music data for the use in warm starting the RNN models 25 | 26 | A file gets sampled from @dataFolder under ./checked file, and gets encoded using 27 | the 'vocab_map_meta.p' and 'vocab_map_music.p' files under @vocab_dir. 28 | 29 | The first @data_len characters in the music data is returned. 30 | """ 31 | 32 | oneHotHeaders = ('R', 'M', 'L', 'K_key', 'K_mode') 33 | otherHeaders = ('len', 'complexity') 34 | 35 | # while loop here, just in case that the file we choose contains characters that 36 | # does not appear in the original dataset 37 | while True: 38 | # pick a random file in dataFolder 39 | if os.path.isfile(dataFolder): 40 | abc_file = dataFolder 41 | else: 42 | abc_list = os.listdir(dataFolder) 43 | abc_file = os.path.join(dataFolder, random.choice(abc_list)) 44 | 45 | meta,music = loadCleanABC(abc_file) 46 | if data_len==-1: 47 | warm_str = music 48 | else: 49 | warm_str = music[:data_len-1] 50 | 51 | # start encoding 52 | meta_enList = [] 53 | music_enList = [] 54 | encodeSuccess = True 55 | 56 | # encode the metadata info 57 | for header in oneHotHeaders: 58 | if meta[header] not in meta_map[header]: 59 | encodeSuccess = False 60 | break 61 | else: 62 | meta_enList.append(meta_map[header][meta[header]]) 63 | 64 | for header in otherHeaders: 65 | meta_enList.append(meta[header]) 66 | 67 | # encode music data 68 | # add the BEGIN token 69 | music_enList.append(music_map['']) 70 | for i in range(len(warm_str)): 71 | c = music[i] 72 | if c not in music_map: 73 | encodeSuccess = False 74 | break 75 | else: 76 | music_enList.append(music_map[c]) 77 | 78 | if encodeSuccess: 79 | break 80 | 81 | print '-'*50 82 | print 'Generating the warm-start sequence...' 83 | print 'Chose %s to warm-start...' % abc_file 84 | print 'Meta Data is: %s' % str(meta) 85 | print 'The associated encoding is: %s' % str(meta_enList) 86 | print 'Music to warm-start with is: %s' % warm_str 87 | print 'The associated encoding is: %s' % str(music_enList) 88 | print '-'*50 89 | 90 | return meta_enList,music_enList 91 | 92 | 93 | def sample_with_temperature(logits, temperature): 94 | flattened_logits = logits.flatten() 95 | unnormalized = np.exp((flattened_logits - np.max(flattened_logits)) / temperature) 96 | probabilities = unnormalized / float(np.sum(unnormalized)) 97 | sample = np.random.choice(len(probabilities), p=probabilities) 98 | return sample 99 | 100 | 101 | def get_checkpoint(args, session, saver): 102 | # Checkpoint 103 | found_ckpt = False 104 | 105 | if args.override: 106 | if tf.gfile.Exists(args.ckpt_dir): 107 | tf.gfile.DeleteRecursively(args.ckpt_dir) 108 | tf.gfile.MakeDirs(args.ckpt_dir) 109 | 110 | # check if arags.ckpt_dir is a directory of checkpoints, or the checkpoint itself 111 | if len(re.findall('model.ckpt-[0-9]+', args.ckpt_dir)) == 0: 112 | ckpt = tf.train.get_checkpoint_state(args.ckpt_dir) 113 | if ckpt and ckpt.model_checkpoint_path: 114 | saver.restore(session, ckpt.model_checkpoint_path) 115 | i_stopped = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]) 116 | print "Found checkpoint for epoch ({0})".format(i_stopped) 117 | found_ckpt = True 118 | else: 119 | print('No checkpoint file found!') 120 | i_stopped = 0 121 | else: 122 | saver.restore(session, args.ckpt_dir) 123 | i_stopped = int(args.ckpt_dir.split('/')[-1].split('-')[-1]) 124 | print "Found checkpoint for epoch ({0})".format(i_stopped) 125 | found_ckpt = True 126 | 127 | 128 | return i_stopped, found_ckpt 129 | 130 | 131 | def save_checkpoint(args, session, saver, i): 132 | checkpoint_path = os.path.join(args.ckpt_dir, 'model.ckpt') 133 | saver.save(session, checkpoint_path, global_step=i) 134 | # saver.save(session, os.path.join(SUMMARY_DIR,'model.ckpt'), global_step=i) 135 | 136 | 137 | def encode_meta_batch(meta_vocabulary, meta_batch): 138 | new_meta_batch = [] 139 | vocab_lengths = [0] + [len(small_vocab) for small_vocab in meta_vocabulary.values()] 140 | num_values = len(meta_vocabulary.values()) 141 | 142 | for meta_data in meta_batch: 143 | new_meta_data = [meta_data[i] + sum(vocab_lengths[:i+1]) for i in xrange(num_values)] 144 | new_meta_data = np.append(new_meta_data, meta_data[5:]) 145 | new_meta_batch.append(new_meta_data) 146 | return new_meta_batch 147 | 148 | 149 | def create_noise_meta(meta_vocabulary): 150 | vocab_lengths = [len(small_vocab) for small_vocab in meta_vocabulary.values()] 151 | noise_meta_batch = [np.random.randint(upper_bound) for upper_bound in vocab_lengths] 152 | noise_meta_batch += [np.random.randint(10, high=41), np.random.randint(50, high=400)] 153 | return np.array(noise_meta_batch) 154 | 155 | 156 | def pack_feed_values(args, input_batch, label_batch, meta_batch, 157 | initial_state_batch, use_meta_batch, num_encode, num_decode): 158 | # if (args.train != "sample"): 159 | # for i, input_b in enumerate(input_batch): 160 | # if input_b.shape[0] != 50: 161 | # print "Input batch {0} contains and examples of size {1}".format(i, input_b.shape[0]) 162 | # input_batch[i] = np.zeros(50) 163 | 164 | # for j, label_b in enumerate(label_batch): 165 | # if label_b.shape[0] != 50: 166 | # print "Output batch {0} contains and examples of size {1}".format(j, label_b.shape[0]) 167 | # label_batch[j] = np.zeros(50) 168 | packed = [] 169 | 170 | input_batch = np.stack(input_batch) 171 | label_batch = np.stack(label_batch) 172 | 173 | packed = [] 174 | if args.model == 'seq2seq': 175 | packed += [input_batch.T, label_batch.T, meta_batch, initial_state_batch, use_meta_batch, num_encode, num_decode] 176 | # + attention? 177 | elif args.model == 'char': 178 | packed += [input_batch, label_batch, meta_batch, initial_state_batch, use_meta_batch] 179 | elif args.model == 'cbow': 180 | new_label_batch = [d[-1] for d in label_batch] 181 | packed += [input_batch, new_label_batch] 182 | elif args.model == 'gan': 183 | packed += [input_batch, label_batch, meta_batch, initial_state_batch, use_meta_batch] 184 | # MORE? 185 | return packed 186 | 187 | 188 | def parseCommandLine(): 189 | desc = u'{0} [Args] [Options]\nDetailed options -h or --help'.format(__file__) 190 | parser = ArgumentParser(description=desc) 191 | 192 | print("Parsing Command Line Arguments...") 193 | requiredModel = parser.add_argument_group('Required Model arguments') 194 | requiredModel.add_argument('-m', choices = ["seq2seq", "char", "cbow"], type = str, 195 | dest = 'model', required = True, help = 'Type of model to run') 196 | requiredTrain = parser.add_argument_group('Required Train/Test arguments') 197 | requiredTrain.add_argument('-p', choices = ["train", "test", "sample", "dev"], type = str, 198 | dest = 'train', required = True, help = 'Training or Testing phase to be run') 199 | 200 | requiredTrain.add_argument('-c', type = str, dest = 'set_config', 201 | help = 'Set hyperparameters', default='') 202 | 203 | parser.add_argument('-o', dest='override', action="store_true", help='Override the checkpoints') 204 | parser.add_argument('-e', dest='num_epochs', default=50, type=int, help='Set the number of Epochs') 205 | parser.add_argument('-ckpt', dest='ckpt_dir', default=DIR_MODIFIER + '/temp_ckpt/', type=str, help='Set the checkpoint directory') 206 | parser.add_argument('-data', dest='data_dir', default='', type=str, help='Set the data directory') 207 | 208 | args = parser.parse_args() 209 | return args 210 | 211 | 212 | 213 | 214 | 215 | if __name__ == "__main__": 216 | genWarmStartDataset(20) 217 | -------------------------------------------------------------------------------- /midi_manipulator.py: -------------------------------------------------------------------------------- 1 | import pretty_midi 2 | import os 3 | from shutil import copy 4 | import numpy as np 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | from multiprocessing import Pool, Lock, Array 9 | 10 | import h5py 11 | def write2hdf5(filename, dict2store, compression="lzf"): 12 | """ 13 | Write items in a dictionary to an hdf5file 14 | @type filename : String 15 | @param filename : Filename of the hdf5 file to output to. 16 | @type dict2store : Dict 17 | @param dict2store : Dictionary of items to store. The value should be an array. 18 | """ 19 | with h5py.File(filename,'w') as hf: 20 | for key,value in dict2store.iteritems(): 21 | hf.create_dataset(key, data=value,compression=compression) 22 | 23 | def eraseUnreadable(folderN): 24 | iterf = 0 25 | for fileN in os.listdir(folderN): 26 | fileAbsN = folderN+'/'+fileN 27 | try: 28 | if iterf%100 == 0: 29 | print iterf 30 | iterf += 1 31 | midi_data = pretty_midi.PrettyMIDI(fileAbsN) 32 | except: 33 | print 'failed: '+fileN 34 | os.remove(fileAbsN) 35 | 36 | def plotMIDI(file_name): 37 | midi_data = pretty_midi.PrettyMIDI(file_name) 38 | roll = midi_data.get_piano_roll() 39 | 40 | plt.matshow(roll[:,:2000], aspect='auto', origin='lower', cmap='magma') 41 | plt.show() 42 | 43 | def initChecker(l,clist): 44 | global LOCK,CHECKLIST 45 | LOCK = l 46 | CHECKLIST = clist 47 | 48 | def executeParallel(func, mapList): 49 | """ 50 | multiprocess a function 51 | @type func : function handle 52 | @param func : function to multiprocess 53 | @type mapList : list 54 | @param mapList : arguments to be given to the function 55 | Usage: 56 | executeParallel(reduceNpySz,os.listdir('/data/augmented_roi_original')) 57 | """ 58 | 59 | p = Pool(8) 60 | p.map(func, mapList) 61 | 62 | def checkTimeSignature(midi_data): 63 | timeSigs = midi_data.time_signature_changes 64 | 65 | return (len(timeSigs)==1) and (timeSigs[0].numerator==4 and timeSigs[0].denominator==4) 66 | 67 | def checkerWorker(dataPack): 68 | filename,outputFolder = dataPack 69 | midi_data = pretty_midi.PrettyMIDI(filename) 70 | timeSigPass = checkTimeSignature(midi_data) 71 | 72 | checkPass = timeSigPass 73 | 74 | LOCK.acquire() 75 | if checkPass: 76 | CHECKLIST[0] = CHECKLIST[0]+1 77 | else: 78 | CHECKLIST[1] = CHECKLIST[1]+1 79 | if sum(CHECKLIST)%250==0: 80 | print CHECKLIST[:] 81 | LOCK.release() 82 | 83 | if outputFolder!=None: 84 | copy(filename,outputFolder) 85 | 86 | def checker(folderName, outputFolder=None): 87 | if outputFolder != None: 88 | if not os.path.exists(outputFolder): 89 | os.makedirs(outputFolder) 90 | 91 | checkList = Array('i', [0,0]) 92 | LOCK = Lock() 93 | mapList = [(os.path.join(folderName,fname),outputFolder) for fname in os.listdir(folderName)] 94 | p = Pool(8, initializer=initChecker, initargs=(LOCK,checkList)) 95 | p.map(checkerWorker, mapList) 96 | 97 | def convert2pianoRollWorker(dataPack): 98 | h5name,filenames = dataPack 99 | h5Dict = {} 100 | for fname in filenames: 101 | midi_data = pretty_midi.PrettyMIDI(fname) 102 | outname = os.path.basename(fname) 103 | outname = outname[:outname.rfind('.mid')]+'.npy' 104 | h5Dict[outname] = midi_data.get_piano_roll() 105 | 106 | write2hdf5(h5name, h5Dict) 107 | 108 | def convert2pianoRoll(folderName): 109 | """ 110 | Converts .mid files to pianoroll .npys, and save them in .hdf5 format. 111 | """ 112 | outputFolder = folderName + '_pianoroll' 113 | if not os.path.exists(outputFolder): 114 | os.makedirs(outputFolder) 115 | 116 | p = Pool(8) 117 | filenames = [os.path.join(folderName,fname) for fname in os.listdir(folderName)][0:13] 118 | batchNum = 2000 119 | indx = 0 120 | mapList = [] 121 | while indx127: 200 | pianoroll = pianoroll/np.max(pianoroll)*127 201 | if not use_velocity: 202 | pianoroll[np.nonzero(pianoroll)] = 100 203 | 204 | num_notes,midi_duration = pianoroll.shape 205 | duration_history = [0]*num_notes 206 | velocity_history = [0]*num_notes 207 | deltaT = 1.0/fs 208 | for timeIndx in range(midi_duration): 209 | timeshot = pianoroll[:,timeIndx] 210 | 211 | for noteIndx in range(num_notes): 212 | if (velocity_history[noteIndx]!=timeshot[noteIndx]) \ 213 | or (timeIndx==midi_duration-1 and timeshot[noteIndx]>0): 214 | if velocity_history[noteIndx]>0: 215 | start_t = duration_history[noteIndx]*deltaT 216 | end_t = timeIndx*deltaT 217 | note = pretty_midi.Note(velocity=int(velocity_history[noteIndx]), 218 | pitch=noteIndx, 219 | start=start_t, end=end_t) 220 | instrument.notes.append(note) 221 | 222 | velocity_history[noteIndx] = timeshot[noteIndx] 223 | duration_history[noteIndx] = timeIndx 224 | 225 | pm.instruments.append(instrument) 226 | pm.write(outfilename) 227 | 228 | def convertMidiAbcWorker(mapList): 229 | # abc2midi.exe taken from http://ifdo.ca/~seymour/runabc/top.html 230 | filename,fromDir,toDir,abc2midi = mapList 231 | fromStr,toStr,binName = ('.abc','.mid','abc2midi') \ 232 | if abc2midi else ('.mid','.abc','midi2abc') 233 | if fromStr not in filename: 234 | return 235 | 236 | fromFile = os.path.join(fromDir,filename) 237 | toFile = os.path.join(toDir,filename.replace(fromStr,'')+toStr) 238 | 239 | os.system('abcmidi_win32\\%s.exe "%s" -o "%s" -silent' %(binName,fromFile,toFile)) 240 | 241 | def convertMidiAbc(fileDir, abc2midi): 242 | outputFolder = fileDir + ('_midi' if abc2midi else '_abc') 243 | if not os.path.exists(outputFolder): 244 | os.makedirs(outputFolder) 245 | 246 | p = Pool(8) 247 | mapList = [(fname,fileDir,outputFolder,abc2midi) for fname in os.listdir(fileDir)] 248 | 249 | p.map(convertMidiAbcWorker, mapList) 250 | 251 | if __name__ == "__main__": 252 | convertMidiAbc('the_session', abc2midi=True) 253 | #eraseUnreadable('video_games') 254 | #plotMIDI('video_games/dw2.mid') 255 | 256 | # checker('video_games','video_games_cleaned') 257 | 258 | # midi_data = pretty_midi.PrettyMIDI('video_games/Zelda_2-_Battle_Stage.mid') 259 | # piano = midi_data.get_piano_roll() 260 | # pianoroll2midi(piano, 'sample.mid') 261 | pass -------------------------------------------------------------------------------- /midi_crawler.py: -------------------------------------------------------------------------------- 1 | import sys, os, shutil 2 | 3 | from argparse import ArgumentParser 4 | import re 5 | 6 | import urllib 7 | import urllib2 8 | from bs4 import BeautifulSoup 9 | from urlparse import urljoin 10 | from posixpath import basename 11 | 12 | import threading 13 | 14 | max_thread = 2 15 | sema = threading.Semaphore(max_thread) 16 | lock = threading.Lock() 17 | 18 | def html_downloader(url, next_htmls, folderName): 19 | try: 20 | response = urllib2.urlopen(url) 21 | 22 | # check if the url is a valid html website 23 | if "text/html" in response.headers["content-type"]: 24 | htmlStr = response.read() 25 | next_htmls.append((url, htmlStr)) 26 | 27 | elif '.mid' in url: 28 | file = open(folderName+'\\'+basename(url), 'wb') 29 | file.write(response.read()) 30 | file.close() 31 | sys.stdout.write('.') 32 | except: 33 | pass 34 | 35 | sema.release() 36 | exit() 37 | 38 | def titleFinder(url, visited_urls, htmlStr, folderName, urlRegex=None, depth=1): 39 | print '\n'+url 40 | soup = BeautifulSoup(htmlStr, "html.parser") 41 | 42 | # download the htmls 43 | thread_list = [] 44 | next_htmls = [] 45 | for link in soup.findAll('a'): 46 | if link.has_attr('href'): 47 | new_path = urljoin(url, link['href']) 48 | 49 | if (urlRegex!=None): 50 | regexPass= False 51 | for regEx in urlRegex: 52 | if (re.search(regEx, new_path)!=None): 53 | regexPass = True 54 | break 55 | if not regexPass: 56 | continue 57 | 58 | if (new_path not in visited_urls) and (depth!=0 or ('.mid' in new_path)): 59 | visited_urls.append(new_path) 60 | 61 | sema.acquire(True) 62 | th = threading.Thread(target=html_downloader, args=(new_path, next_htmls, folderName)) 63 | 64 | thread_list.append(th) 65 | th.start() 66 | 67 | # wait for the threads to finish 68 | for th in thread_list: 69 | th.join() 70 | 71 | if depth==0: 72 | return 73 | 74 | # keep on crawling 75 | for nextURL,nextResp in next_htmls: 76 | titleFinder(nextURL, visited_urls, nextResp, folderName, urlRegex, depth=depth-1) 77 | 78 | counter = 0 79 | def thesession_downloader(pageNum, folderName): 80 | try: 81 | base_url = 'https://thesession.org' 82 | url = base_url+'/tunes/'+str(pageNum) 83 | 84 | response = urllib2.urlopen(url) 85 | 86 | soup = BeautifulSoup(response.read(), "html.parser") 87 | 88 | songName = soup.find('h1').getText() 89 | count = 0 90 | for link in soup.findAll('a'): 91 | if link.has_attr('href'): 92 | if re.match('/tunes/%s/abc/[0-9]+' % pageNum, link['href']): 93 | if os.path.exists('%s\\%s_%d.abc' %(folderName,songName,count)): 94 | continue 95 | 96 | abcUrl = base_url+link['href'] 97 | 98 | file = open('%s\\%s_%d.abc' %(folderName,songName,count), 'wb') 99 | abcResponse = urllib2.urlopen(abcUrl) 100 | file.write(abcResponse.read()) 101 | file.close() 102 | sys.stdout.write('.') 103 | 104 | count += 1 105 | 106 | except: 107 | sys.stdout.write('X') 108 | pass 109 | 110 | global counter 111 | counter += 1 112 | print counter 113 | sema.release() 114 | exit() 115 | 116 | def scrapeTheSession(folderName): 117 | thread_list = [] 118 | for i in xrange(16000,17000): 119 | sema.acquire(True) 120 | th = threading.Thread(target=thesession_downloader, args=(i, folderName)) 121 | 122 | thread_list.append(th) 123 | th.start() 124 | 125 | # wait for the threads to finish 126 | for th in thread_list: 127 | th.join() 128 | 129 | def montreal_downloader(url, outputname): 130 | try: 131 | response = urllib2.urlopen(url) 132 | soup = BeautifulSoup(response.read(), "html.parser") 133 | 134 | abcTxt = soup.find('div', {'class':'abc'}) 135 | 136 | if abcTxt!=None: 137 | with open(outputname, 'wb') as abcf: 138 | abcf.write(re.sub(r'[^\x00-\x7f]',r'',abcTxt.getText())) 139 | 140 | sys.stdout.write('.') 141 | 142 | except: 143 | sys.stdout.write('X') 144 | pass 145 | 146 | sema.release() 147 | 148 | def scrapeTheMontreal(folderName): 149 | # first find the urls of all songs 150 | url = 'http://music.gordfisch.net/montrealsession/complete.php' 151 | response = urllib2.urlopen(url) 152 | soup = BeautifulSoup(response.read(), "html.parser") 153 | 154 | song_urls = [] 155 | song_names = [] 156 | for link in soup.findAll('a', href=True): 157 | song_name = re.sub(r'[^\x00-\x7f]',r'',link.getText()) 158 | if len(song_name)<3: 159 | continue 160 | 161 | song_name = song_name.replace('.','').replace(',','')+'.abc' 162 | count = 0 163 | while song_name.replace('.abc','_%d.abc'%count) in song_names: 164 | count += 1 165 | song_names.append(song_name.replace('.abc','_%d.abc'%count)) 166 | song_urls.append(urljoin(url, link['href'])) 167 | 168 | thread_list = [] 169 | for song_url,song_name in zip(song_urls,song_names): 170 | sema.acquire(True) 171 | th = threading.Thread(target=montreal_downloader, 172 | args=(song_url, os.path.join(folderName,song_name))) 173 | 174 | thread_list.append(th) 175 | th.start() 176 | 177 | # wait for the threads to finish 178 | for th in thread_list: 179 | th.join() 180 | 181 | def abcNotation_downloader(url, outputname): 182 | try: 183 | for i in range(10): 184 | try: 185 | response = urllib2.urlopen(url, timeout=5) 186 | break 187 | except: 188 | pass 189 | 190 | soup = BeautifulSoup(response.read(), "html.parser") 191 | 192 | abcTxt = soup.find('textarea') 193 | if abcTxt!=None: 194 | with open(outputname, 'wb') as abcf: 195 | abcf.write(re.sub(r'[^\x00-\x7f]',r'',abcTxt.getText())) 196 | 197 | except: 198 | print url 199 | pass 200 | 201 | sema.release() 202 | 203 | def abcNotation_urls(url, song_urls, song_names): 204 | for i in range(10): 205 | try: 206 | response = urllib2.urlopen(url, timeout=5) 207 | break 208 | except: 209 | pass 210 | 211 | soup = BeautifulSoup(response.read(), "html.parser") 212 | 213 | tmp_song_names = [] 214 | tmp_song_urls = [] 215 | link = soup.find('pre') 216 | for song_obj in link.findAll('a', href=True): 217 | song_name = re.sub(r'[^\x00-\x7f]',r'',song_obj.getText()) 218 | song_name = song_name.replace('.','').replace(',','').replace(' ','_')+'.abc' 219 | tmp_song_names.append(song_name) 220 | 221 | tmp_song_urls.append(urljoin(url, song_obj['href'])) 222 | 223 | lock.acquire() 224 | for song_name in tmp_song_names: 225 | count = 0 226 | while song_name.replace('.abc','_%d.abc'%count) in song_names: 227 | count += 1 228 | 229 | song_names.append(song_name.replace('.abc','_%d.abc'%count)) 230 | 231 | song_urls += tmp_song_urls 232 | 233 | lock.release() 234 | sema.release() 235 | 236 | 237 | def scrapeTheABCNotation(folderName): 238 | # first find the urls of all songs 239 | base_url = 'http://abcnotation.com/searchTunes?q=%s&f=t&o=a&s=' %'[v:1]' 240 | 241 | song_urls = [] 242 | song_names = [] 243 | thread_list = [] 244 | for i in xrange(20,25): 245 | print i 246 | sema.acquire(True) 247 | url = base_url+str(50*i) 248 | th = threading.Thread(target=abcNotation_urls, 249 | args=(url, song_urls, song_names)) 250 | 251 | thread_list.append(th) 252 | th.start() 253 | 254 | # wait for the threads to finish 255 | for th in thread_list: 256 | th.join() 257 | 258 | thread_list = [] 259 | for song_url,song_name in zip(song_urls,song_names): 260 | sema.acquire(True) 261 | th = threading.Thread(target=abcNotation_downloader, 262 | args=(song_url, os.path.join(folderName, song_name))) 263 | 264 | thread_list.append(th) 265 | th.start() 266 | 267 | # wait for the threads to finish 268 | for th in thread_list: 269 | th.join() 270 | 271 | 272 | def scrapeLocally(url, folderName): 273 | """ 274 | Scrapes .abc files stored locally under @url. 275 | @url should be a file with multiple .abc files listed. 276 | """ 277 | for fname in os.listdir(url): 278 | with open(os.path.join(url,fname),'r') as abcfile: 279 | count = 0 280 | during_file = False 281 | 282 | for line in abcfile: 283 | line = line.strip() 284 | if not during_file and line.replace(' ','')[:2]=='X:': 285 | during_file = True 286 | newFile = '%s_%d.abc' % (fname,count) 287 | count += 1 288 | f = open(os.path.join(folderName,newFile), 'w') 289 | 290 | if during_file: 291 | if len(line)==0: 292 | during_file = False 293 | f.close() 294 | else: 295 | f.write(line+'\n') 296 | 297 | def parseCLI(): 298 | desc = u'{0} [Args] [Options]\nDetailed options -h or --help'.format(__file__) 299 | 300 | parser = ArgumentParser(description=desc) 301 | 302 | parser.add_argument('-u', '--url', type = str, dest = 'url', required = True, 303 | help = 'URL of the website to crawl in') 304 | parser.add_argument('-f', '--folderName', type = str, dest = 'folderName', required = True, 305 | help = 'MIDI output folder') 306 | parser.add_argument('-d', '--depth', type = int, dest = 'depth', default=2, 307 | help = 'Crawl depth') 308 | parser.add_argument('-r', '--urlRegex', type = str, dest = 'urlRegex', 309 | help = 'RegEx urls need to follow') 310 | 311 | args = parser.parse_args() 312 | return args 313 | 314 | if __name__ == "__main__": 315 | args = parseCLI() 316 | 317 | # make the directory for the created files 318 | try: 319 | os.mkdir(args.folderName) 320 | except: 321 | pass 322 | 323 | if args.urlRegex!=None: 324 | args.urlRegex = args.urlRegex.split(',') 325 | 326 | # start crawling 327 | if 'thesession.org' in args.url: 328 | scrapeTheSession(args.folderName) 329 | elif 'montreal' in args.url: 330 | scrapeTheMontreal(args.folderName) 331 | elif 'abcnotation' in args.url: 332 | scrapeTheABCNotation(args.folderName) 333 | elif ('http' not in args.url) and ('www' not in args.url): 334 | scrapeLocally(args.url, args.folderName) 335 | else: 336 | visited_urls = [] 337 | response = urllib2.urlopen(args.url) 338 | html = response.read() 339 | titleFinder(args.url, visited_urls, html, args.folderName, args.urlRegex, depth=args.depth) 340 | 341 | print "done" -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pretty_midi 4 | import re 5 | import pickle 6 | import random 7 | 8 | 9 | import tensorflow as tf 10 | 11 | from collections import Counter 12 | 13 | CHECK_DIR = 'checked' 14 | 15 | tf_ver = tf.__version__ 16 | SHERLOCK = (str(tf_ver) == '0.12.1') 17 | 18 | # for Sherlock 19 | if SHERLOCK: 20 | DIR_MODIFIER = '/scratch/users/nipuna1' 21 | # for Azure 22 | else: 23 | DIR_MODIFIER = '/data' 24 | 25 | # Midi Related 26 | #------------------------------------ 27 | def plotMIDI(file_name): 28 | """ 29 | Usage: plotMIDI('video_games/dw2.mid') 30 | """ 31 | midi_data = pretty_midi.PrettyMIDI(file_name) 32 | roll = midi_data.get_piano_roll() 33 | 34 | plt.matshow(roll[:,:2000], aspect='auto', origin='lower', cmap='magma') 35 | plt.show() 36 | 37 | 38 | # File Manipulation Related 39 | #------------------------------------ 40 | def makedir(outputFolder): 41 | if not os.path.exists(outputFolder): 42 | os.makedirs(outputFolder) 43 | 44 | import h5py 45 | def write2hdf5(filename, dict2store, compression="lzf"): 46 | """ 47 | Write items in a dictionary to an hdf5file 48 | @type filename : String 49 | @param filename : Filename of the hdf5 file to output to. 50 | @type dict2store : Dict 51 | @param dict2store : Dictionary of items to store. The value should be an array. 52 | 53 | Usage: write2hdf5('encoded_data.h5',{'data':os.listdir('the_session_cleaned_checked_encoded')}) 54 | """ 55 | with h5py.File(filename,'w') as hf: 56 | for key,value in dict2store.iteritems(): 57 | hf.create_dataset(key, data=value,compression=compression) 58 | 59 | 60 | def hdf52dict(hdf5Filename): 61 | """ 62 | Loads an HDF5 file of a game and returns a dictionary of the contents 63 | @type hdf5Filename: String 64 | @param hdf5Filename: Filename of the hdf5 file. 65 | """ 66 | retDict = {} 67 | with h5py.File(hdf5Filename,'r') as hf: 68 | for key in hf.keys(): 69 | retDict[key] = np.array(hf.get(key)) 70 | 71 | return retDict 72 | 73 | def abc2h5(folderName='the_session_cleaned_checked_encoded', outputFile='encoded_data.h5'): 74 | encodeDict = {} 75 | for filestr in os.listdir(folderName): 76 | encodeDict[filestr] = np.load(os.path.join(folderName,filestr)) 77 | write2hdf5(outputFile,encodeDict) 78 | 79 | def find_basename(filename): 80 | firstNum = re.search('[0-9]', filename).group(0) 81 | if filename[0]==firstNum: 82 | firstNum = re.search('_', filename).group(0) 83 | return filename[:(filename.find(firstNum)-1)] 84 | 85 | def datasetSplit(folderName, setRatio): 86 | """ 87 | Split the dataset into training, testing, and dev sets. 88 | Usage: testTrainSplit('the_session_cleaned', (0.8,0.1,0.1)) 89 | """ 90 | if sum(setRatio)!=1: 91 | print '[ERROR] datasetSplit(): %f+%f+%f does not equal 1...' \ 92 | %(setRatio[0],setRatio[1],setRatio[2]) 93 | exit(0) 94 | 95 | songlist = set() 96 | for filename in os.listdir(os.path.join(folderName, CHECK_DIR)): 97 | basename = find_basename(filename) 98 | songlist.add(basename) 99 | 100 | songlist = list(songlist) 101 | random.shuffle(songlist) 102 | 103 | train_test_split_indx = int(len(songlist)*setRatio[0]) 104 | test_dev_split_indx = int(len(songlist)*(setRatio[0]+setRatio[1])) 105 | trainSongs = songlist[:train_test_split_indx] 106 | testSongs = songlist[train_test_split_indx:test_dev_split_indx] 107 | devSongs = songlist[test_dev_split_indx:] 108 | 109 | pickle.dump(trainSongs, open(os.path.join(folderName, 'train_songs.p'),'wb')) 110 | pickle.dump(testSongs, open(os.path.join(folderName, 'test_songs.p'),'wb')) 111 | pickle.dump(devSongs, open(os.path.join(folderName, 'dev_songs.p'),'wb')) 112 | 113 | #------------------------------------ 114 | 115 | # .abc Related 116 | #------------------------------------ 117 | def findNumMeasures(music): 118 | return music.replace('||','|').replace('|||','|').count('|') 119 | 120 | def transposeABC(fromFile, toFile, shiftLvl): 121 | """ 122 | Transposes the .abc file in @fromFile by @shiftLvl and saves it to @toFile 123 | 124 | abc2abc.exe taken from http://ifdo.ca/~seymour/runabc/top.html 125 | """ 126 | 127 | # am I being ran on a Windows machine ('nt'), or a linux machine('posix')? 128 | if os.name=='posix': 129 | exe_cmd = 'abc2abc' 130 | elif os.name=='nt': 131 | exe_cmd = 'abcmidi_win32\\abc2abc.exe' 132 | 133 | cmd = '%s "%s" -t %d -e > "%s"' \ 134 | %(exe_cmd,fromFile,shiftLvl,toFile) 135 | print toFile 136 | 137 | os.system(cmd) 138 | 139 | MODE_MAJ = 0 140 | MODE_MIN = 1 141 | MODE_MIX = 2 142 | MODE_DOR = 3 143 | MODE_PHR = 4 144 | MODE_LYD = 5 145 | MODE_LOC = 6 146 | def keySigDecomposer(line): 147 | """ 148 | Decompose the key signature into two portions- key and mode 149 | 150 | Returns: 151 | key - number of flats, negative for sharps 152 | mode - as defined by MODE_ constants 153 | """ 154 | 155 | # first determine the mode 156 | mode = MODE_MAJ 157 | 158 | searchList = [('mix',MODE_MIX),('dor',MODE_DOR),('phr',MODE_PHR),('lyd',MODE_LYD), 159 | ('loc',MODE_LOC),('maj',MODE_MAJ),('min',MODE_MIN),('m',MODE_MIN), 160 | ('p',MODE_PHR)] 161 | 162 | lower = line.lower() 163 | for searchTup in searchList: 164 | if searchTup[0] in lower: 165 | mode = searchTup[1] 166 | line = line[:lower.rfind(searchTup[0])] 167 | break 168 | 169 | # then determine the key 170 | keys = ['B#','E#','A#','D#','G#','C#','F#','B','E','A','D','G','C', 171 | 'F','Bb','Eb','Ab','Db','Gb','Cb','Fb'] 172 | mode_modifier = {MODE_MAJ:-12, MODE_MIN:-9, MODE_MIX:-11, MODE_DOR:-10, 173 | MODE_PHR:-8, MODE_LYD:-13, MODE_LOC:-7} 174 | 175 | key = keys.index(line) + mode_modifier[mode] 176 | 177 | return str(key),str(mode) 178 | 179 | def keySigComposer(num_flats, mode): 180 | """ 181 | Reverses keySigDecomposer 182 | """ 183 | num_flats = int(num_flats) 184 | mode = int(mode) 185 | 186 | keys = ['B#','E#','A#','D#','G#','C#','F#','B','E','A','D','G','C', 187 | 'F','Bb','Eb','Ab','Db','Gb','Cb','Fb'] 188 | mode_modifier = {MODE_MAJ:12, MODE_MIN:9, MODE_MIX:11, MODE_DOR:10, 189 | MODE_PHR:8, MODE_LYD:13, MODE_LOC:7} 190 | mode_name = {MODE_MAJ:'', MODE_MIN:'m', MODE_MIX:'mix', MODE_DOR:'dor', 191 | MODE_PHR:'phr', MODE_LYD:'lyd', MODE_LOC:'loc'} 192 | 193 | return 'K:%s%s\n' %(keys[mode_modifier[mode]+num_flats], mode_name[mode]) 194 | 195 | def loadCleanABC(abcname): 196 | """ 197 | Loads a file in .abc format (cleaned), and returns the meta data and music contained 198 | in the file. 199 | 200 | @meta - dictionary of metadata, key is the metadata type (ex. 'K') 201 | @music - string of the music 202 | """ 203 | headerTup = ('X', 'T', 'R', 'M', 'L', 'K', 'Q') 204 | 205 | meta = {} 206 | music = '' 207 | counter = len(headerTup) 208 | with open(abcname,'r') as abcfile: 209 | for line in abcfile: 210 | # break down the key signature into # of sharps and flats 211 | # and mode 212 | if counter>0: 213 | if line[0]=='K': 214 | try: 215 | meta['K_key'],meta['K_mode'] = keySigDecomposer(line[2:-1]) 216 | except: 217 | print 'Key signature decomposition failed for file: ' + abcname 218 | raise Exception('Key signature decomposition failed for file: ' + abcname) 219 | elif line[0]=='M': 220 | if 'C' in line[2:-1]: 221 | meta['M'] = '4/4' 222 | else: 223 | meta['M'] = line[2:-1] 224 | else: 225 | meta[line[0]] = line[2:-1] 226 | counter -= 1 227 | else: 228 | music = line[:-1] 229 | 230 | notes = [chr(i) for i in range(ord('a'),ord('g')+1)] 231 | notes += [c.upper() for c in notes] 232 | # add metadata that we manually create 233 | meta['len'] = findNumMeasures(music) 234 | countList = Counter(music) 235 | timeSigNumerator = int(meta['M'][:meta['M'].find('/')]) 236 | meta['complexity'] = (sum(countList[c] for c in notes)*100)/(meta['len']*timeSigNumerator) 237 | 238 | return meta,music 239 | 240 | def writeCleanABC(meta, music, outfile=''): 241 | headerTup = ('X', 'T', 'R', 'M', 'L', 'K', 'Q') 242 | 243 | abcStr = '' 244 | for header in headerTup: 245 | if header=='K': 246 | abcStr += keySigComposer(meta['K_key'], meta['K_mode']) 247 | else: 248 | abcStr += '%s:%s\n' %(header, meta[header]) 249 | 250 | abcStr += music + '\n' 251 | 252 | if len(outfile)!=0: 253 | with open(outfile,'w') as f: 254 | f.write(abcStr) 255 | 256 | return abcStr 257 | 258 | import subprocess 259 | def passesABC2ABC(fromFile): 260 | """ 261 | Returns true if the .abc file in @fromFile passes the abc2abc.exe check 262 | """ 263 | 264 | # am I being ran on a Windows machine ('nt'), or a linux machine('posix')? 265 | if os.name=='posix': 266 | cmd = 'abc2abc' 267 | elif os.name=='nt': 268 | cmd = 'abcmidi_win32\\abc2abc.exe' 269 | 270 | cmdlist = [cmd, fromFile] 271 | proc = subprocess.Popen(cmdlist, stdout=subprocess.PIPE) 272 | 273 | (out, err) = proc.communicate() 274 | 275 | # error check 276 | errorCnt_bar = out.count('Error : Bar') 277 | errorCnt = out.count('Error') - out.count('ignored') 278 | if errorCnt_bar>2 or errorCnt!=errorCnt_bar: 279 | return False 280 | elif errorCnt>0: 281 | barErrorList = re.findall('Bar [0-9]+', out) 282 | for i,barStr in enumerate(barErrorList): 283 | barErrorList[i] = int(re.search('[0-9]+',barStr).group(0)) 284 | 285 | if barErrorList[0] == 1: 286 | errorCnt -= 1 287 | 288 | if abs(findNumMeasures(out)-barErrorList[-1])<3: 289 | errorCnt -= 1 290 | 291 | return errorCnt==0 292 | 293 | def encoding2ABC(metaList, musicList, meta_map, music_map, outputname=None, 294 | vocab_dir=os.path.join(DIR_MODIFIER, 'the_session_processed')): 295 | """ 296 | Converts lists encoding of .abc song into .abc string 297 | 298 | @metaList - A list of encoded metadata 299 | @musicList - A list of encoded music 300 | """ 301 | 302 | oneHotHeaders = ('R', 'M', 'L') 303 | 304 | meta_reverse = {} 305 | for header in meta_map.keys(): 306 | meta_ori = meta_map[header] 307 | meta_reverse[header] = dict(zip(meta_ori.values(), meta_ori.keys())) 308 | 309 | music_reverse = dict(zip(music_map.values(), music_map.keys())) 310 | 311 | abcStr = 'X: 1\n' 312 | for i in range(len(oneHotHeaders)): 313 | header = oneHotHeaders[i] 314 | abcStr += '%s: %s\n' %(header, str(meta_reverse[header][metaList[i]])) 315 | 316 | num_flats = int(meta_reverse['K_key'][metaList[len(oneHotHeaders)]]) 317 | mode = int(meta_reverse['K_mode'][metaList[len(oneHotHeaders)+1]]) 318 | 319 | abcStr += keySigComposer(num_flats, mode) 320 | 321 | for music_ch in musicList: 322 | if music_ch inf, p is uniform. Easy to sample from! 48 | # For T --> 0, p "concentrates" on arg max. Hard to sample from! 49 | TEMPERATURE = 1.0 50 | 51 | 52 | 53 | 54 | def run_gan(args): 55 | use_seq2seq_data = (args.model == 'seq2seq') 56 | if args.data_dir != '': 57 | dataset_dir = args.data_dir 58 | elif args.train == 'train': 59 | dataset_dir = GAN_TRAIN_DATA 60 | elif args.train == 'test': 61 | dataset_dir = GAN_TEST_DATA 62 | else: # args.train == 'dev' or 'sample' (which has no dataset, but we just read anyway) 63 | dataset_dir = GAN_DEVELOPMENT_DATA 64 | 65 | 66 | print 'Using dataset %s' %dataset_dir 67 | dateset_filenames = reader.abc_filenames(dataset_dir) 68 | 69 | # figure out the input data size 70 | window_sz = int(re.findall('[0-9]+', re.findall('window_[0-9]+', dataset_dir)[0])[0]) 71 | if 'output_sz' in dataset_dir: 72 | label_sz = int(re.findall('[0-9]+', re.findall('output_sz_[0-9]+', dataset_dir)[0])[0]) 73 | else: 74 | label_sz = window_sz 75 | 76 | input_size = 1 if (args.train == "sample" and args.model!='cbow') else window_sz 77 | initial_size = 7 78 | label_size = 1 if args.train == "sample" else label_sz 79 | batch_size = 1 if args.train == "sample" else BATCH_SIZE 80 | NUM_EPOCHS = args.num_epochs 81 | print "Using checkpoint directory: {0}".format(args.ckpt_dir) 82 | 83 | # Getting vocabulary mapping: 84 | vocabulary = reader.read_abc_pickle(VOCAB_DATA) 85 | vocab_sz = len(vocabulary) 86 | vocabulary[""] = vocab_sz 87 | vocabulary[""] = vocab_sz+1 88 | if use_seq2seq_data: 89 | vocabulary[""] = vocab_sz+2 90 | 91 | # Vocabulary info 92 | vocabulary_size = len(vocabulary) 93 | vocabulary_decode = dict(zip(vocabulary.values(), vocabulary.keys())) 94 | meta_vocabulary = reader.read_abc_pickle(META_DATA) 95 | num_classes = len(meta_vocabulary['R']) + 1 96 | 97 | gan_label_size = 1 98 | 99 | start_encode = vocabulary[""] if (args.train == "sample" and use_seq2seq_data) else vocabulary[""] 100 | end_encode = vocabulary[""] 101 | # Getting meta mapping: 102 | meta_map = pickle.load(open(META_DATA, 'rb')) 103 | 104 | cell_type = 'lstm' 105 | # cell_type = 'gru' 106 | # cell_type = 'rnn' 107 | 108 | curModel = GenAdversarialNet(input_size, gan_label_size, num_classes, cell_type, 109 | args.train=='train', batch_size, vocabulary_size, 110 | args.set_config, use_lrelu=True, use_batchnorm=False, 111 | dropout=None) 112 | 113 | probabilities_real_op, probabilities_fake_op = curModel.create_model() 114 | input_placeholder, label_placeholder, \ 115 | rnn_meta_placeholder, rnn_initial_state_placeholder, \ 116 | rnn_use_meta_placeholder, train_op_d, train_op_gan, \ 117 | loss_op, class_loss, accuracy_op, gen_accuracy_op = curModel.train() 118 | 119 | print "Reading in {0}-set filenames.".format(args.train) 120 | print "Running {0} model for {1} epochs.".format(args.model, NUM_EPOCHS) 121 | 122 | global_step = tf.Variable(0, trainable=False, name='global_step') #tf.contrib.framework.get_or_create_global_step() 123 | saver = tf.train.Saver(max_to_keep=NUM_EPOCHS) 124 | step = 0 125 | 126 | with tf.Session(config=GPU_CONFIG) as session: 127 | print "Inititialized TF Session!" 128 | 129 | # Checkpoint 130 | i_stopped, found_ckpt = utils_runtime.get_checkpoint(args, session, saver) 131 | 132 | file_writer = tf.summary.FileWriter(SUMMARY_DIR, graph=session.graph, max_queue=10, flush_secs=30) 133 | confusion_matrix = np.zeros((vocabulary_size, vocabulary_size)) 134 | batch_accuracies = [] 135 | 136 | if args.train == "train": 137 | init_op = tf.global_variables_initializer() # tf.group(tf.initialize_all_variables(), tf.initialize_local_variables()) 138 | init_op.run() 139 | else: 140 | # Exit if no checkpoint to test 141 | if not found_ckpt: 142 | return 143 | NUM_EPOCHS = i_stopped + 1 144 | 145 | # Sample Model 146 | if args.train == "sample": 147 | pass 148 | # # Sample Model 149 | # warm_length = 20 150 | # warm_meta, warm_chars = utils_runtime.genWarmStartDataset(warm_length) 151 | # 152 | # warm_meta_array = [warm_meta[:] for idx in xrange(5)] 153 | # 154 | # # Change Key 155 | # warm_meta_array[1][4] = 1 - warm_meta_array[1][4] 156 | # # Change Number of Flats/Sharps 157 | # warm_meta_array[2][3] = np.random.choice(11) 158 | # # Lower Complexity 159 | # warm_meta_array[3][6] = 50 160 | # # Higher Complexity 161 | # warm_meta_array[4][6] = 350 162 | # 163 | # new_warm_meta = utils_runtime.encode_meta_batch(meta_vocabulary, warm_meta_array) 164 | # new_warm_meta_array = zip(warm_meta_array, new_warm_meta) 165 | # 166 | # print "Sampling from single RNN cell using warm start of ({0})".format(warm_length) 167 | # for old_meta, meta in new_warm_meta_array: 168 | # print "Current Metadata: {0}".format(meta) 169 | # generated = warm_chars[:] 170 | # 171 | # if args.model == 'char': 172 | # # Warm Start 173 | # for j, c in enumerate(warm_chars): 174 | # if cell_type == 'lstm': 175 | # if j == 0: 176 | # initial_state_sample = [[np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size)] for layer in xrange(curModel.config.num_layers)] 177 | # else: 178 | # initial_state_sample = [] 179 | # for lstm_tuple in state: 180 | # initial_state_sample.append(lstm_tuple[0]) 181 | # else: 182 | # initial_state_sample = [np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size)] if (j == 0) else state[0] 183 | # 184 | # feed_values = utils_runtime.pack_feed_values(args, [[c]], 185 | # [[0]], [meta], 186 | # initial_state_sample, (j == 0), 187 | # None, None) 188 | # logits, state = curModel.sample(session, feed_values) 189 | # 190 | # # Sample 191 | # sampled_character = utils_runtime.sample_with_temperature(logits, TEMPERATURE) 192 | # while sampled_character != vocabulary[""] and len(generated) < 100: 193 | # if cell_type == 'lstm': 194 | # initial_state_sample = [] 195 | # for lstm_tuple in state: 196 | # initial_state_sample.append(lstm_tuple[0]) 197 | # else: 198 | # initial_state_sample = state[0] 199 | # 200 | # feed_values = utils_runtime.pack_feed_values(args, [[sampled_character]], 201 | # [[0]], [np.zeros_like(meta)], 202 | # initial_state_sample, False, 203 | # None, None) 204 | # logits, state = curModel.sample(session, feed_values) 205 | # 206 | # sampled_character = utils_runtime.sample_with_temperature(logits, TEMPERATURE) 207 | # generated.append(sampled_character) 208 | # 209 | # elif args.model == 'seq2seq': 210 | # prediction = sample_Seq2Seq(args, curModel, cell_type, session, warm_chars, vocabulary, meta, batch_size) 211 | # generated.extend(prediction.flatten()) 212 | # 213 | # 214 | # decoded_characters = [vocabulary_decode[char] for char in generated] 215 | # 216 | # # Currently chopping off the last char regardless if its or not 217 | # encoding = utils.encoding2ABC(old_meta, generated) 218 | 219 | # Train, dev, test model 220 | else: 221 | for i in xrange(i_stopped, NUM_EPOCHS): 222 | print "Running epoch ({0})...".format(i) 223 | random.shuffle(dateset_filenames) 224 | for j, data_file in enumerate(dateset_filenames): 225 | # Get train data - into feed_dict 226 | data = reader.read_abc_pickle(data_file) 227 | random.shuffle(data) 228 | data_batches = reader.abc_batch(data, n=batch_size/2) 229 | for k, data_batch in enumerate(data_batches): 230 | meta_batch, input_window_batch, output_window_batch = tuple([list(tup) for tup in zip(*data_batch)]) 231 | new_meta_batch = utils_runtime.encode_meta_batch(meta_vocabulary, meta_batch) 232 | 233 | for it in range(0, 10): 234 | noise_meta_batch = [utils_runtime.create_noise_meta(meta_vocabulary) for l in xrange(batch_size/2)] 235 | new_noise_meta_batch = utils_runtime.encode_meta_batch(meta_vocabulary, noise_meta_batch) 236 | noise_input_window_batch = [np.random.randint(vocabulary_size, size=window_sz) for l in xrange(batch_size/2)] 237 | initial_state_batch = [[np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size/2)] for layer in xrange(curModel.config.num_layers)] 238 | gan_labels = np.asarray([int(m[0]) for m in meta_batch]) 239 | 240 | feed_dict = { 241 | input_placeholder: noise_input_window_batch, 242 | rnn_meta_placeholder: new_noise_meta_batch, 243 | rnn_initial_state_placeholder: initial_state_batch, 244 | rnn_use_meta_placeholder: True 245 | } 246 | 247 | if args.train == "train": 248 | _, gen_accuracy = session.run([ train_op_gan, gen_accuracy_op], feed_dict=feed_dict) 249 | 250 | print "Only Generator training loss: {0}".format(gen_accuracy) 251 | 252 | 253 | noise_meta_batch = [utils_runtime.create_noise_meta(meta_vocabulary) for l in xrange(batch_size/2)] 254 | new_noise_meta_batch = utils_runtime.encode_meta_batch(meta_vocabulary, noise_meta_batch) 255 | 256 | noise_input_window_batch = [np.random.randint(vocabulary_size, size=window_sz) for l in xrange(batch_size/2)] 257 | noise_input_window_batch += input_window_batch 258 | 259 | initial_state_batch = [[np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size/2)] for layer in xrange(curModel.config.num_layers)] 260 | gan_labels = np.asarray([int(m[0]) for m in meta_batch]) 261 | 262 | feed_dict = { 263 | input_placeholder: noise_input_window_batch, 264 | label_placeholder: gan_labels, 265 | rnn_meta_placeholder: new_noise_meta_batch, 266 | rnn_initial_state_placeholder: initial_state_batch, 267 | rnn_use_meta_placeholder: True 268 | } 269 | 270 | # summary, conf, accuracy = curModel.run(args, session, feed_values) 271 | 272 | if args.train == "train": 273 | _, _ , curLoss, classLoss, accuracy, gen_accuracy = session.run([train_op_d, train_op_gan, loss_op, 274 | class_loss, accuracy_op, gen_accuracy_op], feed_dict=feed_dict) 275 | else: # Sample case not necessary b/c function will only be called during normal runs 276 | pass 277 | # summary, loss, probabilities, prediction, accuracy, confusion_matrix = session.run([self.summary_op, self.loss_op, self.probabilities_op, self.prediction_op, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 278 | 279 | print "The current total Discriminator loss is {0} in epoch {1}".format(curLoss, i) 280 | print "The current Class loss for real data is {0} in epoch {1}".format(classLoss, i) 281 | print "The current accuracy for real data is {0} in epoch {1}".format(accuracy, i) 282 | print "The current accuracy for generator is {0} in epoch {1} \n".format(gen_accuracy, i) 283 | # file_writer.add_summary(summary, step) 284 | # 285 | # # Update confusion matrix 286 | # confusion_matrix += conf 287 | # 288 | # # Record batch accuracies for test code 289 | # if args.train == "test" or args.train == 'dev': 290 | # batch_accuracies.append(accuracy) 291 | # 292 | # # Processed another batch 293 | # step += 1 294 | 295 | if args.train == "train": 296 | # Checkpoint model - every epoch 297 | utils_runtime.save_checkpoint(args, session, saver, i) 298 | confusion_suffix = str(i) 299 | else: # dev or test (NOT sample) 300 | test_accuracy = np.mean(batch_accuracies) 301 | print "Model {0} accuracy: {1}".format(args.train, test_accuracy) 302 | confusion_suffix = "_{0}-set".format(args.train) 303 | 304 | if args.train == 'dev': 305 | # Update the file for choosing best hyperparameters 306 | curFile = open(curModel.config.dev_filename, 'a') 307 | curFile.write("Dev set accuracy: {0}".format(test_accuracy)) 308 | curFile.write('\n') 309 | curFile.close() 310 | 311 | # Plot Confusion Matrix 312 | # plot_confusion(confusion_matrix, vocabulary, confusion_suffix+"_all") 313 | # plot_confusion(confusion_matrix, vocabulary, confusion_suffix+"_removed", characters_remove=['|', '2', '']) 314 | 315 | 316 | 317 | def main(_): 318 | 319 | args = utils_runtime.parseCommandLine() 320 | run_gan(args) 321 | 322 | if args.train != "sample": 323 | if tf.gfile.Exists(SUMMARY_DIR): 324 | tf.gfile.DeleteRecursively(SUMMARY_DIR) 325 | tf.gfile.MakeDirs(SUMMARY_DIR) 326 | 327 | if __name__ == "__main__": 328 | tf.app.run() 329 | -------------------------------------------------------------------------------- /utils_preprocess.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | import re 4 | import pickle 5 | import shutil 6 | 7 | import numpy as np 8 | 9 | from collections import Counter 10 | from multiprocessing import Pool 11 | 12 | from utils import * 13 | 14 | FORMAT_DIR = 'formatted' 15 | CHECK_DIR = 'checked' 16 | ENCODE_TEST_DIR = 'test_encoded' 17 | ENCODE_TRAIN_DIR = 'train_encoded' 18 | ENCODE_DEV_DIR = 'dev_encoded' 19 | NN_INPUT_TEST_DIR = 'nn_input_test' 20 | NN_INPUT_TRAIN_DIR = 'nn_input_train' 21 | NN_INPUT_DEV_DIR = 'nn_input_dev' 22 | 23 | def eraseUnreadable(folderN): 24 | iterf = 0 25 | for fileN in os.listdir(folderN): 26 | fileAbsN = folderN+'/'+fileN 27 | try: 28 | if iterf%100 == 0: 29 | print iterf 30 | iterf += 1 31 | midi_data = pretty_midi.PrettyMIDI(fileAbsN) 32 | #roll = midi_data.get_piano_roll() 33 | except: 34 | print 'failed: '+fileN 35 | os.remove(fileAbsN) 36 | 37 | def formatABCtxtWorker(dataPack): 38 | filename,outputname,duet = dataPack 39 | header = True 40 | headerDict = {} 41 | headerTup = ('T', 'R', 'M', 'L', 'K', 'Q') 42 | headerDefault = {'T':'none', 'R':'none', 'M':'4/4', 'L':'1/8', 'K':'C', 'Q':'1/4=100'} 43 | print filename 44 | with open(filename,'r') as infile: 45 | fileStr = infile.read().replace(' ','').replace('\r','\n') 46 | if 'X:' not in fileStr: 47 | return 48 | 49 | with open(outputname,'w') as outfile: 50 | for line in fileStr.split('\n'): 51 | line = line.strip() 52 | # skip empty lines 53 | if line=='' or '%' in line: 54 | continue 55 | 56 | if header: 57 | headerStr = re.match('[a-zA-Z]:',line) 58 | 59 | if headerStr is None: 60 | header = False 61 | for head in headerTup: 62 | if head in headerDict: 63 | if head=='R': 64 | headerDict[head] = headerDict[head][0:2]+headerDict[head][2:].lower() 65 | outfile.write(headerDict[head]+'\n') 66 | else: 67 | outfile.write('%s:%s\n' %(head,headerDefault[head])) 68 | 69 | elif headerStr.group()[0] in headerTup: 70 | headerChr = headerStr.group()[0] 71 | if 'none' == line[line.find(':')+1:].lower(): 72 | headerDict[headerChr] = '%s:%s\n' %(headerChr, headerDefault[headerChr]) 73 | else: 74 | headerDict[headerChr] = line 75 | 76 | if not header: 77 | headerStr = re.match('[a-zA-Z]:',line) 78 | if (headerStr is not None) or (line[:3]=='[V:' and not duet): 79 | break 80 | if len(re.findall('\[V:[03-9]', line))!=0: 81 | continue 82 | 83 | # remove stuff inside of double quotes 84 | quotes = False 85 | addStr = '' 86 | for ch in line: 87 | if not quotes: 88 | addStr += ch 89 | if ch=='"': 90 | if quotes: 91 | addStr = addStr[:-1] 92 | quotes = not quotes 93 | 94 | outfile.write(addStr) 95 | 96 | if duet: 97 | outfile.write('%') 98 | 99 | outfile.write('\n') 100 | 101 | def formatABCtxt(folderName, outputFolder, isDuet): 102 | outputFolder = os.path.join(outputFolder, FORMAT_DIR) 103 | makedir(outputFolder) 104 | 105 | p = Pool(8) 106 | filenames = [re.sub(r'[^\x00-\x7f]',r'',fname) for fname in os.listdir(folderName)] 107 | mapList = [(os.path.join(folderName,fname), os.path.join(outputFolder,fname2), isDuet) 108 | for fname,fname2 in zip(os.listdir(folderName),filenames)] 109 | 110 | p.map(formatABCtxtWorker, mapList) 111 | 112 | MIN_MEASURES = 10 113 | NUM_TRANSPOSITIONS = 4 114 | def checkABCtxtWorker(dataPack): 115 | filename,outputname,isDuet = dataPack 116 | 117 | print filename 118 | header = True 119 | headerDict = {} 120 | headerTup = ('T', 'R', 'M', 'L', 'K', 'Q') 121 | with open(filename,'r') as infile: 122 | fileStr = infile.read() 123 | fileList = fileStr.split('\n') 124 | 125 | # checking stage 126 | #----------------------------- 127 | # each .abc file needs to be 8 lines long (6 metadata, 1 music, and 1 empty line) 128 | if len(fileList)!=8: 129 | print filename+': Does not have 8 lines' 130 | return 131 | 132 | # check that the file contains all metadata tags 133 | for i,header in enumerate(headerTup): 134 | if fileList[i][0] != header: 135 | print filename+': Does not contain the metadata '+header 136 | return 137 | 138 | # make sure that there are more than MIN_MEASURES measures in the song 139 | if fileStr.replace('||','|').count('|')len(music): 356 | break 357 | output_window = music[output_start:output_end] 358 | 359 | tup = (meta, input_window, output_window) 360 | tupList.append(tup) 361 | 362 | count += 1 363 | 364 | # add another window which includes the last element 365 | if nnType=='char_rnn' or nnType=='BOW': 366 | start_indx = len(music)-window_sz-1 367 | output_start = start_indx+1 368 | output_end = start_indx+window_sz+1 if nnType=='char_rnn' else output_start+1 369 | elif nnType=='seq2seq': 370 | start_indx = len(music)-window_sz-output_sz 371 | output_start = start_indx+window_sz 372 | output_end = output_start+output_sz 373 | 374 | tupList.append((meta, music[start_indx:start_indx+window_sz], music[output_start:output_end])) 375 | 376 | return tupList 377 | 378 | def npy2nnInputWorker(dataPack): 379 | outfname,tupList = dataPack 380 | windowList = [] 381 | for tup in tupList: 382 | windowList += npy2nnInputWorkerWorker(tup) 383 | 384 | pickle.dump(windowList, open(outfname,'wb')) 385 | 386 | def npy2nnInput(outputFolder, stride_sz, window_sz, nnType, output_sz=0, num_buckets=8): 387 | """ 388 | Converts encoded npy to an array of tuples for NN input 389 | 390 | @outputFolder - string / filename of h5 file to read from 391 | @stride_sz - int / stride size 392 | @window_sz - int / window size of the input 393 | @output_sz - int / window size of the output (only used for nnType='seq2seq') 394 | @nnType - string / nn to feed the generated data to. 395 | 'BOW' 'seq2seq' 'char_rnn' 396 | @num_buckets - int / number of files to generate 397 | """ 398 | 399 | if output_sz==0 and nnType=='seq2seq': 400 | print '[ERROR] npy2nnInput(): make sure to set the @output_sz for "seq2seq"' 401 | exit(0) 402 | 403 | dir_list = [(NN_INPUT_TEST_DIR, ENCODE_TEST_DIR), 404 | (NN_INPUT_TRAIN_DIR, ENCODE_TRAIN_DIR), 405 | (NN_INPUT_DEV_DIR, ENCODE_DEV_DIR)] 406 | 407 | for outDir,inDir in dir_list: 408 | inputList = [] 409 | outfName = outDir+'_stride_%d_window_%d_nnType_%s'%(stride_sz,window_sz,nnType) 410 | if nnType=='seq2seq': 411 | outfName += '_output_sz_%d' % output_sz 412 | 413 | nnFolder = os.path.join(outputFolder, outfName) 414 | makedir(nnFolder) 415 | 416 | encodedDir = os.path.join(outputFolder, inDir) 417 | for fname in os.listdir(encodedDir): 418 | inputList.append((stride_sz, window_sz, nnType, output_sz, os.path.join(encodedDir, fname))) 419 | 420 | mapList = [] 421 | for i in range(num_buckets): 422 | mapList.append((os.path.join(nnFolder,'%d.p'%i), 423 | inputList[int(i*len(inputList)/num_buckets) 424 | :int((i+1)*len(inputList)/num_buckets)])) 425 | 426 | p = Pool(8) 427 | p.map(npy2nnInputWorker, mapList) 428 | 429 | shuffleDataset(nnFolder) 430 | shutil.rmtree(nnFolder) 431 | 432 | def shuffleDataset(originalDir): 433 | print 'Shuffling %s' % originalDir 434 | outFolder = originalDir+'_shuffled' 435 | makedir(outFolder) 436 | 437 | input_list = [] 438 | filenames = os.listdir(originalDir) 439 | num_buckets = len(filenames) 440 | print 'Loading data' 441 | for filename in filenames: 442 | print filename 443 | with open(os.path.join(originalDir,filename),'r') as f: 444 | input_list += pickle.load(f) 445 | 446 | random.shuffle(input_list) 447 | 448 | print 'Done shuffling, saving the shuffled data...' 449 | for i,filename in enumerate(filenames): 450 | print filename 451 | with open(os.path.join(outFolder,filename),'w') as f: 452 | input_frac = input_list[int(i*len(input_list)/len(filenames)) 453 | :int((i+1)*len(input_list)/len(filenames))] 454 | pickle.dump(input_frac, f) 455 | 456 | def removeWrongDim(folderName): 457 | for subfolder in os.listdir(folderName): 458 | subfolderPath = os.path.join(folderName, subfolder) 459 | print subfolderPath 460 | 461 | if 'nn_input' in subfolderPath: 462 | inputSz = int(re.findall('[0-9]+', re.findall('window_[0-9]+', subfolderPath)[0])[0]) 463 | 464 | if 'output_sz_' in subfolderPath: 465 | outputSz = int(re.findall('[0-9]+', re.findall('output_sz_[0-9]+', subfolderPath)[0])[0]) 466 | else: 467 | outputSz = inputSz 468 | 469 | for filename in os.listdir(subfolderPath): 470 | with open(os.path.join(subfolderPath, filename), 'rb') as inF: 471 | print filename 472 | 473 | filename_abs = os.path.join(subfolderPath, filename) 474 | inputTupList = pickle.load(inF) 475 | 476 | deleteIndx = [] 477 | count = 0 478 | for meta,inArr,outArr in inputTupList: 479 | if (len(inArr) != inputSz) or (len(outArr) != outputSz): 480 | deleteIndx.append(count) 481 | count += 1 482 | 483 | if len(deleteIndx)==0: 484 | continue 485 | 486 | print 'Found %d tuple[s] with malformed inputs...' % len(deleteIndx) 487 | deleteIndx.reverse() 488 | for delI in deleteIndx: 489 | del inputTupList[delI] 490 | 491 | cleanFilename = os.path.join(subfolderPath, filename.replace('.','_cleaned.')) 492 | with open(cleanFilename, 'wb') as outF: 493 | pickle.dump(inputTupList, outF) 494 | 495 | os.remove(filename_abs) 496 | os.rename(cleanFilename, filename_abs) 497 | 498 | 499 | if __name__ == "__main__": 500 | # # preprocessing pipeline 501 | # #----------------------------------- 502 | originalDataDir = '/data/full_dataset/duet' 503 | # processedDir = originalDataDir 504 | processedDir = originalDataDir+'_processed' 505 | isDuet = True 506 | 507 | # print '-'*20 + 'FORMATTING' + '-'*20 508 | # formatABCtxt(originalDataDir, processedDir, isDuet) 509 | # print '-'*20 + 'CHECKING' + '-'*20 510 | # checkABCtxt(processedDir, isDuet) 511 | 512 | # # for Duet: 513 | # convertNewLines2Percent(processedDir) 514 | 515 | # print '-'*20 + 'SPLITTING' + '-'*20 516 | # datasetSplit(processedDir, (0.8,0.1,0.1)) 517 | # print '-'*20 + 'GENERATING VOCAB' + '-'*20 518 | # generateVocab(processedDir) 519 | # print '-'*20 + 'ENCODING' + '-'*20 520 | # encodeABC(processedDir) 521 | # print '-'*20 + 'FORMING NNINPUTS' + '-'*20 522 | # npy2nnInput(processedDir, 25, 10, 'seq2seq', output_sz=10) 523 | # print '-'*20 + 'FORMING NNINPUTS' + '-'*20 524 | # npy2nnInput(processedDir, 25, 25, 'seq2seq', output_sz=25) 525 | # print '-'*20 + 'FORMING NNINPUTS' + '-'*20 526 | # npy2nnInput(processedDir, 10, 100, 'seq2seq', output_sz=100) 527 | # print '-'*20 + 'FORMING NNINPUTS' + '-'*20 528 | # npy2nnInput(processedDir, 25, 10, 'char_rnn') 529 | # print '-'*20 + 'FORMING NNINPUTS' + '-'*20 530 | # npy2nnInput(processedDir, 25, 25, 'char_rnn') 531 | # print '-'*20 + 'FORMING NNINPUTS' + '-'*20 532 | # npy2nnInput(processedDir, 25, 50, 'char_rnn') 533 | # print '-'*20 + 'REMOVING WRONG DIMENSIONS' + '-'*20 534 | # removeWrongDim(processedDir) 535 | # #----------------------------------- 536 | # pass 537 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import os 8 | import sys 9 | from models import CharRNN, Config, Seq2SeqRNN, CBOW, GenAdversarialNet 10 | import pickle 11 | import reader 12 | import random 13 | import utils_runtime 14 | import utils_hyperparam 15 | import utils 16 | import re 17 | import copy 18 | 19 | tf_ver = tf.__version__ 20 | SHERLOCK = (str(tf_ver) == '0.12.1') 21 | 22 | # TRAIN_DATA = '/data/small_processed/nn_input_train' 23 | # DEVELOPMENT_DATA = '/data/small_processed/nn_input_test' 24 | 25 | # for Sherlock 26 | if SHERLOCK: 27 | DIR_MODIFIER = '/scratch/users/nipuna1' 28 | from tensorflow.contrib.metrics import confusion_matrix as tf_confusion_matrix 29 | # for Azure 30 | else: 31 | DIR_MODIFIER = '/data' 32 | 33 | TRAIN_DATA = DIR_MODIFIER + '/full_dataset/char_rnn_dataset/nn_input_train_stride_25_window_25_nnType_char_rnn_shuffled' 34 | TEST_DATA = DIR_MODIFIER + '/full_dataset/char_rnn_dataset/nn_input_test_stride_25_window_25_nnType_char_rnn_shuffled' 35 | DEVELOPMENT_DATA = DIR_MODIFIER + '/full_dataset/char_rnn_dataset/nn_input_dev_stride_25_window_25_nnType_char_rnn_shuffled' 36 | 37 | GAN_TRAIN_DATA = DIR_MODIFIER + '/full_dataset/gan_dataset/nn_input_train_stride_25_window_25_nnType_seq2seq_output_sz_25_shuffled' 38 | GAN_TEST_DATA = DIR_MODIFIER + '/full_dataset/gan_dataset/nn_input_test_stride_25_window_25_nnType_seq2seq_output_sz_25_shuffled' 39 | GAN_DEVELOPMENT_DATA = DIR_MODIFIER + '/full_dataset/gan_dataset/nn_input_dev_stride_25_window_25_nnType_seq2seq_output_sz_25_shuffled' 40 | 41 | SUMMARY_DIR = DIR_MODIFIER + '/dev_summary2' 42 | 43 | BATCH_SIZE = 100 # should be dynamically passed into Config 44 | NUM_EPOCHS = 50 45 | GPU_CONFIG = tf.ConfigProto() 46 | GPU_CONFIG.gpu_options.per_process_gpu_memory_fraction = 0.3 47 | 48 | # For T --> inf, p is uniform. Easy to sample from! 49 | # For T --> 0, p "concentrates" on arg max. Hard to sample from! 50 | TEMPERATURE = 1.0 51 | 52 | meta_map = pickle.load(open(os.path.join(DIR_MODIFIER, 'full_dataset/global_map_meta.p'),'rb')) 53 | music_map = pickle.load(open(os.path.join(DIR_MODIFIER, 'full_dataset/global_map_music.p'),'rb')) 54 | 55 | def plot_confusion(confusion_matrix, vocabulary, epoch, characters_remove=[], annotate=False): 56 | # Get vocabulary components 57 | vocabulary_keys = music_map.keys() 58 | vocabulary_values = music_map.values() 59 | # print vocabulary_keys 60 | vocabulary_values, vocabulary_keys = tuple([list(tup) for tup in zip(*sorted(zip(vocabulary_values, vocabulary_keys)))]) 61 | # print vocabulary_keys 62 | 63 | removed_indicies = [] 64 | for c in characters_remove: 65 | i = vocabulary_keys.index(c) 66 | vocabulary_keys.remove(c) 67 | index = vocabulary_values.pop(i) 68 | removed_indicies.append(index) 69 | 70 | # Delete unnecessary rows 71 | conf_temp = np.delete(confusion_matrix, removed_indicies, axis=0) 72 | # Delete unnecessary cols 73 | new_confusion = np.delete(conf_temp, removed_indicies, axis=1) 74 | 75 | 76 | vocabulary_values = range(len(vocabulary_keys)) 77 | vocabulary_size = len(vocabulary_keys) 78 | 79 | fig, ax = plt.subplots(figsize=(10, 10)) 80 | res = ax.imshow(new_confusion.astype(int), interpolation='nearest', cmap=plt.cm.jet) 81 | cb = fig.colorbar(res) 82 | 83 | if annotate: 84 | for x in xrange(vocabulary_size): 85 | for y in xrange(vocabulary_size): 86 | ax.annotate(str(new_confusion[x, y]), xy=(y, x), 87 | horizontalalignment='center', 88 | verticalalignment='center', 89 | fontsize=4) 90 | 91 | plt.xticks(vocabulary_values, vocabulary_keys, fontsize=6) 92 | plt.yticks(vocabulary_values, vocabulary_keys, fontsize=6) 93 | fig.savefig('confusion_matrix_epoch{0}.png'.format(epoch)) 94 | 95 | 96 | 97 | def sample_Seq2Seq(args, curModel, cell_type, session, warm_chars, vocabulary, meta, batch_size): 98 | num_encode = [len(warm_chars)] 99 | num_decode = [1000] 100 | 101 | if cell_type == 'lstm': 102 | initial_state_sample = [[np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size)] for layer in xrange(curModel.config.num_layers)] 103 | else: 104 | initial_state_sample = [np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size)] 105 | 106 | feed_values = utils_runtime.pack_feed_values(args, [warm_chars], 107 | [[vocabulary[""]]], [np.zeros_like(meta)], 108 | initial_state_sample, True, 109 | num_encode, num_decode) 110 | # logits, state = curModel.sample(session, feed_values) 111 | prediction = curModel.sample(session, feed_values) 112 | print len(prediction[0]) 113 | return prediction 114 | 115 | 116 | def sampleCBOW(session, args, curModel, vocabulary_decode): 117 | # Sample Model 118 | warm_length = curModel.input_size 119 | warm_meta, warm_chars = utils_runtime.genWarmStartDataset(warm_length, meta_map, music_map) 120 | 121 | warm_meta_array = [warm_meta] 122 | # warm_meta_array = [warm_meta[:] for idx in xrange(3)] 123 | # warm_meta_array[1][4] = 1 - warm_meta_array[1][4] 124 | # warm_meta_array[1][3] = np.random.choice(11) 125 | 126 | print "Sampling from single RNN cell using warm start of ({0})".format(warm_length) 127 | for meta in warm_meta_array: 128 | print "Current Metadata: {0}".format(meta) 129 | generated = warm_chars[:] 130 | context_window = warm_chars[:] 131 | 132 | # Warm Start (get the first prediction) 133 | feed_values = utils_runtime.pack_feed_values(args, [context_window], [[0]*len(context_window)], 134 | None, None, None, None, None) 135 | logits,_ = curModel.sample(session, feed_values) 136 | 137 | # Sample 138 | sampled_character = utils_runtime.sample_with_temperature(logits, TEMPERATURE) 139 | #while sampled_character!=END_TOKEN_ID and len(generated) < 200: 140 | while len(generated) < 200: 141 | # update the context input for the model 142 | context_window = context_window[1:] + [sampled_character] 143 | 144 | feed_values = utils_runtime.pack_feed_values(args, [context_window], [[0]*len(context_window)], 145 | None, None, None, None, None) 146 | logits,_ = curModel.sample(session, feed_values) 147 | 148 | sampled_character = utils_runtime.sample_with_temperature(logits, TEMPERATURE) 149 | generated.append(sampled_character) 150 | 151 | decoded_characters = [vocabulary_decode[char] for char in generated] 152 | 153 | # Currently chopping off the last char regardless if its or not 154 | encoding = utils.encoding2ABC(meta, generated[1:-1], meta_map, music_map) 155 | 156 | return encoding 157 | 158 | 159 | def run_model(args): 160 | # used by song_generator.py 161 | if hasattr(args, 'temperature'): 162 | global TEMPERATURE 163 | TEMPERATURE = args.temperature 164 | 165 | if hasattr(args, 'warm_len'): 166 | warm_length = args.warm_len 167 | else: 168 | warm_length = 15 169 | 170 | if hasattr(args, 'meta_map'): 171 | global meta_map,music_map 172 | meta_map = pickle.load(open(os.path.join(DIR_MODIFIER, args.meta_map),'rb')) 173 | music_map = pickle.load(open(os.path.join(DIR_MODIFIER, args.music_map),'rb')) 174 | 175 | use_seq2seq_data = (args.model == 'seq2seq') 176 | if args.data_dir != '': 177 | dataset_dir = args.data_dir 178 | elif args.train == 'train': 179 | dataset_dir = GAN_TRAIN_DATA if use_seq2seq_data else TRAIN_DATA 180 | elif args.train == 'test': 181 | dataset_dir = GAN_TEST_DATA if use_seq2seq_data else TEST_DATA 182 | else: # args.train == 'dev' or 'sample' (which has no dataset, but we just read anyway) 183 | dataset_dir = GAN_DEVELOPMENT_DATA if use_seq2seq_data else DEVELOPMENT_DATA 184 | 185 | print 'Using dataset %s' %dataset_dir 186 | dateset_filenames = reader.abc_filenames(dataset_dir) 187 | 188 | # figure out the input data size 189 | window_sz = int(re.findall('[0-9]+', re.findall('window_[0-9]+', dataset_dir)[0])[0]) 190 | if 'output_sz' in dataset_dir: 191 | label_sz = int(re.findall('[0-9]+', re.findall('output_sz_[0-9]+', dataset_dir)[0])[0]) 192 | else: 193 | label_sz = window_sz 194 | 195 | input_size = 1 if (args.train == "sample" and args.model!='cbow') else window_sz 196 | initial_size = 7 197 | label_size = 1 if args.train == "sample" else label_sz 198 | batch_size = 1 if args.train == "sample" else BATCH_SIZE 199 | NUM_EPOCHS = args.num_epochs 200 | print "Using checkpoint directory: {0}".format(args.ckpt_dir) 201 | 202 | # Getting vocabulary mapping: 203 | vocab_sz = len(music_map) 204 | music_map[""] = vocab_sz 205 | music_map[""] = vocab_sz+1 206 | if use_seq2seq_data: 207 | music_map[""] = vocab_sz+2 208 | 209 | vocabulary_size = len(music_map) 210 | vocabulary_decode = dict(zip(music_map.values(), music_map.keys())) 211 | 212 | start_encode = music_map[""] if (args.train == "sample" and use_seq2seq_data) else music_map[""] 213 | end_encode = music_map[""] 214 | 215 | cell_type = 'lstm' 216 | # cell_type = 'gru' 217 | # cell_type = 'rnn' 218 | 219 | if args.model == 'seq2seq': 220 | curModel = Seq2SeqRNN(input_size, label_size, batch_size, vocabulary_size, cell_type, args.set_config, start_encode, end_encode) 221 | curModel.create_model(is_train = (args.train=='train')) 222 | curModel.train() 223 | curModel.metrics() 224 | 225 | elif args.model == 'char': 226 | curModel = CharRNN(input_size, label_size, batch_size, vocabulary_size, cell_type, args.set_config) 227 | curModel.create_model(is_train = (args.train=='train')) 228 | curModel.train() 229 | curModel.metrics() 230 | 231 | elif args.model == 'cbow': 232 | curModel = CBOW(input_size, batch_size, vocabulary_size, args.set_config) 233 | curModel.create_model() 234 | curModel.train() 235 | curModel.metrics() 236 | 237 | print "Running {0} model for {1} epochs.".format(args.model, NUM_EPOCHS) 238 | 239 | print "Reading in {0}-set filenames.".format(args.train) 240 | 241 | global_step = tf.Variable(0, trainable=False, name='global_step') #tf.contrib.framework.get_or_create_global_step() 242 | saver = tf.train.Saver(max_to_keep=NUM_EPOCHS) 243 | step = 0 244 | 245 | with tf.Session(config=GPU_CONFIG) as session: 246 | print "Inititialized TF Session!" 247 | 248 | # Checkpoint 249 | i_stopped, found_ckpt = utils_runtime.get_checkpoint(args, session, saver) 250 | 251 | # file_writer = tf.summary.FileWriter(SUMMARY_DIR, graph=session.graph, max_queue=10, flush_secs=30) 252 | file_writer = tf.summary.FileWriter(args.ckpt_dir, graph=session.graph, max_queue=10, flush_secs=30) 253 | confusion_matrix = np.zeros((vocabulary_size, vocabulary_size)) 254 | batch_accuracies = [] 255 | 256 | if args.train == "train": 257 | init_op = tf.global_variables_initializer() # tf.group(tf.initialize_all_variables(), tf.initialize_local_variables()) 258 | init_op.run() 259 | else: 260 | # Exit if no checkpoint to test 261 | if not found_ckpt: 262 | return 263 | NUM_EPOCHS = i_stopped + 1 264 | 265 | # Sample Model 266 | if args.train == "sample": 267 | if args.model=='cbow': 268 | encoding = sampleCBOW(session, args, curModel, vocabulary_decode) 269 | return encoding 270 | 271 | # Sample Model 272 | if hasattr(args, 'warmupData'): 273 | warm_meta, warm_chars = utils_runtime.genWarmStartDataset(warm_length, meta_map, 274 | music_map, dataFolder=args.warmupData) 275 | else: 276 | warm_meta, warm_chars = utils_runtime.genWarmStartDataset(warm_length, meta_map, music_map) 277 | 278 | # warm_meta_array = [warm_meta[:] for idx in xrange(5)] 279 | warm_meta_array = [warm_meta[:] for idx in xrange(10)] 280 | 281 | # Change Key 282 | warm_meta_array[1][4] = 1 - warm_meta_array[1][4] 283 | # Change Number of Flats/Sharps 284 | warm_meta_array[2][3] = np.random.choice(11) 285 | # Lower Complexity 286 | warm_meta_array[3][6] = 50 287 | # Higher Complexity 288 | warm_meta_array[4][6] = 350 289 | # Higher LEngth 290 | warm_meta_array[5][5] = 30 291 | 292 | new_warm_meta = utils_runtime.encode_meta_batch(meta_map, warm_meta_array) 293 | new_warm_meta_array = zip(warm_meta_array, new_warm_meta) 294 | 295 | print "Sampling from single RNN cell using warm start of ({0})".format(warm_length) 296 | for old_meta, meta in new_warm_meta_array: 297 | print "Current Metadata: {0}".format(meta) 298 | generated = warm_chars[:] 299 | 300 | if args.model == 'char': 301 | # Warm Start 302 | for j, c in enumerate(warm_chars): 303 | if cell_type == 'lstm': 304 | if j == 0: 305 | initial_state_sample = [[np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size)] for layer in xrange(curModel.config.num_layers)] 306 | else: 307 | initial_state_sample = [] 308 | for lstm_tuple in state: 309 | initial_state_sample.append(lstm_tuple[0]) 310 | else: 311 | initial_state_sample = [np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size)] if (j == 0) else state[0] 312 | 313 | feed_values = utils_runtime.pack_feed_values(args, [[c]], 314 | [[0]], [meta], 315 | initial_state_sample, (j == 0), 316 | None, None) 317 | logits, state = curModel.sample(session, feed_values) 318 | 319 | # Sample 320 | sampled_character = utils_runtime.sample_with_temperature(logits, TEMPERATURE) 321 | while sampled_character != music_map[""] and len(generated) < 100: 322 | if cell_type == 'lstm': 323 | initial_state_sample = [] 324 | for lstm_tuple in state: 325 | initial_state_sample.append(lstm_tuple[0]) 326 | else: 327 | initial_state_sample = state[0] 328 | 329 | feed_values = utils_runtime.pack_feed_values(args, [[sampled_character]], 330 | [[0]], [np.zeros_like(meta)], 331 | initial_state_sample, False, 332 | None, None) 333 | logits, state = curModel.sample(session, feed_values) 334 | 335 | sampled_character = utils_runtime.sample_with_temperature(logits, TEMPERATURE) 336 | generated.append(sampled_character) 337 | 338 | elif args.model == 'seq2seq': 339 | prediction = sample_Seq2Seq(args, curModel, cell_type, session, warm_chars, music_map, meta, batch_size) 340 | generated.extend(prediction.flatten()) 341 | 342 | decoded_characters = [vocabulary_decode[char] for char in generated] 343 | 344 | encoding = utils.encoding2ABC(old_meta, generated, meta_map, music_map) 345 | 346 | if hasattr(args, 'ran_from_script'): 347 | return encoding 348 | 349 | # Train, dev, test model 350 | else: 351 | for i in xrange(i_stopped, NUM_EPOCHS): 352 | print "Running epoch ({0})...".format(i) 353 | random.shuffle(dateset_filenames) 354 | for j, data_file in enumerate(dateset_filenames): 355 | # Get train data - into feed_dict 356 | data = reader.read_abc_pickle(data_file) 357 | random.shuffle(data) 358 | data_batches = reader.abc_batch(data, n=batch_size) 359 | for k, data_batch in enumerate(data_batches): 360 | meta_batch, input_window_batch, output_window_batch = tuple([list(tup) for tup in zip(*data_batch)]) 361 | new_meta_batch = utils_runtime.encode_meta_batch(meta_map, meta_batch) 362 | 363 | initial_state_batch = [[np.zeros(curModel.config.hidden_size) for entry in xrange(batch_size)] for layer in xrange(curModel.config.num_layers)] 364 | num_encode = [window_sz] * batch_size 365 | num_decode = num_encode[:] 366 | 367 | feed_values = utils_runtime.pack_feed_values(args, input_window_batch, 368 | output_window_batch, new_meta_batch, 369 | initial_state_batch, True, 370 | num_encode, num_decode) 371 | 372 | summary, conf, accuracy = curModel.run(args, session, feed_values) 373 | 374 | file_writer.add_summary(summary, step) 375 | 376 | # Update confusion matrix 377 | confusion_matrix += conf 378 | 379 | # Record batch accuracies for test code 380 | if args.train == "test" or args.train == 'dev': 381 | batch_accuracies.append(accuracy) 382 | 383 | # Processed another batch 384 | step += 1 385 | 386 | if args.train == "train": 387 | # Checkpoint model - every epoch 388 | utils_runtime.save_checkpoint(args, session, saver, i) 389 | confusion_suffix = str(i) 390 | else: # dev or test (NOT sample) 391 | test_accuracy = np.mean(batch_accuracies) 392 | print "Model {0} accuracy: {1}".format(args.train, test_accuracy) 393 | confusion_suffix = "_{0}-set".format(args.train) 394 | 395 | if args.train == 'dev': 396 | # Update the file for choosing best hyperparameters 397 | curFile = open(curModel.config.dev_filename, 'a') 398 | curFile.write("Dev set accuracy: {0}".format(test_accuracy)) 399 | curFile.write('\n') 400 | curFile.close() 401 | 402 | # Plot Confusion Matrix 403 | plot_confusion(confusion_matrix, music_map, confusion_suffix+"_all") 404 | plot_confusion(confusion_matrix, music_map, confusion_suffix+"_removed", characters_remove=['|', '2', '']) 405 | 406 | def main(_): 407 | 408 | args = utils_runtime.parseCommandLine() 409 | run_model(args) 410 | 411 | if args.train != "sample": 412 | if tf.gfile.Exists(SUMMARY_DIR): 413 | tf.gfile.DeleteRecursively(SUMMARY_DIR) 414 | tf.gfile.MakeDirs(SUMMARY_DIR) 415 | 416 | 417 | if __name__ == "__main__": 418 | tf.app.run() 419 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | tf_ver = tf.__version__ 4 | SHERLOCK = (str(tf_ver) == '0.12.1') 5 | 6 | from tensorflow.contrib import rnn 7 | if SHERLOCK: 8 | from tensorflow.python.ops import rnn_cell as rnn 9 | from tensorflow.contrib.metrics import confusion_matrix as tf_confusion_matrix 10 | else: 11 | from tensorflow.contrib import rnn 12 | 13 | 14 | from tensorflow.contrib import seq2seq 15 | import utils_models 16 | import numpy as np 17 | import sys 18 | import os 19 | import logging 20 | import math 21 | 22 | import utils_hyperparam 23 | 24 | 25 | class Config(object): 26 | 27 | def setIfNotSet(self, attrStr, val): 28 | if not hasattr(self, attrStr): 29 | setattr(self, attrStr, val) 30 | 31 | def __init__(self, hyperparam_path): 32 | if len(hyperparam_path)!=0: 33 | print "Setting hyperparameters from a file %s" %hyperparam_path 34 | utils_hyperparam.setHyperparam(self, hyperparam_path) 35 | 36 | self.setIfNotSet('batch_size', 100) 37 | self.setIfNotSet('lr', 0.001) 38 | 39 | self.setIfNotSet('songtype', 19) #20 40 | self.setIfNotSet('sign', 16) 41 | self.setIfNotSet('notesize', 5) 42 | self.setIfNotSet('flats', 12) 43 | self.setIfNotSet('mode', 6) 44 | 45 | self.setIfNotSet('len', 1) 46 | self.setIfNotSet('complex', 1) 47 | self.setIfNotSet('max_length', 8) 48 | 49 | self.setIfNotSet('vocab_size', 81) 50 | self.setIfNotSet('meta_embed', 160) #self.songtype/2 51 | self.setIfNotSet('hidden_size', self.meta_embed*5 + 2) 52 | self.setIfNotSet('embedding_dims', 20) 53 | self.setIfNotSet('vocab_meta', self.songtype + self.sign + self.notesize + self.flats + self.mode) 54 | self.setIfNotSet('num_meta', 7) 55 | self.setIfNotSet('num_layers', 2) 56 | self.setIfNotSet('keep_prob', 0.6) 57 | 58 | # Only for CBOW model 59 | self.setIfNotSet('embed_size', 32) 60 | 61 | # Only for Seq2Seq Attention Models 62 | self.setIfNotSet('num_encode', 8) 63 | self.setIfNotSet('num_decode', 4) 64 | self.setIfNotSet('attention_option', 'bahdanau') 65 | self.setIfNotSet('bidirectional', False) 66 | 67 | # Discriminator Parameters 68 | self.setIfNotSet('numFilters', 32) 69 | self.setIfNotSet('hidden_units', 100) 70 | self.setIfNotSet('num_outputs', 2) 71 | self.setIfNotSet('cnn_lr', 0.001) 72 | self.setIfNotSet('label_smooth', 0.15) 73 | self.setIfNotSet('generator_prob', 0.1) 74 | self.setIfNotSet('num_classes', 19) 75 | self.setIfNotSet('gan_lr', 0.001) 76 | 77 | 78 | class CBOW(object): 79 | 80 | def __init__(self, input_size, batch_size, vocab_size, hyperparam_path): 81 | self.config = Config(hyperparam_path) 82 | self.input_size = input_size 83 | self.config.batch_size = batch_size 84 | self.config.vocab_size = vocab_size 85 | self.input_placeholder = tf.placeholder(tf.int32, shape=[None, self.input_size], name="Inputs") 86 | self.label_placeholder = tf.placeholder(tf.int32, shape=[None], name="Labels") 87 | self.embeddings = tf.Variable(tf.random_uniform([self.config.vocab_size, 88 | self.config.embed_size], -1.0, 1.0)) 89 | 90 | print("Completed Initializing the CBOW Model.....") 91 | 92 | def create_model(self): 93 | weight = tf.get_variable("Wout", shape=[self.config.embed_size, self.config.vocab_size], 94 | initializer=tf.contrib.layers.xavier_initializer()) 95 | bias = tf.Variable(tf.zeros([self.config.vocab_size])) 96 | 97 | word_vec = tf.nn.embedding_lookup(self.embeddings, self.input_placeholder) 98 | average_embedding = tf.reduce_sum(word_vec, reduction_indices=1) 99 | 100 | self.logits_op = tf.add(tf.matmul(average_embedding, weight), bias) 101 | self.probabilities_op = tf.nn.softmax(self.logits_op) 102 | print("Built the CBOW Model.....") 103 | 104 | def train(self): 105 | self.loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_op, labels=self.label_placeholder)) 106 | tf.summary.scalar('Loss', self.loss_op) 107 | self.train_op = tf.train.AdamOptimizer(self.config.lr).minimize(self.loss_op) 108 | 109 | print("Setup the training mechanism for the CBOW model.....") 110 | 111 | def metrics(self): 112 | last_axis = len(self.probabilities_op.get_shape().as_list()) 113 | self.prediction_op = tf.to_int32(tf.argmax(self.probabilities_op, axis=last_axis-1)) 114 | difference = self.label_placeholder - self.prediction_op 115 | zero = tf.constant(0, dtype=tf.int32) 116 | boolean_difference = tf.cast(tf.equal(difference, zero), tf.float64) 117 | self.accuracy_op = tf.reduce_mean(boolean_difference) 118 | tf.summary.scalar('Accuracy', self.accuracy_op) 119 | 120 | self.summary_op = tf.summary.merge_all() 121 | 122 | if SHERLOCK: 123 | self.confusion_matrix = tf_confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.prediction_op, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 124 | else: 125 | self.confusion_matrix = tf.confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.prediction_op, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 126 | 127 | def _feed_dict(self, feed_values): 128 | input_batch = feed_values[0] 129 | label_batch = feed_values[1] 130 | feed_dict = { 131 | self.input_placeholder: input_batch, 132 | self.label_placeholder: label_batch 133 | } 134 | return feed_dict 135 | 136 | def run(self, args, session, feed_values): 137 | feed_dict = self._feed_dict(feed_values) 138 | 139 | if args.train == "train": 140 | _, summary, loss, probabilities, prediction, accuracy, confusion_matrix = session.run([self.train_op, self.summary_op, self.loss_op, self.probabilities_op, self.prediction_op, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 141 | else: # Sample case not necessary b/c function will only be called during normal runs 142 | summary, loss, probabilities, prediction, accuracy, confusion_matrix = session.run([self.summary_op, self.loss_op, self.probabilities_op, self.prediction_op, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 143 | 144 | print "Average accuracy per batch {0}".format(accuracy) 145 | print "Batch Loss: {0}".format(loss) 146 | # print "Output Predictions: {0}".format(prediction) 147 | # print "Output Prediction Probabilities: {0}".format(probabilities) 148 | 149 | return summary, confusion_matrix, accuracy 150 | 151 | 152 | def sample(self, session, feed_values): 153 | feed_dict = self._feed_dict(feed_values) 154 | 155 | logits = session.run([self.logits_op], feed_dict=feed_dict)[0] 156 | return logits, np.zeros((1, 1)) # dummy value 157 | 158 | 159 | 160 | class CharRNN(object): 161 | 162 | def __init__(self, input_size, label_size, batch_size, vocab_size, cell_type, hyperparam_path, gan_inputs=None): 163 | # with tf.variable_scope("CharRNN") as scope: 164 | self.input_size = input_size 165 | self.label_size = label_size 166 | self.cell_type = cell_type 167 | self.config = Config(hyperparam_path) 168 | self.config.batch_size = batch_size 169 | self.config.vocab_size = vocab_size 170 | self.gan_inputs = gan_inputs 171 | 172 | # self.initial_state = self.cell.zero_state(self.config.batch_size, dtype=tf.int32) 173 | # input_shape = (None,) + tuple([self.config.max_length,input_size]) 174 | input_shape = (None,) + tuple([input_size]) 175 | # output_shape = (None,) + tuple([self.config.max_length,label_size]) 176 | output_shape = (None,) + tuple([label_size]) 177 | 178 | self.input_placeholder = tf.placeholder(tf.int32, shape=input_shape, name='Input') 179 | self.label_placeholder = tf.placeholder(tf.int32, shape=output_shape, name='Output') 180 | self.meta_placeholder = tf.placeholder(tf.int32, shape=[None, self.config.num_meta], name='Meta') 181 | self.use_meta_placeholder = tf.placeholder(tf.bool, name='State_Initialization_Bool') 182 | 183 | if cell_type == 'rnn': 184 | self.cell = rnn.BasicRNNCell(self.config.hidden_size) 185 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[None, self.config.hidden_size], name="Initial_State") 186 | elif cell_type == 'gru': 187 | self.cell = rnn.GRUCell(self.config.hidden_size) 188 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[None, self.config.hidden_size], name="Initial_State") 189 | elif cell_type == 'lstm': 190 | self.cell = rnn.BasicLSTMCell(self.config.hidden_size) 191 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[self.config.num_layers, None, self.config.hidden_size], name="Initial_State") 192 | 193 | print "Completed Initializing the Char RNN Model using a {0} cell".format(cell_type.upper()) 194 | 195 | 196 | def create_model(self, is_train=True): 197 | # with tf.variable_scope("CharRNN") as scope: 198 | if is_train: 199 | self.cell = rnn.DropoutWrapper(self.cell, input_keep_prob=1.0, output_keep_prob=self.config.keep_prob) 200 | rnn_model = rnn.MultiRNNCell([self.cell]*self.config.num_layers, state_is_tuple=True) 201 | 202 | # Embedding lookup for ABC format characters 203 | embeddings_var = tf.Variable(tf.random_uniform([self.config.vocab_size, self.config.embedding_dims], 204 | 0, 10, dtype=tf.float32, seed=3), name='char_embeddings') 205 | true_inputs = self.input_placeholder if (self.gan_inputs == None) else self.gan_inputs 206 | self.embeddings_var = embeddings_var 207 | embeddings = tf.nn.embedding_lookup(embeddings_var, true_inputs) 208 | 209 | # Embedding lookup for Metadata 210 | embeddings_var_meta = tf.Variable(tf.random_uniform([self.config.vocab_meta, self.config.meta_embed], 211 | 0, 10, dtype=tf.float32, seed=3), name='char_embeddings_meta') 212 | embeddings_meta = tf.nn.embedding_lookup(embeddings_var_meta, self.meta_placeholder[:, :5]) 213 | 214 | embeddings_meta_flat = tf.reshape(embeddings_meta, shape=[-1, self.config.hidden_size-2]) 215 | 216 | # Putting all the word embeddings together and then appending the numerical constants at the end of the word embeddings 217 | embeddings_meta = tf.concat([embeddings_meta_flat, tf.to_float(self.meta_placeholder[:, 5:])], axis=-1) 218 | 219 | print embeddings_meta.get_shape().as_list() 220 | if self.cell_type == 'lstm': 221 | initial_added = tf.cond(self.use_meta_placeholder, 222 | lambda: [embeddings_meta for layer in xrange(self.config.num_layers)], 223 | lambda: tf.unstack(self.initial_state_placeholder, axis=0)) # [self.initial_state_placeholder[layer] for layer in xrange(self.config.num_layers)]) 224 | [initial_added[idx].set_shape([self.config.batch_size, self.config.hidden_size]) for idx in xrange(self.config.num_layers)] 225 | initial_tuple = tuple([rnn.LSTMStateTuple(initial_added[idx], np.zeros((self.config.batch_size, self.config.hidden_size), dtype=np.float32)) for idx in xrange(self.config.num_layers)]) 226 | else: 227 | initial_added = tf.cond(self.use_meta_placeholder, 228 | lambda: embeddings_meta, 229 | lambda: self.initial_state_placeholder) 230 | initial_tuple = (initial_added, np.zeros((self.config.batch_size, self.config.hidden_size), dtype=np.float32)) 231 | 232 | rnn_output, self.state_op = tf.nn.dynamic_rnn(rnn_model, embeddings, dtype=tf.float32, initial_state=initial_tuple) 233 | 234 | decode_var = tf.Variable(tf.random_uniform([self.config.hidden_size, self.config.vocab_size], 235 | 0, 10, dtype=tf.float32, seed=3), name='char_decode') 236 | decode_bias = tf.Variable(tf.random_uniform([self.config.vocab_size], 237 | 0, 10, dtype=tf.float32, seed=3), name='char_decode_bias') 238 | decode_list = [] 239 | for i in xrange(self.input_size): 240 | decode_list.append(tf.matmul(rnn_output[:, i, :], decode_var) + decode_bias) 241 | 242 | self.logits_op = tf.stack(decode_list, axis=1) 243 | self.rnn_output = rnn_output 244 | self.probabilities_op = tf.nn.softmax(self.logits_op) 245 | 246 | print("Built the Char RNN Model...") 247 | 248 | 249 | def train(self, max_norm=5, op='adam'): 250 | # with tf.variable_scope("CharRNN") as scope: 251 | self.loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_op, labels=self.label_placeholder)) 252 | tf.summary.scalar('Loss', self.loss_op) 253 | tvars = tf.trainable_variables() 254 | 255 | # Gradient clipping 256 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss_op, tvars),max_norm) 257 | optimizer = tf.train.AdamOptimizer(self.config.lr) 258 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 259 | 260 | print("Setup the training mechanism for the Char RNN Model...") 261 | 262 | 263 | def metrics(self): 264 | # Same function, did not make a general one b/c need to store _ops within class 265 | last_axis = len(self.probabilities_op.get_shape().as_list()) 266 | self.prediction_op = tf.to_int32(tf.argmax(self.probabilities_op, axis=last_axis-1)) 267 | difference = self.label_placeholder - self.prediction_op 268 | zero = tf.constant(0, dtype=tf.int32) 269 | boolean_difference = tf.cast(tf.equal(difference, zero), tf.float64) 270 | self.accuracy_op = tf.reduce_mean(boolean_difference) 271 | tf.summary.scalar('Accuracy', self.accuracy_op) 272 | 273 | self.summary_op = tf.summary.merge_all() 274 | 275 | if SHERLOCK: 276 | self.confusion_matrix = tf_confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.prediction_op, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 277 | else: 278 | self.confusion_matrix = tf.confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.prediction_op, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 279 | 280 | 281 | def _feed_dict(self, feed_values): 282 | input_batch = feed_values[0] 283 | label_batch = feed_values[1] 284 | meta_batch = feed_values[2] 285 | initial_state_batch = feed_values[3] 286 | use_meta_batch = feed_values[4] 287 | 288 | feed_dict = { 289 | self.input_placeholder: input_batch, 290 | self.label_placeholder: label_batch, 291 | self.meta_placeholder: meta_batch, 292 | self.initial_state_placeholder: initial_state_batch, 293 | self.use_meta_placeholder: use_meta_batch 294 | } 295 | 296 | return feed_dict 297 | 298 | 299 | def run(self, args, session, feed_values): 300 | feed_dict = self._feed_dict(feed_values) 301 | 302 | if args.train == "train": 303 | _, summary, loss, probabilities, prediction, accuracy, confusion_matrix = session.run([self.train_op, self.summary_op, self.loss_op, self.probabilities_op, self.prediction_op, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 304 | else: # Sample case not necessary b/c function will only be called during normal runs 305 | summary, loss, probabilities, prediction, accuracy, confusion_matrix = session.run([self.summary_op, self.loss_op, self.probabilities_op, self.prediction_op, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 306 | 307 | print "Average accuracy per batch {0}".format(accuracy) 308 | print "Batch Loss: {0}".format(loss) 309 | # print "Output Predictions: {0}".format(prediction) 310 | # print "Output Prediction Probabilities: {0}".format(probabilities) 311 | 312 | return summary, confusion_matrix, accuracy 313 | 314 | 315 | def sample(self, session, feed_values): 316 | feed_dict = self._feed_dict(feed_values) 317 | 318 | logits, state = session.run([self.logits_op, self.state_op], feed_dict=feed_dict) 319 | return logits, state 320 | 321 | 322 | class CharRNNScope(object): 323 | 324 | def __init__(self, input_size, label_size, batch_size, vocab_size, cell_type, hyperparam_path, gan_inputs=None): 325 | with tf.variable_scope("CharRNN") as scope: 326 | self.input_size = input_size 327 | self.label_size = label_size 328 | self.cell_type = cell_type 329 | self.config = Config(hyperparam_path) 330 | self.config.batch_size = batch_size 331 | self.config.vocab_size = vocab_size 332 | self.gan_inputs = gan_inputs 333 | 334 | # self.initial_state = self.cell.zero_state(self.config.batch_size, dtype=tf.int32) 335 | # input_shape = (None,) + tuple([self.config.max_length,input_size]) 336 | input_shape = (None,) + tuple([input_size]) 337 | # output_shape = (None,) + tuple([self.config.max_length,label_size]) 338 | output_shape = (None,) + tuple([label_size]) 339 | 340 | self.input_placeholder = tf.placeholder(tf.int32, shape=input_shape, name='Input') 341 | self.label_placeholder = tf.placeholder(tf.int32, shape=output_shape, name='Output') 342 | self.meta_placeholder = tf.placeholder(tf.int32, shape=[None, self.config.num_meta], name='Meta') 343 | self.use_meta_placeholder = tf.placeholder(tf.bool, name='State_Initialization_Bool') 344 | 345 | if cell_type == 'rnn': 346 | self.cell = rnn.BasicRNNCell(self.config.hidden_size) 347 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[None, self.config.hidden_size], name="Initial_State") 348 | elif cell_type == 'gru': 349 | self.cell = rnn.GRUCell(self.config.hidden_size) 350 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[None, self.config.hidden_size], name="Initial_State") 351 | elif cell_type == 'lstm': 352 | self.cell = rnn.BasicLSTMCell(self.config.hidden_size) 353 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[self.config.num_layers, None, self.config.hidden_size], name="Initial_State") 354 | 355 | print "Completed Initializing the Char RNN Model using a {0} cell".format(cell_type.upper()) 356 | 357 | 358 | def create_model(self, is_train=True): 359 | with tf.variable_scope("CharRNN") as scope: 360 | if is_train: 361 | self.cell = rnn.DropoutWrapper(self.cell, input_keep_prob=1.0, output_keep_prob=self.config.keep_prob) 362 | rnn_model = rnn.MultiRNNCell([self.cell]*self.config.num_layers, state_is_tuple=True) 363 | 364 | # Embedding lookup for ABC format characters 365 | embeddings_var = tf.Variable(tf.random_uniform([self.config.vocab_size, self.config.embedding_dims], 366 | 0, 10, dtype=tf.float32, seed=3), name='char_embeddings') 367 | true_inputs = self.input_placeholder if (self.gan_inputs == None) else self.gan_inputs 368 | self.embeddings_var = embeddings_var 369 | embeddings = tf.nn.embedding_lookup(embeddings_var, true_inputs) 370 | 371 | # Embedding lookup for Metadata 372 | embeddings_var_meta = tf.Variable(tf.random_uniform([self.config.vocab_meta, self.config.meta_embed], 373 | 0, 10, dtype=tf.float32, seed=3), name='char_embeddings_meta') 374 | embeddings_meta = tf.nn.embedding_lookup(embeddings_var_meta, self.meta_placeholder[:, :5]) 375 | 376 | embeddings_meta_flat = tf.reshape(embeddings_meta, shape=[-1, self.config.hidden_size-2]) 377 | 378 | # Putting all the word embeddings together and then appending the numerical constants at the end of the word embeddings 379 | embeddings_meta = tf.concat([embeddings_meta_flat, tf.to_float(self.meta_placeholder[:, 5:])], axis=-1) 380 | 381 | print embeddings_meta.get_shape().as_list() 382 | if self.cell_type == 'lstm': 383 | initial_added = tf.cond(self.use_meta_placeholder, 384 | lambda: [embeddings_meta for layer in xrange(self.config.num_layers)], 385 | lambda: tf.unstack(self.initial_state_placeholder, axis=0)) # [self.initial_state_placeholder[layer] for layer in xrange(self.config.num_layers)]) 386 | [initial_added[idx].set_shape([self.config.batch_size, self.config.hidden_size]) for idx in xrange(self.config.num_layers)] 387 | initial_tuple = tuple([rnn.LSTMStateTuple(initial_added[idx], np.zeros((self.config.batch_size, self.config.hidden_size), dtype=np.float32)) for idx in xrange(self.config.num_layers)]) 388 | else: 389 | initial_added = tf.cond(self.use_meta_placeholder, 390 | lambda: embeddings_meta, 391 | lambda: self.initial_state_placeholder) 392 | initial_tuple = (initial_added, np.zeros((self.config.batch_size, self.config.hidden_size), dtype=np.float32)) 393 | 394 | rnn_output, self.state_op = tf.nn.dynamic_rnn(rnn_model, embeddings, dtype=tf.float32, initial_state=initial_tuple) 395 | 396 | decode_var = tf.Variable(tf.random_uniform([self.config.hidden_size, self.config.vocab_size], 397 | 0, 10, dtype=tf.float32, seed=3), name='char_decode') 398 | decode_bias = tf.Variable(tf.random_uniform([self.config.vocab_size], 399 | 0, 10, dtype=tf.float32, seed=3), name='char_decode_bias') 400 | decode_list = [] 401 | for i in xrange(self.input_size): 402 | decode_list.append(tf.matmul(rnn_output[:, i, :], decode_var) + decode_bias) 403 | 404 | self.logits_op = tf.stack(decode_list, axis=1) 405 | self.rnn_output = rnn_output 406 | self.probabilities_op = tf.nn.softmax(self.logits_op) 407 | 408 | print("Built the Char RNN Model...") 409 | 410 | 411 | def train(self, max_norm=5, op='adam'): 412 | with tf.variable_scope("CharRNN") as scope: 413 | self.loss_op = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits_op, labels=self.label_placeholder)) 414 | tf.summary.scalar('Loss', self.loss_op) 415 | tvars = tf.trainable_variables() 416 | 417 | # Gradient clipping 418 | grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss_op, tvars),max_norm) 419 | optimizer = tf.train.AdamOptimizer(self.config.lr) 420 | self.train_op = optimizer.apply_gradients(zip(grads, tvars)) 421 | 422 | print("Setup the training mechanism for the Char RNN Model...") 423 | 424 | 425 | def metrics(self): 426 | # Same function, did not make a general one b/c need to store _ops within class 427 | last_axis = len(self.probabilities_op.get_shape().as_list()) 428 | self.prediction_op = tf.to_int32(tf.argmax(self.probabilities_op, axis=last_axis-1)) 429 | difference = self.label_placeholder - self.prediction_op 430 | zero = tf.constant(0, dtype=tf.int32) 431 | boolean_difference = tf.cast(tf.equal(difference, zero), tf.float64) 432 | self.accuracy_op = tf.reduce_mean(boolean_difference) 433 | tf.summary.scalar('Accuracy', self.accuracy_op) 434 | 435 | self.summary_op = tf.summary.merge_all() 436 | 437 | if SHERLOCK: 438 | self.confusion_matrix = tf_confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.prediction_op, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 439 | else: 440 | self.confusion_matrix = tf.confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.prediction_op, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 441 | 442 | 443 | def _feed_dict(self, feed_values): 444 | input_batch = feed_values[0] 445 | label_batch = feed_values[1] 446 | meta_batch = feed_values[2] 447 | initial_state_batch = feed_values[3] 448 | use_meta_batch = feed_values[4] 449 | 450 | feed_dict = { 451 | self.input_placeholder: input_batch, 452 | self.label_placeholder: label_batch, 453 | self.meta_placeholder: meta_batch, 454 | self.initial_state_placeholder: initial_state_batch, 455 | self.use_meta_placeholder: use_meta_batch 456 | } 457 | 458 | return feed_dict 459 | 460 | 461 | def run(self, args, session, feed_values): 462 | feed_dict = self._feed_dict(feed_values) 463 | 464 | if args.train == "train": 465 | _, summary, loss, probabilities, prediction, accuracy, confusion_matrix = session.run([self.train_op, self.summary_op, self.loss_op, self.probabilities_op, self.prediction_op, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 466 | else: # Sample case not necessary b/c function will only be called during normal runs 467 | summary, loss, probabilities, prediction, accuracy, confusion_matrix = session.run([self.summary_op, self.loss_op, self.probabilities_op, self.prediction_op, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 468 | 469 | print "Average accuracy per batch {0}".format(accuracy) 470 | print "Batch Loss: {0}".format(loss) 471 | # print "Output Predictions: {0}".format(prediction) 472 | # print "Output Prediction Probabilities: {0}".format(probabilities) 473 | 474 | return summary, confusion_matrix, accuracy 475 | 476 | 477 | def sample(self, session, feed_values): 478 | feed_dict = self._feed_dict(feed_values) 479 | 480 | logits, state = session.run([self.logits_op, self.state_op], feed_dict=feed_dict) 481 | return logits, state 482 | 483 | 484 | 485 | 486 | 487 | class Seq2SeqRNN(object): 488 | 489 | def __init__(self,input_size, label_size, batch_size, vocab_size, cell_type, 490 | hyperparam_path, start_encode, end_encode): 491 | self.input_size = input_size 492 | self.label_size = label_size 493 | self.cell_type = cell_type 494 | self.config = Config(hyperparam_path) 495 | self.config.batch_size = batch_size 496 | self.config.vocab_size = vocab_size 497 | 498 | # input_shape = (None,) + tuple([input_size]) 499 | input_shape = (None, None) 500 | # output_shape = (None,) + tuple([label_size]) 501 | output_shape = (None, None) 502 | 503 | self.input_placeholder = tf.placeholder(tf.int32, shape=input_shape, name='Input') 504 | self.label_placeholder = tf.placeholder(tf.int32, shape=output_shape, name='Output') 505 | self.meta_placeholder = tf.placeholder(tf.int32, shape=[None, self.config.num_meta], name='Meta') 506 | self.use_meta_placeholder = tf.placeholder(tf.bool, name='State_Initialization_Bool') 507 | self.num_encode = tf.placeholder(tf.int32, shape=(None,), name='Num_encode') 508 | self.num_decode = tf.placeholder(tf.int32, shape=(None,), name='Num_decode') 509 | 510 | if cell_type == 'rnn': 511 | self.encoder_cell = rnn.BasicRNNCell(self.config.hidden_size) 512 | self.decoder_cell = rnn.BasicRNNCell(2*self.config.hidden_size) 513 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[None, self.config.hidden_size], name="Initial_State") 514 | elif cell_type == 'gru': 515 | self.encoder_cell = rnn.GRUCell(self.config.hidden_size) 516 | self.decoder_cell = rnn.GRUCell(2*self.config.hidden_size) 517 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[None, self.config.hidden_size], name="Initial_State") 518 | elif cell_type == 'lstm': 519 | self.encoder_cell = rnn.BasicLSTMCell(self.config.hidden_size) 520 | self.decoder_cell = rnn.BasicLSTMCell(self.config.hidden_size) 521 | self.initial_state_placeholder = tf.placeholder(tf.float32, shape=[self.config.num_layers, None, self.config.hidden_size], name="Initial_State") 522 | 523 | print "Completed Initializing the Seq2Seq RNN Model using a {0} cell".format(cell_type.upper()) 524 | 525 | # Seq2Seq specific initializers 526 | self.attention_option = "luong" 527 | self.start_encode = start_encode 528 | self.end_encode = end_encode 529 | 530 | 531 | # Based on the example model presented in https://github.com/ematvey/tensorflow-seq2seq-tutorials/blob/master/model_new.py 532 | def create_model(self, is_train): 533 | with tf.variable_scope("Seq2Seq") as scope: 534 | 535 | def output_fn(outputs): 536 | return tf.contrib.layers.linear(outputs, self.config.vocab_size, scope=scope) 537 | 538 | if is_train: 539 | self.encoder_cell = rnn.DropoutWrapper(self.encoder_cell, input_keep_prob=1.0, output_keep_prob=self.config.keep_prob) 540 | self.decoder_cell = rnn.DropoutWrapper(self.decoder_cell, input_keep_prob=1.0, output_keep_prob=self.config.keep_prob) 541 | 542 | self.encoder_cell = rnn.MultiRNNCell([self.encoder_cell]*self.config.num_layers, state_is_tuple=True) 543 | self.decoder_cell = rnn.MultiRNNCell([self.decoder_cell]*self.config.num_layers, state_is_tuple=True) 544 | 545 | # GO_SLICE = tf.ones([tf.shape(self.input_placeholder)[0],1], dtype=tf.int32)*self.start_encode 546 | 547 | # self.decoder_train_inputs = self.label_placeholder[:,:self.input_size-1] 548 | self.go_token = tf.constant(self.config.vocab_size-1, dtype=tf.int32, shape=[1, self.config.batch_size]) 549 | self.decoder_train_inputs = tf.concat([self.go_token, self.label_placeholder[:self.input_size-1, :]], axis=0) 550 | 551 | self.decoder_train_targets = self.label_placeholder 552 | 553 | self.loss_weights = tf.ones([self.config.batch_size, self.input_size], dtype=tf.float32, name="loss_weights") 554 | sqrt3 = math.sqrt(3) 555 | initializer = tf.random_uniform_initializer(-sqrt3, sqrt3) 556 | 557 | # Creating the embeddings and deriving the embeddings for the encoder and decoder 558 | self.embedding_matrix = tf.get_variable(name="embedding_matrix", 559 | shape=[self.config.vocab_size, self.config.embedding_dims], initializer=initializer, 560 | dtype=tf.float32) 561 | 562 | self.encoder_embedded = tf.nn.embedding_lookup(self.embedding_matrix, self.input_placeholder) 563 | 564 | self.decoder_inputs_embedded = tf.nn.embedding_lookup(self.embedding_matrix, self.decoder_train_inputs) 565 | 566 | # Embedding lookup for Metadata 567 | embeddings_var_meta = tf.Variable(tf.random_uniform([self.config.vocab_meta, self.config.meta_embed], 568 | 0, 10, dtype=tf.float32, seed=3), name='char_embeddings_meta') 569 | embeddings_meta = tf.nn.embedding_lookup(embeddings_var_meta, self.meta_placeholder[:, :5]) 570 | 571 | embeddings_meta_flat = tf.reshape(embeddings_meta, shape=[-1, self.config.hidden_size-2]) 572 | 573 | # Putting all the word embeddings together and then appending the numerical constants at the end of the word embeddings 574 | embeddings_meta = tf.concat([embeddings_meta_flat, tf.to_float(self.meta_placeholder[:, 5:])], axis=-1) 575 | 576 | # Create initial_state 577 | if self.cell_type == 'lstm': 578 | initial_added = tf.cond(self.use_meta_placeholder, 579 | lambda: [embeddings_meta for layer in xrange(self.config.num_layers)], 580 | lambda: tf.unstack(self.initial_state_placeholder, axis=0)) # [self.initial_state_placeholder[layer] for layer in xrange(self.config.num_layers)]) 581 | [initial_added[idx].set_shape([self.config.batch_size, self.config.hidden_size]) for idx in xrange(self.config.num_layers)] 582 | initial_tuple = tuple([rnn.LSTMStateTuple(initial_added[idx], np.zeros((self.config.batch_size, self.config.hidden_size), dtype=np.float32)) for idx in xrange(self.config.num_layers)]) 583 | else: 584 | initial_added = tf.cond(self.use_meta_placeholder, 585 | lambda: embeddings_meta, 586 | lambda: self.initial_state_placeholder) 587 | initial_tuple = (initial_added, np.zeros((self.config.batch_size, self.config.hidden_size), dtype=np.float32)) 588 | 589 | if not self.config.bidirectional: 590 | self.encoder_outputs, self.encoder_state = tf.nn.dynamic_rnn(cell=self.encoder_cell, inputs=self.encoder_embedded, 591 | sequence_length=self.num_encode,time_major=True, dtype=tf.float32, initial_state=initial_tuple) 592 | else: 593 | ((encoder_fw_outputs,encoder_bw_outputs),\ 594 | (encoder_fw_state, encoder_bw_state)) = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.encoder_cell, 595 | cell_bw=self.encoder_cell, inputs=self.encoder_embedded,initial_state_fw=initial_tuple, 596 | initial_state_bw=initial_tuple,sequence_length=self.num_encode, time_major=True, dtype=tf.float32) 597 | 598 | self.encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2) 599 | 600 | if isinstance(encoder_fw_state, tuple): 601 | encoder_state_c = tf.concat( (encoder_fw_state[0], encoder_bw_state[1]), 1, name='bidirectional_concat_c') 602 | encoder_state_h = tf.concat( (encoder_fw_state[0], encoder_bw_state[1]), 1, name='bidirectional_concat_h') 603 | self.encoder_state = rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h) 604 | 605 | else: 606 | self.encoder_state = tf.concat((encoder_fw_state, encoder_bw_state), 1, name='bidirectional_concat') 607 | 608 | self.encoder_outputs = encoder_fw_outputs 609 | self.encoder_state = encoder_fw_state 610 | 611 | # Setting up the Attention mechanism 612 | # print type(self.encoder_outputs) 613 | # print type(self.encoder_state) 614 | attention_states = tf.transpose(self.encoder_outputs, [1, 0, 2]) 615 | 616 | attention_keys, attention_values, attention_score_fn, \ 617 | attention_construct_fn = seq2seq.prepare_attention( attention_states=attention_states, 618 | attention_option=self.attention_option, num_units=self.config.hidden_size) 619 | 620 | decoder_fn_train = seq2seq.attention_decoder_fn_train( encoder_state=self.encoder_state, 621 | attention_keys=attention_keys, attention_values=attention_values, 622 | attention_score_fn=attention_score_fn, attention_construct_fn=attention_construct_fn, 623 | name='attention_decoder') 624 | 625 | # decoder_fn_inference = seq2seq.attention_decoder_fn_inference( output_fn=output_fn, encoder_state=self.encoder_state, 626 | # attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, 627 | # attention_construct_fn=attention_construct_fn, embeddings=self.embedding_matrix, 628 | # start_of_sequence_id=self.start_encode, end_of_sequence_id=self.end_encode, 629 | # maximum_length=tf.reduce_max(self.num_decode), num_decoder_symbols=self.config.vocab_size) 630 | 631 | decoder_fn_inference = utils_models.attention_decoder_fn_sampled_inference( output_fn=output_fn, encoder_state=self.encoder_state, 632 | attention_keys=attention_keys, attention_values=attention_values, attention_score_fn=attention_score_fn, 633 | attention_construct_fn=attention_construct_fn, embeddings=self.embedding_matrix, 634 | start_of_sequence_id=self.start_encode, end_of_sequence_id=self.end_encode, 635 | maximum_length=tf.reduce_max(self.num_decode) + 3, num_decoder_symbols=self.config.vocab_size, temperature=0.5) 636 | 637 | self.decoder_outputs_train, self.decoder_state_train, \ 638 | self.decoder_context_state_train = seq2seq.dynamic_rnn_decoder( cell=self.decoder_cell, 639 | decoder_fn=decoder_fn_train, inputs=self.decoder_inputs_embedded, 640 | sequence_length=self.num_decode, time_major=True, scope=scope) 641 | 642 | self.decoder_logits_train = tf.contrib.layers.linear(self.decoder_outputs_train, self.config.vocab_size, scope=scope) 643 | self.decoder_prediction_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_prediction_train') 644 | 645 | # self.decoder_prediction_train = tf.argmax(self.decoder_logits_train, axis=-1, name='decoder_prediction_train') 646 | 647 | scope.reuse_variables() 648 | 649 | self.decoder_logits_inference, self.decoder_state_inference, \ 650 | self.decoder_context_state_inference = seq2seq.dynamic_rnn_decoder(cell=self.decoder_cell, 651 | decoder_fn=decoder_fn_inference, time_major=True, scope=scope) 652 | 653 | 654 | self.decoder_prediction_inference = tf.argmax(self.decoder_logits_inference, axis=-1, name='decoder_prediction_inference') 655 | 656 | print("Built the Seq2Seq RNN Model...") 657 | 658 | 659 | def train(self, op='adam', max_norm=5): 660 | logits = tf.transpose(self.decoder_logits_train, [1, 0, 2]) 661 | targets = tf.transpose(self.decoder_train_targets, [1, 0]) 662 | # print self.decoder_logits_train.get_shape().as_list() 663 | # print self.decoder_train_targets.get_shape().as_list() 664 | # print logits.get_shape().as_list() 665 | # print targets.get_shape().as_list() 666 | 667 | self.loss_op = seq2seq.sequence_loss(logits=logits, targets=targets, 668 | weights=self.loss_weights) 669 | tf.summary.scalar('Loss', self.loss_op) 670 | self.train_op = tf.train.AdamOptimizer().minimize(self.loss_op) 671 | print("Setup the training mechanism for the Seq2Seq RNN Model...") 672 | 673 | 674 | def metrics(self): 675 | # Same function, did not make a general one b/c need to store _ops within class 676 | difference = self.decoder_train_targets - tf.cast(self.decoder_prediction_train, tf.int32) 677 | zero = tf.constant(0, dtype=tf.int32) 678 | boolean_difference = tf.cast(tf.equal(difference, zero), tf.float64) 679 | self.accuracy_op = tf.reduce_mean(boolean_difference) 680 | tf.summary.scalar('Accuracy', self.accuracy_op) 681 | 682 | self.summary_op = tf.summary.merge_all() 683 | 684 | if SHERLOCK: 685 | self.confusion_matrix = tf_confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.decoder_prediction_train, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 686 | else: 687 | self.confusion_matrix = tf.confusion_matrix(tf.reshape(self.label_placeholder, [-1]), tf.reshape(self.decoder_prediction_train, [-1]), num_classes=self.config.vocab_size, dtype=tf.int32) 688 | 689 | 690 | def _feed_dict(self, feed_values): 691 | input_batch = feed_values[0] 692 | label_batch = feed_values[1] 693 | meta_batch = feed_values[2] 694 | initial_state_batch = feed_values[3] 695 | use_meta_batch = feed_values[4] 696 | num_encode = feed_values[5] 697 | num_decode = feed_values[6] 698 | 699 | feed_dict = { 700 | self.input_placeholder: input_batch, 701 | self.label_placeholder: label_batch, 702 | self.meta_placeholder: meta_batch, 703 | self.initial_state_placeholder: initial_state_batch, 704 | self.use_meta_placeholder: use_meta_batch, 705 | self.num_encode: num_encode, 706 | self.num_decode: num_decode 707 | } 708 | 709 | return feed_dict 710 | 711 | 712 | def run(self, args, session, feed_values): 713 | feed_dict = self._feed_dict(feed_values) 714 | 715 | if args.train == "train": 716 | _, summary, loss, prediction, accuracy, confusion_matrix = session.run([self.train_op, self.summary_op, self.loss_op, self.decoder_prediction_train, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 717 | else: # Sample case not necessary b/c function will only be called during normal runs 718 | summary, loss, prediction, accuracy, confusion_matrix = session.run([self.summary_op, self.loss_op, self.decoder_prediction_train, self.accuracy_op, self.confusion_matrix], feed_dict=feed_dict) 719 | 720 | print "Average accuracy per batch {0}".format(accuracy) 721 | print "Batch Loss: {0}".format(loss) 722 | # print "Output Predictions: {0}".format(prediction) 723 | # print "Output Prediction Probabilities: {0}".format(probabilities) 724 | 725 | return summary, confusion_matrix, accuracy 726 | 727 | 728 | def sample(self, session, feed_values): 729 | feed_dict = self._feed_dict(feed_values) 730 | 731 | logits = tf.transpose(self.decoder_logits_inference, [1, 0, 2]) 732 | probabilities = tf.nn.softmax(logits) 733 | predictions = tf.argmax(probabilities, axis=-1) 734 | pred = session.run(predictions, feed_dict=feed_dict) 735 | return pred 736 | 737 | 738 | 739 | 740 | 741 | class Discriminator(object): 742 | 743 | def __init__(self, inputs, labels_size, is_training, batch_size, hyperparam_path, use_lrelu=True, use_batchnorm=False, dropout=None, reuse=True): 744 | self.input = inputs 745 | self.labels_size = labels_size 746 | self.batch_size = batch_size 747 | self.is_training = is_training 748 | self.reuse = reuse 749 | self.dropout = dropout 750 | self.use_batchnorm = use_batchnorm 751 | self.use_lrelu = use_lrelu 752 | self.config = Config(hyperparam_path) 753 | 754 | def lrelu(self, x, leak=0.2, name='lrelu'): 755 | return tf.maximum(x, leak*x) 756 | 757 | def conv_layer(self, inputs, filterSz, strideSz, padding, use_lrelu, use_batchnorm): 758 | with tf.variable_scope("discriminator") as scope: 759 | if self.reuse: 760 | scope.reuse_variables() 761 | 762 | filterWeights = tf.Variable(tf.random_normal(filterSz)) 763 | l1 = tf.nn.conv2d(inputs,filterWeights,strideSz,padding='SAME') 764 | if use_batchnorm: 765 | l1 = tf.contrib.layers.batch_norm(l1, decay=0.9,center=True,scale=True, 766 | epsilon=1e-8,is_training=is_training, reuse=self.reuse, trainable=True, scope=scope) 767 | 768 | if use_lrelu: 769 | l2 = self.lrelu(l1) 770 | else: 771 | l2 = tf.nn.relu(l1) 772 | 773 | if self.dropout is not None and self.is_training == True: 774 | l2 = tf.nn.dropout(l2, self.dropout) 775 | 776 | return l2 777 | 778 | 779 | def create_model(self): 780 | with tf.variable_scope("discriminator") as scope: 781 | if self.reuse: 782 | scope.reuse_variables() 783 | 784 | 785 | filterSz1 = [3, self.config.embedding_dims,1, self.config.numFilters] 786 | strideSz1 = [1,1,1,1] 787 | 788 | conv_layer1 = self.conv_layer(self.input,filterSz1,strideSz1, padding='SAME', use_lrelu=True, use_batchnorm=False) 789 | 790 | filterSz2 = [3, 1, self.config.numFilters, self.config.numFilters] 791 | strideSz2 = [1,1,1,1] 792 | conv_layer2 = self.conv_layer(conv_layer1,filterSz2,strideSz2,padding='SAME',use_lrelu=True, use_batchnorm=False) 793 | 794 | win_size = [1,3,1,1] 795 | strideSz3 = [1,1,1,1] 796 | conv_layer3 = tf.nn.max_pool(conv_layer2,ksize=win_size,strides=strideSz3, padding='SAME') 797 | 798 | layerShape = conv_layer3.get_shape().as_list() 799 | numParams = reduce(lambda x, y: x*y, layerShape[1:]) 800 | 801 | layer_flatten = tf.reshape(conv_layer3, [-1, numParams]) 802 | 803 | fully_conn_weights_1 = tf.get_variable("weights_fully_conn_1", [numParams, self.config.hidden_units], 804 | initializer=tf.random_normal_initializer()) 805 | fully_conn_bias_1 = tf.get_variable("bias_fully_conn_1", [self.config.hidden_units,], 806 | initializer=tf.random_normal_initializer()) 807 | layer4 = tf.matmul(layer_flatten,fully_conn_weights_1 ) + fully_conn_bias_1 808 | 809 | if self.dropout is not None and self.is_training == True: 810 | layer4 = tf.nn.dropout(layer4, self.dropout) 811 | 812 | fully_conn_weights_2 = tf.get_variable("weights_fully_conn_2", [self.config.hidden_units,self.labels_size ], 813 | initializer=tf.random_normal_initializer()) 814 | fully_conn_bias_2 = tf.get_variable("bias_fully_conn_2", [self.labels_size,], 815 | initializer=tf.random_normal_initializer()) 816 | layer5 = tf.matmul(layer4,fully_conn_weights_2 ) + fully_conn_bias_2 817 | 818 | self.output = layer5 819 | self.pred = tf.nn.softmax(layer5) 820 | 821 | return self.pred 822 | 823 | 824 | # def train(self, op='adam'): 825 | # self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.output, 826 | # labels=self.labels)) 827 | 828 | 829 | # train_op = tf.train.AdamOptimizer(self.config.cnn_lr).minimize(self.loss) 830 | # return train_op 831 | 832 | 833 | 834 | class GenAdversarialNet(object): 835 | 836 | 837 | def __init__(self, input_size, label_size ,num_classes, cell_type, is_training, batch_size, vocab_size, 838 | hyperparam_path, use_lrelu=True, use_batchnorm=False, dropout=None): 839 | self.input_size = input_size 840 | self.label_size = label_size 841 | self.is_training = is_training 842 | self.cell_type = cell_type 843 | self.hyperparam_path = hyperparam_path 844 | self.use_lrelu = use_lrelu 845 | self.use_batchnorm = use_batchnorm 846 | self.dropout = dropout 847 | self.config = Config(hyperparam_path) 848 | self.batch_size = batch_size 849 | self.config.vocab_size = vocab_size 850 | self.config.num_classes = num_classes 851 | 852 | output_shape = (None,) 853 | self.label_placeholder = tf.placeholder(tf.int32, shape=output_shape, name='Output') 854 | 855 | input_shape = (None,) + tuple([self.input_size]) 856 | self.input_placeholder = tf.placeholder(tf.int32, shape=input_shape, name='Input') 857 | 858 | print "Completed Initializing the GAN Model using a {0} cell".format(cell_type.upper()) 859 | 860 | 861 | # Function taken from Goodfellow's Codebase on Training of GANs: https://github.com/openai/improved-gan/ 862 | def sigmoid_kl_with_logits(self, logits, targets): 863 | # broadcasts the same target value across the whole batch 864 | # this is implemented so awkwardly because tensorflow lacks an x log x op 865 | if targets in [0., 1.]: 866 | entropy = 0. 867 | else: 868 | entropy = - targets * np.log(targets) - (1. - targets) * np.log(1. - targets) 869 | return tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=tf.ones_like(logits) * targets) - entropy 870 | 871 | # Ideas for function taken from Goodfellow's Codebase on Training of GANs: https://github.com/openai/improved-gan/ 872 | def normalize_class_outputs(self,logits): 873 | generated_class_logits = tf.squeeze(tf.slice(logits, [0, self.config.num_classes - 1], [self.batch_size, 1])) 874 | positive_class_logits = tf.slice(logits, [0, 0], [self.batch_size, self.config.num_classes - 1]) 875 | mx = tf.reduce_max(positive_class_logits, 1, keep_dims=True) 876 | safe_pos_class_logits = positive_class_logits - mx 877 | 878 | gan_logits = tf.log(tf.reduce_sum(tf.exp(safe_pos_class_logits), 1)) + tf.squeeze(mx) - generated_class_logits 879 | return gan_logits 880 | 881 | 882 | def create_model(self): 883 | gen_batch_size = self.batch_size/2 884 | disc_batch_size = self.batch_size/2 885 | 886 | generator_inputs = tf.slice(self.input_placeholder, [0,0], [gen_batch_size,self.input_size ]) 887 | 888 | # print generator_inputs.get_shape().as_list() 889 | self.generator_model = CharRNN(self.input_size, self.label_size, gen_batch_size ,self.config.vocab_size, 890 | self.cell_type, self.hyperparam_path, gan_inputs = generator_inputs) 891 | self.generator_model.create_model(is_train = True) 892 | self.generator_model.train() 893 | self.generator_output = self.generator_model.logits_op 894 | # print self.generator_output.get_shape().as_list() 895 | 896 | # Sample the output of the GAN to find the correct prediction of each character 897 | generator_samples = [] 898 | for i in xrange(self.input_size): 899 | generator_samples.append(tf.multinomial(self.generator_output[:,i,:], num_samples=1)) 900 | 901 | self.current_policy = tf.stack(generator_samples, axis=1) 902 | 903 | # Create the Discriminator embeddings and sample the Generator output and Real input from these embeddings 904 | real_inputs = tf.slice(self.input_placeholder, [disc_batch_size,0], [disc_batch_size,self.input_size ]) 905 | with tf.variable_scope("discriminator") as scope: 906 | self.embeddings_disc = tf.get_variable('disc_embeddings',[self.config.vocab_size, self.config.embedding_dims], 907 | dtype=tf.float32) 908 | 909 | embeddings_generator_out = tf.nn.embedding_lookup(self.embeddings_disc, self.current_policy) 910 | embeddings_real_input = tf.nn.embedding_lookup(self.embeddings_disc, real_inputs) 911 | 912 | # Inputs the fake examples from the CharRNN to the CNN Discriminator 913 | embeddings_generator_out = tf.expand_dims(embeddings_generator_out[:,:,0,:], -1) 914 | self.discriminator_gen_model = Discriminator(embeddings_generator_out, self.config.num_classes, is_training=self.is_training, 915 | batch_size=gen_batch_size, hyperparam_path=self.hyperparam_path, use_lrelu=self.use_lrelu, use_batchnorm=self.use_batchnorm, 916 | dropout=self.dropout, reuse=False) 917 | discriminator_gen_pred = self.discriminator_gen_model.create_model() 918 | 919 | # Inputs the real sequences from the text files to the CNN Discriminator 920 | embeddings_real_input = tf.expand_dims(embeddings_real_input, -1) 921 | self.discriminator_real_samp = Discriminator(embeddings_real_input, self.config.num_classes, is_training=self.is_training, 922 | batch_size=disc_batch_size, hyperparam_path=self.hyperparam_path, use_lrelu=self.use_lrelu, use_batchnorm=self.use_batchnorm, 923 | dropout=self.dropout, reuse=True) 924 | discriminator_real_pred = self.discriminator_real_samp.create_model() 925 | 926 | 927 | # Collecting outputs and finding losses 928 | self.gan_real_output = self.discriminator_real_samp.output 929 | self.gan_fake_output = self.discriminator_gen_model.output 930 | 931 | self.tot_gan_logits = tf.concat([self.gan_real_output, self.gan_fake_output], axis=0) 932 | print self.tot_gan_logits.get_shape().as_list() 933 | self.gan_logits_norm = self.normalize_class_outputs(self.tot_gan_logits) 934 | # self.gan_logits_fake = self.normalize_class_outputs(self.gan_fake_output) 935 | 936 | self.gan_logits_real = self.gan_logits_norm[:self.batch_size/2] 937 | self.gan_logits_fake = self.gan_logits_norm[self.batch_size/2:] 938 | 939 | self.gan_pred_real = self.sigmoid_kl_with_logits(self.gan_logits_real, 1. - self.config.label_smooth) 940 | self.gan_pred_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits=self.gan_logits_fake, 941 | labels=tf.zeros_like(self.gan_logits_fake)) 942 | 943 | print("Built the GAN Model...") 944 | return self.gan_pred_real, self.gan_pred_fake 945 | 946 | 947 | 948 | def train(self): 949 | class_loss_weight = 1 950 | 951 | self.loss_class = class_loss_weight*tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.gan_real_output, 952 | labels=self.label_placeholder) 953 | 954 | self.tot_d_loss = tf.reduce_mean(self.gan_pred_real + self.gan_pred_fake + self.loss_class) 955 | # tot_g_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.generator_output, 956 | # labels=self.label_placeholder[self.config.batch_size/2:])) 957 | 958 | disc_accuracy = tf.reduce_mean(tf.to_float(tf.nn.in_top_k(self.gan_real_output, self.label_placeholder, 1))) 959 | disc_pred_corr_examples = tf.cast(tf.equal(tf.argmax(self.gan_fake_output, axis=-1), 960 | tf.ones([self.batch_size/2 ], dtype=tf.int64)*self.config.num_classes-1), tf.float32) 961 | gen_fool_dis_examples = tf.cast(tf.not_equal(tf.argmax(self.gan_fake_output, axis=-1), 962 | tf.ones([self.batch_size/2 ], dtype=tf.int64)*self.config.num_classes-1), tf.float32) 963 | 964 | num_real_class = tf.reduce_mean(gen_fool_dis_examples) 965 | num_fake_class = tf.reduce_mean(disc_pred_corr_examples) 966 | 967 | combined_labels = gen_fool_dis_examples + disc_pred_corr_examples 968 | 969 | combined_labels = tf.expand_dims(combined_labels, -1) 970 | combined_labels = tf.expand_dims(combined_labels, -1) 971 | # prob_grads = tf.multiply(combined_labels, tf.nn.softmax(self.generator_output)) 972 | 973 | all_disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') 974 | all_gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='CharRNN') 975 | 976 | self.policy_grads = combined_labels*tf.one_hot(self.current_policy[:,:,0], depth=self.config.vocab_size)*self.generator_output 977 | self.d_gen_grad = tf.gradients(self.tot_d_loss, self.embeddings_disc) 978 | self.gen_grad = tf.gradients(self.policy_grads, all_gen_vars) 979 | self.train_op_d = tf.train.AdamOptimizer(self.config.gan_lr).apply_gradients(zip(self.d_gen_grad, [self.embeddings_disc])) 980 | self.train_op_gan = tf.train.AdamOptimizer(self.config.gan_lr).apply_gradients(zip(self.gen_grad, all_gen_vars)) 981 | 982 | print "Completed setup of training mechanism for the GAN...." 983 | return self.input_placeholder, self.label_placeholder, \ 984 | self.generator_model.meta_placeholder, self.generator_model.initial_state_placeholder, \ 985 | self.generator_model.use_meta_placeholder, self.train_op_d, self.train_op_gan, self.tot_d_loss, \ 986 | tf.reduce_mean(self.loss_class), disc_accuracy, num_real_class 987 | --------------------------------------------------------------------------------